In this HackerRank Tree Pruning problem solution A tree, T, has N vertices numbered from 1 to N and is rooted at vertex 1. Each vertex I has an integer weight, Wi, associated with it, and T's total weight is the sum of the weights of its nodes. A single remove operation removes the subtree rooted at some arbitrary vertex U from tree T.

Given T, perform up to K remove operations so that the total weight of the remaining vertices in T is maximal. Then print T's maximal total weight on a new line.

HackerRank Tree Pruning problem solution


Problem solution in Python.

#!/bin/python3

import os
import sys

#
# Complete the treePrunning function below.
#

from collections import defaultdict

INF = -(1e15)

def dfs(x, f, g, k, weights):
    dpc = [INF]*(k+1)
    dpc[0] = weights[x]
    
    for n in g[x]:
        if n == f:
            continue
        dpn = dfs(n, x, g, k, weights)
        dptmp = [INF]*(k+1)
        for i in range(k+1):
            if dpc[i] == INF:
                break
            for j in range(0, k-i+1):
                if dpn[j] == INF:
                    break
                dptmp[i+j] = max(dptmp[i+j], dpc[i]+dpn[j])
            if i+1 <= k:
                dptmp[i+1] = max(dptmp[i+1], dpc[i])
        dpc = dptmp
    return dpc

def treePrunning(k,weights,edges):
    g = defaultdict(list)
    for u, v in edges:
        g[u-1].append(v-1)
        g[v-1].append(u-1)

    dpn = dfs(0, -1, g, k, weights)
    return max(max(dpn),0)


    
if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    nk = input().split()

    n = int(nk[0])

    k = int(nk[1])

    weights = list(map(int, input().rstrip().split()))

    tree = []

    for _ in range(n-1):
        tree.append(list(map(int, input().rstrip().split())))

    result = treePrunning(k, weights, tree)

    fptr.write(str(result) + '\n')

    fptr.close()
 


Problem solution in Java.

import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws IOException {
        FastScanner in = new FastScanner(System.in);
        PrintWriter out = new PrintWriter(System.out);
        new Main().run(in, out);
        out.close();
    }


    int n;
    int K;
    List<Integer>[] adj;
    int[] w;
    void run(FastScanner in, PrintWriter out) {

        n = in.nextInt();
        K = in.nextInt();
        w = new int[n];
        adj = new List[n];
        for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
        for (int i = 0; i < n; i++) w[i] = in.nextInt();
        for (int i = 1; i < n; i++) {
            int u = in.nextInt()-1;
            int v = in.nextInt()-1;
            adj[u].add(v);
            adj[v].add(u);
        }

        long[] dp = go(0, -1);
        long max = Long.MIN_VALUE;
        for (int k = 0; k <= K; k++) {
            max = Math.max(max, dp[k]);
        }
        out.println(max);
    }

    long[] go(int u, int p) {

        long[][] dp = new long[2][K+1];
        for (long[] d : dp) Arrays.fill(d, Long.MIN_VALUE);
        int flip = 0;
        dp[0][0] = w[u];

        for (int v : adj[u]) {
            Arrays.fill(dp[flip^1], Long.MIN_VALUE);
            if (v == p) continue;

            long[] childDp = go(v, u);
            for (int k = 0; k <= K && dp[flip][k] != Long.MIN_VALUE; k++) {
                for (int pk = 0; pk+k <= K && childDp[pk] != Long.MIN_VALUE; pk++) {
                    dp[flip^1][pk+k] = Math.max(dp[flip^1][pk+k],
                            dp[flip][k] + childDp[pk]);
                }
            }
            flip = flip^1;
        }
        dp[flip][1] = Math.max(dp[flip][1], 0);
        return dp[flip];
    }

    static class FastScanner {
        BufferedReader br;
        StringTokenizer st;

        public FastScanner(InputStream in) {
            br = new BufferedReader(new InputStreamReader(in));
            st = null;
        }

        String next() {
            while (st == null || !st.hasMoreElements()) {
                try {
                    st = new StringTokenizer(br.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }

        int nextInt() {
            return Integer.parseInt(next());
        }

        long nextLong() {
            return Long.parseLong(next());
        }
    }
}


Problem solution in C++.

#include <cmath>
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdlib>
#include <climits>
using namespace std;
typedef long long ll;

const int MAXN = 100005;
const int MAXK = 205;
int wgt[MAXN];
ll dp[MAXN][MAXK];
vector< vector<int> > net;
int N, K;

void dfs(int cur, int par){
    ll w = wgt[cur];
    int first = -1;
    for(int i = 0; i < net[cur].size(); i++){
        int u = net[cur][i];
        if(u == par) continue;
        if(first == -1) first = i;
        dfs(u, cur);
        w += dp[u][0];
    }
    if(first != -1)
        for(int i = 0; i <= K; i++) dp[cur][i] = w + dp[net[cur][first]][i] - dp[net[cur][first]][0]; 
    for(int i = first +1; i < net[cur].size(); i++){
        int u = net[cur][i];
        if(u == par ) continue;
        for(int j = K; j >=0; j--){
            for(int l = j; l >= 0; l--){
                if(dp[cur][j] < dp[cur][l] + dp[u][j-l] - dp[u][0]) dp[cur][j] = dp[cur][l] + dp[u][j-l] - dp[u][0];
            }
        }
    } 
    dp[cur][0] = w;
    for(int i = 1; i <= K; ++i)
        if (dp[cur][i] < 0) dp[cur][i] = 0;
}

int main() {
    /* Enter your code here. Read input from STDIN. Print output to STDOUT */  
    cin>>N>>K;
    for(int i = 1; i <= N; i++) cin>>wgt[i];
    net.resize(N+1);
    for(int i = 0; i < N-1; i++){
        int u,v;
        cin>>u>>v;
        net[u].push_back(v);
        net[v].push_back(u);
    }
    for(int i = 1; i<= N; i++){
        for(int j = 0; j<=K; j++){
            dp[i][j] = -1*INT_MAX;
        }
    }
    dfs(1,0);
    ll ans = dp[1][0];
    for(int i = 1; i <= K; i++) ans = max(ans, dp[1][i]);
    cout<<ans<<endl;
    return 0;
}


Problem solution in C.

#include<stdio.h>
#include<stdlib.h>
typedef struct _node
{
    int x;
    struct _node *next;
}node;
int a[100000], b[100000], size[100000], trace[100000] = {0}, NN = 0;
long long dp[100001][201];
node *table[100000] = {0};
void insert_edge(int x, int y)
{
    node *t;
    t = (node*)malloc(sizeof(node));
    t -> x = y;
    t -> next = table[x];
    table[x] = t;
    t = (node*)malloc(sizeof(node));
    t -> x = x;
    t -> next = table[y];
    table[y] = t;
    return;
}
void dfs(int x)
{
    node *t;
    int i = NN;
    trace[x] = 1;
    b[NN++] = a[x];
    for( t = table[x] ; t ; t = t -> next )
    {
        if(!trace[t -> x])
        {
            dfs(t -> x);
        }
    }
    size[i] = NN - i;
    return;
}
long long max(long long x, long long y)
{
  return ( x > y ) ? x : y;
}
int main()
{
    int N, K, x, y, i, j;
    long long sum;
    scanf("%d%d", &N, &K);
    for( i = 0 ; i < N ; i++ )
    {
        scanf("%d", a + i);
    }
    for( i = 0 ; i < N - 1 ; i++ )
    {
        scanf("%d%d", &x, &y);
        insert_edge(x-1, y-1);
    }
    dfs(0);
    for( i = 0 ; i <= K ; i++ )
        dp[0][i] = 0;
    for( i = 1, sum = 0 ; i <= N ; i++ )
    {
        sum += b[i-1];
        for( j = 0 ; j <= K ; j++ )
        {
            dp[i][j] = sum;
        }
    }
    for( i = 1, sum = 0 ; i <= N ; i++ )
    {
        for( j = 0 ; j <= K ; j++ )
        {
            if( j != K )
            {
                dp[i+size[i-1]-1][j+1] = max(dp[i+size[i-1]-1][j+1], dp[i-1][j]);
            }
            dp[i][j] = max(dp[i][j], dp[i-1][j] + b[i-1]);
        }
    }
    printf("%lld", dp[N][K]);
    return 0;
}