In this HackerRank Max Transform problem solution, we are given an array and we need to find the sum of the elements of S(S(array)). the max transform of the max transform of the array. since the answer is very large, so we only find it modulo 10 to power 9 plus 7.

HackerRank Max Transform problem solution


Problem solution in Python programming.

#!/bin/python3

import math
import os
import random
import re
import sys

# Complete the solve function below.

import math
import os
import random
import re
import sys
sys.setrecursionlimit(9999999)
from decimal import Decimal
def t1(n):
    return Decimal(n * (n + 1) / 2)


def t2(n):
    return Decimal(n * (n + 1) * (n + 2) / 6)


def u2(n):
    return Decimal(n * (n + 2) * (2 * n + 5) / 24)


def countzip(a, b):
    return u2(a + b) - u2(abs(a - b)) + t2(abs(a - b))


def countends(x, n, ex):
    return countzip(n, ex) - countzip(x, ex) - countzip(n - 1 - x, 0)


def countsplit(x, n):
    return t1(t1(n)) - t1(x) - countzip(n - x - 1, x - 1)


K = 20
lg = [0] * (1 << K)
for i in range(K):
    lg[1 << i] = i
for i in range(1, 1 << K):
    lg[i] = max(lg[i], lg[i - 1])


def make_rangemax(A):
    n = len(A)
    assert 1 << K > n

    key = lambda x: A[x]
    mxk = []
    mxk.append(range(n))
    for k in range(K - 1):
        mxk.append(list(mxk[-1]))
        for i in range(n - (1 << k)):
            mxk[k + 1][i] = max(
            mxk[k][i], mxk[k][i + (1 << k)],
             key=key)

    def rangemax(i, j):
        k = lg[j - i]
        return max(mxk[k][i], mxk[k][j - (1 << k)], key=key)

    return rangemax


def brutesolo(A):
    rangemax = make_rangemax(A)
    stack = [(0, len(A))]
    ans = 0
    while stack:
        i, j = stack.pop()
        if i != j:
            x = rangemax(i, j)
            stack.append((i, x))
            stack.append((x + 1, j))
            ans += A[x] * (x - i + 1) * (j - x)
    return ans


def make_brute(A):
    rangemax = make_rangemax(A)

    def brute(i, j):
        stack = [(i, j)]
        ans = 0
        while stack:
            i, j = stack.pop()
            if i != j:
                x = rangemax(i, j)
                stack.append((i, x))
                stack.append((x + 1, j))
                ans += A[x] * countends(x - i, j - i, 0)
        return ans

    return brute, rangemax


def ends(A, B):
    brutea, rangemaxa = make_brute(A)
    bruteb, rangemaxb = make_brute(B)

    stack = [(len(A), len(B))]
    ans = 0
    while stack:
        i, j = stack.pop()
        if i == 0:
            ans += bruteb(0, j)
        elif j == 0:
            ans += brutea(0, i)
        else:
            x = rangemaxa(0, i)
            y = rangemaxb(0, j)
            if A[x] < B[y]:
                ans += bruteb(y + 1, j)
                ans += B[y] * countends(y, j, i)
                stack.append((i, y))
            else:
                ans += brutea(x + 1, i)
                ans += A[x] * countends(x, i, j)
                stack.append((x, j))

    return ans


def maxpairs(a):
    return [max(x, y) for x, y in zip(a, a[1:])]


def solve(A):
    n = len(A)
    x = max(range(n), key=lambda x: A[x])
    return (int((brutesolo(A[:x]) +
    ends(A[x + 1:][::-1], maxpairs(A[:x])) + 
    A[x] * countsplit(x, n))%(10**9+7)))



if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    n = int(input())

    A = list(map(int, input().rstrip().split()))

    result = solve(A)

    fptr.write(str(result) + '\n')

    fptr.close()


Problem solution in Java Programming.

import java.io.*;
import java.math.*;
import java.security.*;
import java.text.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.regex.*;

public class Solution {

    // Complete the solve function below.
    static final int SUM_DIV = 1000000007;
    static class Plateau {
        final int start;
        final int end;
        final int v;

        Plateau(int start, int end, int v) {
            this.start = start;
            this.end = end;
            this.v = v;
        }

        @Override
        public String toString() {
            return new StringJoiner(", ",  "[", "]")
                    .add("start=" + start)
                    .add("end=" + end)
                    .add("v=" + v)
                    .toString();
        }
    }

    static int solve(int[] input) {
        // Return the sum of S(S(A)) modulo 10^9+7.
        final Map<Integer, Plateau> mapStart = new HashMap<>(input.length * 2);
        final Map<Integer, Plateau> mapEnd = new HashMap<>(input.length * 2);
        for (int i = 0; i < input.length; ++i) {
            Plateau p = new Plateau(i, i, input[i]);
            mapStart.put(i, p);
            mapEnd.put(i, p);
        }
        long subtract = 0;
        Plateau cur = mapStart.remove(0);
        mapEnd.remove(0);

        for (;;) {

            if (mapStart.isEmpty()) {
                long total = totalCount(input.length) ;
                long result = ((((long)cur.v) * total + subtract) + SUM_DIV) % SUM_DIV;
//                System.out.println("total=" + total + " subtract=" + subtract + " result=" + result);
                return (int)result;
            }
            Plateau prev = mapEnd.get(normalize(cur.start - 1, input));
            if (prev.v == cur.v) {
                // extend plateau

                cur = new Plateau(prev.start, cur.end, cur.v);
//                System.out.println("Extending plateau back, " + cur.toString());
                mapStart.remove(prev.start);
                mapEnd.remove(prev.end);
                continue;
            }
            Plateau next = mapStart.get(normalize(cur.end + 1, input));
            if (next.v == cur.v) {
                cur = new Plateau(cur.start, next.end, cur.v);
//                System.out.println("Extending plateau forward, " + cur.toString());
                mapStart.remove(next.start);
                mapEnd.remove(next.end);
                continue;
            }
            if (next.v > cur.v && prev.v > cur.v) {
                // found plateau; pull it up
                int nextV = Math.min(next.v, prev.v);
                long delta = (long) (nextV - cur.v);
                if (cur.end >= cur.start) {
                    delta *= calculateCounts(normalize(cur.end - cur.start + 1, input));
                } else {
                    delta *= countInverse(input.length - cur.start, normalize(cur.end + 1 - cur.start, input));
                }
//                System.out.println("Pull up, nextV=" + nextV + " cur=" + cur +
//                        " subDelta=" + delta + " sub=" + subtract + "->" + (subtract - delta));
                subtract -= delta;
                subtract %= SUM_DIV;
                cur = new Plateau(cur.start, cur.end, nextV);

                continue;
            }

//            System.out.println("value=" + (countMaxClean(input) + subtract + " " + Arrays.toString(input)));
            boolean back = prev.v < cur.v;
            Plateau successor;
            if (back) {
                successor = prev;
            } else { //next < v
                successor = next;
            }
            mapStart.remove(successor.start);
            mapEnd.remove(successor.end);
            mapStart.put(cur.start, cur);
            mapEnd.put(cur.end, cur);
            cur = successor;
//            System.out.println("Switch " + (back ? "back" : "forw") + ", " + cur.toString());
        }
    }
    private static int normalize(int idx, int[] input) {
        return (idx + input.length) % input.length;
    }

    private static int getByIdx(int[] input, int i) {
        return input[normalize(i + input.length, input)];
    }

    private static long totalCount(long n) {
        long s1Size = n * (n + 1) / 2  % SUM_DIV;
        long s2Size = (s1Size * (s1Size + 1) / 2) % SUM_DIV;
        return s2Size;
    }

    private static long calculateCounts(long n) {
        return (n * n * n + 3 * n * n + 2 * n) / 6 % SUM_DIV;
    }

    private static long countInverse(long c1, long l) {
        // Don't ask
        if (c1 <= l / 2) {
            return (-4 * c1 * c1 * c1 + l * l * l +
                    6 * c1 * c1 * l - 3 * c1 * l * l
                    - 3 * c1 * l + 3 * l * l - 2 * c1 + 2 * l) / 6 % SUM_DIV;
        } else {
            return ((countInverse(l - c1 - 1, l) - temp(c1 + 1) % SUM_DIV + temp(l - c1) % SUM_DIV) +
                    SUM_DIV) % SUM_DIV;
        }
    }
    private static long temp(long n) {
        return (n * n + n) / 2;
    }


    private static final Scanner scanner = new Scanner(System.in);

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

        int n = scanner.nextInt();
        scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

        int[] A = new int[n];

        String[] AItems = scanner.nextLine().split(" ");
        scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

        for (int i = 0; i < n; i++) {
            int AItem = Integer.parseInt(AItems[i]);
            A[i] = AItem;
        }

        int result = solve(A);

        bufferedWriter.write(String.valueOf(result));
        bufferedWriter.newLine();

        bufferedWriter.close();

        scanner.close();
    }
}


Problem solution in C++ programming.

#include <cstdio>
#include <iostream>
#include <sstream>
#include <deque>
#include <queue>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <string>
#include <cstdlib>
#include <ctime>
using namespace std; 

#define P 1000000007
#define N 1100000

int used[N], fa[N], sum[N], f[N], now, ans, T, cc;
vector <int> V[N];
int n;
int a[N];

int gf(int x) {
    if (fa[x] != x)
        fa[x] = gf(fa[x]);
    return fa[x];
}

void merge(int x, int y) {
    x = gf(x);
    y = gf(y);
    sum[x] += sum[y];
    fa[y] = x;
}

void add(int x) {
    used[x] = 1;
    sum[x] = 1;
    if (used[x - 1]) {
        now = (now - f[sum[gf(x - 1)]] + P) % P;
        merge(x, x - 1);
    }
    if (used[x + 1]) {
        now = (now - f[sum[gf(x + 1)]] + P) % P;
        merge(x, x + 1);
    }
    now = (now + f[sum[gf(x)]]) % P;
    int L = sum[gf(1)], R = sum[gf(n)];
    // printf("?? %d %d %d\n", x, L, R);
    x = min(R, L - 1);
    if (x <= 0) {
        cc = now;
        return ;
    }
    cc = now;
    // printf("?? %d %d\n", cc, x);
    cc = (cc + 1LL * x * L * (R + 1)) % P;
    cc = (cc - 1LL * x * (x + 1) / 2 % P * (L + R + 1)) % P;
    cc = (cc + 1LL * x * (x + 1) * (2 * x + 1) / 6) % P;
    cc = (cc + P) % P;
    // printf("! %d\n", cc);
    return ;
}

int main() {
    scanf("%d", &n);
    int ma = 0;
    for (int i = 1; i <= n; i++)
        scanf("%d", &a[i]), V[a[i]].push_back(i), ma = max(ma, a[i]);
    T = 1LL * n * (n + 1) / 2 % P;
    T = 1LL * T * (T + 1) / 2 % P;
    for (int i = 1; i <= n; i++)
        f[i] = (1LL * i * (i + 1) * (2 * i + 1) / 6 + 1LL * i * (i + 1) / 2) / 2 % P;
    for (int i = 1; i <= n; i++)
        fa[i] = i;
    now = 0;

    for (int i = 0; i < ma; i++) {
        for (int j = 0; j < (int) V[i].size(); j++)
            add(V[i][j]);
        ans = (ans + T - cc) % P;
    }
    ans = (ans + P) % P;
    printf("%d\n", ans);

}


Problem solution in C programming.

#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse4")
#include<stdio.h>
#include<string.h>
#include<stdlib.h>
const int mod = 1000000007, _2 = 500000004;
int N, MX = 0, tp, a[200010], i_1[200010], st[200010], mxl[200010], mxr[200010], sxl[200010], sxr[200010];
long long M, CNT, ANS = 0;
void calc(int w, int x, int y)
{
    if( x < y )
    {
        int temp = x;
        x = y;
        y = temp;
    }
    int k;
    if( x == y )
    {
        k = ( ( (long long)( x + y ) * i_1[y] % mod - (long long)x * x % mod ) % mod + mod ) % mod;
    }
    else
    {
        k = ( ( (long long)y * ( i_1[x-1] - i_1[y] ) % mod + (long long)( x + y ) * i_1[y] % mod ) % mod + mod ) % mod;
    }
    ANS = ( ANS + (long long)w * k ) % mod;
    CNT -= k;
    if( CNT < 0 )
    {
        CNT += mod;
    }
}
void calcl(int w, int x, int y)
{
    if( x == 1 || y == 0 )
    {
        return;
    }
    int k;
    if( y < x )
    {
        k = i_1[y];
    }
    else
    {
        k = ( i_1[x-1] + (long long)( y - x + 1 ) * ( x - 1 ) ) % mod;
    }
    ANS = ( ANS + (long long)w * k ) % mod;
    CNT -= k;
    if( CNT < 0 )
    {
        CNT += mod;
    }
}
void calcr(int w, int x, int y)
{
    if( x == 0 || y == 1 )
    {
        return;
    }
    int k;
    if( y + 1 <= x )
    {
        k = i_1[y-1];
    }
    else
    {
        k = ( i_1[x] + (long long)( y - x - 1 ) * x ) % mod;
    }
    ANS = ( ANS + (long long)w * k ) % mod;
    CNT -= k;
    if( CNT < 0 )
    {
        CNT += mod;
    }
}
int main()
{
    int p;
    scanf("%d", &N);
    for( int i = 1 ; i <= N ; i++ )
    {
        scanf("%d", &a[i]);
        MX = MX > a[i] ? MX : a[i];
    }
    M = ( (long long)N * ( N + 1 ) >> 1 ) % mod;
    M = (long long)M * ( M + 1 ) % mod * _2 % mod;
    CNT = M;
    for( int i = 1 ; i <= N ; i++ )
    {
        i_1[i] = ( i_1[i-1] + i ) % mod;
    }
    for( int i = 1 ; i <= N ; i++ )
    {
        sxl[i] = sxl[i-1] > a[i] ? sxl[i-1] : a[i];
    }
    for( int i = N ; i ; i-- )
    
    {
        sxr[i] = sxr[i+1] > a[i] ? sxr[i+1] : a[i];
    }
    tp = 0;
    for( int i = 1 ; i <= N ; i++ )
    {
        while( tp > 0 && a[st[tp]] <= a[i] )
        {
            tp--;
        }
        if(tp)
        {
            mxl[i] = st[tp] + 1;
        }
        else
        {
            mxl[i] = 1;
        }
        st[++tp] = i;
    }
    tp = 0;
    for( int i = N ; i ; i-- )
    {
        while( tp > 0 && a[st[tp]] < a[i] )
        {
            tp--;
        }
        if(tp)
        {
            mxr[i] = st[tp] - 1;
        }
        else
        {
            mxr[i] = N;
        }
        st[++tp] = i;
    }
    for( int i = 1 ; i <= N ; i++ )
    {
        calc(a[i], i-mxl[i]+1, mxr[i]-i+1);
    }
    p = N;
    for( int i = 1 ; i <= N ; i++ )
    {
        int g = sxl[i];
        while( p > i && sxr[p] < g )
        {
            p--;
        }
        while( p < i )
        {
            p++;
        }
        calcl(g, i, N-p);
    }
    p = 1;
    for( int i = N ; i ; i-- )
    {
        int g = sxr[i];
        while( p < i && sxl[p] <= g )
        {
            p++;
        }
        while( p > i )
        {
            p--;
        }
        calcr(g, N-i+1, p-1);
    }
    CNT = ( CNT % mod + mod ) % mod;
    ANS = ( ANS + (long long)CNT * MX ) % mod;
    printf("%lld", ANS);
    return 0;
}