In this HackerRank Kth Ancestor problem solution, Susan likes to play with graphs, and Tree data structure is one of her favorites. She has designed a problem and wants to know if anyone can solve it. Sometimes she adds or removes a leaf node. Your task is to figure out the Kth parent of a node at any instant.

HackerRank Kth Ancestor problem solution


Problem solution in Python.

import sys
from collections import defaultdict, namedtuple
from array import array


AddNode = namedtuple('AddNode', 'child parent')
RemoveNode = namedtuple('RemoveNode', 'node')
QueryParent = namedtuple('QueryParent', 'node kth')


def log(msg):
    print(msg, file=sys.stderr)


def find_all_leaf_nodes(tree):
    queue = set([0, ])
    while queue:
        node = queue.pop()
        if node in tree.children:
            for c in tree.children[node]:
                queue.add(c)
        else:
            yield node


def solve_queries(tree, queries):
    #log('Tree:')
    #print_tree(tree)
    for q in queries:
        #log(q)
        if type(q) == AddNode:
            tree.add_node(q.child, q.parent)
        elif type(q) == RemoveNode:
            tree.remove_leaf(q.node)
        elif type(q) == QueryParent:
            yield tree.get_kth_parent(q.node, q.kth)


def read_ints(reader):
    for p in (_.strip().split() for _ in reader):
        yield tuple([int(_) for _ in p])


class Tree(object):
    def __init__(self):
        self.children = defaultdict(set)
        self.parents = dict()
        self.levels = dict()
        self.levels[0] = 0
        self.ten_p = dict()
        self.hundred_p = dict()
        self.thousand_p = dict()
        self.cache_hits = [0, 0, 0]

    def add_node(self, child, parent):
        # first get the level
        level = self.levels[parent] + 1
        self.levels[child] = level
        self.children[parent].add(child)
        self.parents[child] = parent
        if level > 10 and level % 10 == 0:
            self.ten_p[child] = self.get_kth_parent(child, 10)
        if level > 100 and level % 100 == 0:
            self.hundred_p[child] = self.get_kth_parent(child, 100)
        if level > 1000 and level % 1000 == 0:
            self.thousand_p[child] = self.get_kth_parent(child, 1000)

    def remove_leaf(self, node):
        level = self.levels.pop(node)
        parent = self.parents.pop(node)
        self.children[parent].remove(node)
        if level % 10 == 0:
            try:
                self.ten_p.pop(node)
            except KeyError:
                pass
        if level % 100 == 0:
            try:
                self.hundred_p.pop(node)
            except KeyError:
                pass
        if level % 1000 == 0:
            try:
                self.thousand_p.pop(node)
            except KeyError:
                pass

    def get_kth_parent(self, node, max_back):
        if node not in self.parents:
            return 0
        if self.levels[node] < max_back:
            return 0
        zero_counter = max_back
        while node != 0 and zero_counter != 0:
            if zero_counter > 1000 and node in self.thousand_p:
                self.cache_hits[2] += 1
                node = self.thousand_p[node]
                zero_counter -= 1000
                continue
            if zero_counter > 100 and node in self.hundred_p:
                self.cache_hits[1] += 1
                node = self.hundred_p[node]
                zero_counter -= 100
                continue
            if zero_counter > 10 and node in self.ten_p:
                self.cache_hits[0] += 1
                node = self.ten_p[node]
                zero_counter -= 10
                continue
            node = self.parents[node]
            zero_counter -= 1
        # we are at the root-root node
        return node


def read_instructions(int_lines):
    number_cases = next(int_lines)[0]
    for _ in range(number_cases):
        nodes_in_tree = next(int_lines)[0]
        tree = Tree()
        for pos, (child, parent) in enumerate(int_lines):
            tree.add_node(child, parent)
            if pos == nodes_in_tree - 1:
                break
        number_queries = next(int_lines)[0]
        queries = list()
        for pos, vals in enumerate(int_lines):
            if vals[0] == 0:
                # notice reversal, to make same as input
                queries.append(AddNode(vals[2], vals[1]))
            elif vals[0] == 1:
                queries.append(RemoveNode(vals[1]))
            elif vals[0] == 2:
                queries.append(QueryParent(vals[1], vals[2]))
            else:
                raise Exception('Do not know how to handle query of type %d in %s' % (vals[0], vals))
            if pos == number_queries - 1:
                break
        yield (tree, queries)


def main():
    for tree, queries in read_instructions(read_ints(sys.stdin)):
        for answer in solve_queries(tree, queries):
            print(answer)
    log('Cached: %s' % (tree.cache_hits))


if __name__ == '__main__':
    main()

{"mode":"full","isActive":false}


Problem solution in Java.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Solution {
    static InputStream is;
    static PrintWriter out;
    static String INPUT = "";
    
    static void solve()
    {
        for(int T = ni(); T >= 1;T--){
            int n = 100001;
            int m = ni();
            int[] from = new int[m];
            int[] to = new int[m];
            for(int i = 0;i < m;i++){
                from[i] = ni();
                to[i] = ni();
            }
            int[][] g = packU(n, from, to);
            int[] par = parents(g, 0);
            
            int[][] spar = new int[17][n];
            for(int i = 0;i < n;i++){
                spar[0][i] = par[i];
            }
            for(int d = 1;d < 17;d++){
                for(int i = 0;i < n;i++){
                    spar[d][i] = spar[d-1][i] == -1 ? -1 : spar[d-1][spar[d-1][i]];
                }
            }
            int Q = ni();
            for(int z = 0;z < Q;z++){
                int type = ni();
                if(type == 0){
                    // insert
                    int y = ni(), x = ni();
                    spar[0][x] = y;
                    for(int d = 1;d < 17;d++){
                        spar[d][x] = spar[d-1][x] == -1 ? -1 : spar[d-1][spar[d-1][x]];
                    }
                }else if(type == 1){
                    // remove
                    int y = ni();
                    for(int d = 0;d < 17;d++){
                        spar[d][y] = -1;
                    }
                }else if(type == 2){
                    // kth
                    int y = ni(), K = ni();
                    for(int d = 0;d < 17;d++){
                        if(K<<31-d<0){
                            y = spar[d][y];
                            if(y == -1)break;
                        }
                    }
                    if(y == -1)y = 0;
                    out.println(y);
                }
            }
        }
    }
    
    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for(int f : from)
            p[f]++;
        for(int t : to)
            p[t]++;
        for(int i = 0;i < n;i++)
            g[i] = new int[p[i]];
        for(int i = 0;i < from.length;i++){
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }
    
    public static int[] parents(int[][] g, int root)
    {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);
        
        int[] q = new int[n];
        q[0] = root;
        for(int p = 0, r = 1;p < r;p++) {
            int cur = q[p];
            for(int nex : g[cur]){
                if(par[cur] != nex){
                    q[r++] = nex;
                    par[nex] = cur;
                }
            }
        }
        return par;
    }
    
    public static void main(String[] args) throws Exception
    {
        long S = System.currentTimeMillis();
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        solve();
        out.flush();
        long G = System.currentTimeMillis();
        tr(G-S+"ms");
    }
    
    private static boolean eof()
    {
        if(lenbuf == -1)return true;
        int lptr = ptrbuf;
        while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false;
        
        try {
            is.mark(1000);
            while(true){
                int b = is.read();
                if(b == -1){
                    is.reset();
                    return true;
                }else if(!isSpaceChar(b)){
                    is.reset();
                    return false;
                }
            }
        } catch (IOException e) {
            return true;
        }
    }
    
    private static byte[] inbuf = new byte[1024];
    static int lenbuf = 0, ptrbuf = 0;
    
    private static int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private static double nd() { return Double.parseDouble(ns()); }
    private static char nc() { return (char)skip(); }
    
    private static String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private static char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private static char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private static int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private static int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { if(INPUT.length() != 0)System.out.println(Arrays.deepToString(o)); }
}

{"mode":"full","isActive":false}


Problem solution in C++.

#include<iostream>
#include<cstdio>
#include<cmath>
#include<string>
#include<algorithm>
#include<vector>
using namespace std;

const int MAXN = 100100;
const int MAXK = 20;
int parent[MAXN][MAXK];
int cntChild[MAXN];

void solve() {
    for(int i=0; i<MAXN; i++) {
        for(int j=0; j<MAXK; j++)
            parent[i][j] = 0;
        cntChild[i] = 0;
    }
    int N;
    cin>>N;
    for(int i=0; i<N; i++) {
        int x,y;
//        scanf("%d%d",&x,&y);
        cin>>x>>y;
        parent[x][0] = y;
        cntChild[y]++;
    }
    for(int i=1; i<MAXK; i++) {
        for(int v=1; v<MAXN; v++) {
            parent[v][i] = parent[parent[v][i-1]][i-1];
            //     if(parent[v][i]!=0)
            //	cout<<v<<" "<<i<<" "<<parent[v][i]<<endl;
        }
    }
    int Q;
    cin>>Q;
    for(int _=0; _<Q; _++) {
      /*  for(int i=0; i<20; i++) {
           cout<<i<<" : ";
            for(int j=0; j<20; j++)
                cout<<parent[i][j]<<" ";
            cout<<endl;
        }*/
        int kind;
        cin>>kind;
//        scanf("%d",&kind);
        if(kind==1) {
            int x;
            cin>>x;
//            scanf("%d",&x);
            cntChild[parent[x][0]]--;
            if(cntChild[x]!=0) for(;;);
            for(int j=0; j<MAXK; j++)
                parent[x][j] = 0;
        }
        if(kind==0) {
            int x,y;
//            scanf("%d%d",&y,&x);
            cin>>y>>x;
  //          if(y==0) for(;;);
            parent[x][0]= y;
            cntChild[y]++;
            for(int i=1; i<MAXK; i++) {
                parent[x][i] = parent[parent[x][i-1]][i-1];
            }
        }
        if(kind==2) {
            int x,k;
  //          scanf("%d%d",&x,&k);
            cin>>x>>k;
            while(k!=0) {
                int t = 1;
                int cnt = 0;
                while(t<=k) {
                    t *= 2;
                    cnt++;
                }
                t/=2;cnt--;
                
                x = parent[x][cnt];
                k -= t;
//                cout<<x<<" "<<k<<endl;
            }
            cout<<x<<endl;
            
        }
        
    }
    
}
int main() {
    int T;
    cin>>T;
    for(int i=0; i<T; i++)
        solve();
    
    return 0;
}

{"mode":"full","isActive":false}


Problem solution in C.

#include <assert.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAXNODE 300000
#define LOGMAX 20

char* readline();
char** split_string(char*);

int* kthparent(int n, int** edges, int q, char** queries, int* result_count){
    int pow2parent[LOGMAX][MAXNODE];
    for(int i = 0; i < LOGMAX; i++){
        for(int j = 0; j < MAXNODE; j++){
            pow2parent[i][j] = 0;
        }
    }

    for(int i = 0; i < n; i++){
        pow2parent[0][edges[i][0]] = edges[i][1];
    }
    for(int i = 1; i < LOGMAX; i++){
        for(int j = 0; j < MAXNODE; j++){
            pow2parent[i][j] = pow2parent[i - 1][pow2parent[i - 1][j]];
        }
    }

    int *toreturn = NULL;
    *result_count = 0;
    for(int i = 0; i < q; i++){
        char** splitquery = split_string(queries[i]);
        if(queries[i][0] == '0'){
            int parent = atoi(splitquery[1]);
            int leaf = atoi(splitquery[2]);
            pow2parent[0][leaf] = parent;
            for(int j = 1; j < LOGMAX; j++){
                pow2parent[j][leaf] = pow2parent[j - 1][pow2parent[j - 1][leaf]];
            }
        }
        else if(queries[i][0] == '1'){
            int leaf = atoi(splitquery[1]);
            for(int j = 0; j < LOGMAX; j++){
                pow2parent[j][leaf] = 0;
            }
        }
        else if(queries[i][0] == '2'){
            *result_count += 1;
            toreturn = realloc(toreturn, (*result_count)*sizeof(int));
            int currnode = atoi(splitquery[1]);
            int target = atoi(splitquery[2]);
            for(int j = 0; j < LOGMAX; j++){
                if(((target>>j)&1) == 1){
                    currnode = pow2parent[j][currnode];
                }
            }
            toreturn[(*result_count) - 1] = currnode;
        }
        else{
            exit(EXIT_FAILURE);
        }
    }
    return toreturn;
}

int main() {

    int t;
    scanf("%d\n", &t);
    for(int i = 0; i < t; i++){
        int n;
        scanf("%d\n", &n);
        int** edges = malloc(n*sizeof(int*));
        for(int j = 0; j < n; j++){
            edges[j] = malloc(2*sizeof(int));
            scanf("%d %d\n", edges[j], edges[j] + 1);
        }
        int q;
        scanf("%d\n", &q);
        char** queries = malloc(q*sizeof(char*));
        for(int j = 0; j < q; j++){
            queries[j] = readline();
        }

        int result_count;
        int* result = kthparent(n, edges, q, queries, &result_count);
        for(int j = 0; j < result_count; j++){
            printf("%d\n", result[j]);
        }
    }
    return 0;
}


char* readline() {
    size_t alloc_length = 1024;
    size_t data_length = 0;
    char* data = malloc(alloc_length);

    while (true) {
        char* cursor = data + data_length;
        char* line = fgets(cursor, alloc_length - data_length, stdin);

        if (!line) { break; }

        data_length += strlen(cursor);

        if (data_length < alloc_length - 1 || data[data_length - 1] == '\n') { break; }

        size_t new_length = alloc_length << 1;
        data = realloc(data, new_length);

        if (!data) { break; }

        alloc_length = new_length;
    }

    if (data[data_length - 1] == '\n') {
        data[data_length - 1] = '\0';
    }

    data = realloc(data, data_length);

    return data;
}

char** split_string(char* str) {
    char** splits = NULL;
    char* token = strtok(str, " ");

    int spaces = 0;

    while (token) {
        splits = realloc(splits, sizeof(char*) * ++spaces);
        if (!splits) {
            return splits;
        }

        splits[spaces - 1] = token;

        token = strtok(NULL, " ");
    }

    return splits;
}

{"mode":"full","isActive":false}