Splay Trees in Julia Lang

Let us implement Splay Trees in Julia which is, relatively new and popular, dynamically typed language with multiple dispatch.

We model a node in the splay tree as an abstract data type. Each node contains a single integer data value and two references to the left and the right subtrees. The subtrees can be empty, as in the case of leaf nodes, and we model this by using the Nil type. This is similar to the Maybe type in Haskell, and it allows us to specify that a subtree is missing. We construct a node, root, the root of the splay tree. We will use it as a global variable. We assume that only integral data is stored and all elements are distinct. References are passed by values in Julia so we use a global variable to change state inside a function.

Node ADT

type Nil end
typealias MayBe{T} Union{T,Nil}

type splay 
    data::Int
    left::MayBe{splay}
    right::MayBe{splay}
end

splay( data :: Int) = splay(data, Nil(), Nil() )
root = splay(0)                # splay tree the root is root

The splay tree supports, find, insert, and delete. The amortized cost of each of these operations is at most $O(\log{n})$ for a sufficiently long sequence of operations.

Find

Recall that the ordering property in a binary search states that at every node, the data stored in the left subtree is smaller than the data stored at the node, which in turn is smaller than any data stored in the right subtree. So, to search for a key, we compare the key with the value stored at the node and branch accordingly. The code below is recursive and is not the most efficient. If the key is in the tree then we splay the node containing the key, else we splay the parent of the node where the search failed.

function find(node::splay, data)
    global root
    if (data == root.data) || (typeof(node) == Nil)
        return
    end
    if (data == node.data)
        splay!(node)
        return
    end
    (data < node.data ? (typeof(node.left) != Nil ? find(node.left, data) : splay!(node)) :
        (typeof(node.right) != Nil ? find(node.right, data) : splay!(node)) )
end

Delete

To delete a node containing the key, we find the key in the splay tree. This reorganizes the tree by calling the splay function. If the key is in the tree then the key is at the root. We then delete the root as in BST. If the key is not in the root, we do nothing. Several cases are considered. The subtree needs to be reattached to the original tree, so we find the parent and stitch the subtree back as a child.

function delete(key)
    global root
    global debugprint

    if (debugprint == true)
        println("delete ", key)
    end

    find(root, key)

    # left subtree is Nil, right not Nil, left subtree of the right child is Nil
    if (typeof(root.left)==Nil && typeof(root.right)!=Nil && typeof(root.right.left)==Nil) 
        root = root.right
        prn(root)
    end
    
    # symmetric case to above
    if (typeof(root.right)==Nil && typeof(root.left)!=Nil && typeof(root.left.right)==Nil) 
        root = root.left
        prn(root)
    end

    if (root.data == key) && (typeof(root.left)!=Nil)      # left subtree is not Nil
        n = root.left
        while ( typeof(n.right)!=Nil)            # get the max in the left subtree
            n = n.right
        end
        p = findParent(n.data)
        p.right = n.left
        root.data = n.data                
        prn(root)
        return
    end

    if (root.data == key) && (typeof(root.right)!=Nil)      # left subtree is Nil, right is !Nil
        n = root.right
        while ( typeof(n.left)!=Nil)            # get the min in the left subtree
            n = n.left
        end
        p = findParent(n.data)
        p.left = n.right
        root.data = n.data                
        prn(root)
        return
    end
    (root.data == key ? root = Nil : return)        # both subtrees are empty
end

Insert

We use binary search to find the place. A new node is created and splayed all the way to the top using splay!.

function insert(node::splay, data ) 
    
    if (data < node.data )
        if (typeof(node.left)== Nil) 
            node.left = splay(data, Nil(), Nil() )
            splay!( node.left )                # splay the node
            return
        else
            insert(node.left, data)
        end
    end
    if (data > node.data )
        if (typeof(node.right)== Nil) 
            node.right = splay(data, Nil(), Nil())
            splay!( node.right )                # splay the node
            return
        else
            insert(node.right, data)
        end
    end
    #return node         
end

Splay

The basic operation which does the reorganization of data is called splay. Splaying a node x restructures the subtree rooted at the grandparent of x (if one exists), otherwise it restructures the subtree rooted at the parent of x. Three cases are to be considered. For details of the three rotations, see the Wikipedia entry on splay trees. See lecture 12 in the book by Kozen for details of the analysis.

This is a recursive implementation. The node x is splayed until it is the root of the tree.

function splay!( x )  # splay the tree rooted at node, containing the node x
    global root

    global debugprint

    if (debugprint == true)
        println()
        println()
        println("Splay ", x.data)
        println()
        prn(root)
    end

    p = findParent(x.data)
    g = findParent(p.data)


    #println(" p, g ", p.data, " ", g.data)

    if (x.data == root.data)     # x is already at the root
        return 
    end

    if (p.data == g.data) && (p.data != x.data)     # no grandparent    # Simple rotation
        if (x.data < p.data)
            p.left = x.right
            x.right = p
        end
        if (x.data > p.data)
            p.right = x.left
            x.left = p
        end
        root = x
        splay!(x)        # recursively splay x to the top
        # println(root)
        return
    end

    # x has a parent and a grandparent
    # r is the node where the subtree will hang
    r = findParent(g.data)

    if (x.data < p.data < g.data)                        # Zig Zig
        g.left = p.right
        p.left = x.right
        p.right = g
        x.right = p
        if (g.data == root.data)
            root = x
            #println("zero ", root)
        else
            (r.data > x.data? r.left =x : r.right =x)
        end
        splay!(x)        # recursively splay x to the top
        return
    end

    if (x.data > p.data > g.data)                        # Zig Zig
        g.right = p.left
        p.right = x.left
        p.left = g
        x.left = p
        if (g.data == root.data)
            root = x
        else
            (r.data > x.data? r.left =x : r.right =x)
        end
        splay!(x)        # recursively splay x to the top
        return
    end

    if (p.data < x.data < g.data)                        # Zig Zag
        p.right = x.left
        g.left = x.right
        x.left = p
        x.right = g
        if (g.data == root.data)
            root = x
        else
            (r.data > x.data? r.left =x : r.right =x)
        end
        splay!(x)        # recursively splay x to the top
        return
    end

    if (g.data < x.data < p.data)                        # Zig Zag
        p.left = x.right
        g.right = x.left
        x.left = g
        x.right = p
        if (g.data == root.data)
            root = x
        else
            (r.data > x.data? r.left =x : r.right =x)
        end
        splay!(x)        # recursively splay x to the top
        return
    end
end

Helper Functions

# Print the tree using in order traversal.
# Large trees with more than 16 nodes will generate 
# a run time error.

prn(node::splay) = ppst(node, 60)

function ppst(node::splay,  right::Int)
    if (typeof(node.left)!=Nil)
        ppst(node.left,  right - 5)
    end
    println( repeat(" ", right-5) * "-----" * string(node.data))
    if (typeof(node.right)!=Nil)
        ppst(node.right,  right - 5)
    end
end


# helper function to find the parent of the node that
# contains the key. We use it to stitch the subtree as a child of
# the parent after a rotation is performed.

function findParent(key)
    node, par = root, root
    while (node.data != key)
        par = node
        if (key < node.data)
            node = node.left
        else
            node = node.right
        end
    end
    return par
end

Test Cases

Let us start with a root that contains 0. The results of inserting elements 1,2,..,5 is shown below. If you rotate the screen counterclockwise by 90 degrees, the figure below would be the same as the typical drawing of a binary tree, with root on the top, and the left subtree on the left side, and right subtree to the right of the root.

debugprint = true
for i in 1:5
    insert(root,i)
end
Splay 1

                                                       -----0
                                                  -----1


Splay 1

                                                  -----0
                                                       -----1


Splay 2

                                                  -----0
                                                       -----1
                                                  -----2


Splay 2

                                             -----0
                                                  -----1
                                                       -----2


Splay 3

                                             -----0
                                                  -----1
                                                       -----2
                                                  -----3


Splay 3

                                        -----0
                                             -----1
                                                  -----2
                                                       -----3


Splay 4

                                        -----0
                                             -----1
                                                  -----2
                                                       -----3
                                                  -----4


Splay 4

                                   -----0
                                        -----1
                                             -----2
                                                  -----3
                                                       -----4


Splay 5

                                   -----0
                                        -----1
                                             -----2
                                                  -----3
                                                       -----4
                                                  -----5


Splay 5

                              -----0
                                   -----1
                                        -----2
                                             -----3
                                                  -----4
                                                       -----5

find(0) repeatedly splays the node with 0, until it is the new root as shown below.

find(root,0)
Splay 0

                              -----0
                                   -----1
                                        -----2
                                             -----3
                                                  -----4
                                                       -----5


Splay 0

                                        -----0
                                   -----1
                              -----2
                                             -----3
                                                  -----4
                                                       -----5


Splay 0

                                                  -----0
                                        -----1
                                   -----2
                                             -----3
                                        -----4
                                                       -----5


Splay 0

                                                       -----0
                                        -----1
                                   -----2
                                             -----3
                                        -----4
                                                  -----5

The splay trees obtained by removing the elements 0,1,2,3,4 in that order is shown below.

for i in 0:4
    delete(i)
end
delete 0
                                                       -----1
                                        -----2
                                             -----3
                                        -----4
                                                  -----5
delete 1
                                                       -----2
                                             -----3
                                        -----4
                                                  -----5
delete 2
                                                       -----3
                                             -----4
                                                  -----5
delete 3
                                                       -----4
                                                  -----5
delete 4
                                                       -----5
Next, we build a splay tree from scratch, by inserting 100,000 distinct random elements. For each insertion, we compute the time. In the figure below, the time taken to insert each element is shown in blue. We then `delete` all the elements one by one in the same random order as used in the construction, and plot the cost of each delete operation, shown in yellow. If you see a spike at the start it is due to jit compilation.
size = 100000
y=zeros(size)

debugprint = false
inp = randperm(size)

using PyPlot
root = splay(0)
ctr = 1
for j in inp 
    y[ctr] = @elapsed    insert(root,j)
    ctr = ctr + 1
end


p = plot( y)

prn(n::splay) = return        # redefine prn() to suppress printing

ctr = size
for j in inp[1:size-1] 
    y[ctr] = @elapsed    delete(j)
    ctr = ctr - 1
end

p= plot( y[2:size])

Finally, we benchmark the insert function.

using BenchmarkTools
root = splay(0)
for j in inp 
    insert(root,j)
end
t = @benchmark (insert(root, 100001), delete(100001))
median(t)
BenchmarkTools.TrialEstimate: 
  time:             16.756 μs
  gctime:           0.000 ns (0.00%)
  memory:           592 bytes
  allocs:           36

The code is available here. It has been tested on julia 0.5.2. This code is purely illustrative, you might want to remove the global root, and recursion from the code above.