树链剖分
例题:
给定一棵顶点带权的树,有如下的操作。
-
修改某条路径上所有点的权值
-
询问某条路径上所有点的权值和
分析:
如果树是链,那么就可以用区更区查的线段树解决。
那么树该怎么办?我们用树链剖分。
1.什么是树链剖分
将树上的点按照某种方式组织起来,剖分成若干的链,每条链相当于一个序列。
这样,操作的路径可以拆分成某几条链,也就是若干条完整的序列或者某条序列上的一段区间。
此时,就可以用线段树等处理序列上的区间操作的数据结构解决问题。
因此,如何恰当的把树剖分成若干条链,是问题的核心。
2.剖分思想
轻重边剖分是最常用的剖分思想。
我们用size[u]是以u为根的子树结点个数。令v是u儿子中size最大的一个儿子,则称<u,v>
是重边,v是u的重儿子。反之,u到其余儿子的边为轻边。
那么轻重边有什么性质呢?
-
如果v是u的儿子,且<u,v>是轻边,则size[v]<=2size[u] ;
-
从根结点到某一点v的路径上的轻边个数不多于logn,其中n为树的结点数;
-
称某条路径为重路径,当且仅当它全部由重边组成。那么对于每个点到根的路径上都有不超过logn条重路径满足它不是任何一个重路径的真子路径;
所以轻重边剖分的好处也就自然而然了。
可以知道,一个点在且仅在一条重路径上,且一条重路径一定是一条从根节点方向向叶节点方向延申的深度递增的路径。
操作需要处理的路径<u,v>,只需要分别处理u,v两点分别到LCA(u,v)的路径即可。
一条路径可以分解成最多logn条重路径和最多logn条轻边。对于每个重路径,就相当于一个序列,用线段树维护;对于轻边,则可以直接跳过,因为轻边的两个端点一定在某两条重路径上。
综上所述,总体的时间复杂度是O(nlogn)。
3.代码实现
在实现代码之前,您需要熟练掌握线段树的内容。
由上述,我们可以知道,要维护的东西很多。
struct Vert{
int wei;//权重
int fa;//父亲
int dep;//深度
int size;//子树结点数
int son;//重儿子
int top;//所在重路径顶部结点
int seg;//所在线段树的下标
}vert[N];
其中,wei
是直接读入的,不需要维护。
第一遍dfs
时,我们需要求出fa,dep,size,son
。这是一个基本的树上dfs
。
void dfs1(int u,int fa){
vert[u].size=1;
vert[u].fa=fa;
vert[i].dep=vert[fa].dep+1;
int son=0;
for(int i=head[u];i;i=edge[i].next){
if(edge[i].v==fa) continue;
dfs1(edge[i].v,u);
vert[u].size+=vert[edge[i].v].size;
if(vert[edge[i].v].size>vert[son].size) son=edge[i].v;
}
vert[u].son=son;
}
现在要进行第二遍dfs
,需要同时求出top
和seg
。首先,因为需要维护线段树,所以我们建立一个线段树的结构体。每次加进一个线段树,都可以用一个函数add_seg_rev()
解决。
struct Tree{
int l,r;
int sum,mx;//sum表示和,mx表示最大值
int rev;//线段树第x个位置对应的树结点
}seg[N<<2];
void add_seg_rev(int x){
seg_num++;
vert[x].seg=seg_num;
seg[seg_num].rev=x;
}
做完这些准备操作,就可以进行第二遍dfs
了。
void dfs2(int u){
if(vert[u].son>0){
add_seg_rev(vert[u].son);
vert[vert[u].son].top=vert[u].top;
dfs2(vert[u].son);
}
for(int i=head[u];i;i=edge[i].next){
if(vert[edge[i].v].top) continue;
add_seg_rev(edge[i].v);
vert[edge[i].v].top=edge[i].v;
dfs2(edge[i].v);
}
}
需要注意的是,在dfs2
之前,还要将根节点插入线段树中。
add_seg_rev(1);
vert[1].top=1;
在两边dfs
之后,我们已经处理好了vert
结构体中的所有内容,并且将线段树顺序排好了,现在我们来建树。
void buildtree(int id,int l,int r){
seg[id].l=l;
seg[id].r=r;
if(l==r){
int wei=vert[seg[l].rev].wei;
seg[id].mx=wei;
seg[id].sum=wei;
return ;
}
int mid=(l+r)>>1;
buildtree(id<<1,l,mid);
buildtree(id<<1|1,mid+1,r);
seg[id].sum=seg[id<<1].sum+seg[id<<1|1].sum;
seg[id].mx=max(seg[id<<1].mx,seg[id<<1|1].mx);
}
然后,就是线段树的维护和查询操作了。我们以P2590 树的统计为例,这里是单点更新,区间查询。我们需要同时维护区间和和区间最大值。
struct Result{//用来记录查询的答案,捆绑返回
int ok;
int sum;
int mx;
};
void update(int id,int pos,int val){
if((pos<seg[id].l)||(pos>seg[id].r)) return ;
if((seg[id].l==seg[id].r)&&(seg[id].l==pos)){
seg[id].sum=val;
seg[id].mx=val;
return ;
}
update(id<<1,pos,val);
update(id<<1|1,pos,val);
seg[id].sum=seg[id<<1].sum+seg[id<<1|1].sum;
seg[id].mx=max(seg[id<<1].mx,seg[id<<1|1].mx);
}
Result merge(Result res1,Result res2){
Result res;
if(res1.ok&&res2.ok){
res.ok=true;
res.sum=res1.sum+res2.sum;
res.mx=max(res1.mx,res2.mx);
}
else if(res1.ok) res=res1;
else if(res2.ok) res=res2;
else res.ok=false;
return res;
}
Result query(int id,int le,int ri){
int l1=max(seg[id].l,le);
int r1=min(seg[id].r,ri);
int overlap=r1-l1+1;
if(overlap<=0) return (Result){false,0,0};
if(overlap>=seg[id].r-seg[id].l+1){
Result res;
res.ok=true;
res.sum=seg[id].sum;
res.mx=seg[id].mx;
return res;
}
Result res1=query(id<<1,le,ri);
Result res2=query(id<<1|1,le,ri);
Result resu=merge(res1,res2);
return resu;
}
最终,就是最后的树链查询了。我们需要做到的是将其分为若干个序列或序列的子序列,再将每一个在线段树上的位置查询一下就可以。
Result ask(int x,int y){
Result res;
res.ok=false;
int tx=vert[x].top;
int ty=vert[y].top;
while(tx!=ty){
if(vert[tx].dep<vert[ty].dep)
swap(x,y),swap(tx,ty);
Result res1=query(1,vert[tx].seg,vert[x].seg);
res=merge(res,res1);
x=vert[tx].fa;
tx=vert[x].top;
//将x跳到fa[top[x]],同时更新这段区间的内容
}
if(vert[x].dep>vert[y].dep) swap(x,y);
Result res1=query(1,vert[x].seg,vert[y].seg);
res=merge(res,res1);
//最后,x和y都在同一个重路径上了,直接用线段树查询即可。
return res;
}
主函数(查询)内容:
for(int i=1;i<=m;i++){
string opt;
int x,y;
cin>>opt;
scanf("%lld%lld",&x,&y);
if(opt[0]=='C'){
update(1,vert[x].seg,y);
}
else if(opt[1]=='M'){
Result res=ask(x,y);
printf("%lld\n",res.mx);
}
else{
Result res=ask(x,y);
printf("%lld\n",res.sum);
}
}
4.完整代码
以下是P2590 树的统计的完整代码。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=30005;
int n,m,s,MOD;
struct Vert{
int wei;//权重
int fa;//父亲
int dep;//深度
int size;//子树结点数
int son;//重儿子
int top;//所在重路径顶部结点
int seg;//所在线段树的下标
}vert[N];
struct Edge{
int u,v;
int next;
}edge[N<<1];
int head[N],tot=0;
struct Tree{
int l,r;
int sum,mx;
int rev;//线段树第x个位置对应的树结点
}seg[N<<2];
int seg_num;
struct Result{
bool ok;
int sum;
int mx;
};
void add_edge(int u,int v){
tot++;
edge[tot].u=u;
edge[tot].v=v;
edge[tot].next=head[u];
head[u]=tot;
}
void add_seg_rev(int x){
seg_num++;
vert[x].seg=seg_num;
seg[seg_num].rev=x;
}
void dfs1(int u,int fa){
vert[u].size=1;
vert[u].fa=fa;
vert[u].dep=vert[fa].dep+1;
int son=0;
for(int i=head[u];i;i=edge[i].next){
if(edge[i].v==fa) continue;
dfs1(edge[i].v,u);
vert[u].size+=vert[edge[i].v].size;
if(vert[edge[i].v].size>vert[son].size) son=edge[i].v;
}
vert[u].son=son;
}
void dfs2(int u){
if(vert[u].son>0){
add_seg_rev(vert[u].son);
vert[vert[u].son].top=vert[u].top;
dfs2(vert[u].son);
}
for(int i=head[u];i;i=edge[i].next){
if(vert[edge[i].v].top) continue;
add_seg_rev(edge[i].v);
vert[edge[i].v].top=edge[i].v;
dfs2(edge[i].v);
}
}
void buildtree(int id,int l,int r){
seg[id].l=l;
seg[id].r=r;
if(l==r){
int wei=vert[seg[l].rev].wei;
seg[id].mx=wei;
seg[id].sum=wei;
return ;
}
int mid=(l+r)>>1;
buildtree(id<<1,l,mid);
buildtree(id<<1|1,mid+1,r);
seg[id].sum=seg[id<<1].sum+seg[id<<1|1].sum;
seg[id].mx=max(seg[id<<1].mx,seg[id<<1|1].mx);
}
void update(int id,int pos,int val){
if((pos<seg[id].l)||(pos>seg[id].r)) return ;
if((seg[id].l==seg[id].r)&&(seg[id].l==pos)){
seg[id].sum=val;
seg[id].mx=val;
return ;
}
update(id<<1,pos,val);
update(id<<1|1,pos,val);
seg[id].sum=seg[id<<1].sum+seg[id<<1|1].sum;
seg[id].mx=max(seg[id<<1].mx,seg[id<<1|1].mx);
}
Result merge(Result res1,Result res2){
Result res;
if(res1.ok&&res2.ok){
res.ok=true;
res.sum=res1.sum+res2.sum;
res.mx=max(res1.mx,res2.mx);
}
else if(res1.ok) res=res1;
else if(res2.ok) res=res2;
else res.ok=false;
return res;
}
Result query(int id,int le,int ri){
int l1=max(seg[id].l,le);
int r1=min(seg[id].r,ri);
int overlap=r1-l1+1;
if(overlap<=0) return (Result){false,0,0};
if(overlap>=seg[id].r-seg[id].l+1){
Result res;
res.ok=true;
res.sum=seg[id].sum;
res.mx=seg[id].mx;
return res;
}
Result res1=query(id<<1,le,ri);
Result res2=query(id<<1|1,le,ri);
Result resu=merge(res1,res2);
return resu;
}
Result ask(int x,int y){
Result res;
res.ok=false;
int tx=vert[x].top;
int ty=vert[y].top;
while(tx!=ty){
if(vert[tx].dep<vert[ty].dep)
swap(x,y),swap(tx,ty);
Result res1=query(1,vert[tx].seg,vert[x].seg);
res=merge(res,res1);
x=vert[tx].fa;
tx=vert[x].top;
}
if(vert[x].dep>vert[y].dep) swap(x,y);
Result res1=query(1,vert[x].seg,vert[y].seg);
res=merge(res,res1);
return res;
}
signed main(){
scanf("%lld",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%lld%lld",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
for(int i=1;i<=n;i++) scanf("%lld",&vert[i].wei);
dfs1(1,0);
add_seg_rev(1);
vert[1].top=1;
dfs2(1);
buildtree(1,1,n);
scanf("%lld",&m);
for(int i=1;i<=m;i++){
string opt;
int x,y;
cin>>opt;
scanf("%lld%lld",&x,&y);
if(opt[0]=='C'){
update(1,vert[x].seg,y);
}
else if(opt[1]=='M'){
Result res=ask(x,y);
printf("%lld\n",res.mx);
}
else{
Result res=ask(x,y);
printf("%lld\n",res.sum);
}
}
exit(0);
}
Stay hungry, stay foolish.