这题实际上比较套路......
注意到 $n\le 50000$,大概率是想让我们用个 $\mathcal{O}(n\sqrt{n})$ 的复杂度。想到根号复杂度的话就很有可能是个根号分治。这题也不例外。
我们按照 $k$ 的大小分成两类:假设一个界为 $B=200$,当 $k\ge B$ 的时候就可以直接暴力模拟这个跳的过程,当 $k<B$ 的时候就可以考虑用什么东西来维护。
先说简单的部分:暴力向上跳。我们可以将路径 $x\to y$ 拆成 $x\to lca,lca \to y$。这样的话,我们需要距离 $x$ 为 $k$ 的祖先是哪一个(也就是 $x$ 的 $k$ 级祖先)。这一部分我们可以用长链剖分来搞,虽然慢了点但还是能过。这样的话,就可以将路径看成 $x$ 到 $lca$,$y$ 到 $lca$。
注意一个细节:$y$ 需要先往上跳一小段再暴力跳到 $lca$,因为题目中有说 “距离小于 $k$ 的话就直接一步跳到 $y$”。这样的话,因为 $y$ 往上跳与实际路径方向相反,所以需要先处理这一小段。
之后再说 $k<B$ 的部分。设 $s(i,j)$ 表示从 $i$ 开始每次跳 $j$ 步一直跳到根节点为止的路径上的点权和(当然也有可能跳不到根节点,但这个其实无所谓,因为这个时候再跳 $j$ 步的话已经跳出去了)。这一步有一个递推式子:设 $f(i,j)$ 表示 $i$ 的 $j$ 级祖先是哪个点,那么显然我们就有 $s(i,j)=s(f(i,j),j)+a_i$($a_i$ 为点权),注意 $f(i,j)$ 可能不存在,这个时候就是 $s(i,j)=a_i$。
之后,我们可以算出 $s(x,k)$ 并减去在 $lca(x,y)$ 上方的部分贡献,同理也可以算出 $s(y,k)$ 并减去在 $lca(x,y)$ 上方的部分贡献。
综上,这个题就做完了。
但是这题细节是真的超多......
#include <bits/stdc++.h> #define ll long long #define ls id << 1 #define rs id << 1 | 1 #define mem(array, value, size, type) memset(array, value, ((size) + 5) * sizeof(type)) #define memarray(array, value) memset(array, value, sizeof(array)) #define pb(x) push_back(x) #define st(x) (1LL << (x)) #define pii pair<int, int> #define mp(a, b) make_pair((a), (b)) #define Flush fflush(stdout) #define vecfirst (*vec.begin()) #define veclast (*vec.rbegin()) #define vecall(v) (v).begin(), (v).end() #define vecupsort(v) (sort((v).begin(), (v).end())) #define vecdownsort(v, type) (sort(vecall(v), greater<type>())) #define veccmpsort(v, cmp) (sort(vecall(v), cmp)) using namespace std; const int N = 50050; const int inf = 0x3f3f3f3f; const ll llinf = 0x3f3f3f3f3f3f3f3f; const int mod = 998244353; const int MOD = 1e9 + 7; const double PI = acos(-1.0); clock_t TIME__START, TIME__END; void program_end() { #ifdef ONLINE printf("\nTime used: system("pause"); #endif } const int M = 201; int s[N][M], a[N]; int n; int b[N], d[N], f[N][16], son[N], dep[N], top[N], g[N]; vector<int> e[N], u[N], v[N]; void dfs(int x, int fa) { dep[x] = d[x] = d[fa] + 1; f[x][0] = fa; for (int i = 1; (1 << i) <= d[x]; i++) f[x][i] = f[f[x][i - 1]][i - 1]; for (auto y : e[x]) { if (y == fa) continue; dfs(y, x); if (dep[y] > dep[x]) dep[x] = dep[y], son[x] = y; } } inline int getlca(int a, int b) { if (d[a] > d[b]) swap(a, b); for (int i = 15; i >= 0; i--) if (d[a] <= d[b] - (1 << i)) b = f[b][i]; if (a == b) return a; for (int i = 15; i >= 0; i--) { if (f[a][i] == f[b][i]) continue; else a = f[a][i], b = f[b][i]; } return f[a][0]; } inline int getdis(int x, int y) { return d[x] + d[y] - 2 * d[getlca(x, y)]; } void dfs(int x, int p, int fa) { top[x] = p; if (x == p) { for (int i = 0, o = x; i <= dep[x] - d[x]; ++i) u[x].pb(o), o = f[o][0]; for (int i = 0, o = x; i <= dep[x] - d[x]; ++i) v[x].pb(o), o = son[o]; } if (son[x]) dfs(son[x], p, f[son[x]][0]); for (auto y : e[x]) if (y != son[x] && y != fa) dfs(y, y, x); } inline int ask(int x, int k) { if (!k) return x; x = f[x][g[k]], k -= 1 << g[k], k -= d[x] - d[top[x]], x = top[x]; return k >= 0 ? (k < (int)u[x].size() ? u[x][k] : -1) : (k < (int)v[x].size() ? v[x][-k] : -1); } void dfs2(int x, int fa) { for (int j = 1; j < M; ++j) { if (ask(x, j) == -1) s[x][j] = a[x]; else s[x][j] = s[ask(x, j)][j] + a[x]; } for (auto vv : e[x]) { if (vv == fa) continue; dfs2(vv, x); } } inline int queryBig(int x, int y, int k) { int resx = 0, resy = 0; int lca = getlca(x, y), D = getdis(x, y); int dx = getdis(x, lca), dy = getdis(y, lca); if (lca == x) { int ry = D if (ry) resy += a[y]; y = ask(y, ry); while (~y && d[y] >= d[lca]) resy += a[y], y = ask(y, k); return resy; } while (~x && d[x] > d[lca]) resx += a[x], x = ask(x, k); if (lca == y) return resx + a[y]; int ry = D if (ry) resy += a[y]; y = ask(y, ry); while (~y && d[y] > d[lca]) resy += a[y], y = ask(y, k); if (dx resx += a[lca]; return resx + resy; } inline int querySmall(int x, int y, int k) { int resx = 0, resy = 0; int lca = getlca(x, y), D = getdis(x, y); int ry = D if (ry) resy += a[y]; if (lca == x) { y = ask(y, ry); int lcay = ask(lca, k); resy += (y == -1 ? 0 : s[y][k]) - (lcay == -1 ? 0 : s[lcay][k]); return resy; } if (lca == y) { int lcax = ask(lca, k - D resx += s[x][k] - (lcax == -1 ? 0 : s[lcax][k]); return resx; } int dx = getdis(x, lca); int lcax = ask(lca, k - dx resx += s[x][k] - (lcax == -1 ? 0 : s[lcax][k]); y = ask(y, ry); int lcay = ask(lca, dx resy += (y == -1 ? 0 : s[y][k]) - (lcay == -1 ? 0 : s[lcay][k]); return resx + resy; } inline void solve() { scanf( g[0] = -1; for (int i = 1; i <= n; ++i) g[i] = g[i >> 1] + 1, scanf( for (int i = 1; i < n; ++i) { int x, y; scanf( e[x].push_back(y), e[y].push_back(x); } dfs(1, 0); dfs(1, 1, 0); dfs2(1, 0); for (int i = 1; i <= n; ++i) scanf( int uu, vv, k; for (int i = 2; i <= n; ++i) { scanf( int ans = 0; if (k >= M) ans = queryBig(b[i - 1], b[i], k); else ans = querySmall(b[i - 1], b[i], k); printf( } } int main() { TIME__START = clock(); int Test = 1; // scanf( while (Test--) solve(); TIME__END = clock(); program_end(); return 0; }
Comments NOTHING