4 minute read

Problem Statement

1740

Intuition

When I first thought about solving the problem of finding the distance between two nodes in a binary tree, I realized that the key challenge was to locate their Lowest Common Ancestor (LCA). Once I find the LCA, the distance between the two nodes can be easily computed by summing up their respective distances from this ancestor.

Approach

  1. Finding the LCA:

    • I implemented a recursive function find_lca to traverse the binary tree and identify the LCA of the two nodes, p and q.
    • This function returns True if either of the nodes is found in the subtree rooted at the current node, and it sets self.lca when both nodes are found in different subtrees of the current node.
  2. Calculating Distances:

    • Once the LCA is found, I used another recursive function dfs to compute the distance from the LCA to each of the two nodes, p and q.
    • The distance between p and q is then the sum of the distances from the LCA to each of these nodes.

Complexity

  • Time complexity:

    • Finding the LCA takes (O(n)) time in the worst case, where (n) is the number of nodes in the tree, since we potentially visit every node once.
    • Calculating the distances from the LCA to each node also takes (O(n)) in the worst case, though it’s generally faster since we start from the LCA.
    • Therefore, the overall time complexity is (O(n)).
  • Space complexity:

    • The space complexity is (O(h)), where (h) is the height of the tree. This is due to the recursive call stack used in both find_lca and dfs functions.

Code

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
        self.lca = None
        def find_lca(node):
            if not node:
                return False
            L = find_lca(node.left)
            R = find_lca(node.right)
            mid = node.val == q or node.val == p
            if L + R + mid >= 2:
                self.lca = node
                return True
            return L or R or mid

        find_lca(root)

        def dfs(node, target, dist):
            if not node:
                return 0
            if node.val == target:
                return dist
            return dfs(node.left, target, dist + 1) or \
                dfs(node.right, target, dist + 1)

        dist_q = dfs(self.lca, q, 0)
        dist_p = dfs(self.lca, p, 0)
        return dist_q + dist_p

Editorial

class Solution:
    def findDistance(self, root, p, q):
        # Find the lowest common ancestor of p and q.
        lca = self.__find_LCA(root, p, q)
        return self.__depth(lca, p) + self.__depth(lca, q)

    # Function to find the LCA of the given nodes.
    def __find_LCA(self, root, p, q):
        if root is None or root.val == p or root.val == q:
            return root
        left = self.__find_LCA(root.left, p, q)
        right = self.__find_LCA(root.right, p, q)
        if left is not None and right is not None:
            return root
        return left if left is not None else right

    # Function to find the depth of the node with respect to LCA.
    def __depth(self, root, target, current_depth=0):
        # Node not found
        if root is None:
            return -1
        if root.val == target:
            return current_depth

        # Check left subtree
        left_depth = self.__depth(root.left, target, current_depth + 1)
        if left_depth != -1:
            return left_depth

        # If not in left subtree, it is guaranteed to be in right subtree
        return self.__depth(root.right, target, current_depth + 1)
class Solution:
    def findDistance(self, root, p, q):
        lca = self._find_LCA(root, p, q)
        bfs = deque([lca])
        distance = 0
        depth = 0
        foundp = False
        foundq = False
        while bfs and (not foundp or not foundq):
            size = len(bfs)
            for i in range(size):
                node = bfs.popleft()  # Dequeue the node
                if node.val == p:
                    distance += depth
                    foundp = True
                if node.val == q:
                    distance += depth
                    foundq = True
                if node.left:
                    bfs.append(node.left)  # Enqueue left child
                if node.right:
                    bfs.append(node.right)  # Enqueue right child
            depth += 1
        return distance

    def _find_LCA(self, root, p, q):
        if root is None or root.val == p or root.val == q:
            return root
        left = self._find_LCA(root.left, p, q)
        right = self._find_LCA(root.right, p, q)
        if left and right:
            return root
        return left if left else right

Approach 3: One pass (Based on Lowest Common Ancestor)

class Solution:
    def findDistance(self, root, p, q):
        return self.__distance(root, p, q, 0)

    # Private helper function
    def __distance(self, root, p, q, depth):
        if root is None or p == q:
            return 0

        # If either p or q is found, calculate the ret_distance as the maximum
        # of depth and ret_distance value for left and right subtrees.
        if root.val == p or root.val == q:
            left = self.__distance(root.left, p, q, 1)
            right = self.__distance(root.right, p, q, 1)

            return max(left, right) if left > 0 or right > 0 else depth

        # Otherwise, calculate the ret_distance as sum of ret_distance of left
        # and right subtree.
        left = self.__distance(root.left, p, q, depth + 1)
        right = self.__distance(root.right, p, q, depth + 1)
        ret_distance = left + right

        # If current node is the LCA, subtract twice of depth.
        if left != 0 and right != 0:
            ret_distance -= 2 * depth

        return ret_distance