# Iterative Tree Node Counting

I love coding, but one of the major hurdles I haven’t been able to pass is the sluggish speed at which I actually write code – this is a death sentence during coding interviews. In order to combat this, I’ve been working on competitive programming over the last several months, with the result that I’ve seen some moderate improvement to my coding. One of the great strengths of these contests is that in addition to gaining raw speed, I’ve been exposed to a variety of data structures and algorithm implementations which I would have never otherwise have come across. Such was the case when I got stuck on the following problem:

Given an *N*-node connected tree, with *N - 1* edges and three randomly picked
nodes `a`

, `b`

, `c`

, find the expected value of `d(a, b) + d(b, c) + d(c, a)`

.
The catch here is that the distance function `d()`

is subject to change after
some amount of time, and the expected value must be recalculated. As per math,
we know expected values are of
the form `E(X) = SUM(xP(x))`

, which, given that *N* may be quite large, would
mean a rather costly naive algorithmic approach. As such, we would like to
reduce the recalculation time to *O*(1) if possible, with a precalculation
penalty to determine the original expected value.

As it turns out, this is possible – however, in order to do this, we must
calculate *the number of nodes in the subtree rooted at each node*: Consider a
tree with root `R`

:

```
R
/ \
A B
/
C
```

Here, we wish to build the data structure `{ R : 4, A : 2, B : 1, C : 1 }`

. This
is very easily done piggybacking off of DFS:

```
CACHE = {}
def count_nodes_recursive(r):
global CACHE
CACHE[r] = sum([ count_nodes_recursive(c) for c in r.children ]) + 1
return CACHE[r]
```

Unfortunately, recursion and Python do not mix very well – the default recursion limit is set to 1000:

```
def foo(c, l):
if c < l:
foo(c + 1, l)
foo(1, 1000)
```

For this problem, the tree depth could feasibly exceed the recursion limit
(`sys.getrecursionlimit()`

), and therefore, we would need to another approach to
solve this problem. Typically, we would attempt to allocate our own stack, and
manually feed the stack the nodes to be searched:

```
def dfs(r):
from collections import deque as stack
l = []
q = stack([ r ])
while len(q):
r = q.pop()
l.append(r)
for c in r.children:
q.append(c)
assert(l[0] == r)
return l
```

Unfortunately, as seen by the assertion, what *doesn’t* change here is the fact
that we are iterating from the *root* nodes to the *leaves* of the tree.
However, when we consider `count_nodes_recursive()`

, the *root* node result
needs the results from the *leaves* first: `CACHE[r] = sum(...) + 1`

. This
suggests that we will need to reverse the stack process order:

```
def reverse_dfs(r):
l = dfs(r)[::-1]
assert(l[-1] == r)
return l
```

In this way, we guarantee that every node processed in the order given by
`reverse_dfs()`

is processed *after* all its children are processed. We can
utilize this ordering to percolate our node count towards the root of the tree
by utilizing a global variable `CACHE`

:

```
CACHE = {}
def count_nodes_iterative(r):
global CACHE
for n in reverse_dfs(r):
CACHE[n] = 1
for c in n.children:
CACHE[n] += CACHE[c]
return CACHE[r]
```

Due to our reverse queue ordering and because `c`

is guaranteed to be a child of
`n`

, we are able to properly propagate the values processed in the CACHE earlier
in the queue upwards. And we’re done!

Full code (including question-specific considerations) are on pastebin.