设 $f(n)$ 表示根节点的数字恰好为 $n$ 的方案数,那么只需要枚举左右子树根节点的数字对应的方案数,乘起来再全部加起来即可。即

$$f(n)=\sum\limits_{d|n}f(d)f(\frac{n}{d}) + [\exists k | n]$$

接下来,设 $F(n)=\sum\limits_{i=1}^n f(i)$,那么我们要求的答案就是 $F(n)$,试着把上面的式子代入化简一下?

$$
\begin{aligned}
F(n) &= \sum\limits_{i=1}^n f(i) \\
&= \sum\limits_{i=1}^n([\exists k|i] + \sum\limits_{d|i}f(d)f(\frac{i}{d}) ) \\
&=\sum\limits_{i=1}^n\sum\limits_{d|i}f(d)f(\frac{i}{d}) +\sum\limits_{i=1}^n[\exists k|i]
\end{aligned}
$$

上述式子后面这一坨可以 $2^m$ 容斥随便搞搞,接下来看前面这一坨:
中间有一步用 $i=td$ 代换

$$
\begin{aligned}
\sum\limits_{i=1}^n\sum\limits_{d|i}f(d)f(\frac{i}{d})
&=
\sum\limits_{d=1}^n\sum\limits_{d|i}f(d)f(\frac{i}{d}) \\
&= \sum\limits_{d=1}^nf(d)\sum\limits_{d|i}f(\frac{i}{d}) \\
&= \sum\limits_{d=1}^nf(d)\sum\limits_{t=1}^{\lfloor n/d \rfloor}f(t) \\
&=
\sum\limits_{d=1}^nf(d)F(\lfloor n/d \rfloor)
\end{aligned}
$$

注意到数论分块的时候,后面每一块 $\sum\limits_{d=l}^r f(d)$ 也就是 $F(r)-F(l-1)$,直接计算即可。

总的来说这个差不多就是杜教筛的套路?复杂度的话题解给了个 $\mathcal{O}(n\log n)^{\frac{2}{3}}$,嘛我也不太会算,但我这个代码常数倒是超大......

#include <bits/stdc++.h>
#define ll long long
#define ls id << 1
#define rs id << 1 | 1
#define mem(array, value, size, type) memset(array, value, ((size) + 5) * sizeof(type))
#define memarray(array, value) memset(array, value, sizeof(array))
#define fillarray(array, value, begin, end) fill((array) + (begin), (array) + (end) + 1, value)
#define fillvector(v, value) fill((v).begin(), (v).end(), value)
#define pb(x) push_back(x)
#define st(x) (1LL << (x))
#define pii pair<int, int>
#define mp(a, b) make_pair((a), (b))
#define Flush fflush(stdout)
#define vecfirst (*vec.begin())
#define veclast (*vec.rbegin())
#define vecall(v) (v).begin(), (v).end()
#define vecupsort(v) (sort((v).begin(), (v).end()))
#define vecdownsort(v, type) (sort(vecall(v), greater<type>()))
#define veccmpsort(v, cmp) (sort(vecall(v), cmp))
using namespace std;
const int N = 1000050;
const int inf = 0x3f3f3f3f;
const ll llinf = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const int MOD = 1e9 + 7;
const double PI = acos(-1.0);
clock_t TIME__START, TIME__END;
void program_end()
{
#ifdef ONLINE
    printf("\n\nTime used:
    system("pause");
#endif
}
int add(int a, int b) { return a + b >= mod ? a + b - mod : a + b; }
int sub(int a, int b) { return a - b < 0 ? a - b + mod : a - b; }
int mul(long long a, long long b) { return add(a * b
int gcd(int a, int b) { return b == 0 ? a : gcd(b, a
const int MAXN = 300050;
vector<int> vecfac[MAXN + 5];

int n, m;
int k[5], sk[20], f[MAXN + 5];
unordered_map<int, int> mapF;
int sum_p(int n)
{
    int ret = 0;
    for (int s = 1; s < st(m); ++s)
    {
        int flag = ((__builtin_popcount(s) & 1) ? 1 : -1);
        int res = n / sk[s];
        res = mul(res, flag);
        ret = add(ret, res);
    }
    return ret;
}
int sum_f(int n)
{
    if (n <= MAXN)
        return f[n];
    if (mapF.count(n))
        return mapF[n];
    int ans = sum_p(n);
    for (int l = 2, r; l < n; l = r + 1)
    {
        r = n / (n / l);
        if (r == n)
            break;
        int res2 = sub(sum_f(r), sum_f(l - 1));
        int res1 = sum_f(n / l);
        int res = mul(res1, res2);
        ans = add(ans, res);
    }
    return mapF[n] = ans;
}

void precalcf(int T)
{
    memarray(f, 0);
    mapF.clear(), mapF[0] = mapF[1] = 0;
    for (int i = 2, sum = 0; i <= T; ++i)
    {
        for (int j = 0; j < m; ++j)
            f[i] |= (i
        for (auto d : vecfac[i])
            f[i] = add(f[i], mul(f[d], f[i / d]));
        sum = add(sum, f[i]);
        mapF[i] = sum;
    }
    for (int i = 2; i <= T; ++i)
        f[i] = add(f[i], f[i - 1]);
}

inline void solve()
{
    scanf(
    for (int i = 0; i < m; ++i)
        scanf(
    for (int s = 1; s < st(m); ++s)
        sk[s] = 1;
    for (int s = 1; s < st(m); ++s)
        for (int i = 0; i < m; ++i)
            if ((s >> i) & 1)
                sk[s] = sk[s] / gcd(sk[s], k[i]) * k[i];
    precalcf(MAXN);
    int ans = sum_f(n);
    cout << ans << '\n';
}

int main()
{
    TIME__START = clock();
    for (int i = 1; i <= MAXN; ++i)
        for (int j = i; j <= MAXN; j += i)
            vecfac[j].push_back(i);
    for (int i = 1; i <= MAXN; ++i)
        vecupsort(vecfac[i]);
    int Test = 1;
    scanf(
    while (Test--)
    {
        solve();
        // if (Test)
        //     putchar('\n');
    }
    TIME__END = clock();
    program_end();
    return 0;
}