题目(BZOJ)
题目(LibreOJ)
来说说这题吧
首先肯定要求 $s$ 到 $t$ 的最短路径条数,我们就根据这个来做。
一对 $(A,B)$ 满足条件,则 $s$ 到 $A$ 到 $t$ 的路径条数$+s$ 到 $B$ 到 $t$ 的路径条数$==s$到 $t$ 的路径条数,这个很显然,
于是我们设 $sum[0][i]$ 表示从 $s$ 出发到i的最短路径条数,$sum[1][i]$ 表示从 $t$ 出发到 $i$ 的最短路径条数,
则条件一转化为: $sum[0][A]\times sum[1][A]+sum[0][B]\times sum[1][B]=sum[0][t]$ ,
移一下项:$sum[0][A]\times sum[1][A]=sum[0][t]-sum[0][B]\times sum[1][B]$,
所以我们将 $sum[0][A]\times sum[1][A]$ 作为状态存入一个 map 中,也就相当于把 $sum[0][t]-sm[0][B]*sum[1][B]$ 存入 map 中。
同时,我们再开一个 $g$ 数组,表示 $A$ 与 $B$ 相互可达性。
然后就做完了。
总而言之,算法流程如下:
- 正反各一次求一遍最短路,更新 $sum$ 数组;
- 预处理满足 $sum[0][i]+sum[1][i]==sum[0][t]$ 的点,放到 $p$ 数组中,并且预处理这个 map;
- 对 $p$ 数组中的点正反各一次拓扑排序,更新 $g$ 数组;
- 最后做个统计,完毕。
不过这题有 $s$ 和 $t$ 不联通的情况......直接输出 $\frac{n(n-1)}{2}$ 即可。
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N=50050; const ll inf=1e15+5; int n,m,s,t; int p[N],cnt; int sta[N],top,in[N],pos[N]; map <ll,bitset<N> > mp; bitset <N> g[2][N]; struct edge{ int vk,wk,nxt; }e[N<<1]; int head[N],tt; inline void ad(int u,int v,int w){e[++tt].vk=v;e[tt].wk=w;e[tt].nxt=head[u];head[u]=tt;} int vis[N];ll dis[2][N]; ll sum[2][N],ans; queue <int> q; void sp(int k,int S,int T){ sum[k][S]=1; memset(vis,0,sizeof vis); for (int i=1;i<=n;++i) dis[k][i]=inf; dis[k][S]=0;vis[S]=1;q.push(S); while(!q.empty()){ int u=q.front();q.pop();vis[u]=0; for (int i=head[u];i;i=e[i].nxt){ int v=e[i].vk; if (dis[k][v]>dis[k][u]+e[i].wk){ sum[k][v]=sum[k][u]; dis[k][v]=dis[k][u]+e[i].wk; if (!vis[v]) vis[v]=1,q.push(v); } else if (dis[k][v]==dis[k][u]+e[i].wk) sum[k][v]+=sum[k][u]; } } if (dis[k][T]==inf) printf( } void topsort(int k){ top=0;memset(in,0,sizeof in); for (int i=1;i<=cnt;++i){ int u=p[i]; for (int j=head[u];j;j=e[j].nxt){ int v=e[j].vk; if (vis[v]&&dis[k][u]+e[j].wk==dis[k][v]) ++in[v]; } } for (int i=1;i<=cnt;++i) if (!in[p[i]]) sta[++top]=p[i],pos[p[i]]=top; for (int h=1;h<=top;++h){ int u=sta[h]; for (int i=head[u];i;i=e[i].nxt){ int v=e[i].vk; if (vis[v]&&dis[k][u]+e[i].wk==dis[k][v]){ in[v]--; if (!in[v]) sta[++top]=v,pos[v]=top; } } } for (int i=1;i<=cnt;++i) g[k][p[i]][p[i]-1]=1; for (int i=top;i;--i){ int u=sta[i]; for (int j=head[u];j;j=e[j].nxt){ int v=e[j].vk; if (vis[v]&&(dis[k][u]+e[j].wk==dis[k][v])&&pos[u]<pos[v]) g[k][u]|=g[k][v]; } } } void init(){ memset(vis,0,sizeof vis); for (int i=1;i<=n;++i) if (dis[0][i]+dis[1][i]==dis[0][t]) p[++cnt]=i,vis[i]=1; for (int i=1;i<=cnt;++i) mp[sum[0][p[i]]*sum[1][p[i]]]|=1<<(p[i]-1); } void solve(){ topsort(0); topsort(1); for (int i=1;i<=cnt;++i){ int u=p[i]; ans+=((mp[sum[0][t]-sum[0][u]*sum[1][u]]>>(i-1))& (~g[0][u]>>(i-1))&(~g[1][u]>>(i-1))).count(); } ll tmp=0; for (int i=1;i<=cnt;++i) if (sum[0][p[i]]*sum[1][p[i]]==sum[0][t]) ++tmp; ans+=tmp*(n-cnt); printf( } int main(){ scanf( for (int i=1;i<=m;++i){ int u,v,w;scanf( ad(u,v,w),ad(v,u,w); } sp(0,s,t);sp(1,t,s); init(); solve(); return 0; }
Comments NOTHING