In this HackerRank Super Maximum Cost Queries problem solution we have a tree with each node that has some cost to travel. we need t find the cost of a path from some node X to some other node Y.

HackerRank Super Maximum Cost Queries problem solution


Problem solution in Python programming.

#!/bin/python3

import os
from operator import itemgetter
from itertools import groupby
from bisect import bisect


class Node:
    """Represents an element of a set."""
    def __init__(self, id):
        self.id = id
        self.parent = self
        self.rank = 0
        self.size = 1
        self.paths = 0

    def __repr__(self):
        return 'Node({!r})'.format(self.id)


def Find(x):
    """Returns the representative object of the set containing x."""
    if x.parent is not x:
        x.parent = Find(x.parent)
    return x.parent


def Union(x, y):
    xroot = Find(x)
    yroot = Find(y)

    # x and y are already in the same set
    if xroot is yroot:
        return 0

    # x and y are not in same set, so we merge them
    if xroot.rank < yroot.rank:
        xroot, yroot = yroot, xroot  # swap xroot and yroot

    new_paths = xroot.size * yroot.size
    xroot.paths += yroot.paths + new_paths
    
    # merge yroot into xroot
    yroot.parent = xroot
    xroot.size += yroot.size
    
    if xroot.rank == yroot.rank:
        xroot.rank = xroot.rank + 1
    
    return new_paths

# Complete the solve function below.
def solve(tree, queries):
    tree.sort(key=itemgetter(2))
    
    weights, path_count = [0], [0]
    nodes = {}
    
    for k, g in groupby(tree, key=itemgetter(2)):
        total = path_count[-1]
        
        for path in g:
            node1 = nodes.setdefault(path[0], Node(path[0]))
            node2 = nodes.setdefault(path[1], Node(path[1]))
            total += Union(node1, node2)
        weights.append(k)
        path_count.append(total)
    
    res = []
    for L, R in queries:
        Li = bisect(weights, L-1) - 1
        Ri = bisect(weights, R) - 1
        res.append(path_count[Ri] - path_count[Li])
    return res
    
    

if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')
    n, q = map(int, input().split())
    tree = []

    for _ in range(n-1):
        tree.append(list(map(int, input().rstrip().split())))

    queries = []

    for _ in range(q):
        queries.append(list(map(int, input().rstrip().split())))

    result = solve(tree, queries)

    fptr.write('\n'.join(map(str, result)))
    fptr.write('\n')

    fptr.close()


Problem solution in Java Programming.

import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

    // Complete the solve function below.
    static long[] solve(final int[][] tree, final int[][] queries) {
      final int edgeCount = tree.length;
      final int n = tree.length + 1;
      final Integer[] indexes = indexes(edgeCount);
      Arrays.sort(indexes, Comparator.comparingInt(idx -> tree[idx][2]));
      
      final UF uf = new UF(n);
      final long[] count = new long[edgeCount];
      final int[] weights = new int[edgeCount];
      for (int i = 0; i < indexes.length; i++) {
        final int idx = indexes[i];
        final int u = tree[idx][0] - 1;
        final int v = tree[idx][1] - 1;
        final int w = tree[idx][2];
        count[i] = 1L * uf.size(u) * uf.size(v);
        weights[i] = w;
        uf.connect(u, v);
      }
      
      final long[] pf = new long[tree.length];
      pf[0] = count[0];
      for (int i = 1; i < count.length; i++) {
        pf[i] = pf[i - 1] + count[i];
      }
      
      final long[] qAns = new long[queries.length];
      for (int qIdx = 0; qIdx < queries.length; qIdx++) {
        final int l = queries[qIdx][0];
        final int r = queries[qIdx][1];
        
        final int lIdx = lbs(weights, l);
        final int rIdx = rbs(weights, r);
        
        final long c;
        if (lIdx == tree.length || rIdx == -1) {
          c = 0;
        } else {
          c = pf[Math.min(rIdx, pf.length - 1)] - (lIdx > 0 ? pf[lIdx - 1] : 0L);
        } 
        qAns[qIdx] = c;
        //1 2 3 4 10 10 121
      }
      return qAns;
    }
  
    private static int lbs(final int[] a, final int value) {
      int lo = 0;
      int hi = a.length - 1;
      while (lo <= hi) {
        final int mid = lo + (hi - lo) / 2;
        if (a[mid] < value) {
          lo = mid + 1;
        } else {
          hi = mid - 1;
        }
      }
      return lo;
    }
  
    private static int rbs(final int[] a, final int value) {
      int lo = 0;
      int hi = a.length - 1;
      while (lo <= hi) {
        final int mid = lo + (hi - lo) / 2;
        if (a[mid] <= value) {
          lo = mid + 1;
        } else {
          hi = mid - 1;
        }
      }
      return hi;
    }
  
    private static long[] prefixSum(final long[] a) {
      final long[] pf = new long[a.length];
      pf[0] = a[0];
      for (int i = 1; i < a.length; i++) {
        pf[i] = pf[i - 1] + a[i];
      }
      return pf;
    }
  
    private static final class UF {
      private final int[] p;
      private final int[] s;
      
      UF(final int n) {
        p = new int[n];
        for (int i = 1; i < n; i++) {
          p[i] = i;
        }
        s = new int[n];
        Arrays.fill(s, 1);
      }
      
      int size(final int i) {
        final int iRoot = root(i);
        return s[iRoot];
      }
      
      void connect(final int a, final int b) {
        final int aRoot = root(a);
        final int bRoot = root(b);
        if (p[aRoot] == p[bRoot]) {
          return;
        }
        final int minRoot = s[aRoot] <= s[bRoot] ? aRoot : bRoot;
        final int maxRoot = aRoot == minRoot ? bRoot : aRoot;
        p[minRoot] = maxRoot;
        s[maxRoot] += s[minRoot];
      }
      
      private int root(int i) {
        while (p[i] != i) {
          p[i] = p[p[i]];
          i = p[i];
        }
        return i;
      }
    }
  
    private static Integer[] indexes(final int n) {
      final Integer[] indexes = new Integer[n];
      for (int i = 0; i < n; i++) {
        indexes[i] = i;
      }
      return indexes;
    }
  
  

    private static final Scanner scanner = new Scanner(System.in);

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

        String[] nq = scanner.nextLine().split(" ");

        int n = Integer.parseInt(nq[0]);

        int q = Integer.parseInt(nq[1]);

        int[][] tree = new int[n-1][3];

        for (int treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
            String[] treeRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int treeColumnItr = 0; treeColumnItr < 3; treeColumnItr++) {
                int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
                tree[treeRowItr][treeColumnItr] = treeItem;
            }
        }

        int[][] queries = new int[q][2];

        for (int queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
            String[] queriesRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int queriesColumnItr = 0; queriesColumnItr < 2; queriesColumnItr++) {
                int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
                queries[queriesRowItr][queriesColumnItr] = queriesItem;
            }
        }

        long[] result = solve(tree, queries);

        for (int resultItr = 0; resultItr < result.length; resultItr++) {
            bufferedWriter.write(String.valueOf(result[resultItr]));

            if (resultItr != result.length - 1) {
                bufferedWriter.write("\n");
            }
        }

        bufferedWriter.newLine();

        bufferedWriter.close();

        scanner.close();
    }
}


Problem solution in C++ programming.

#include <stdio.h>
#include <algorithm>
#include <assert.h>
#include <set>
#include <map>
#include <complex>
#include <iostream>
#include <time.h>
#include <stack>
#include <stdlib.h>
#include <memory.h>
#include <bitset>
#include <math.h>
#include <string>
#include <string.h>
#include <queue>
#include <vector>

using namespace std;

const int MaxN = 1e5 + 10;
const int INF = 1e9;
const int MOD = 1e9 + 7;

int n, q, sz[MaxN], who[MaxN];
vector < pair < int, int > > graph[MaxN];
long long cnt = 0;

int get(int v) {
  return v == who[v] ? v : who[v] = get(who[v]);
}

void unite(int a, int b) {
  a = get(a);
  b = get(b);
  if (a == b) {
    return;
  }
  if (a & 1) {
    swap(a, b);
  }
  cnt += 1LL * sz[a] * sz[b];
  sz[b] += sz[a];
  who[a] = b;
}

int main() {
//  freopen("input.txt", "r", stdin);
//  ios::sync_with_stdio(false);
//  cin.tie(NULL);
  scanf("%d%d", &n, &q);
  vector < pair < int, pair < int, int > > > edges;
  for (int i = 0; i < n - 1; ++i) {
    int u, v, w;
    scanf("%d%d%d", &u, &v, &w);
    graph[u].push_back(make_pair(v, w));
    graph[v].push_back(make_pair(u, w));
    edges.push_back(make_pair(w, make_pair(u, v)));
  }
  vector < pair < int, long long > > ans;
  ans.push_back(make_pair(0, 0LL));
  for (int i = 1; i <= n; ++i) {
    who[i] = i;
    sz[i] = 1;
  }
  sort(edges.begin(), edges.end());
  for (int i = 0, j; i < (int)edges.size(); i = j) {
    for (j = i; j < (int)edges.size() && edges[j].first == edges[i].first; ++j);
    for (int k = i; k < j; ++k) {
      unite(edges[k].second.first, edges[k].second.second);
    }
    ans.push_back(make_pair(edges[i].first, cnt));
  }
  ans.push_back(make_pair(MOD, ans.back().second));
  for (int i = 0; i < q; ++i) {
    int l, r;
    scanf("%d%d", &l, &r);
    printf("%lld\n", (--upper_bound(ans.begin(), ans.end(), make_pair(r, 1LL << 60)))->second -
           (--lower_bound(ans.begin(), ans.end(), make_pair(l, -1LL << 60)))->second);
  }
  return 0;
}


Problem solution in C programming.

#include<stdio.h>
#include<stdlib.h>
typedef long long unsigned U;
typedef unsigned u;
u X[111111],Y[111111],W[111111],N[111111],l;
u B[111111],S[111111],D[111111],Q[222222],Qi;
U V[111111];
int F(const void*x,const void*y)
{
    if(W[*(u*)x]>W[*(u*)y])return 1;
    if(W[*(u*)x]<W[*(u*)y])return-1;
    return 0;
}
u Z(u v)
{
    u lo=0,hi=l,mi;
    while((mi=(lo+hi)>>1)>lo)
    {
        if(B[mi]>v)hi=mi;
        else lo=mi;
    }
    return lo;
}
int main()
{
    u n,q,i=-1,j,k,x,y;U r=0;l=1;
    for(scanf("%u%u",&n,&q);++i<n-1;S[D[N[i]=i]=i]=1)scanf("%u%u%u",X+i,Y+i,W+i);
    for(;i<=n;W[i++]=-1)S[D[N[i]=i]=i]=1;
    qsort(N,n-1,sizeof(u),F);
    for(i=-1;(k=W[N[++i]])!=-1u;)
    {
        for(x=X[N[i]];D[x]!=x;x=D[x])Q[Qi++]=x;
        for(y=Y[N[i]];D[y]!=y;y=D[y])Q[Qi++]=y;
        r+=S[x]*(U)S[y];
        if(x>y){D[x]=y;S[y]+=S[x];x=y;}
        else{D[y]=x;S[x]+=S[y];}
        while(Qi)D[Q[--Qi]]=x;
        if(k!=W[N[i+1]]){B[l]=k;V[l++]=r;}
    }
    while(q--)
    {
        scanf("%u%u",&i,&j);
        x=Z(i);y=Z(j);
        if(B[x]==i)--x;
        printf("%llu\n",V[y]-V[x]);
    }
    return 0;
}


Problem solution in JavaScript programming.

'use strict';

const fs = require('fs');

process.stdin.resume();
process.stdin.setEncoding('utf-8');

let inputString = '';
let currentLine = 0;

process.stdin.on('data', inputStdin => {
    inputString += inputStdin;
});

process.stdin.on('end', _ => {
    inputString = inputString.trim().split('\n').map(str => str.trim());

    main();
});

function readLine() {
    return inputString[currentLine++];
}

// Complete the solve function below.
const UF = class {
    constructor(len) {
        this.parents = Array(len + 1).fill(null).map((e, i) => i);
        this.sizes = Array(len + 1).fill(1);
    }

    find(a) {
        while (a !== this.parents[a]) {
            a = this.parents[a];
        }
        return a;
    }

    union(a, b, operation) {
        const rootOfA = this.find(a);
        const rootOfB = this.find(b);
        if (rootOfA !== rootOfB) {
            const sizeOfA = this.sizes[rootOfA];
            const sizeOfB = this.sizes[rootOfB];

            operation(sizeOfA * sizeOfB);

            if (sizeOfA < sizeOfB) {
                this.parents[rootOfA] = rootOfB;
                this.sizes[rootOfB] += this.sizes[rootOfA];
            } else {
                this.parents[rootOfB] = rootOfA;
                this.sizes[rootOfA] += this.sizes[rootOfB];
            }
        }
    }
}


const solve = (tree, queries) => {
    const len = tree.length;
    const uf = new UF(len + 2);
    tree = tree.sort((a, b) => a[2] - b[2]);
    const pathsWithCost = new Map();
    pathsWithCost.set(0, 0);
    let currentCost = 0;

    for (let i = 0; i < len; i++) {
        if (tree[i][2] !== currentCost) {
            pathsWithCost.set(tree[i][2], pathsWithCost.get(currentCost));
            currentCost = tree[i][2];
        }
        uf.union(tree[i][0], tree[i][1], (pathCount) => {
            pathsWithCost.set(currentCost, pathsWithCost.get(currentCost) + pathCount);
        });
    }

    const keys = Array.from(pathsWithCost.keys());
    const keysLen = keys.length;
    const find = (n) => {
        let lo = -1;
        let hi = keysLen;
        let mid = Math.floor((lo + hi) / 2);

        while(mid > lo && mid < hi) {
            if(keys[mid] === n) {
                return pathsWithCost.get(n);
            } else {
                if(keys[mid] > n) {
                    hi = mid;
                    mid = Math.floor((lo + hi) / 2);
                } else {
                    lo = mid;
                    mid = Math.floor((lo + hi) / 2);
                }
            }
        }
        return pathsWithCost.get(keys[lo]);
    }

    const result = [];
    queries.forEach(query => {
        result.push(find(query[1]) - find(query[0] - 1));
    })

    return result;
}

function main() {
    const ws = fs.createWriteStream(process.env.OUTPUT_PATH);

    const nq = readLine().split(' ');

    const n = parseInt(nq[0], 10);

    const q = parseInt(nq[1], 10);

    let tree = Array(n-1);

    for (let treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
        tree[treeRowItr] = readLine().split(' ').map(treeTemp => parseInt(treeTemp, 10));
    }

    let queries = Array(q);

    for (let queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
        queries[queriesRowItr] = readLine().split(' ').map(queriesTemp => parseInt(queriesTemp, 10));
    }

    let result = solve(tree, queries);

    ws.write(result.join("\n") + "\n");

    ws.end();
}