In this HackerEarth Simple Sum problem solution, You have been given an array of N integers A1,A2..AN. You have to find simple sum for this array. Simple Sum is defined as Sigma(i=1,i=N) Sigma(j=i,j=N) max(Ai,Ai+1,...,Aj) * (Ai | Aj). | denotes the bitwise OR operator.


HackerEarth Simple Sum problem solution


HackerEarth Simple Sum problem solution.

#include <bits/stdc++.h>
using namespace std;
#define ll long long unsigned
#define pb push_back
#define fr freopen("in.txt","r",stdin)
#define rep(i,n) for(int i=0;i<n;i++)
#define frep(i,n) for(int i=1;i<=n;i++)
#define maxval 100011
#define maxn 300011
#define pi pair<int,int>
#define f first
#define s second
#define MAXBITS 15
ll A[100011];
int dp[100011][20];
ll cnt[100011][20];
ll ans = 0;
int query(int i,int j) {
int len = j-i+1;
len = log2(len);
int p = dp[i][len];
int q = dp[j-(1<<len)+1][len];
if(A[p]>A[q]) return p;
return q;
}
void calc(int i,int j) {
if(i==j) {
ans+=A[i]*A[i];
return;
}
if(i>j) return;
int m = query(i,j);
if(m-i<=j-m) {
for(int k=i;k<=m;k++) {
rep(p,MAXBITS) {
if(A[k]&(1<<p)) {
ans+=A[m]*(1LL<<p)*(ll)(j-m+1);
} else{
ans+=A[m]*(1LL<<p)*(ll)(cnt[j][p]-cnt[m-1][p]);
}
}
}
} else{
for(int k=m;k<=j;k++) {
rep(p,MAXBITS) {
if(A[k]&(1<<p)) {
ans+=A[m]*(1LL<<p)*(ll)(m-i+1);
} else{
ans+=A[m]*(1LL<<p)*(ll)(cnt[m][p]-cnt[i-1][p]);
}
}
}
}
//ans = 0;
calc(i,m-1);
calc(m+1,j);
}
int main() {
freopen("in10.txt","r",stdin);
freopen("out10.txt","w",stdout);

int N;
cin >> N;
frep(i,N) {
cin >> A[i];
dp[i][0] = i;
rep(j,MAXBITS) {
cnt[i][j] = cnt[i-1][j];
if(A[i]&(1<<j)) {
cnt[i][j]++;
}
}
}
int p,q;
for(int s=1;s<20;s++) {
frep(i,N-(1<<s)+1) {
p = dp[i][s-1];
q = dp[i+(1<<(s-1))][s-1];
if(A[p]>A[q]) dp[i][s] = p;
else dp[i][s] = q;
}
}
calc(1,N);
cout << ans;
}

Second solution

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cassert>
#include <algorithm>
using namespace std;

const int MAXN = 100000;
const int MAXM = 10000;
const int MAXBITS = 16;

int a[MAXN], cnt[MAXBITS][MAXN + 1];

struct MaxInfo
{
int value, i;
MaxInfo() {}
MaxInfo(int value, int i) : value(value), i(i) {}
};

MaxInfo st[20][MAXN];

inline bool operator < (const MaxInfo& a, const MaxInfo &b)
{
return a.value < b.value;
}

MaxInfo getMax(int l, int r)
{
int len = r - l + 1;
int x = (int)log2(len);
return max(st[x][l], st[x][r - (1 << x) + 1]);
}

long long divideAndConquer(int l, int r)
{
if (l > r) {
return 0;
}
if (l == r) {
return (long long) a[l] * a[l];
}
int mid = l + r >> 1;

MaxInfo maxValue = getMax(l, r);
long long ret = divideAndConquer(l, maxValue.i - 1) + divideAndConquer(maxValue.i + 1, r);
if (maxValue.i < mid) {
for (int i = l; i <= maxValue.i; ++ i) {
for (int bit = 0; bit < MAXBITS; ++ bit) {
if (a[i] >> bit & 1) {
ret += (long long)maxValue.value * (r - maxValue.i + 1) * (1LL << bit);
} else {
ret += (long long)maxValue.value * (cnt[bit][r + 1] - cnt[bit][maxValue.i]) * (1LL << bit);
}
}
}
} else {
for (int j = maxValue.i; j <= r; ++ j) {
for (int bit = 0; bit < MAXBITS; ++ bit) {
if (a[j] >> bit & 1) {
ret += (long long)maxValue.value * (maxValue.i - l + 1) * (1LL << bit);
} else {
ret += (long long)maxValue.value * (cnt[bit][maxValue.i + 1] - cnt[bit][l]) * (1LL << bit);
}
}
}
}
return ret;
}

int main()
{
int n;
assert(scanf("%d", &n) == 1 && 1 <= n && n <= MAXN);
for (int i = 0; i < n; ++ i) {
assert(scanf("%d", &a[i]) == 1);
// fprintf(stderr, "%d\n", a[i]);
assert(1 <= a[i] && a[i] <= MAXM);
st[0][i] = MaxInfo(a[i], i);
for (int bit = 0; bit < MAXBITS; ++ bit) {
cnt[bit][i + 1] = cnt[bit][i] + (a[i] >> bit & 1);
}
}
for (int i = 0, len = 1; len < n; ++ i, len *= 2) {
for (int j = 0; j + len * 2 <= n; ++ j) {
st[i + 1][j] = max(st[i][j], st[i][j + len]);
}
}
printf("%lld\n", divideAndConquer(0, n - 1));
return 0;
}