pstopia Notes for Problem Solving Contest

[Medium] RequiredSubstrings

Topcoder SRM 519 Div1

Problem

평범한 DP를 생각하려고 하면 words[] 끼리 서로 겹치는 부분이 존재할 때, 상태를 어떻게 정의해야할지 막막하다. 패턴이 여러개 뒤섞여 등장하는 경우, Aho-Corasick 알고리즘의 DFA를 이용하면 깔끔하게 상태를 정의할 수 있다. 먼저 words[] 를 갖고 DFA를 만든다. 그리고 여기에 DP를 끼얹는다. meet[u] = DFA상의 노드u에서 매칭되는 패턴의 집합 이건 노드u에서 output link 를 따라가면 쉽게 구할 수 있다. D[i][u][s] = 길이i 이고 현재 DFA상에서 노드 u에 위치한 상태. 지금까지 등장한 패턴의 집합이 s. 일 때, 가능한 문자열의 수 D[i+1][v][s ∪ meet[v]] += D[i][u][s] (DFA상에 u->v 인 간선이 존재할 때) 답은 sum( D[L][u][s] , u는 DFA상의 모든 노드 , s는 정확히 C개를 포함한 모든 집합 )
#include <vector>
#include <queue>
#include <algorithm>
using namespace std;
struct AhoCorasick {
const int alphabet;
struct node {
node(int nxt) : next(nxt) {}
vector<int> next, report;
int back = 0, output_link = 0, flag = 0;
};
int maxid = 0;
vector<node> dfa;
AhoCorasick(int alphabet) : alphabet(alphabet), dfa(1, node(alphabet)) {}
template<typename InIt, typename Fn> void add(int id, InIt first, InIt last, Fn func) {
int cur = 0;
for (; first != last; ++first) {
auto s = func(*first);
if (auto next = dfa[cur].next[s]) cur = next;
else {
cur = dfa[cur].next[s] = (int)dfa.size();
dfa.emplace_back(alphabet);
}
}
dfa[cur].report.push_back(id);
maxid = max(maxid, id);
}
void build() {
queue<int> q;
vector<char> visit(dfa.size());
visit[0] = 1;
q.push(0);
while (!q.empty()) {
auto cur = q.front(); q.pop();
dfa[cur].output_link = dfa[cur].back;
if (dfa[dfa[cur].back].report.empty())
dfa[cur].output_link = dfa[dfa[cur].back].output_link;
for (int s = 0; s < alphabet; s++) {
auto &next = dfa[cur].next[s];
if (next == 0) next = dfa[dfa[cur].back].next[s];
if (visit[next]) continue;
if (cur) dfa[next].back = dfa[dfa[cur].back].next[s];
visit[next] = 1;
q.push(next);
}
}
for (int i = 0; i < dfa.size(); ++i) {
for (int p = i; p; p = dfa[p].output_link)
for (auto id : dfa[p].report) dfa[i].flag |= 1 << id;
}
}
};
const int MOD = 1000000009;
int d[51][350][1 << 6];
class RequiredSubstrings {
public:
int solve(vector<string> words, int C, int L) {
AhoCorasick ac(26);
for (int i = 0; i < words.size(); ++i) {
ac.add(i, words[i].begin(), words[i].end(), [](char c) { return c - 'a'; });
}
ac.build();
d[0][0][0] = 1;
for (int i = 0; i < L; ++i) {
for (int j = 0; j < ac.dfa.size(); ++j) {
for (int nxt : ac.dfa[j].next) {
for (int s = 0; s < (1 << words.size()); ++s) {
d[i + 1][nxt][s | ac.dfa[nxt].flag] += d[i][j][s];
d[i + 1][nxt][s | ac.dfa[nxt].flag] %= MOD;
}
}
}
}
int ans = 0;
for (int i = 0; i < ac.dfa.size(); ++i) {
for (int s = 0; s < (1 << words.size()); ++s) {
int bitcnt = 0;
for (int k = 0; k < words.size(); ++k) {
bitcnt += (s & (1 << k)) ? 1 : 0;
}
if (bitcnt == C) {
ans = (ans + d[L][i][s]) % MOD;
}
}
}
return ans;
}
};