Home About MIT CC BY-SA 4.0 RSS

Iterative Tree Node Counting

I love coding, but one of the major hurdles I haven’t been able to pass is the sluggish speed at which I actually write code – this is a death sentence during coding interviews. In order to combat this, I’ve been working on competitive programming over the last several months, with the result that I’ve seen some moderate improvement to my coding. One of the great strengths of these contests is that in addition to gaining raw speed, I’ve been exposed to a variety of data structures and algorithm implementations which I would have never otherwise have come across. Such was the case when I got stuck on the following problem:

Given an N-node connected tree, with N - 1 edges and three randomly picked nodes a, b, c, find the expected value of d(a, b) + d(b, c) + d(c, a). The catch here is that the distance function d() is subject to change after some amount of time, and the expected value must be recalculated. As per math, we know expected values are of the form E(X) = SUM(xP(x)), which, given that N may be quite large, would mean a rather costly naive algorithmic approach. As such, we would like to reduce the recalculation time to O(1) if possible, with a precalculation penalty to determine the original expected value.

As it turns out, this is possible – however, in order to do this, we must calculate the number of nodes in the subtree rooted at each node: Consider a tree with root R:

     R
    / \
   A   B
  /
 C

Here, we wish to build the data structure { R : 4, A : 2, B : 1, C : 1 }. This is very easily done piggybacking off of DFS:

CACHE = {}
def count_nodes_recursive(r):
	global CACHE
	CACHE[r] = sum([ count_nodes_recursive(c) for c in r.children ]) + 1
	return CACHE[r]

Unfortunately, recursion and Python do not mix very well – the default recursion limit is set to 1000:

def foo(c, l):
	if c < l:
		foo(c + 1, l)
foo(1, 1000)

For this problem, the tree depth could feasibly exceed the recursion limit (sys.getrecursionlimit()), and therefore, we would need to another approach to solve this problem. Typically, we would attempt to allocate our own stack, and manually feed the stack the nodes to be searched:

def dfs(r):
	from collections import deque as stack
	l = []
	q = stack([ r ])
	while len(q):
		r = q.pop()
		l.append(r)
		for c in r.children:
			q.append(c)
	assert(l[0] == r)
	return l

Unfortunately, as seen by the assertion, what doesn’t change here is the fact that we are iterating from the root nodes to the leaves of the tree. However, when we consider count_nodes_recursive(), the root node result needs the results from the leaves first: CACHE[r] = sum(...) + 1. This suggests that we will need to reverse the stack process order:

def reverse_dfs(r):
	l = dfs(r)[::-1]
	assert(l[-1] == r)
	return l

In this way, we guarantee that every node processed in the order given by reverse_dfs() is processed after all its children are processed. We can utilize this ordering to percolate our node count towards the root of the tree by utilizing a global variable CACHE:

CACHE = {}
def count_nodes_iterative(r):
	global CACHE
	for n in reverse_dfs(r):
		CACHE[n] = 1
		for c in n.children:
			CACHE[n] += CACHE[c]
	return CACHE[r]

Due to our reverse queue ordering and because c is guaranteed to be a child of n, we are able to properly propagate the values processed in the CACHE earlier in the queue upwards. And we’re done!

Full code (including question-specific considerations) are on pastebin.