Open In App

Number of K-length paths in a Tree

Last Updated : 12 Apr, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a tree of N nodes and an integer K, the task is to find the total number of paths having length K.

Examples:

Input: N = 5, K = 2
tree =           1
                   /  \
                 2    5
               /  \
             3    4
Output: 4
Explanation: The paths having length 2 are
1 – 2 – 3, 1 – 2 – 4, 2 – 1 – 5, 3 – 2 – 4

Input: N = 2, K = 2
tree =       1
                /
              2
Output: 0
Explanation: There is no path in the tree having length 2.

Intuition: The main idea is to find the K-length paths from each node and add them.

  1. Find the number of K-length paths ‘originating’ from a given node ‘node’. ‘Originating’ here means, ‘node’ will have the smallest depth among all the nodes in the path. For example, 2-length paths originating from 1 are shown in the below diagram.
  2. Sum above value for all the nodes and it will be the required answer.

Naive Approach: To compute the K-length paths originating from ‘node’ two DFS are used. Say this entire process is: paths_originating_from(node)

  1. Suppose ‘node’ has multiple children and currently child ‘u’ is being processed.
  2. For all the previous children, the frequency of nodes at a particular depth has been calculated. More formally, freq[d] gives the number of nodes at depth ‘d’ when only children of ‘node’ before ‘u’ have been processed.
  3. If there is a node ‘x’ at depth ‘d’, number of K length paths originating from ‘node’ and passing through ‘x’ will be freq[K – d].
  4. The first DFSF would contribute to the final answer, and the second DFS would update the freq[] array for future use.
  5. Sum-up ‘paths_originating_from(node)’ for all nodes of the tree, this will be the required answer.

See the image below to understand the 2nd point better.

Below is the implementation of the above approach.

C++




// C++ code to implement above approach
#include <bits/stdc++.h>
using namespace std;
 
int mx_depth = 0, ans = 0;
int N, K;
vector<int> freq;
vector<vector<int> > g;
 
// This dfs is responsible for calculating ans
// and updating freq vector
void dfs(int node, int par, int depth,
         bool contri)
{
    if (depth > K)
        return;
    mx_depth = max(mx_depth, depth);
 
    if (contri) {
        ans += freq[K - depth];
    }
    else {
        freq[depth]++;
    }
 
    for (auto nebr : g[node]) {
        if (nebr != par) {
            dfs(nebr, node, depth + 1,
                contri);
        }
    }
}
 
// Function to calculate K length paths
// originating from node
void paths_originating_from(int node,
                            int par)
{
    mx_depth = 0;
    freq[0] = 1;
 
    // For every not-removed nebr,
    // calculate its contribution,
    // then update freq vector for it
    for (auto nebr : g[node]) {
        if (nebr != par) {
            dfs(nebr, node, 1, true);
            dfs(nebr, node, 1, false);
        }
    }
 
    // Re-initialize freq vector
    for (int i = 0; i <= mx_depth; ++i) {
        freq[i] = 0;
    }
 
    // Repeat the same for children
    for (auto nebr : g[node]) {
        if (nebr != par) {
            paths_originating_from(nebr,
                                   node);
        }
    }
}
 
// Utility method to add edges to tree
void edge(int a, int b)
{
    a--;
    b--;
    g[a].push_back(b);
    g[b].push_back(a);
}
 
// Driver code
int main()
{
    N = 5, K = 2;
    freq = vector<int>(N);
    g = vector<vector<int> >(N);
 
    edge(1, 2);
    edge(1, 5);
    edge(2, 3);
    edge(2, 4);
   
    paths_originating_from(0, -1);
    cout << ans << endl;
}


Java




import java.io.*;
import java.util.*;
 
public class Main {
  static int N, K, mx_depth = 0, ans = 0;
  ;
  static void edge(int a, int b,
                   ArrayList<ArrayList<Integer> > g)
  {
    a--;
    b--;
    g.get(a).add(b);
    g.get(b).add(a);
  }
  // This dfs is responsible for calculating ans
  // and updating freq vector
  static void dfs(int node, int par, int depth,
                  boolean contri, ArrayList<Integer> freq,
                  ArrayList<ArrayList<Integer> > g)
  {
    if (depth > K)
      return;
    mx_depth = Math.max(mx_depth, depth);
 
    if (contri) {
      ans += freq.get(K - depth);
    }
    else {
      freq.set(depth, freq.get(depth) + 1);
    }
 
    for (int i = 0; i < g.get(node).size(); i++) {
      int nebr = g.get(node).get(i);
      if (nebr != par) {
        dfs(nebr, node, depth + 1, contri, freq, g);
      }
    }
  }
 
  // Function to calculate K length paths
  // originating from node
  static void
    paths_originating_from(int node, int par,
                           ArrayList<Integer> freq,
                           ArrayList<ArrayList<Integer> > g)
  {
    mx_depth = 0;
    freq.set(0, 1);
 
    // For every not-removed nebr,
    // calculate its contribution,
    // then update freq vector for it
 
    for (int i = 0; i < g.get(node).size(); i++) {
      int nebr = g.get(node).get(i);
      if (nebr != par) {
        dfs(nebr, node, 1, true, freq, g);
        dfs(nebr, node, 1, false, freq, g);
      }
    }
 
    // Re-initialize freq vector
    for (int i = 0; i <= mx_depth; ++i) {
      freq.set(i, 0);
    }
 
    // Repeat the same for children
    for (int i = 0; i < g.get(node).size(); i++) {
      int nebr = g.get(node).get(i);
      if (nebr != par) {
        paths_originating_from(nebr, node, freq, g);
      }
    }
  }
  public static void main(String[] args)
  {
    N = 5;
    K = 2;
    ArrayList<Integer> freq = new ArrayList<Integer>();
    for (int i = 0; i < N; i++)
      freq.add(0);
    ArrayList<ArrayList<Integer> > g
      = new ArrayList<ArrayList<Integer> >();
    for (int i = 0; i < N; i++)
      g.add(new ArrayList<Integer>());
 
    edge(1, 2, g);
    edge(1, 5, g);
    edge(2, 3, g);
    edge(2, 4, g);
    paths_originating_from(0, -1, freq, g);
    System.out.println(ans);
  }
}
 
// This code is contributed by garg28harsh.


Python3




# Python code to implement above approach
mx_depth = 0
ans = 0
N = 5
K = 2
freq = [0] * N
g = [[] for _ in range(N)]
 
# This dfs is responsible for calculating ans
# and updating freq vector
 
 
def dfs(node, par, depth, contri):
    global mx_depth, ans, freq
    if depth > K:
        return
    mx_depth = max(mx_depth, depth)
 
    if contri:
        ans += freq[K - depth]
    else:
        freq[depth] += 1
 
    for nebr in g[node]:
        if nebr != par:
            dfs(nebr, node, depth + 1, contri)
 
# Function to calculate K length paths
# originating from node
 
 
def paths_originating_from(node, par):
    global mx_depth, freq
    mx_depth = 0
    freq[0] = 1
 
    # For every not-removed nebr,
    # calculate its contribution,
    # then update freq vector for it
    for nebr in g[node]:
        if nebr != par:
            dfs(nebr, node, 1, True)
            dfs(nebr, node, 1, False)
 
    # Re-initialize freq vector
    freq = [0] * (mx_depth + 1)
 
    # Repeat the same for children
    for nebr in g[node]:
        if nebr != par:
            paths_originating_from(nebr, node)
 
# Utility method to add edges to tree
 
 
def edge(a, b):
    a -= 1
    b -= 1
    g[a].append(b)
    g[b].append(a)
 
 
# Driver code
edge(1, 2)
edge(1, 5)
edge(2, 3)
edge(2, 4)
 
paths_originating_from(0, -1)
print(ans)


Javascript




// Javascript code to implement above approach
 
let max_Depth = 0, ans = 0, N, K, freq = [];
let g = new Array(1000).fill(0).map(() => new Array(0))
 
// This dfs is responsible for calculating ans
// and updating freq vector
function dfs(node, par, depth, contri) {
    if (depth > K)
        return;
    mx_depth = Math.max(mx_depth, depth);
 
    if (contri) {
        ans += freq[K - depth];
    }
    else {
        freq[depth]++;
    }
 
    for (let nebr of g[node]) {
        if (nebr != par) {
            dfs(nebr, node, depth + 1,
                contri);
        }
    }
}
 
// Function to calculate K length paths
// originating from node
function paths_originating_from(node, par) {
    mx_depth = 0;
    freq[0] = 1;
 
    // For every not-removed nebr,
    // calculate its contribution,
    // then update freq vector for it
    for (let nebr of g[node]) {
        if (nebr != par) {
            dfs(nebr, node, 1, true);
            dfs(nebr, node, 1, false);
        }
    }
 
    // Re-initialize freq vector
    for (let i = 0; i <= mx_depth; ++i) {
        freq[i] = 0;
    }
 
    // Repeat the same for children
    for (let nebr of g[node]) {
        if (nebr != par) {
            paths_originating_from(nebr,
                node);
        }
    }
}
 
// Utility method to add edges to tree
function edge(a, b) {
    a--;
    b--;
    g[a].push(b);
    g[b].push(a);
}
 
// Driver code
 
N = 5;
K = 2;
freq = new Array(N).fill(0);
g = new Array(1000).fill(0).map(() => new Array());
 
edge(1, 2);
edge(1, 5);
edge(2, 3);
edge(2, 4);
 
paths_originating_from(0, -1);
console.log(ans)


C#




// C# code to implement above approach
using System;
using System.Collections.Generic;
 
class Program {
    static int mx_depth = 0, ans = 0;
    static int N, K;
    static List<int> freq;
    static List<List<int> > g;
 
    // This dfs is responsible for calculating ans
    // and updating freq vector
    static void dfs(int node, int par, int depth,
                    bool contri)
    {
        if (depth > K)
            return;
        mx_depth = Math.Max(mx_depth, depth);
 
        if (contri) {
            ans += freq[K - depth];
        }
        else {
            freq[depth]++;
        }
 
        foreach(int nebr in g[node])
        {
            if (nebr != par) {
                dfs(nebr, node, depth + 1, contri);
            }
        }
    }
 
    // Function to calculate K length paths
    // originating from node
    static void paths_originating_from(int node, int par)
    {
        mx_depth = 0;
        freq[0] = 1;
 
        // For every not-removed nebr,
        // calculate its contribution,
        // then update freq vector for it
        foreach(int nebr in g[node])
        {
            if (nebr != par) {
                dfs(nebr, node, 1, true);
                dfs(nebr, node, 1, false);
            }
        }
 
        // Re-initialize freq vector
        for (int i = 0; i <= mx_depth; ++i) {
            freq[i] = 0;
        }
 
        // Repeat the same for children
        foreach(int nebr in g[node])
        {
            if (nebr != par) {
                paths_originating_from(nebr, node);
            }
        }
    }
 
    // Utility method to add edges to tree
    static void edge(int a, int b)
    {
        a--;
        b--;
        g[a].Add(b);
        g[b].Add(a);
    }
 
    // Driver code
    static void Main()
    {
        N = 5;
        K = 2;
        freq = new List<int>(N);
        g = new List<List<int> >(N);
 
        for (int i = 0; i < N; i++) {
            freq.Add(0);
            g.Add(new List<int>());
        }
 
        edge(1, 2);
        edge(1, 5);
        edge(2, 3);
        edge(2, 4);
 
        paths_originating_from(0, -1);
        Console.WriteLine(ans);
    }
}
 
// This code is contributed by rutikbhosale.


Output

4

Time Complexity: O(N * H) where H is the height of the tree which can be N at max
Auxiliary Space: O(N)

Efficient Approach: This approach is based on the concept of Centroid Decomposition. The steps are as follows:

  1. Find the centroid of current tree T.
  2. All ‘not-removed’ nodes reachable from T belong to its sub-tree. Call paths_originating_from(T), then mark T as ‘removed’.
  3. Repeat the above process for all ‘not-removed’ neighbors of T.

The following figure shows a tree with current centroid and its sub-tree. Note that nodes with thick borders have already been selected as centroids previously and do not belong to the sub-tree of the current centroid.

Below is the implementation of the above approach.

C++




// C++ code to implement above approach
#include <bits/stdc++.h>
using namespace std;
 
// Struct for centroid decomposition
struct CD {
    // 1. mx_depth will be used to store
    // the height of a node
    // 2. g[] is adjacency list for tree
    // 3. freq[] stores frequency of nodes
    // at particular height, it is maintained
    // for children of a node
    int n, k, mx_depth, ans;
    vector<bool> removed;
    vector<int> size, freq;
    vector<vector<int> > g;
 
    // Constructor for struct
    CD(int n1, int k1)
    {
        n = n1;
        k = k1;
        ans = mx_depth = 0;
 
        g.resize(n);
        size.resize(n);
        freq.resize(n);
        removed.assign(n, false);
    }
 
    // Utility method to add edges to tree
    void edge(int u, int v)
    {
        u--;
        v--;
        g[u].push_back(v);
        g[v].push_back(u);
    }
 
    // Finds size of a subtree,
    // ignoring removed nodes in the way
    int get_size(int node, int par)
    {
        if (removed[node])
            return 0;
        size[node] = 1;
 
        for (auto nebr : g[node]) {
            if (nebr != par) {
                size[node] += get_size(nebr,
                                       node);
            }
        }
 
        return size[node];
    }
 
    // Calculates centroid of a subtree
    // of 'node' of size 'sz'
    int get_centroid(int node, int par,
                     int sz)
    {
        for (auto nebr : g[node]) {
            if (nebr != par && !removed[nebr]
                && size[nebr] > sz / 2) {
                return get_centroid(nebr,
                                    node, sz);
            }
        }
        return node;
    }
 
    // Decompose the tree
    // into various centroids
    void decompose(int node, int par)
    {
        get_size(node, -1);
 
        // c is centroid of subtree 'node'
        int c = get_centroid(node, par,
                             size[node]);
 
        // Find paths_originating_from 'c'
        paths_originating_from(c);
 
        // Mark this centroid as removed
        removed = true;
 
        // Find other centroids
        for (auto nebr : g) {
            if (!removed[nebr]) {
                decompose(nebr, c);
            }
        }
    }
 
    // This dfs is responsible for
    // calculating ans and
    // updating freq vector
    void dfs(int node, int par, int depth,
             bool contri)
    {
        if (depth > k)
            return;
        mx_depth = max(mx_depth, depth);
 
        if (contri) {
            ans += freq[k - depth];
        }
        else {
            freq[depth]++;
        }
 
        for (auto nebr : g[node]) {
            if (nebr != par &&
                !removed[nebr]) {
                dfs(nebr, node,
                    depth + 1, contri);
            }
        }
    }
 
    // Function to find K-length paths
    // originating from node
    void paths_originating_from(int node)
    {
        mx_depth = 0;
        freq[0] = 1;
 
        // For every not-removed nebr,
        // calculate its contribution,
        // then update freq vector for it
        for (auto nebr : g[node]) {
            if (!removed[nebr]) {
                dfs(nebr, node, 1, true);
                dfs(nebr, node, 1, false);
            }
        }
         
        // Re-initialize freq vector
        for (int i = 0; i <= mx_depth; ++i) {
            freq[i] = 0;
        }
    }
};
 
// Driver code
int main()
{
    int N = 5, K = 2;
 
    CD cd_s(N, K);
    cd_s.edge(1, 2);
    cd_s.edge(1, 5);
    cd_s.edge(2, 3);
    cd_s.edge(2, 4);
 
    cd_s.decompose(0, -1);
    cout << cd_s.ans;
    return 0;
}


Java




// Java Code to implement above approach
import java.util.*;
 
public class Solution
{
   
    // Struct for centroid decomposition
    static class CD
    {
       
        // 1. mx_depth will be used to store
        // the height of a node
        // 2. g[] is adjacency list for tree
        // 3. freq[] stores frequency of nodes
        // at particular height, it is maintained
        // for children of a node
        int n, k, mx_depth, ans;
        boolean[] removed;
        int[] size, freq;
        ArrayList<ArrayList<Integer> > g;
 
        // Constructor for struct
        CD(int n1, int k1)
        {
            n = n1;
            k = k1;
            ans = mx_depth = 0;
 
            g = new ArrayList<>();
            for (int i = 0; i < n; i++) {
                g.add(new ArrayList<>());
            }
            size = new int[n];
            freq = new int[n];
            removed = new boolean[n];
        }
 
        // Utility method to add edges to tree
        public void edge(int u, int v)
        {
            u--;
            v--;
            g.get(u).add(v);
            g.get(v).add(u);
        }
 
        // Finds size of a subtree,
        // ignoring removed nodes in the way
        public int get_size(int node, int par)
        {
            if (removed[node])
                return 0;
            size[node] = 1;
 
            for (Integer nebr : g.get(node)) {
                if (nebr != par) {
                    size[node] += get_size(nebr, node);
                }
            }
 
            return size[node];
        }
 
        // Calculates centroid of a subtree
        // of 'node' of size 'sz'
        int get_centroid(int node, int par, int sz)
        {
            for (Integer nebr : g.get(node)) {
                if (nebr != par && !removed[nebr]
                    && size[nebr] > sz / 2) {
                    return get_centroid(nebr, node, sz);
                }
            }
            return node;
        }
 
        // Decompose the tree
        // into various centroids
        public void decompose(int node, int par)
        {
            get_size(node, -1);
 
            // c is centroid of subtree 'node'
            int c = get_centroid(node, par, size[node]);
 
            // Find paths_originating_from 'c'
            paths_originating_from(c);
 
            // Mark this centroid as removed
            removed = true;
 
            // Find other centroids
            for (Integer nebr : g.get(c)) {
                if (!removed[nebr]) {
                    decompose(nebr, c);
                }
            }
        }
 
        // This dfs is responsible for
        // calculating ans and
        // updating freq vector
        void dfs(int node, int par, int depth,
                 boolean contri)
        {
            if (depth > k)
                return;
            mx_depth = Math.max(mx_depth, depth);
 
            if (contri) {
                ans += freq[k - depth];
            }
            else {
                freq[depth]++;
            }
 
            for (Integer nebr : g.get(node)) {
                if (nebr != par && !removed[nebr]) {
                    dfs(nebr, node, depth + 1, contri);
                }
            }
        }
 
        // Function to find K-length paths
        // originating from node
        void paths_originating_from(int node)
        {
            mx_depth = 0;
            freq[0] = 1;
 
            // For every not-removed nebr,
            // calculate its contribution,
            // then update freq vector for it
            for (Integer nebr : g.get(node)) {
                if (!removed[nebr]) {
                    dfs(nebr, node, 1, true);
                    dfs(nebr, node, 1, false);
                }
            }
 
            // Re-initialize freq vector
            for (int i = 0; i <= mx_depth; ++i) {
                freq[i] = 0;
            }
        }
    };
 
    // Driver code
    public static void main(String[] args)
    {
        int N = 5, K = 2;
 
        CD cd_s = new CD(N, K);
        cd_s.edge(1, 2);
        cd_s.edge(1, 5);
        cd_s.edge(2, 3);
        cd_s.edge(2, 4);
 
        cd_s.decompose(0, -1);
        System.out.println(cd_s.ans);
    }
}
 
// This code is contributed by karandeep1234


Python3




# Python code to implement above approach
 
# Struct for centroid decomposition
 
 
class CD:
    # Constructor for struct
    def __init__(self, n1, k1):
        self.n = n1
        self.k = k1
        self.ans = self.mx_depth = 0
        self.removed = [False]*n1
        self.size = [0]*n1
        self.freq = [0]*n1
        self.g = [[] for i in range(n1)]
 
    # Utility method to add edges to tree
    def edge(self, u, v):
        u -= 1
        v -= 1
        self.g[u].append(v)
        self.g[v].append(u)
 
    # Finds size of a subtree,
    # ignoring removed nodes in the way
    def get_size(self, node, par):
        if self.removed[node]:
            return 0
        self.size[node] = 1
 
        for nebr in self.g[node]:
            if nebr != par:
                self.size[node] += self.get_size(nebr, node)
 
        return self.size[node]
 
    # Calculates centroid of a subtree
    # of 'node' of size 'sz'
    def get_centroid(self, node, par, sz):
        for nebr in self.g[node]:
            if nebr != par and not self.removed[nebr] and self.size[nebr] > sz / 2:
                return self.get_centroid(nebr, node, sz)
        return node
 
    # Decompose the tree
    # into various centroids
    def decompose(self, node, par):
        self.get_size(node, -1)
 
        # c is centroid of subtree 'node'
        c = self.get_centroid(node, par, self.size[node])
 
        # Find paths_originating_from 'c'
        self.paths_originating_from(c)
 
        # Mark this centroid as removed
        self.removed = True
 
        # Find other centroids
        for nebr in self.g:
            if not self.removed[nebr]:
                self.decompose(nebr, c)
 
    # This dfs is responsible for
    # calculating ans and
    # updating freq vector
    def dfs(self, node, par, depth, contri):
        if depth > self.k:
            return
        self.mx_depth = max(self.mx_depth, depth)
 
        if contri:
            self.ans += self.freq[self.k - depth]
        else:
            self.freq[depth] += 1
 
        for nebr in self.g[node]:
            if nebr != par and not self.removed[nebr]:
                self.dfs(nebr, node, depth + 1, contri)
 
    # Function to find K-length paths
    # originating from node
    def paths_originating_from(self, node):
        self.mx_depth = 0
        self.freq[0] = 1
 
        # For every not-removed nebr,
        # calculate its contribution,
        # then update freq vector for it
        for nebr in self.g[node]:
            if not self.removed[nebr]:
                self.dfs(nebr, node, 1, True)
                self.dfs(nebr, node, 1, False)
 
        # Re-initialize freq vector
        for i in range(self.mx_depth+1):
            self.freq[i] = 0
 
 
# Driver code
if __name__ == "__main__":
    N = 5
    K = 2
 
    cd_s = CD(N, K)
    cd_s.edge(1, 2)
    cd_s.edge(1, 5)
    cd_s.edge(2, 3)
    cd_s.edge(2, 4)
    cd_s.decompose(0, -1)
    print(cd_s.ans)


Javascript




class CD {
    constructor(n1, k1) {
        this.n = n1;
        this.k = k1;
        this.ans = this.mx_depth = 0;
        this.removed = Array(n1).fill(false);
        this.size = Array(n1).fill(0);
        this.freq = Array(n1).fill(0);
        this.g = Array(n1).fill().map(() => []);
    }
 
    edge(u, v) {
        u -= 1;
        v -= 1;
        this.g[u].push(v);
        this.g[v].push(u);
    }
 
    get_size(node, par) {
        if (this.removed[node]) {
            return 0;
        }
        this.size[node] = 1;
 
        for (let nebr of this.g[node]) {
            if (nebr !== par) {
                this.size[node] += this.get_size(nebr, node);
            }
        }
 
        return this.size[node];
    }
 
    get_centroid(node, par, sz) {
        for (let nebr of this.g[node]) {
            if (!this.removed[nebr] && nebr !== par && this.size[nebr] > sz / 2) {
                return this.get_centroid(nebr, node, sz);
            }
        }
        return node;
    }
 
    decompose(node, par) {
        this.get_size(node, -1);
        let c = this.get_centroid(node, par, this.size[node]);
        this.paths_originating_from(c);
        this.removed = true;
 
        for (let nebr of this.g) {
            if (!this.removed[nebr]) {
                this.decompose(nebr, c);
            }
        }
    }
 
    dfs(node, par, depth, contri) {
        if (depth > this.k) {
            return;
        }
        this.mx_depth = Math.max(this.mx_depth, depth);
 
        if (contri) {
            this.ans += this.freq[this.k - depth];
        } else {
            this.freq[depth] += 1;
        }
 
        for (let nebr of this.g[node]) {
            if (nebr !== par && !this.removed[nebr]) {
                this.dfs(nebr, node, depth + 1, contri);
            }
        }
    }
 
    paths_originating_from(node) {
        this.mx_depth = 0;
        this.freq[0] = 1;
 
        for (let nebr of this.g[node]) {
            if (!this.removed[nebr]) {
                this.dfs(nebr, node, 1, true);
                this.dfs(nebr, node, 1, false);
            }
        }
 
        for (let i = 0; i < this.mx_depth + 1; i++) {
            this.freq[i] = 0;
        }
    }
}
 
const N = 5;
const K = 2;
 
const cd_s = new CD(N, K);
cd_s.edge(1, 2);
cd_s.edge(1, 5);
cd_s.edge(2, 3);
cd_s.edge(2, 4);
cd_s.decompose(0, -1);
console.log(cd_s.ans);


C#




// C++ code to implement above approach
using System;
using System.Collections.Generic;
 
class CD {
    // 1. mx_depth will be used to store
    // the height of a node
    // 2. g[] is adjacency list for tree
    // 3. freq[] stores frequency of nodes
    // at particular height, it is maintained
    // for children of a node
    private int n, k, mx_depth, ans;
    private List<bool> removed;
    private List<int> size, freq;
    private List<List<int> > g;
 
    // Constructor for struct
    public CD(int n1, int k1)
    {
        n = n1;
        k = k1;
        ans = mx_depth = 0;
 
        g = new List<List<int> >();
        for (int i = 0; i < n; i++)
            g.Add(new List<int>());
 
        size = new List<int>();
        size.AddRange(new int[n]);
 
        freq = new List<int>();
        freq.AddRange(new int[n]);
 
        removed = new List<bool>();
        removed.AddRange(new bool[n]);
    }
 
    // Utility method to add edges to tree
    public void Edge(int u, int v)
    {
        u--;
        v--;
        g[u].Add(v);
        g[v].Add(u);
    }
 
    // Finds size of a subtree,
    // ignoring removed nodes in the way
    private int GetSize(int node, int par)
    {
        if (removed[node])
            return 0;
        size[node] = 1;
 
        foreach(int nebr in g[node])
        {
            if (nebr != par)
                size[node] += GetSize(nebr, node);
        }
        return size[node];
    }
 
    // Calculates centroid of a subtree
    // of 'node' of size 'sz'
 
    private int GetCentroid(int node, int par, int sz)
    {
        foreach(int nebr in g[node])
        {
            if (!removed[nebr] && nebr != par
                && size[nebr] > sz / 2)
                return GetCentroid(nebr, node, sz);
        }
        return node;
    }
 
    // Decompose the tree
    // into various centroids
    private void Decompose(int node, int par)
    {
        GetSize(node, -1);
        // c is centroid of subtree 'node'
        int c = GetCentroid(node, par, size[node]);
        // Find paths_originating_from 'c'
        PathsOriginatingFrom(c);
        // Mark this centroid as removed
        removed = true;
        // Find other centroids
        foreach(int nebr in g)
        {
            if (!removed[nebr])
                Decompose(nebr, c);
        }
    }
 
    // This dfs is responsible for
    // calculating ans and
    // updating freq vector
 
    private void Dfs(int node, int par, int depth,
                     bool contri)
    {
        if (depth > k)
            return;
        mx_depth = Math.Max(mx_depth, depth);
 
        if (contri)
            ans += freq[k - depth];
        else
            freq[depth]++;
 
        foreach(int nebr in g[node])
        {
            if (!removed[nebr] && nebr != par)
                Dfs(nebr, node, depth + 1, contri);
        }
    }
    // Function to find K-length paths
    // originating from node
    private void PathsOriginatingFrom(int node)
    {
        mx_depth = 0;
        freq[0] = 1;
        // For every not-removed nebr,
        // calculate its contribution,
        // then update freq vector for it
 
        foreach(int nebr in g[node])
        {
            if (!removed[nebr]) {
                Dfs(nebr, node, 1, true);
                Dfs(nebr, node, 1, false);
            }
        }
 
        for (int i = 0; i <= mx_depth; i++)
            freq[i] = 0;
    }
 
    public int Solve()
    {
        Decompose(0, -1);
        return ans;
    }
}
// Driver code
class Program {
    static void Main(string[] args)
    {
        int N = 5, K = 2;
 
        CD cd_s = new CD(N, K);
        cd_s.Edge(1, 2);
        cd_s.Edge(1, 5);
        cd_s.Edge(2, 3);
        cd_s.Edge(2, 4);
 
        Console.WriteLine(cd_s.Solve());
    }
}
// This code is generated by Chetan Bargal


Output

4

Time Complexity: O(N * log(N)) where log N is the height of the tree
Auxiliary Space: O(N)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads