PG BATTLE 2021 解答例一覧

各問題につきC++、Java、Pythonの解答例を用意しております。
AtCoder社による解答はC++のもので、JavaとPythonの解答例はPG BATTLE運営が用意したものとなりますので、予めご了承ください。

ましゅまろ

難易度1
物理現象グラフィックバトル

難易度3
ゼロのない整数

難易度3
一点封鎖

難易度5
無双関数

ましゅまろ(難易度1)物理現象グラフィックバトル

C++

#include<bits/stdc++.h>
using namespace std;
int main(){
  double x,y,k;
  cin >> x >> y >> k;
  printf("%.12lf\n",y-(k/x));
  return 0;
}

Java

import java.util.*;
 
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
 
        double X = sc.nextDouble();
        double Y = sc.nextDouble();
        double K = sc.nextDouble();
 
        System.out.println(Y - K / X);
    }
}

Python

X, Y, K = map(int, input().split())
print(Y - K / X)

ましゅまろ(難易度3)ゼロのない整数

C++

#include<bits/stdc++.h>
using namespace std;
int main(){
  string x;
  cin >> x;
  bool fl=false;
  for(auto &nx : x){
    if(nx=='0'){fl=true;}
    if(fl){cout << '1';}else{cout << nx;}
  }
  cout << '\n';
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        char[] X = sc.next().toCharArray();
        boolean appeared = false;
        for (int i = 0; i < X.length; i++) {
            if (appeared || X[i] == '0') {
                X[i] = '1';
                appeared = true;
            }
        }

        System.out.println(new String(X));
    }
}

Python

X = input()
idx = X.find("0")
print(X if idx == -1 else X[:idx]+"1"*(len(X)-idx))

ましゅまろ(難易度3)一点封鎖

C++

#include <atcoder/modint>
#include <bits/stdc++.h>
#define mod 998244353
using namespace std;
using namespace atcoder;
using mint = modint;
mint nCr(int n,int r){
  if(n<r){return 0;}
  mint res=1;
  for(int i=0;i<r;i++){
    res*=(n-i);
    res/=(i+1);
  }
  return res;
}
int main(){
  int n,m,a,b;
  cin >> n >> m >> a >> b;
  int ans1,ans2;
  //Solution 1
  vector<vector<int>> dp(n+1,vector<int>(m+1,0));
  dp[0][0]=1;
  for(int i=0;i<=n;i++){
    for(int j=0;j<=m;j++){
      if(i==a && j==b){dp[i][j]=0;continue;}
      if(i!=0){dp[i][j]+=dp[i-1][j];}
      if(j!=0){dp[i][j]+=dp[i][j-1];}
      dp[i][j]%=mod;
    }
  }
  ans1=dp[n][m];
  
  //Solution 2
  if(a<=n && b<=m){
    mint res=nCr(n+m,n),sub=nCr(a+b,a);
    sub*=nCr((n-a)+(m-b),(n-a));
    res-=sub;
    ans2=res.val();
  }
  else{ans2=nCr(n+m,n).val();}
  
  assert(ans1==ans2);
  cout << ans1 << '\n';
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int M = sc.nextInt();
        int A = sc.nextInt();
        int B = sc.nextInt();
        long mod = 998244353;

        long[] factorials = new long[4001];
        factorials[0] = 1;
        for (int i = 1; i < 4001; i++) {
            factorials[i] = factorials[i-1] * i % mod;
        }

        // Fermat's little theorem
        long ans = factorials[N+M] * modPow(factorials[N], mod-2, mod) % mod * modPow(factorials[M], mod-2, mod) % mod;
        if (A <= N && B <= M) {
            long ab1 = factorials[A+B] * modPow(factorials[A], mod-2, mod) % mod
                                       * modPow(factorials[B], mod-2, mod) % mod;
            long ab2 = factorials[N+M-A-B] * modPow(factorials[N-A], mod-2, mod) % mod
                                           * modPow(factorials[M-B], mod-2, mod) % mod;
            ans = (ans + mod * mod - ab1 * ab2) % mod;
        }
        System.out.println(ans);
    }

    static long modPow(long x, long n, long mod) {
        long ret = 1;
        // Repeated squaring
        while (0 < n) {
            if ((n & 1) == 1) {
                ret = ret * x % mod;
            }
            x = x * x % mod;
            n >>= 1;
        }
        return ret;
    }
}

Python

N, M, A, B = map(int, input().split())
mod = 998244353

factorials = [1] * (N+M+1)
for i in range(1, N+M+1):
    factorials[i] = factorials[i-1] * i % mod

# Fermat's little theorem
ans = factorials[N+M] * pow(factorials[N], mod-2, mod) * pow(factorials[M], mod-2, mod) % mod
if A <= N and B <= M:
    ab1 = factorials[A+B] * pow(factorials[A], mod-2, mod) * pow(factorials[B], mod-2, mod)
    ab2 = factorials[N+M-A-B] * pow(factorials[N-A], mod-2, mod) * pow(factorials[M-B], mod-2, mod)
    ans = (ans + mod*mod - ab1*ab2) % mod;

print(ans)

ましゅまろ(難易度5)無双関数

C++

#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
int main(){
  int n;
  cin >> n;
  vector<int> a(n),s(n);
  for(auto &nx : a){cin >> nx;}
  for(auto &nx : s){cin >> nx;}
  for(int i=1;i<n;i++){s[i]+=s[i-1];s[i]%=mod;}
  int res=0,p=0;
  set<int> st;
  for(int i=0;i<n;i++){
    while(st.find(a[i])!=st.end()){st.erase(a[p]);p++;}
    st.insert(a[i]);
    res+=s[i-p];res%=mod;
  }
  cout << res << '\n';
  return 0;
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int[] A = new int[N];
        for (int i = 0; i < N; i++) {
            A[i] = sc.nextInt();
        }
        int[] S = new int[N];
        for (int i = 0; i < N; i++) {
            S[i] = sc.nextInt();
        }
        int mod = 998244353;

        Map<Integer, Integer> indexes = new HashMap<Integer, Integer>();
        long[] prefixSums = new long[N];
        prefixSums[0] = S[0];
        for (int i = 1; i < N; i++) {
            prefixSums[i] = prefixSums[i-1] + S[i];
            prefixSums[i] %= mod;
        }
        long ans = 0;
        int l = 0;
        for (int r = 0; r < N; r++) {
            if (indexes.containsKey(A[r])) {
                l = Math.max(l, indexes.get(A[r]));
            }

            indexes.put(A[r], r+1);
            ans += prefixSums[r-l];
            ans %= mod;
        }

        System.out.println(ans);
    }
}

Python

from itertools import accumulate

N = int(input())
A = list(map(int, input().split()))
S = list(map(int, input().split()))
mod = 998244353

indexes = {}
prefix_sums = list(accumulate(S))
ans = 0
l = 0
for r in range(N):
    if A[r] in indexes:
        l = max(l, indexes[A[r]])
    indexes[A[r]] = r + 1
    ans += prefix_sums[r-l]
    ans %= mod

print(ans)

せんべい

難易度2
7番勝負

難易度3
連結成分数の見積もり

難易度4
トーナメント表

難易度6
[リ[[スー]バ][ズパ]ル]

せんべい(難易度2)7番勝負

C++

#include<bits/stdc++.h>
using namespace std;
int main(){
  int p;
  cin >> p;
  double wp=((double)p)/100.0;
  double lp=((double)(100-p))/100.0;
  double res=0;
  for(int i=0;i<(1<<7);i++){
    int w=0;
    for(int j=0;j<7;j++){if(i&(1<<j)){w++;}}
    if(w>=4){res+=pow(wp,w)*pow(lp,7-w);}
  }
  printf("%.12lf\n",100.0*res);
  return 0;
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        double P = sc.nextDouble() / 100;
        double ans = 0;
        for (int i = 4; i <= 7; i++) {
            double tmp = 1;
            for (int j = 0; j < i; j++) {
                tmp *= 7 - j;
                tmp /= j + 1;
                tmp *= P;
            }
            for (int j = 0; j < 7 - i; j++) {
                tmp *= 1 - P;
            }
            ans += tmp;
        }
        System.out.println(ans * 100);
    }
}

Python

P = int(input()) / 100;

ans = 0;
for i in range(4, 8):
    tmp = 1;
    for j in range(i):
        tmp *= 7 - j;
        tmp /= j + 1;
        tmp *= P;
    for j in range(7-i):
        tmp *= 1 - P;
    ans += tmp;

print(ans * 100)

せんべい(難易度3)連結成分数の見積もり

C++

#include<bits/stdc++.h>
using namespace std;
int main(){
  int t;
  cin >> t;
  while(t>0){
    t--;
    long long n,m;
    cin >> n >> m;
    cout << max(1ll,n-m) << ' ';
    long long st=1,fi=1500000000,te;
    while(st<=fi){
      te=(st+fi)/2;
      if((te*(te-1))/2<m){st=te+1;}else{fi=te-1;}
    }
    cout << n+1-st << '\n';
  }
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int T = sc.nextInt();
        for (int i = 0; i < T; i++) {
            int N = sc.nextInt();
            long M = sc.nextLong();

            long min = Math.max(1, N-M);

            // Binary search
            long l = 1;
            long r = 1000000001;
            while (l <= r) {
                long mid = (l + r) / 2;
                if (mid * (mid - 1) / 2 < M) {
                    l = mid + 1;
                } else {
                    r = mid - 1;
                }
            }
            long max = N - l + 1;

            System.out.println(min + " " + max);
        }
    }
}

Python

T = int(input())
for _ in range(T):
    N, M = map(int, input().split())

    mi = max(1, N-M)

    # Binary search
    l = 1
    r = 1000000001
    while l <= r:
        mid = (l + r) // 2
        if mid * (mid - 1) // 2 < M:
            l = mid + 1
        else:
            r = mid - 1
    ma = N - l + 1

    print(mi, ma)

せんべい(難易度4)トーナメント表

C++

#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
int main() {
  int N, M;
  cin >> N;
  M = 1 << N;
  vector<vector<int>> a(N + 1);
  for (int i = 1, x; i <= M; i++) {
    cin >> x;
    a[__builtin_ctz(x)].push_back(i);
  }
  reverse(begin(a), end(a));
  vector<int> ans;
  for (int i = 1; i <= M; i++) {
    int j = __builtin_ctz(i);
    if (a[j].empty()) {
      cout << -1 << "\n";
      exit(0);
    }
    ans.push_back(a[j].back());
    a[j].pop_back();
  }
  for (int i = 0; i < M; i++) cout << ans[i] << " \n"[i == M - 1];
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int siz = pow(2, N);
        int[][] A = new int[siz][2];
        for (int i = 0; i < siz; i++) {
            A[i][0] = sc.nextInt();
            A[i][1] = i + 1;
        }
        Arrays.sort(A, (a, b) -> Integer.compare(-a[0], -b[0]));
        int[] tree = new int[siz];
        int k = 0;
        for (int i = 0; i < N+1; i++) {
            for (int j = pow(2, i)-1; j < siz; j += pow(2, i+1)) {
                if (pow(2, N-i) != A[k][0]) {
                    System.out.println(-1);
                    return;
                }

                tree[j] = A[k][1];
                k++;
            }
        }

        StringJoiner ans = new StringJoiner(" ");
        Arrays.stream(tree).forEach(x -> ans.add(String.valueOf(x)));
        System.out.println(ans.toString());
    }

    static int pow(int x, int n) {
        return (int) Math.pow(x, n);
    }
}

Python

N = int(input())
A = list(map(int, input().split()))

lst = [[A[i], i+1] for i in range(2**N)]
lst.sort(reverse=True)

tree = [0] * 2**N
k = 0
for i in range(N+1):
    for j in range(2**i-1, 2**N, 2**(i+1)):
        if 2**(N-i) != lst[k][0]:
            print(-1)
            exit()

        tree[j] = lst[k][1]
        k += 1

print(*tree)

せんべい(難易度6)[リ[[スー]バ][ズパ]ル]

C++

#include <algorithm>
#include <cassert>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
  int N, M;
  cin >> N >> M;
  vector<int> P(N);
  for (auto& x : P) cin >> x;
  vector<pair<int, int>> AB(M);
  for (auto& [a, b] : AB) cin >> a >> b;
  // 簡単のため 0-indexed の 半開区間にする
  for (auto& [a, _] : AB) --a;
  // 区間全体を入れる
  AB.emplace_back(0, N);
  // ソート
  sort(begin(AB), end(AB), [](auto l, auto r) {
    return l.first == r.first ? l.second > r.second : l.first < r.first;
  });
  assert(AB[0] == make_pair(0, N));
  // 包含関係に対応するグラフを構築
  vector<vector<int>> g(M + 1);
  vector<int> par(M + 1, -1);
  for (int i = 1; i <= M; i++) {
    auto [a, b] = AB[i];
    // 親のノードは?
    int p = i - 1;
    // p に完全に含まれるまで親にさかのぼる
    // さかのぼる回数は償却 O(1) (cf : オイラーツアー)
    while (true) {
      auto [pa, pb] = AB[p];
      // pa a b pb の順 (b = pb を許容) -> 包含
      if (b <= pb) break;
      // 並び順は pa pb a b なので親にさかのぼる必要あり
      p = par[p];
    }
    // i を p の子とする
    g[p].push_back(i);
    par[i] = p;
  }
  // 区間の先頭要素の最小値を記録する配列
  vector<int> mn(M + 1, 1 << 30);
  // reverse するかを表す配列
  // (子のmin, 子のid(ブロック単体の場合-1) )
  vector<vector<pair<int, int>>> chd(M + 1);
  // DFSで順番を決定する
  auto dfs = [&](auto rc, int c) -> void {
    auto [a, b] = AB[c];
    int cid = 0;
    for (int i = a; i < b;) {
      int upd = 0;
      if (cid < (int)g[c].size()) {
        int d = g[c][cid];
        auto [ca, cb] = AB[d];
        if (i == ca) {
          // dを使う。再帰的に最小値を計算
          rc(rc, d);
          // chd[c] に pushして更新作業
          chd[c].emplace_back(mn[d], d);
          cid++;
          i = cb;
          upd = 1;
        }
      }
      if (!upd) {
        chd[c].emplace_back(P[i], -1);
        i++;
      }
    }
    if (chd[c].back().first < chd[c][0].first) {
      reverse(begin(chd[c]), end(chd[c]));
    }
    mn[c] = chd[c][0].first;
    return;
  };
  dfs(dfs, 0);
  vector<int> ans;
  // DFS で解を構成
  auto dfs2 = [&](auto rc, int c) -> void {
    for (auto& [val, id] : chd[c]) {
      if (id == -1) {
        ans.push_back(val);
      } else {
        rc(rc, id);
      }
    }
  };
  dfs2(dfs2, 0);
  assert((int)ans.size() == N);
  for (int i = 0; i < N; i++) cout << ans[i] << " \n"[i == N - 1];
}

Java

import java.util.*;

public class Main {
    static int N;
    static int M;
    static int[] P;
    static int[][] AB;
    static List<List<Integer>> g;
    static List<List<Integer[]>> chd;
    static List<String> ans;
    static int[] mn;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        N = Integer.parseInt(sc.next());
        M = Integer.parseInt(sc.next());
        P = new int[N];
        for (int i = 0; i < N; i++) {
            P[i] = Integer.parseInt(sc.next());
        }
        AB = new int[M+1][2];
        for (int i = 0; i < M; i++) {
            // 簡単のため 0-indexed の 半開区間にする
            AB[i][0] = Integer.parseInt(sc.next()) - 1;
            AB[i][1] = Integer.parseInt(sc.next());
        }
        // 区間全体を入れる
        AB[M] = new int[] {0, N};

        // ソート
        Arrays.sort(AB, (x, y) -> x[0] == y[0] ? y[1] - x[1] : x[0] - y[0]);

        // 包含関係に対応するグラフを構築
        g = new ArrayList<List<Integer>>();
        int[] par = new int[M+1];

        // 区間の先頭要素の最小値を記録する配列
        mn = new int[M+1];
        int init = 1 << 30;

        // reverse するかを表す配列
        // (子のmin, 子のid(ブロック単体の場合-1) )
        chd = new ArrayList<List<Integer[]>>();

        for (int i = 0; i < M+1; i++) {
            g.add(new ArrayList<Integer>());
            par[i] = -1;
            mn[i] = init;
            chd.add(new ArrayList<Integer[]>());
        }

        for (int i = 1; i <= M; i++) {
            int a = AB[i][0];
            int b = AB[i][1];

            // 親のノードは?
            int p = i - 1;

            // p に完全に含まれるまで親にさかのぼる
            // さかのぼる回数は償却 O(1) (cf : オイラーツアー)
            while (true) {
                int pa = AB[p][0];
                int pb = AB[p][1];
                // pa a b pb の順 (b = pb を許容) -> 包含
                if (b <= pb) {
                    break;
                }
                // 並び順は pa pb a b なので親にさかのぼる必要あり
                p = par[p];
            }
            // i を p の子とする
            g.get(p).add(i);
            par[i] = p;
        }

        // DFSで順番を決定する
        dfs(0);

        ans = new ArrayList<String>();
        // DFS で解を構成
        dfs2(0);

        System.out.println(String.join(" ", ans));
    }

    static void dfs(int c) {
        int a = AB[c][0];
        int b = AB[c][1];
        int cid = 0;
        for (int i = a; i < b;) {
            boolean upd = false;
            if (cid < g.get(c).size()) {
                int d = g.get(c).get(cid);
                int ca = AB[d][0];
                int cb = AB[d][1];
                if (i == ca) {
                    // dを使う。再帰的に最小値を計算
                    dfs(d);
                    // chd[c] に addして更新作業
                    chd.get(c).add(new Integer[] {mn[d], d});
                    cid++;
                    i = cb;
                    upd = true;
                }
            }
            if (!upd) {
                chd.get(c).add(new Integer[] {P[i], -1});
                i++;
            }
        }
        if (chd.get(c).get(chd.get(c).size()-1)[0] < chd.get(c).get(0)[0]) {
            Collections.reverse(chd.get(c));
        }
        mn[c] = chd.get(c).get(0)[0];
        return;
    }

    static void dfs2(int c) {
        for (Integer[] ele : chd.get(c)) {
            int val = ele[0];
            int id = ele[1];
            if (id == -1) {
                ans.add(String.valueOf(val));
            } else {
                dfs2(id);
            }
        }
    };
}

Python

import sys
sys.setrecursionlimit(10**9)

N, M = list(map(int, input().split()))
P = list(map(int, input().split()))

AB = [list(map(int, input().split())) for _ in range(M)]
for i in range(M):
    AB[i][0] -= 1

AB.append([0,N])
AB.sort(key = lambda x: (x[0], -x[1]))

g = [[] for _ in range(M+1)]
par = [-1] * (M+1)
for i in range(1, M+1):
    a, b = AB[i]
    p = i - 1
    while True:
        pa, pb = AB[p]
        if b <= pb:
            break
        p = par[p]
    g[p].append(i)
    par[i] = p

mn = [1<<30] * (M+1)
chd = [[] for _ in range((M+1))]

def dfs(c):
    a, b = AB[c]
    cid = 0
    i = a
    while i < b:
        upd = 0
        if cid < len(g[c]):
            d = g[c][cid]
            ca, cb = AB[d]
            if i == ca:
                dfs(d)
                chd[c].append([mn[d], d])
                cid += 1
                i = cb
                upd = 1
        if not upd:
            chd[c].append([P[i], -1])
            i += 1
    if chd[c][-1][0] < chd[c][0][0]:
        chd[c].reverse()
    mn[c] = chd[c][0][0]
    return
dfs(0)

ans = []
def dfs2(c):
    for val, id in chd[c]:
        if id == -1:
            ans.append(val)
        else:
            dfs2(id)
dfs2(0)

print(*ans)

かつおぶし

難易度3
階乗の桁数

難易度4
桁と数列

難易度6
ペアなすごろく

難易度6
コイン投げ

かつおぶし(難易度3)階乗の桁数

C++

#include <cmath>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
  int N;
  cin >> N;
  double lg = 0;
  for (int i = 1; i <= N; i++) lg += log10(i);
  cout << int(lg) + 1 << "\n";
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();

        double ans = 1;
        for (int i = 1; i <= N; i++) {
            ans += Math.log10(i);
        }

        System.out.println((int)ans);
    }
}

Python

from math import log10
print(int(sum(map(log10, range(1,int(input())+1))))+1)

かつおぶし(難易度4)桁と数列

C++

#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
  long long N, S;
  cin >> N >> S;
  vector<long long> d(N);
  for (auto& x : d) cin >> x;
  // i桁の整数の最小値は?
  vector<long long> mn(16);
  mn[1] = 1;
  for (int i = 2; i <= 15; i++) mn[i] = mn[i - 1] * 10;
  // 最小値を設定
  vector<long long> a(N);
  rep(i, N) {
    a[i] = mn[d[i]];
    S -= a[i];
    if (S < 0) {
      cout << "-1\n";
      exit(0);
    }
  }
  // できるだけ大きくする
  rep(i, N) {
    long long dif = mn[d[i] + 1] - 1 - a[i];
    dif = min(S, dif);
    a[i] += dif;
    S -= dif;
  }
  // 最大限大きくしてもSが残ってたらダメ
  if (S > 0) {
    cout << "-1\n";
    exit(0);
  }
  rep(i, N) cout << a[i] << " \n"[i == N - 1];
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int S = sc.nextInt();
        int total = 0;
        int[] lst = new int[N];
        for (int i = 0; i < N; i++) {
            total += lst[i] = (int) Math.pow(10, sc.nextInt()-1);
        }

        if (S < total) {
            System.out.println(-1);
            return;
        } else {
            S -= total;
        }

        for (int i = 0; i < N; i++) {
            int add = Math.min(S, lst[i]*9-1);
            lst[i] += add;
            S -= add;
        }

        if (S == 0) {
            StringJoiner ans = new StringJoiner(" ");
            Arrays.stream(lst).forEach(x -> ans.add(String.valueOf(x)));
            System.out.println(ans.toString());
        } else {
            System.out.println(-1);
        }
    }
}

Python

N, S = map(int, input().split())
d = list(map(int, input().split()))

ans = [10**(di-1) for di in d]
total = sum(ans)
if S < total:
    print(-1)
    exit()
else:
    S -= total

for i in range(N):
    add = min(S, ans[i]*9-1)
    ans[i] += add
    S -= add

if S:
    print(-1)
else:
    print(*ans)

かつおぶし(難易度6)ペアなすごろく

C++

#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
long long llsankaku(long long x){return ((1+x)*x)/2;}
long long power(long long a,long long b){
  long long x=1,y=a;
  while(b>0){
    if(b&1ll){
      x=(x*y)%mod;
    }
    y=(y*y)%mod;
    b>>=1;
  }
  return x%mod;
}
long long modular_inverse(long long n){
  return power(n,mod-2);
}
#define MAXN 524288
int main(){
  long long n;
  cin >> n;
  vector<long long> r(n);
  for(auto &nx : r){cin >> nx;}
  vector<long long> acc(MAXN,0);
  long long res=0;
  long long prob=1,delta=(mod-1);
  acc[0]=1;
  for(long long i=0;i<MAXN;i++){
    if(i<n){
      long long q=llsankaku(r[i])%mod;
      q=modular_inverse(q);
      q*=prob;q%=mod;
      delta+=(mod-q);delta%=mod;
      acc[i+r[i]]+=q;acc[i+r[i]]%=mod;
      prob+=q*(r[i]+1);prob%=mod;
    }
    else{res+=prob*(i+1);res%=mod;}
    prob+=delta;prob%=mod;
    delta+=acc[i];delta%=mod;
  }
  cout << res << '\n';
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int[] R = new int[N];
        for (int i = 0; i < N; i++) {
            R[i] = sc.nextInt();
        }
        long mod = 998244353;

        int max = 400001;
        long[] acc = new long[max];
        long ans = 0;
        long prob = 1;
        long delta = mod - 1;
        acc[0] = 1;
        for (int i = 0; i < max; i++) {
            if (i < N) {
                long q = (long)R[i] * (R[i]+1) / 2 % mod;
                q = modPow(q, mod-2, mod) * prob % mod;
                delta = (delta + mod - q) % mod;
                acc[i+R[i]] = (acc[i+R[i]] + q) % mod;
                prob = (prob + q * (R[i]+1)) % mod;
            } else {
                ans = (ans + prob * (i+1)) % mod;
            }
            prob = (prob + delta) % mod;
            delta = (delta + acc[i]) % mod;
        }
        System.out.println(ans);
    }

    static long modPow(long x, long n, long mod) {
        long ret = 1;
        // Repeated squaring
        while (0 < n) {
            if ((n&1) == 1) {
                ret = ret * x % mod;
            }
            x = x * x % mod;
            n >>= 1;
        }
        return ret;
    }
}

Python

N = int(input())
R = list(map(int, input().split()))
mod = 998244353

acc = [0]*(4*10**5+1)
res = 0
prob = 1
delta = mod - 1
acc[0] = 1
for i in range(4*10**5+1):
    if i < N:
        q = R[i] * (R[i]+1) // 2 % mod
        q = pow(q, mod-2, mod) * prob % mod
        delta = (delta+mod-q) % mod
        acc[i+R[i]] = (acc[i+R[i]]+q) % mod
        prob = (prob+q*(R[i]+1)) % mod
    else:
        res = (res+prob*(i+1)) % mod
    prob = (prob+delta) % mod
    delta = (delta+acc[i]) % mod

print(res)

かつおぶし(難易度6)コイン投げ

C++

#include <iostream>
#include <string>
#include <vector>
using namespace std;
constexpr int MOD = 998244353;
int main() {
  string S;
  cin >> S;
  vector<int> A(S.size() + 1, -1), dp(S.size() + 1);
  for (int i = 0, j = -1, k; i < (int)S.size(); i++) {
    while (~j && S[i] != S[j]) j = A[j];
    A[i + 1] = S[i + 1] == S[++j] ? A[j] : j;
    for (k = A[i]; ~k && S[i] == S[k];) k = A[k];
    dp[i + 1] = ((dp[i] * 2 - dp[++k] - 2) % MOD + MOD) % MOD;
  }
  cout << (-dp.back() % MOD + MOD) % MOD << "\n";
}

Java

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        char[] S = (sc.next() + "_").toCharArray();
        int len = S.length;
        int mod = 998244353;
        int[] A = new int[len];
        Arrays.fill(A, -1);

        int[] dp = new int[len];
        for (int i = 0, j = -1; i < len-1; i++) {
            while ((~j != 0) && (S[i] != S[j])) {
                j = A[j];
            }

            A[i+1] = (S[i+1] == S[++j] ? A[j] : j);
            int k = A[i];
            for (k = A[i]; (~k != 0) && (S[i] == S[k]);) {
                k = A[k];
            }
            dp[i+1] = ((dp[i] * 2 - dp[k+1] - 2) % mod + mod) % mod;
        }

        System.out.println((-dp[len-1] % mod + mod) % mod);
    }
}

Python

S = input()
siz = len(S)
S = S + " "
mod = 998244353

A = [-1] * (siz + 1)
dp = [0] * (siz + 1)
j = -1
for i in range(siz):
    while ~j and (S[i] != S[j]):
        j = A[j]

    j += 1
    A[i+1] = (A[j] if S[i+1] == S[j] else j)
    k = A[i]
    while ~k and (S[i] == S[k]):
        k = A[k]
    dp[i+1] = ((dp[i] * 2 - dp[k+1] - 2) % mod + mod) % mod

print((-dp[-1] % mod + mod) % mod)