HackerEarth Xsquare And Array Operations problem solution

In this HackerEarth Xsquare And Array Operations problem Xsquare loves to play with arrays a lot. Today, he has an array A consisting of N distinct integers. He wants to perform following operation over his array A.

Select a pair of consecutive integers say (Ai,Ai+1) for some 1 ≤ i < N. Replace the selected pair of integers with the max(Ai,Ai+1).
Replace N with the new size of the array A.
Above operation incurs a cost of rupees max(Ai,Ai+1) to Xsquare.
As we can see after N-1 such operations, there is only 1 element left. Xsquare wants to know the most optimal transformation strategy to reduce the given array A to only 1 element.

A transformation strategy is considered to be most optimal if it incurs minimum cost over all possible transformation strategies.

#include <bits/stdc++.h>
using namespace std ;
#define LL long long int
#define ft first
#define sd second
#define PII pair<int,int>
#define MAXN 100001
#define MAXM 1001
#define mp make_pair
#define f_in(st) freopen(st,"r",stdin)
#define f_out(st) freopen(st,"w",stdout)
#define sc(x) scanf("%d",&x)
#define scll(x) scanf("%lld",&x)
#define pr(x) printf("%lld\n",x)
#define pb push_back
#define MOD 1000000007
class node{
public :
int max_value , max_index;
max_value = 0 ;
max_index = -1 ;
} ;

class SegTree{

public :
node st[4*MAXN] ;
vector<int> A ;
int N ;

SegTree(int N,vector<int> &A){
this->N = N ;
(this->A).resize(N+1) ;
for(int i=1;i<=N;i++){
(this->A)[i] = A[i] ;

void _merge(node &a,node &b,node &c){
if(b.max_value < c.max_value){
a.max_value = c.max_value ;
a.max_index = c.max_index ;
a.max_value = b.max_value ;
a.max_index = b.max_index ;

void buildst(int idx,int ss,int se){
if(ss == se){
st[idx].max_value = A[ss] ;
st[idx].max_index = ss ;
return ;
int mid = (ss+se)/2 ;
buildst(2*idx,ss,mid) ;
buildst(2*idx+1,mid+1,se) ;
_merge(st[idx],st[2*idx],st[2*idx+1]) ;

void update(int idx,int ss,int se,int pos,int val){
if(ss == se){
st[idx].max_value = val ;
return ;
int mid = (ss+se)/2 ;
if(pos <= mid)
update(2*idx,ss,mid,pos,val) ;
update(2*idx+1,mid+1,se,pos,val) ;
_merge(st[idx],st[2*idx],st[2*idx+1]) ;

node query(int idx,int ss,int se,int L,int R){
node ret ;
if(L > se || R < ss)
return ret ;
if(L <= ss && se <= R)
return st[idx] ;
int mid = (ss+se)/2 ;
node left = query(2*idx,ss,mid,L,R) ;
node right = query(2*idx+1,mid+1,se,L,R) ;
_merge(ret,left,right) ;
return ret ;

} ;

int N,T;
vector<int> A ;
SegTree *obj ;
LL solve(int ss,int se){
if(se <= ss)
return 0 ;
else if(se-ss+1 == 2){
return max(A[ss],A[se]) ;
node ret = obj->query(1,1,N,ss,se) ;
if(ret.max_index == ss){
return ret.max_value + solve(ss+1,se) ;
}else if(ret.max_index == se){
return ret.max_value + solve(ss,se-1) ;
return 2*ret.max_value + solve(ss,ret.max_index-1) + solve(ret.max_index+1,se) ;
int main(){
f_in("in04.txt") ;
f_out("out04.txt") ;
sc(T) ;
assert( T <= 100000 && T >= 1 ) ;
sc(N) ;
assert( N >= 1 && N <= 100000 ) ;
A.resize(N+1) ;
for(int i=1;i<=N;i++){
sc(A[i]) ;
assert(A[i] >= 1 && A[i] <= 1e9) ;
obj = new SegTree(N,A) ;
obj->buildst(1,1,N) ;
pr(solve(1,N)) ;
return 0 ;

Second solution

#include <bits/stdc++.h>
#define lli long long
#define MAX 100005
using namespace std;

int n;
lli A[MAX];

struct node {
lli mx;
int idx;
node() { }
node(lli mx, int idx)
this->mx = mx;
this->idx = idx;

node combine(node p1, node p2)
node ret;
if ( p1.mx > p2.mx ) return p1;
return p2;

void build(int where, int left, int right)
if ( left > right ) return;
if ( left == right ) {
tree[where].mx = A[left];
tree[where].idx = left;
int mid = (left+right)/2;
build(where*2, left, mid);
build(where*2+1, mid+1, right);
tree[where] = combine(tree[where*2], tree[where*2+1]);

node query(int where, int left, int right, int i, int j)
if ( left > right || left > j || right < i ) return node(-1,-1);
if ( left >= i && right <= j ) return tree[where];
int mid = (left+right)/2;
return combine(query(where*2, left, mid, i, j), query(where*2+1, mid+1, right, i, j));

lli f(int left, int right)
if ( left >= right ) return 0;
node val = query(1,0,n-1,left,right);
if ( val.idx == left || val.idx == right ) return val.mx + f(left,val.idx-1) + f(val.idx+1,right);
else return 2LL*val.mx + f(left,val.idx-1) + f(val.idx+1,right);

int main()
int t;
cin >> t;
while ( t-- ) {
cin >> n;
for ( int i = 0; i < n; i++ ) cin >> A[i];
lli ans = f(0,n-1);
cout << ans << endl;
return 0;

