Header Ad

HackerRank Balanced Forest Interview preparation kit solution

In this HackerRank Balanced Forest Interview preparation kit problem You need to Complete the balancedForest function. It must return an integer representing the minimum value of c[w] that can be added to allow the creation of a balanced forest, or -1 if it is not possible.

HackerRank Balanced Forest Interview preparation kit solution


Problem solution in Python programming.

from operator import attrgetter
from itertools import groupby
from sys import stderr

class Node:
    def __init__(self, index, value):
        self.index = index
        self.value = value
        self.children = []
        
def readtree():
    size = int(input())
    values = readints()
    assert size == len(values)
    nodes = [Node(i, v) for i, v in enumerate(values)]
    for _ in range(size - 1):
        x, y = readints()
        nodes[x-1].children.append(nodes[y-1])
        nodes[y-1].children.append(nodes[x-1])
    return nodes

def readints():
    return [int(fld) for fld in input().strip().split()]

def findbestbal(nodes):
    if len(nodes) == 1:
        return -1
    rootify(nodes[0])
#    print([(n.index, n.value, n.totalval) for n in nodes], file=stderr)
    best = total = nodes[0].totalval
    dummynode = Node(None, None)
    dummynode.totalval = 0
    sortnode = []
    for k, g in groupby(sorted([dummynode] + nodes, key = attrgetter('totalval')), attrgetter('totalval')):
        sortnode.append(list(g))
    total = nodes[0].totalval
    for ihi, n in enumerate(sortnode):
        if 3 * n[0].totalval >= total:
            break
    else:
        assert False
    ilo = ihi - 1
    for ihi in range(ihi, len(sortnode)):
        hi = sortnode[ihi][0].totalval
        lo = sortnode[ilo][0].totalval
        while 2 * hi + lo > total:
            if lo == 0:
                return -1
            if (total - lo) % 2 == 0:
                x = (total - lo) // 2
                for lonode in sortnode[ilo]:
                    if uptototalval(lonode, x + lo):
                        return x - lo
            ilo -= 1
            lo = sortnode[ilo][0].totalval
        if len(sortnode[ihi]) > 1:
            return 3 * hi - total
        hinode = sortnode[ihi][0]
        if 2 * hi + lo == total:
            for lonode in sortnode[ilo]:
                if uptototalval(lonode, hi) != hinode:
                    return hi - lo
        y = total - 2 * hi
        if uptototalval(hinode, 2 * hi) or uptototalval(hinode, hi + y):
            return hi - y

def rootify(root):
    root.parent = root.jumpup = None
    root.depth = 0
    bfnode = [root]
    i = 0
    while i < len(bfnode):
        node = bfnode[i]
        depth = node.depth + 1
        jumpup = uptodepth(node, depth & (depth - 1))
        for child in node.children:
            child.parent = node
            child.children.remove(node)
            child.depth = depth
            child.jumpup = jumpup
            bfnode.append(child)
        i += 1
    for node in reversed(bfnode):
        node.totalval = node.value + sum(child.totalval for child in node.children)
            
def uptodepth(node, depth):
    while node.depth > depth:
        if node.jumpup.depth <= depth:
            node = node.jumpup
        else:
            node = node.parent
    return node
            
def uptototalval(node, totalval):
  try:
#    print('uptototalval(%s,%s)' % (node.index, totalval), file=stderr)
    while node.totalval < totalval:
        if node.parent is None:
            return None
        if node.jumpup.totalval <= totalval:
            node = node.jumpup
        else:
            node = node.parent
#        print((node.index, node.totalval), file=stderr)
    if node.totalval == totalval:
        return node
    else:
        return None
  except Exception:
    return None
    
ncases = int(input())
for _ in range(ncases):
    print(findbestbal(readtree()))


Problem solution in Java Programming.

import java.io.*;
import java.util.*;

public class Solution {

    public static void main(String[] args) {
        /* Enter your code here. Read input from STDIN. Print output to STDOUT. Your class should be named Solution. */
        Scanner scanner = new Scanner(System.in) ;
        int q = scanner.nextInt() ;
        if (q < 1 || q > 5) {
            throw new IllegalArgumentException("1<=Q<=50000") ;
        }
        while(q>0) {
            int n = scanner.nextInt() ;
            if (n < 1 || n > 50000) {
                throw new IllegalArgumentException("1<=N<=50000") ;
            }
            List<Node> nodes = new ArrayList<Node>(n) ;
			List<GraphNode> graph = new ArrayList<GraphNode>(n) ;
			List<TreeNode> tree = new ArrayList<TreeNode>(n) ;
            for(int i=1;i<=n;i++) {
				Node node = new Node(i,scanner.nextInt()) ;
                nodes.add(node) ;
				graph.add(new GraphNode(node)) ;
				tree.add(new TreeNode(node)) ;
            }
            List<Edge> edges = new ArrayList<Edge>() ;
            for(int i=0;i<n-1;i++) {
                Edge edge = new Edge(scanner.nextInt(), scanner.nextInt()) ;
                edges.add(edge) ;
                addEdge(graph,edge) ;
            }
			graphsToTree(graph.get(0),tree,new HashSet<Node>()) ;
            //System.out.println(tree.get(0)) ;
            //System.out.println(edges) ;
            System.out.println(findMinCw(tree, edges, tree.get(0))) ;
            q-- ;
        }
    }
    
	private static class Node {
		int nodeId ;
        long coins=0 ;
		
		public Node(int nodeId, long coins) {
            this.nodeId = nodeId ;
            this.coins = coins ;
        }
        public int getNodeId() {
            return nodeId ;
        }
        public long getCoins() {
            return coins ;
        }
		
		public boolean equals(Node node) {
            return (node!=null && node.getNodeId() == nodeId) ;
        }
        
        public int hashCode() {
            return nodeId ;
        }
		public String toString() {
			return String.format("nodeId:%s, coins:%d",nodeId,coins) ;
		}
	}
	
	private static class GraphNode {
		
		final Node node ;
		Set<GraphNode> connectedNodes = new HashSet<GraphNode>() ;
		
		public GraphNode(Node node) {
			this.node = node ;
		}
		public Node getNode() {
			return node ;
		}
		public void addConnection(GraphNode addNode) {
			connectedNodes.add(addNode) ;
		}
		public void removeConnection(GraphNode removeNode) {
			connectedNodes.remove(removeNode) ;
		}
		public Set<GraphNode> getConnectedNodes() {
			return connectedNodes ;
		}        
        public String toString() {
			return String.format("Node:%s, connectedNodes[%s]",node.toString(),connectedNodesToString()) ;
        }
		private String connectedNodesToString() {
			StringBuilder builder = new StringBuilder() ;
			for(GraphNode node : connectedNodes) {
				builder.append("Node:").append(node).append(",") ;
			}
			return builder.toString() ;
		}
    }
	
	private static class TreeNode {
		final Node node ;
		TreeNode parentNode ;
		Set<TreeNode> childNodes = new HashSet<TreeNode>() ;
		long totalCoins ;
		
		public TreeNode(Node node) {
			this.node = node ;
			totalCoins = node.getCoins() ;
		}
		public Node getNode() {
			return node ;
		}
		public long getTotalCoins() {
			return totalCoins ;
		}
		public Set<TreeNode> getChildNodes() {
			return childNodes ;
		}
		
		public void setParent(TreeNode node) {
            if(parentNode!=null && node!=null){
                throw new RuntimeException("Multiple parent is not supported. parent:"+parentNode+" current:"+this+" new parent:"+node);
            }
            parentNode = node ;
        }
		
		public void addChildNode(TreeNode node) {
            childNodes.add(node) ;
            totalCoins+=node.getTotalCoins() ;
            node.setParent(this) ;
            if (parentNode!=null) {
                parentNode.addChildCoins(node.getTotalCoins()) ;
            }
        }
		
		public void addChildCoins(long childCoins) {
            totalCoins += childCoins ;
            if (parentNode!=null) {
                parentNode.addChildCoins(childCoins) ;
            }
        }
		
		public void removeChildNode(TreeNode node) {
            childNodes.remove(node) ;
            totalCoins-=node.getTotalCoins() ;
            node.setParent(null) ;
            if (parentNode!=null) {
                parentNode.removeChildCoins(node.getTotalCoins()) ;
            }
        }
		
		public void removeChildCoins(long childCoins) {
            totalCoins -= childCoins ;
            if (parentNode!=null) {
                parentNode.removeChildCoins(childCoins) ;
            }
        }
		
		public boolean isParentOf(TreeNode childNode) {
			return childNodes.contains(childNode) ;
		}
		
		public boolean isRoot() {
			return parentNode == null ;
		}
		
		public String toString() {
			return String.format("Node:%s, totalCoins:%d, parentNode:%s, childNodes:[%s]",node.toString(), totalCoins, 
				(parentNode!=null)?parentNode.getNode().toString():"NULL",childNodes.toString()) ;
		}
	}
	
    private static class Edge {
        int node1 ;
        int node2 ;
        
        public Edge(int node1, int node2) {
			this.node1 = node1 ;
			this.node2 = node2 ;
        }

		public int getNode1() {
			return node1 ;
		}

		public int getNode2() {
			return node2 ;
		}
        
		public void swapNode() {
			int tmpNode = node1 ;
			node1 = node2 ;
			node2 = tmpNode ;
		}
		
        public String toString() {
            return "node1:"+node1+" node2:"+node2 ;
        }
    }
    
    private static void addEdge(List<GraphNode> nodes, Edge edge) {
        nodes.get(edge.getNode1()-1).addConnection(nodes.get(edge.getNode2()-1)) ;
        nodes.get(edge.getNode2()-1).addConnection(nodes.get(edge.getNode1()-1)) ;
    }
    
    private static void removeEdge(List<GraphNode> nodes, Edge edge) {
        nodes.get(edge.getNode1()-1).removeConnection(nodes.get(edge.getNode2()-1)) ;
        nodes.get(edge.getNode2()-1).removeConnection(nodes.get(edge.getNode1()-1)) ;
    }
	
	private static void graphsToTree(GraphNode graph, List<TreeNode> treeNodes, Set<Node> visitedNode) {
		TreeNode treeNode = treeNodes.get(graph.getNode().getNodeId()-1) ;
		visitedNode.add(graph.getNode()) ;
		for(GraphNode connectedNode : graph.getConnectedNodes()) {
			if (!visitedNode.contains(connectedNode.getNode())) {
				treeNode.addChildNode(treeNodes.get(connectedNode.getNode().getNodeId()-1)) ;
				visitedNode.add(connectedNode.getNode()) ;
				graphsToTree(connectedNode, treeNodes, visitedNode) ;
			}			
		}
	}
	
	private static TreeNode removeTreeEdge(List<TreeNode> nodes, Edge edge) {
		TreeNode node1 = nodes.get(edge.getNode1()-1) ;
		TreeNode node2 = nodes.get(edge.getNode2()-1) ;
		
		if (node1.isParentOf(node2)) {
			// node1 is parent of node2
			node1.removeChildNode(node2) ;			
			return node2 ;
		} else {
			node2.removeChildNode(node1) ;
			edge.swapNode() ;
			return node1 ;
		}
	}
	
	private static void addTreeEdge(List<TreeNode> nodes, Edge edge, TreeNode rootNode) {
		TreeNode node1 = nodes.get(edge.getNode1()-1) ;
		TreeNode node2 = nodes.get(edge.getNode2()-1) ;
		
		node1.addChildNode(node2) ;	
	}
	
	//DFS on subTree with expected value
	private static boolean findSubTreeWithValue(TreeNode searchRoot, TreeNode tree, long expectedValue) {
		if (searchRoot.getTotalCoins() <= expectedValue || tree.getTotalCoins() <= expectedValue) {
			return false ;
		}
		for(TreeNode subTree : tree.getChildNodes()) {
			long subTreeCoins = subTree.getTotalCoins() ;
			long remainingCoins = searchRoot.getTotalCoins()-subTreeCoins ;
			
			if (subTreeCoins == expectedValue || remainingCoins==expectedValue) {
				return true ;
			}
			if (findSubTreeWithValue(searchRoot,subTree,expectedValue)) {
				return true ;
			}
		}
		return false ;
	}
	
	public static long findMinCw(List<TreeNode> nodes, List<Edge> edges, TreeNode rootNode) {
        long minCw = -1 ;
        for (int i = 0; i<edges.size() ;i++) {
			Edge removeEdge1 = edges.get(i) ;
			TreeNode tree1 = removeTreeEdge(nodes,removeEdge1) ;
            
            long nodes1Coins = rootNode.getTotalCoins() ;
            long nodes2Coins = tree1.getTotalCoins()  ;

            long largeSetCoins, smallSetCoins ;
            TreeNode treeToSplit = null ;
			
            if (nodes1Coins == nodes2Coins) {
                long cw = nodes1Coins ;
                if (minCw <0 || cw < minCw) {
                    minCw = cw ;
                }
				addTreeEdge(nodes, removeEdge1, rootNode);
                continue ;
            } else if (nodes1Coins>nodes2Coins) {
                largeSetCoins = nodes1Coins ;
                smallSetCoins = nodes2Coins ;
				treeToSplit = rootNode ;
            } else {
                largeSetCoins = nodes2Coins ;
                smallSetCoins = nodes1Coins ;                
				treeToSplit = tree1 ;
            }

            long expectedCw = -1 ;
            long expectedCw1 = -1 ;
			long searchValue ;
            if (largeSetCoins%2 == 0) {
                expectedCw1 = largeSetCoins/2l - smallSetCoins ;
            }
            long expectedCw2 = smallSetCoins - (largeSetCoins - smallSetCoins) ;

            if (expectedCw1 >= 0 && expectedCw2 >=0) {
                expectedCw = Math.min(expectedCw1,expectedCw2) ;
            } else if (expectedCw1 >= 0) {
                expectedCw = expectedCw1 ;
            } else if (expectedCw2 >= 0) {
                expectedCw = expectedCw2 ;
            }
            
            if (expectedCw<0 || (minCw >0 && expectedCw > minCw)) {
                addTreeEdge(nodes, removeEdge1, rootNode);
                continue ;
            }

			if (expectedCw == expectedCw1) {
				searchValue = largeSetCoins/2l ;
			} else {
				searchValue = smallSetCoins ;
			}

			if (findSubTreeWithValue(treeToSplit, treeToSplit, searchValue)) {
				if (minCw <0 || expectedCw < minCw) {
                    minCw = expectedCw ;
                }
			}
		
            addTreeEdge(nodes, removeEdge1, rootNode);
        }
        return minCw ;
    }

}


Problem solution in C++ programming.

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#include <string>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <deque>
#include <cassert>
#include <stdlib.h>

using namespace std;


typedef long long ll;

const ll INF = (ll) 1e18;
const int N = (int) 5e4 + 10;

vector<int> g[N];
ll c[N];
ll f[N];
ll res = INF;
ll tot = 0;
bool was[N];

void upd(ll a, ll b, ll c) {
    if (a == b && c <= a)
        res = min(res, a - c);
    if (a == c && b <= a)
        res = min(res, a - b);
    if (b == c && a <= b)
        res = min(res, b - a); 
}

set<ll>* unite(set<ll>* a, set<ll>* b) {
    if (a->size() > b->size())
        swap(a, b);
    for (ll x : *a) {
        if (b->count(tot - 2 * x))
            upd(tot - 2 * x, x, x);
        if (b->count(x))
            upd(x, x, tot - 2 * x);
        if ((tot - x) % 2 == 0 && b->count((tot - x) / 2))
            upd((tot - x) / 2, x, (tot - x) / 2);
    }
    for (ll x : *a) {
        b->insert(x);
    }
    delete a;
    return b;
}

set<ll>* dfs(int v) {
    was[v] = true;
    f[v] = c[v];
    set<ll>* sv = new set<ll>();
    for (int to : g[v])
        if (!was[to]) {
            set<ll>* sto = dfs(to);
            f[v] += f[to];
            sv = unite(sv, sto);
        }
    if (f[v] % 2 == 0 && sv->count(f[v] / 2))
        upd(f[v] / 2, f[v] / 2, tot - f[v]);
    if (sv->count(tot - f[v]))
        upd(tot - f[v], 2 * f[v] - tot, tot - f[v]);
    if (sv->count(2 * f[v] - tot))
        upd(2 * f[v] - tot, tot - f[v], tot - f[v]);
    sv->insert(f[v]);
    return sv;
}

void solve() {
    int n;
    cin >> n;
    for (int i = 0; i < N; i++) {
        was[i] = false;
        g[i].clear();
        c[i] = 0;
    }
    tot = 0;
    res = INF;
    for (int i = 0; i < n; i++) {
        cin >> c[i];
        tot += c[i];
    }
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        cin >> x >> y;
        --x;
        --y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    set<ll>* s = dfs(0);
    //for (int i = 0; i < n; i++)
    //    cerr << f[i] << " ";
    //cerr << endl;
    delete s;
    if (res == INF)
        res = -1;
    cout << res << endl;
    // cerr << "----------" << endl;
}

int main() {
    ios_base::sync_with_stdio(0);
    int p;
    cin >> p;
    while (p--) {
        solve();
    }
    return 0;
}


Problem solution in C programming.

#include <assert.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

char* readline();
char** split_string(char*);

struct Node
{
    int64_t sum; // sum of all node in this tree.
    int64_t testSum;
    int data;
    int parent;
    int *child;
    int childCnt;
};
typedef struct Node Node;

void sumTree(Node *root, int index)
{
    int i;
    root[index].sum = root[index].data;
    for (i = 0; i < root[index].childCnt; i++)
    {
        int child = root[index].child[i];
        if (child == root[index].parent) continue;
        sumTree(root, child);
        root[index].sum += root[child].sum;
    }
    root[index].testSum = root[index].sum;
}

void updateTree(Node *tree, int root, int parent)
{
    tree[root].parent = parent;
    int i;
    for (i = 0; i < tree[root].childCnt; i++)
    {
        if (tree[root].child[i] == parent) continue;
        updateTree(tree, tree[root].child[i], root);
    }
}

int64_t childSum(Node *tree, int root, int  branch_root, int64_t targetSum, bool *bFound)
{
    int i;
    int64_t currSum = 0;

    if (tree[root].testSum < targetSum) return tree[root].testSum;

    for (i = 0; (i < tree[root].childCnt) && (*bFound==0); i++)
    {
        int child = tree[root].child[i];
        if (child == tree[root].parent) continue;
        if (child == branch_root) continue;
        int64_t chSum = childSum(tree, child, branch_root, targetSum, bFound);

        if (chSum == targetSum)
        {
            *bFound = 1;
            break;
        }

        currSum += chSum;
    }
    return currSum + tree[root].data;
}

// Complete the balancedForest function below.
int64_t balancedForest(int c_count, int* c, int edges_rows, int edges_columns, int** edges) {
    int i, j;
    // build tree.
    Node *tree = (Node *)calloc(c_count, sizeof(Node));
    for (i = 0; i < c_count; i++)
    {
        tree[i].data = c[i];
        tree[i].childCnt = 0;
        tree[i].child = NULL;
        tree[i].parent = -1;
        tree[i].sum = 0;
    }
    for (i = 0; i < edges_rows; i++)
    {
        int pa = edges[i][0] - 1;
        int ch = edges[i][1] - 1;
        tree[pa].child = (int *)realloc(tree[pa].child, (1 + tree[pa].childCnt) * sizeof(int));
        tree[pa].child[tree[pa].childCnt] = ch;
        tree[pa].childCnt++;
        tree[ch].child = (int *)realloc(tree[ch].child, (1 + tree[ch].childCnt) * sizeof(int));
        tree[ch].child[tree[ch].childCnt] = pa;
        tree[ch].childCnt++;
    }
    // Now update the parent_node;
    int root = 0; // pick the first one as root.
    updateTree(tree, root, -1);
    sumTree(tree, root);

    int64_t treeSum = 0;
    treeSum = tree[root].sum;
    
    int64_t maxSum = (treeSum - 1) / 2 + 1;
    int64_t minSum = treeSum / 3 - 1;
    int64_t minW = -1;
    for (i = 0; i < c_count; i++)
    {
        if (i == root) continue;

        //if (tree[i].sum >= minSum && tree[i].sum <= maxSum)
        {
            int64_t sumI = tree[i].sum;

            // Check for special case.
            int64_t sumJ = treeSum - sumI;

            if (sumI == sumJ)
            {
                if (minW<0 || minW>sumI) minW = sumI;
            }
            else
            {
                bool bFound = 0;
                int64_t targetSum;
                int searchRoot = root;
                int branchRoot = i;
                int64_t w = 0;
                if (sumI > sumJ)
                {
                    targetSum = sumI;
                    sumI = sumJ;
                    sumJ = targetSum;
                    searchRoot = i;
                    branchRoot = root;
                }
                if ((sumI << 1) < sumJ)
                {
                    targetSum = sumJ >> 1;
                    if (sumJ - targetSum != targetSum) continue;
                    w = targetSum - sumI;
                }
                else
                {
                    targetSum = sumI;
                    w = targetSum - (sumJ - sumI);
                }

                if (minW >= 0 && minW < w) continue;

                // search in the main tree
                
                // first, update the testSum;
                if (searchRoot == root)
                {
                    int curr = tree[branchRoot].parent;
                    int64_t branchSum = tree[branchRoot].sum;
                    while (curr != -1)
                    {
                        tree[curr].testSum -= branchSum;
                        curr = tree[curr].parent;
                    }
                }

                childSum(tree, searchRoot, branchRoot, targetSum, &bFound);
                if (bFound)
                {
                    if (minW == -1 || minW > w) minW = w;
                }

                // last, restore the testSum
                if (searchRoot == root)
                {
                    int curr = tree[branchRoot].parent;
                    while (curr != -1)
                    {
                        tree[curr].testSum = tree[curr].sum;
                        curr = tree[curr].parent;
                    }
                }
            }
        }
    }
    return minW;
}

int main()
{
    FILE* fptr = fopen(getenv("OUTPUT_PATH"), "w");

    char* q_endptr;
    char* q_str = readline();
    int q = strtol(q_str, &q_endptr, 10);

    if (q_endptr == q_str || *q_endptr != '\0') { exit(EXIT_FAILURE); }

    for (int q_itr = 0; q_itr < q; q_itr++) {
        char* n_endptr;
        char* n_str = readline();
        int n = strtol(n_str, &n_endptr, 10);

        if (n_endptr == n_str || *n_endptr != '\0') { exit(EXIT_FAILURE); }

        char** c_temp = split_string(readline());

        int* c = malloc(n * sizeof(int));

        for (int i = 0; i < n; i++) {
            char* c_item_endptr;
            char* c_item_str = *(c_temp + i);
            int c_item = strtol(c_item_str, &c_item_endptr, 10);

            if (c_item_endptr == c_item_str || *c_item_endptr != '\0') { exit(EXIT_FAILURE); }

            *(c + i) = c_item;
        }

        int c_count = n;

        int** edges = malloc((n - 1) * sizeof(int*));

        for (int i = 0; i < n - 1; i++) {
            *(edges + i) = malloc(2 * (sizeof(int)));

            char** edges_item_temp = split_string(readline());

            for (int j = 0; j < 2; j++) {
                char* edges_item_endptr;
                char* edges_item_str = *(edges_item_temp + j);
                int edges_item = strtol(edges_item_str, &edges_item_endptr, 10);

                if (edges_item_endptr == edges_item_str || *edges_item_endptr != '\0') { exit(EXIT_FAILURE); }

                *(*(edges + i) + j) = edges_item;
            }
        }

        int edges_rows = n - 1;
        int edges_columns = 2;

        int64_t result = balancedForest(c_count, c, edges_rows, edges_columns, edges);

        fprintf(fptr, "%lld\n", result);
    }

    fclose(fptr);

    return 0;
}

char* readline() {
    size_t alloc_length = 1024;
    size_t data_length = 0;
    char* data = malloc(alloc_length);

    while (true) {
        char* cursor = data + data_length;
        char* line = fgets(cursor, alloc_length - data_length, stdin);

        if (!line) {
            break;
        }

        data_length += strlen(cursor);

        if (data_length < alloc_length - 1 || data[data_length - 1] == '\n') {
            break;
        }

        alloc_length <<= 1;

        data = realloc(data, alloc_length);

        if (!line) {
            break;
        }
    }

    if (data[data_length - 1] == '\n') {
        data[data_length - 1] = '\0';

        data = realloc(data, data_length);
    } else {
        data = realloc(data, data_length + 1);

        data[data_length] = '\0';
    }

    return data;
}

char** split_string(char* str) {
    char** splits = NULL;
    char* token = strtok(str, " ");

    int spaces = 0;

    while (token) {
        splits = realloc(splits, sizeof(char*) * ++spaces);

        if (!splits) {
            return splits;
        }

        splits[spaces - 1] = token;

        token = strtok(NULL, " ");
    }

    return splits;
}


Problem solution in JavaScript programming.

'use strict';

const fs = require('fs');

process.stdin.resume();
process.stdin.setEncoding('utf-8');

let inputString = '';
let currentLine = 0;

process.stdin.on('data', inputStdin => {
    inputString += inputStdin;
});

process.stdin.on('end', function() {
    inputString = inputString.replace(/\s*$/, '')
        .split('\n')
        .map(str => str.replace(/\s*$/, ''));

    main();
});

function readLine() {
    return inputString[currentLine++];
}

// Complete the balancedForest function below.
function balancedForest(c, edges) {
    const nodes = c.map(cost => ({ cost, adj: [], visited: false, solved: false }));
    
    for(let [a,b] of edges) {
        nodes[a-1].adj.push(b-1);
        nodes[b-1].adj.push(a-1);
    }
    
    const dfs = n => {
        if (n.visited) return 0;
        n.visited = true;
        
        for (let a of n.adj)
            n.cost += dfs(nodes[a]);
        return n.cost;
    }
    
    const sum = dfs(nodes[0]);
    //console.log(sum, nodes);

    let min = sum;
    const excsum = {};
    const incsum = {};
    
    const solve = n => {
        if (n.solved) return;
        n.solved = true;
        
        const cost_a = 3 * n.cost - sum;
        const cost_b = (sum - n.cost) / 2 - n.cost;
        //console.log("solve", n, { incsum, excsum }, { min, cost_a, cost_b });

        // can split in two equal subtrees?
        if (sum % 2 === 0 && n.cost === (sum / 2)) min = Math.min(min, sum / 2);

        if (cost_a >= 0 && (
            excsum[n.cost] // another subtree with equal cost?
            || excsum[sum - 2 * n.cost] // another subtree with 1/3 cost
            || incsum[sum - n.cost]) // edge to remove
        ) min = Math.min(min, cost_a);

        if (cost_b >= 0 && (sum - n.cost) % 2 === 0) {
            if (excsum[(sum - n.cost) / 2] || incsum[(sum + n.cost) / 2]) 
                min = Math.min(min, cost_b);
        }

        incsum[n.cost] = true;
        for (let a of n.adj) solve(nodes[a]);
        delete incsum[n.cost];
        excsum[n.cost] = true;
    }
    
    solve(nodes[0]);
    return min === sum ? -1 : min;
}

function main() {
    const ws = fs.createWriteStream(process.env.OUTPUT_PATH);

    const q = parseInt(readLine(), 10);

    for (let qItr = 0; qItr < q; qItr++) {
        const n = parseInt(readLine(), 10);

        const c = readLine().split(' ').map(cTemp => parseInt(cTemp, 10));

        let edges = Array(n - 1);

        for (let i = 0; i < n - 1; i++) {
            edges[i] = readLine().split(' ').map(edgesTemp => parseInt(edgesTemp, 10));
        }

        const result = balancedForest(c, edges);

        ws.write(result + '\n');
    }

    ws.end();
}


Post a Comment

0 Comments