Header Ad

HackerRank Dynamic Summation problem solution

In this HackerRank Dynamic Summation problem solution, you have given a tree of N nodes, where each node is uniquely numbered in between [1, N]. Each node also has a value that is initially 0. You need to perform the following two operations in the tree.

  1. Update Operation
  2. Report Operation

HackerRank Dynamic Summation problem solution


Problem solution in Java Programming.

import java.io.*;
import java.util.*;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.BitSet;
import java.util.InputMismatchException;

public class DynamicSummationUWISolution {
    static InputStream is;
    static PrintWriter out;
    static String INPUT = "";

    static void solve()
    {

        int P = mods.length;

        int n = ni();
        int[] from = new int[n-1];
        int[] to = new int[n-1];
        for(int i = 0;i < n-1;i++){
            from[i] = ni()-1;
            to[i] = ni()-1;
        }
        int[][] g = packU(n, from, to);
        int[][] pars = parents3(g, 0);
        int[] par = pars[0], dep = pars[2];
        int[][] spar = logstepParents(par);
        int[][] rights = makeRights(g, par, 0);
        int[] iord = rights[1], right = rights[2];

        long[][][] ft = new long[P][2][n+1];

        int Q = ni();
        for(int q = 0;q < Q;q++){
            char type = nc();
            if(type == 'U'){
                int r = ni()-1;
                int t = ni()-1;
                long a = nl();
                long b = nl();
                int[][] qs = null;
                if(r == t){
                    qs = new int[][]{
                            {0, n-1, 1}
                    };
                }else if(dep[r] > dep[t] && ancestor(r, dep[r]-dep[t], spar) == t){
                    int ct = ancestor(r, dep[r]-dep[t]-1, spar);
                    qs = new int[][]{
                            {0, n-1, 1},
                            {iord[ct], right[iord[ct]], -1},
                    };
                }else{
                    qs = new int[][]{
                            {iord[t], right[iord[t]], 1}
                    };
                }

                for(int j = 0;j < P;j++){
                    int mod = mods[j];
                    long v = pow(a, b, mod);
                    v += pow(a+1, b, mod);
                    if(v >= mod)v -= mod;
                    v += pow(b+1, a, mod);
                    if(v >= mod)v -= mod;
                    for(int[] z : qs){
                        addRangeFenwick(ft[j][0], ft[j][1], z[0], z[1], v*z[2]);
                    }
                }
            }else if(type == 'R'){
                int r = ni()-1;
                int t = ni()-1;
                int m = ni();
                if(m == 1){
                    out.println(0);
                }else{
                    int[][] qs = null;
                    if(r == t){
                        qs = new int[][]{
                                {0, n-1, 1}
                        };
                    }else if(dep[r] > dep[t] && ancestor(r, dep[r]-dep[t], spar) == t){
                        int ct = ancestor(r, dep[r]-dep[t]-1, spar);
                        qs = new int[][]{
                                {0, n-1, 1},
                                {iord[ct], right[iord[ct]], -1},
                        };
                    }else{
                        qs = new int[][]{
                                {iord[t], right[iord[t]], 1}
                        };
                    }
                    long[] divs = new long[P];
                    long[] vals = new long[P];
                    for(int j = 0;j < P;j++){
                        int mod = gcd(mods[j], m);
                        long ret = 0;
                        for(int[] z : qs){
                            ret += sumRangeFenwick(ft[j][0], ft[j][1], z[1])*z[2];
                            ret -= sumRangeFenwick(ft[j][0], ft[j][1], z[0]-1)*z[2];
                        }
                        ret %= mod;
                        if(ret < 0)ret += mod;
                        divs[j] = mod;
                        vals[j] = ret;
                    }
                    out.println(crt(divs, vals));
                }
            }
        }

    }

    public static int gcd(int a, int b) {
        while (b > 0){
            int c = a;
            a = b;
            b = c % b;
        }
        return a;
    }

    public static long[] exGCD(long a, long b)
    {
        if(a == 0 || b == 0)return null;
        int as = Long.signum(a);
        int bs = Long.signum(b);
        a = Math.abs(a); b = Math.abs(b);
        long p = 1, q = 0, r = 0, s = 1;
        while(b > 0){
            long c = a / b;
            long d;
            d = a; a = b; b = d % b;
            d = p; p = q; q = d - c * q;
            d = r; r = s; s = d - c * s;
        }
        return new long[]{a, p * as, r * bs};
    }

    public static long crt(final long[] divs, final long[] mods)
    {
        long div = divs[0];
        long mod = mods[0];
        for(int i = 1;i < divs.length;i++){
            long[] apr = exGCD(div, divs[i]);
            if((mods[i] - mod) % apr[0] != 0)return -1;
            long ndiv = div / apr[0] * divs[i];
            long da = div / apr[0];
            long nmod = (mul(mul(apr[1], mods[i]-mod, ndiv), da, ndiv)+mod)%ndiv;
            if(nmod < 0)nmod += ndiv;
            div = ndiv;
            mod = nmod;
        }
        return mod;
    }

    public static long mul(long a, long b, long mod)
    {
        a %= mod; if(a < 0)a += mod;
        b %= mod; if(b < 0)b += mod;
        long ret = 0;
        int x = 63-Long.numberOfLeadingZeros(b);
        for(;x >= 0;x--){
            ret += ret;
            if(ret >= mod)ret -= mod;
            if(b<<63-x<0){
                ret += a;
                if(ret >= mod)ret -= mod;
            }
        }
        return ret;
    }

    public static void addRangeFenwick(long[] ft0, long[] ft1, int i, long v)
    {
        addFenwick(ft1, i+1, -v);
        addFenwick(ft1, 0, v);
        addFenwick(ft0, i+1, v*(i+1));
    }

    public static void addRangeFenwick(long[] ft0, long[] ft1, int a, int b, long v)
    {
        if(a <= b){
            addFenwick(ft1, b+1, -v);
            addFenwick(ft0, b+1, v*(b+1));
            addFenwick(ft1, a, v);
            addFenwick(ft0, a, -v*a);
        }
    }

    public static long sumRangeFenwick(long[] ft0, long[] ft1, int i)
    {
        return sumFenwick(ft1, i) * (i+1) + sumFenwick(ft0, i);
    }

    public static long[] restoreRangeFenwick(long[] ft0, long[] ft1)
    {
        int n = ft0.length-1;
        long[] ret = new long[n];
        for(int i = 0;i < n;i++)ret[i] = sumRangeFenwick(ft0, ft1, i);
        for(int i = n-1;i >= 1;i--)ret[i] -= ret[i-1];
        return ret;
    }

    public static long sumFenwick(long[] ft, int i)
    {
        long sum = 0;
        for(i++;i > 0;i -= i&-i)sum += ft[i];
        return sum;
    }

    public static void addFenwick(long[] ft, int i, long v)
    {
        if(v == 0)return;
        int n = ft.length;
        for(i++;i < n;i += i&-i)ft[i] += v;
    }

    public static long pow(long a, long n, long mod) {
        a %= mod;
        long ret = 1;
        int x = 63 - Long.numberOfLeadingZeros(n);
        for(;x >= 0;x--){
            ret = ret * ret % mod;
            if(n << 63 - x < 0)
                ret = ret * a % mod;
        }
        return ret;
    }

    public static int[] sieveEratosthenes(int n) {
        if(n <= 32){
            int[] primes = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31 };
            for(int i = 0;i < primes.length;i++){
                if(n < primes[i]){
                    return Arrays.copyOf(primes, i);
                }
            }
            return primes;
        }

        int u = n + 32;
        double lu = Math.log(u);
        int[] ret = new int[(int) (u / lu + u / lu / lu * 1.5)];
        ret[0] = 2;
        int pos = 1;

        int[] isp = new int[(n + 1) / 32 / 2 + 1];
        int sup = (n + 1) / 32 / 2 + 1;

        int[] tprimes = { 3, 5, 7, 11, 13, 17, 19, 23, 29, 31 };
        for(int tp : tprimes){
            ret[pos++] = tp;
            int[] ptn = new int[tp];
            for(int i = (tp - 3) / 2;i < tp << 5;i += tp)
                ptn[i >> 5] |= 1 << (i & 31);
            for(int i = 0;i < tp;i++){
                for(int j = i;j < sup;j += tp)
                    isp[j] |= ptn[i];
            }
        }

        // 3,5,7
        // 2x+3=n
        int[] magic = { 0, 1, 23, 2, 29, 24, 19, 3, 30, 27, 25, 11, 20, 8, 4,
                13, 31, 22, 28, 18, 26, 10, 7, 12, 21, 17, 9, 6, 16, 5, 15, 14 };
        int h = n / 2;
        for(int i = 0;i < sup;i++){
            for(int j = ~isp[i];j != 0;j &= j - 1){
                int pp = i << 5 | magic[(j & -j) * 0x076be629 >>> 27];
                int p = 2 * pp + 3;
                if(p > n)
                    break;
                ret[pos++] = p;
                for(int q = pp;q <= h;q += p)
                    isp[q >> 5] |= 1 << (q & 31);
            }
        }

        return Arrays.copyOf(ret, pos);
    }

    public static int lca2(int a, int b, int[][] spar, int[] depth) {
        if(depth[a] < depth[b]){
            b = ancestor(b, depth[b] - depth[a], spar);
        }else if(depth[a] > depth[b]){
            a = ancestor(a, depth[a] - depth[b], spar);
        }

        if(a == b)
            return a;
        int sa = a, sb = b;
        for(int low = 0, high = depth[a], t = Integer.highestOneBit(high), k = Integer
                .numberOfTrailingZeros(t);t > 0;t >>>= 1, k--){
            if((low ^ high) >= t){
                if(spar[k][sa] != spar[k][sb]){
                    low |= t;
                    sa = spar[k][sa];
                    sb = spar[k][sb];
                }else{
                    high = low | t - 1;
                }
            }
        }
        return spar[0][sa];
    }

    protected static int ancestor(int a, int m, int[][] spar) {
        for(int i = 0;m > 0 && a != -1;m >>>= 1, i++){
            if((m & 1) == 1)
                a = spar[i][a];
        }
        return a;
    }

    public static int[][] logstepParents(int[] par) {
        int n = par.length;
        int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(n - 1)) + 1;
        int[][] pars = new int[m][n];
        pars[0] = par;
        for(int j = 1;j < m;j++){
            for(int i = 0;i < n;i++){
                pars[j][i] = pars[j - 1][i] == -1 ? -1
                        : pars[j - 1][pars[j - 1][i]];
            }
        }
        return pars;
    }

    public static int[][] parents3(int[][] g, int root) {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);

        int[] depth = new int[n];
        depth[0] = 0;

        int[] q = new int[n];
        q[0] = root;
        for(int p = 0, r = 1;p < r;p++){
            int cur = q[p];
            for(int nex : g[cur]){
                if(par[cur] != nex){
                    q[r++] = nex;
                    par[nex] = cur;
                    depth[nex] = depth[cur] + 1;
                }
            }
        }
        return new int[][] { par, q, depth };
    }

    public static int[] sortByPreorder(int[][] g, int root){
        int n = g.length;
        int[] stack = new int[n];
        int[] ord = new int[n];
        BitSet ved = new BitSet();
        stack[0] = root;
        int p = 1;
        int r = 0;
        ved.set(root);
        while(p > 0){
            int cur = stack[p-1];
            ord[r++] = cur;
            p--;
            for(int e : g[cur]){
                if(!ved.get(e)){
                    stack[p++] = e;
                    ved.set(e);
                }
            }
        }
        return ord;
    }

    public static int[][] makeRights(int[][] g, int[] par, int root)
    {
        int n = g.length;
        int[] ord = sortByPreorder(g, root);
        int[] iord = new int[n];
        for(int i = 0;i < n;i++)iord[ord[i]] = i;

        int[] right = new int[n];
        for(int i = n-1;i >= 0;i--){
            int v = i;
            for(int e : g[ord[i]]){
                if(e != par[ord[i]]){
                    v = Math.max(v, right[iord[e]]);
                }
            }
            right[i] = v;
        }
        return new int[][]{ord, iord, right};
    }

    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for(int f : from)
            p[f]++;
        for(int t : to)
            p[t]++;
        for(int i = 0;i < n;i++)
            g[i] = new int[p[i]];
        for(int i = 0;i < from.length;i++){
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }

    public static void main(String[] args) throws Exception
    {
        long S = System.currentTimeMillis();
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);

        solve();
        out.flush();
        long G = System.currentTimeMillis();
        tr(G-S+"ms");
    }

    private static boolean eof()
    {
        if(lenbuf == -1)return true;
        int lptr = ptrbuf;
        while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false;

        try {
            is.mark(1000);
            while(true){
                int b = is.read();
                if(b == -1){
                    is.reset();
                    return true;
                }else if(!isSpaceChar(b)){
                    is.reset();
                    return false;
                }
            }
        } catch (IOException e) {
            return true;
        }
    }

    private static byte[] inbuf = new byte[1024];
    static int lenbuf = 0, ptrbuf = 0;

    private static int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }

    private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }

    private static double nd() { return Double.parseDouble(ns()); }
    private static char nc() { return (char)skip(); }

    private static String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private static char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private static char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }

    private static int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }

    private static int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }

        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private static long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }

        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private static void tr(Object... o) { if(INPUT.length() != 0)System.out.println(Arrays.deepToString(o)); }
}


Problem solution in C++ programming.

#include<cstdio>
#include<vector>
using namespace std;

const int MAXN = 100000;

vector<int> v[MAXN + 1];
vector<int> child[MAXN + 1];
bool visited[MAXN + 1];
int startTime[MAXN + 1], endTime[MAXN + 1];

int dfs(int now, int start)
{
	startTime[now] = start;
	visited[now] = true;
	int sz = (int)v[now].size();
	for (int i = 0; i < sz; i++)
	{
		if (!visited[v[now][i]])
		{
			child[now].push_back(v[now][i]);
			start = dfs(v[now][i], start + 1);
		}
	}
	endTime[now] = start;
	return start;
}

const int NMOD = 5;
int MOD[NMOD] = { 11 * 101 * 13 * 97 * 17 * 89, 19 * 83 * 23 * 81 * 25 * 29, 31 * 79 * 37 * 73 * 41, 43 * 47 * 49 * 53 * 59, 61 * 64 * 67 * 71 };
int p[26] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101 };

int modidx[102];

int gcd(int a, int b)
{
	return b ? gcd(b, a%b) : a;
}

void init()
{
	for (int i = 1; i <= 101; i++)
	{
		for (int j = 0; j < NMOD; j++)
		{
			if (gcd(i, MOD[j]) != 1)
			{
				modidx[i] = j;
				break;
			}
		}
	}
}

long long tree[NMOD][4 * MAXN], lazy[NMOD][4 * MAXN];

long long ext_euclid(long long a, long long b, long long& x, long long& y)
{
	long long t, d;
	if (b == 0) { x = 1; y = 0; return a; }
	d = ext_euclid(b, a%b, x, y);
	t = x, x = y, y = t - a / b*y;
	return d;
}

/*
long long linear_modular_equation_system(long long B[], long long W[], long long k)
{
	long long d, x, y, m, n = 1;
	long long a = 0;
	for (long long i = 0; i < k; i++)
		n *= W[i];
	for (long long i = 0; i < k; i++)
	{
		m = n / W[i];
		d = ext_euclid(W[i], m, x, y);
		while (y < 0) y += W[i];
		a += y * (m * B[i] % n);
	}
	while (a < 0) a += n;
	return a % n;
}
*/

long long linear_modular_equation_system(long long B[], long long W[], long long k)
{
	long long d, x, y, m, n = 1;
	long long a = 0;
	for (long long i = 0; i < k; i++)
		n *= W[i];
	for (long long i = 0; i < k; i++)
	{
		m = n / W[i];
		d = ext_euclid(W[i], m, x, y);
		a = (a + y*m*B[i]) % n;
	}
	if (a>0) return a;
	else return a + n;
}

void updateTree(long long node, long long a, long long b, long long i, long long j, long long value, long long modidx)
{
	if (lazy[modidx][node] != 0)
	{
		tree[modidx][node] = (tree[modidx][node] + (long long)(b - a + 1) * lazy[modidx][node]) % MOD[modidx];

		if (a != b)
		{
			lazy[modidx][node * 2] = (lazy[modidx][node * 2] + lazy[modidx][node]) % MOD[modidx];
			lazy[modidx][node * 2 + 1] = (lazy[modidx][node * 2 + 1] + lazy[modidx][node]) % MOD[modidx];
		}

		lazy[modidx][node] = 0;
	}

	if (a > b || a > j || b < i) return;

	if (a >= i && b <= j)
	{
		tree[modidx][node] = (tree[modidx][node] + (long long)(b - a + 1) * value) % MOD[modidx];

		if (a != b)
		{
			lazy[modidx][node * 2] = (lazy[modidx][node * 2] + value) % MOD[modidx];
			lazy[modidx][node * 2 + 1] = (lazy[modidx][node * 2 + 1] + value) % MOD[modidx];
		}

		return;
	}

	updateTree(node * 2, a, (a + b) / 2, i, j, value, modidx);
	updateTree(node * 2 + 1, (a + b) / 2 + 1, b, i, j, value, modidx);

	tree[modidx][node] = (tree[modidx][node * 2] + tree[modidx][node * 2 + 1]) % MOD[modidx];
}

long long queryTree(long long node, long long a, long long b, long long i, long long j, long long modidx)
{
	if (a > b || a > j || b < i) return 0;

	if (lazy[modidx][node] != 0)
	{
		tree[modidx][node] = (tree[modidx][node] + (long long)(b - a + 1) * lazy[modidx][node]) % MOD[modidx];
		if (a != b)
		{
			lazy[modidx][node * 2] = (lazy[modidx][node * 2] + lazy[modidx][node]) % MOD[modidx];
			lazy[modidx][node * 2 + 1] = (lazy[modidx][node * 2 + 1] + lazy[modidx][node]) % MOD[modidx];
		}
		lazy[modidx][node] = 0;
	}

	if (a >= i && b <= j) return tree[modidx][node];

	long long q1 = queryTree(node * 2, a, (a + b) / 2, i, j, modidx);
	long long q2 = queryTree(node * 2 + 1, (a + b) / 2 + 1, b, i, j, modidx);
	long long tmp = q1 + q2;
	return tmp % MOD[modidx];
}

int ModExp(long long a, long long b, int mod)
{
	a %= mod;
	long long c = 1, d = a;
	while (b)
	{
		if (b & 1) c = (c*d) % mod;
		d = (d*d) % mod;
		b >>= 1;
	}
	return (int)c;
}

int calc(long long a, long long b, int mod)
{
	long long sum1 = ModExp(a, b, mod);
	long long sum2 = ModExp(a + 1, b, mod);
	long long sum3 = ModExp(b + 1, a, mod);
	return (sum1 + sum2 + sum3) % mod;
}

int binarySearchChild(int root, int st, int en)
{
	int lo = 0, hi = (int)child[root].size(), mid = 0;

	while (lo < hi - 1)
	{
		mid = (lo + hi) / 2;
		int st_mid = startTime[child[root][mid]];
		int en_mid = endTime[child[root][mid]];

		if (st >= st_mid && en <= en_mid) return mid;
		else if (en <= st_mid) hi = mid;
		else lo = mid + 1;
	}
	return lo;
}

int main()
{
	init();

	int n;
	scanf("%d", &n);
	for (int i = 0; i < n - 1; i++)
	{
		int x, y;
		scanf("%d%d", &x, &y);
		v[x].push_back(y);
		v[y].push_back(x);
	}
	dfs(1, 1);

	int q;
	scanf("%d", &q);
	getchar();
	while (q--)
	{
		char ch;
		int r, t, m;
		long long a, b;
		ch = getchar();
		if (ch == 'U')
		{
			scanf("%d%d%lld%lld", &r, &t, &a, &b);
			getchar();

			int sr = startTime[r], er = endTime[r];
			int st = startTime[t], et = endTime[t];

			int val[NMOD];

			for (int i = 0; i < NMOD; i++)
				val[i] = calc(a, b, MOD[i]);

			if (sr == st && er == et)
			{
				for (int i = 0; i < NMOD; i++)
					updateTree(1, startTime[1], endTime[1], startTime[1], endTime[1], val[i], i);
			}
			else if (sr > st && er <= et)
			{
				for (int i = 0; i < NMOD; i++)
					updateTree(1, startTime[1], endTime[1], startTime[1], endTime[1], val[i], i);
				int childIdx = binarySearchChild(t, sr, er);
				for (int i = 0; i < NMOD; i++)
					updateTree(1, startTime[1], endTime[1], startTime[child[t][childIdx]], endTime[child[t][childIdx]], MOD[i] - val[i], i);
			}
			else
			{
				for (int i = 0; i < NMOD; i++)
					updateTree(1, startTime[1], endTime[1], st, et, val[i], i);
			}
		}
		else
		{
			scanf("%d%d%d", &r, &t, &m);
			getchar();

			if (m == 1)
			{
				printf("0\n");
				continue;
			}

			int sr = startTime[r], er = endTime[r];
			int st = startTime[t], et = endTime[t];

			int _m = m;

			long long W[6], B[6], k = 0;
			for (int i = 0; i < 26; i++)
			{
				int tmp = 1;
				bool flag = false;
				while (m % p[i] == 0)
				{
					flag = true;
					m /= p[i];
					tmp *= p[i];
				}
				if (flag)
					W[k++] = tmp;
			}

			if (sr == st && er == et)
			{
				for (int i = 0; i < k; i++)
					B[i] = queryTree(1, startTime[1], endTime[1], startTime[1], endTime[1], modidx[W[i]]) % W[i];
				printf("%lld\n", linear_modular_equation_system(B, W, k) % _m);
			}
			else if (sr > st && er <= et)
			{
				int childIdx = binarySearchChild(t, sr, er);
				for (int i = 0; i < k; i++)
				{
					long long tmp1 = queryTree(1, startTime[1], endTime[1], startTime[1], endTime[1], modidx[W[i]]) % W[i];
					long long tmp2 = queryTree(1, startTime[1], endTime[1], startTime[child[t][childIdx]], endTime[child[t][childIdx]], modidx[W[i]]) % W[i];
					B[i] = (tmp1 + W[i] - tmp2) % W[i];
				}
				printf("%lld\n", linear_modular_equation_system(B, W, k) % _m);
			}
			else
			{
				for (int i = 0; i < k; i++)
					B[i] = queryTree(1, startTime[1], endTime[1], st, et, modidx[W[i]]) % W[i];
				printf("%lld\n", linear_modular_equation_system(B, W, k) % _m);
			}
		}
	}
	return 0;
}


Problem solution in C programming.

#include <stdio.h>
#include <stdlib.h>
#define PNUM 26
#define MODNUM 5
typedef struct whatever{
long long offset;
long long val;
} node;
typedef struct _list{
int x;
struct _list *next;
} list;
void update(int r,int t,long long A,long long B);
void query(int r,int t,int m);
void s_sum_update(int n,int b,int e,int i,
int j,long long val,node*tree,int mod);
long long s_sum_query(int n,int b,int e,
int i,int j,long long offset,node*tree,int mod);
void insert_edge(int x,int y);
void dfs(int x,int level);
void dfs2(int x,int level);
int isA(int x,int y);
long long modPow(long long a,long long x,int mod);
int get_i(int*a,int num,int size);
int med(int*a,int size);
long long crt(long long *mod_prime,
long long *list,int size,long long P);
void ext_gcd(long long a,long long b,
long long *x,long long *y);
int prime[PNUM]={2, 3, 5, 7, 11, 13, 17,
19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61,
 67, 71, 73, 79, 83, 89, 97, 101};
int pmod[MODNUM]={908107200, 247110827, 
259106347, 1673450759, 72370439};
int p_idx[PNUM]=
{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2,
 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4};
int b[100000],e[100000],trace[100000]={0},
idx[100000],level_size[100000]={0},l[100000],
N,sort_size=0;
int *level_x[100000]={0},*level_i[100000]={0};
list *table[100000]={0};
node tree[MODNUM][400000]={0};

int main(){
int Q,r,t,x,y,i;
long long A,B;
char str[2];
scanf("%d",&N);
for(i=0;i<N-1;i++){
scanf("%d%d",&x,&y);
insert_edge(x-1,y-1);
}
dfs(0,0);
for(i=0;i<N;i++)
if(level_size[i]){
level_x[i]=(int*)malloc(level_size[i]*sizeof(int));
level_i[i]=(int*)malloc(level_size[i]*sizeof(int));
level_size[i]=0;
}
for(i=0;i<N;i++)
trace[i]=0;
dfs2(0,0);
scanf("%d",&Q);
while(Q--){
scanf("%s",str);
if(str[0]=='U'){
scanf("%d%d%lld%lld",&r,&t,&A,&B);
update(r-1,t-1,A,B);
}
else{
scanf("%d%d%d",&r,&t,&x);
query(r-1,t-1,x);
}
}
return 0;
}
void update(int r,int t,long long A,long long B){
long long val,i,j;
for(i=0;i<MODNUM;i++){
val=(modPow(A,B,pmod[i])+modPow(A+1,B,pmod[i])+modPow(
    B+1,A,pmod[i]))%pmod[i];
if(isA(t,r)){
s_sum_update(1,0,N-1,b[0],e[0],val,&tree[i][0],pmod[i]);
if(t!=r){
j=get_i(level_i[l[t]+1],b[r],level_size[l[t]+1]);
if(l[t]+1==l[r])
j=level_x[l[t]+1][j];
else
j=level_x[l[t]+1][j-1];
s_sum_update(1,0,N-1,b[j],e[j],pmod[i]-val,&tree[i][0],pmod[i]);
}
}
else
s_sum_update(1,0,N-1,b[t],e[t],val,&tree[i][0],pmod[i]);
}
return;
}
void query(int r,int t,int m){
int rprime_size=0,i,j;
long long mm[MODNUM],rprime[PNUM],rlist[PNUM],p=m;
if(m==1){
printf("0\n");
return;
}
for(i=0;i<MODNUM;i++)
if(isA(t,r)){
mm[i]=s_sum_query(1,0,N-1,b[0],e[0],0,&tree[i][0],pmod[i]);
if(t!=r){
j=get_i(level_i[l[t]+1],b[r],level_size[l[t]+1]);
if(l[t]+1==l[r])
j=level_x[l[t]+1][j];
else
j=level_x[l[t]+1][j-1];
mm[i]=(mm[i]-s_sum_query(1,0,N-1,b[j],e[j],
0,&tree[i][0],pmod[i])+pmod[i])%pmod[i];
}
}
else
mm[i]=s_sum_query(1,0,N-1,b[t],e[t],0,&tree[i][0],pmod[i]);
for(i=0;i<PNUM;i++)
if(m%prime[i]==0){
rprime[rprime_size]=1;
while(p%prime[i]==0){
p/=prime[i];
rprime[rprime_size]*=prime[i];
}
rlist[rprime_size]=mm[p_idx[i]]%rprime[rprime_size];
rprime_size++;
}
printf("%lld\n",crt(rprime,rlist,rprime_size,m));
return;
}
void s_sum_update(int n,int b,int e,int i,
int j,long long val,node*tree,int mod){
if(b>e||i>j||b>j||e<i)
return;
if(b>=i&&e<=j){
tree[n].offset=(tree[n].offset+val)%mod;
tree[n].val=(tree[n].val+(e-b+1)*val)%mod;
return;
}
s_sum_update(n*2,b,(b+e)/2,i,j,val,tree,mod);
s_sum_update(n*2+1,(b+e)/2+1,e,i,j,val,tree,mod);
tree[n].val=(tree[n*2].val+tree[
    n*2+1].val+tree[n].offset*(e-b+1))%mod;
return;
}
long long s_sum_query(int n,int b,int e,
int i,int j,long long offset,node*tree,int mod){
if(b>e||i>j||b>j||e<i)
return 0;
if(b>=i&&e<=j)
return (tree[n].val+(e-b+1)*offset)%mod;
offset=(offset+tree[n].offset)%mod;
return (s_sum_query(n*2,b,(b+e)/2,i,j,
offset,tree,mod)+s_sum_query(n*2+1,(
b+e)/2+1,e,i,j,offset,tree,mod))%mod;
}
void insert_edge(int x,int y){
list *node;
node=(list*)malloc(sizeof(list));
node->x=x;
node->next=table[y];
table[y]=node;
node=(list*)malloc(sizeof(list));
node->x=y;
node->next=table[x];
table[x]=node;
return;
}
void dfs(int x,int level){
trace[x]=1;
b[x]=sort_size++;
list *node;
l[x]=level;
level_size[level]++;
for(node=table[x];node;node=node->next){
if(!trace[node->x])
dfs(node->x,level+1);
}
e[x]=sort_size-1;
return;
}
void dfs2(int x,int level){
trace[x]=1;
list *node;
level_i[level][level_size[level]]=b[x];
level_x[level][level_size[level]]=x;
level_size[level]++;
for(node=table[x];node;node=node->next){
if(!trace[node->x])
dfs2(node->x,level+1);
}
return;
}
int isA(int x,int y){
return b[x]<=b[y] && e[x]>=e[y];
}
long long modPow(long long a,long long x,int mod){
long long res = 1;
a%=mod;
while(x>0){
if(x%2)
res=res*a%mod;
a=a*a%mod;
x>>=1;
}
return res;
}
int get_i(int*a,int num,int size){
if(size==0)
return 0;
if(num>med(a,size))
return get_i(&a[(size+1)>>1],num,size>>1)+((size+1)>>1);
else
return get_i(a,num,(size-1)>>1);
}
int med(int*a,int size){
return a[(size-1)>>1];
}
long long crt(long long *mod_prime,
long long *list,int size,long long P){
long long i,x,y,ans=0;
for(i=0;i<size;i++){
ext_gcd(mod_prime[i],P/mod_prime[i],&x,&y);
while(y<0)
y+=P;
ans+=list[i]*y%P*(P/mod_prime[i])%P;
ans%=P;
}
return ans;
}
void ext_gcd(long long a,long long b,
long long *x,long long *y){
long long q,r,s,t;
if(!b){
(*x)=1;
(*y)=0;
return;
}
q=a/b;
r=a%b;
ext_gcd(b,r,&s,&t);
(*x)=t;
(*y)=s-q*t;
return;
}


Post a Comment

0 Comments