In this HackerRank Stone Game problem solution, Alice and Bob are playing the game of Nim with N piles of stones with sizes P0, P1,...., Pn-1. If Alice plays first, she loses if and only if the 'xor sum' (or 'Nim sum') of the piles is zero.

Since Bob already knows who will win (assuming optimal play), he decides to cheat by removing some stones in some piles before the game starts. However, to reduce the risk of suspicion, he must keep at least one pile unchanged. Your task is to count the number of ways Bob can remove the stones to force Alice into losing the game. Since the number can be very large, output the number of ways modulo 10 to power 9 plus 7. Assume that both players will try to optimize their strategy and try to win the game.

HackerRank Stone Game problem solution


Problem solution in Python.

import operator as op
import functools as ft
from sys import stderr
MOD = 1000000007

def readcase():
    npiles = int(input())
    piles = [int(fld) for fld in input().split()]
    assert npiles == len(piles)
    return piles

def numsolns(piles):
    return (numunrestrictedsolns(piles) - 
            numunrestrictedsolns([pile-1 for pile in piles if pile > 1])) % MOD

def numunrestrictedsolns(piles, MOD=MOD):
    if len(piles) == 0:
        return 1
    xorall = ft.reduce(op.xor, piles)
    leftmost = ft.reduce(op.or_, piles).bit_length() - 1
    rightmost = max(0, xorall.bit_length() - 1)
    ans = 0
    for first1 in range(rightmost, leftmost+1):
        premult = 1
        matchbit = 1 << first1
        for i, bigalt in enumerate(piles):
            if bigalt & matchbit != 0:
                even = 1
                odd = 0
                for pile in piles[i+1:]:
                    neweven = (1 + (pile & ~-matchbit)) * even
                    newodd = (1 + (pile & ~-matchbit)) * odd
                    if pile & matchbit != 0:
                        neweven += matchbit * odd
                        newodd += matchbit * even
                    even, odd = neweven % MOD, newodd % MOD
                ans += (even if xorall & matchbit != 0 else odd) * premult % MOD
            premult = (premult * ((bigalt & ~-matchbit) + 1)) % MOD
    if xorall == 0:
        ans += 1
    return ans % MOD

print(numsolns(readcase()))


Problem solution in Java.

import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

  static final long MODULO = 1_000_000_007;
  static int[] mi;

  static int solve(int n, int k) {
    int[][] vx = new int[n+1][32];
    for (int i = 0; i <= 31; i++) {
      vx[0][i] = 0;
    }
    
    int[][][] dp = new int[n+1][32][2];
    dp[0][0][0] = 1;

    for (int i = 1; i <= n; i++) {
      for (int j = 0; j < 31; j++) {
        vx[i][j] = (vx[i-1][j] ^ (mi[i]&(1<<j)));
      }
    }

    boolean[] valid = new boolean[32];
    valid[31] = true;

    for (int i = 30; i >= 0; i--) {
      valid[i] = valid[i+1] && (vx[n][i] == (k & (1<<i)));
    }

    for (int i = 1; i <= n; i++) {
      for (int j = 0; j < 31; j++) {
        for (int kj = 0; kj < 2; kj++) {
          if (dp[i-1][j][kj] == 0) continue;

          for (int k1 = 0; k1 < 31; k1++) {
            if ((mi[i] & (1<<(k1))) != 0) {
              int small, tmpj, tmpkj;

              if (k1 > j) {
                small = j;
                tmpj = k1;
                tmpkj = (vx[i-1][k1] != 0 ? 1 : 0);
              } else {
                small = k1;
                tmpj = j;

                if (k1 == j) {
                  tmpkj = kj;
                } else {
                  tmpkj = kj^((mi[i] & (1<<j)) != 0 ? 1 : 0);
                }
              }

              dp[i][tmpj][tmpkj] = (int)((dp[i][tmpj][tmpkj] + ((long) dp[i-1][j][kj]) * (1<<small)) % MODULO);
            }
          }
        }
      }
    }

    int res = 0;

    for(int i = 30; i >= 0; i--) {
      if (valid[i+1]) {
        res = (int) ((res + dp[n][i][(k & (1<<i)) != 0 ? 1 : 0]) % MODULO);
      } else {
        break;
      }
    }

    return res;
  }
  
  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    mi = new int[n+1];

    st = new StringTokenizer(br.readLine());
    for (int i = 1; i <= n; i++) {
      int pItem = Integer.parseInt(st.nextToken());
      mi[i] = pItem + 1;
    }

    int result = solve(n, 0);

    for(int i = 1; i <= n; i++) {
      mi[i]--;
    }
    result = (int)((result + MODULO - solve(n, 0)) % MODULO);
    
    bw.write(String.valueOf(result));
    bw.newLine();

    bw.close();
    br.close();
  }
}


Problem solution in C++.

#include<stdio.h>
#include<string.h>

typedef long long LL;
const int MODULO = 1000000007, N = 120;
int dp[N][32][2], vx[N][32], valid[32], mi[N];

int solve(int n,int k) {
    for(int i = 0;i <= 31; i++) vx[0][i] = 0;

        memset(dp,0,sizeof(dp));
        dp[0][0][0] = 1;

        for(int i = 1; i <= n; i++)
            for(int j = 0; j < 31; j++)
                vx[i][j] = (vx[i-1][j] ^ (mi[i]&(1<<j)));

        valid[31] = 1;

        for(int i = 30; i >= 0; i--)
            valid[i] = valid[i+1] && (vx[n][i] == (k & (1<<i)));

        for(int i = 1; i <= n; i++)
            for(int j = 0; j < 31; j++)
                for(int kj = 0; kj < 2; kj++) {
                    if (dp[i-1][j][kj] == 0) continue;

                    for(int k = 0; k < 31; k++)
                        if (mi[i] & (1<<(k))) {
                            int small, tmpj, tmpkj;

                            if (k > j) {
                                small = j;
                                tmpj = k;
                                tmpkj = (vx[i-1][k] ? 1 : 0);
                            } else {
                                small = k;
                                tmpj = j;

                                if (k == j)
                                    tmpkj = kj;
                                else
                                    tmpkj = kj^((mi[i] & (1<<j)) ? 1 : 0);
                            }

                            dp[i][tmpj][tmpkj] = (dp[i][tmpj][tmpkj] + ((LL) dp[i-1][j][kj]) * (1<<small)) % MODULO;
                        }
                }

        int res = 0;

        for(int i = 30; i >= 0; i--)
            if (valid[i+1])
                res = (res + dp[n][i][(k & (1<<i)) ? 1 : 0]) % MODULO;
            else break;

        return res;
}

int main() {
    int n, k = 0, res;

    scanf("%d", &n);

    for(int i = 1; i <= n; i++) {
        scanf("%d", &mi[i]);
        mi[i]++;
    }

    res = solve(n,k);

    for(int i = 1; i <= n; i++) mi[i]--;
    res = (res + MODULO - solve(n,k)) % MODULO;
    printf("%d\n", res);

    return 0;
}


Problem solution in C.

#include <assert.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MOD 1000000007

char* readline();
char** split_string(char*);

/*
 * Complete the stoneGame function below.
 */

long betterStoneGame(int p_count, int* p){
    long total = 0;
    long currp[p_count];
    for(int i = 0; i < p_count; i++){
        currp[i] = p[i];
    }

    long invtwopower[32];
    invtwopower[0] = 1;
    for(int i = 0; i < 31; i++){
        invtwopower[i + 1] = (invtwopower[i]*500000004)%MOD;
    }

    int addall0 = 1;
    for(int i = 31; i >= 0; i--){
        long numi0 = 1;
        long numi1 = 0;
        long allabove0 = 1;
        long allabove1 = 0;
        int numabove = 0;
        for(int j = 0; j < p_count; j++){
            long oldi0 = numi0;
            long oldi1 = numi1;
            long oldabove0 = allabove0;
            long oldabove1 = allabove1;
            if(((currp[j]>>i) & 1) == 1){
                numi0 = ((1<<i)*oldi0 + (currp[j] + 1 - (1<<i))*oldi1)%MOD;
                numi1 = ((1<<i)*oldi1 + (currp[j] + 1 - (1<<i))*oldi0)%MOD;
                allabove0 = ((currp[j] + 1 - (1<<i))*oldabove1)%MOD;
                allabove1 = ((currp[j] + 1 - (1<<i))*oldabove0)%MOD;
                numabove++;
                currp[j] -= (1<<i);
            }
            else{
                numi0 = ((currp[j] + 1)*numi0)%MOD;
                numi1 = ((currp[j] + 1)*numi1)%MOD;
                allabove0 = ((currp[j] + 1)*allabove0)%MOD;
                allabove1 = ((currp[j] + 1)*allabove1)%MOD;
            }
        }
        
        if((numabove & 1) == 1){
            total = (total + numi0*invtwopower[i])%MOD;
            addall0 = 0;
            break;
        }
        else{
            total = (total + ((numi0  + MOD - allabove0)*invtwopower[i]))%MOD;
        }
    }
    total = (total + addall0)%MOD;
    return total;
}

long stoneGame(int p_count, int* p) {
    int *subp = malloc(p_count*sizeof(int));
    for(int i = 0; i < p_count; i++){
        subp[i] = p[i] - 1;
    }
    return (betterStoneGame(p_count, p) + MOD  - betterStoneGame(p_count, subp))%MOD;
}

int main()
{
    FILE* fptr = fopen(getenv("OUTPUT_PATH"), "w");

    char* p_count_endptr;
    char* p_count_str = readline();
    int p_count = strtol(p_count_str, &p_count_endptr, 10);

    if (p_count_endptr == p_count_str || *p_count_endptr != '\0') { exit(EXIT_FAILURE); }

    char** p_temp = split_string(readline());

    int* p = malloc(p_count * sizeof(int));

    for (int p_itr = 0; p_itr < p_count; p_itr++) {
        char* p_item_endptr;
        char* p_item_str = *(p_temp + p_itr);
        int p_item = strtol(p_item_str, &p_item_endptr, 10);

        if (p_item_endptr == p_item_str || *p_item_endptr != '\0') { exit(EXIT_FAILURE); }

        *(p + p_itr) = p_item;
    }

    long result = stoneGame(p_count, p);

    fprintf(fptr, "%ld\n", result);

    fclose(fptr);

    return 0;
}

char* readline() {
    size_t alloc_length = 1024;
    size_t data_length = 0;
    char* data = malloc(alloc_length);

    while (true) {
        char* cursor = data + data_length;
        char* line = fgets(cursor, alloc_length - data_length, stdin);

        if (!line) { break; }

        data_length += strlen(cursor);

        if (data_length < alloc_length - 1 || data[data_length - 1] == '\n') { break; }

        size_t new_length = alloc_length << 1;
        data = realloc(data, new_length);

        if (!data) { break; }

        alloc_length = new_length;
    }

    if (data[data_length - 1] == '\n') {
        data[data_length - 1] = '\0';
    }
    if(data[data_length - 1] != '\0'){
        data_length++;
        data = realloc(data, data_length);
        data[data_length - 1] = '\0';
    }

    data = realloc(data, data_length);

    return data;
}

char** split_string(char* str) {
    char** splits = NULL;
    char* token = strtok(str, " ");

    int spaces = 0;

    while (token) {
        splits = realloc(splits, sizeof(char*) * ++spaces);
        if (!splits) {
            return splits;
        }

        splits[spaces - 1] = token;

        token = strtok(NULL, " ");
    }

    return splits;
}