Find Distance in a Binary Tree

Updated on 28 May, 2025
Find Distance in a Binary Tree header image

Problem Statement

The objective is to calculate the distance between two specified nodes within a binary tree. Each node in the tree holds a unique integer value, and the tree's structure is given as its root. The "distance" between two nodes is defined by the number of edges that form the path connecting them. If we consider the tree as an unweighted graph where edges represent direct parent-child links between nodes, the task is essentially to determine the shortest path, in terms of edges, from one node to another.

Examples

Example 1

Input:

root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 0

Output:

3

Explanation:

There are 3 edges between 5 and 0: 5-3-1-0.

Example 2

Input:

root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 7

Output:

2

Explanation:

There are 2 edges between 5 and 7: 5-2-7.

Example 3

Input:

root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 5

Output:

0

Explanation:

The distance between a node and itself is 0.

Constraints

  • The number of nodes in the tree is in the range [1, 104].
  • 0 <= Node.val <= 109
  • All Node.val are unique.
  • p and q are values in the tree.

Approach and Intuition

When calculating the distance between two nodes in a binary tree, the optimal method involves finding their Lowest Common Ancestor (LCA). The LCA of two nodes p and q in a tree is the deepest node that has both p and q as descendants (where a node can be a descendant of itself).

  1. Identify the LCA:

    • Start at the root and recursively determine the side (left or right) of the tree that each value resides in. The node where paths for p and q diverge is the LCA.
  2. Calculate the distance:

    • Once the LCA is established, the problem reduces to finding the distance from the LCA to each node (p and q). This can be achieved by starting at the LCA and counting the steps needed to reach each of the nodes.
    • The total distance between p and q is then the sum of the distances from the LCA to p and from the LCA to q.

Simplified Steps:

  • Find the path from the root to p and the path from the root to q.
  • The last common node in both paths is the LCA.
  • From the LCA, count the number of edges to p and q.
  • Sum the counts to get the total distance.

This approach effectively utilizes the structure of the binary tree and its properties related to node descent, allowing for efficient computation of the distance. The constraints ensure that every value in the tree is unique, thus omitting the complexity of handling duplicate values and simplifying the search process for p and q.

Solutions

  • C++
  • Java
  • Python
cpp
class Solution {
public:
    int calculateDistance(TreeNode* root, int start, int end) {
        return findDistance(root, start, end, 0);
    }

private:
    int findDistance(TreeNode* node, int start, int end, int currentDepth) {
        if (node == nullptr || start == end) {
            return 0;
        }

        if (node->val == start || node->val == end) {
            int distanceLeft = findDistance(node->left, start, end, 1);
            int distanceRight = findDistance(node->right, start, end, 1);
            return (distanceLeft > 0 || distanceRight > 0) ? max(distanceLeft, distanceRight) : currentDepth;
        }

        int leftDepth = findDistance(node->left, start, end, currentDepth + 1);
        int rightDepth = findDistance(node->right, start, end, currentDepth + 1);
        int combinedDepth = leftDepth + rightDepth;

        if (leftDepth != 0 && rightDepth != 0) {
            combinedDepth -= 2 * currentDepth;
        }
        return combinedDepth;
    }
};

In the provided C++ solution, the goal is to find the distance between two nodes in a binary tree. The solution implements a recursive approach to determine the shortest path between the specified start and end node values.

  • The primary function, calculateDistance, initiates the process.
  • It calls a helper function, findDistance, passing the tree's root and the node values start and end alongside an initial depth of zero.
  • Inside findDistance, the function first checks whether the current node is nullptr or if the start and end values are the same, both conditions under which it returns zero as the distance cannot be calculated.
  • The function recursively checks each child node (left and right) to determine if the current node's value matches either start or end. If a match is found, it checks the child nodes for distance, returning the greater of the left or right distances if either is found, or the current depth if neither child holds the other value.
  • As it recursively explores the tree, it maintains a currentDepth value that helps in determining the actual distance once both nodes are found.
  • If both left and right recursive calls return non-zero values, indicating that each child node leads to one of the target nodes, it calculates the cumulative depth, adjusting for double-counted depths using currentDepth.

The approach cleverly handles multiple recursive depths, using them to back-calculate the needed response, thus effectively and efficiently determining the distance between the two target nodes within the binary tree's structure. This ensures that you can ascertain the number of edges between two nodes, assuming the tree's structure and node values are defined.

java
public class Solution {

    public int calculateDistance(TreeNode root, int n1, int n2) {
        return findDepth(root, n1, n2, 0);
    }

    private int findDepth(TreeNode node, int n1, int n2, int level) {
        if (node == null || n1 == n2) {
            return 0;
        }

        if (node.val == n1 || node.val == n2) {
            int leftDepth = findDepth(node.left, n1, n2, 1);
            int rightDepth = findDepth(node.right, n1, n2, 1);

            return (leftDepth > 0 || rightDepth > 0) ? Math.max(leftDepth, rightDepth) : level;
        }

        int left = findDepth(node.left, n1, n2, level + 1);
        int right = findDepth(node.right, n1, n2, level + 1);
        int calculatedDistance = left + right;

        if (left != 0 && right != 0) {
            calculatedDistance -= 2 * level;
        }

        return calculatedDistance;
    }
}

In this solution, you calculate the distance between two nodes in a binary tree using recursive depth-first search (DFS) in Java. The algorithm operates by comparing the node values with target values (n1 and n2). Here's how the algorithm in the Solution class approaches the problem:

  • The calculateDistance method initializes the process by calling findDepth with the root of the tree and the target node values, starting at level 0.
  • The findDepth method includes several checks:
    • It returns 0 if the current node is null or the two target nodes are the same.
    • If the current node's value matches either of the target node values (n1 or n2), it calculates the depth from the current node to the other target node through left and right subtrees recursively.
    • If the method finds the node on both sides (left and right), it assumes the current node is the lowest common ancestor (LCA) and calculates the distance between the nodes based on their levels.
  • The depth calculation considers branching into left and right subtrees, increasing the level by 1 with each recursive call.
  • If both target nodes are found in different branches of the same subtree, the function adjusts the result to account for double-counting of the depth to the LCA.

The primary focus is on navigating through the tree, leveraging recursion to explore all possible paths until the nodes are located, and then using the nodes' levels to compute the distance accurately. This approach efficiently handles various tree structures and allows for the dynamic identification of node distances based on their relationships.

python
class Solution:
    def getDistance(self, root, node1, node2):
        return self._findDepth(root, node1, node2, 0)

    def _findDepth(self, node, val1, val2, level):
        if node is None or val1 == val2:
            return 0

        if node.val == val1 or node.val == val2:
            left_depth = self._findDepth(node.left, val1, val2, 1)
            right_depth = self._findDepth(node.right, val1, val2, 1)

            return max(left_depth, right_depth) if left_depth > 0 or right_depth > 0 else level

        left_depth = self._findDepth(node.left, val1, val2, level + 1)
        right_depth = self._findDepth(node.right, val1, val2, level + 1)
        total_depth = left_depth + right_depth

        if left_depth != 0 and right_depth != 0:
            total_depth -= 2 * level

        return total_depth

The provided solution written in Python implements a method to find the distance between two nodes in a binary tree. The main function getDistance starts the process by calling _findDepth, which traverses the tree recursively.

  • The algorithm checks if either of the nodes is None or if the two node values are the same, returning a distance of 0 in these cases.
  • If the current node equals one of the target nodes (node1 or node2), the algorithm recursively searches both left and right child nodes for the depth of the other target node starting with a depth of 1.
  • If the current node's value matches neither of the target node values, it recursively explores both left and right subtrees, incrementing the level with each recursive call.
  • The algorithm calculates the total depth by summing the depths returned from the left and right subtree searches.
  • If both subtrees returned a non-zero depth (meaning both nodes were found in different branches of the current node), the algorithm adjusts the total depth by subtracting twice the current level to account for the distance back to the common ancestor.

This strategy ensures an efficient traversal of the tree using a divide-and-conquer approach, effectively calculating the shortest path between the two specified nodes avoiding redundancy. The output is the minimal number of edges between node1 and node2 in the binary tree.

Comments

No comments yet.