树形dp

  • 1、基础树形dp:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)给一棵树以及根的编号,求出每个点的深度和子树大小:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<bits/stdc++.h>
#define in inline
#define re register
const int N=N;
using namespace std;
int n,r;
int cnt,head[N];
struct edge
{
int to,nxt;
};
edge e[N<<1];
int d[N],s[N];
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
return ;
}
in void dfs(re int u,re int fa)
{
s[u]=1;
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
d[v]=d[u]+1;
dfs(v,u);
s[u]+=s[v];
}
return ;
}
int main()
{
n=qread();
r=qread();
for(re int i=1;i<n;i++)
{
int u=qread();
int v=qread();
mr(u,v);
mr(v,u);
}
d[r]=1;
dfs(r,0);
for(re int i=1;i<=n;i++)
{
qwrite(d[i]);
putchar(' ');
}
putchar('\n');
for(re int i=1;i<=n;i++)
{
qwrite(s[i]);
putchar(' ');
}
putchar('\n');
return 0;
}
  • 2、递归前转移树形dp:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)给一棵树以及根的编号,点有点权,可正可负,求从根到叶子节点的链上的点权和最大的子链的点权和:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<bits/stdc++.h>
#define in inline
#define re register
const int N=N;
using namespace std;
int n,r;
int cnt,head[N];
struct edge
{
int to,nxt;
};
edge e[N<<1];
int val[N],f[N];
int ans;
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
return ;
}
in void dfs(re int u,re int fa)
{
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
f[v]=max(f[v],f[u]+a[v]);
ans=max(ans,f[v]);
dfs(v,u);
}
return ;
}
int main()
{
n=qread();
r=qread();
for(re int i=1;i<n;i++)
{
int u=qread();
int v=qread();
}
for(re int i=1;i<=n;i++)
{
val[i]=qread();
f[i]=val[i];
}
ans=val[r];
dfs(r,0);
qwrite(ans);
putchar('\n');
return 0;
}
  • 递归后转移树形dp:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)P1352 没有上司的舞会

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<bits/stdc++.h>
#define in inline
#define re register
const int N=6010;
using namespace std;
int n;
int cnt,head[N];
struct edge
{
int to,nxt;
};
edge e[N<<1];
int val[N],f[N][2];
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
return ;
}
in void dfs(re int u,re int fa)
{
f[u][1]=val[u];
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
dfs(v,u);
f[u][0]+=max(f[v][1],f[v][0]);
f[u][1]+=f[v][0];
}
return ;
}
int main()
{
n=qread();
for(re int i=1;i<=n;i++)
{
val[i]=qread();
}
for(re int i=1;i<n;i++)
{
int u=qread();
int v=qread();
mr(u,v);
mr(v,u);
}
dfs(1,0);
qwrite(max(f[1][0],f[1][1]));
putchar('\n');
return 0;
}
  • 4、树上01背包:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)P2014 [CTSC1997]选课

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<bits/stdc++.h>
#define in inline
#define re register
cosnt int N=310;
using namespace std;
int n,m;
int cnt,head[N];
struct edge
{
int to,nxt;
};
edge e[N<<1];
int val[N],s[N],f[N][N];
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
return ;
}
in void dfs(re int u,re int fa)
{
memset(f[u],-0x3f,sizeof(f[u]));
s[u]=1;
f[u][1]=val[u];
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
dfs(v,u);
for(re int j=s[u];j;j--)
{
for(re int k=s[v];~k;k--)
{
f[u][j+k]=max(f[u][j+k],f[u][j]+f[v][k]);
}
}
s[u]+=s[v];
}
return ;
}
int main()
{
n=qread();
m=qread()+1;
for(re int i=1;i<=n;i++)
{
int fa=qread();
val[i]=qread();
mr(fa,i);
mr(i,fa);
}
dfs(0,0);
qwrite(f[0][m]);
putchar('\n');
return 0;
}
  • 5、树上分组背包:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)P1273 有线电视网

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include<bits/stdc++.h>
#define in inline
#define re register
const int N=3010;
using namespace std;
int n,m;
int cnt,head[N];
struct edge
{
int to,nxt,val;
};
edge e[N<<1];
int f[N][N],s[N];
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v,re int w)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
e[cnt].val=w;
head[u]=cnt;
return ;
}
in void dfs(re int u,re int fa)
{
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
dfs(v,u);
for(re int j=s[u];~j;j--)
{
for(re int k=s[v];~k;k--)
{
f[u][j+k]=max(f[u][j+k],f[u][j]+f[v][k]-e[i].val);
}
}
s[u]+=s[v];
}
return ;
}
int main()
{
memset(f,-0x3f,sizeof(f));
n=qread();
m=qread();
for(re int i=1;i<=n-m;i++)
{
int k=qread();
for(re int j=1;j<=k;j++)
{
int v=qread();
int w=qread();
mr(i,v,w);
mr(v,i,w);
}
f[i][0]=0;
}
for(re int i=n-m+1;i<=n;i++)
{
f[i][1]=qread();
f[i][0]=0;
s[i]=1;
}
dfs(1,0);
for(re int i=m;i>=0;i--)
{
if(!i)
{
putchar('0');
putchar('\n');
return 0;
}
if(f[1][i]>=0)
{
qwrite(i);
putchar('\n');
return 0;
}
}
return 0;
}
  • 6、基础换根dp:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)P2986 [USACO10MAR]Great Cow Gathering G

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include<bits/stdc++.h>
#define in inline
#define re register
const int N=100010;
#define int long long
using namespace std;
int n;
int cnt,head[N],val[N];
int d[N],s[N];
struct edge
{
int to,nxt,val;
};
edge e[N<<1];
int f[N],ans;
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v,re int w)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
e[cnt].val=w;
head[u]=cnt;
return ;
}
in void dfs1(re int u,re int fa)
{
s[u]=val[u];
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
d[v]=d[u]+e[i].val;
dfs1(v,u);
s[u]+=s[v];
}
return ;
}
in void dfs2(re int u,re int fa)
{
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
{
continue;
}
f[v]=f[u]+(s[1]-(s[v]<<1))*e[i].val;
ans=min(ans,f[v]);
dfs2(v,u);
}
}
signed main()
{
n=qread();
for(re int i=1;i<=n;i++)
{
val[i]=qread();
}
for(re int i=1;i<n;i++)
{
int u=qread();
int v=qread();
int w=qread();
mr(u,v,w);
mr(v,u,w);
}
dfs1(1,0);
for(re int i=1;i<=n;i++)
{
f[1]+=d[i]*val[i];
}
ans=f[1];
dfs2(1,0);
qwrite(ans);
putchar('\n');
return 0;
}

\(\qquad\!\!\)例题:

\(\qquad\!\!\)P1453 城市环路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
//P1453 城市环路
#include<bits/stdc++.h>
#define in inline
#define re register
const int N=100010;
using namespace std;
int n;
int cnt,head[N];
struct edge
{
int to,nxt;
};
edge e[N<<1];
bool b[N];
int tot;
int rd[N],c[N];
int f[N][2],g[N][2];
int ans;
double k;
in int qread()
//快读
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void mr(re int u,re int v)
//建边
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
return ;
}
in void dfs0(re int u,re int fa)
//找环
{
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
//父节点或已在环上就不算
if(v==fa||b[v])
{
continue;
}
//找到一个点就加入环
if(rd[v]==2)
{
b[v]=1;
c[++tot]=v;
//递归
dfs0(v,u);
//加入一个点后就跳出循环
break;
}
}
return ;
}
in void tppx()
//拓扑排序
{
queue<int> q;
for(re int i=1;i<=n;i++)
{
if(rd[i]==1)
{
q.push(i);
}
}
while(!q.empty())
{
int u=q.front();
q.pop();
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
rd[v]--;
if(rd[v]==1)
{
q.push(v);
}
}
}
for(re int i=1;i<=n;i++)
{
//找到第一个在环上的点加入环并进行 dfs
if(rd[i]==2)
{
b[i]=1;
c[++tot]=i;
dfs0(i,0);
//加入一个点后就跳出循环
break;
}
}
return ;
}
in void dfs(re int u,re int fa)
//树形 dp
{
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa||b[v])
//父节点或已在环上就不算
{
continue;
}
dfs(v,u);
//转移
f[u][0]+=max(f[v][1],f[v][0]);
f[u][1]+=f[v][0];
}
return ;
}
int main()
{
//读入
n=qread();
for(re int i=1;i<=n;i++)
{
f[i][1]=qread();
}
for(re int i=1;i<=n;i++)
{
int u=qread()+1;
int v=qread()+1;
//统计入度
rd[u]++;
rd[v]++;
//建边
mr(u,v);
mr(v,u);
}
//拓扑排序
tppx();
//对于每棵树 树形 dp
for(re int i=1;i<=tot;i++)
{
dfs(c[i],0);
}
//环形 dp
memset(g,-0x3f,sizeof(g));
g[1][0]=f[c[1]][0];
for(re int i=2;i<=tot;i++)
{
g[i][1]=g[i-1][0]+f[c[i]][1];
g[i][0]=max(g[i-1][0],g[i-1][1])+f[c[i]][0];
}
ans=max(g[tot][0],g[tot][1]);
memset(g,-0x3f,sizeof(g));
g[1][1]=f[c[1]][1];
for(re int i=2;i<=tot;i++)
{
g[i][1]=g[i-1][0]+f[c[i]][1];
g[i][0]=max(g[i-1][0],g[i-1][1])+f[c[i]][0];
}
ans=max(ans,g[tot][0]);
//读入 k
scanf("%lf",&k);
//输出答案
printf("%.1lf\n",ans*k);
return 0;
}
  • 9、基环树之开环直接树形dp:

\(\qquad\!\!\)例题:

\(\qquad\!\!\)[ZJOI2008]骑士

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include<bits/stdc++.h>
#define in inline
#define re register
#define inf 0x3f3f3f3f3f3f3f3f
#define int long long
const int N=1000010;
using namespace std;
int n,a[N],val[N];
int cnt,head[N];
struct edge
{
int to,nxt;
};
edge e[N];
bool vis[N];
int c;
int f[N][2];
int ans;
in int qread()
{
int x=0,y=1;
int ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
y=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*y;
}
in void qwrite(re int x)
{
if(x<0)
{
putchar('-');
qwrite(-x);
}
else
{
if(x>9)
{
qwrite(x/10);
}
putchar(x%10+'0');
}
return ;
}
in void mr(re int u,re int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
return ;
}
in void dfs1(re int u)
{
vis[u]=1;
if(vis[a[u]])
{
c=u;
}
else
{
dfs1(a[u]);
}
return ;
}
in void dfs2(re int u)
{
vis[u]=1;
f[u][1]=val[u];
f[u][0]=0;
for(re int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v^c)
{
dfs2(v);
f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
}
else
{
f[u][1]=-inf;
}
}
}
signed main()
{
n=qread();
for(re int i=1;i<=n;i++)
{
val[i]=qread();
a[i]=qread();
mr(a[i],i);
}
for(re int i=1;i<=n;i++)
{
if(vis[i])
{
continue;
}
dfs1(i);
dfs2(c);
int maxx=max(f[c][0],f[c][1]);
c=a[c];
dfs2(c);
ans+=max(maxx,max(f[c][0],f[c][1]));
}
qwrite(ans);
putchar('\n');
return 0;
}