Splay Trees in Julia Lang

Let us implement Splay Trees in Julia which is, relatively new and popular, dynamically typed language with multiple dispatch.

We model a node in the splay tree as an abstract data type. Each node contains a single integer data value and two references to the left and the right subtrees. The subtrees can be empty, as in the case of leaf nodes, and we model this by using the Nil type. This is similar to the Maybe type in Haskell, and it allows us to specify that a subtree is missing. We construct a node, root, the root of the splay tree. We will use it as a global variable. We assume that only integral data is stored and all elements are distinct. References are passed by values in Julia so we use a global variable to change state inside a function.

type Nil end
typealias MayBe{T} Union{T,Nil}

type splay
data::Int
left::MayBe{splay}
right::MayBe{splay}
end

splay( data :: Int) = splay(data, Nil(), Nil() )
root = splay(0)                # splay tree the root is root


The splay tree supports, find, insert, and delete. The amortized cost of each of these operations is at most $O(\log{n})$ for a sufficiently long sequence of operations.

Find

Recall that the ordering property in a binary search states that at every node, the data stored in the left subtree is smaller than the data stored at the node, which in turn is smaller than any data stored in the right subtree. So, to search for a key, we compare the key with the value stored at the node and branch accordingly. The code below is recursive and is not the most efficient. If the key is in the tree then we splay the node containing the key, else we splay the parent of the node where the search failed.

function find(node::splay, data)
global root
if (data == root.data) || (typeof(node) == Nil)
return
end
if (data == node.data)
splay!(node)
return
end
(data < node.data ? (typeof(node.left) != Nil ? find(node.left, data) : splay!(node)) :
(typeof(node.right) != Nil ? find(node.right, data) : splay!(node)) )
end


Delete

To delete a node containing the key, we find the key in the splay tree. This reorganizes the tree by calling the splay function. If the key is in the tree then the key is at the root. We then delete the root as in BST. If the key is not in the root, we do nothing. Several cases are considered. The subtree needs to be reattached to the original tree, so we find the parent and stitch the subtree back as a child.

function delete(key)
global root
global debugprint

if (debugprint == true)
println("delete ", key)
end

find(root, key)

# left subtree is Nil, right not Nil, left subtree of the right child is Nil
if (typeof(root.left)==Nil && typeof(root.right)!=Nil && typeof(root.right.left)==Nil)
root = root.right
prn(root)
end

# symmetric case to above
if (typeof(root.right)==Nil && typeof(root.left)!=Nil && typeof(root.left.right)==Nil)
root = root.left
prn(root)
end

if (root.data == key) && (typeof(root.left)!=Nil)      # left subtree is not Nil
n = root.left
while ( typeof(n.right)!=Nil)            # get the max in the left subtree
n = n.right
end
p = findParent(n.data)
p.right = n.left
root.data = n.data
prn(root)
return
end

if (root.data == key) && (typeof(root.right)!=Nil)      # left subtree is Nil, right is !Nil
n = root.right
while ( typeof(n.left)!=Nil)            # get the min in the left subtree
n = n.left
end
p = findParent(n.data)
p.left = n.right
root.data = n.data
prn(root)
return
end
(root.data == key ? root = Nil : return)        # both subtrees are empty
end


Insert

We use binary search to find the place. A new node is created and splayed all the way to the top using splay!.

function insert(node::splay, data )

if (data < node.data )
if (typeof(node.left)== Nil)
node.left = splay(data, Nil(), Nil() )
splay!( node.left )                # splay the node
return
else
insert(node.left, data)
end
end
if (data > node.data )
if (typeof(node.right)== Nil)
node.right = splay(data, Nil(), Nil())
splay!( node.right )                # splay the node
return
else
insert(node.right, data)
end
end
#return node
end


Splay

The basic operation which does the reorganization of data is called splay. Splaying a node x restructures the subtree rooted at the grandparent of x (if one exists), otherwise it restructures the subtree rooted at the parent of x. Three cases are to be considered. For details of the three rotations, see the Wikipedia entry on splay trees. See lecture 12 in the book by Kozen for details of the analysis.

This is a recursive implementation. The node x is splayed until it is the root of the tree.

function splay!( x )  # splay the tree rooted at node, containing the node x
global root

global debugprint

if (debugprint == true)
println()
println()
println("Splay ", x.data)
println()
prn(root)
end

p = findParent(x.data)
g = findParent(p.data)

#println(" p, g ", p.data, " ", g.data)

if (x.data == root.data)     # x is already at the root
return
end

if (p.data == g.data) && (p.data != x.data)     # no grandparent    # Simple rotation
if (x.data < p.data)
p.left = x.right
x.right = p
end
if (x.data > p.data)
p.right = x.left
x.left = p
end
root = x
splay!(x)        # recursively splay x to the top
# println(root)
return
end

# x has a parent and a grandparent
# r is the node where the subtree will hang
r = findParent(g.data)

if (x.data < p.data < g.data)                        # Zig Zig
g.left = p.right
p.left = x.right
p.right = g
x.right = p
if (g.data == root.data)
root = x
#println("zero ", root)
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x)        # recursively splay x to the top
return
end

if (x.data > p.data > g.data)                        # Zig Zig
g.right = p.left
p.right = x.left
p.left = g
x.left = p
if (g.data == root.data)
root = x
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x)        # recursively splay x to the top
return
end

if (p.data < x.data < g.data)                        # Zig Zag
p.right = x.left
g.left = x.right
x.left = p
x.right = g
if (g.data == root.data)
root = x
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x)        # recursively splay x to the top
return
end

if (g.data < x.data < p.data)                        # Zig Zag
p.left = x.right
g.right = x.left
x.left = g
x.right = p
if (g.data == root.data)
root = x
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x)        # recursively splay x to the top
return
end
end


Helper Functions

# Print the tree using in order traversal.
# Large trees with more than 16 nodes will generate
# a run time error.

prn(node::splay) = ppst(node, 60)

function ppst(node::splay,  right::Int)
if (typeof(node.left)!=Nil)
ppst(node.left,  right - 5)
end
println( repeat(" ", right-5) * "-----" * string(node.data))
if (typeof(node.right)!=Nil)
ppst(node.right,  right - 5)
end
end

# helper function to find the parent of the node that
# contains the key. We use it to stitch the subtree as a child of
# the parent after a rotation is performed.

function findParent(key)
node, par = root, root
while (node.data != key)
par = node
if (key < node.data)
node = node.left
else
node = node.right
end
end
return par
end


Test Cases

Let us start with a root that contains 0. The results of inserting elements 1,2,..,5 is shown below. If you rotate the screen counterclockwise by 90 degrees, the figure below would be the same as the typical drawing of a binary tree, with root on the top, and the left subtree on the left side, and right subtree to the right of the root.

debugprint = true
for i in 1:5
insert(root,i)
end

Splay 1

-----0
-----1

Splay 1

-----0
-----1

Splay 2

-----0
-----1
-----2

Splay 2

-----0
-----1
-----2

Splay 3

-----0
-----1
-----2
-----3

Splay 3

-----0
-----1
-----2
-----3

Splay 4

-----0
-----1
-----2
-----3
-----4

Splay 4

-----0
-----1
-----2
-----3
-----4

Splay 5

-----0
-----1
-----2
-----3
-----4
-----5

Splay 5

-----0
-----1
-----2
-----3
-----4
-----5


find(0) repeatedly splays the node with 0, until it is the new root as shown below.

find(root,0)

Splay 0

-----0
-----1
-----2
-----3
-----4
-----5

Splay 0

-----0
-----1
-----2
-----3
-----4
-----5

Splay 0

-----0
-----1
-----2
-----3
-----4
-----5

Splay 0

-----0
-----1
-----2
-----3
-----4
-----5


The splay trees obtained by removing the elements 0,1,2,3,4 in that order is shown below.

for i in 0:4
delete(i)
end

delete 0
-----1
-----2
-----3
-----4
-----5
delete 1
-----2
-----3
-----4
-----5
delete 2
-----3
-----4
-----5
delete 3
-----4
-----5
delete 4
-----5

Next, we build a splay tree from scratch, by inserting 100,000 distinct random elements. For each insertion, we compute the time. In the figure below, the time taken to insert each element is shown in blue. We then delete all the elements one by one in the same random order as used in the construction, and plot the cost of each delete operation, shown in yellow. If you see a spike at the start it is due to jit compilation.
size = 100000
y=zeros(size)

debugprint = false
inp = randperm(size)

using PyPlot
root = splay(0)
ctr = 1
for j in inp
y[ctr] = @elapsed    insert(root,j)
ctr = ctr + 1
end

p = plot( y)

prn(n::splay) = return        # redefine prn() to suppress printing

ctr = size
for j in inp[1:size-1]
y[ctr] = @elapsed    delete(j)
ctr = ctr - 1
end

p= plot( y[2:size])


Finally, we benchmark the insert function.

using BenchmarkTools
root = splay(0)
for j in inp
insert(root,j)
end
t = @benchmark (insert(root, 100001), delete(100001))
median(t)

BenchmarkTools.TrialEstimate:
time:             16.756 μs
gctime:           0.000 ns (0.00%)
memory:           592 bytes
allocs:           36


The code is available here. It has been tested on julia 0.5.2. This code is purely illustrative, you might want to remove the global root, and recursion from the code above.