PG BATTLE 2021 解答例一覧
各問題につきC++、Java、Pythonの解答例を用意しております。
AtCoder社による解答はC++のもので、JavaとPythonの解答例はPG BATTLE運営が用意したものとなりますので、予めご了承ください。
ましゅまろ(難易度1)物理現象グラフィックバトル
C++
#include<bits/stdc++.h>
using namespace std;
int main(){
double x,y,k;
cin >> x >> y >> k;
printf("%.12lf\n",y-(k/x));
return 0;
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
double X = sc.nextDouble();
double Y = sc.nextDouble();
double K = sc.nextDouble();
System.out.println(Y - K / X);
}
}
Python
X, Y, K = map(int, input().split()) print(Y - K / X)
ましゅまろ(難易度3)ゼロのない整数
C++
#include<bits/stdc++.h>
using namespace std;
int main(){
string x;
cin >> x;
bool fl=false;
for(auto &nx : x){
if(nx=='0'){fl=true;}
if(fl){cout << '1';}else{cout << nx;}
}
cout << '\n';
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
char[] X = sc.next().toCharArray();
boolean appeared = false;
for (int i = 0; i < X.length; i++) {
if (appeared || X[i] == '0') {
X[i] = '1';
appeared = true;
}
}
System.out.println(new String(X));
}
}
Python
X = input()
idx = X.find("0")
print(X if idx == -1 else X[:idx]+"1"*(len(X)-idx))
ましゅまろ(難易度3)一点封鎖
C++
#include <atcoder/modint>
#include <bits/stdc++.h>
#define mod 998244353
using namespace std;
using namespace atcoder;
using mint = modint;
mint nCr(int n,int r){
if(n<r){return 0;}
mint res=1;
for(int i=0;i<r;i++){
res*=(n-i);
res/=(i+1);
}
return res;
}
int main(){
int n,m,a,b;
cin >> n >> m >> a >> b;
int ans1,ans2;
//Solution 1
vector<vector<int>> dp(n+1,vector<int>(m+1,0));
dp[0][0]=1;
for(int i=0;i<=n;i++){
for(int j=0;j<=m;j++){
if(i==a && j==b){dp[i][j]=0;continue;}
if(i!=0){dp[i][j]+=dp[i-1][j];}
if(j!=0){dp[i][j]+=dp[i][j-1];}
dp[i][j]%=mod;
}
}
ans1=dp[n][m];
//Solution 2
if(a<=n && b<=m){
mint res=nCr(n+m,n),sub=nCr(a+b,a);
sub*=nCr((n-a)+(m-b),(n-a));
res-=sub;
ans2=res.val();
}
else{ans2=nCr(n+m,n).val();}
assert(ans1==ans2);
cout << ans1 << '\n';
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int M = sc.nextInt();
int A = sc.nextInt();
int B = sc.nextInt();
long mod = 998244353;
long[] factorials = new long[4001];
factorials[0] = 1;
for (int i = 1; i < 4001; i++) {
factorials[i] = factorials[i-1] * i % mod;
}
// Fermat's little theorem
long ans = factorials[N+M] * modPow(factorials[N], mod-2, mod) % mod * modPow(factorials[M], mod-2, mod) % mod;
if (A <= N && B <= M) {
long ab1 = factorials[A+B] * modPow(factorials[A], mod-2, mod) % mod
* modPow(factorials[B], mod-2, mod) % mod;
long ab2 = factorials[N+M-A-B] * modPow(factorials[N-A], mod-2, mod) % mod
* modPow(factorials[M-B], mod-2, mod) % mod;
ans = (ans + mod * mod - ab1 * ab2) % mod;
}
System.out.println(ans);
}
static long modPow(long x, long n, long mod) {
long ret = 1;
// Repeated squaring
while (0 < n) {
if ((n & 1) == 1) {
ret = ret * x % mod;
}
x = x * x % mod;
n >>= 1;
}
return ret;
}
}
Python
N, M, A, B = map(int, input().split())
mod = 998244353
factorials = [1] * (N+M+1)
for i in range(1, N+M+1):
factorials[i] = factorials[i-1] * i % mod
# Fermat's little theorem
ans = factorials[N+M] * pow(factorials[N], mod-2, mod) * pow(factorials[M], mod-2, mod) % mod
if A <= N and B <= M:
ab1 = factorials[A+B] * pow(factorials[A], mod-2, mod) * pow(factorials[B], mod-2, mod)
ab2 = factorials[N+M-A-B] * pow(factorials[N-A], mod-2, mod) * pow(factorials[M-B], mod-2, mod)
ans = (ans + mod*mod - ab1*ab2) % mod;
print(ans)
ましゅまろ(難易度5)無双関数
C++
#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
int main(){
int n;
cin >> n;
vector<int> a(n),s(n);
for(auto &nx : a){cin >> nx;}
for(auto &nx : s){cin >> nx;}
for(int i=1;i<n;i++){s[i]+=s[i-1];s[i]%=mod;}
int res=0,p=0;
set<int> st;
for(int i=0;i<n;i++){
while(st.find(a[i])!=st.end()){st.erase(a[p]);p++;}
st.insert(a[i]);
res+=s[i-p];res%=mod;
}
cout << res << '\n';
return 0;
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int[] A = new int[N];
for (int i = 0; i < N; i++) {
A[i] = sc.nextInt();
}
int[] S = new int[N];
for (int i = 0; i < N; i++) {
S[i] = sc.nextInt();
}
int mod = 998244353;
Map<Integer, Integer> indexes = new HashMap<Integer, Integer>();
long[] prefixSums = new long[N];
prefixSums[0] = S[0];
for (int i = 1; i < N; i++) {
prefixSums[i] = prefixSums[i-1] + S[i];
prefixSums[i] %= mod;
}
long ans = 0;
int l = 0;
for (int r = 0; r < N; r++) {
if (indexes.containsKey(A[r])) {
l = Math.max(l, indexes.get(A[r]));
}
indexes.put(A[r], r+1);
ans += prefixSums[r-l];
ans %= mod;
}
System.out.println(ans);
}
}
Python
from itertools import accumulate
N = int(input())
A = list(map(int, input().split()))
S = list(map(int, input().split()))
mod = 998244353
indexes = {}
prefix_sums = list(accumulate(S))
ans = 0
l = 0
for r in range(N):
if A[r] in indexes:
l = max(l, indexes[A[r]])
indexes[A[r]] = r + 1
ans += prefix_sums[r-l]
ans %= mod
print(ans)
せんべい(難易度2)7番勝負
C++
#include<bits/stdc++.h>
using namespace std;
int main(){
int p;
cin >> p;
double wp=((double)p)/100.0;
double lp=((double)(100-p))/100.0;
double res=0;
for(int i=0;i<(1<<7);i++){
int w=0;
for(int j=0;j<7;j++){if(i&(1<<j)){w++;}}
if(w>=4){res+=pow(wp,w)*pow(lp,7-w);}
}
printf("%.12lf\n",100.0*res);
return 0;
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
double P = sc.nextDouble() / 100;
double ans = 0;
for (int i = 4; i <= 7; i++) {
double tmp = 1;
for (int j = 0; j < i; j++) {
tmp *= 7 - j;
tmp /= j + 1;
tmp *= P;
}
for (int j = 0; j < 7 - i; j++) {
tmp *= 1 - P;
}
ans += tmp;
}
System.out.println(ans * 100);
}
}
Python
P = int(input()) / 100;
ans = 0;
for i in range(4, 8):
tmp = 1;
for j in range(i):
tmp *= 7 - j;
tmp /= j + 1;
tmp *= P;
for j in range(7-i):
tmp *= 1 - P;
ans += tmp;
print(ans * 100)
せんべい(難易度3)連結成分数の見積もり
C++
#include<bits/stdc++.h>
using namespace std;
int main(){
int t;
cin >> t;
while(t>0){
t--;
long long n,m;
cin >> n >> m;
cout << max(1ll,n-m) << ' ';
long long st=1,fi=1500000000,te;
while(st<=fi){
te=(st+fi)/2;
if((te*(te-1))/2<m){st=te+1;}else{fi=te-1;}
}
cout << n+1-st << '\n';
}
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
for (int i = 0; i < T; i++) {
int N = sc.nextInt();
long M = sc.nextLong();
long min = Math.max(1, N-M);
// Binary search
long l = 1;
long r = 1000000001;
while (l <= r) {
long mid = (l + r) / 2;
if (mid * (mid - 1) / 2 < M) {
l = mid + 1;
} else {
r = mid - 1;
}
}
long max = N - l + 1;
System.out.println(min + " " + max);
}
}
}
Python
T = int(input())
for _ in range(T):
N, M = map(int, input().split())
mi = max(1, N-M)
# Binary search
l = 1
r = 1000000001
while l <= r:
mid = (l + r) // 2
if mid * (mid - 1) // 2 < M:
l = mid + 1
else:
r = mid - 1
ma = N - l + 1
print(mi, ma)
せんべい(難易度4)トーナメント表
C++
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
int main() {
int N, M;
cin >> N;
M = 1 << N;
vector<vector<int>> a(N + 1);
for (int i = 1, x; i <= M; i++) {
cin >> x;
a[__builtin_ctz(x)].push_back(i);
}
reverse(begin(a), end(a));
vector<int> ans;
for (int i = 1; i <= M; i++) {
int j = __builtin_ctz(i);
if (a[j].empty()) {
cout << -1 << "\n";
exit(0);
}
ans.push_back(a[j].back());
a[j].pop_back();
}
for (int i = 0; i < M; i++) cout << ans[i] << " \n"[i == M - 1];
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int siz = pow(2, N);
int[][] A = new int[siz][2];
for (int i = 0; i < siz; i++) {
A[i][0] = sc.nextInt();
A[i][1] = i + 1;
}
Arrays.sort(A, (a, b) -> Integer.compare(-a[0], -b[0]));
int[] tree = new int[siz];
int k = 0;
for (int i = 0; i < N+1; i++) {
for (int j = pow(2, i)-1; j < siz; j += pow(2, i+1)) {
if (pow(2, N-i) != A[k][0]) {
System.out.println(-1);
return;
}
tree[j] = A[k][1];
k++;
}
}
StringJoiner ans = new StringJoiner(" ");
Arrays.stream(tree).forEach(x -> ans.add(String.valueOf(x)));
System.out.println(ans.toString());
}
static int pow(int x, int n) {
return (int) Math.pow(x, n);
}
}
Python
N = int(input())
A = list(map(int, input().split()))
lst = [[A[i], i+1] for i in range(2**N)]
lst.sort(reverse=True)
tree = [0] * 2**N
k = 0
for i in range(N+1):
for j in range(2**i-1, 2**N, 2**(i+1)):
if 2**(N-i) != lst[k][0]:
print(-1)
exit()
tree[j] = lst[k][1]
k += 1
print(*tree)
せんべい(難易度6)[リ[[スー]バ][ズパ]ル]
C++
#include <algorithm>
#include <cassert>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
int N, M;
cin >> N >> M;
vector<int> P(N);
for (auto& x : P) cin >> x;
vector<pair<int, int>> AB(M);
for (auto& [a, b] : AB) cin >> a >> b;
// 簡単のため 0-indexed の 半開区間にする
for (auto& [a, _] : AB) --a;
// 区間全体を入れる
AB.emplace_back(0, N);
// ソート
sort(begin(AB), end(AB), [](auto l, auto r) {
return l.first == r.first ? l.second > r.second : l.first < r.first;
});
assert(AB[0] == make_pair(0, N));
// 包含関係に対応するグラフを構築
vector<vector<int>> g(M + 1);
vector<int> par(M + 1, -1);
for (int i = 1; i <= M; i++) {
auto [a, b] = AB[i];
// 親のノードは?
int p = i - 1;
// p に完全に含まれるまで親にさかのぼる
// さかのぼる回数は償却 O(1) (cf : オイラーツアー)
while (true) {
auto [pa, pb] = AB[p];
// pa a b pb の順 (b = pb を許容) -> 包含
if (b <= pb) break;
// 並び順は pa pb a b なので親にさかのぼる必要あり
p = par[p];
}
// i を p の子とする
g[p].push_back(i);
par[i] = p;
}
// 区間の先頭要素の最小値を記録する配列
vector<int> mn(M + 1, 1 << 30);
// reverse するかを表す配列
// (子のmin, 子のid(ブロック単体の場合-1) )
vector<vector<pair<int, int>>> chd(M + 1);
// DFSで順番を決定する
auto dfs = [&](auto rc, int c) -> void {
auto [a, b] = AB[c];
int cid = 0;
for (int i = a; i < b;) {
int upd = 0;
if (cid < (int)g[c].size()) {
int d = g[c][cid];
auto [ca, cb] = AB[d];
if (i == ca) {
// dを使う。再帰的に最小値を計算
rc(rc, d);
// chd[c] に pushして更新作業
chd[c].emplace_back(mn[d], d);
cid++;
i = cb;
upd = 1;
}
}
if (!upd) {
chd[c].emplace_back(P[i], -1);
i++;
}
}
if (chd[c].back().first < chd[c][0].first) {
reverse(begin(chd[c]), end(chd[c]));
}
mn[c] = chd[c][0].first;
return;
};
dfs(dfs, 0);
vector<int> ans;
// DFS で解を構成
auto dfs2 = [&](auto rc, int c) -> void {
for (auto& [val, id] : chd[c]) {
if (id == -1) {
ans.push_back(val);
} else {
rc(rc, id);
}
}
};
dfs2(dfs2, 0);
assert((int)ans.size() == N);
for (int i = 0; i < N; i++) cout << ans[i] << " \n"[i == N - 1];
}
Java
import java.util.*;
public class Main {
static int N;
static int M;
static int[] P;
static int[][] AB;
static List<List<Integer>> g;
static List<List<Integer[]>> chd;
static List<String> ans;
static int[] mn;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
N = Integer.parseInt(sc.next());
M = Integer.parseInt(sc.next());
P = new int[N];
for (int i = 0; i < N; i++) {
P[i] = Integer.parseInt(sc.next());
}
AB = new int[M+1][2];
for (int i = 0; i < M; i++) {
// 簡単のため 0-indexed の 半開区間にする
AB[i][0] = Integer.parseInt(sc.next()) - 1;
AB[i][1] = Integer.parseInt(sc.next());
}
// 区間全体を入れる
AB[M] = new int[] {0, N};
// ソート
Arrays.sort(AB, (x, y) -> x[0] == y[0] ? y[1] - x[1] : x[0] - y[0]);
// 包含関係に対応するグラフを構築
g = new ArrayList<List<Integer>>();
int[] par = new int[M+1];
// 区間の先頭要素の最小値を記録する配列
mn = new int[M+1];
int init = 1 << 30;
// reverse するかを表す配列
// (子のmin, 子のid(ブロック単体の場合-1) )
chd = new ArrayList<List<Integer[]>>();
for (int i = 0; i < M+1; i++) {
g.add(new ArrayList<Integer>());
par[i] = -1;
mn[i] = init;
chd.add(new ArrayList<Integer[]>());
}
for (int i = 1; i <= M; i++) {
int a = AB[i][0];
int b = AB[i][1];
// 親のノードは?
int p = i - 1;
// p に完全に含まれるまで親にさかのぼる
// さかのぼる回数は償却 O(1) (cf : オイラーツアー)
while (true) {
int pa = AB[p][0];
int pb = AB[p][1];
// pa a b pb の順 (b = pb を許容) -> 包含
if (b <= pb) {
break;
}
// 並び順は pa pb a b なので親にさかのぼる必要あり
p = par[p];
}
// i を p の子とする
g.get(p).add(i);
par[i] = p;
}
// DFSで順番を決定する
dfs(0);
ans = new ArrayList<String>();
// DFS で解を構成
dfs2(0);
System.out.println(String.join(" ", ans));
}
static void dfs(int c) {
int a = AB[c][0];
int b = AB[c][1];
int cid = 0;
for (int i = a; i < b;) {
boolean upd = false;
if (cid < g.get(c).size()) {
int d = g.get(c).get(cid);
int ca = AB[d][0];
int cb = AB[d][1];
if (i == ca) {
// dを使う。再帰的に最小値を計算
dfs(d);
// chd[c] に addして更新作業
chd.get(c).add(new Integer[] {mn[d], d});
cid++;
i = cb;
upd = true;
}
}
if (!upd) {
chd.get(c).add(new Integer[] {P[i], -1});
i++;
}
}
if (chd.get(c).get(chd.get(c).size()-1)[0] < chd.get(c).get(0)[0]) {
Collections.reverse(chd.get(c));
}
mn[c] = chd.get(c).get(0)[0];
return;
}
static void dfs2(int c) {
for (Integer[] ele : chd.get(c)) {
int val = ele[0];
int id = ele[1];
if (id == -1) {
ans.add(String.valueOf(val));
} else {
dfs2(id);
}
}
};
}
Python
import sys
sys.setrecursionlimit(10**9)
N, M = list(map(int, input().split()))
P = list(map(int, input().split()))
AB = [list(map(int, input().split())) for _ in range(M)]
for i in range(M):
AB[i][0] -= 1
AB.append([0,N])
AB.sort(key = lambda x: (x[0], -x[1]))
g = [[] for _ in range(M+1)]
par = [-1] * (M+1)
for i in range(1, M+1):
a, b = AB[i]
p = i - 1
while True:
pa, pb = AB[p]
if b <= pb:
break
p = par[p]
g[p].append(i)
par[i] = p
mn = [1<<30] * (M+1)
chd = [[] for _ in range((M+1))]
def dfs(c):
a, b = AB[c]
cid = 0
i = a
while i < b:
upd = 0
if cid < len(g[c]):
d = g[c][cid]
ca, cb = AB[d]
if i == ca:
dfs(d)
chd[c].append([mn[d], d])
cid += 1
i = cb
upd = 1
if not upd:
chd[c].append([P[i], -1])
i += 1
if chd[c][-1][0] < chd[c][0][0]:
chd[c].reverse()
mn[c] = chd[c][0][0]
return
dfs(0)
ans = []
def dfs2(c):
for val, id in chd[c]:
if id == -1:
ans.append(val)
else:
dfs2(id)
dfs2(0)
print(*ans)
かつおぶし(難易度3)階乗の桁数
C++
#include <cmath>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
int N;
cin >> N;
double lg = 0;
for (int i = 1; i <= N; i++) lg += log10(i);
cout << int(lg) + 1 << "\n";
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
double ans = 1;
for (int i = 1; i <= N; i++) {
ans += Math.log10(i);
}
System.out.println((int)ans);
}
}
Python
from math import log10 print(int(sum(map(log10, range(1,int(input())+1))))+1)
かつおぶし(難易度4)桁と数列
C++
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
long long N, S;
cin >> N >> S;
vector<long long> d(N);
for (auto& x : d) cin >> x;
// i桁の整数の最小値は?
vector<long long> mn(16);
mn[1] = 1;
for (int i = 2; i <= 15; i++) mn[i] = mn[i - 1] * 10;
// 最小値を設定
vector<long long> a(N);
rep(i, N) {
a[i] = mn[d[i]];
S -= a[i];
if (S < 0) {
cout << "-1\n";
exit(0);
}
}
// できるだけ大きくする
rep(i, N) {
long long dif = mn[d[i] + 1] - 1 - a[i];
dif = min(S, dif);
a[i] += dif;
S -= dif;
}
// 最大限大きくしてもSが残ってたらダメ
if (S > 0) {
cout << "-1\n";
exit(0);
}
rep(i, N) cout << a[i] << " \n"[i == N - 1];
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int S = sc.nextInt();
int total = 0;
int[] lst = new int[N];
for (int i = 0; i < N; i++) {
total += lst[i] = (int) Math.pow(10, sc.nextInt()-1);
}
if (S < total) {
System.out.println(-1);
return;
} else {
S -= total;
}
for (int i = 0; i < N; i++) {
int add = Math.min(S, lst[i]*9-1);
lst[i] += add;
S -= add;
}
if (S == 0) {
StringJoiner ans = new StringJoiner(" ");
Arrays.stream(lst).forEach(x -> ans.add(String.valueOf(x)));
System.out.println(ans.toString());
} else {
System.out.println(-1);
}
}
}
Python
N, S = map(int, input().split())
d = list(map(int, input().split()))
ans = [10**(di-1) for di in d]
total = sum(ans)
if S < total:
print(-1)
exit()
else:
S -= total
for i in range(N):
add = min(S, ans[i]*9-1)
ans[i] += add
S -= add
if S:
print(-1)
else:
print(*ans)
かつおぶし(難易度6)ペアなすごろく
C++
#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
long long llsankaku(long long x){return ((1+x)*x)/2;}
long long power(long long a,long long b){
long long x=1,y=a;
while(b>0){
if(b&1ll){
x=(x*y)%mod;
}
y=(y*y)%mod;
b>>=1;
}
return x%mod;
}
long long modular_inverse(long long n){
return power(n,mod-2);
}
#define MAXN 524288
int main(){
long long n;
cin >> n;
vector<long long> r(n);
for(auto &nx : r){cin >> nx;}
vector<long long> acc(MAXN,0);
long long res=0;
long long prob=1,delta=(mod-1);
acc[0]=1;
for(long long i=0;i<MAXN;i++){
if(i<n){
long long q=llsankaku(r[i])%mod;
q=modular_inverse(q);
q*=prob;q%=mod;
delta+=(mod-q);delta%=mod;
acc[i+r[i]]+=q;acc[i+r[i]]%=mod;
prob+=q*(r[i]+1);prob%=mod;
}
else{res+=prob*(i+1);res%=mod;}
prob+=delta;prob%=mod;
delta+=acc[i];delta%=mod;
}
cout << res << '\n';
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int[] R = new int[N];
for (int i = 0; i < N; i++) {
R[i] = sc.nextInt();
}
long mod = 998244353;
int max = 400001;
long[] acc = new long[max];
long ans = 0;
long prob = 1;
long delta = mod - 1;
acc[0] = 1;
for (int i = 0; i < max; i++) {
if (i < N) {
long q = (long)R[i] * (R[i]+1) / 2 % mod;
q = modPow(q, mod-2, mod) * prob % mod;
delta = (delta + mod - q) % mod;
acc[i+R[i]] = (acc[i+R[i]] + q) % mod;
prob = (prob + q * (R[i]+1)) % mod;
} else {
ans = (ans + prob * (i+1)) % mod;
}
prob = (prob + delta) % mod;
delta = (delta + acc[i]) % mod;
}
System.out.println(ans);
}
static long modPow(long x, long n, long mod) {
long ret = 1;
// Repeated squaring
while (0 < n) {
if ((n&1) == 1) {
ret = ret * x % mod;
}
x = x * x % mod;
n >>= 1;
}
return ret;
}
}
Python
N = int(input())
R = list(map(int, input().split()))
mod = 998244353
acc = [0]*(4*10**5+1)
res = 0
prob = 1
delta = mod - 1
acc[0] = 1
for i in range(4*10**5+1):
if i < N:
q = R[i] * (R[i]+1) // 2 % mod
q = pow(q, mod-2, mod) * prob % mod
delta = (delta+mod-q) % mod
acc[i+R[i]] = (acc[i+R[i]]+q) % mod
prob = (prob+q*(R[i]+1)) % mod
else:
res = (res+prob*(i+1)) % mod
prob = (prob+delta) % mod
delta = (delta+acc[i]) % mod
print(res)
かつおぶし(難易度6)コイン投げ
C++
#include <iostream>
#include <string>
#include <vector>
using namespace std;
constexpr int MOD = 998244353;
int main() {
string S;
cin >> S;
vector<int> A(S.size() + 1, -1), dp(S.size() + 1);
for (int i = 0, j = -1, k; i < (int)S.size(); i++) {
while (~j && S[i] != S[j]) j = A[j];
A[i + 1] = S[i + 1] == S[++j] ? A[j] : j;
for (k = A[i]; ~k && S[i] == S[k];) k = A[k];
dp[i + 1] = ((dp[i] * 2 - dp[++k] - 2) % MOD + MOD) % MOD;
}
cout << (-dp.back() % MOD + MOD) % MOD << "\n";
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
char[] S = (sc.next() + "_").toCharArray();
int len = S.length;
int mod = 998244353;
int[] A = new int[len];
Arrays.fill(A, -1);
int[] dp = new int[len];
for (int i = 0, j = -1; i < len-1; i++) {
while ((~j != 0) && (S[i] != S[j])) {
j = A[j];
}
A[i+1] = (S[i+1] == S[++j] ? A[j] : j);
int k = A[i];
for (k = A[i]; (~k != 0) && (S[i] == S[k]);) {
k = A[k];
}
dp[i+1] = ((dp[i] * 2 - dp[k+1] - 2) % mod + mod) % mod;
}
System.out.println((-dp[len-1] % mod + mod) % mod);
}
}
Python
S = input()
siz = len(S)
S = S + " "
mod = 998244353
A = [-1] * (siz + 1)
dp = [0] * (siz + 1)
j = -1
for i in range(siz):
while ~j and (S[i] != S[j]):
j = A[j]
j += 1
A[i+1] = (A[j] if S[i+1] == S[j] else j)
k = A[i]
while ~k and (S[i] == S[k]):
k = A[k]
dp[i+1] = ((dp[i] * 2 - dp[k+1] - 2) % mod + mod) % mod
print((-dp[-1] % mod + mod) % mod)