#include<bits/stdc++.h> #define MAXN 200005 #define int long long usingnamespace std; inlineintread(){ int x=0,f=1; char ch=getchar(); while (ch<'0'||ch>'9'){ if (ch=='-') f=-1; ch=getchar(); } while (ch>='0'&&ch<='9'){ x=(x<<3)+(x<<1)+(ch^'0'); ch=getchar(); } return x*f; } int M[MAXN],f[MAXN]; vector<int>G[MAXN]; inlinevoidAddEdge(int u,int v){ G[u].push_back(v); } int a[MAXN],b[MAXN]; int cnt1,cnt2; voiddfs(int u,int father,int chain){ if (chain!=-1){ if (chain==0) a[++cnt1]=M[u]; else b[++cnt2]=M[u]; } for (registerint i=0;i<G[u].size();++i){ int v=G[u][i]; if (v==father) continue; if (chain==-1) dfs(v,u,i); elsedfs(v,u,chain); } } inlineintcmp(int a,int b){return a>b;} #undef int intmain(){ #define int long long int n=read(); for (registerint i=1;i<=n;++i) M[i]=read(); for (registerint i=2;i<=n;++i) { f[i]=read(); AddEdge(i,f[i]); AddEdge(f[i],i); } dfs(1,1,-1); sort(a+1,a+1+cnt1,cmp); sort(b+1,b+1+cnt2,cmp); int ans=0; for (registerint i=1;i<=min(cnt1,cnt2);++i){ ans+=max(a[i],b[i]); } if (cnt1<cnt2) for (registerint i=cnt1+1;i<=cnt2;++i) ans+=b[i]; elsefor (registerint i=cnt2+1;i<=cnt1;++i) ans+=a[i]; printf("%lld\n",ans+M[1]); }
#include<bits/stdc++.h> #define MAXN 200005 usingnamespace std; inlineintread(){ int x=0,f=1; char ch=getchar(); while (ch<'0'||ch>'9'){ if (ch=='-') f=-1; ch=getchar(); } while (ch>='0'&&ch<='9'){ x=(x<<3)+(x<<1)+(ch^'0'); ch=getchar(); } return x*f; } priority_queue<int>Q[MAXN]; int stk[MAXN],top; inlinevoidMerge(int x,int y){//Merge x to y if (Q[x].size()>Q[y].size()) swap(Q[x],Q[y]); top=0; while (Q[x].size()){ stk[++top]=max(Q[x].top(),Q[y].top()); Q[x].pop(),Q[y].pop(); } for (registerint i=1;i<=top;++i) Q[y].push(stk[i]); } vector<int>G[MAXN]; inlinevoidAddEdge(int u,int v){ G[u].push_back(v); } int M[MAXN]; voiddfs(int u,int father){ for (registerint i=0;i<G[u].size();++i){ int v=G[u][i]; if (v!=father) dfs(v,u),Merge(v,u); } Q[u].push(M[u]); } intmain(){ int n=read(); for (registerint i=1;i<=n;++i) M[i]=read(); for (registerint i=2;i<=n;++i){ int f=read(); AddEdge(i,f); AddEdge(f,i); } dfs(1,1); longlong ans=0; while (Q[1].size()) ans+=Q[1].top(),Q[1].pop(); printf("%lld\n",ans); }
细心的读者一定会发现,其实Merge函数对应的就是上面链情况的代码:
1 2 3 4 5 6
int ans=0; for (registerint i=1;i<=min(cnt1,cnt2);++i){ ans+=max(a[i],b[i]); } if (cnt1<cnt2) for (registerint i=cnt1+1;i<=cnt2;++i) ans+=b[i]; elsefor (registerint i=cnt2+1;i<=cnt1;++i) ans+=a[i];