In this HackerRank Similar Pair problem solution you have given a tree where each node is labeled from 1 to n, find the number of similar pairs in the tree. remember that a pair of nodes (a,b) is similar pair if node a is the ancestor of node b and the absolute difference between a and b is less than variable k.

HackerRank Similar Pair problem solution


Problem solution in Python.

import resource
import sys
sys.setrecursionlimit(2000000)

def add(x, v):
    x += 1
    while x <= n:
        a[x] += v
        x += x & -x

def que(x):
    x += 1
    if x <= 0:
        return 0
    ret = 0
    x = min(n, x)
    while x > 0:
        ret += a[x]
        x -= x & -x
    return ret

st = []
vis = {}
def dfs(x):
    
    global ans
    st.append(x)
    while st:
        x = st[-1]
        if not x in vis:
            ans += que(x + T) - que(x - T - 1)
            add(x, 1)
            vis[x] = 1
        if nx[x]:
            st.append(nx[x][-1])
            nx[x].pop()
        else:
            st.pop()
            add(x, -1)

n, T = (int(x) for x in input().split())
a = [0 for i in range(4 * n)]
nx = [[] for i in range(n)]
pre = [-1 for i in range(n)]
for i in range(n - 1):
    s, e = (int(x) - 1 for x in input().split())
    nx[s].append(e)
    pre[e] = s
    
s = 1
while pre[s] != -1:
    s = pre[s]
ans = 0
dfs(s)
print(ans)

{"mode":"full","isActive":false}


Problem solution in Java.

import java.io.*;
import java.util.*;

public class Solution {
 public static LinkedList<Integer>[] nodes = new LinkedList[100002];  
    static int n , t, root;  
  
      
  
    public static void main(String[] args) {  
        /* Enter your code here. Read input from STDIN. Print output to STDOUT. Your class should be named Solution. */  
          
        Scanner scan = new Scanner(System.in);  
          
        n = scan.nextInt();  
        t = scan.nextInt();  
        long[] stree = new long[4*n+1];  
          
        for(int i=1;i<=n;i++)  
            nodes[i] = new LinkedList<Integer>();  
          
        int[] idegree = new int[n+1];  
          
        for(int i=1;i<n;i++)  
        {  
            int par = scan.nextInt();  
            int chd = scan.nextInt();  
              
            nodes[par].addFirst(chd);  
            idegree[chd]++;  
        }  
          
        for(int i=1;i<=n;i++)  
        {  
            if(idegree[i] == 0)  
            {  
                root = i;  
                break;  
            }  
        }  
          
        long[] pairs = new long[1];  
          
        depthSearch(root,stree,pairs);  
          
        System.out.println(pairs[0]);  
          
    }  
      
    public static void depthSearch(int nodeval, long[] stree, long[] pairs){  
          
        int min = (nodeval - t < 1) ? 1 : nodeval - t;  
        int max = (nodeval + t > n) ? n : nodeval + t;  
          
        pairs[0] += query(stree,1,1,n,min, max);  
          
        updateTree(stree,1,1,n,nodeval,1);  
          
        for(int chd : nodes[nodeval]){  
            depthSearch(chd, stree, pairs);  
        }  
          
        updateTree(stree,1,1,n,nodeval,-1);  
    }    
      
    public static void updateTree(long[] tree, int node,int tl, int tr, int val, long opt){  
            if(val < tl || val > tr || tl > tr)  
                return;  
              
            tree[node] += opt;  
              
            int m = (tl + tr) >> 1;  
              
            if(tl == tr)  
                return;  
            else if(val <= m)  
                updateTree(tree,node<<1,tl,m,val,opt);  
            else  
                updateTree(tree,node<<1|1,m+1,tr,val,opt);  
    }  
      
    public static long query(long[] tree, int node, int tl, int tr, int min, int max){  
          
        if(max < tl || min > tr)  
            return 0;  
          
        else if(max == tr && min == tl)  
            return tree[node];  
          
        else{  
            int mid = (tl + tr) >> 1;  
            int lmax = (mid < max) ? mid : max;  
            int rmin = (min > mid) ? min : mid + 1;  
            return query(tree,node<<1, tl, mid, min, lmax) + query(tree,node<<1|1, mid+1, tr, rmin, max);  
        }  
    } 
}

{"mode":"full","isActive":false}


Problem solution in C++.

#include<iostream>
#include<vector>
using namespace std;
vector<int> graph[110001];
int T, N, deg[100001] = {0};
long long ST[100001*4] = {0};

void update(int node, int b, int e, int idx, int val)  {

    if(b > node || e < node) return;

    if(b == e) {
        ST[idx] += val;
        return;
    }

    int m = (b + e) >> 1;
    int q = idx << 1;
    update(node, b, m, q, val);
    update(node, m + 1, e, q + 1, val);

    ST[idx] = ST[q]  + ST[q+1];

}

long long Query(int l, int r, int b, int e, int idx) {

    if( l > e || r < b) return 0;

    if(l <= b && r >= e) return ST[idx];

    int m = (b + e) >> 1;
    int q = idx << 1;
    return Query(l, r, b, m, q) + Query(l, r, m + 1, e, q + 1);
}

long long SimilarPairs(int node) {

    int l = max(1, node - T), r  = min(N, node + T);
    long long res = 0;

    res = Query(l, r, 1, N, 1);

    update(node, 1, N, 1, 1);

    for(int i = 0; i < graph[node].size(); i++) {
       res +=  SimilarPairs(graph[node][i]);
    }

    update(node, 1, N, 1, -1);

    return res;
}

int main() {

    long x, y, root, start;


    cin >> N >> T;

    for(int i = 0; i < N - 1; i++) {
        cin >> x >> y;
        graph[x].push_back(y);
        deg[y]++;
    }

    for(int i = 1; i <= N; i++) if(!deg[i]) root = i;

    long long result = SimilarPairs(root);

    cout << result << endl;

    cin.get();

    return 0;

}

{"mode":"full","isActive":false}


Problem solution in C.

#include "stdio.h"
#include "stdlib.h"
#include "string.h"
#include "math.h"

typedef struct Node
{
    struct Node *parent;
    struct Node *peer_next;
    struct Node *child_list;
    int     val;
    struct Node *hash_next;
}Node;

unsigned long long int count;
unsigned int n,T,size;
Node **hash;
Node *root=NULL;

unsigned int diff(int a, int b)
{
    if(a>b) return (a-b);
    else    return (b-a);
}

void countup(Node *x)
{
    int i,val;
    if(!x || !x->parent) return;
    if((n-T) < size)
    {
        count+=size;
        for(i=0;i<(((x->val-1)>T)?(x->val-1-T):0); i++)
            if(hash[i]) count--;
        for(i=(((x->val+T)>n)?n:(x->val+T));i<n; i++)
            if(hash[i]) count--;
    }
    else if(T > size)
    {
        val=x->val;
        x=x->parent;
        while(x)
        {
            if(diff(val,x->val) <= T) count++;
            x=x->parent;
        }
    }
    else
    {
        for(i=((x->val-1)>T)?(x->val-1-T):0; i<(((x->val+T)>n)?n:(x->val+T)); i++)
        {
            if(hash[i])
            {
                //printf("%2d, 0x%x\n",i,hash[i]);
                count++;
            }
        }
    }
}

void solve()
{
    Node *tmp=root;
    Node *tmp1;
    int i;
    for(i=0;i<n;i++) hash[i]=NULL;
    size=0;
    while(tmp)
    {
        while(tmp->child_list)
        {
            hash[(tmp->val-1)%n]=tmp;
            size++;
            tmp=tmp->child_list;
        }

        countup(tmp);
        tmp1=tmp;
        tmp=tmp->parent;
        if(tmp)// && (tmp->child_list == tmp1))
        {
            hash[(tmp->val-1)%n]=NULL;
            size--;
            tmp->child_list=tmp1->peer_next;
        }
        //printf("node = %3d (count = %d)\n",tmp1->val,count);
        free(tmp1);
    }
}

Node* allocate(unsigned int val)
{
    Node *node=malloc(sizeof(Node));
    memset(node,0,sizeof(Node));
    node->val=val;
    return node;
}
Node* insert(unsigned int val)
{
    Node *tmp=hash[val%n];
    if(!tmp)
    {
        return (hash[val%n]=allocate(val));
    }
    while(tmp)
    {
        if(tmp->val==val) return tmp;
        if(!tmp->hash_next)
            break;
        tmp=tmp->hash_next;
    }
    return (tmp->hash_next=allocate(val));
}

void connect(Node *parent, Node *child)
{
    if(!parent || !child) return;
    /*if(!parent->child_list)
        parent->child_list=child;
    else
    {
        Node *peer=parent->child_list;
        while(peer->peer_next) peer=peer->peer_next;
        peer->peer_next=child;
    }*/
    child->peer_next=parent->child_list;
    parent->child_list=child;

    child->parent=parent;
}

void build(){

    int i,a,b;
    Node *parent,*child;
    for(i=0;i<n-1;i++)
    {
        scanf("%d %d",&a,&b);
        parent=insert(a);
        child=insert(b);
        //printf("%d %d\n",parent->val,child->val);
        connect(parent,child);
        /*if(!parent->parent)
            root=parent;*/
    }
    root=hash[1];
    while(root && root->parent) root=root->parent;
}

void print(Node *node, int level)
{
    int i=level;
    if(!node) return;
    while(i--) printf("  ");
    printf("%d (%d)\n",node->val,node->parent?node->parent->val:0);
    node=node->child_list;
    while(node)
    {
        print(node,level+1);
        node=node->peer_next;
    }
}

int main(){
    count=0;
    scanf("%d %d",&n,&T);
    hash=malloc(n*sizeof(Node*));
    memset(hash,0,n*sizeof(Node*));
    if (!hash) return -1;
    build();
    //print(root, 0);
    solve();
    printf("%llu\n",count);
    return 0;
}
                    

{"mode":"full","isActive":false}