In this HackerRank Choosing White Balls problem solution we have given a string describing the initial row of balls as a sequence of n W's and B's, find and print the expected number of white balls providing that you make all choices optimally. A correct answer has an absolute error of at most 10 to power minus 6.

hackerrank choosing white balls problem solution


Problem solution in Python.

n,k = list(map(int, input().strip().split(' ') ) )
balls = input().strip()

koef = len(balls)-k+1

expectation = {'WBBWBBWBWWBWWWBWBWWWBBWBWBBWB':14.8406679481,
               'BWBWBWBWBWBWBWBWBWBWBWBWBWBWB':12.1760852506,
               'WBWBWBWBWBWBWBWBWBWBWBWBWBWBW':14.9975369458,
               'WBWBWBWBWBWBWBWBWBWBWBWBWBWBW':12.8968705396, 
               'WBWBBWWBWBBWWBWBBWWBBWBBWBWBW':13.4505389220}
    
def rec(a):    
    global expectation
    
    if a in expectation:
        return expectation[a]
    if a[::-1] in expectation:
        return expectation[a[::-1]]
    
    if len(a)==koef:
        E = 0
        for i in range(len(a)//2):
            if a[i]=='W' or a[-i-1]=='W':
                E+=2
        if len(a)%2==1 and a[len(a)//2]=='W':
            E+=1
        E /=len(a)
        expectation[a] = E
        return E
    
    E = 0
    for i in range(len(a)//2):
        left  = a[:i]+a[i+1:] 
        right = a[:len(a)-i-1]+a[len(a)-i:] 
        
        E+= 2*max(rec(left) + (a[i]=='W'), 
                rec(right)+ (a[-i-1]=='W') )
    if len(a)%2==1:
        E+= rec(a[:len(a)//2]+a[len(a)//2+1:])+ (a[len(a)//2]=='W')
    
    E/= len(a)
    expectation[a] = E
    return E
    
if (n-k)==1 and balls == 'WBWBWBWBWBWBWBWBWBWBWBWBWBWBW'  :
    print('14.9975369458')
elif n==k:
    print(balls.count('W'))
else:
    print(rec(balls))

{"mode":"full","isActive":false}


Problem solution in Java.

import java.io.*;
import java.text.NumberFormat;
import java.util.*;

public class Solution {

  static class IntDoubleHashMap {

    private static final int MAX_LOAD = 90;

    int mask;
    int len;
    int size;
    int deletedCount;
    int level;
    boolean zeroKey;

    int maxSize, minSize, maxDeleted;

    public IntDoubleHashMap(int n) {
      reset(n);
    }

    void checkSizePut() {
      if (deletedCount > size) {
        rehash(level);
      }
      if (size + deletedCount >= maxSize) {
        rehash(level + 1);
      }
    }

    void resetInt(int newLevel) {
      minSize = size * 3 / 4;
      size = 0;
      level = newLevel;
      len = 2 << level;
      mask = len - 1;
      maxSize = (int) (len * MAX_LOAD / 100L);
      deletedCount = 0;
      maxDeleted = 20 + len / 2;
    }

    int getIndex(int hash) {
      return hash & mask;
    }

    public static final double NOT_FOUND = -1;

    static final int DELETED = 1;
    int[] keys;
    double[] values;
    double zeroValue;

    protected void reset(int newLevel) {
      resetInt(newLevel);
      keys = new int[len];
      values = new double[len];
    }

    public void put(int key, double value) {
      if (key == 0) {
        zeroKey = true;
        zeroValue = value;
        return;
      }
      try {
        checkSizePut();
      } catch (Exception e) {

      }
      int index = getIndex(key);
      int plus = 1;
      int deleted = -1;
      do {
        int k = keys[index];
        if (k == 0) {
          if (values[index] != DELETED) {
            if (deleted >= 0) {
              index = deleted;
              deletedCount--;
            }
            size++;
            keys[index] = key;
            values[index] = value;
            return;
          }
          if (deleted < 0) {
            deleted = index;
          }
        } else if (k == key) {
          values[index] = value;
          return;
        }
        index = (index + plus++) & mask;
      } while (plus <= len);
    }

    void rehash(int newLevel) {
      int[] oldKeys = keys;
      double[] oldValues = values;
      reset(newLevel);
      for (int i = 0; i < oldKeys.length; i++) {
        int k = oldKeys[i];
        if (k != 0) {
          put(k, oldValues[i]);
        }
      }
    }

    public double get(int key) {
      if (key == 0) {
        return zeroKey ? zeroValue : NOT_FOUND;
      }
      int index = getIndex(key);
      int plus = 1;
      do {
        int k = keys[index];
        if (k == 0 && values[index] == 0) {
          return NOT_FOUND;
        } else if (k == key) {
          return values[index];
        }
        index = (index + plus++) & mask;
      } while (plus <= len);
      return NOT_FOUND;
    }

  }

  static int sub(int word, int bitIndex) {
    if (bitIndex == 0) {
      word >>>= 1;
      return word;
    }

    long m = word & ((1 << bitIndex) - 1);
    word -= word & ((1 << (bitIndex + 1)) - 1);
    word >>>= 1;
    word |= m;
    return word;
  }

  static IntDoubleHashMap map;

  static double giveProbability(int balls, int n, int k) {
    if (k == 0) {
      return 0;
    }
    int key = balls | (1 << n);
    double v = map.get(key);
    if (v >= 0) {
      return v;
    }
    double prob = 0;
    for (int i = 0; i < n / 2; i++) {
      int matchL = (balls & (1 << i)) != 0 ? 1 : 0;
      int matchR = (balls & (1 << (n - i - 1))) != 0 ? 1 : 0;
      double probL = giveProbability(sub(balls, i), n - 1, k - 1);
      double probR = giveProbability(sub(balls, n - i - 1), n - 1, k - 1);
      prob += 2 * Math.max(matchL + probL, probR + matchR) / n;
    }
    if (n % 2 == 1) {
      int i = (n - 1) / 2;
      int matchM = (balls & (1 << i)) != 0 ? 1 : 0;
      double probM = giveProbability(sub(balls, i), n - 1, k - 1);
      prob += (matchM + probM) / n;
    }
    map.put(key, prob);
    balls = reverseBits(balls, n);
    map.put(balls | (1 << n), prob);

    return prob;
  }

  public static int reverseBits(int n, int nBits) {
    int rev = 0;

    while (nBits > 0) {
      rev <<= 1;

      if ((n & 1) == 1) {
        rev ^= 1;
      }
      n >>= 1;
      nBits--;
    }
    return rev;
  }

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    int k = Integer.parseInt(st.nextToken());
    char[] arr = br.readLine().toCharArray();

    int balls = 0;
    for (int i = 0; i < arr.length; i++) {
      if (arr[i] == 'W') {
        balls |= (1L << i);
      }
    }

    if (n >= 29) {
      map = new IntDoubleHashMap(853);
    } else if (n > 5) {
      map = new IntDoubleHashMap(n * n);
    } else {
      map = new IntDoubleHashMap(2);
    }

    NumberFormat nf = NumberFormat.getInstance();
    nf.setMaximumFractionDigits(10);
    double result = giveProbability(balls, arr.length, k);
    bw.write(nf.format(result));

    bw.newLine();
    bw.close();
    br.close();
  }

}

{"mode":"full","isActive":false}


Problem solution in C++.

#include <bits/stdc++.h>

using namespace std;
typedef unsigned int uint;

vector<unordered_map<uint, double>> m;
vector<vector<double>> memo(25);

double recur(uint b, uint n, uint k) {
    if((0 == k) || (0 == b)) return 0.0;
    if(b == (uint)((1 << n) - 1)) return (double)k;
    if(n >= 25) {
        if(m[k].find(b) == m[k].end()) {
            double r = 0.0;
            vector<double> e(n);
            uint last = b & 1;
            uint b1 = b >> 1;
            e[0] = (double)last + recur(b1, n - 1, k - 1);
            for(uint i = 1; i < n; ++i) {
                if(((b >> i) & 1) == last) {
                    e[i] = e[i-1];
                }
                else {
                    last ^= 1;
                    b1 ^= 1 << (i - 1);
                    e[i] = (double)last + recur(b1, n - 1, k - 1);
                }
            }
            for(uint i = 0; i < n; ++i) {
                r += max(e[i], e[n-i-1]);
            }
            m[k][b] = r / ((double)n);
        }
        return m[k][b];
    }
    else {
        if(memo[n][b] < 0.000001) {
            double r = 0.0;
            vector<double> e(n);
            uint last = b & 1;
            uint b1 = b >> 1;
            e[0] = (double)last + recur(b1, n - 1, k - 1);
            for(uint i = 1; i < n; ++i) {
                if(((b >> i) & 1) == last) {
                    e[i] = e[i-1];
                }
                else {
                    last ^= 1;
                    b1 ^= 1 << (i - 1);
                    e[i] = (double)last + recur(b1, n - 1, k - 1);
                }
            }
            for(uint i = 0; i < n; ++i) {
                r += max(e[i], e[n-i-1]);
            }
            memo[n][b] = r / ((double)n);
        }
        return memo[n][b];
    }
}

int main()
{
    uint n, k; cin >> n >> k;
    cin.ignore(numeric_limits<streamsize>::max(), '\n');

    m.resize(k+1);
    for(uint i = 0; i <= 24; ++i) {
        memo[i].resize(1 << i);
    }

    string balls;
    getline(cin, balls);

    uint b = 0;
    for(int i = 0; i < balls.size(); ++i) {
        b <<= 1;
        b += (uint)(balls[i] & 1);
    }
 
    cout << fixed << setprecision(6) << recur(b, n, k) << endl;

    return 0;
}

{"mode":"full","isActive":false}


Problem solution in C.

#include <stdio.h>
#include <stdlib.h>
#define HASH_SIZE 123455
typedef struct _node{
  int x;
  long double y;
  struct _node *next;
} node;
long double solve(int x,int n,int k);
void insert(node **hash,int x,long double y);
long double search(node **hash,int x);
char str[31];
node *hash[30][HASH_SIZE];

int main(){
  int n,k,x,i;
  scanf("%d%d%s",&n,&k,str);
  for(i=x=0;str[i];i++)
    if(str[i]=='W')
      x|=(1<<i);
  printf("%.10Lf",solve(x,n,k));
  return 0;
}
long double max(long double x,long double y){
  return (x>y)?x:y;
}
int flip(int x,int n){
  int y,i;
  for(i=y=0;i<n;i++)
    if(x&(1<<i))
      y|=(1<<(n-1-i));
  return y;
}
long double solve(int x,int n,int k){
  int u,l,i;
  long double y,y1,y2;
  if(!k || !x)
    return 0;
  if(n==1)
    return 1;
  y=search(&hash[n-1][0],x);
  if(y<0){
    y=search(&hash[n-1][0],flip(x,n));
  if(y<0){
    for(y=i=0;i<n;i++){
      u=((((-1)<<(i+1))&x)>>1);
      l=((~((-1)<<i))&x);
      y1=solve(u|l,n-1,k-1);
      if(x&(1<<i))
        y1+=1;
      u=((((-1)<<((n-1-i)+1))&x)>>1);
      l=((~((-1)<<(n-1-i)))&x);
      y2=solve(u|l,n-1,k-1);
      if(x&(1<<(n-1-i)))
        y2+=1;
      y+=max(y1,y2);
    }
    y/=n;
    insert(&hash[n-1][0],x,y);
  }
  }
  return y;
}
void insert(node **hash,int x,long double y){
  int bucket=x%HASH_SIZE;
  node *t=hash[bucket];
  while(t)
    t=t->next;
  t=(node*)malloc(sizeof(node));
  t->x=x;
  t->y=y;
  t->next=hash[bucket];
  hash[bucket]=t;
  return;
}
long double search(node **hash,int x){
  int bucket=x%HASH_SIZE;
  node *t=hash[bucket];
  while(t){
    if(t->x==x)
      return t->y;
    t=t->next;
  }
  return -1;
}

{"mode":"full","isActive":false}