[BJOI2018]治疗之雨

|

一道高消好题,能够加深你对高消的理解。

传送门

对于这种概率题,一个套路是直接设答案推式子。

设$f[i]$表示英雄有i滴血时期望的生存回合,设$p[i]$为掉i滴血的概率,可以看出$p[i]$是和当前血量无关的。
$$
f[i]=1+\sum_{j=1}^{i}f[i-j]\times p[j]\times\frac{m}{m+1}+f[i-j+1]\times p[j+1]\times\frac{1}{m+1}
$$
意思是在不被加血的情况下掉j滴血和加血的情况下掉j+1滴血。

同时,f[i]还可以得到f[i+1],即在i<n的情况下加了一滴血并不受伤:
$$
f[i]+=f[i+1]\times\frac{p[0]}{m+1}(i<n)
$$
我们来看p:
$$
p[i]=C_{k}^{i}\times(\frac{1}{m+1})^i\times(\frac{m}{m+1})^{k-i}
$$
把组合数展开:
$$
p[i]=\frac{k!}{i!\times (k-i)!}\times(\frac{1}{m+1})^i\times(\frac{m}{m+1})^{k-i}
$$
这个东西可以在预处理$\frac{1}{i!}$后n^2算,但有o(n)的做法:
$$
p[i]=\frac{k!}{(i-1)!\times(k-i-1)!}\times(\frac{1}{m+1})^{i-1}\times(\frac{m}{m+1})^{k-i+1}\times\frac{m\times(m+1)}{m+1}\times\frac{k-i+1}{i}
$$

$$
p[i]=p[i-1]\times\frac{1}{m}\times\frac{k-i+1}{i}
$$
初始化$p[0]=(\frac{m}{m+1})^k$

这样我们就可以根据前面f[i]的方程列出一个高斯消元的矩阵,得到一个$O(n^3)$的做法,获得70分的好成绩。

这个矩阵由于前面f[i]只能从不比他大的和f[i+1]转移到所以是这样:

$a[1][1]$ $a[1][2]$ 0 0 -1
$a[2][1]$ $a[2][2]$ $a[2][3]$ 0 -1
$a[3][1]$ $a[3][2]$ $a[3][3]$ $a[3][4]$ -1
$a[4][1]$ $a[4][2]$ $a[4][3]$ $a[4][4]$ -1

除了最后一行以外,其他行对角线右边只有一个元素,这样我们在对第i行消元的时候可以只消掉第i列,最后一列只会剩下一项,由此可以解出前面所有的答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include<cstdio>
#include<cstring>
#include<iostream>
#define re register
#define ll long long
#define mod 1000000007
#ifdef ONLINE_JUDGE
char ss[1<<17],*A=ss,*B=ss;
inline char gc(){if(A==B){B=(A=ss)+fread(ss,1,1<<17,stdin);if(A==B)return EOF;}return*A++;}
template<class T>inline void read(T&x){
static char c;static int y;
for(c=gc(),x=0,y=1;c<48||57<c;c=gc())if(c=='-')y=-1;
for(;48<=c&&c<=57;c=gc())x=((x+(x<<2))<<1)+(c^'0');
x*=y;
}
#else
void read(ll&x){scanf("%lld",&x);}
#endif
using namespace std;
ll dft[1501],a[1501][1501],rec[1501];
ll p[1501],n,u,m,k,t;
inline ll ksm(ll a,ll b){
ll res=1;
while (b){
if (b&1) res=(res*a)%mod;
a=a*a%mod; b>>=1;
}
return res;
}
inline ll C(ll a,ll b){
ll res=1;
for (re int i=0; i<a; i++) res=(res*(b-i))%mod;
return res*dft[a]%mod;
}
inline void gauss(){
for (re int i=1; i<n; i++){
// re int bj=i;
// for (re int j=i+1; j<=n; j++)
// if (a[j][i]){bj=j; break;}
// for (re int j=i; j<=n+1; j++) swap(a[i][j],a[bj][j]);
ll inv=ksm(a[i][i],mod-2); a[i][n+1]=(a[i][n+1]*inv)%mod;
a[i][i+1]=(a[i][i+1]*inv)%mod; a[i][i]=1;
for (re int j=i+1; j<=n; j++){
a[j][n+1]=(a[j][n+1]+mod-a[i][n+1]*a[j][i]%mod)%mod;
a[j][i+1]=(a[j][i+1]+mod-a[i][i+1]*a[j][i]%mod)%mod;
a[j][i]=0;
}
// for (re int j=i; j<=n+1; j++) a[i][j]=a[i][j]*inv%mod;
// for (re int j=1; j<=n; j++){
// if (j==i) continue;
// ll r=a[j][i];
// for (re int k=i; k<=n+1; k++) a[j][k]=(a[j][k]+mod-a[i][k]*r%mod)%mod;
// a[j][i]=0;
// }
}
// for (re int i=1; i<=n; i++){
// for (re int j=1; j<=n+1; j++) printf("%lld ",a[i][j]);
// puts("");
// }
ll inv=ksm(a[n][n],mod-2); a[n][n+1]=(a[n][n+1]*inv)%mod; a[n][n]=1;
for (re int i=n-1; i>=u; i--){
a[i][n+1]=(a[i][n+1]+mod-a[i+1][n+1]*a[i][i+1]%mod)%mod;
// a[i][i+1]=a[i][n+1]*ksm(a[i][i],mod-2)%mod;
a[i][i]=1;
}
}
int main(){
read(t);
dft[0]=dft[1]=rec[1]=1;
for (re int i=2; i<=1500; i++) rec[i]=(mod-mod/i)*rec[mod%i]%mod;
// for (re int i=2; i<=1500; i++) dft[i]=dft[i-1]*rec[i]%mod;
while (t--){
read(n),read(u),read(m),read(k);
if (!k){puts("-1"); continue;}
if (!m&&(k==1)){puts("-1"); continue;}
if (!m){
re int ans=0;
while (u>0){
if (u<n) u++;
u-=k; ans++;
}
printf("%d\n",ans); continue;
}
ll ny=ksm(m+1,mod-2),nym=ksm(m,mod-2);
memset(p,0,sizeof(p));
memset(a,0,sizeof(a));
p[0]=ksm(m*ny%mod,k);
for (re int i=1; i<=min(n,k); i++)
p[i]=p[i-1]*nym%mod*rec[i]%mod*(k-i+1+mod)%mod;
// for (re int i=0; i<=n; i++)
// p[i]=(C(i,k)*ksm(ny,i)%mod)*ksm(m*ny%mod,k-i)%mod;
for (re int i=1; i<=n-1; i++){
for (re int j=0; j<=i-1; j++)
a[i][i-j]=(p[j]*m%mod+p[j+1])%mod*ny%mod;
a[i][i+1]=p[0]*ny%mod;
}
for (re int i=1; i<=n; i++) a[n][i]=p[n-i];
for (re int i=1; i<=n; i++) a[i][i]=(a[i][i]-1+mod)%mod;
for (re int i=1; i<=n; i++) a[i][n+1]=mod-1;
gauss();
printf("%lld\n",a[u][n+1]);
}
return 0;
}
文章目录
,
载入天数...载入时分秒... 字数统计:75.7k