Problem of The Day: Find Distance in a Binary Tree
Problem Statement
Intuition
When I first thought about solving the problem of finding the distance between two nodes in a binary tree, I realized that the key challenge was to locate their Lowest Common Ancestor (LCA). Once I find the LCA, the distance between the two nodes can be easily computed by summing up their respective distances from this ancestor.
Approach
-
Finding the LCA:
- I implemented a recursive function
find_lca
to traverse the binary tree and identify the LCA of the two nodes,p
andq
. - This function returns
True
if either of the nodes is found in the subtree rooted at the current node, and it setsself.lca
when both nodes are found in different subtrees of the current node.
- I implemented a recursive function
-
Calculating Distances:
- Once the LCA is found, I used another recursive function
dfs
to compute the distance from the LCA to each of the two nodes,p
andq
. - The distance between
p
andq
is then the sum of the distances from the LCA to each of these nodes.
- Once the LCA is found, I used another recursive function
Complexity
-
Time complexity:
- Finding the LCA takes (O(n)) time in the worst case, where (n) is the number of nodes in the tree, since we potentially visit every node once.
- Calculating the distances from the LCA to each node also takes (O(n)) in the worst case, though it’s generally faster since we start from the LCA.
- Therefore, the overall time complexity is (O(n)).
-
Space complexity:
- The space complexity is (O(h)), where (h) is the height of the tree. This is due to the recursive call stack used in both
find_lca
anddfs
functions.
- The space complexity is (O(h)), where (h) is the height of the tree. This is due to the recursive call stack used in both
Code
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
self.lca = None
def find_lca(node):
if not node:
return False
L = find_lca(node.left)
R = find_lca(node.right)
mid = node.val == q or node.val == p
if L + R + mid >= 2:
self.lca = node
return True
return L or R or mid
find_lca(root)
def dfs(node, target, dist):
if not node:
return 0
if node.val == target:
return dist
return dfs(node.left, target, dist + 1) or \
dfs(node.right, target, dist + 1)
dist_q = dfs(self.lca, q, 0)
dist_p = dfs(self.lca, p, 0)
return dist_q + dist_p
Editorial
Approach 1: Brute Force (Lowest Common Ancestor and Depth-First Search)
class Solution:
def findDistance(self, root, p, q):
# Find the lowest common ancestor of p and q.
lca = self.__find_LCA(root, p, q)
return self.__depth(lca, p) + self.__depth(lca, q)
# Function to find the LCA of the given nodes.
def __find_LCA(self, root, p, q):
if root is None or root.val == p or root.val == q:
return root
left = self.__find_LCA(root.left, p, q)
right = self.__find_LCA(root.right, p, q)
if left is not None and right is not None:
return root
return left if left is not None else right
# Function to find the depth of the node with respect to LCA.
def __depth(self, root, target, current_depth=0):
# Node not found
if root is None:
return -1
if root.val == target:
return current_depth
# Check left subtree
left_depth = self.__depth(root.left, target, current_depth + 1)
if left_depth != -1:
return left_depth
# If not in left subtree, it is guaranteed to be in right subtree
return self.__depth(root.right, target, current_depth + 1)
Approach 2: Lowest Common Ancestor and Breadth-First Search
class Solution:
def findDistance(self, root, p, q):
lca = self._find_LCA(root, p, q)
bfs = deque([lca])
distance = 0
depth = 0
foundp = False
foundq = False
while bfs and (not foundp or not foundq):
size = len(bfs)
for i in range(size):
node = bfs.popleft() # Dequeue the node
if node.val == p:
distance += depth
foundp = True
if node.val == q:
distance += depth
foundq = True
if node.left:
bfs.append(node.left) # Enqueue left child
if node.right:
bfs.append(node.right) # Enqueue right child
depth += 1
return distance
def _find_LCA(self, root, p, q):
if root is None or root.val == p or root.val == q:
return root
left = self._find_LCA(root.left, p, q)
right = self._find_LCA(root.right, p, q)
if left and right:
return root
return left if left else right
Approach 3: One pass (Based on Lowest Common Ancestor)
class Solution:
def findDistance(self, root, p, q):
return self.__distance(root, p, q, 0)
# Private helper function
def __distance(self, root, p, q, depth):
if root is None or p == q:
return 0
# If either p or q is found, calculate the ret_distance as the maximum
# of depth and ret_distance value for left and right subtrees.
if root.val == p or root.val == q:
left = self.__distance(root.left, p, q, 1)
right = self.__distance(root.right, p, q, 1)
return max(left, right) if left > 0 or right > 0 else depth
# Otherwise, calculate the ret_distance as sum of ret_distance of left
# and right subtree.
left = self.__distance(root.left, p, q, depth + 1)
right = self.__distance(root.right, p, q, depth + 1)
ret_distance = left + right
# If current node is the LCA, subtract twice of depth.
if left != 0 and right != 0:
ret_distance -= 2 * depth
return ret_distance