2019-09-10 12:29:53 C++

C++

Copy Copied! Full
#include <bits/stdc++.h> using namespace std; #define ALL(A) (A).begin(),(A).end() #define ll long long const int inf = 1e9+7; int sz = 1e5+5; vector<ll> bit(sz+1); int lowbit(int x){ return x&-x; } void add(int x,int w){ for(int i = x;i<sz;i += lowbit(i)){ bit[i] += w; } } ll sum(int x){ ll ret = 0; for(int i = x;i>0;i-=lowbit(i)){ ret += bit[i]; } return ret; } int search_small(int index){ ll x = sum(index); int ok = -1; int ng = index; while(abs(ok-ng)>1){ int mid = (ok+ng)/2; if(sum(mid)<x)ok = mid; else ng = mid; } return ok; } int search_big(int index){ ll x = sum(index); int ok = sz; int ng = index; while(abs(ok-ng)>1){ int mid = (ok+ng)/2; int count = sum(mid); if(count > x)ok = mid; else ng = mid; } return ok; } int main(void){ int N; cin >> N; vector<ll> a(N+1); for(int i=1;i<=N;i++){ int tmp; cin >> tmp; a[tmp]= i; } ll ans = 0; for(int i=N;i>=1;i--){ int nowpos = a[i]; ll s1 = search_small(nowpos); ll s2 = search_small(s1); ll b1 = search_big(nowpos); b1 = min(b1,(ll)N+1); ll b2 = search_big(b1); b2 = min(b2,(ll)N+1); ll cases = 0; if(s1!=-1)cases += (s1-s2)*(b1-nowpos); if(b1!=N+1)cases += (nowpos - (s1+1))*(b2-b1); ans += i*cases; add(nowpos,1); } cout << ans << endl; }
RECOMMEND