kacho65535の競プロメモ

Atcoderと戯れる予定のブログです

yukicoder:No.1103 Directed Length Sum

めちゃくちゃ教育的で面白い問題だった!

問題概要


yukicoder.me

N 頂点の有向木がある。

この有向木は根以外の各頂点の親からその頂点へ 1 本の有向辺が伸びた構造をしている。

f(a,b) を「頂点 a から頂点 b に移動するとき経由する辺の数」とするとき、


\displaystyle
\sum_{i=1}^{N} \sum_{j=1}^{N} f(i,j)
mod 10^{9}+7 を求めよ。

但し f(a,a)=0, f(a,b)=0 ( a から b へのパスが存在しないとき)とする。

制約

  • 3 \leq N \leq 10^{6}

解法


N が大きいのでまず愚直 O(N^{2}) は通らない。

この有向木は根以外の各頂点の親からその頂点へ 1 本の有向辺が伸びた構造をしている。(意訳)という文言が重要そうである。

なぜなら a から b へのパスが存在しないとき f(a,b)=0 なので、根からとある葉へのパスに異なる頂点 ab が両方とも存在していて a が根に近いときにのみ f(a,b)0 でなくなるという言い換えができるからである。

この言い換えをすることによって何が嬉しいのかというと、前回解説した問題 yukicoder:No.1096 Range Sums - kacho65535の競プロメモ でも用いた考察典型である各要素が何回足されるかを考える問題に帰着できるからである。 しかも、足される回数の求め方も前回の問題と本質的に同じである。

以上の考察から次のことが分かる。

ある有向辺 a \to b が足される回数は、根から a までのパスに含まれる頂点の個数(根や a も含む)を x 、この有向辺を取り除いたときにできる b を含む部分木の頂点数を y として、 x×y 回である。

よってすべての有向辺に対する x×y の総和が答えになる。(辺のコストが 1 なので)

これを求めるには、各 x を根からのBFSで求めた後、木dpで部分木のサイズを記録していけばよい。

コード


#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define mod 1000000007ll
#define loop(i, n) for (int i = 0; i < n; i++)
#define all(v) v.begin(), v.end()
#define putout(a) cout << a << '\n'
#define Sum(v) accumulate(all(v), 0ll)
void bfs(vector<vector<ll>> &G, vector<ll> &dist, ll s)
{
    queue<ll> que;
    dist[s] = 0;
    que.push(s);
    while (!que.empty())
    {
        ll v = que.front();
        que.pop();
        for (auto next_v : G[v])
        {
            if (dist[next_v] != -1)
                continue;
            dist[next_v] = dist[v] + 1;
            que.push(next_v);
        }
    }
}
vector<ll> sz;
void dfs(vector<vector<ll>> &G, ll v, ll p)
{
    for (auto nv : G[v])
    {
        if (nv == p)
            continue;
        dfs(G, nv, v);
    }
    sz[v] = 1;
    for (auto nv : G[v])
    {
        if (nv == p)
            continue;
        sz[v] += sz[nv];
    }
}
int main()
{
    ll N;
    cin >> N;
    vector<vector<ll>> G(N);
    ll root = 0; //bとして1度も出現しないのが根
    sz.resize(N);
    vector<ll> A(N - 1), B(N - 1), x(N - 1), y(N - 1), used(N, false);
    loop(i, N - 1)
    {
        cin >> A[i] >> B[i];
        A[i]--;
        B[i]--;
        used[B[i]] = true;
        G[A[i]].emplace_back(B[i]);
    }
    loop(i, N) if (!used[i]) root = i;
    vector<ll> dist(N, -1);
    bfs(G, dist, root);
    dfs(G, root, -1);
    loop(i, N - 1) x[i] = dist[B[i]]; //dist[A[i]]+1でもよい
    loop(i, N - 1) y[i] = sz[B[i]];
    ll ans = 0;
    loop(i, N - 1)
    {
        ans += x[i] * y[i] % mod;
        ans %= mod;
    }
    putout(ans);
    return 0;
}