Header Ad

HackerRank White Falcon And Tree problem solution

In this HackerRank White Falcon And Tree problem solution, you are given a tree with N nodes. and each node contains a linear function. and first, we need to assign the ax + b function of all the nodes on the path and then calculate the modulo of the function with the given expression.

HackerRank White Falcon And Tree problem solution


Problem solution in Java Programming.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Solution {
    
    static InputStream is;
    static PrintWriter out;
    static String INPUT = "";
    static int mod = 1000000007;
    
    static void solve()
    {
        int n = ni();
        int[][] co = new int[n][];
        for(int i = 0;i < n;i++){
            co[i] = new int[]{ni(), ni()};
        }
        int[] from = new int[n-1];
        int[] to = new int[n-1];
        for(int i = 0;i < n-1;i++){
            from[i] = ni()-1;
            to[i] = ni()-1;
        }
        int[][] g = packU(n, from, to);
        int[][] pars = parents3(g, 0);
        int[] par = pars[0], ord = pars[1], dep = pars[2];
        int[] clus = decomposeToHeavyLight(g, par, ord);
        int[][] cluspath = clusPaths(clus, ord);
        int[] clusiind = clusIInd(cluspath, n);
        SegmentTreeNodePlus[] sts = new SegmentTreeNodePlus[cluspath.length];
        for(int i = 0;i < cluspath.length;i++){
            int[][] lco = new int[cluspath[i].length][];
            for(int j = 0;j < cluspath[i].length;j++){
                lco[j] = co[cluspath[i][j]];
            }
            sts[i] = new SegmentTreeNodePlus(lco);
        }
        
        int[][] spar = logstepParents(par);
        int Q = ni();
        for(int z = 0;z < Q;z++){
            int t = ni();
            if(t == 1){
                int u = ni()-1, v = ni()-1, a = ni(), b = ni();
                int lca = lca2(u, v, spar, dep);
                int[][] pr = query2(u, lca, v, clus, cluspath, clusiind, par);
                for(int[] e : pr){
                    sts[e[0]].update(Math.min(e[1], e[2]), Math.max(e[1], e[2])+1, a, b);
                }
            }else{
                int u = ni()-1, v = ni()-1;
                long x = ni();
                int lca = lca2(u, v, spar, dep);
                int[][] pr = query2(u, lca, v, clus, cluspath, clusiind, par);
                for(int[] e : pr){
                    if(e[1] <= e[2]){
                        x = sts[e[0]].apply(e[1], e[2]+1, x, false);
                    }else{
                        x = sts[e[0]].apply(e[2], e[1]+1, x, true);
                    }
                }
                out.println(x);
            }
        }
    }
    
    public static int lca2(int a, int b, int[][] spar, int[] depth) {
        if (depth[a] < depth[b]) {
            b = ancestor(b, depth[b] - depth[a], spar);
        } else if (depth[a] > depth[b]) {
            a = ancestor(a, depth[a] - depth[b], spar);
        }

        if (a == b)
            return a;
        int sa = a, sb = b;
        for (int low = 0, high = depth[a], t = Integer.highestOneBit(high), k = Integer
                .numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) {
            if ((low ^ high) >= t) {
                if (spar[k][sa] != spar[k][sb]) {
                    low |= t;
                    sa = spar[k][sa];
                    sb = spar[k][sb];
                } else {
                    high = low | t - 1;
                }
            }
        }
        return spar[0][sa];
    }

    protected static int ancestor(int a, int m, int[][] spar) {
        for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) {
            if ((m & 1) == 1)
                a = spar[i][a];
        }
        return a;
    }

    public static int[][] logstepParents(int[] par) {
        int n = par.length;
        int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(n - 1)) + 1;
        int[][] pars = new int[m][n];
        pars[0] = par;
        for (int j = 1; j < m; j++) {
            for (int i = 0; i < n; i++) {
                pars[j][i] = pars[j - 1][i] == -1 ? -1
                        : pars[j - 1][pars[j - 1][i]];
            }
        }
        return pars;
    }
    
    public static class SegmentTreeNodePlus {
        public int M, H, N;
        public Node[] node;
        public int[][] cover;
        
        private static class Node
        {
            long co;
            long lc;
            long rc;
            
            public Node() {
                co = 1;
                lc = rc = 0;
            }

            public Node(long co, long lc, long rc) {
                this.co = co;
                this.lc = lc;
                this.rc = rc;
            }
            
            public long apply(long x, boolean dir)
            {
                if(!dir){
                    return (co * x + lc) % mod;
                }else{
                    return (co * x + rc) % mod;
                }
            }
        }
        
        public SegmentTreeNodePlus(int[][] co)
        {
            N = co.length;
            M = Integer.highestOneBit(Math.max(N-1, 1))<<2;
            H = M>>>1;
            
            node = new Node[M];
            cover = new int[H][];
            for(int i = 0;i < N;i++){
                node[H+i] = new Node(co[i][0], co[i][1], co[i][1]);
            }
            for(int i = H-1;i >= 1;i--)propagate(i);
        }
        
        private void propagate(int cur)
        {
            node[cur] = prop2(node[2*cur], node[2*cur+1], cover[cur], node[cur], H/Integer.highestOneBit(cur));
        }
        
        static final int mod = 1000000007;
        
        private Node prop2(Node L, Node R, int[] cover, Node C, int len)
        {
            if(L != null && R != null){
                if(C == null)C = new Node();
                if(cover == null){
                    C.co = L.co * R.co % mod;
                    C.lc = (R.co * L.lc + R.lc) % mod;
                    C.rc = (L.co * R.rc + L.rc) % mod;
                }else{
                    long co = cover[0], c = cover[1];
                    for(int x = len;x > 1;x >>>= 1){
                        long nco = co * co % mod;
                        long nc = (co * c + c) % mod;
                        co = nco;
                        c = nc;
                    }
                    C.co = co;
                    C.lc = C.rc = c;
                }
                return C;
            }else if(L != null){
                return prop1(L, cover, C, len);
            }else if(R != null){
                return prop1(R, cover, C, len);
            }else{
                return null;
            }
        }
        
        private Node prop1(Node L, int[] cover, Node C, int len)
        {
            if(C == null)C = new Node();
            if(cover == null){
                C.co = L.co;
                C.lc = L.lc;
                C.rc = L.rc;
            }else{
                long co = cover[0], c = cover[1];
                for(int x = len;x > 1;x >>>= 1){
                    long nco = co * co % mod;
                    long nc = (co * c + c) % mod;
                    co = nco;
                    c = nc;
                }
                C.co = co;
                C.lc = C.rc = c;
            }
            return C;
        }
        
        int[] temp = null;
        
        public void update(int l, int r, int a, int b) { 
            temp = new int[]{a, b};
            if(l < r)update(l, r, a, b, 0, H, 1); }
        
        protected void update(int l, int r, int a, int b, int cl, int cr, int cur)
        {
            if(cur >= H){
                node[cur].co = a;
                node[cur].lc = node[cur].rc = b;
            }else if(l <= cl && cr <= r){
                cover[cur] = temp;
                propagate(cur);
            }else{
                int mid = cl+cr>>>1;
                boolean bp = false;
                if(cover[cur] != null){
                    if(2*cur < H){
                        cover[2*cur] = cover[cur];
                        cover[2*cur+1] = cover[cur];
                        cover[cur] = null;
                        bp = true;
                    }else{
                        node[2*cur].co = cover[cur][0];
                        node[2*cur].lc = node[2*cur].rc = cover[cur][1];
                        node[2*cur+1].co = cover[cur][0];
                        node[2*cur+1].lc = node[2*cur+1].rc = cover[cur][1];
                        cover[cur] = null;
                    }
                }
                if(cl < r && l < mid){
                    update(l, r, a, b, cl, mid, 2*cur);
                }else if(bp){
                    propagate(2*cur);
                }
                
                if(mid < r && l < cr){
                    update(l, r, a, b, mid, cr, 2*cur+1);
                }else if(bp){
                    propagate(2*cur+1);
                }
                propagate(cur);
            }
        }
        
        public long apply(int l, int r, long x, boolean dir) {
            return apply(l, r, x, dir, 0, H, 1);
        }
        
        protected long apply(int l, int r, long x, boolean dir, int cl, int cr, int cur)
        {
            if(l <= cl && cr <= r){
                return node[cur].apply(x, dir);
            }else{
                int mid = cl+cr>>>1;
                if(cover[cur] != null){
                    long co = cover[cur][0], c = cover[cur][1];
                    for(int h = Math.min(r, cr) - Math.max(l, cl);h > 0;h >>>= 1){
                        if((h&1) == 1){
                            x = (co * x + c) % mod;
                        }
                        long nco = co * co % mod;
                        long nc = (co * c + c) % mod;
                        co = nco;
                        c = nc;
                    }
                    return x;
                }
                if(!dir){
                    if(cl < r && l < mid){
                        x = apply(l, r, x, dir, cl, mid, 2*cur);
                    }
                    if(mid < r && l < cr){
                        x = apply(l, r, x, dir, mid, cr, 2*cur+1);
                    }
                }else{
                    if(mid < r && l < cr){
                        x = apply(l, r, x, dir, mid, cr, 2*cur+1);
                    }
                    if(cl < r && l < mid){
                        x = apply(l, r, x, dir, cl, mid, 2*cur);
                    }
                }
                return x;
            }
        }
    }
    
    public static int[][] parents3(int[][] g, int root) {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);

        int[] depth = new int[n];
        depth[0] = 0;

        int[] q = new int[n];
        q[0] = root;
        for (int p = 0, r = 1; p < r; p++) {
            int cur = q[p];
            for (int nex : g[cur]) {
                if (par[cur] != nex) {
                    q[r++] = nex;
                    par[nex] = cur;
                    depth[nex] = depth[cur] + 1;
                }
            }
        }
        return new int[][] { par, q, depth };
    }
    
    public static int[] decomposeToHeavyLight(int[][] g, int[] par, int[] ord)
    {
        int n = g.length;
        int[] size = new int[n];
        Arrays.fill(size, 1);
        for(int i = n-1;i > 0;i--)size[par[ord[i]]] += size[ord[i]];
        
        int[] clus = new int[n];
        Arrays.fill(clus, -1);
        int p = 0;
        outer:
        for(int i = 0;i < n;i++){
            int u = ord[i];
            if(clus[u] == -1)clus[u] = p++;
            for(int v : g[u]){
                if(par[u] != v && size[v] >= size[u]/2){
                    clus[v] = clus[u];
                    continue outer;
                }
            }
            for(int v : g[u]){
                if(par[u] != v){
                    clus[v] = clus[u];
                    break;
                }
            }
        }
        return clus;
    }
    
    public static int[][] clusPaths(int[] clus, int[] ord)
    {
        int n = clus.length;
        int[] rp = new int[n];
        int sup = 0;
        for(int i = 0;i < n;i++){
            rp[clus[i]]++;
            sup = Math.max(sup, clus[i]);
        }
        sup++;
        
        int[][] row = new int[sup][];
        for(int i = 0;i < sup;i++)row[i] = new int[rp[i]];
        
        for(int i = n-1;i >= 0;i--){
            row[clus[ord[i]]][--rp[clus[ord[i]]]] = ord[i];
        }
        return row;
    }
    
    public static int[] clusIInd(int[][] clusPath, int n)
    {
        int[] iind = new int[n];
        for(int[] path : clusPath){
            for(int i = 0;i < path.length;i++){
                iind[path[i]] = i;
            }
        }
        return iind;
    }
    
    public static int[][] query2(int x, int anc, int y, int[] clus, int[][] cluspath, int[] clusiind, int[] par)
    {
        int[][] stack = new int[60][];
        int sp = 0;
        
        int cx = clus[x]; 
        int indx = clusiind[x]; 
        while(cx != clus[anc]){
            stack[sp++] = new int[]{cx, indx, 0};
            int con = par[cluspath[cx][0]];
            indx = clusiind[con];
            cx = clus[con];
        }
        stack[sp++] = new int[]{cx, indx, clusiind[anc]};
        
        int top = sp;
        int cy = clus[y];
        int indy = clusiind[y]; 
        while(cy != clus[anc]){
            stack[sp++] = new int[]{cy, 0, indy};
            int con = par[cluspath[cy][0]];
            indy = clusiind[con];
            cy = clus[con];
        }
        if(clusiind[anc] < indy){
            stack[sp++] = new int[]{cy, clusiind[anc]+1, indy};
        }
        for(int p = top, q = sp-1;p < q;p++,q--){
            int[] dum = stack[p]; stack[p] = stack[q]; stack[q] = dum;
        }
        return Arrays.copyOf(stack, sp);
    }
    
    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for (int f : from)
            p[f]++;
        for (int t : to)
            p[t]++;
        for (int i = 0; i < n; i++)
            g[i] = new int[p[i]];
        for (int i = 0; i < from.length; i++) {
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }
    
    public static void main(String[] args) throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        solve();
        out.flush();
    }
    
    private static boolean eof()
    {
        if(lenbuf == -1)return true;
        int lptr = ptrbuf;
        while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false;
        
        try {
            is.mark(1000);
            while(true){
                int b = is.read();
                if(b == -1){
                    is.reset();
                    return true;
                }else if(!isSpaceChar(b)){
                    is.reset();
                    return false;
                }
            }
        } catch (IOException e) {
            return true;
        }
    }
    
    private static byte[] inbuf = new byte[1024];
    static int lenbuf = 0, ptrbuf = 0;
    
    private static int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private static double nd() { return Double.parseDouble(ns()); }
    private static char nc() { return (char)skip(); }
    
    private static String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private static char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private static char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private static int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private static int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
}


Problem solution in C++ programming.

#define _CRT_SECURE_NO_WARNINGS
#include <string>
#include <vector>
#include <algorithm>
#include <numeric>
#include <set>
#include <map>
#include <queue>
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cctype>
#include <cassert>
#include <limits>
#include <functional>
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define rer(i,l,u) for(int (i)=(int)(l);(i)<=(int)(u);++(i))
#define reu(i,l,u) for(int (i)=(int)(l);(i)<(int)(u);++(i))
#if defined(_MSC_VER) || __cplusplus > 199711L
#define aut(r,v) auto r = (v)
#else
#define aut(r,v) typeof(v) r = (v)
#endif
#define each(it,o) for(aut(it, (o).begin()); it != (o).end(); ++ it)
#define all(o) (o).begin(), (o).end()
#define pb(x) push_back(x)
#define mp(x,y) make_pair((x),(y))
#define mset(m,v) memset(m,v,sizeof(m))
#define INF 0x3f3f3f3f
#define INFL 0x3f3f3f3f3f3f3f3fLL
using namespace std;
typedef vector<int> vi; typedef pair<int,int> pii; typedef vector<pair<int,int> > vpii;
typedef long long ll; typedef vector<long long> vl; typedef pair<long long,long long> pll; typedef vector<pair<long long,long long> > vpll;
typedef vector<string> vs; typedef long double ld;
template<typename T, typename U> inline void amin(T &x, U y) { if(y < x) x = y; }
template<typename T, typename U> inline void amax(T &x, U y) { if(x < y) x = y; }

template<int MOD>
struct ModInt {
    static const int Mod = MOD;
    unsigned x;
    ModInt(): x(0) { }
    ModInt(signed sig) { int sigt = sig % MOD; if(sigt < 0) sigt += MOD; x = sigt; }
    ModInt(signed long long sig) { int sigt = sig % MOD; if(sigt < 0) sigt += MOD; x = sigt; }
    int get() const { return (int)x; }
    
    ModInt &operator+=(ModInt that) { if((x += that.x) >= MOD) x -= MOD; return *this; }
    ModInt &operator-=(ModInt that) { if((x += MOD - that.x) >= MOD) x -= MOD; return *this; }
    ModInt &operator*=(ModInt that) { x = (unsigned long long)x * that.x % MOD; return *this; }
    
    ModInt operator+(ModInt that) const { return ModInt(*this) += that; }
    ModInt operator-(ModInt that) const { return ModInt(*this) -= that; }
    ModInt operator*(ModInt that) const { return ModInt(*this) *= that; }
};
typedef ModInt<1000000007> mint;

//y = ax + b
struct LinearExpr {
    mint a, b;
    LinearExpr(): a(1), b(0) { }
    LinearExpr(mint a_, mint b_): a(a_), b(b_) { }
    LinearExpr(const LinearExpr &val, int) { a = val.a, b = val.b; }
    LinearExpr &operator+=(const LinearExpr &that) {
        b = b * that.a + that.b;
        a = a * that.a;
        return *this;
    }
    LinearExpr operator+(const LinearExpr &that) const {
        return LinearExpr(*this) += that;
    }
    LinearExpr operator*(int k) const {
        LinearExpr a = *this, r;
        while(k) {
            if(k & 1) r += a;
            a += a;
            k >>= 1;
        }
        return r;
    }
    mint evalute(mint x) const { return a * x + b; }
};

typedef LinearExpr Val;

struct Sum {
    LinearExpr forward, backward;
    Sum(): forward(), backward() { }
    Sum(const Val &val, int): forward(val), backward(val) { }
    Sum &operator+=(const Sum &that) {
        forward += that.forward;
        backward = that.backward + backward;
        return *this;
    }
    Sum operator+(const Sum &that) const { return Sum(*this) += that; }
};

struct Laziness {
    bool fill;
    LinearExpr expr;
    Laziness(): fill(false) { }
    Laziness(LinearExpr expr_): fill(true), expr(expr_) { }
    Laziness &operator+=(const Laziness &that) {
        if(that.fill)
            *this = that;
        return *this;
    }
    void addToVal(Val &val, int) const {
        if(fill)
            val = expr;
    }
    void addToSum(Sum &sum, int left, int right) const {
        if(fill) {
            LinearExpr multiplicated = expr * (right - left);
            sum.forward = sum.backward = multiplicated;
        }
    }
};

struct SegmentTree {
    vector<Val> leafs;
    vector<Sum> nodes;
    vector<Laziness> laziness;
    vector<int> leftpos, rightpos;
    int n, n2;
    void init(int n_, const Val &v = Val()) { init(vector<Val>(n_, v)); }
    void init(const vector<Val> &u) {
        n = 1; while(n < (int)u.size()) n *= 2;
        n2 = (n - 1) / 2 + 1;
        leafs = u; leafs.resize(n, Val());
        nodes.resize(n);
        for(int i = n-1; i >= n2; -- i)
            nodes[i] = Sum(leafs[i*2-n], i*2-n) + Sum(leafs[i*2+1-n], i*2+1-n);
        for(int i = n2-1; i > 0; -- i)
            nodes[i] = nodes[i*2] + nodes[i*2+1];
        laziness.assign(n, Laziness());

        leftpos.resize(n); rightpos.resize(n);
        for(int i = n-1; i >= n2; -- i) {
            leftpos[i] = i*2-n;
            rightpos[i] = (i*2+1-n) + 1;
        }
        for(int i = n2-1; i > 0; -- i) {
            leftpos[i] = leftpos[i*2];
            rightpos[i] = rightpos[i*2+1];
        }
    }
    Val get(int i) {
        int indices[128];
        int k = getIndices(indices, i, i+1);
        propagateRange(indices, k);
        return leafs[i];
    }
    Sum getRangeCommutative(int i, int j) {
        int indices[128];
        int k = getIndices(indices, i, j);
        propagateRange(indices, k);
        Sum res = Sum();
        for(int l = i + n, r = j + n; l < r; l >>= 1, r >>= 1) {
            if(l & 1) res += sum(l ++);
            if(r & 1) res += sum(-- r);
        }
        return res;
    }
    Sum getRange(int i, int j) {
        int indices[128];
        int k = getIndices(indices, i, j);
        propagateRange(indices, k);
        Sum res = Sum();
        for(; i && i + (i&-i) <= j; i += i&-i)
            res += sum((n+i) / (i&-i));
        for(k = 0; i < j; j -= j&-j)
            indices[k ++] = (n+j) / (j&-j) - 1;
        while(-- k >= 0) res += sum(indices[k]);
        return res;
    }
    void set(int i, const Val &x) {
        int indices[128];
        int k = getIndices(indices, i, i+1);
        propagateRange(indices, k);
        leafs[i] = x;
        mergeRange(indices, k);
    }
    void addToRange(int i, int j, const Laziness &x) {
        if(i >= j) return;
        int indices[128];
        int k = getIndices(indices, i, j);
        propagateRange(indices, k);
        int l = i + n, r = j + n;
        if(l & 1) { int p = (l ++) - n; x.addToVal(leafs[p], p); }
        if(r & 1) { int p = (-- r) - n; x.addToVal(leafs[p], p); }
        for(l >>= 1, r >>= 1; l < r; l >>= 1, r >>= 1) {
            if(l & 1) laziness[l ++] += x;
            if(r & 1) laziness[-- r] += x;
        }
        mergeRange(indices, k);
    }
private:
    int getIndices(int indices[], int i, int j) const {
        int k = 0, l, r;
        if(i >= j) return 0;
        for(l = (n + i) >> 1, r = (n + j - 1) >> 1; l != r; l >>= 1, r >>= 1) {
            indices[k ++] = l;
            indices[k ++] = r;
        }
        for(; l; l >>= 1) indices[k ++] = l;
        return k;
    }
    void propagateRange(int indices[], int k) {
        for(int i = k - 1; i >= 0; -- i)
            propagate(indices[i]);
    }
    void mergeRange(int indices[], int k) {
        for(int i = 0; i < k; ++ i)
            merge(indices[i]);
    }
    inline void propagate(int i) {
        if(i >= n) return;
        laziness[i].addToSum(nodes[i], leftpos[i], rightpos[i]);
        if(i * 2 < n) {
            laziness[i * 2] += laziness[i];
            laziness[i * 2 + 1] += laziness[i];
        }else {
            laziness[i].addToVal(leafs[i * 2 - n], i * 2 - n);
            laziness[i].addToVal(leafs[i * 2 + 1 - n], i * 2 + 1 - n);
        }
        laziness[i] = Laziness();
    }
    inline void merge(int i) {
        if(i >= n) return;
        nodes[i] = sum(i * 2) + sum(i * 2 + 1);
    }
    inline Sum sum(int i) {
        propagate(i);
        return i < n ? nodes[i] : Sum(leafs[i - n], i - n);
    }
};

struct CentroidPathDecomposition {
    vector<int> colors, positions;    //Vertex -> Color, Vertex -> Offset
    vector<int> lengths, parents, branches;    //Color -> Int, Color -> Color, Color -> Offset
    vector<int> parentnodes, depths;    //Vertex -> Vertex, Vertex -> Int
    //vector<FenwickTree>??????1??????????
    //sortednodes?[lefts[v], rights[v])?v?subtree??????
    vector<int> sortednodes, offsets;    //Index -> Vertex, Color -> Index
    vector<int> lefts, rights;    //Vertex -> Index

    struct BuildDFSState {
        int i, len, parent;
        BuildDFSState() { }
        BuildDFSState(int i_, int l, int p): i(i_), len(l), parent(p) { }
    };

    //??????????????????????????
    void build(const vector<vi> &g, int root) {
        int n = g.size();

        colors.assign(n, -1); positions.assign(n, -1);
        lengths.clear(); parents.clear(); branches.clear();
        parentnodes.assign(n, -1); depths.assign(n, -1);

        sortednodes.clear(); offsets.clear();
        lefts.assign(n, -1); rights.assign(n, -1);

        vector<int> subtreesizes;
        measure(g, root, subtreesizes);

        typedef BuildDFSState State;
        depths[root] = 0;
        vector<State> s;
        s.push_back(State(root, 0, -1));
        while(!s.empty()) {
            State t = s.back(); s.pop_back();
            int i = t.i, len = t.len;
            int index = sortednodes.size();
            int color = lengths.size();

            if(t.parent == -3) {
                rights[i] = index;
                continue;
            }

            if(t.parent != -2) {
                assert(parents.size() == color);
                parents.push_back(t.parent);
                branches.push_back(len);
                offsets.push_back(index);
                len = 0;
            }
            colors[i] = color;
            positions[i] = len;

            lefts[i] = index;
            sortednodes.push_back(i);

            int maxsize = -1, maxj = -1;
            each(j, g[i]) if(colors[*j] == -1) {
                if(maxsize < subtreesizes[*j]) {
                    maxsize = subtreesizes[*j];
                    maxj = *j;
                }
                parentnodes[*j] = i;
                depths[*j] = depths[i] + 1;
            }
            s.push_back(State(i, -1, -3));
            if(maxj == -1) {
                lengths.push_back(len + 1);
            }else {
                each(j, g[i]) if(colors[*j] == -1 && *j != maxj)
                    s.push_back(State(*j, len, color));
                s.push_back(State(maxj, len + 1, -2));
            }
        }
    }
    
    void get(int v, int &c, int &p) const {
        c = colors[v]; p = positions[v];
    }
    bool go_up(int &c, int &p) const {
        p = branches[c]; c = parents[c];
        return c != -1;
    }

    inline const int *nodesBegin(int c) const { return &sortednodes[0] + offsets[c]; }
    inline const int *nodesEnd(int c) const { return &sortednodes[0] + offsets[c+1]; }

private:
    void measure(const vector<vi> &g, int root, vector<int> &out_subtreesizes) const {
        out_subtreesizes.assign(g.size(), -1);
        vector<int> s;
        s.push_back(root);
        while(!s.empty()) {
            int i = s.back(); s.pop_back();
            if(out_subtreesizes[i] == -2) {
                int s = 1;
                each(j, g[i]) if(out_subtreesizes[*j] != -2)
                    s += out_subtreesizes[*j];
                out_subtreesizes[i] = s;
            }else {
                s.push_back(i);
                each(j, g[i]) if(out_subtreesizes[*j] == -1)
                    s.push_back(*j);
                out_subtreesizes[i] = -2;
            }
        }
    }
};

int lowest_common_ancestor(const CentroidPathDecomposition &cpd, int x, int y) {
    int cx, px, cy, py;
    cpd.get(x, cx, px);
    cpd.get(y, cy, py);
    while(cx != cy) {
        if(cpd.depths[*cpd.nodesBegin(cx)] < cpd.depths[*cpd.nodesBegin(cy)])
            cpd.go_up(cy, py);
        else
            cpd.go_up(cx, px);
    }
    return cpd.nodesBegin(cx)[min(px, py)];
}


int main() {
    int N;
    scanf("%d", &N);
    vector<Val> initval(N);
    rep(i, N) {
        int a, b;
        scanf("%d%d", &a, &b);
        initval[i] = LinearExpr(a, b);
    }
    vector<vi> g(N);
    rep(i, N-1) {
        int x, y;
        scanf("%d%d", &x, &y), -- x, -- y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    CentroidPathDecomposition cpd;
    cpd.build(g, 0);
    vector<Val> permutatedInitval(N);
    rep(i, N)
        permutatedInitval[i] = initval[cpd.sortednodes[i]];
    SegmentTree segt;
    segt.init(permutatedInitval);
    vector<pii> path;
    int Q;
    scanf("%d", &Q);
    rep(ii, Q) {
        int ty;
        scanf("%d", &ty);
        if(ty == 1) {
            int u, v, a, b;
            scanf("%d%d%d%d", &u, &v, &a, &b), -- u, -- v;
            Laziness laziness(LinearExpr(a, b));
            int w = lowest_common_ancestor(cpd, u, v), wc, wp;
            cpd.get(w, wc, wp);
            rep(uv, 2) {
                int c, p;
                cpd.get(uv == 0 ? u : v, c, p);
                while(1) {
                    int top = c == wc ? wp + uv : 0;
                    int o = cpd.offsets[c], len = cpd.lengths[c];
                    //???[o + top, o + p]????? (????)
                    segt.addToRange(o + top, o + p + 1, laziness);
                    if(c == wc) break;
                    cpd.go_up(c, p);
                }
            }
        }else if(ty == 2) {
            int u, v, x;
            scanf("%d%d%d", &u, &v, &x), -- u, -- v;
            LinearExpr expr;
            int w = lowest_common_ancestor(cpd, u, v), wc, wp;
            cpd.get(w, wc, wp);
            rep(uv, 2) {
                path.clear();
                int c, p;
                cpd.get(uv == 0 ? u : v, c, p);
                while(1) {
                    int top = c == wc ? wp + uv : 0;
                    int o = cpd.offsets[c], len = cpd.lengths[c];
                    //???[o + top, o + p]????? (????)
                    path.push_back(mp(o + top, o + p));
                    if(c == wc) break;
                    cpd.go_up(c, p);
                }
                if(uv == 0) {
                    for(int i = 0; i < (int)path.size(); ++ i) {
                        int top = path[i].first, bottom = path[i].second;
                        expr += segt.getRange(top, bottom + 1).backward;
                    }
                }else {
                    for(int i = (int)path.size() - 1; i >= 0; -- i) {
                        int top = path[i].first, bottom = path[i].second;
                        expr += segt.getRange(top, bottom + 1).forward;
                    }
                }
            }
            mint ans = expr.evalute(x);
            printf("%d\n", ans.get());
        }else return 1;
    }
    return 0;
}


Problem solution in C programming.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef struct _lnode{
  int x;
  int w;
  struct _lnode *next;
} lnode;
typedef struct _tree{
  long long sum[2][2];
  long long sum1[2][2];
  long long offset1;
  long long offset2;
} tree;
#define MOD 1000000007
void insert_edge(int x,int y,int w);
void dfs0(int u);
void dfs1(int u,int c);
void preprocess();
int lca(int a,int b);
void sum(int v,int tl,int tr,int l,int r,tree *t,long long ans[][2],int f);
void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t);
void merge(long long a[][2],long long b[][2],long long ans[][2]);
void push(int v,int tl,int tr,tree *t);
void range_solve(int x,int y,int a,int b);
int min(int x,int y);
int max(int x,int y);
long long solve(int x,int y,int a);
void one(long long*a,int SIZE);
void mul(long long*a,long long*b,int SIZE);
void powm(long long*a,int n,long long*res,int SIZE);
int N,cn,A[100000],B[100000],level[100000],DP[18][100000],subtree_size[100000],special[100000],node_chain[100000],node_idx[100000],chain_head[100000],chain_len[100000]={0};
lnode *table[100000]={0};
tree *chain[100000];

int main(){
  int Q,x,y,a,b,i;
  scanf("%d",&N);
  for(i=0;i<N;i++)
    scanf("%d%d",A+i,B+i);
  for(i=0;i<N-1;i++){
    scanf("%d%d",&x,&y);
    insert_edge(x-1,y-1,1);
  }
  preprocess();
  scanf("%d",&Q);
  while(Q--){
    scanf("%d",&x);
    switch(x){
      case 1:
        scanf("%d%d%d%d",&x,&y,&a,&b);
        range_solve(x-1,y-1,a,b);
        break;
      default:
        scanf("%d%d%d",&x,&y,&a);
        printf("%lld\n",solve(x-1,y-1,a));
    }
  }
  return 0;
}
void insert_edge(int x,int y,int w){
  lnode *t=malloc(sizeof(lnode));
  t->x=y;
  t->w=w;
  t->next=table[x];
  table[x]=t;
  t=malloc(sizeof(lnode));
  t->x=x;
  t->w=w;
  t->next=table[y];
  table[y]=t;
  return;
}
void dfs0(int u){
  lnode *x;
  subtree_size[u]=1;
  special[u]=-1;
  for(x=table[u];x;x=x->next)
    if(x->x!=DP[0][u]){
      DP[0][x->x]=u;
      level[x->x]=level[u]+1;
      dfs0(x->x);
      subtree_size[u]+=subtree_size[x->x];
      if(special[u]==-1 || subtree_size[x->x]>subtree_size[special[u]])
        special[u]=x->x;
    }
  return;
}
void dfs1(int u,int c){
  lnode *x;
  node_chain[u]=c;
  node_idx[u]=chain_len[c]++;
  for(x=table[u];x;x=x->next)
    if(x->x!=DP[0][u])
      if(x->x==special[u])
        dfs1(x->x,c);
      else{
        chain_head[cn]=x->x;
        dfs1(x->x,cn++);
      }
  return;
}
void preprocess(){
  int i,j;
  level[0]=0;
  DP[0][0]=0;
  dfs0(0);
  for(i=1;i<18;i++)
    for(j=0;j<N;j++)
      DP[i][j] = DP[i-1][DP[i-1][j]];
  cn=1;
  chain_head[0]=0;
  dfs1(0,0);
  for(i=0;i<cn;i++){
    chain[i]=(tree*)malloc(4*chain_len[i]*sizeof(tree));
    memset(chain[i],0,4*chain_len[i]*sizeof(tree));
    for(j=0;j<4*chain_len[i];j++)
      chain[i][j].offset1=chain[i][j].offset2=-1;
  }
  for(i=0;i<N;i++)
    range_update(1,0,chain_len[node_chain[i]]-1,node_idx[i],node_idx[i],A[i],B[i],chain[node_chain[i]]);
  return;
}
int lca(int a,int b){
  int i;
  if(level[a]>level[b]){
    i=a;
    a=b;
    b=i;
  }
  int d = level[b]-level[a];
  for(i=0;i<18;i++)
    if(d&(1<<i))
      b=DP[i][b];
  if(a==b)return a;
  for(i=17;i>=0;i--)
    if(DP[i][a]!=DP[i][b])
      a=DP[i][a],b=DP[i][b];
  return DP[0][a];
}
void sum(int v,int tl,int tr,int l,int r,tree *t,long long ans[][2],int f){
  long long a[2][2],b[2][2];
  push(v,tl,tr,t);
  if(l>r){
    ans[0][0]=1;
    ans[0][1]=0;
    ans[1][0]=0;
    ans[1][1]=1;
    return;
  }
  if(l==tl && r==tr){
    if(f)
      memcpy(ans,t[v].sum1,sizeof(t[v].sum1));
    else
      memcpy(ans,t[v].sum,sizeof(t[v].sum));
    return;
  }
  int tm=(tl+tr)/2;
  sum(v*2,tl,tm,l,min(r,tm),t,a,f);
  sum(v*2+1,tm+1,tr,max(l,tm+1),r,t,b,f);
  if(f)
    merge(b,a,ans);
  else
    merge(a,b,ans);
  return;
}
void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t){
  push(v,tl,tr,t);
  if(pos2<tl || pos1>tr)
    return;
  int tm=(tl+tr)/2;
  if(pos1<=tl && pos2>=tr){
    t[v].offset1=o1;
    t[v].offset2=o2;
  }
  else{
    range_update(v*2,tl,tm,pos1,pos2,o1,o2,t);
    range_update(v*2+1,tm+1,tr,pos1,pos2,o1,o2,t);
    push(v*2,tl,tm,t);
    push(v*2+1,tm+1,tr,t);
    merge(t[v*2].sum,t[v*2+1].sum,t[v].sum);
    merge(t[v*2+1].sum1,t[v*2].sum1,t[v].sum1);
  }
  return;
}
void merge(long long a[][2],long long b[][2],long long ans[][2]){
  ans[0][0]=(a[0][0]*b[0][0]+a[0][1]*b[1][0])%MOD;
  ans[0][1]=(a[0][0]*b[0][1]+a[0][1]*b[1][1])%MOD;
  ans[1][0]=(a[1][0]*b[0][0]+a[1][1]*b[1][0])%MOD;
  ans[1][1]=(a[1][0]*b[0][1]+a[1][1]*b[1][1])%MOD;
  return;
}
void push(int v,int tl,int tr,tree *t){
  long long a[2][2];
  if(t[v].offset1==-1 || t[v].offset2==-1)
    return;
  a[0][0]=t[v].offset1;
  a[0][1]=t[v].offset2;
  a[1][0]=0;
  a[1][1]=1;
  powm(&a[0][0],tr-tl+1,&t[v].sum[0][0],2);
  memcpy(t[v].sum1,t[v].sum,sizeof(t[v].sum));
  if(tl!=tr){
    t[v*2].offset1=t[v*2+1].offset1=t[v].offset1;
    t[v*2].offset2=t[v*2+1].offset2=t[v].offset2;
  }
  t[v].offset1=t[v].offset2=-1;
  return;
}
void range_solve(int x,int y,int a,int b){
  int ca=lca(x,y);
  while(node_chain[x]!=node_chain[ca]){
    range_update(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],a,b,chain[node_chain[x]]);
    x=DP[0][chain_head[node_chain[x]]];
  }
  range_update(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],a,b,chain[node_chain[x]]);
  while(node_chain[y]!=node_chain[ca]){
    range_update(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],a,b,chain[node_chain[y]]);
    y=DP[0][chain_head[node_chain[y]]];
  }
  if(node_idx[y]!=node_idx[ca])
    range_update(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],a,b,chain[node_chain[y]]);
  return;
}
int min(int x,int y){
  return (x<y)?x:y;
}
int max(int x,int y){
  return (x>y)?x:y;
}
long long solve(int x,int y,int a){
  int ca=lca(x,y);
  long long t1[2][2],t2[2][2]={1,0,0,1},t3[2][2],ans[2][2];
  while(node_chain[x]!=node_chain[ca]){
    sum(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],chain[node_chain[x]],t1,0);
    memcpy(t3,t2,sizeof(t2));
    merge(t1,t3,t2);
    x=DP[0][chain_head[node_chain[x]]];
  }
  sum(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],chain[node_chain[x]],t1,0);
  memcpy(t3,t2,sizeof(t2));
  merge(t1,t3,ans);
  t2[0][0]=1;
  t2[0][1]=0;
  t2[1][0]=0;
  t2[1][1]=1;
  while(node_chain[y]!=node_chain[ca]){
    sum(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],chain[node_chain[y]],t1,1);
    memcpy(t3,t2,sizeof(t2));
    merge(t3,t1,t2);
    y=DP[0][chain_head[node_chain[y]]];
  }
  if(node_idx[y]!=node_idx[ca]){
    sum(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],chain[node_chain[y]],t1,1);
    memcpy(t3,t2,sizeof(t2));
    merge(t3,t1,t2);
  }
  merge(t2,ans,t1);
  return (a*t1[0][0]+t1[0][1])%MOD;
}
void one(long long*a,int SIZE){
    int i,j;
    for (i = 0; i < SIZE; i++)
        for (j = 0; j < SIZE; j++)
            a[i*SIZE+j] = (i == j);
    return;
}
void mul(long long*a,long long*b,int SIZE){
    int i,j,k;
    long long res[SIZE][SIZE];
    for(i=0;i<SIZE;i++)
      for(j=0;j<SIZE;j++)
        res[i][j]=0;
    for (i = 0; i < SIZE; i++)
        for (j = 0; j < SIZE; j++)
            for (k = 0; k < SIZE; k++)
                res[i][j] = (res[i][j]+a[i*SIZE+k] * b[k*SIZE+j])%MOD;
    for (i = 0; i < SIZE; i++)
        for (j = 0; j < SIZE; j++)
            a[i*SIZE+j] = res[i][j];
    return;
}
void powm(long long*a,int n,long long*res,int SIZE){
    one(res,SIZE);
    while (n > 0) {
        if (n % 2 == 0)
        {
            mul(a, a,SIZE);
            n /= 2;
        }
        else {
            mul(res, a,SIZE);
            n--;
        }
    }
}


Post a Comment

0 Comments