ShinriiTin/notes/DP_of_DP

DP套DP学习笔记

通过一个外层的DP来计算使得内层的DP方程(子DP)最终结果为特定值的输入数

例题

HDU 4899 Hero meet devil

题目大意:给出一个只由\(ACGT\)组成的字符串\(S\),(\(|S|\le15\)),对于每个\(0\le i \le|S|\),求有多少个不同的只由\(ACGT\)组成的长度为\(1\le m\le1000\)的字符串\(T\),使得\(LCS(S,T)=i\)

题解

子DP的方程为\(f[i][j] = \begin{cases} f[i-1][j-1]+1 &, S[i]=T[j]\\ \max(f[i-1][j],f[i][j-1]) &,S[i]\not=T[j] \end{cases}\)

\(f[i][j]\)\(f[i+1][j]\)至多只会相差\(1\),将第一维状态差分一下,得到一个长度为\(|S|\)的01串,这个01串记录了子DP中\(f[0,\cdots,|S|][j]\)的所有状态,因此,我们可以通过这个01串来进行外层的DP。

设计外层DP为\(g[i][set]\)表示考虑\(T\)的前\(i\)个位置,\(f[0,\cdots,|S|][i]\)的状态为\(set\)的方案数。

我们可以预处理出子DP每个状态对应每种转移后能得到的状态值(建出关于子DP的自动机),然后外层DP的转移就是利用这些预处理的信息来实现的。

子DP的状态数为\(2^{|S|}\),转移数为\(4\times2^{|S|}\),令\(n=|S|\),则建立子DP的自动机的时间复杂度为\(O(4\times2^n\times n)\),空间复杂度为\(O(4\times2^n)\)

外层DP的状态数为\(m\times2^n\),转移数为\(4\times2^n\times m\),时间复杂度为\(O(4\times2^n\times m)\),采用滚动数组,空间复杂度为\(O(2^n)\)

因此总的时间复杂度为\(4\times2^n\times(n+m)\),空间复杂度为\(O(4\times2^n)\)

HDU 5079 Square

题目大意:给出一个\(n\times n(n\le8)\)的棋盘,棋盘上每个位置最开始都是黑色的,有一些位置损坏了,你可以选择一些没有损坏的格子,染成白色,问对于\(0\le i \le n\),有多少种不同的染色方案,使得染色后棋盘的最大白色子正方形的边长等于\(i\)

题解

根据子DP的不同,这道题的做法也有所不同。

方法1

\(f[i][j]\)表示以位置\((i,j)\)作为右下角的最大白色正方形,则有转移方程如下(业界良心陈老师的WC讲稿上将min错写为了max): \[ f[i][j]=\begin{cases} \min(f[i-1][j],f[i][j-1],f[i-1][j-1])+1&,(i,j)\text{ is white}\\ 0 &,(i,j) \text{ is black}\\ \end{cases} \] 然后考虑如何记录这个子DP的所有状态。我们考虑一行一行的\(DP\),把每一行的\(f\)和全局最大的\(f\)一起状压成一个\(n+1\)位的\(n+1\)进制数。

\(n=8\)时最多有\(13089\)种不同的状态(开始按陈老师讲稿上说的一格一格转移,但是每一行第一列都要特殊处理,写起来很麻烦,状态也挺多的,就放弃了)。

子DP的状态数最多有\(13089\)种,每种状态枚举\(2^n\)种不同的转移。

外层DP的状态数有\(13089\times n\)个,转移有至多\(13089\times n\times2^n\)种。

#include <bits/stdc++.h>

#define pb push_back

struct sta{
	int n,a[9];	
	inline void decode(int n,int s){
		this->n=n;
		for(int i=0;i<=n;s/=(n+1),++i)a[i]=s%(n+1);	
	}
	inline int code(){
		int s=0;
		for(int i=n;~i;--i)s=s*(n+1)+a[i];
		return s;	
	}
	inline sta trans(int set){
		sta res; res.n=n,res.a[0]=a[0];
		for(int i=1;i<=n;++i){
			if((set>>(i-1))&1){
				res.a[i]=i>1?std::min(res.a[i-1],std::min(a[i-1],a[i]))+1:1;
			}
			else res.a[i]=0;
			if(res.a[i]>res.a[0])res.a[0]=res.a[i];
		}
		return res;
	}
};

std::unordered_map<int,int>vis;

int cnt;

std::vector<int>vec,tmp;

std::queue<int>q;

inline bool check_new(int sta){
	if(vis.count(sta))return 0;
	vis[sta]=cnt++,vec.pb(sta);
	return 1;
}

inline void push_queue(int sta){
	if(check_new(sta))q.push(sta);
}

inline int pop_queue(){
	int sta=q.front();
	return q.pop(),sta;	
}

const int max_S=13100,mod=1e9+7;

struct DFA{
	std::vector<int>vec;
	int m,trans[max_S][1<<8];
}F[9];

inline void init(int n){
	vis.clear();
	cnt=0,vec.clear();
	while(!q.empty())q.pop();
	push_queue(0);
	for(int i=1;i<=n;++i){
		tmp.clear();
		while(!q.empty())tmp.pb(pop_queue());
		for(auto&s:tmp){
			sta cur; cur.decode(n,s);
			for(int set=0,trans;set<(1<<n);++set){
				trans=cur.trans(set).code();
				push_queue(trans);
				F[n].trans[vis[s]][set]=vis[trans];
			}
		}
	}
	F[n].m=cnt,F[n].vec=vec;
//	printf("%d\n",cnt);
}

int T,n,m,s,cur,f[2][max_S],ans[10];

char str[10];

std::vector<int>trans;

int main(){
	for(int i=1;i<=8;++i)init(i);
	scanf("%d\n",&T);
	while(T--){
		scanf("%d",&n);
		m=F[n].m,s=1<<n;
		memset(f[cur=0],0,sizeof(int)*m);
		f[0][0]=1;
		for(int i=1;i<=n;++i){
			scanf("%s",str);
			memset(f[cur^=1],0,sizeof(int)*m);
			trans.clear();
			for(int set=0;set<s;++set){
				bool flag=1;
				for(int j=0;j<n;++j)
					if(set>>j&1)flag&= str[j]=='o';
				if(flag)trans.pb(set);
			}
			for(int j=0,tmp;j<m;++j)if((tmp=f[cur^1][j])>0){
				for(auto&set:trans){
					int&res=f[cur][F[n].trans[j][set]];
					if((res+=tmp)>=mod)res-=mod;
				}
			}
		}
		for(int i=0;i<=n;++i)ans[i]=0;
		for(int i=0,tmp;i<m;++i)if((tmp=f[cur][i])>0){
			int&res=ans[F[n].vec[i]%(n+1)];
			if((res+=tmp)>=mod)res-=mod;
		}
		for(int i=0;i<=n;++i)printf("%d\n",ans[i]);
	}
	return 0;
}

方法2

还是考虑方法1中的子DP。其实我们不需要多花一个位置来记录全局最优值,只需要换一种思路,做\(n+1\)次DP套DP,第\(i\)次限制只能由子DP值小于\(i\)的状态来做子DP,最后将\(n+1\)次计算的结果差分一下就得到了总的答案。

因此,我们只需要用一个\(n\)位的\(n+1\)进制数来记录一行中每个位置的\(f\)值,这样对于\(n\le8\)的情况,最多就只有\(3528\)种不同的状态。

#include <bits/stdc++.h>

#define pb push_back

struct sta{
	int n,a[8];	
	inline void decode(int n,int s){
		this->n=n;
		for(int i=0;i<n;s/=(n+1),++i)a[i]=s%(n+1);	
	}
	inline int code(){
		int s=0;
		for(int i=n-1;~i;--i)s=s*(n+1)+a[i];
		return s;	
	}
	inline sta trans(int set){
		sta res; res.n=n;
		for(int i=0;i<n;++i){
			if(set>>i&1){
				res.a[i]=i?std::min(res.a[i-1],std::min(a[i-1],a[i]))+1:1;
			}
			else res.a[i]=0;
		}
		return res;
	}
};

std::unordered_map<int,int>vis;

int cnt;

std::vector<int>vec,tmp;

std::queue<int>q;

inline bool check_new(int sta){
	if(vis.count(sta))return 0;
	vis[sta]=cnt++,vec.pb(sta);
	return 1;
}

inline void push_queue(int sta){
	if(check_new(sta))q.push(sta);
}

inline int pop_queue(){
	int sta=q.front();
	return q.pop(),sta;	
}

const int max_S=4000,mod=1e9+7;

struct DFA{
	int m,limit[9],trans[max_S][1<<8];
}F[9];

inline void init(int n){
	static int Trans[max_S][1<<8];
	static int x[max_S],id[max_S],max[max_S],siz[10];
	vis.clear();
	cnt=0,vec.clear();
	while(!q.empty())q.pop();
	push_queue(0);
	for(int i=1;i<=n;++i){
		tmp.clear();
		while(!q.empty())tmp.pb(pop_queue());
		for(auto&s:tmp){
			sta cur; cur.decode(n,s);
			for(int set=0,trans;set<(1<<n);++set){
				trans=cur.trans(set).code();
				push_queue(trans);
				Trans[vis[s]][set]=vis[trans];
			}
		}
	}
//	printf("%d\n",cnt);
	for(int i=0;i<=n;++i)siz[i]=0;	
	for(int i=0;i<cnt;++i){
		sta cur; cur.decode(n,vec[i]);
		max[i]=0;
		for(int j=0;j<n;++j)
			if(cur.a[j]>max[i])max[i]=cur.a[j];
		++siz[max[i]];
	}
	for(int i=1;i<=n;++i)siz[i]+=siz[i-1];
	for(int i=0;i<=n;++i)F[n].limit[i]=siz[i];
	for(int i=cnt-1;~i;--i)x[id[i]=--siz[max[i]]]=i;
	F[n].m=cnt;
	for(int i=0;i<cnt;++i)
		for(int set=0;set<(1<<n);++set){
			F[n].trans[i][set]=id[Trans[x[i]][set]];
		}
}

int T,n,f[2][max_S],ans[10];

char str[10][10];

std::vector<int>trans;

inline int solve(int n,int k){
	int cur=0,m=F[n].limit[k],s=1<<n;
	memset(f[0],0,sizeof(int)*m);
	f[0][0]=1;
	for(int i=1;i<=n;++i){
		memset(f[cur^=1],0,sizeof(int)*m);	
		std::vector<int>trans;
		for(int set=0;set<s;++set){
			bool flag=1;
			for(int j=0;j<n;++j)if(set>>j&1)flag&=str[i][j]=='o';
			if(flag)trans.pb(set);
		}
		for(int j=0,tmp;j<m;++j)if((tmp=f[cur^1][j])>0){
			for(auto&set:trans){
				if(F[n].trans[j][set]>=m)continue;
				int&res=f[cur][F[n].trans[j][set]];
				if((res+=tmp)>=mod)res-=mod;	
			}
		}
	}
	int res=0;
	for(int j=0,tmp;j<m;++j)if((tmp=f[cur][j])>0){
		if((res+=tmp)>=mod)res-=mod;
	}
	return res;
}

int main(){
	for(int i=1;i<=8;++i)init(i);
	scanf("%d\n",&T);
	while(T--){
		scanf("%d",&n);
		for(int i=1;i<=n;++i)scanf("%s",str[i]);
		for(int i=0;i<=n;++i)ans[i]=solve(n,i);
		for(int i=n;i;--i)if((ans[i]-=ans[i-1])<0)ans[i]+=mod;
		for(int i=0;i<=n;++i)printf("%d\n",ans[i]);
	}
	return 0;
}

方法3

方法2的基础上,我们注意到每一行前\(i-2\)个位置的DP值其实是不需要记录的,因为以它们为右下角不可能产生大小为\(i\)的正方形,并且后面的位置的DP值也与它们无关。

这样我们只需要记录\(n-i+2\)个位置的信息即可,当\(i\)很小的时候可以直接暴力。这样的状态数量更少,运行效率更高。