Loj#2542.「PKUWC 2018」随机游走 | wzf2000's blog

Loj#2542.「PKUWC 2018」随机游走

Loj#2542.「PKUWC 2018」随机游走 题解

题意:

  • $n$ 个节点的树,$Q$ 次询问从 $x$ 随机游走完一个集合的点期望步数。
  • $n\le 18,Q\le 5000$。

题解:

  • 听说这题题解挺多的。
  • 按我这种写法其实可以出到 $n\le 20,Q\le 100000$。
  • 以下方法完全来自于 $\rm wanglichao1121$ 的题解,然而他好像忘记上传了,所以我直接丢上来了。
  • (所以看到的可以去这里膜一膜)

随机

  • 不妨设 $X_i$ 为代表 $i$ 号点第一次被走到的时间的一个随机变量(不是期望!!!而且不是独立的)。
  • 那么就可以列出一个一看就大有可为的柿子:
  • 使用陈老师高超的 $\min\max$ 技巧容斥一波,先抛结论:
  • 冷静观察一下,发现 $\max$ 变成了 $\min$。
  • 这也正是 $\min\max$ 容斥的核心,为了不影响思维连贯性,这个感性证明将放到最后。
    • $\rm Q$:为什么要枚举一层 $T$ 来让 $\max$ 改为 $\min$ 呢?
    • $\rm A$:回到原题中会发现若干个点到达时间最小值的期望就是游走时碰到第一个点就停的期望步数一看就好做多了啊!也就是说我们只要算出这个东西就可以 $O(2^k)$ 询问啦
  • 然后一波 $O(2^nn)$ 子集变幻一下就能预处理出所有答案,做到 $O(1)$ 询问。
    • $\rm Q$:所以这个东西咋求啊?
    • $\rm A$:这就是下一套理论了。

游走

  • 首先枚举点集。
  • 考虑现在剩下的问题。
  • 你在树上游走,遇到一个特殊点就结束,问期望步数。
  • 我们可以列出方程:
  • 于是我们已经有了一个暴力高斯消元的 $O(2^nn^3)$ 的做法了。
  • 然而考虑到树的独特结构,对于一般结点(非特殊点且非叶子结点)的那个方程,我们还能进行改写:
  • 如果考虑在树上跑 $\rm dfs$,每次返回当前节点关于它父亲节点的一次函数,就可以 $O(n)$ 跑出答案啦!

证明

  • 你要的 $\min\max$ 感性证明在这O(∩_∩)O~。
  • 先重写一下要证的柿子:
  • 其中 $X_i$ 是一堆随机变量,这里并没有要求独立。
  • 我们不妨假设这些随机变量其实对应随机的集合。容斥套容斥!
  • 对于每一个 $X_i$ 的任意取值,我们构造形如 $\{1,2,3,…,X_i\}$ 的集合 $Y_i$ 。
  • 这么规定就让 $\min$ 和 $\max$ 很容易变形。
  • 容易发现:
  • 多元素的 $\min$ 和 $\max$ 同理,而且可以发现这里的对应也就意味着对应集合的规模为对应变量的大小。
  • 那么我们观察我们要证的东西,就等价于:
  • 这不就是经典容斥的柿子套上一层期望的线性性吗!

代码:

#include <bits/stdc++.h>
#define gc getchar()
using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=20;
int n,m,st,first[N],number,Ans[N],sp[N],f[1<<N];
struct edge
{
    int to,next;
    void add(int x,int y)
    {
        to=y,next=first[x],first[x]=number;
    }
}e[N<<1];
int ksm(int x,int y,int ret=1)
{
    for (;y;y>>=1,x=(ll)x*x%mod)
        if (y&1) ret=(ll)ret*x%mod;
    return ret;
}
struct node
{
    int a,b;
    node(int a=0,int b=0):a(a),b(b){};
    friend node operator +(const node &A,const node &B)
    {
        return node((A.a+B.a)%mod,(A.b+B.b)%mod);
    }
    friend node operator /(const node &A,const int &B)
    {
        int x=ksm(B,mod-2);
        return node((ll)A.a*x%mod,(ll)A.b*x%mod);
    }
}dp[N];
int read()
{
    int x=1;
    char ch;
    while (ch=gc,ch<'0'||ch>'9') if (ch=='-') x=-1;
    int s=ch-'0';
    while (ch=gc,ch>='0'&&ch<='9') s=s*10+ch-'0';
    return s*x;
}
void dfs(int x,int last)
{
    int d=(last>0);
    dp[x]=node(0,0);
    if (sp[x]) return;
    for (int i=first[x];i;i=e[i].next)
        if (e[i].to!=last)
        {
            dfs(e[i].to,x);
            dp[x]=dp[x]+dp[e[i].to];
            d++;
        }
    if (d==1&&last)
    {
        dp[x]=node(1,1);
        return;
    }
    dp[x]=dp[x]/d;
    dp[x]=node(ksm((ll)d*(mod+1-dp[x].a)%mod,mod-2),(ll)(dp[x].b+1)*ksm(mod+1-dp[x].a,mod-2)%mod);
}
void dft(int *a,int n)
{
    for (int i=0;i<n;i++)
        for (int j=0;j<(1<<n);j++)
            if (j>>i&1) a[j]=(a[j]+a[j^(1<<i)])%mod;
}
int main()
{
    n=read(),m=read(),st=read();
    for (int i=1;i<n;i++)
    {
        int x=read(),y=read();
        e[++number].add(x,y),e[++number].add(y,x);
    }
    for (int i=1;i<(1<<n);i++)
    {
        int cnt=-1;
        for (int j=1;j<=n;j++)
            if (i>>(j-1)&1) sp[j]=1,cnt++;
            else sp[j]=0;
        dfs(st,0);
        f[i]=cnt&1?(mod-dp[st].b)%mod:dp[st].b;
    }
    dft(f,n);
    while (m--)
    {
        int k=read(),now=0;
        while (k--) now|=1<<(read()-1);
        printf("%d\n",f[now]);
    }
    return 0;
}