In this HackerRank Coprime Paths problem solution You are given an undirected, connected graph, G, with n nodes and m edges where m = n-1. Each node i is initially assigned a value, node, that has at most 3 prime divisors.

You must answer q queries in the form u v. For each query, find and print the number of (x,y) pairs of nodes on the path between u and v such that gcd(node x, node y) = 1 and the length of the path between u and v is minimal among all paths from u to v.

HackerRank Coprime Paths problem solution


Problem solution in Java.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class F {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    
    long ret;
    int[] freq;
    int[] pfreq;
    EulerTour et;
    int[] lpf = enumLowestPrimeFactors(10000005);
    int[] mob = enumMobiusByLPF(10000005, lpf);
    int[] a;
    
    void solve()
    {
        int n = ni(), Q = ni();
        a = na(n);
        for(int i = 0;i < n;i++){
            int pre = -1;
            int mul = 1;
            for(int j = a[i];j > 1;j /= lpf[j]){
                if(pre != lpf[j]){
                    mul *= lpf[j];
                    pre = lpf[j];
                }
            }
            a[i] = mul;
        }
        
        int[] from = new int[n - 1];
        int[] to = new int[n - 1];
        for (int i = 0; i < n - 1; i++) {
            from[i] = ni() - 1;
            to[i] = ni() - 1;
        }
        int[][] g = packU(n, from, to);
        int[][] pars = parents3(g, 0);
        int[] par = pars[0], ord = pars[1], dep = pars[2];
        
        et = nodalEulerTour(g, 0);
        int[][] spar = logstepParents(par);
        
        int[][] qs = new int[Q][];
        int[] special = new int[Q];
        Arrays.fill(special, -1);
        for(int i = 0;i < Q;i++){
            int x = ni()-1, y = ni()-1;
            int lca = lca2(x, y, spar, dep);
            if(lca == x){
                qs[i] = new int[]{et.first[x], et.first[y]};
            }else if(lca == y){
                qs[i] = new int[]{et.first[y], et.first[x]};
            }else if(et.first[x] < et.first[y]){
                qs[i] = new int[]{et.last[x], et.first[y]};
                special[i] = lca;
            }else{
                qs[i] = new int[]{et.last[y], et.first[x]};
                special[i] = lca;
            }
        }
        
        long[] pqs = sqrtSort(qs, 2*n-1);
        
        int L = 0, R = -1;
        freq = new int[n];
        
        long[] ans = new long[Q];
        pfreq = new int[10000005];
        for(long pa : pqs){
            int ind = (int)(pa&(1<<25)-1);
            int ql = qs[ind][0], qr = qs[ind][1];
            while(R < qr)change(++R, 1);
            while(L > ql)change(--L, 1);
            while(R > qr)change(R--, -1);
            while(L < ql)change(L++, -1);
            if(special[ind] != -1)change(et.first[special[ind]], 1);

            ans[ind] = ret;
            if(special[ind] != -1)change(et.first[special[ind]], -1);
        }
        for(long v : ans){
            out.println(v);
        }
    }
    
    public static void trnz(int... o)
    {
        for(int i = 0;i < o.length;i++)if(o[i] != 0)System.out.print(i+":"+o[i]+" ");
        System.out.println();
    }

    
    public static int[] enumMobiusByLPF(int n, int[] lpf)
    {
        int[] mob = new int[n+1];
        mob[1] = 1;
        for(int i = 2;i <= n;i++){
            int j = i/lpf[i];
            if(lpf[j] == lpf[i]){
//                mob[i] = 0;
            }else{
                mob[i] = -mob[j];
            }
        }
        return mob;
    }
    
    void dfs(int cur, int n, int d)
    {
        if(n == 1){
            if(d > 0)ret += mob[cur] * pfreq[cur];
            pfreq[cur] += d;
            if(d < 0)ret -= mob[cur] * pfreq[cur];
            return;
        }
        
        dfs(cur, n/lpf[n], d);
        dfs(cur/lpf[n], n/lpf[n], d);
    }
    
    void change(int x, int d)
    {
        int ind = et.vs[x];
        if(freq[ind] == 1){
            dfs(a[ind], a[ind], -1);
        }
        freq[ind] += d;
        if(freq[ind] == 1){
            dfs(a[ind], a[ind], 1);
        }
    }
    
    public static long[] sqrtSort(int[][] qs, int n)
    {
        int m = qs.length;
        long[] pack = new long[m];
        int S = (int)Math.sqrt(n);
        for(int i = 0;i < m;i++){
            pack[i] = (long)qs[i][0]/S<<50|(long)((qs[i][0]/S&1)==0?qs[i][1]:(1<<25)-1-qs[i][1])<<25|i;
        }
        Arrays.sort(pack);
        return pack;
    }
    
    
    public static int lca2(int a, int b, int[][] spar, int[] depth) {
        if (depth[a] < depth[b]) {
            b = ancestor(b, depth[b] - depth[a], spar);
        } else if (depth[a] > depth[b]) {
            a = ancestor(a, depth[a] - depth[b], spar);
        }

        if (a == b)
            return a;
        int sa = a, sb = b;
        for (int low = 0, high = depth[a], t = Integer.highestOneBit(high), k = Integer
                .numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) {
            if ((low ^ high) >= t) {
                if (spar[k][sa] != spar[k][sb]) {
                    low |= t;
                    sa = spar[k][sa];
                    sb = spar[k][sb];
                } else {
                    high = low | t - 1;
                }
            }
        }
        return spar[0][sa];
    }

    protected static int ancestor(int a, int m, int[][] spar) {
        for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) {
            if ((m & 1) == 1)
                a = spar[i][a];
        }
        return a;
    }

    public static int[][] logstepParents(int[] par) {
        int n = par.length;
        int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(n - 1)) + 1;
        int[][] pars = new int[m][n];
        pars[0] = par;
        for (int j = 1; j < m; j++) {
            for (int i = 0; i < n; i++) {
                pars[j][i] = pars[j - 1][i] == -1 ? -1 : pars[j - 1][pars[j - 1][i]];
            }
        }
        return pars;
    }


    
    public static class EulerTour
    {
        public int[] vs; 
        public int[] first;
        public int[] last; 
        
        public EulerTour(int[] vs, int[] f, int[] l) {
            this.vs = vs;
            this.first = f;
            this.last = l;
        }
    }
    
    public static EulerTour nodalEulerTour(int[][] g, int root)
    {
        int n = g.length;
        int[] vs = new int[2*n];
        int[] f = new int[n];
        int[] l = new int[n];
        int p = 0;
        Arrays.fill(f, -1);
        
        int[] stack = new int[n];
        int[] inds = new int[n];
        int sp = 0;
        stack[sp++] = root;
        outer:
        while(sp > 0){
            int cur = stack[sp-1], ind = inds[sp-1];
            if(ind == 0){
                vs[p] = cur;
                f[cur] = p;
                p++;
            }
            while(ind < g[cur].length){
                int nex = g[cur][ind++];
                if(f[nex] == -1){ 
                    inds[sp-1] = ind;
                    stack[sp] = nex;
                    inds[sp] = 0;
                    sp++;
                    continue outer;
                }
            }
            inds[sp-1] = ind;
            if(ind == g[cur].length){
                vs[p] = cur;
                l[cur] = p;
                p++;
                sp--;
            }
        }
        
        return new EulerTour(vs, f, l);
    }


    public static int[][] parents3(int[][] g, int root) {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);

        int[] depth = new int[n];
        depth[0] = 0;

        int[] q = new int[n];
        q[0] = root;
        for (int p = 0, r = 1; p < r; p++) {
            int cur = q[p];
            for (int nex : g[cur]) {
                if (par[cur] != nex) {
                    q[r++] = nex;
                    par[nex] = cur;
                    depth[nex] = depth[cur] + 1;
                }
            }
        }
        return new int[][] { par, q, depth };
    }

    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for (int f : from)
            p[f]++;
        for (int t : to)
            p[t]++;
        for (int i = 0; i < n; i++)
            g[i] = new int[p[i]];
        for (int i = 0; i < from.length; i++) {
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }

    
    public static int[] enumLowestPrimeFactors(int n) {
        int tot = 0;
        int[] lpf = new int[n + 1];
        int u = n + 32;
        double lu = Math.log(u);
        int[] primes = new int[(int) (u / lu + u / lu / lu * 1.5)];
        for (int i = 2; i <= n; i++)
            lpf[i] = i;
        for (int p = 2; p <= n; p++) {
            if (lpf[p] == p)
                primes[tot++] = p;
            int tmp;
            for (int i = 0; i < tot && primes[i] <= lpf[p] && (tmp = primes[i] * p) <= n; i++) {
                lpf[tmp] = primes[i];
            }
        }
        return lpf;
    }

    
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception { new F().run(); }
    
    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ 
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}

{"mode":"full","isActive":false}


Problem solution in C++.

#include <bits/stdc++.h>

using namespace std;

const int N = 25123, LN = 15;
const int MaxVal = 10000000;

using ll = long long;

vector<int> adj[N];

int depth[N], parent[LN][N];
int ST[N], EN[N], cur_time = 0;
int vec[2 * N];

void dfs(int u = 0, int d = 0, int prev = -1) {
    depth[u] = d;
    parent[0][u] = prev;

    ST[u] = cur_time++;
    vec[ST[u]] = u;

    for (int v : adj[u]) {
        if (v == prev) continue;
        dfs(v, d + 1, u);
    }

    EN[u] = cur_time++;
    vec[EN[u]] = u;
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u, v);

    int diff = depth[u] - depth[v];
    for (int i = 0; i < LN; i++) {
        if ((diff >> i) & 1) {
            u = parent[i][u];
        }
    }

    if (u == v) return u;

    for (int i = LN - 1; i >= 0; i--) {
        if (parent[i][u] != parent[i][v]) {
            u = parent[i][u];
            v = parent[i][v];
        }
    }

    return parent[0][u];
}

vector<int> primes[N];
vector<pair<int, int>> upd[N];

int pr[MaxVal + 1], S;

ll ans[N];
bool used[N];

int vp[3][4 * N];

int main() {
    memset(parent, -1, sizeof parent);

    for (int i = 2; i <= MaxVal; i++) {
        if (!pr[i]) {
            for (int j = i + i; j <= MaxVal; j += i) {
                if (!pr[j]) pr[j] = i;
            }
        }
    }

    int n, q;
    scanf("%d %d", &n, &q);

    S = sqrt(2 * n);

    // printf("S = %d\n", S);

    map<int, int> p1;
    map<tuple<int, int>, int> p2;
    map<tuple<int, int, int>, int> p3;

    for (int i = 0; i < n; i++) {
        int id, x;
        scanf("%d", &x);

        auto& v = primes[i];

        while (pr[x]) {
            v.push_back(pr[x]);
            x /= pr[x];
        }

        if (x > 1) {
            v.push_back(x);
        }

        assert(is_sorted(begin(v), end(v)));
        v.resize(unique(begin(v), end(v)) - begin(v));

        for (int k = 0; k < v.size(); k++) {
            id = p1.size();
            if (p1.count(v[k])) id = p1[v[k]];
            else p1[v[k]] = id;

            upd[i].emplace_back(0, id);

            for (int j = k + 1; j < v.size(); j++) {
                auto tmp = make_tuple(v[k], v[j]);
                id = p2.size();
                if (p2.count(tmp)) id = p2[tmp];
                else p2[tmp] = id;

                upd[i].emplace_back(1, id);
            }
        }

        if (v.size() == 3) {
            auto tmp = make_tuple(v[0], v[1], v[2]);
            id = p3.size();
            if (p3.count(tmp)) id = p3[tmp];
            else p3[tmp] = id;

            upd[i].emplace_back(2, id);
        }
    }

    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    dfs();

    for (int i = 1; i < LN; i++) {
        for (int j = 0; j < n; j++) {
            if (parent[i - 1][j] != -1) {
                parent[i][j] = parent[i - 1][parent[i - 1][j]];
            }
        }
    }

    using qt = tuple<int, int, int, int>;
    vector<qt> queries;
    for (int i = 0; i < q; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        u--, v--;

        if (ST[u] > ST[v]) swap(u, v);

        int p = lca(u, v);
        if (p == u) queries.emplace_back(ST[u], ST[v], i, -1);
        else {
            queries.emplace_back(EN[u], ST[v], i, p);
        }
    }

    sort(begin(queries), end(queries), [](const qt& a, const qt& b) -> bool {
        int l1, r1, i1, p1;
        int l2, r2, i2, p2;

        tie(l1, r1, i1, p1) = a;
        tie(l2, r2, i2, p2) = b;

        if (l1 / S != l2 / S) return l1 / S < l2 / S;
        return r1 > r2;
    });

    int active = 0;
    ll cur = 0;

    auto insert = [&](int u) {
        int tmp = active;

        for (auto& p : upd[u]) {
            if (p.first & 1) tmp += vp[p.first][p.second];
            else tmp -= vp[p.first][p.second];
        }

        cur += tmp;

        used[u] = true;
        active++;

        for (auto& p : upd[u]) {
            vp[p.first][p.second]++;
        }        
    };

    auto remove = [&](int u) {
        used[u] = false;
        active--;

        for (auto& p : upd[u]) {
            vp[p.first][p.second]--;
        }

        int tmp = active;

        for (auto& p : upd[u]) {
            if (p.first & 1) tmp += vp[p.first][p.second];
            else tmp -= vp[p.first][p.second];
        }

        cur -= tmp;        
    };

    int L = 0, R = -1;
    for (auto& t : queries) {
        int l, r, i, p;
        tie(l, r, i, p) = t;

        // printf("%d %d %d %d\n", l, r, i, p);

        while (R < r) {
            R++;
            if (used[vec[R]]) remove(vec[R]);
            else insert(vec[R]);
        }

        while (R > r) {
            if (used[vec[R]]) remove(vec[R]);
            else insert(vec[R]);
            R--;
        }

        while (L < l) {
            if (used[vec[L]]) remove(vec[L]);
            else insert(vec[L]);
            L++;
        }

        while (L > l) {
            L--;
            if (used[vec[L]]) remove(vec[L]);
            else insert(vec[L]);
        }

        ans[i] = cur;

        // printf("cur: %d\n", cur);

        if (p != -1) {
            int tmp = active;

            for (auto& k : upd[p]) {
                if (k.first & 1) tmp += vp[k.first][k.second];
                else tmp -= vp[k.first][k.second];
            }

            ans[i] += tmp;
        }
    }

    for (int i = 0; i < q; i++) {
        printf("%lld\n", ans[i]);
    }

    return 0;
}

{"mode":"full","isActive":false}