PG BATTLE 2021 解答例一覧
各問題につきC++、Java、Pythonの解答例を用意しております。
AtCoder社による解答はC++のもので、JavaとPythonの解答例はPG BATTLE運営が用意したものとなりますので、予めご了承ください。
ましゅまろ(難易度1)物理現象グラフィックバトル
C++
#include<bits/stdc++.h> using namespace std; int main(){ double x,y,k; cin >> x >> y >> k; printf("%.12lf\n",y-(k/x)); return 0; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); double X = sc.nextDouble(); double Y = sc.nextDouble(); double K = sc.nextDouble(); System.out.println(Y - K / X); } }
Python
X, Y, K = map(int, input().split()) print(Y - K / X)
ましゅまろ(難易度3)ゼロのない整数
C++
#include<bits/stdc++.h> using namespace std; int main(){ string x; cin >> x; bool fl=false; for(auto &nx : x){ if(nx=='0'){fl=true;} if(fl){cout << '1';}else{cout << nx;} } cout << '\n'; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); char[] X = sc.next().toCharArray(); boolean appeared = false; for (int i = 0; i < X.length; i++) { if (appeared || X[i] == '0') { X[i] = '1'; appeared = true; } } System.out.println(new String(X)); } }
Python
X = input() idx = X.find("0") print(X if idx == -1 else X[:idx]+"1"*(len(X)-idx))
ましゅまろ(難易度3)一点封鎖
C++
#include <atcoder/modint> #include <bits/stdc++.h> #define mod 998244353 using namespace std; using namespace atcoder; using mint = modint; mint nCr(int n,int r){ if(n<r){return 0;} mint res=1; for(int i=0;i<r;i++){ res*=(n-i); res/=(i+1); } return res; } int main(){ int n,m,a,b; cin >> n >> m >> a >> b; int ans1,ans2; //Solution 1 vector<vector<int>> dp(n+1,vector<int>(m+1,0)); dp[0][0]=1; for(int i=0;i<=n;i++){ for(int j=0;j<=m;j++){ if(i==a && j==b){dp[i][j]=0;continue;} if(i!=0){dp[i][j]+=dp[i-1][j];} if(j!=0){dp[i][j]+=dp[i][j-1];} dp[i][j]%=mod; } } ans1=dp[n][m]; //Solution 2 if(a<=n && b<=m){ mint res=nCr(n+m,n),sub=nCr(a+b,a); sub*=nCr((n-a)+(m-b),(n-a)); res-=sub; ans2=res.val(); } else{ans2=nCr(n+m,n).val();} assert(ans1==ans2); cout << ans1 << '\n'; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int M = sc.nextInt(); int A = sc.nextInt(); int B = sc.nextInt(); long mod = 998244353; long[] factorials = new long[4001]; factorials[0] = 1; for (int i = 1; i < 4001; i++) { factorials[i] = factorials[i-1] * i % mod; } // Fermat's little theorem long ans = factorials[N+M] * modPow(factorials[N], mod-2, mod) % mod * modPow(factorials[M], mod-2, mod) % mod; if (A <= N && B <= M) { long ab1 = factorials[A+B] * modPow(factorials[A], mod-2, mod) % mod * modPow(factorials[B], mod-2, mod) % mod; long ab2 = factorials[N+M-A-B] * modPow(factorials[N-A], mod-2, mod) % mod * modPow(factorials[M-B], mod-2, mod) % mod; ans = (ans + mod * mod - ab1 * ab2) % mod; } System.out.println(ans); } static long modPow(long x, long n, long mod) { long ret = 1; // Repeated squaring while (0 < n) { if ((n & 1) == 1) { ret = ret * x % mod; } x = x * x % mod; n >>= 1; } return ret; } }
Python
N, M, A, B = map(int, input().split()) mod = 998244353 factorials = [1] * (N+M+1) for i in range(1, N+M+1): factorials[i] = factorials[i-1] * i % mod # Fermat's little theorem ans = factorials[N+M] * pow(factorials[N], mod-2, mod) * pow(factorials[M], mod-2, mod) % mod if A <= N and B <= M: ab1 = factorials[A+B] * pow(factorials[A], mod-2, mod) * pow(factorials[B], mod-2, mod) ab2 = factorials[N+M-A-B] * pow(factorials[N-A], mod-2, mod) * pow(factorials[M-B], mod-2, mod) ans = (ans + mod*mod - ab1*ab2) % mod; print(ans)
ましゅまろ(難易度5)無双関数
C++
#include<bits/stdc++.h> #define mod 998244353 using namespace std; int main(){ int n; cin >> n; vector<int> a(n),s(n); for(auto &nx : a){cin >> nx;} for(auto &nx : s){cin >> nx;} for(int i=1;i<n;i++){s[i]+=s[i-1];s[i]%=mod;} int res=0,p=0; set<int> st; for(int i=0;i<n;i++){ while(st.find(a[i])!=st.end()){st.erase(a[p]);p++;} st.insert(a[i]); res+=s[i-p];res%=mod; } cout << res << '\n'; return 0; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int[] A = new int[N]; for (int i = 0; i < N; i++) { A[i] = sc.nextInt(); } int[] S = new int[N]; for (int i = 0; i < N; i++) { S[i] = sc.nextInt(); } int mod = 998244353; Map<Integer, Integer> indexes = new HashMap<Integer, Integer>(); long[] prefixSums = new long[N]; prefixSums[0] = S[0]; for (int i = 1; i < N; i++) { prefixSums[i] = prefixSums[i-1] + S[i]; prefixSums[i] %= mod; } long ans = 0; int l = 0; for (int r = 0; r < N; r++) { if (indexes.containsKey(A[r])) { l = Math.max(l, indexes.get(A[r])); } indexes.put(A[r], r+1); ans += prefixSums[r-l]; ans %= mod; } System.out.println(ans); } }
Python
from itertools import accumulate N = int(input()) A = list(map(int, input().split())) S = list(map(int, input().split())) mod = 998244353 indexes = {} prefix_sums = list(accumulate(S)) ans = 0 l = 0 for r in range(N): if A[r] in indexes: l = max(l, indexes[A[r]]) indexes[A[r]] = r + 1 ans += prefix_sums[r-l] ans %= mod print(ans)
せんべい(難易度2)7番勝負
C++
#include<bits/stdc++.h> using namespace std; int main(){ int p; cin >> p; double wp=((double)p)/100.0; double lp=((double)(100-p))/100.0; double res=0; for(int i=0;i<(1<<7);i++){ int w=0; for(int j=0;j<7;j++){if(i&(1<<j)){w++;}} if(w>=4){res+=pow(wp,w)*pow(lp,7-w);} } printf("%.12lf\n",100.0*res); return 0; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); double P = sc.nextDouble() / 100; double ans = 0; for (int i = 4; i <= 7; i++) { double tmp = 1; for (int j = 0; j < i; j++) { tmp *= 7 - j; tmp /= j + 1; tmp *= P; } for (int j = 0; j < 7 - i; j++) { tmp *= 1 - P; } ans += tmp; } System.out.println(ans * 100); } }
Python
P = int(input()) / 100; ans = 0; for i in range(4, 8): tmp = 1; for j in range(i): tmp *= 7 - j; tmp /= j + 1; tmp *= P; for j in range(7-i): tmp *= 1 - P; ans += tmp; print(ans * 100)
せんべい(難易度3)連結成分数の見積もり
C++
#include<bits/stdc++.h> using namespace std; int main(){ int t; cin >> t; while(t>0){ t--; long long n,m; cin >> n >> m; cout << max(1ll,n-m) << ' '; long long st=1,fi=1500000000,te; while(st<=fi){ te=(st+fi)/2; if((te*(te-1))/2<m){st=te+1;}else{fi=te-1;} } cout << n+1-st << '\n'; } }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int T = sc.nextInt(); for (int i = 0; i < T; i++) { int N = sc.nextInt(); long M = sc.nextLong(); long min = Math.max(1, N-M); // Binary search long l = 1; long r = 1000000001; while (l <= r) { long mid = (l + r) / 2; if (mid * (mid - 1) / 2 < M) { l = mid + 1; } else { r = mid - 1; } } long max = N - l + 1; System.out.println(min + " " + max); } } }
Python
T = int(input()) for _ in range(T): N, M = map(int, input().split()) mi = max(1, N-M) # Binary search l = 1 r = 1000000001 while l <= r: mid = (l + r) // 2 if mid * (mid - 1) // 2 < M: l = mid + 1 else: r = mid - 1 ma = N - l + 1 print(mi, ma)
せんべい(難易度4)トーナメント表
C++
#include <algorithm> #include <iostream> #include <vector> using namespace std; int main() { int N, M; cin >> N; M = 1 << N; vector<vector<int>> a(N + 1); for (int i = 1, x; i <= M; i++) { cin >> x; a[__builtin_ctz(x)].push_back(i); } reverse(begin(a), end(a)); vector<int> ans; for (int i = 1; i <= M; i++) { int j = __builtin_ctz(i); if (a[j].empty()) { cout << -1 << "\n"; exit(0); } ans.push_back(a[j].back()); a[j].pop_back(); } for (int i = 0; i < M; i++) cout << ans[i] << " \n"[i == M - 1]; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int siz = pow(2, N); int[][] A = new int[siz][2]; for (int i = 0; i < siz; i++) { A[i][0] = sc.nextInt(); A[i][1] = i + 1; } Arrays.sort(A, (a, b) -> Integer.compare(-a[0], -b[0])); int[] tree = new int[siz]; int k = 0; for (int i = 0; i < N+1; i++) { for (int j = pow(2, i)-1; j < siz; j += pow(2, i+1)) { if (pow(2, N-i) != A[k][0]) { System.out.println(-1); return; } tree[j] = A[k][1]; k++; } } StringJoiner ans = new StringJoiner(" "); Arrays.stream(tree).forEach(x -> ans.add(String.valueOf(x))); System.out.println(ans.toString()); } static int pow(int x, int n) { return (int) Math.pow(x, n); } }
Python
N = int(input()) A = list(map(int, input().split())) lst = [[A[i], i+1] for i in range(2**N)] lst.sort(reverse=True) tree = [0] * 2**N k = 0 for i in range(N+1): for j in range(2**i-1, 2**N, 2**(i+1)): if 2**(N-i) != lst[k][0]: print(-1) exit() tree[j] = lst[k][1] k += 1 print(*tree)
せんべい(難易度6)[リ[[スー]バ][ズパ]ル]
C++
#include <algorithm> #include <cassert> #include <iostream> #include <utility> #include <vector> using namespace std; #define rep(i, n) for (int i = 0; i < (n); i++) int main() { int N, M; cin >> N >> M; vector<int> P(N); for (auto& x : P) cin >> x; vector<pair<int, int>> AB(M); for (auto& [a, b] : AB) cin >> a >> b; // 簡単のため 0-indexed の 半開区間にする for (auto& [a, _] : AB) --a; // 区間全体を入れる AB.emplace_back(0, N); // ソート sort(begin(AB), end(AB), [](auto l, auto r) { return l.first == r.first ? l.second > r.second : l.first < r.first; }); assert(AB[0] == make_pair(0, N)); // 包含関係に対応するグラフを構築 vector<vector<int>> g(M + 1); vector<int> par(M + 1, -1); for (int i = 1; i <= M; i++) { auto [a, b] = AB[i]; // 親のノードは? int p = i - 1; // p に完全に含まれるまで親にさかのぼる // さかのぼる回数は償却 O(1) (cf : オイラーツアー) while (true) { auto [pa, pb] = AB[p]; // pa a b pb の順 (b = pb を許容) -> 包含 if (b <= pb) break; // 並び順は pa pb a b なので親にさかのぼる必要あり p = par[p]; } // i を p の子とする g[p].push_back(i); par[i] = p; } // 区間の先頭要素の最小値を記録する配列 vector<int> mn(M + 1, 1 << 30); // reverse するかを表す配列 // (子のmin, 子のid(ブロック単体の場合-1) ) vector<vector<pair<int, int>>> chd(M + 1); // DFSで順番を決定する auto dfs = [&](auto rc, int c) -> void { auto [a, b] = AB[c]; int cid = 0; for (int i = a; i < b;) { int upd = 0; if (cid < (int)g[c].size()) { int d = g[c][cid]; auto [ca, cb] = AB[d]; if (i == ca) { // dを使う。再帰的に最小値を計算 rc(rc, d); // chd[c] に pushして更新作業 chd[c].emplace_back(mn[d], d); cid++; i = cb; upd = 1; } } if (!upd) { chd[c].emplace_back(P[i], -1); i++; } } if (chd[c].back().first < chd[c][0].first) { reverse(begin(chd[c]), end(chd[c])); } mn[c] = chd[c][0].first; return; }; dfs(dfs, 0); vector<int> ans; // DFS で解を構成 auto dfs2 = [&](auto rc, int c) -> void { for (auto& [val, id] : chd[c]) { if (id == -1) { ans.push_back(val); } else { rc(rc, id); } } }; dfs2(dfs2, 0); assert((int)ans.size() == N); for (int i = 0; i < N; i++) cout << ans[i] << " \n"[i == N - 1]; }
Java
import java.util.*; public class Main { static int N; static int M; static int[] P; static int[][] AB; static List<List<Integer>> g; static List<List<Integer[]>> chd; static List<String> ans; static int[] mn; public static void main(String[] args) { Scanner sc = new Scanner(System.in); N = Integer.parseInt(sc.next()); M = Integer.parseInt(sc.next()); P = new int[N]; for (int i = 0; i < N; i++) { P[i] = Integer.parseInt(sc.next()); } AB = new int[M+1][2]; for (int i = 0; i < M; i++) { // 簡単のため 0-indexed の 半開区間にする AB[i][0] = Integer.parseInt(sc.next()) - 1; AB[i][1] = Integer.parseInt(sc.next()); } // 区間全体を入れる AB[M] = new int[] {0, N}; // ソート Arrays.sort(AB, (x, y) -> x[0] == y[0] ? y[1] - x[1] : x[0] - y[0]); // 包含関係に対応するグラフを構築 g = new ArrayList<List<Integer>>(); int[] par = new int[M+1]; // 区間の先頭要素の最小値を記録する配列 mn = new int[M+1]; int init = 1 << 30; // reverse するかを表す配列 // (子のmin, 子のid(ブロック単体の場合-1) ) chd = new ArrayList<List<Integer[]>>(); for (int i = 0; i < M+1; i++) { g.add(new ArrayList<Integer>()); par[i] = -1; mn[i] = init; chd.add(new ArrayList<Integer[]>()); } for (int i = 1; i <= M; i++) { int a = AB[i][0]; int b = AB[i][1]; // 親のノードは? int p = i - 1; // p に完全に含まれるまで親にさかのぼる // さかのぼる回数は償却 O(1) (cf : オイラーツアー) while (true) { int pa = AB[p][0]; int pb = AB[p][1]; // pa a b pb の順 (b = pb を許容) -> 包含 if (b <= pb) { break; } // 並び順は pa pb a b なので親にさかのぼる必要あり p = par[p]; } // i を p の子とする g.get(p).add(i); par[i] = p; } // DFSで順番を決定する dfs(0); ans = new ArrayList<String>(); // DFS で解を構成 dfs2(0); System.out.println(String.join(" ", ans)); } static void dfs(int c) { int a = AB[c][0]; int b = AB[c][1]; int cid = 0; for (int i = a; i < b;) { boolean upd = false; if (cid < g.get(c).size()) { int d = g.get(c).get(cid); int ca = AB[d][0]; int cb = AB[d][1]; if (i == ca) { // dを使う。再帰的に最小値を計算 dfs(d); // chd[c] に addして更新作業 chd.get(c).add(new Integer[] {mn[d], d}); cid++; i = cb; upd = true; } } if (!upd) { chd.get(c).add(new Integer[] {P[i], -1}); i++; } } if (chd.get(c).get(chd.get(c).size()-1)[0] < chd.get(c).get(0)[0]) { Collections.reverse(chd.get(c)); } mn[c] = chd.get(c).get(0)[0]; return; } static void dfs2(int c) { for (Integer[] ele : chd.get(c)) { int val = ele[0]; int id = ele[1]; if (id == -1) { ans.add(String.valueOf(val)); } else { dfs2(id); } } }; }
Python
import sys sys.setrecursionlimit(10**9) N, M = list(map(int, input().split())) P = list(map(int, input().split())) AB = [list(map(int, input().split())) for _ in range(M)] for i in range(M): AB[i][0] -= 1 AB.append([0,N]) AB.sort(key = lambda x: (x[0], -x[1])) g = [[] for _ in range(M+1)] par = [-1] * (M+1) for i in range(1, M+1): a, b = AB[i] p = i - 1 while True: pa, pb = AB[p] if b <= pb: break p = par[p] g[p].append(i) par[i] = p mn = [1<<30] * (M+1) chd = [[] for _ in range((M+1))] def dfs(c): a, b = AB[c] cid = 0 i = a while i < b: upd = 0 if cid < len(g[c]): d = g[c][cid] ca, cb = AB[d] if i == ca: dfs(d) chd[c].append([mn[d], d]) cid += 1 i = cb upd = 1 if not upd: chd[c].append([P[i], -1]) i += 1 if chd[c][-1][0] < chd[c][0][0]: chd[c].reverse() mn[c] = chd[c][0][0] return dfs(0) ans = [] def dfs2(c): for val, id in chd[c]: if id == -1: ans.append(val) else: dfs2(id) dfs2(0) print(*ans)
かつおぶし(難易度3)階乗の桁数
C++
#include <cmath> #include <iostream> #include <utility> #include <vector> using namespace std; #define rep(i, n) for (int i = 0; i < (n); i++) int main() { int N; cin >> N; double lg = 0; for (int i = 1; i <= N; i++) lg += log10(i); cout << int(lg) + 1 << "\n"; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); double ans = 1; for (int i = 1; i <= N; i++) { ans += Math.log10(i); } System.out.println((int)ans); } }
Python
from math import log10 print(int(sum(map(log10, range(1,int(input())+1))))+1)
かつおぶし(難易度4)桁と数列
C++
#include <iostream> #include <utility> #include <vector> using namespace std; #define rep(i, n) for (int i = 0; i < (n); i++) int main() { long long N, S; cin >> N >> S; vector<long long> d(N); for (auto& x : d) cin >> x; // i桁の整数の最小値は? vector<long long> mn(16); mn[1] = 1; for (int i = 2; i <= 15; i++) mn[i] = mn[i - 1] * 10; // 最小値を設定 vector<long long> a(N); rep(i, N) { a[i] = mn[d[i]]; S -= a[i]; if (S < 0) { cout << "-1\n"; exit(0); } } // できるだけ大きくする rep(i, N) { long long dif = mn[d[i] + 1] - 1 - a[i]; dif = min(S, dif); a[i] += dif; S -= dif; } // 最大限大きくしてもSが残ってたらダメ if (S > 0) { cout << "-1\n"; exit(0); } rep(i, N) cout << a[i] << " \n"[i == N - 1]; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int S = sc.nextInt(); int total = 0; int[] lst = new int[N]; for (int i = 0; i < N; i++) { total += lst[i] = (int) Math.pow(10, sc.nextInt()-1); } if (S < total) { System.out.println(-1); return; } else { S -= total; } for (int i = 0; i < N; i++) { int add = Math.min(S, lst[i]*9-1); lst[i] += add; S -= add; } if (S == 0) { StringJoiner ans = new StringJoiner(" "); Arrays.stream(lst).forEach(x -> ans.add(String.valueOf(x))); System.out.println(ans.toString()); } else { System.out.println(-1); } } }
Python
N, S = map(int, input().split()) d = list(map(int, input().split())) ans = [10**(di-1) for di in d] total = sum(ans) if S < total: print(-1) exit() else: S -= total for i in range(N): add = min(S, ans[i]*9-1) ans[i] += add S -= add if S: print(-1) else: print(*ans)
かつおぶし(難易度6)ペアなすごろく
C++
#include<bits/stdc++.h> #define mod 998244353 using namespace std; long long llsankaku(long long x){return ((1+x)*x)/2;} long long power(long long a,long long b){ long long x=1,y=a; while(b>0){ if(b&1ll){ x=(x*y)%mod; } y=(y*y)%mod; b>>=1; } return x%mod; } long long modular_inverse(long long n){ return power(n,mod-2); } #define MAXN 524288 int main(){ long long n; cin >> n; vector<long long> r(n); for(auto &nx : r){cin >> nx;} vector<long long> acc(MAXN,0); long long res=0; long long prob=1,delta=(mod-1); acc[0]=1; for(long long i=0;i<MAXN;i++){ if(i<n){ long long q=llsankaku(r[i])%mod; q=modular_inverse(q); q*=prob;q%=mod; delta+=(mod-q);delta%=mod; acc[i+r[i]]+=q;acc[i+r[i]]%=mod; prob+=q*(r[i]+1);prob%=mod; } else{res+=prob*(i+1);res%=mod;} prob+=delta;prob%=mod; delta+=acc[i];delta%=mod; } cout << res << '\n'; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int[] R = new int[N]; for (int i = 0; i < N; i++) { R[i] = sc.nextInt(); } long mod = 998244353; int max = 400001; long[] acc = new long[max]; long ans = 0; long prob = 1; long delta = mod - 1; acc[0] = 1; for (int i = 0; i < max; i++) { if (i < N) { long q = (long)R[i] * (R[i]+1) / 2 % mod; q = modPow(q, mod-2, mod) * prob % mod; delta = (delta + mod - q) % mod; acc[i+R[i]] = (acc[i+R[i]] + q) % mod; prob = (prob + q * (R[i]+1)) % mod; } else { ans = (ans + prob * (i+1)) % mod; } prob = (prob + delta) % mod; delta = (delta + acc[i]) % mod; } System.out.println(ans); } static long modPow(long x, long n, long mod) { long ret = 1; // Repeated squaring while (0 < n) { if ((n&1) == 1) { ret = ret * x % mod; } x = x * x % mod; n >>= 1; } return ret; } }
Python
N = int(input()) R = list(map(int, input().split())) mod = 998244353 acc = [0]*(4*10**5+1) res = 0 prob = 1 delta = mod - 1 acc[0] = 1 for i in range(4*10**5+1): if i < N: q = R[i] * (R[i]+1) // 2 % mod q = pow(q, mod-2, mod) * prob % mod delta = (delta+mod-q) % mod acc[i+R[i]] = (acc[i+R[i]]+q) % mod prob = (prob+q*(R[i]+1)) % mod else: res = (res+prob*(i+1)) % mod prob = (prob+delta) % mod delta = (delta+acc[i]) % mod print(res)
かつおぶし(難易度6)コイン投げ
C++
#include <iostream> #include <string> #include <vector> using namespace std; constexpr int MOD = 998244353; int main() { string S; cin >> S; vector<int> A(S.size() + 1, -1), dp(S.size() + 1); for (int i = 0, j = -1, k; i < (int)S.size(); i++) { while (~j && S[i] != S[j]) j = A[j]; A[i + 1] = S[i + 1] == S[++j] ? A[j] : j; for (k = A[i]; ~k && S[i] == S[k];) k = A[k]; dp[i + 1] = ((dp[i] * 2 - dp[++k] - 2) % MOD + MOD) % MOD; } cout << (-dp.back() % MOD + MOD) % MOD << "\n"; }
Java
import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); char[] S = (sc.next() + "_").toCharArray(); int len = S.length; int mod = 998244353; int[] A = new int[len]; Arrays.fill(A, -1); int[] dp = new int[len]; for (int i = 0, j = -1; i < len-1; i++) { while ((~j != 0) && (S[i] != S[j])) { j = A[j]; } A[i+1] = (S[i+1] == S[++j] ? A[j] : j); int k = A[i]; for (k = A[i]; (~k != 0) && (S[i] == S[k]);) { k = A[k]; } dp[i+1] = ((dp[i] * 2 - dp[k+1] - 2) % mod + mod) % mod; } System.out.println((-dp[len-1] % mod + mod) % mod); } }
Python
S = input() siz = len(S) S = S + " " mod = 998244353 A = [-1] * (siz + 1) dp = [0] * (siz + 1) j = -1 for i in range(siz): while ~j and (S[i] != S[j]): j = A[j] j += 1 A[i+1] = (A[j] if S[i+1] == S[j] else j) k = A[i] while ~k and (S[i] == S[k]): k = A[k] dp[i+1] = ((dp[i] * 2 - dp[k+1] - 2) % mod + mod) % mod print((-dp[-1] % mod + mod) % mod)