Open In App

Querying the number of distinct colors in a subtree of a colored tree using BIT

Last Updated : 13 Dec, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Prerequisites: BIT, DFS

Given a rooted tree T, with ‘n’ nodes, each node has a color denoted by the array color[](color[i] denotes the color of ith node in form of an integer). Respond to ‘Q’ queries of the following type: 

  • distinct u – Print the number of distinct colored nodes under the subtree rooted under ‘u’

Examples:  

            1
/ \
2 3
/|\ | \
4 5 6 7 8
/| \
9 10 11
color[] = {0, 2, 3, 3, 4, 1, 3, 4, 3, 2, 1, 1}
Indexes NA 1 2 3 4 5 6 7 8 9 10 11
(Node Values and colors start from index 1)
distinct 3 -> output should be '4'.
There are six different nodes in the subtree rooted with
3, nodes are 3, 7, 8, 9, 10 and 11. These nodes have
four distinct colors (3, 4, 2 and 1)
distinct 2 -> output should be '3'.
distinct 7 -> output should be '3'.

Building a solution in steps:  

  1. Flatten the tree using DFS; store the visiting time and ending time for every node in two arrays, vis_time[i] stores the visiting time of the ith node while end_time[i] stores the ending time.
  2. In the same DFS call, store the value of color of every node in an array flat_tree[], at indices: vis_time[i] and end_time[i] for ith node. 
    Note: size of the array flat_tree[] will be 2n.

Now the problem is reduced to finding the number of distinct elements in the range [vis_time[u], end_time[u] ] in the array flat_tree[] for each query of the specified type. To do so, we will process the queries off-line(processing the queries in an order different than the one provided in the question and storing the results, and finally printing the result for each in the order specified in the question).

Steps:  

  1. First, we pre-process the array flat_tree[]; we maintain a table[](an array of vectors), table[i] stores the vector containing all the indices in flat_tree[] that have value i. That is, if flat_tree[j] = i, then table[i] will have one of its element j.
  2. In BIT, we update ‘1’ at ith index if we want the ith element of flat_tree[] to be counted in query() method. We now maintain another array traverser[]; traverser[i] contains the pointer to the next element of table[i] that is not marked in BIT yet.
  3. We now update our BIT and set ‘1’ at first occurrence of every element in flat_tree[] and increment corresponding traverser[] by ‘1’(if flat_tree[i] is occurring for the first time then traverser[flat_tree[i]] is incremented by ‘1’) to point to the next occurrence of that element.
  4. Now our query(R) function for BIT would return the number of distinct elements in flat_tree[] in the range [1, R].
  5. We sort all the queries in order of increasing vis_time[], let li denote vis_time[i] and ri denote the end_time[i]. Sorting the queries in increasing order of li gives us an edge, as when processing the ith query we won’t see any query in future with its ‘l‘ smaller than li. So we can remove all the elements’ occurrences up to li – 1 from BIT and add their next occurrences using the traverser[] array. And then query(R) would return the number of distinct elements in the range [li, ri ]

C++




// A C++ program implementing the above design
#include<bits/stdc++.h>
#define max_color 1000005
#define maxn 100005
using namespace std;
 
// Note: All elements of global arrays are
// initially zero
// All the arrays have been described above
int bit[maxn], vis_time[maxn], end_time[maxn];
int flat_tree[2 * maxn];
vector<int> tree[maxn];
vector<int> table[max_color];
int traverser[max_color];
 
bool vis[maxn];
int tim = 0;
 
//li, ri and index are stored in queries vector
//in that order, as the sort function will use
//the value li for comparison
vector< pair< pair<int, int>, int> > queries;
 
//ans[i] stores answer to ith query
int ans[maxn];
 
//update function to add val to idx in BIT
void update(int idx, int val)
{
    while ( idx < maxn )
    {
        bit[idx] += val;
        idx += idx & -idx;
    }
}
 
//query function to find sum(1, idx) in BIT
int query(int idx)
{
    int res = 0;
    while ( idx > 0 )
    {
        res += bit[idx];
        idx -= idx & -idx;
    }
    return res;
}
 
void dfs(int v, int color[])
{
    //mark the node visited
    vis[v] = 1;
 
    //set visiting time of the node v
    vis_time[v] = ++tim;
 
    //use the color of node v to fill flat_tree[]
    flat_tree[tim] = color[v];
 
    vector<int>::iterator it;
    for (it=tree[v].begin(); it!=tree[v].end(); it++)
        if (!vis[*it])
            dfs(*it, color);
 
 
    // set ending time for node v
    end_time[v] = ++tim;
 
    // setting its color in flat_tree[] again
    flat_tree[tim] = color[v];
}
 
//function to add an edge(u, v) to the tree
void addEdge(int u, int v)
{
    tree[u].push_back(v);
    tree[v].push_back(u);
}
 
//function to build the table[] and also add
//first occurrences of elements to the BIT
void hashMarkFirstOccurrences(int n)
{
    for (int i = 1 ; i <= 2 * n ; i++)
    {
        table[flat_tree[i]].push_back(i);
 
        //if it is the first occurrence of the element
        //then add it to the BIT and increment traverser
        if (table[flat_tree[i]].size() == 1)
        {
            //add the occurrence to bit
            update(i, 1);
 
            //make traverser point to next occurrence
            traverser[flat_tree[i]]++;
        }
    }
}
 
//function to process all the queries and store their answers
void processQueries()
{
    int j = 1;
    for (int i=0; i<queries.size(); i++)
    {
        //for each query remove all the occurrences before its li
        //li is the visiting time of the node
        //which is stored in first element of first pair
        for ( ; j < queries[i].first.first ; j++ )
        {
            int elem = flat_tree[j];
 
            //update(i, -1) removes an element at ith index
            //in the BIT
            update( table[elem][traverser[elem] - 1], -1);
 
            //if there is another occurrence of the same element
            if ( traverser[elem] < table[elem].size() )
            {
                //add the occurrence to the BIT and
                //increment traverser
                update(table[elem][ traverser[elem] ], 1);
                traverser[elem]++;
            }
        }
 
        //store the answer for the query, the index of the query
        //is the second element of the pair
        //And ri is stored in second element of the first pair
        ans[queries[i].second] = query(queries[i].first.second);
    }
}
 
// Count distinct colors in subtrees rooted with qVer[0],
// qVer[1], ...qVer[qn-1]
void countDistinctColors(int color[], int n, int qVer[], int qn)
{
    // build the flat_tree[], vis_time[] and end_time[]
    dfs(1, color);
 
    // add query for u = 3, 2 and 7
    for (int i=0; i<qn; i++)
        queries.push_back(make_pair(make_pair(vis_time[qVer[i]],
                                    end_time[qVer[i]]), i) );
 
    // sort the queries in order of increasing vis_time
    sort(queries.begin(), queries.end());
 
    // make table[] and set '1' at first occurrences of elements
    hashMarkFirstOccurrences(n);
 
    // process queries
    processQueries();
 
    // print all the answers, in order asked
    // in the question
    for (int i=0; i<queries.size() ; i++)
    {
        cout << "Distinct colors in the corresponding subtree"
        "is: " << ans[i] << endl;
    }
}
 
//driver code
int main()
{
    /*
            1
           / \
          2   3
         /|\  | \
        4 5 6 7  8
             /| \
            9 10 11    */
    int n = 11;
    int color[] = {0, 2, 3, 3, 4, 1, 3, 4, 3, 2, 1, 1};
 
    // add all the edges to the tree
    addEdge(1, 2);
    addEdge(1, 3);
    addEdge(2, 4);
    addEdge(2, 5);
    addEdge(2, 6);
    addEdge(3, 7);
    addEdge(3, 8);
    addEdge(7, 9);
    addEdge(7, 10);
    addEdge(7, 11);
 
 
    int qVer[] = {3, 2, 7};
    int qn = sizeof(qVer)/sizeof(qVer[0]);
 
    countDistinctColors(color, n, qVer, qn);
 
    return 0;
}


Java




import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
 
public class Main {
    private static final int maxColor = 1000005;
    private static final int maxn = 100005;
     
    private static int[] bit = new int[maxn];
    private static int[] visTime = new int[maxn];
    private static int[] endTime = new int[maxn];
    private static int[] flatTree = new int[2 * maxn];
    private static List<Integer>[] tree = new ArrayList[maxn];
    private static List<Integer>[] table = new ArrayList[maxColor];
    private static int[] traverser = new int[maxColor];
    private static boolean[] vis = new boolean[maxn];
    private static int tim = 0;
     
    private static List<Pair<Pair<Integer, Integer>, Integer>> queries = new ArrayList<>();
    private static int[] ans = new int[maxn];
     
    private static void update(int idx, int val) {
        while (idx < maxn) {
            bit[idx] += val;
            idx += idx & -idx;
        }
    }
     
    private static int query(int idx) {
        int res = 0;
        while (idx > 0) {
            res += bit[idx];
            idx -= idx & -idx;
        }
        return res;
    }
     
    private static void dfs(int v, int[] color) {
        vis[v] = true;
        visTime[v] = ++tim;
        flatTree[tim] = color[v];
        for (int u : tree[v]) {
            if (!vis[u]) {
                dfs(u, color);
            }
        }
        endTime[v] = ++tim;
        flatTree[tim] = color[v];
    }
     
    private static void addEdge(int u, int v) {
        tree[u].add(v);
        tree[v].add(u);
    }
     
    private static void hashMarkFirstOccurrences(int n) {
        for (int i = 1; i <= 2 * n; i++) {
            table[flatTree[i]].add(i);
            if (table[flatTree[i]].size() == 1) {
                update(i, 1);
                traverser[flatTree[i]]++;
            }
        }
    }
     
    private static void processQueries() {
        int j = 1;
        for (Pair<Pair<Integer, Integer>, Integer> query : queries) {
            for (; j < query.first.first; j++) {
                int elem = flatTree[j];
                update(table[elem].get(traverser[elem] - 1), -1);
                if (traverser[elem] < table[elem].size()) {
                    update(table[elem].get(traverser[elem]), 1);
                    traverser[elem]++;
                }
            }
            ans[query.second] = query(query.first.second);
        }
    }
     
    private static void countDistinctColors(int[] color, int n, int[] qVer, int qn) {
        dfs(1, color);
        for (int i = 0; i < qn; i++) {
            queries.add(new Pair<>(new Pair<>(visTime[qVer[i]], endTime[qVer[i]]), i));
        }
        Collections.sort(queries);
        hashMarkFirstOccurrences(n);
        processQueries();
        for (int i = 0; i < queries.size(); i++) {
            System.out.println("Distinct colors in the corresponding subtree is: " + ans[i]);
        }
    }
     
    public static void main(String[] args) {
        int n = 11;
        int[] color = {0, 2, 3, 3, 4, 1, 3, 4, 3, 2, 1, 1};
         
         for (int i = 0; i < maxn; i++) {
        tree[i] = new ArrayList<>();
        table[i] = new ArrayList<>(); // Initialize the table array here
    }
 
         
        addEdge(1, 2);
        addEdge(1, 3);
        addEdge(2, 4);
        addEdge(2, 5);
        addEdge(2, 6);
        addEdge(3, 7);
        addEdge(3, 8);
        addEdge(7, 9);
        addEdge(7, 10);
        addEdge(7, 11);
         
        int[] qVer = {3, 2, 7};
        int qn = qVer.length;
         
        countDistinctColors(color, n, qVer, qn);
    }
     
    static class Pair<A, B> implements Comparable<Pair<A, B>> {
        A first;
        B second;
         
        public Pair(A first, B second) {
            this.first = first;
            this.second = second;
        }
         
        @Override
        public int compareTo(Pair<A, B> other) {
            if (this.first.equals(other.first)) {
                return ((Comparable<B>) this.second).compareTo(other.second);
            } else {
                return ((Comparable<A>) this.first).compareTo(other.first);
            }
        }
    }
}


Python3




# All elements of global arrays are initially zero
bit = [0] * 100005  # Binary Indexed Tree (BIT)
vis_time = [0] * 100005  # Visiting time for nodes
end_time = [0] * 100005  # Ending time for nodes
flat_tree = [0] * (2 * 100005# Flattened tree array
tree = [[] for _ in range(100005)]  # Tree adjacency list
table = [[] for _ in range(1000005)]  # Table to store occurrences of colors
traverser = [0] * 1000005  # Keeps track of occurrences for each color
vis = [False] * 100005  # Visited nodes
tim = 0  # Time variable for node traversal
queries = []  # Queries to process
ans = [0] * 100005  # Stores answers to queries
 
# Update function to add val to idx in BIT
def update(idx, val):
    while idx < len(bit):
        bit[idx] += val
        idx += idx & -idx
 
# Query function to find sum(1, idx) in BIT
def query(idx):
    res = 0
    while idx > 0:
        res += bit[idx]
        idx -= idx & -idx
    return res
 
def dfs(v, color):
    global tim
    vis[v] = True
    vis_time[v] = tim = tim + 1
    flat_tree[tim] = color[v]  # Flattening the tree with node colors
 
    for node in tree[v]:  # Traverse through adjacent nodes
        if not vis[node]:
            dfs(node, color)
 
    end_time[v] = tim = tim + 1
    flat_tree[tim] = color[v]
 
def addEdge(u, v):
    tree[u].append(v)  # Add edges to the tree
    tree[v].append(u)
 
def hashMarkFirstOccurrences(n):
    # Loop through the flattened tree to mark first occurrences of colors
    for i in range(1, 2 * n + 1):
        table[flat_tree[i]].append(i)
        if len(table[flat_tree[i]]) == 1:
            update(i, 1# Update BIT for first occurrences
            traverser[flat_tree[i]] += 1
 
def processQueries():
    j = 1
    for i in range(len(queries)):
        # Process queries and update BIT accordingly
        while j < queries[i][0][0]:
            elem = flat_tree[j]
            update(table[elem][traverser[elem] - 1], -1)
 
            if traverser[elem] < len(table[elem]):
                update(table[elem][traverser[elem]], 1)
                traverser[elem] += 1
            j += 1
 
        ans[queries[i][1]] = query(queries[i][0][1])  # Store query answers
 
def countDistinctColors(color, n, qVer, qn):
    dfs(1, color)  # Start depth-first search from node 1
 
    for i in range(qn):
        queries.append(((vis_time[qVer[i]], end_time[qVer[i]]), i))  # Prepare queries
 
    queries.sort()  # Sort queries based on visiting time and ending time
    hashMarkFirstOccurrences(n)  # Mark first occurrences in the flattened tree
    processQueries()  # Process queries and update BIT
 
    for i in range(len(queries)):
        print(f"Distinct colors in the corresponding subtree is: {ans[i]}"# Print query answers
 
if __name__ == "__main__":
    # Sample tree structure and colors
    n = 11
    color = [0, 2, 3, 3, 4, 1, 3, 4, 3, 2, 1, 1]
 
    addEdge(1, 2# Add edges to construct the tree
    addEdge(1, 3)
    addEdge(2, 4)
    addEdge(2, 5)
    addEdge(2, 6)
    addEdge(3, 7)
    addEdge(3, 8)
    addEdge(7, 9)
    addEdge(7, 10)
    addEdge(7, 11)
 
    qVer = [3, 2, 7# Query nodes
    qn = len(qVer)
 
    countDistinctColors(color, n, qVer, qn)  # Count distinct colors in subtrees rooted at query nodes


C#




using System;
using System.Collections.Generic;
using System.Linq;
 
class Program
{
  // Note: All elements of global arrays are
  // initially zero
  // All the arrays have been described above
  const int max_color = 1000005;
  const int maxn = 100005;
  static int[] bit = new int[maxn];
  static int[] vis_time = new int[maxn];
  static int[] end_time = new int[maxn];
  static int[] flat_tree = new int[2 * maxn];
  static List<int>[] tree = Enumerable.Repeat(0, maxn).Select(x => new List<int>()).ToArray();
  static List<int>[] table = Enumerable.Repeat(0, max_color).Select(x => new List<int>()).ToArray();
  static int[] traverser = new int[max_color];
  static bool[] vis = new bool[maxn];
  static int tim = 0;
 
  // li, ri and index are stored in queries vector
  // in that order, as the sort function will use
  // the value li for comparison
  static List<Tuple<Tuple<int, int>, int>> queries = new List<Tuple<Tuple<int, int>, int>>();
 
  // ans[i] stores answer to ith query
 
  static int[] ans = new int[maxn];
 
  // update function to add val to idx in BIT
  static void Update(int idx, int val)
  {
    while (idx < maxn)
    {
      bit[idx] += val;
      idx += idx & -idx;
    }
  }
  // query function to find sum(1, idx) in BIT
  static int Query(int idx)
  {
    int res = 0;
    while (idx > 0)
    {
      res += bit[idx];
      idx -= idx & -idx;
    }
    return res;
  }
 
  static void Dfs(int v, int[] color)
  {
    // mark the node visited
    vis[v] = true;
    vis_time[v] = ++tim;
    flat_tree[tim] = color[v];
    foreach (int it in tree[v])
      if (!vis[it])
        Dfs(it, color);
    end_time[v] = ++tim;
    flat_tree[tim] = color[v];
  }
  //function to add edges to graph
  static void addEdge(int u, int v)
  {
    tree[u].Add(v);
    tree[v].Add(u);
  }
  // function to build the table[] and also add
  // first occurrences of elements to the BIT
  static void HashMarkFirstOccurrences(int n)
  {
    for (int i = 1; i <= 2 * n; i++)
    {
      // if it is the first occurrence of the element
      // then add it to the BIT and increment traverser
      table[flat_tree[i]].Add(i);
      if (table[flat_tree[i]].Count == 1)
      {
        Update(i, 1);
        traverser[flat_tree[i]]++;
      }
    }
  }
   
  // function to process all the queries and store their answers
  static void ProcessQueries()
  {
    int j = 1;
     
    // for each query remove all the occurrences before its li
    // li is the visiting time of the node
    // which is stored in first element of first pair
    for (int i = 0; i < queries.Count; i++)
    {
      for (; j < queries[i].Item1.Item1; j++)
      {
        int elem = flat_tree[j];
        Update(table[elem][traverser[elem] - 1], -1);
        if (traverser[elem] < table[elem].Count)
        {
          Update(table[elem][traverser[elem]], 1);
          traverser[elem]++;
        }
      }
      ans[queries[i].Item2] = Query(queries[i].Item1.Item2);
    }
  }
   
  // Count distinct colors in subtrees rooted with qVer[0],
  // qVer[1], ...qVer[qn-1]
  static void countDistinctColors(int[] color, int n, int[] qVer, int qn)
  {
     
    // build the flat_tree[], vis_time[] and end_time[]
    Dfs(1, color);
     
    // add query for u = 3, 2 and 7
    for (int i = 0; i < qn; i++)
      queries.Add(new Tuple<Tuple<int, int>, int>(new Tuple<int, int>(vis_time[qVer[i]], end_time[qVer[i]]), i));
    queries.Sort();
    HashMarkFirstOccurrences(n);
    ProcessQueries();
    // print all the answers, in order asked
    // in the question
    for (int i = 0; i < queries.Count; i++)
      Console.WriteLine("Distinct colors in the corresponding subtree is: {0}", ans[i]);
  }
 
  static void Main(string[] args)
  {
    /*
                    1
                   / \
                  2   3
                 /|\  | \
                4 5 6 7  8
                     /| \
                    9 10 11    */
    int n = 11;
    int[] color = { 0, 2, 3, 3, 4, 1, 3, 4, 3, 2, 1, 1 };
 
 
    // add all the edges to the tree
    addEdge(1, 2);
    addEdge(1, 3);
    addEdge(2, 4);
    addEdge(2, 5);
    addEdge(2, 6);
    addEdge(3, 7);
    addEdge(3, 8);
    addEdge(7, 9);
    addEdge(7, 10);
    addEdge(7, 11);
 
    int[] qVer = { 3, 2, 7 };
    int qn = qVer.Length;
 
    countDistinctColors(color, n, qVer, qn);
 
  }
}


Javascript




// Constants for maximum color, maximum nodes, and initializing arrays
const max_color = 1000005;
const maxn = 100005;
const bit = new Array(maxn).fill(0); // Binary Indexed Tree
const vis_time = new Array(maxn).fill(0); // Visit time for nodes
const end_time = new Array(maxn).fill(0); // End time for nodes
const flat_tree = new Array(2 * maxn).fill(0); // Flattened tree structure
const tree = Array.from({ length: maxn }, () => []); // Graph/tree structure
const table = Array.from({ length: max_color }, () => []); // Table for elements' occurrences
const traverser = new Array(max_color).fill(0); // Tracks traversed elements
const vis = new Array(maxn).fill(false); // Tracks visited nodes
let tim = 0; // Time counter
 
const queries = []; // Array to store queries
 
const ans = new Array(maxn).fill(0); // Array to store answers to queries
 
// Function to update Binary Indexed Tree
function Update(idx, val) {
  while (idx < maxn) {
    bit[idx] += val;
    idx += idx & -idx;
  }
}
 
// Function to query Binary Indexed Tree
function Query(idx) {
  let res = 0;
  while (idx > 0) {
    res += bit[idx];
    idx -= idx & -idx;
  }
  return res;
}
 
// Depth-first search traversal on the tree
function Dfs(v, color) {
  vis[v] = true;
  vis_time[v] = ++tim;
  flat_tree[tim] = color[v];
  tree[v].forEach((it) => {
    if (!vis[it]) Dfs(it, color);
  });
  end_time[v] = ++tim;
  flat_tree[tim] = color[v];
}
 
// Function to add edges to the tree/graph
function addEdge(u, v) {
  tree[u].push(v);
  tree[v].push(u);
}
 
// Function to populate table and BIT with first occurrences
function HashMarkFirstOccurrences(n) {
  for (let i = 1; i <= 2 * n; i++) {
    table[flat_tree[i]].push(i);
    if (table[flat_tree[i]].length === 1) {
      Update(i, 1);
      traverser[flat_tree[i]]++;
    }
  }
}
 
// Function to process queries and store answers
function ProcessQueries() {
  let j = 1;
  for (let i = 0; i < queries.length; i++) {
    for (; j < queries[i][0][0]; j++) {
      const elem = flat_tree[j];
      Update(table[elem][traverser[elem] - 1], -1);
      if (traverser[elem] < table[elem].length) {
        Update(table[elem][traverser[elem]], 1);
        traverser[elem]++;
      }
    }
    ans[queries[i][1]] = Query(queries[i][0][1]);
  }
}
 
// Function to count distinct colors in subtrees
function countDistinctColors(color, n, qVer, qn) {
  Dfs(1, color); // Traverse the tree to generate visit and end times
  for (let i = 0; i < qn; i++) {
    // Push queries based on visit and end times to queries array
    queries.push([[vis_time[qVer[i]], end_time[qVer[i]]], i]);
  }
  queries.sort((a, b) => a[0][0] - b[0][0]); // Sort queries based on visit times
  HashMarkFirstOccurrences(n); // Initialize BIT and table with first occurrences
  ProcessQueries(); // Process queries to calculate distinct colors
  for (let i = 0; i < queries.length; i++) {
    console.log(`Distinct colors in the corresponding subtree is: ${ans[i]}`); // Print the answers
  }
}
 
// Define the tree structure and colors
const n = 11;
const color = [0, 2, 3, 3, 4, 1, 3, 4, 3, 2, 1, 1];
 
// Define edges in the tree
addEdge(1, 2);
addEdge(1, 3);
addEdge(2, 4);
addEdge(2, 5);
addEdge(2, 6);
addEdge(3, 7);
addEdge(3, 8);
addEdge(7, 9);
addEdge(7, 10);
addEdge(7, 11);
 
// Define query vertices and call the function to count distinct colors
const qVer = [3, 2, 7];
const qn = qVer.length;
countDistinctColors(color, n, qVer, qn);


Output: 

Distinct colors in the corresponding subtree is:4
Distinct colors in the corresponding subtree is:3
Distinct colors in the corresponding subtree is:3
Time Complexity: O( Q * log(n) )



 



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads