HackerEarth Longest Paths in Tree problem solution

In this HackerEarth Longest Paths in Tree problem solution You are given a tree. Simple path of length m is a sequence of vertices v1,v2,...,vm such that
• All vi are distinct.
• vi and vi+1 are connected by edge for 1 <= i <= m - 1.
For each vertex find length of longest simple path that goes through this vertex. Also count number of this paths. Two paths considered distinct if set of vertices of this paths differ.

HackerEarth Longest Paths in Tree problem solution.

#include <bits/stdc++.h>

using namespace std;

typedef long double ld;
typedef long long ll;
typedef pair<int, ll> Path;

#define len first
#define cnt second

const int M = 500100;
const int ROOT = 0;

vector<int> g[M];
int n;
int alen[M], acnt[M];
Path down[M];

cin >> n;
if (n == 1) {
cout << "1 1\n";
exit(0);
}
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
--u, --v;
g[u].push_back(v);
g[v].push_back(u);
}
}

Path operator+(const Path &lhs, const Path &rhs) {
Path ans(lhs);
if (rhs.len > ans.len) {
ans = rhs;
} else if (rhs.len == ans.len) {
ans.cnt += rhs.cnt;
}
return ans;
}

Path up(const Path &a) {
Path ans(a);
ans.len++;
return ans;
}

void dfs1(int v, int from) {
down[v] = Path(0, 1);
for (int to : g[v])
if (to != from) {
dfs1(to, v);
down[v] = down[v] + up(down[to]);
}
}

void dfs2(int v, int from, Path UP) {
int len1 = -1, len2 = -1;
ll cnt1 = 0;
ll sum11 = 0, sum12 = 0, sum2 = 0;

auto add = [&](const Path &x) {
int len = x.len;
ll cnt = x.cnt;
if (len > len1) {
len2 = len1;
sum2 = sum11;
len1 = len;
sum11 = cnt;
sum12 = cnt * cnt;
cnt1 = 1;
} else if (len == len1) {
sum11 += cnt;
sum12 += cnt * cnt;
cnt1++;
} else if (len > len2) {
len2 = len;
sum2 = cnt;
} else if (len == len2) {
sum2 += cnt;
}
};

for (int to : g[v])
if (to != from) {
}

assert(cnt1 >= 1);
if (cnt1 > 1) {
alen[v] = 2 * len1;
acnt[v] = (sum11 * sum11 - sum12) / 2;
} else if (len2 == -1) {
alen[v] = len1;
acnt[v] = sum11;
} else {
alen[v] = len1 + len2;
acnt[v] = sum11 * sum2;
}

for (int to : g[v])
if (to != from) {
if (cnt1 > 1) {
if (down[to].len + 1 != len1) {
dfs2(to, v, up(Path(len1, sum11)));
} else {
dfs2(to, v, up(Path(len1, sum11 - down[to].cnt)));
}
} else {
if (down[to].len + 1 == len1) {
dfs2(to, v, up(Path(len2, sum2)));
} else {
dfs2(to, v, up(Path(len1, sum11)));
}
}
}
}

void kill() {
int root = ROOT % n;
dfs1(root, -1);
for (int i = 0; i < n; ++i) {
cerr << i << ": " << down[i].len << " - " << down[i].cnt << endl;
}
dfs2(root, -1, Path(0, 1));
for (int i = 0; i < n; ++i)
cout << alen[i] + 1 << " " << acnt[i] << "\n";
}

int main() {
#ifdef LOCAL
assert(freopen("a.in", "r", stdin));
#endif

ios_base::sync_with_stdio(false);
kill();
}

Second solution

import java.io.OutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.BufferedWriter;
import java.util.InputMismatchException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import java.io.Writer;
import java.io.OutputStreamWriter;
import java.io.InputStream;

/**
* Built using CHelper plug-in
* Actual solution is at the top
*/
public class Main {
public static void main(String[] args) {
InputStream inputStream = System.in;
OutputStream outputStream = System.out;
OutputWriter out = new OutputWriter(outputStream);
LongestPathsInTree solver = new LongestPathsInTree();
solver.solve(1, in, out);
out.close();
}

static class LongestPathsInTree {
public int n;
int[] d1;
int[] d2;
int[] up;
int[] c1;
int[] c2;
int[] cup;
List<Integer>[] graph;
public long[] res;
public long[] ans;

public void solve(int testNumber, InputReader in, OutputWriter out) {
new LongestPathsInTree()._solve(1, in, out);
out.close();
}, "1", Runtime.getRuntime().maxMemory());
t.start();
try {
t.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}

public void _solve(int testNumber, InputReader in, OutputWriter out) {
n = in.nextInt();
graph = Stream.generate(ArrayList::new).limit(n).toArray(List[]::new);
for (int i = 0; i < n - 1; i++) {
int a = in.nextInt() - 1, b = in.nextInt() - 1;
}
d1 = new int[n];
d2 = new int[n];
up = new int[n];
c1 = new int[n];
c2 = new int[n];
cup = new int[n];
res = new long[n];
ans = new long[n];
dfs(0, -1);
cup[0] = 1;
dfs2(0, -1, 0);
for (int i = 0; i < n; i++) {
out.println(res[i] + " " + ans[i]);
}
}

public void dfs(int node, int par) {
d1[node] = 0;
c1[node] = 1;
for (int next : graph[node]) {
if (next == par) continue;
dfs(next, node);
if (d1[next] + 1 > d1[node]) {
d1[node] = d1[next] + 1;
c1[node] = 0;
}
if (d1[next] + 1 == d1[node]) c1[node] += c1[next];
}
}

public void dfs2(int node, int par, int frompar) {
up[node] = frompar;
int mx1 = 0, mx2 = 0;
int r1 = up[node], r2 = -1;
int x1 = 0, x2 = 0;
for (int next : graph[node]) {
if (next == par) continue;
if (d1[next] + 1 > r1) {
r2 = r1;
r1 = d1[next] + 1;
} else if (d1[next] + 1 > r2) {
r2 = d1[next] + 1;
}

if (d1[next] + 1 > mx1) {
mx2 = mx1;
mx1 = d1[next] + 1;
x2 = x1;
x1 = 0;
}
if (d1[next] + 1 == mx1) x1 += c1[next];

if (d1[next] + 1 < mx1 && d1[next] + 1 > mx2) {
mx2 = d1[next] + 1;
x2 = 0;
}
if (d1[next] + 1 == mx2) x2 += c1[next];
}
d1[node] = mx1;
d2[node] = mx2;
c1[node] = x1;
c2[node] = x2;
long total1 = 0;
long total2 = 0;
long xx = 0;
if (up[node] == r1) {
total1 += cup[node];
xx += 1L * cup[node] * cup[node];
}
if (up[node] == r2) {
total2 += cup[node];
}
for (int next : graph[node]) {
if (next == par) continue;
if (d1[next] + 1 == r1) {
total1 += c1[next];
xx += 1L * c1[next] * c1[next];
}
if (d1[next] + 1 == r2) {
total2 += c1[next];
}
}
if (r2 == -1) {
r2 = 0;
total2 = 1;
}
if (r1 == 0) total1 = 2;
res[node] = r1 + r2 + 1;
// System.out.println("A: " +node+" "+r1+" "+r2+" "+total1+" "+total2+" "+xx);
// System.out.println(d1[node]+" "+d2[node]+" "+up[node]+" "+c1[node]+" "+c2[node]+" "+cup[node]);
if (r1 == r2) {
ans[node] = (total1 * total1 - xx) / 2;
} else {
ans[node] = total1 * total2;
}

for (int next : graph[node]) {
if (next == par) continue;
int mxn = mx1, count = x1;
if (d1[next] + 1 == mx1) {
count -= c1[next];
if (count == 0) {
mxn = mx2;
count = x2;
}
}

if (up[node] > mxn) {
mxn = up[node];
count = cup[node];
} else if (up[node] == mxn) {
count += cup[node];
}
cup[next] = count;
dfs2(next, node, mxn + 1);
}
}

}

static class OutputWriter {
private final PrintWriter writer;

public OutputWriter(OutputStream outputStream) {
writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(outputStream)));
}

public OutputWriter(Writer writer) {
this.writer = new PrintWriter(writer);
}

public void print(Object... objects) {
for (int i = 0; i < objects.length; i++) {
if (i != 0) {
writer.print(' ');
}
writer.print(objects[i]);
}
}

public void println(Object... objects) {
print(objects);
writer.println();
}

public void close() {
writer.close();
}

}

private InputStream stream;
private byte[] buf = new byte[1024];
private int curChar;
private int numChars;

this.stream = stream;
}

if (this.numChars == -1) {
throw new InputMismatchException();
} else {
if (this.curChar >= this.numChars) {
this.curChar = 0;

try {
} catch (IOException var2) {
throw new InputMismatchException();
}

if (this.numChars <= 0) {
return -1;
}
}

return this.buf[this.curChar++];
}
}

public int nextInt() {
int c;
;
}

byte sgn = 1;
if (c == 45) {
sgn = -1;
}

int res = 0;

while (c >= 48 && c <= 57) {
res *= 10;
res += c - 48;
if (isSpaceChar(c)) {
return res * sgn;
}
}

throw new InputMismatchException();
}

public static boolean isSpaceChar(int c) {
return c == 32 || c == 10 || c == 13 || c == 9 || c == -1;
}

}
}