CODE FESTIVAL 2017 qual A,E: Modern Painting

http://code-festival-2017-quala.contest.atcoder.jp/tasks/code_festival_2017_quala_e

まず最初に縦の向きに人を動かすとします。すると領域は二つに分断されます。そのうち一つについて横に進む人がX人、上から下に進む人がY人、下から上に進む人がZ人だとします。
するとその領域の塗り方はC(X+Y+Z-1,X-1)になります。なんでかは次の図を見てもらえればわかるかな…。

f:id:omochangram:20171007035342j:plain

このように一部反転させて、横に進む人が通過した領域を囲うようにして移動することを考えると、座標(0,0)から座標(X,Y+Z)までの、(Y座標)=Yの時には必ず一回は右に行くような移動方法の個数が領域の塗り方の個数と一致します。Y座標に初めてついたときのX座標をxとして[x,x+1)の区間を切り取れば、結局座標(X-1,Y+Z)へ移動する方法の個数を求めればいいことになり、C(X+Y+Z-1,X-1)となります。

これが求められたら後は累積和チックに求めれば計算量はO(N+M)となり十分高速です。

コンビネーションの複雑な式が出そうになったら、意味を考えて経路の問題に帰着させることを考えるといいかも?

上から下に進む人と右から左に進む人しかいなかったら簡単にできるなぁとは思いましたが、一回縦に分断した領域についてこんな簡単な式になるとは思いませんでした。

ll fac[MAX_N], inv[MAX_N], fiv[MAX_N]; //fiv:inv(fac(i))
ll pow2[MAX_N];

void C_init(int n) {
	fac[0] = fac[1] = 1; inv[1] = 1;
	fiv[0] = fiv[1] = 1;
	pow2[0] = 1; pow2[1] = 2;
	rep(i, 2, n + 1) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = mod - inv[mod % i] * (mod / i) % mod;
		fiv[i] = fiv[i - 1] * inv[i] % mod;
		pow2[i] = pow2[i - 1] * 2 % mod;
	}
}

ll getC(int a, int b) { //assume a >= b
	if(a < b || a < 0 || b < 0) return 0;
	return fac[a] * fiv[b] % mod * fiv[a - b] % mod;
}

int N, M;
string A, B, C, D;
ll bitcnt[MAX_N];
ll sum[MAX_N];

ll solve2(int y, int z, int x) {
	if(x == 0) {
		if(y == 0 && z == 0) return 1;
		else return 0;
	}
	else return getC(x + y + z - 1, x - 1);
}

ll solve_sub() {
	int ccnt = accumulate(all(C), 0) - '0' * M;
	int dcnt = accumulate(all(D), 0) - '0' * M;
	int acnt = 0, bcnt = 0;
	bitcnt[0] = 0;
	rep(i, 0, N) bitcnt[i + 1] = bitcnt[i] + ((A[i] == '1' && B[i] == '1') ? 1 : 0);
	sum[N] = 0;
	rer(i, N, 0) {
		if(A[i] == '0' && B[i] == '0') sum[i] = sum[i + 1];
		else {
			sum[i] = sum[i + 1] + solve2(acnt, bcnt, dcnt) * pow2[bitcnt[i + 1]] % mod;
			sum[i] %= mod;
		}
		if(A[i] == '1') acnt++;
		if(B[i] == '1') bcnt++;
	}
	acnt = 0; bcnt = 0;
	acnt = 0; bcnt = 0;
	ll res = 0;
	ll mul = 1;
	rep(i, 0, N) {
		if(A[i] == '0' && B[i] == '0') continue;
		ll a = solve2(acnt, bcnt, ccnt);
		ADD(res, a * mul % mod * sum[i] % mod);
		if(A[i] == '1') acnt++;
		if(B[i] == '1') bcnt++;
		if(A[i] == '1' && B[i] == '1') MUL(mul, inv[2]);
	}
	return res;
}



void solve() {
	C_init(600010);
	cin >> N >> M;
	cin >> A >> B >> C >> D;
	int a = accumulate(all(A), 0) - '0' * N;
	int b = accumulate(all(B), 0) - '0' * N;
	int c = accumulate(all(C), 0) - '0' * M;
	int d = accumulate(all(D), 0) - '0' * M;
	if(a == 0 && b == 0 && c == 0 && d == 0) {
		cout << 1 << "\n";
		return;
	}
	ll ans = 0;
	if(a != 0 || b != 0) ADD(ans, solve_sub());
	if(c != 0 || d != 0) {
		swap(N, M);
		swap(A, C);
		swap(B, D);
		ADD(ans, solve_sub());
	}
	cout << ans << "\n";
}

100post目だやったぁ。