Header Ad

HackerRank Kth Ancestor problem solution

In this HackerRank Kth Ancestor problem solution, Susan likes to play with graphs, and Tree data structure is one of her favorites. She has designed a problem and wants to know if anyone can solve it. Sometimes she adds or removes a leaf node. Your task is to figure out the Kth parent of a node at any instant.

HackerRank Kth Ancestor problem solution


Problem solution in Python.

import sys
from collections import defaultdict, namedtuple
from array import array


AddNode = namedtuple('AddNode', 'child parent')
RemoveNode = namedtuple('RemoveNode', 'node')
QueryParent = namedtuple('QueryParent', 'node kth')


def log(msg):
    print(msg, file=sys.stderr)


def find_all_leaf_nodes(tree):
    queue = set([0, ])
    while queue:
        node = queue.pop()
        if node in tree.children:
            for c in tree.children[node]:
                queue.add(c)
        else:
            yield node


def solve_queries(tree, queries):
    #log('Tree:')
    #print_tree(tree)
    for q in queries:
        #log(q)
        if type(q) == AddNode:
            tree.add_node(q.child, q.parent)
        elif type(q) == RemoveNode:
            tree.remove_leaf(q.node)
        elif type(q) == QueryParent:
            yield tree.get_kth_parent(q.node, q.kth)


def read_ints(reader):
    for p in (_.strip().split() for _ in reader):
        yield tuple([int(_) for _ in p])


class Tree(object):
    def __init__(self):
        self.children = defaultdict(set)
        self.parents = dict()
        self.levels = dict()
        self.levels[0] = 0
        self.ten_p = dict()
        self.hundred_p = dict()
        self.thousand_p = dict()
        self.cache_hits = [0, 0, 0]

    def add_node(self, child, parent):
        # first get the level
        level = self.levels[parent] + 1
        self.levels[child] = level
        self.children[parent].add(child)
        self.parents[child] = parent
        if level > 10 and level % 10 == 0:
            self.ten_p[child] = self.get_kth_parent(child, 10)
        if level > 100 and level % 100 == 0:
            self.hundred_p[child] = self.get_kth_parent(child, 100)
        if level > 1000 and level % 1000 == 0:
            self.thousand_p[child] = self.get_kth_parent(child, 1000)

    def remove_leaf(self, node):
        level = self.levels.pop(node)
        parent = self.parents.pop(node)
        self.children[parent].remove(node)
        if level % 10 == 0:
            try:
                self.ten_p.pop(node)
            except KeyError:
                pass
        if level % 100 == 0:
            try:
                self.hundred_p.pop(node)
            except KeyError:
                pass
        if level % 1000 == 0:
            try:
                self.thousand_p.pop(node)
            except KeyError:
                pass

    def get_kth_parent(self, node, max_back):
        if node not in self.parents:
            return 0
        if self.levels[node] < max_back:
            return 0
        zero_counter = max_back
        while node != 0 and zero_counter != 0:
            if zero_counter > 1000 and node in self.thousand_p:
                self.cache_hits[2] += 1
                node = self.thousand_p[node]
                zero_counter -= 1000
                continue
            if zero_counter > 100 and node in self.hundred_p:
                self.cache_hits[1] += 1
                node = self.hundred_p[node]
                zero_counter -= 100
                continue
            if zero_counter > 10 and node in self.ten_p:
                self.cache_hits[0] += 1
                node = self.ten_p[node]
                zero_counter -= 10
                continue
            node = self.parents[node]
            zero_counter -= 1
        # we are at the root-root node
        return node


def read_instructions(int_lines):
    number_cases = next(int_lines)[0]
    for _ in range(number_cases):
        nodes_in_tree = next(int_lines)[0]
        tree = Tree()
        for pos, (child, parent) in enumerate(int_lines):
            tree.add_node(child, parent)
            if pos == nodes_in_tree - 1:
                break
        number_queries = next(int_lines)[0]
        queries = list()
        for pos, vals in enumerate(int_lines):
            if vals[0] == 0:
                # notice reversal, to make same as input
                queries.append(AddNode(vals[2], vals[1]))
            elif vals[0] == 1:
                queries.append(RemoveNode(vals[1]))
            elif vals[0] == 2:
                queries.append(QueryParent(vals[1], vals[2]))
            else:
                raise Exception('Do not know how to handle query of type %d in %s' % (vals[0], vals))
            if pos == number_queries - 1:
                break
        yield (tree, queries)


def main():
    for tree, queries in read_instructions(read_ints(sys.stdin)):
        for answer in solve_queries(tree, queries):
            print(answer)
    log('Cached: %s' % (tree.cache_hits))


if __name__ == '__main__':
    main()

{"mode":"full","isActive":false}


Problem solution in Java.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Solution {
    static InputStream is;
    static PrintWriter out;
    static String INPUT = "";
    
    static void solve()
    {
        for(int T = ni(); T >= 1;T--){
            int n = 100001;
            int m = ni();
            int[] from = new int[m];
            int[] to = new int[m];
            for(int i = 0;i < m;i++){
                from[i] = ni();
                to[i] = ni();
            }
            int[][] g = packU(n, from, to);
            int[] par = parents(g, 0);
            
            int[][] spar = new int[17][n];
            for(int i = 0;i < n;i++){
                spar[0][i] = par[i];
            }
            for(int d = 1;d < 17;d++){
                for(int i = 0;i < n;i++){
                    spar[d][i] = spar[d-1][i] == -1 ? -1 : spar[d-1][spar[d-1][i]];
                }
            }
            int Q = ni();
            for(int z = 0;z < Q;z++){
                int type = ni();
                if(type == 0){
                    // insert
                    int y = ni(), x = ni();
                    spar[0][x] = y;
                    for(int d = 1;d < 17;d++){
                        spar[d][x] = spar[d-1][x] == -1 ? -1 : spar[d-1][spar[d-1][x]];
                    }
                }else if(type == 1){
                    // remove
                    int y = ni();
                    for(int d = 0;d < 17;d++){
                        spar[d][y] = -1;
                    }
                }else if(type == 2){
                    // kth
                    int y = ni(), K = ni();
                    for(int d = 0;d < 17;d++){
                        if(K<<31-d<0){
                            y = spar[d][y];
                            if(y == -1)break;
                        }
                    }
                    if(y == -1)y = 0;
                    out.println(y);
                }
            }
        }
    }
    
    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for(int f : from)
            p[f]++;
        for(int t : to)
            p[t]++;
        for(int i = 0;i < n;i++)
            g[i] = new int[p[i]];
        for(int i = 0;i < from.length;i++){
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }
    
    public static int[] parents(int[][] g, int root)
    {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);
        
        int[] q = new int[n];
        q[0] = root;
        for(int p = 0, r = 1;p < r;p++) {
            int cur = q[p];
            for(int nex : g[cur]){
                if(par[cur] != nex){
                    q[r++] = nex;
                    par[nex] = cur;
                }
            }
        }
        return par;
    }
    
    public static void main(String[] args) throws Exception
    {
        long S = System.currentTimeMillis();
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        solve();
        out.flush();
        long G = System.currentTimeMillis();
        tr(G-S+"ms");
    }
    
    private static boolean eof()
    {
        if(lenbuf == -1)return true;
        int lptr = ptrbuf;
        while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false;
        
        try {
            is.mark(1000);
            while(true){
                int b = is.read();
                if(b == -1){
                    is.reset();
                    return true;
                }else if(!isSpaceChar(b)){
                    is.reset();
                    return false;
                }
            }
        } catch (IOException e) {
            return true;
        }
    }
    
    private static byte[] inbuf = new byte[1024];
    static int lenbuf = 0, ptrbuf = 0;
    
    private static int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private static double nd() { return Double.parseDouble(ns()); }
    private static char nc() { return (char)skip(); }
    
    private static String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private static char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private static char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private static int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private static int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { if(INPUT.length() != 0)System.out.println(Arrays.deepToString(o)); }
}

{"mode":"full","isActive":false}


Problem solution in C++.

#include<iostream>
#include<cstdio>
#include<cmath>
#include<string>
#include<algorithm>
#include<vector>
using namespace std;

const int MAXN = 100100;
const int MAXK = 20;
int parent[MAXN][MAXK];
int cntChild[MAXN];

void solve() {
    for(int i=0; i<MAXN; i++) {
        for(int j=0; j<MAXK; j++)
            parent[i][j] = 0;
        cntChild[i] = 0;
    }
    int N;
    cin>>N;
    for(int i=0; i<N; i++) {
        int x,y;
//        scanf("%d%d",&x,&y);
        cin>>x>>y;
        parent[x][0] = y;
        cntChild[y]++;
    }
    for(int i=1; i<MAXK; i++) {
        for(int v=1; v<MAXN; v++) {
            parent[v][i] = parent[parent[v][i-1]][i-1];
            //     if(parent[v][i]!=0)
            //	cout<<v<<" "<<i<<" "<<parent[v][i]<<endl;
        }
    }
    int Q;
    cin>>Q;
    for(int _=0; _<Q; _++) {
      /*  for(int i=0; i<20; i++) {
           cout<<i<<" : ";
            for(int j=0; j<20; j++)
                cout<<parent[i][j]<<" ";
            cout<<endl;
        }*/
        int kind;
        cin>>kind;
//        scanf("%d",&kind);
        if(kind==1) {
            int x;
            cin>>x;
//            scanf("%d",&x);
            cntChild[parent[x][0]]--;
            if(cntChild[x]!=0) for(;;);
            for(int j=0; j<MAXK; j++)
                parent[x][j] = 0;
        }
        if(kind==0) {
            int x,y;
//            scanf("%d%d",&y,&x);
            cin>>y>>x;
  //          if(y==0) for(;;);
            parent[x][0]= y;
            cntChild[y]++;
            for(int i=1; i<MAXK; i++) {
                parent[x][i] = parent[parent[x][i-1]][i-1];
            }
        }
        if(kind==2) {
            int x,k;
  //          scanf("%d%d",&x,&k);
            cin>>x>>k;
            while(k!=0) {
                int t = 1;
                int cnt = 0;
                while(t<=k) {
                    t *= 2;
                    cnt++;
                }
                t/=2;cnt--;
                
                x = parent[x][cnt];
                k -= t;
//                cout<<x<<" "<<k<<endl;
            }
            cout<<x<<endl;
            
        }
        
    }
    
}
int main() {
    int T;
    cin>>T;
    for(int i=0; i<T; i++)
        solve();
    
    return 0;
}

{"mode":"full","isActive":false}


Problem solution in C.

#include <assert.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAXNODE 300000
#define LOGMAX 20

char* readline();
char** split_string(char*);

int* kthparent(int n, int** edges, int q, char** queries, int* result_count){
    int pow2parent[LOGMAX][MAXNODE];
    for(int i = 0; i < LOGMAX; i++){
        for(int j = 0; j < MAXNODE; j++){
            pow2parent[i][j] = 0;
        }
    }

    for(int i = 0; i < n; i++){
        pow2parent[0][edges[i][0]] = edges[i][1];
    }
    for(int i = 1; i < LOGMAX; i++){
        for(int j = 0; j < MAXNODE; j++){
            pow2parent[i][j] = pow2parent[i - 1][pow2parent[i - 1][j]];
        }
    }

    int *toreturn = NULL;
    *result_count = 0;
    for(int i = 0; i < q; i++){
        char** splitquery = split_string(queries[i]);
        if(queries[i][0] == '0'){
            int parent = atoi(splitquery[1]);
            int leaf = atoi(splitquery[2]);
            pow2parent[0][leaf] = parent;
            for(int j = 1; j < LOGMAX; j++){
                pow2parent[j][leaf] = pow2parent[j - 1][pow2parent[j - 1][leaf]];
            }
        }
        else if(queries[i][0] == '1'){
            int leaf = atoi(splitquery[1]);
            for(int j = 0; j < LOGMAX; j++){
                pow2parent[j][leaf] = 0;
            }
        }
        else if(queries[i][0] == '2'){
            *result_count += 1;
            toreturn = realloc(toreturn, (*result_count)*sizeof(int));
            int currnode = atoi(splitquery[1]);
            int target = atoi(splitquery[2]);
            for(int j = 0; j < LOGMAX; j++){
                if(((target>>j)&1) == 1){
                    currnode = pow2parent[j][currnode];
                }
            }
            toreturn[(*result_count) - 1] = currnode;
        }
        else{
            exit(EXIT_FAILURE);
        }
    }
    return toreturn;
}

int main() {

    int t;
    scanf("%d\n", &t);
    for(int i = 0; i < t; i++){
        int n;
        scanf("%d\n", &n);
        int** edges = malloc(n*sizeof(int*));
        for(int j = 0; j < n; j++){
            edges[j] = malloc(2*sizeof(int));
            scanf("%d %d\n", edges[j], edges[j] + 1);
        }
        int q;
        scanf("%d\n", &q);
        char** queries = malloc(q*sizeof(char*));
        for(int j = 0; j < q; j++){
            queries[j] = readline();
        }

        int result_count;
        int* result = kthparent(n, edges, q, queries, &result_count);
        for(int j = 0; j < result_count; j++){
            printf("%d\n", result[j]);
        }
    }
    return 0;
}


char* readline() {
    size_t alloc_length = 1024;
    size_t data_length = 0;
    char* data = malloc(alloc_length);

    while (true) {
        char* cursor = data + data_length;
        char* line = fgets(cursor, alloc_length - data_length, stdin);

        if (!line) { break; }

        data_length += strlen(cursor);

        if (data_length < alloc_length - 1 || data[data_length - 1] == '\n') { break; }

        size_t new_length = alloc_length << 1;
        data = realloc(data, new_length);

        if (!data) { break; }

        alloc_length = new_length;
    }

    if (data[data_length - 1] == '\n') {
        data[data_length - 1] = '\0';
    }

    data = realloc(data, data_length);

    return data;
}

char** split_string(char* str) {
    char** splits = NULL;
    char* token = strtok(str, " ");

    int spaces = 0;

    while (token) {
        splits = realloc(splits, sizeof(char*) * ++spaces);
        if (!splits) {
            return splits;
        }

        splits[spaces - 1] = token;

        token = strtok(NULL, " ");
    }

    return splits;
}

{"mode":"full","isActive":false}


Post a Comment

0 Comments