Skip to content

Instantly share code, notes, and snippets.

@chenyukang
Last active October 8, 2019 11:01
Show Gist options
  • Save chenyukang/8265615 to your computer and use it in GitHub Desktop.
Save chenyukang/8265615 to your computer and use it in GitHub Desktop.
DFA construction for a simple regular expression match question.
#include <iostream>
#include <string>
#include <vector>
#include <set>
#include <stdio.h>
#include <assert.h>
using namespace std;
enum OpType {
ZERO_PLUS_ONE,
ANY_ONE,
MUST_ONE
};
struct State {
OpType type;
int id;
char value;
bool end;
State* prev;
vector<State*> next;
State(OpType t, int i, char v, State *p) :
type(t), id(i), value(v), end(false), prev(p) {
if(type == ZERO_PLUS_ONE)
next.push_back(this);
if(p == NULL)
prev = this;
}
void add(State* n) {
next.push_back(n);
if(type == ZERO_PLUS_ONE && prev != NULL)
prev->add(n);
}
bool valid(char val) {
if(type == ZERO_PLUS_ONE && (value == val || value == '.'))
return true;
else if(type == MUST_ONE && value == val)
return true;
else if(type == ANY_ONE)
return true;
return false;
}
};
class Solution {
private:
int Num;
public:
State* construct_dfa(const char* pattern) {
if(pattern == NULL) return NULL;
const char* p = pattern;
State* start = new State(ANY_ONE, Num, '.', NULL);
State* cur = start;
State* next = NULL;
char prev = '.';
Num = 1;
while(*p && *p != '\0') {
if(*(p+1) != '*') {
OpType type;
char value;
if(*p == '*') {
type = ZERO_PLUS_ONE;
value = prev;
} else {
value = *p;
type = *p == '.'? ANY_ONE : MUST_ONE;
}
next = new State(type, Num, value, cur);
prev = *p, p++;
} else {
next = new State(ZERO_PLUS_ONE, Num, *p, cur);
prev = '*', p+=2;
}
cur->add(next);
cur = next;
Num++;
}
cur->end = true;
while(cur->type == ZERO_PLUS_ONE) {
cur = cur->prev;
cur->end = true;
}
return start;
}
void visit(State* dfa, set<State*>& s) {
s.insert(dfa);
for(int i=0; i<dfa->next.size(); i++) {
State* n = dfa->next[i];
if(n != NULL && s.find(n) == s.end())
visit(n, s);
}
}
void delete_dfa(State* dfa) {
set<State*> s;
visit(dfa, s);
for(set<State*>::iterator it = s.begin(); it != s.end(); ++it)
delete (*it);
}
bool match(const char* str, State* dfa) {
const char* s = str;
State** curs = (State**)malloc(sizeof(State*) * Num + 1);
State** nexts = (State**)malloc(sizeof(State*) * Num + 1);
int visited[Num+1];
int curNum, nextNum;
bool res = false;
curNum = nextNum = 0;
nexts[nextNum++] = dfa;
int step = 0;
while( s && *s != '\0') {
swap(curs, nexts);
swap(curNum, nextNum);
nextNum = 0;
memset(visited, 0, sizeof(visited));
for(int k=0; k<curNum; k++) {
const vector<State*> adj = curs[k]->next;
for(int k=0; k<adj.size(); k++) {
int id = adj[k]->id;
if(visited[id] == 0 && adj[k]->valid(*s)) {
nexts[nextNum++] = adj[k];
visited[id] = 1;
}
}
}
if(nextNum == 0) break;
s++;
}
for(int k=0; k<nextNum; k++) {
if(nexts[k]->end) {
res = true;
break;
}
}
free(curs);
free(nexts);
return res;
}
bool isMatch(const char *s, const char *p) {
Num = 0;
State* dfa = construct_dfa(p);
bool res = match(s, dfa);
delete_dfa(dfa);
return res;
}
//recursive version
bool isMatch_iter(const char *s, const char *p) {
assert(s && p);
if (*p == '\0') return *s == '\0';
// next char is not '*': must match current character
if (*(p+1) != '*') {
assert(*p != '*');
return ((*p == *s) || (*p == '.' && *s != '\0')) && isMatch_iter(s+1, p+1);
}
// next char is '*'
while ((*p == *s) || (*p == '.' && *s != '\0')) {
if (isMatch_iter(s, p+2)) return true;
s++;
}
return isMatch_iter(s, p+2);
}
};
void test_perf(int v) {
int time = 1000;
string re = "a*";
string a = string(time, 'a');
//std::cout << a << std::endl;
Solution p;
string b;
for(int k=0; k<time; k++)
b = re + b;
b = b + a;
if(v == 1)
std::cout << p.isMatch(a.c_str(), b.c_str()) << std::endl;
else
std::cout << p.isMatch_iter(a.c_str(), b.c_str()) << std::endl;
}
void test(const char* s, const char* p) {
Solution x;
printf("%s vs %s -> ", s, p);
if(x.isMatch(s, p)) {
std::cout << "True" << std::endl;
} else {
std::cout << "False" << std::endl;
}
}
int test_all() {
test("aaaaaaaaaaaaab", "a*a*a*a*a*a*a*a*a*a*a*a*b");
test("aa", "a**");
test("aaba", "ab*a*c*a"); //should false;
test("aa", "a*");
test("aaa", ".a");
test("aaa", "a.a");
test("a", "ab*");
test("a", "ab*b*b*");
test("ab", ".*c");
test("ab", ".*c*");
test("aa", "a");
test("aa", "aa");
test("aaa", "aa");
test("aa", ".*");
test("ab", ".*");
test("abcde", ".*");
test("aab", "c*a*b*");
test("bbbba", ".*a*a");
test("bbbba", "bbbbaa");
return 0;
}
int main() {
//test_all();
test_perf(1);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment