PG BATTLE 2023 解答例一覧

作問者解答はAtCoder社による解答になります。

JavaおよびPythonによる解答例はPG BATTLE運営が用意したものが含まれますので、ご了承ください。

ましゅまろ(Marshmallow)

難易度1(Level1)
コイントス(Coin Toss)

難易度2(Level2)
微分(Derivative)

難易度3(Level3)
2進数と10進数(Binary Number and Decimal Number)

難易度5(Level5)
ダンス(Shall We Dance?)

ましゅまろ(難易度1)コイントス
Marshmallow(Level1)/ Coin Toss

Python(作問者解答)

N=int(input())
print(f"{pow(0.5,N):.10f}")

Java

import java.util.*;
import java.math.BigDecimal;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        System.out.println(BigDecimal.valueOf(Math.pow(0.5, N)).toPlainString());
    }
}

ましゅまろ(難易度2)微分
Marshmallow(Level2)/ Derivative

C++(作問者解答)

#include <iostream>
#include <string>

using namespace std;
typedef long long ll;

int main() {
  string s;
  cin >> s;
  string a = "", b = "";
  int f = 0;
  for (char c : s) {
    if (c == 'x') f = 1;
    if (f != 0) f++;
    if (f == 0) a += c;
    if (f >= 4) b += c;
  }
  ll x = stoll(a), y = stoll(b);
  x *= y; y--;
  cout << x << "x^" << y << '\n';
}

Java

import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        char[] S = sc.next().toCharArray();
        String strA = "";
        String strB = "";
        boolean isA = true;
        for (char c : S) {
            if (c == 'x' || c == '^') {
                isA = false;
                continue;
            }

            if (isA) {
                strA += c;
            } else {
                strB += c;
            }
        }

        long A = Long.parseLong(strA);
        long B = Long.parseLong(strB);
        System.out.println(A*B + "x^" + (B-1));
    }
}

Python

from re import match

S = input()
A, B = map(int, S.split("x^"))

print(f"{A * B}x^{B - 1}")

ましゅまろ(難易度3)2進数と10進数
Marshmallow(Level3)/ Binary Number and Decimal Number

Python(作問者解答)

import decimal
decimal.getcontext().prec = 10000
a, b = input(), input()
A = decimal.Decimal(a)
B = decimal.Decimal(0)
for c in b:
    B = B * 2 + (1 if c == '1' else 0)
print('>' if A > B else '=' if A == B else '<')

Java

import java.util.*;
import java.math.BigInteger;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        BigInteger A = sc.nextBigInteger();
        String B = sc.next();

        int len = B.length()-1;
        BigInteger tmp = new BigInteger("1");
        BigInteger sum = new BigInteger("0");
        for(int i=len; i>=0; i--){
            if (B.charAt(i) == '1') {
                sum = sum.add(tmp);
            }
            tmp = tmp.multiply(new BigInteger("2"));
        }

        switch (A.compareTo(sum)) {
        case 0:
            System.out.println("=");
            break;
        case 1:
            System.out.println(">");
            break;
        default:
            System.out.println("<");
        }
    }
}

ましゅまろ(難易度5)ダンス
Marshmallow(Level5)/ Shall We Dance?

C(作問者解答)

#include<stdio.h>
#define MOD 998244353
#define ll long long
char s[110];
ll dp[110][110];
ll choose[110][110];
int main(){
    int n;
    scanf("%d %s",&n,s);
    
    choose[0][0]=1;
    for(int i=1;i<=n;i++){
        choose[i][0]=1;
        for(int j=1;j<=i;j++)choose[i][j]=(choose[i-1][j-1]+choose[i-1][j])%MOD;
    }
    for(int i=0;i<=n;i++)dp[i][i]=1;
    
    for(int w=2;w<=n;w+=2){
        for(int l=0,r=w;r<=n;l++,r++){
            for(int m=l+1;m<r;m+=2)if(s[l]!=s[m]){
                dp[l][r]=(dp[l][r]+dp[l+1][m]*dp[m+1][r]%MOD*choose[(r-l)/2][(r-(m+1))/2])%MOD;
            }
        }
    }
    printf("%lld\n",dp[0][n]);
}

Java

import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        char[] S = sc.next().toCharArray();
        long mod = 998244353;

        long[][] binominal = new long[N + 1][N + 1];
        binominal[0][0] = 1;
        for (int i = 1; i <= N; i++) {
            binominal[i][0] = 1;
            for (int j = 1; j <= i; j++) {
                binominal[i][j] = (binominal[i - 1][j - 1] + binominal[i - 1][j]) % mod;
            }
        }

        long[][] dp = new long[N + 1][N + 1];
        for (int i = 1; i <= N; i++) {
            dp[i][i] = 1;
        }

        for (int w = 2; w <= N; w += 2) {
            for (int l = 0, r = w; r <= N; l++, r++) {
                for (int m = l + 1; m < r; m += 2) {
                    if (S[l] != S[m]) {
                        dp[l][r] = (dp[l][r] + dp[l + 1][m] * dp[m + 1][r] % mod * binominal[(r - l) / 2][(r - (m + 1)) / 2]) % mod;
                    }
                }
            }
        }

        System.out.println(dp[0][N]);
    }
}

Python

N = int(input())
S = input()
mod = 998244353

binominal = [[0] * (N + 1) for _ in range(N + 1)]
binominal[0][0] = 1
for i in range(1, N + 1):
    binominal[i][0] = 1
    for j in range(1, i + 1):
        binominal[i][j] = (binominal[i - 1][j - 1] + binominal[i - 1][j]) % mod

dp = [[0] * (N + 1) for _ in range(N + 1)]
for i in range(N + 1):
    dp[i][i] = 1

for w in range(2, N + 1, 2):
    for r in range(w, N + 1):
        l = r - w
        for m in range(l + 1, r, 2):
            if S[l] != S[m]:
                dp[l][r] = (dp[l][r] + dp[l + 1][m] * dp[m + 1][r] * binominal[(r - l) // 2][(r - (m + 1)) // 2]) % mod

print(dp[0][N])

せんべい(Senbei)

難易度2(Level2)
積の符号(Sign)

難易度3(Level3)
ABCの個数(ABC counting)

難易度4(Level4)
距離 K(K swap)

難易度6(Level6)
トリオ(Trio)

せんべい(難易度2)積の符号
Senbei(Level2)/ Sign

C++(作問者解答)

#include <iostream>
using namespace std;
int main() {
  int N;
  cin >> N;
  int sgn = 0;
  for (int i = 0; i < N; i++) {
    int x;
    cin >> x;
    if (x == 0) {
      cout << "0\n";
      exit(0);
    }
    if (x < 0) sgn ^= 1;
  }
  cout << (sgn ? "-" : "+") << "\n";
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int[] A = new int[N];
        for (int i = 0; i < N; i++) {
            A[i] = sc.nextInt();
        }

        int p = 1;
        for (int a : A) {
            if (a == 0) {
                System.out.println(0);
                return;
            }
            p *= a / Math.abs(a);
        }

        System.out.println(p > 0 ? "+" : "-");
    }
}

Python

N = int(input())
A = list(map(int, input().split()))

P = 1
for a in A:
    if a == 0:
        print(0)
        exit()
    P *= a // abs(a)

if P > 0:
    print("+")
else:
    print("-")

せんべい(難易度3)ABCの個数
Senbei(Level3)/ ABC counting

Python(作問者解答)

S=input()

ans=0
for i in range(len(S)-2):
  q_cnt=0
  flag=True
  for c,d in zip(S[i:i+3],"ABC"):
    if c=="?":
      q_cnt+=1
    elif c!=d:
      flag=False
  if flag:
    ans+=3**(3-q_cnt)
print(f"{ans/27:.15f}")

Java

import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        char[] S = sc.next().toCharArray();
        char[] abc = {'A', 'B', 'C'};
        double ans = 0;
        for (int i = 0; i < S.length - 2; i++) {
            int cnt = 0;
            boolean flg = true;
            for (int j = 0; j < 3; j++) {
                if (S[i + j] == '?') {
                    cnt++;
                    continue;
                }

                if (S[i + j] == abc[j]) {
                    continue;
                }

                flg = false;
                break;
            }

            if (flg) {
                ans += Math.pow(3, -cnt);
            }
        }

        System.out.println(ans);
    }
}

せんべい(難易度4)距離 K
Senbei(Level4)/ K swap

C++(作問者解答)

#include <iostream>
#include <map>
#include <vector>
using namespace std;
#include "atcoder/modint.hpp"
using mint = atcoder::modint998244353;
struct Binomial {
  vector<mint> fac, invfac, inv;
  Binomial(int n) : fac(n + 1), invfac(n + 1), inv(n + 1) {
    fac[0] = invfac[0] = inv[0] = 1;
    for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i;
    invfac[n] = fac[n].inv();
    for (int i = n - 1; i >= 0; i--) {
      invfac[i] = invfac[i + 1] * (i + 1);
      inv[i + 1] = invfac[i + 1] * fac[i];
    }
  }
} C{303030};
int main() {
  int N, K;
  cin >> N >> K;
  vector<int> A(N);
  for (auto& x : A) cin >> x;
  mint ans = 1;
  for (int i = 0; i < N; i++) {
    if (i == K) break;
    map<int, int> mp;
    int all = 0;
    for (int j = i; j < N; j += K) {
      mp[A[j]]++, all++;
    }
    ans *= C.fac[all];
    for (auto& [_, val] : mp) ans *= C.invfac[val];
  }
  cout << ans.val() << "\n";
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int K = sc.nextInt();
        int[] A = new int[N];
        for (int i = 0; i < N; i++) {
            A[i] = sc.nextInt();
        }

        if (K == 0) {
            System.out.println(1);
            return;
        }

        long mod = 998244353;
        long[] factorials = new long[N + 1];
        factorials[0] = 1;
        for (int i = 1; i <= N; i++) {
            factorials[i] = factorials[i - 1] * i;
            factorials[i] %= mod;
        }

        boolean[] checked = new boolean[N];
        long ans = 1;
        for (int i = 0; i < N; i++) {
            if (checked[i]) continue;

            int totalCnt = 0;
            HashMap<Integer, Integer> counter = new HashMap<Integer, Integer>();
            for (int j = i; j < N; j += K) {
                totalCnt++;
                counter.put(A[j], counter.getOrDefault(A[j], 0) + 1);
                checked[j] = true;
            }

            long prod = 1;
            for (int cnt : counter.values()) {
                prod *= factorials[cnt];
                prod %= mod;
            };

            ans *= factorials[totalCnt] * modPow(prod, mod - 2, mod) % mod;
            ans %= mod;
        }

        System.out.println(ans);
    }

    static long modPow(long x, long p, long mod) {
        long ret = 1;
        while (0 < p) {
            if ((p & 1) == 1) {
                ret = ret * x % mod;
            }
            x = x * x % mod;
            p >>= 1;
        }

        return ret;
    }
}

Python

from collections import Counter

N, K = list(map(int, input().split()))
A = list(map(int, input().split()))

if K == 0:
    print(1)
    exit()

mod = 998244353
factorials = [0] * (N + 1)
factorials[0] = 1
for i in range(1, N + 1):
    factorials[i] = factorials[i - 1] * i
    factorials[i] %= mod

checked = [0] * N
ans = 1
for i in range(N):
    if checked[i]:
        continue

    total_cnt = 0
    counter = Counter()
    for j in range(i, N, K):
        total_cnt += 1
        counter[A[j]] += 1
        checked[j] = 1

    prod = 1
    for cnt in counter.values():
        prod *= factorials[cnt]
        prod %= mod
    ans *= factorials[total_cnt] * pow(prod, mod - 2, mod) % mod
    ans %= mod

print(ans)

せんべい(難易度6)トリオ
Senbei(Level6)/ Trio

C++(作問者解答)

※2023/10/31 修正。詳細はこちら

#include <bits/stdc++.h>
#include <atcoder/modint>
using namespace std;
using mint = atcoder::modint;
int main() {
    int N, M;
    cin >> N >> M;
    mint::set_mod(M);
    vector dp(N + 1, vector(N + 1, mint(0)));
    dp[N][0] = 1;
    array<array<mint, 4>, 4> pre = {};
    for (int i = 0; i <= 3; ++i) {
        for (int j = 0; j <= 3; ++j) {
            int k = 3 - i - j;
            int c = 6;
            for (int t = 1; t <= i; ++t) c /= t;
            for (int t = 1; t <= j; ++t) c /= t;
            for (int t = 1; t <= k; ++t) c /= t;
            pre[i][j] = c;
        }
    }
    for (int row = 0; row < N; ++row) {
        vector ndp(N + 1, vector(N + 1, mint(0)));
        for (int zero = 0; zero <= N; ++zero) {
            for (int three = 0; three <= N; ++three) {
                int one = 2 * (N - zero - three) - 3 * (row - three);
                int two = N - zero - one - three;
                if (one < 0 or two < 0 or dp[zero][three].val() == 0) {
                    continue;
                }
                for (int i = 0; i <= 3; ++i) {
                    for (int j = 0; j <= 3 - i; ++j) {
                        int k = 3 - i - j;
                        if (zero < i or one < j or two < k) {
                            continue;
                        }
                        mint coeff = pre[i][j];
                        for (int t = 0; t < j; ++t) {
                            coeff *= one - t;
                        }
                        for (int t = 0; t < k; ++t) {
                            coeff *= two - t;
                        }
                        ndp[zero - i][three + k] += dp[zero][three] * coeff;
                    }
                }
            }
        }
        dp = move(ndp);
    }
    cout << dp[0][N].val() << '\n';
    return 0;
}

かつおぶし(Katsuobushi)

難易度3(Level3)
3次方程式(Cubic Equation)

難易度4(Level4)
完全二分木の切断(Cut Perfect Binary Tree)

難易度5(Level5)
部分列(Subsequence Number)

難易度6(Level6)
二項数列(Binominal Sequence)

かつおぶし(難易度3)3次方程式
Katsuobushi(Level3)/ Cubic Equation

C++(作問者解答)

#include <iomanip>
#include <iostream>
using namespace std;
int main() {
  double A, B, C, D, L, R;
  cin >> A >> B >> C >> D >> L >> R;
  auto f = [&](double x) { return A * x * x * x + B * x * x + C * x + D; };
  double pos = L, neg = R;
  if (f(pos) < f(neg)) swap(pos, neg);
  for (int _ = 0; _ < 100; _++) {
    double m = (pos + neg) / 2;
    (f(m) > 0 ? pos : neg) = m;
  }
  double ans = (pos + neg) / 2;
  cout << fixed << setprecision(15) << ans << endl;
}

Java

import java.util.*;

public class Main {
    static int A, B, C, D;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        A = sc.nextInt();
        B = sc.nextInt();
        C = sc.nextInt();
        D = sc.nextInt();
        int L = sc.nextInt();
        int R = sc.nextInt();

        System.out.println(f(L) <= f(R) ? bisect(L, R) : bisect(R, L));
    }

    static double bisect(double nega, double posi) {
        double th = Math.pow(0.1, 9);
        while (Math.abs(posi - nega) > th) {
            double mid = (posi + nega) / 2;
            if (f(mid) > 0) {
                posi = mid;
            } else {
                nega = mid;
            }
        }
        return posi;
    }

    static double f(double x) {
        return A * Math.pow(x, 3) + B * Math.pow(x, 2) + C * x + D;
    }
}

Python

(実行環境をPyPyにして提出する必要があります)

A, B, C, D, L, R = list(map(int, input().split()))

def f(x):
    return A * (x**3) + B * (x**2) + C * x + D

def bisect(nega, posi):
    th = 10**9
    while abs(posi - nega) > 1 / th:
        mid = (posi + nega) / 2
        if f(mid) > 0:
            posi = mid
        else:
            nega = mid
    return posi

if f(L) <= f(R):
    print(bisect(L, R))
else:
    print(bisect(R, L))

かつおぶし(難易度4)完全二分木の切断
Katsuobushi(Level4)/ Cut Perfect Binary Tree

C++(作問者解答)

#include<bits/stdc++.h>
#include<atcoder/modint>
using namespace std;
using namespace atcoder;
using mint = modint998244353;
int main() {
    int n;
    string x;
    cin >> n >> x;
    
    int pos = -1;
    for (int i = 0; i < n - 1; i++) {
        if (x[i] != x[i + 1]) {
            if (pos != -1) {
                cout << 0 << endl;
                return 0;
            }
            pos = i;
        }
    }
    assert(pos != -1);
    
    cout << mint(2).pow(pos + 1).val() << endl;
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        char[] X = sc.next().toCharArray();
        long mod = 998244353;

        int pos = -1;
        for (int i = 0; i < N - 1; i++) {
            if (X[i] != X[i + 1]) {
                if (pos != -1) {
                    System.out.println(0);
                    return;
                }
                pos = i;
            }
        }

        System.out.println(modPow(2, pos + 1, mod));
    }

    static long modPow(long x, long p, long mod) {
        long ret = 1;
        while (0 < p) {
            if ((p & 1) == 1) {
                ret = ret * x % mod;
            }
            x = x * x % mod;
            p >>= 1;
        }

        return ret;
    }
}

Python

N = int(input())
X = input()
mod = 998244353

if X[0] == "1":
    X = X.translate(str.maketrans({"0": "1", "1": "0"}))

zero_cnt = X.count("0")
if "0" * zero_cnt + "1" * (N - zero_cnt) == X:
    print(pow(2, zero_cnt, mod))
else:
    print(0)

かつおぶし(難易度5)部分列
Katsuobushi(Level5)/ Subsequence Number

C++(作問者解答)

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#define rep(x, s, t) for(ll x = (s); (x) <= (t); (x)++)
using namespace std;
typedef long long ll;
ll mod = 998244353;
ll n, m;
string s;
ll dp[100005][105];
int main(void)
{
  ios::sync_with_stdio(0);
  cin.tie(0);
  cin >> n >> m;
  cin >> s;
  dp[0][0] = 1;
  rep(i, 0, n-1) rep(j, 0, m-1){
    dp[i+1][j] += dp[i][j], dp[i+1][j] %= mod;
    dp[i+1][(j*10+s[i]-'0')%m] += dp[i][j], dp[i+1][(j*10+s[i]-'0')%m] %= mod;
  }
  ll ans = dp[n][0];
  ans += mod - 1, ans %= mod;
  cout << ans << endl;
  return 0;
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int M = sc.nextInt();
        char[] S = sc.next().toCharArray();
        long mod = 998244353;

        long[][] dp = new long[N + 1][M];
        dp[0][0] = 1;
        for (int i = 0; i < N; i++) {
            int s = Character.getNumericValue(S[i]);
            for (int j = 0; j < M; j++) {
                dp[i + 1][j] += dp[i][j];
                dp[i + 1][j] %= mod;

                int k = (j * 10 + s) % M;
                dp[i + 1][k] += dp[i][j];
                dp[i + 1][k] %= mod;
            }
        }

        System.out.println((dp[N][0] - 1 + mod) % mod);
    }
}

Python

(実行環境をPyPyにして提出する必要があります)

N, M = list(map(int, input().split()))
S = input()
mod = 998244353

dp = [[0] * M for _ in range(N + 1)]
dp[0][0] = 1
for i in range(N):
    s = int(S[i])
    for j in range(M):
        dp[i + 1][j] += dp[i][j]
        dp[i + 1][j] %= mod

        k = (j * 10 + s) % M
        dp[i + 1][k] += dp[i][j]
        dp[i + 1][k] %= mod

print((dp[N][0] - 1) % mod)

かつおぶし(難易度6)二項数列
Katsuobushi(Level6)/ Binominal Sequence

C++(作問者解答)

#include <bits/stdc++.h>
using namespace std;
#define N 500010
#define ll long long
#define MOD (ll)998244353
#define rep(i, n) for(ll i = 0; i < n; ++i)
#define eb emplace_back
#define pb push_back
#define all(c) (c).begin(), (c).end()
#define vi vector<int>
#define pii pair<int,int>
#define pll pair<ll,ll>
ll k[N];
ll r[N];
ll r2[N];
bool b[30];
int main(void) {
    ll n, m;
    ll ans, s, p;
    ll a, x;
    x = MOD - 2;
    rep(i, 30) {
        if (x % 2 == 1)b[i] = true;
        else b[i] = false;
        x /= 2;
    }
    k[0] = (ll)1;
    r[0] = (ll)1;
    rep(i, N - 1) {
        k[i + 1] = (k[i] * (i + 1)) % MOD;
        r[i + 1] = (ll)1;
        x = k[i + 1];
        rep(j, 30) {
            if (b[j]) {
                r[i + 1] = (r[i + 1] * x) % MOD;
            }
            x = (x*x) % MOD;
        }
    }
    r2[0] = 1;
    r2[1] = (MOD + 1) / 2;
    rep(i, N - 2)r2[i + 2] = (r2[i + 1] * r2[1]) % MOD;
    cin >> n >> m;
    m++;
    ans = 0;
    p = 1;
    s = 0;
    rep(i, n) {
        cin >> a;
        m--;
        if (m <= s) {
            x = (r[m] * r[s - m]) % MOD;
            x = (x*k[s]) % MOD;
            x = (x*r2[s]) % MOD;
            p = (p + MOD - x) % MOD;
        }
        //cout << s << " " << m << " " << p << endl;
        rep(j, a) {
            s++;
            if (m <= s) {
                x = (r[m - 1] * r[s - m]) % MOD;
                x = (x*k[s - 1]) % MOD;
                x = (x*r2[s]) % MOD;
                p = (p + MOD - x) % MOD;
            }
            //cout << s << " " << m << " " << p << endl;
        }
        ans = (ans + p) % MOD;
        //cout << ans << endl;
    }
    cout << ans << endl;
    return 0;
}

Java

import java.util.*;
import java.util.function.BiFunction;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int M = sc.nextInt();

        int[] A = new int[N];
        for (int i = 0; i < N; i++) {
            A[i] = sc.nextInt();
        }
        long mod = 998244353;
        int ma = 5 * (int)Math.pow(10, 5) + 1;

        long[] fac = new long[ma];
        fac[0] = 1;
        long[] invFac = new long[ma];
        invFac[0] = 1;
        for (int i = 1; i < ma; i++) {
            fac[i] = fac[i - 1] * i % mod;
            invFac[i] = invFac[i - 1] * modPow(i, mod - 2, mod) % mod;
        }

        long[] inv2Ai = new long[ma];
        inv2Ai[0] = 1;
        inv2Ai[1] = (mod + 1) / 2;
        for (int i = 0; i < ma - 2; i++) {
            inv2Ai[i + 2] = (inv2Ai[i + 1] * inv2Ai[1]) % mod;
        }

        BiFunction<Integer, Integer, Long> nCr = (n, r) -> {
            long x = fac[n];
            x *= invFac[r] * invFac[n - r] % mod;
            return x % mod;
        };

        long p = 1;
        int s = 0;
        int m = M;
        long ans = 0;
        for (int i = 0; i < N; i++) {
            long a = A[i];
            if (m <= s) {
                long x = (nCr.apply(s, m) * inv2Ai[s]) % mod;
                p = (p + mod - x) % mod;
            }

            for (int j = 0; j < a; j++) {
                s += 1;
                if (m <= s) {
                    long x = (nCr.apply(s - 1, m - 1) * inv2Ai[s]) % mod;
                    p = (p + mod - x) % mod;
                }
            }

            ans = (ans + p) % mod;
            m -= 1;
        }

        System.out.println(ans);
    }

    static long modPow(long x, long p, long mod) {
        long ret = 1;
        while (0 < p) {
            if ((p & 1) == 1) {
                ret = ret * x % mod;
            }
            x = x * x % mod;
            p >>= 1;
        }

        return ret;
    }
}

Python

(実行環境をPyPyにして提出する必要があります)

N, M = list(map(int, input().split()))
A = list(map(int, input().split()))
mod = 998244353
ma = 5 * 10**5 + 1

fac = [0] * ma
fac[0] = 1
inv_fac = [0] * ma
inv_fac[0] = 1
for i in range(1, ma):
    fac[i] = fac[i - 1] * i % mod
    inv_fac[i] = inv_fac[i - 1] * pow(i, -1, mod) % mod

inv_2_ai = [0] * ma
inv_2_ai[0] = 1
inv_2_ai[1] = (mod + 1) // 2
for i in range(ma - 2):
    inv_2_ai[i + 2] = (inv_2_ai[i + 1] * inv_2_ai[1]) % mod
# same as: inv_2_ai = [pow(pow(2, i, mod), mod - 2, mod) for i in range(ma + 1)]

def nCr(n, r):
    # n! / (r!(n-r)!)
    x = fac[n]
    x *= inv_fac[r] * inv_fac[n - r] % mod
    return x % mod

p = 1
s = 0
m = M
ans = 0
for i in range(N):
    a = A[i]
    if m <= s:
        x = (nCr(s, m) * inv_2_ai[s]) % mod
        p = (p + mod - x) % mod
    for j in range(a):
        s += 1
        if m <= s:
            x = (nCr(s - 1, m - 1) * inv_2_ai[s]) % mod
            p = (p + mod - x) % mod

    ans = (ans + p) % mod
    m -= 1

print(ans)