Loj#2538.「PKUWC 2018」Slay the Spire | wzf2000's blog

Loj#2538.「PKUWC 2018」Slay the Spire

Loj#2538.「PKUWC 2018」Slay the Spire 题解

题意:

  • 略。

题解:

  • 首先可以根据拿到几张强化几张攻击分类。
  • 之后就是计算类似于取 $i$ 张前 $j$ 大的各种方案的和。
  • 这个比较难直接算。
  • 考虑从大到小排序后,计算取了 $i$ 张,最后一张取了第 $j$ 大的总和为 $f[i][j]$ 和 $g[i][j]$。
  • $f$ 的转移:$f[i][j]=w[j]\times \dbinom{j-1}{i-1}+\sum\limits_{k=0}^{j-1}f[i-1][j]$
  • $g$ 的转移:$g[i][j]=w[j]\times \sum\limits_{k=0}^{j-1}g[i-1][j]$
  • 然后取原来的就可以表示为:$F[i][j]=\sum\limits_{k=0}^{n}f[j][k]\times \dbinom{n-k}{i-j}$
  • $G[i][j]=\sum\limits_{k=0}^{n}g[j][k]\times \dbinom{n-k}{i-j}$
  • 然后就可以计算了,时间复杂度 $O(n^2)$。

代码:

#include <bits/stdc++.h>
#define gc getchar()
using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=3009;
int n,m,k,jc[N],inv[N],jc_inv[N],f[N][N],g[N][N],a[N],b[N];
int read()
{
    int x=1;
    char ch;
    while (ch=gc,ch<'0'||ch>'9') if (ch=='-') x=-1;
    int s=ch-'0';
    while (ch=gc,ch>='0'&&ch<='9') s=s*10+ch-'0';
    return s*x;
}
int C(int n,int m)
{
    if (n<m) return 0;
    return (ll)jc[n]*jc_inv[m]%mod*jc_inv[n-m]%mod;
}
int F(int x,int y)
{
    int ret=0;
    for (int i=0;i<=n;i++)
        ret=(ret+(ll)f[y][i]*C(n-i,x-y))%mod;
    return ret;
}
int G(int x,int y)
{
    int ret=0;
    for (int i=0;i<=n;i++)
        ret=(ret+(ll)g[y][i]*C(n-i,x-y))%mod;
    return ret;
}
int main()
{
    jc[0]=1;
    for (int i=1;i<N;i++) jc[i]=(ll)jc[i-1]*i%mod;
    inv[1]=1;
    for (int i=2;i<N;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
    jc_inv[0]=1;
    for (int i=1;i<N;i++) jc_inv[i]=(ll)jc_inv[i-1]*inv[i]%mod;
    int T=read();
    while (T--)
    {
        n=read(),m=read(),k=read();
        for (int i=1;i<=n;i++) b[i]=read();
        for (int i=1;i<=n;i++) a[i]=read();
        sort(a+1,a+n+1,greater<int>()),sort(b+1,b+n+1,greater<int>());
        memset(f,0,sizeof(f)),memset(g,0,sizeof(g));
        g[0][0]=1;
        for (int i=1;i<=n;i++)
        {
            int tmp1=0,tmp2=0;
            for (int j=0;j<=n;j++)
            {
                f[i][j]=((ll)a[j]*C(j-1,i-1)+tmp1)%mod;
                g[i][j]=(ll)b[j]*tmp2%mod;
                tmp1=(tmp1+f[i-1][j])%mod;
                tmp2=(tmp2+g[i-1][j])%mod;
            }
        }
        int Ans=0;
        for (int i=0;i<k;i++)
            Ans=(Ans+(ll)G(i,i)*F(m-i,k-i))%mod;
        for (int i=k;i<m;i++)
            Ans=(Ans+(ll)G(i,k-1)*F(m-i,1))%mod;
        printf("%d\n",Ans);
    }
    return 0;
}