1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| #include <bits/stdc++.h> #define int long long #define o(a,b) st[a][b] #define sz(k) st[k].size() #define p(k) st[k].size()-1 #define q(k) st[k].size()-2 using namespace std;
inline int gin(){ char c=getchar(); int s=0,f=1; while(c<'0' || c>'9'){ if(c=='-')f=-1; c=getchar(); } while(c>='0'&&c<='9'){ s=(s<<3)+(s<<1)+(c^48); c=getchar(); } return s*f; }
const int N=1e5+5; int n,f[N],a[N],b[10005],s[N]; vector<int> st[10005];
inline int calc(int x,int sum){ return f[x-1]+a[x]*sum*sum; }
inline int find(int x,int y){ int ret=n+1; int l=1,r=n; while(l<=r){ int mid=l+r>>1; if(calc(x,mid-s[x]+1)>=calc(y,mid-s[y]+1)) ret=mid,r=mid-1; else l=mid+1; } return ret; }
signed main(){ n=gin(); for(int i=1;i<=n;i++){ a[i]=gin(); s[i]=++b[a[i]]; } for(int i=1;i<=n;i++){ int c=a[i]; while(sz(c)>=2 && find(o(c,q(c)),o(c,p(c)))<=find(o(c,p(c)),i)) st[c].pop_back(); st[c].push_back(i); while(sz(c)>=2 && find(o(c,q(c)),o(c,p(c)))<=s[i]) st[c].pop_back(); f[i]=calc(o(c,p(c)),s[i]-s[o(c,p(c))]+1); } printf("%lld\n",f[n]); return 0; }
|