[bzoj5109]-[CodePlus 2017]大吉大利,晚上吃鸡!

吃鸡专用(雾)


题目(BZOJ)
题目(LibreOJ)
来说说这题吧
首先肯定要求 $s$ 到 $t$ 的最短路径条数,我们就根据这个来做。
一对 $(A,B)$ 满足条件,则 $s$ 到 $A$ 到 $t$ 的路径条数$+s$ 到 $B$ 到 $t$ 的路径条数$==s$到 $t$ 的路径条数,这个很显然,
于是我们设 $sum[0][i]$ 表示从 $s$ 出发到i的最短路径条数,$sum[1][i]$ 表示从 $t$ 出发到 $i$ 的最短路径条数,
则条件一转化为: $sum[0][A]\times sum[1][A]+sum[0][B]\times sum[1][B]=sum[0][t]$ ,
移一下项:$sum[0][A]\times sum[1][A]=sum[0][t]-sum[0][B]\times sum[1][B]$,
所以我们将 $sum[0][A]\times sum[1][A]$ 作为状态存入一个 map 中,也就相当于把    $sum[0][t]-sm[0][B]*sum[1][B]$ 存入 map 中。
同时,我们再开一个 $g$ 数组,表示 $A$ 与 $B$ 相互可达性。
然后就做完了。
总而言之,算法流程如下:

  1. 正反各一次求一遍最短路,更新 $sum$ 数组;
  2. 预处理满足 $sum[0][i]+sum[1][i]==sum[0][t]$ 的点,放到 $p$ 数组中,并且预处理这个 map;
  3. 对 $p$ 数组中的点正反各一次拓扑排序,更新 $g$ 数组;
  4. 最后做个统计,完毕。

不过这题有 $s$ 和 $t$ 不联通的情况……直接输出 $\frac{n(n-1)}{2}$ 即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=50050;
const ll inf=1e15+5;
int n,m,s,t;
int p[N],cnt;
int sta[N],top,in[N],pos[N];
map <ll,bitset<N> > mp;
bitset <N> g[2][N];
struct edge{
	int vk,wk,nxt;
}e[N<<1];
int head[N],tt;
inline void ad(int u,int v,int w){e[++tt].vk=v;e[tt].wk=w;e[tt].nxt=head[u];head[u]=tt;}
int vis[N];ll dis[2][N];
ll sum[2][N],ans;
queue <int> q;
void sp(int k,int S,int T){
	sum[k][S]=1;
	memset(vis,0,sizeof vis);
	for (int i=1;i<=n;++i) dis[k][i]=inf;
	dis[k][S]=0;vis[S]=1;q.push(S);
	while(!q.empty()){
		int u=q.front();q.pop();vis[u]=0;
		for (int i=head[u];i;i=e[i].nxt){
			int v=e[i].vk;
			if (dis[k][v]>dis[k][u]+e[i].wk){
				sum[k][v]=sum[k][u];
				dis[k][v]=dis[k][u]+e[i].wk;
				if (!vis[v]) vis[v]=1,q.push(v);
			}
			else if (dis[k][v]==dis[k][u]+e[i].wk) sum[k][v]+=sum[k][u];
		}
	}
	if (dis[k][T]==inf) printf("%lld",(ll)n*(n-1)/2),exit(0);
}
void topsort(int k){
	top=0;memset(in,0,sizeof in);
	for (int i=1;i<=cnt;++i){
		int u=p[i];
		for (int j=head[u];j;j=e[j].nxt){
			int v=e[j].vk;
			if (vis[v]&&dis[k][u]+e[j].wk==dis[k][v]) ++in[v];
		}
	}
	for (int i=1;i<=cnt;++i) if (!in[p[i]]) sta[++top]=p[i],pos[p[i]]=top;
	for (int h=1;h<=top;++h){
		int u=sta[h];
		for (int i=head[u];i;i=e[i].nxt){
			int v=e[i].vk;
			if (vis[v]&&dis[k][u]+e[i].wk==dis[k][v]){
				in[v]--;
				if (!in[v]) sta[++top]=v,pos[v]=top;
			}
		}
	}
	for (int i=1;i<=cnt;++i) g[k][p[i]][p[i]-1]=1;
	for (int i=top;i;--i){
		int u=sta[i];
		for (int j=head[u];j;j=e[j].nxt){
			int v=e[j].vk;
			if (vis[v]&&(dis[k][u]+e[j].wk==dis[k][v])&&pos[u]<pos[v])
				g[k][u]|=g[k][v];
		}
	}
}
void init(){
	memset(vis,0,sizeof vis);
	for (int i=1;i<=n;++i)
		if (dis[0][i]+dis[1][i]==dis[0][t])
			p[++cnt]=i,vis[i]=1;
	for (int i=1;i<=cnt;++i) mp[sum[0][p[i]]*sum[1][p[i]]]|=1<<(p[i]-1);
}
void solve(){
	topsort(0);
	topsort(1);
	for (int i=1;i<=cnt;++i){
		int u=p[i];
		ans+=((mp[sum[0][t]-sum[0][u]*sum[1][u]]>>(i-1))&
			   (~g[0][u]>>(i-1))&(~g[1][u]>>(i-1))).count();
	}
	ll tmp=0;
	for (int i=1;i<=cnt;++i)
		if (sum[0][p[i]]*sum[1][p[i]]==sum[0][t]) ++tmp;
	ans+=tmp*(n-cnt);
	printf("%lld",ans);
}
int main(){
	scanf("%d%d%d%d",&n,&m,&s,&t);
	for (int i=1;i<=m;++i){
		int u,v,w;scanf("%d%d%d",&u,&v,&w);
		ad(u,v,w),ad(v,u,w);
	}
	sp(0,s,t);sp(1,t,s);
	init();
	solve();
	return 0;
}

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注