Let us implement Splay Trees in Julia which is, relatively new and popular, dynamically typed language with multiple dispatch.
We model a node in the splay tree as an abstract data type. Each node
contains a single integer data value and two references to the left and
the right subtrees. The subtrees can be empty, as in the case of leaf
nodes, and we model this by using the Nil
type. This is similar to the
Maybe
type in Haskell, and it allows us to specify that a subtree is
missing. We construct a node, root
, the root of the splay tree. We
will use it as a global variable. We assume that only integral data is
stored and all elements are distinct. References are passed by values in
Julia so we use a global variable to change state inside a function.
Node ADT
type Nil end
typealias MayBe{T} Union{T,Nil}
type splay
data::Int
left::MayBe{splay}
right::MayBe{splay}
end
splay( data :: Int) = splay(data, Nil(), Nil() )
root = splay(0) # splay tree the root is root
The splay tree supports, find
, insert
, and delete
. The amortized
cost of each of these operations is at most $O(\log{n})$
for a sufficiently long sequence of operations.
Find
Recall that the ordering property in a binary search states that at every node, the data stored in the left subtree is smaller than the data stored at the node, which in turn is smaller than any data stored in the right subtree. So, to search for a key, we compare the key with the value stored at the node and branch accordingly. The code below is recursive and is not the most efficient. If the key is in the tree then we splay the node containing the key, else we splay the parent of the node where the search failed.
function find(node::splay, data)
global root
if (data == root.data) || (typeof(node) == Nil)
return
end
if (data == node.data)
splay!(node)
return
end
(data < node.data ? (typeof(node.left) != Nil ? find(node.left, data) : splay!(node)) :
(typeof(node.right) != Nil ? find(node.right, data) : splay!(node)) )
end
Delete
To delete a node containing the key, we find the key in the splay tree.
This reorganizes the tree by calling the splay
function. If the key is
in the tree then the key is at the root. We then delete the root as in
BST. If the key is not in the root, we do nothing. Several cases are
considered. The subtree needs to be reattached to the original tree, so
we find the parent and stitch the subtree back as a child.
function delete(key)
global root
global debugprint
if (debugprint == true)
println("delete ", key)
end
find(root, key)
# left subtree is Nil, right not Nil, left subtree of the right child is Nil
if (typeof(root.left)==Nil && typeof(root.right)!=Nil && typeof(root.right.left)==Nil)
root = root.right
prn(root)
end
# symmetric case to above
if (typeof(root.right)==Nil && typeof(root.left)!=Nil && typeof(root.left.right)==Nil)
root = root.left
prn(root)
end
if (root.data == key) && (typeof(root.left)!=Nil) # left subtree is not Nil
n = root.left
while ( typeof(n.right)!=Nil) # get the max in the left subtree
n = n.right
end
p = findParent(n.data)
p.right = n.left
root.data = n.data
prn(root)
return
end
if (root.data == key) && (typeof(root.right)!=Nil) # left subtree is Nil, right is !Nil
n = root.right
while ( typeof(n.left)!=Nil) # get the min in the left subtree
n = n.left
end
p = findParent(n.data)
p.left = n.right
root.data = n.data
prn(root)
return
end
(root.data == key ? root = Nil : return) # both subtrees are empty
end
Insert
We use binary search to find the place. A new node is created and
splayed all the way to the top using splay!
.
function insert(node::splay, data )
if (data < node.data )
if (typeof(node.left)== Nil)
node.left = splay(data, Nil(), Nil() )
splay!( node.left ) # splay the node
return
else
insert(node.left, data)
end
end
if (data > node.data )
if (typeof(node.right)== Nil)
node.right = splay(data, Nil(), Nil())
splay!( node.right ) # splay the node
return
else
insert(node.right, data)
end
end
#return node
end
Splay
The basic operation which does the reorganization of data is called
splay. Splaying a node x
restructures the subtree rooted at the
grandparent of x
(if one exists), otherwise it restructures the
subtree rooted at the parent of x
. Three cases are to be considered.
For details of the three rotations, see the Wikipedia entry on splay
trees. See lecture
12 in the book by Kozen for details of the analysis.
This is a recursive implementation. The node x
is splayed until it is
the root of the tree.
function splay!( x ) # splay the tree rooted at node, containing the node x
global root
global debugprint
if (debugprint == true)
println()
println()
println("Splay ", x.data)
println()
prn(root)
end
p = findParent(x.data)
g = findParent(p.data)
#println(" p, g ", p.data, " ", g.data)
if (x.data == root.data) # x is already at the root
return
end
if (p.data == g.data) && (p.data != x.data) # no grandparent # Simple rotation
if (x.data < p.data)
p.left = x.right
x.right = p
end
if (x.data > p.data)
p.right = x.left
x.left = p
end
root = x
splay!(x) # recursively splay x to the top
# println(root)
return
end
# x has a parent and a grandparent
# r is the node where the subtree will hang
r = findParent(g.data)
if (x.data < p.data < g.data) # Zig Zig
g.left = p.right
p.left = x.right
p.right = g
x.right = p
if (g.data == root.data)
root = x
#println("zero ", root)
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x) # recursively splay x to the top
return
end
if (x.data > p.data > g.data) # Zig Zig
g.right = p.left
p.right = x.left
p.left = g
x.left = p
if (g.data == root.data)
root = x
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x) # recursively splay x to the top
return
end
if (p.data < x.data < g.data) # Zig Zag
p.right = x.left
g.left = x.right
x.left = p
x.right = g
if (g.data == root.data)
root = x
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x) # recursively splay x to the top
return
end
if (g.data < x.data < p.data) # Zig Zag
p.left = x.right
g.right = x.left
x.left = g
x.right = p
if (g.data == root.data)
root = x
else
(r.data > x.data? r.left =x : r.right =x)
end
splay!(x) # recursively splay x to the top
return
end
end
Helper Functions
# Print the tree using in order traversal.
# Large trees with more than 16 nodes will generate
# a run time error.
prn(node::splay) = ppst(node, 60)
function ppst(node::splay, right::Int)
if (typeof(node.left)!=Nil)
ppst(node.left, right - 5)
end
println( repeat(" ", right-5) * "-----" * string(node.data))
if (typeof(node.right)!=Nil)
ppst(node.right, right - 5)
end
end
# helper function to find the parent of the node that
# contains the key. We use it to stitch the subtree as a child of
# the parent after a rotation is performed.
function findParent(key)
node, par = root, root
while (node.data != key)
par = node
if (key < node.data)
node = node.left
else
node = node.right
end
end
return par
end
Test Cases
Let us start with a root that contains 0
. The results of inserting
elements 1,2,..,5
is shown below. If you rotate the screen
counterclockwise by 90 degrees, the figure below would be the same as
the typical drawing of a binary tree, with root on the top, and the left
subtree on the left side, and right subtree to the right of the root.
debugprint = true
for i in 1:5
insert(root,i)
end
Splay 1
-----0
-----1
Splay 1
-----0
-----1
Splay 2
-----0
-----1
-----2
Splay 2
-----0
-----1
-----2
Splay 3
-----0
-----1
-----2
-----3
Splay 3
-----0
-----1
-----2
-----3
Splay 4
-----0
-----1
-----2
-----3
-----4
Splay 4
-----0
-----1
-----2
-----3
-----4
Splay 5
-----0
-----1
-----2
-----3
-----4
-----5
Splay 5
-----0
-----1
-----2
-----3
-----4
-----5
find(0)
repeatedly splays the node with 0, until it is the new root as
shown below.
find(root,0)
Splay 0
-----0
-----1
-----2
-----3
-----4
-----5
Splay 0
-----0
-----1
-----2
-----3
-----4
-----5
Splay 0
-----0
-----1
-----2
-----3
-----4
-----5
Splay 0
-----0
-----1
-----2
-----3
-----4
-----5
The splay trees obtained by removing the elements 0,1,2,3,4
in that
order is shown below.
for i in 0:4
delete(i)
end
delete 0
-----1
-----2
-----3
-----4
-----5
delete 1
-----2
-----3
-----4
-----5
delete 2
-----3
-----4
-----5
delete 3
-----4
-----5
delete 4
-----5
size = 100000
y=zeros(size)
debugprint = false
inp = randperm(size)
using PyPlot
root = splay(0)
ctr = 1
for j in inp
y[ctr] = @elapsed insert(root,j)
ctr = ctr + 1
end
p = plot( y)
prn(n::splay) = return # redefine prn() to suppress printing
ctr = size
for j in inp[1:size-1]
y[ctr] = @elapsed delete(j)
ctr = ctr - 1
end
p= plot( y[2:size])
Finally, we benchmark the insert
function.
using BenchmarkTools
root = splay(0)
for j in inp
insert(root,j)
end
t = @benchmark (insert(root, 100001), delete(100001))
median(t)
BenchmarkTools.TrialEstimate:
time: 16.756 μs
gctime: 0.000 ns (0.00%)
memory: 592 bytes
allocs: 36
The code is available here. It has been tested on julia 0.5.2.
This code is purely illustrative, you might want to remove the
global root
, and recursion from the code above.