Skip to content

Instantly share code, notes, and snippets.

@KuRRe8
Last active May 3, 2025 23:12
Show Gist options
  • Save KuRRe8/758735e4b13a319414a951a468276f35 to your computer and use it in GitHub Desktop.
Save KuRRe8/758735e4b13a319414a951a468276f35 to your computer and use it in GitHub Desktop.
autoplay connect4 in gameboardarena

Autoplay Connect 4

connect 4 is a well-known game which receives a lot of attention in the AI community. The game is played on a 7x6 board, where two players take turns dropping colored discs into columns. The objective is to connect four discs in a row, either horizontally, vertically, or diagonally.

How to compile and use

in wsl g++ -std=c++17 -O3 main.cc then in windows python app.py first use cmd1, login, start a connect4 game, then use cmd8 to loop forever. CTRL+C to stop loop, 0 to exit.

Thanks

biran0083 repo which is inspired by PascalPons

from selenium import webdriver
from selenium.webdriver.common.by import By
import time
import numpy as np
import subprocess
from selenium.webdriver.common.action_chains import ActionChains
import signal
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
gameboard = np.zeros((7, 6), dtype=int) # 列优先,1到7列,行号在游戏里是上小下大,1到6行
# 0表示空,1表示红色,2表示黄色
driver = None # 初始化 driver 为 None
last_suggestion: int
# Define a flag to indicate whether the program should stop
stop_flag = False
def signal_handler(signum, frame):
global stop_flag
print("CTRL+D detected. Stopping the program...")
stop_flag = True
signal.signal(signal.SIGINT, signal_handler)
def reset_gameboard():
global gameboard
gameboard = np.zeros((7, 6), dtype=int) # 重置游戏棋盘
def open_browser():
global driver
if driver is None: # 如果浏览器未启动,则启动
driver = webdriver.Edge(keep_alive=True)
driver.get('https://boardgamearena.com/')
print("浏览器已启动并打开页面。")
else:
print("浏览器已启动,无需重复打开。")
def close_browser():
global driver
if driver is not None: # 如果浏览器已启动,则关闭
driver.quit()
driver = None
print("浏览器已关闭。")
else:
print("浏览器未启动,无需关闭。")
def login():
global driver
raise NotImplementedError("登录功能尚未实现。")
def update_gameboard():
global gameboard
reset_gameboard() # 重置游戏棋盘
global driver
if driver is not None: # 确保浏览器已启动
try:
for col in range(7):
for row in range(6):
element_id = f"disc_{col+1}{row+1}"
try:
element = driver.find_element(By.ID, element_id)
element_class = element.get_attribute("class")
if "disccolor_ff0000" in element_class:
gameboard[col, row] = 1
elif "disccolor_ffff00" in element_class:
gameboard[col, row] = 2
except Exception:
# 如果元素不存在,跳过
continue
print("游戏棋盘已更新:")
print(gameboard.T)
button = driver.find_element(By.XPATH, "/html/body/div[3]/div[1]/div[1]")
button.click()
except Exception as e:
print(f"更新游戏棋盘失败: {e}")
else:
print("浏览器未启动,请先打开浏览器。")
def print_gameboard():
global gameboard
print("当前游戏棋盘:")
for row in range(-1, 5, 1):
line = ""
for col in range(7):
if gameboard[col, row] == 1:
line += "X "
elif gameboard[col, row] == 2:
line += "O "
else:
line += "_ "
print(line.strip())
print("1 2 3 4 5 6 7") # 打印列号
def query_external_program():
global gameboard, last_suggestion
pos = np.int64(0) # Initialize pos as int64
mask = np.int64(0) # Initialize mask as int64
moves = 0
moves = np.sum(gameboard == 1) + np.sum(gameboard == 2)
cur_player = 1 if moves % 2 == 0 else 2
# Iterate through the gameboard to calculate pos and mask
for col in range(7):
for row in range(6):
bit_position = col * 7 + (5 - row) # Calculate bit position
if gameboard[col, row] == cur_player:
pos |= (np.int64(1) << bit_position) # Set the bit for the current player's position
if gameboard[col, row] != 0:
mask |= (np.int64(1) << bit_position) # Set the bit for any occupied position
print(f"Player {cur_player}'s position (pos): {bin(pos)}")
print(f"Mask of occupied positions (mask): {bin(mask)}")
cmd = [
"wsl",
"/mnt/d/_work/connect-four-solver/a.out",
"suggest",
"-l", "/mnt/d/_work/connect-four-solver/table",
"-t", "8",
"-d", "8",
"-bp", str(pos),
"-bm", str(mask),
"-bo", str(moves),
]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
last_line = list(process.stdout)[-1]
print(last_line, end="")
process.wait()
try:
last_suggestion = int(last_line.split()[0]) # 提取建议的列号
print(f"建议的列号: {last_suggestion}")
except (ValueError, IndexError) as e:
print(f"提取建议的列号失败: {e}")
def apply_move():
global driver, last_suggestion
if driver is not None: # 确保浏览器已启动
try:
element_id = f"square_{last_suggestion}_1"
element = driver.find_element(By.ID, element_id)
# 使用 ActionChains 模拟鼠标点击
actions = ActionChains(driver)
actions.move_to_element(element).click().perform()
print(f"已模拟点击元素: {element_id}")
time.sleep(2)
buttonupmost = driver.find_element(By.XPATH, "/html/body/div[3]/div[1]/div[1]")
buttonupmost.click()
except Exception as e:
print(f"应用移动失败: {e}")
else:
print("浏览器未启动,请先打开浏览器。")
def search_text_in_page(driver, target_text):
try:
# 获取整个页面的文本
page_text = driver.find_element("tag name", "body").text
# 检查是否包含目标文本
if target_text in page_text:
print(f"找到了目标文本: {target_text}")
return True
else:
print(f"未找到目标文本: {target_text}")
return False
except Exception as e:
print(f"搜索页面文本时发生异常: {e}")
return False
def cmd9():
print("执行组合命令:更新游戏棋盘、调用外部程序、应用移动")
update_gameboard()
query_external_program()
apply_move()
def cmd8():
print("执行自动检测循环以自动玩游戏")
global stop_flag
stop_flag = False
while not stop_flag:
try:
# 检测是否在游戏中
ingame = False
try:
element = driver.find_element(By.ID, "pagemaintitletext")
ingame = True
except Exception:
ingame = False
if ingame:
# 游戏中逻辑
while ingame and not stop_flag:
try:
element = driver.find_element(By.ID, "pagemaintitletext")
if "游戏结束" in element.text:
ingame = False
playagainbutton = driver.find_element(By.ID, "createNew_btn")
playagainbutton.click()
time.sleep(3)
break
span = element.find_element(By.TAG_NAME, "span")
if span.text == "你":
print("检测到 '你',执行 cmd9。")
cmd9()
elif "必须" in element.text:
time.sleep(1)
except Exception as e:
print(f"游戏中处理异常: {e}")
ingame = False
break
else:
# 游戏外逻辑
try:
element = driver.find_element(By.ID, "pagemaintitletext")
print('找到了游戏标题')
if "游戏结束" in element.text:
playagainbutton = driver.find_element(By.ID, "createNew_btn")
playagainbutton.click()
print('点击了再来一次')
time.sleep(3)
except Exception:
pass
try:
startbtn = driver.find_element(By.XPATH, "/html/body/div[2]/div[5]/div/div[1]/div/div[3]/div/div/div/div[4]/div[2]/div/a")
startbtn.click()
print('点击了开始游戏(匹配)')
time.sleep(0.5)
count = 0
while search_text_in_page(driver=driver, target_text="正在搜寻玩家"):
if stop_flag:
return
time.sleep(1)
count = count + 1
print(f"等待匹配{count}秒")
except Exception:
pass
try:
WebDriverWait(driver, 5).until(
EC.element_to_be_clickable((By.ID, "ags_start_game_accept"))
).click()
print('点击了确认按钮')
WebDriverWait(driver, 18).until(
EC.presence_of_element_located((By.ID, "pagemaintitletext"))
)
print('在18秒内等待到了别人开始')
except Exception:
print('等待确认按钮或者等待他人超时')
pass
except Exception as e:
print(f"cmd8 主循环异常,{e}")
time.sleep(1)
def main():
print("按键监听已启动:")
print("1 - 打开浏览器并访问页面")
print("2 - 登录操作")
print("3 - 更新游戏棋盘")
print("4 - 打印游戏棋盘")
print("5 - 调用外部程序")
print("7 - 应用移动")
print("8 - 自动检测循环")
print("9 - 执行组合命令")
print("0 - 退出程序")
while True:
user_input = input("请输入指令 (1/2/3/4/5/7/8/9/0): ")
if user_input == '1':
open_browser()
elif user_input == '2':
continue
login()
elif user_input == '3':
update_gameboard()
elif user_input == '4':
print_gameboard()
elif user_input == '5':
query_external_program()
elif user_input == '7':
apply_move()
elif user_input == '8':
cmd8()
elif user_input == '9':
cmd9()
elif user_input == '0':
print("退出程序。")
close_browser()
break
else:
print("无效指令,请重新输入。")
if __name__ == "__main__":
main()
#include<array>
#include<unistd.h>
#include<memory>
#include<fstream>
#include<unordered_map>
#include<unordered_set>
#include<cassert>
#include<chrono>
#include<cstring>
#include<vector>
#include<algorithm>
#include<iostream>
constexpr int HEIGHT = 6;
constexpr int WIDTH = 7;
constexpr int MAX_SCORE = (WIDTH * HEIGHT + 1) / 2 - 3;
constexpr int MIN_SCORE = -(WIDTH * HEIGHT) / 2 + 3;
int64_t mirrow(int64_t key) {
int64_t res = 0;
int64_t mask = (1 << (HEIGHT + 1)) - 1;
for (int i = 0; i < WIDTH; i++) {
res <<= (HEIGHT + 1);
res |= key & mask;
key >>= HEIGHT + 1;
}
return res;
}
constexpr int64_t bottom_mask() {
int64_t res = 0;
for (int i = 0; i < WIDTH; i++) {
res += 1LL << (i * (HEIGHT + 1));
}
return res;
}
constexpr int64_t full_mask() {
int64_t res = 0;
for (int i = 0; i < WIDTH; i++) {
res += ((1LL << HEIGHT) - 1) << (i * (HEIGHT + 1));
}
return res;
}
constexpr int64_t BOTTOM = bottom_mask();
constexpr int64_t FULL = full_mask();
class Table {
public:
Table(size_t size, int pinThresh): k(size), l(size), h(size), pinThresh(pinThresh) {
clear();
}
Table(int pinThresh): Table((1 << 23) + 9, pinThresh) {}
void clear() {
memset(k.data(), 0, k.size() * sizeof(int32_t));
}
void put(int64_t key, int8_t low, int8_t hi, int depth) {
if (depth <= pinThresh) {
full_cache[key] = std::make_pair(low, hi);
} else {
size_t idx = key % k.size();
k[idx] = (int32_t) key;
l[idx] = low;
h[idx] = hi;
}
}
std::pair<int, int> get(int64_t key, int depth) {
if (depth <= pinThresh) {
auto it = full_cache.find(key);
if (it != full_cache.end()) {
return it->second;
}
} else {
size_t idx = key % k.size();
if (k[idx] == (int32_t) key) {
return {l[idx], h[idx]};
}
}
return {MIN_SCORE,MAX_SCORE};
}
void load(const std::string& fname) {
std::ifstream f(fname, std::ifstream::in);
int64_t key;
int score;
while (f >> key >> score) {
put(key, score, score, 0);
key = mirrow(key);
put(key, score, score, 0);
}
std::cerr << fname << " loaded" << std::endl;
}
private:
const int pinThresh;
std::unordered_map<int64_t, std::pair<int,int>> full_cache;
std::vector<int32_t> k;
std::vector<int8_t> l;
std::vector<int8_t> h;
};
#define UP(pos, i) (pos << i)
#define DOWN(pos, i) (pos >> i)
#define LEFT(pos, i) (pos >> i * (HEIGHT + 1))
#define RIGHT(pos, i) (pos << i * (HEIGHT + 1))
#define UP_LEFT(pos, i) UP(LEFT(pos, i), i)
#define DOWN_RIGHT(pos, i) DOWN(RIGHT(pos, i), i)
#define UP_RIGHT(pos, i) UP(RIGHT(pos, i), i)
#define DOWN_LEFT(pos, i) DOWN(LEFT(pos, i), i)
int64_t get_winning_moves(int64_t pos, int64_t mask) {
int64_t res = UP(pos, 1) & UP(pos, 2) & UP(pos, 3);
res |= LEFT(pos, 1) & LEFT(pos, 2) & LEFT(pos, 3);
res |= RIGHT(pos, 1) & LEFT(pos, 1) & LEFT(pos, 2);
res |= RIGHT(pos, 2) & RIGHT(pos, 1) & LEFT(pos, 1);
res |= RIGHT(pos, 3) & RIGHT(pos, 2) & RIGHT(pos, 1);
res |= UP_LEFT(pos, 1) & UP_LEFT(pos, 2) & UP_LEFT(pos, 3);
res |= DOWN_RIGHT(pos, 1) & UP_LEFT(pos, 1) & UP_LEFT(pos, 2);
res |= DOWN_RIGHT(pos, 2) & DOWN_RIGHT(pos, 1) & UP_LEFT(pos, 1);
res |= DOWN_RIGHT(pos, 3) & DOWN_RIGHT(pos, 2) & DOWN_RIGHT(pos, 1);
res |= UP_RIGHT(pos, 1) & UP_RIGHT(pos, 2) & UP_RIGHT(pos, 3);
res |= DOWN_LEFT(pos, 1) & UP_RIGHT(pos, 1) & UP_RIGHT(pos, 2);
res |= DOWN_LEFT(pos, 2) & DOWN_LEFT(pos, 1) & UP_RIGHT(pos, 1);
res |= DOWN_LEFT(pos, 3) & DOWN_LEFT(pos, 2) & DOWN_LEFT(pos, 1);
return res & (FULL ^ mask);
}
constexpr int64_t get_column_mask(int col) {
return ((1LL << HEIGHT) - 1) << col * (HEIGHT + 1);
}
constexpr int64_t get_bottom_mask(int col) {
return 1LL << (col * (HEIGHT + 1));
}
int count_winning_moves(int64_t pos, int64_t mask) {
int64_t moves = get_winning_moves(pos, mask);
int n = 0;
while (moves) {
moves &= moves - 1;
n++;
}
return n;
}
class BitBoard {
public:
BitBoard(int64_t pos = 0, int64_t mask = 0, int moves = 0):pos(pos), mask(mask), moves(moves) {}
char get_next_move_type() {
return moves % 2 == 0 ? 'X' : 'O';
}
// seq: sequence of chars, each char is a move.
// '1' -> col 0
// '2' -> col 1
// .etc
BitBoard(const std::string& seq): pos(0), mask(0), moves(0) {
for (int i = 0; i < seq.length(); i++) {
int col = seq[i] - '1';
pos ^= mask;
mask |= mask + get_bottom_mask(col);
moves++;
}
}
int64_t key() {
return pos + mask;
}
BitBoard make_move(int64_t move) {
assert(move);
return BitBoard(pos ^ mask, mask | move, moves + 1);
}
int64_t get_legal_moves() {
return (mask + BOTTOM) & FULL;
}
int64_t get_non_losing_moves() {
int64_t oppo_winning_moves = get_winning_moves(pos ^ mask, mask);
int64_t legal_moves = get_legal_moves();
int64_t forced_moves = legal_moves & oppo_winning_moves;
if (forced_moves) {
if (forced_moves & (forced_moves - 1)) {
// more than 1 forced moves
return 0;
}
legal_moves = forced_moves;
}
return legal_moves & ~(oppo_winning_moves >> 1);
}
bool canWinWithOneMove() {
return get_winning_moves(pos, mask) & get_legal_moves();
}
bool is_winning_move(int64_t move) {
return move & get_winning_moves(pos, mask);
}
void sort_moves(int64_t* res, int n) {
std::array<int, WIDTH> score;
for (int i = 0; i < n; i++) {
score[i] = count_winning_moves(pos | res[i], mask);
}
for (int i = 1; i < n; i++) {
int64_t t = res[i];
int s = score[i];
int j = i;
while (j && score[j-1] < s) {
res[j] = res[j - 1];
score[j] = score[j - 1];
j--;
}
res[j] = t;
score[j] = s;
}
}
void print() {
for (int i = HEIGHT - 1; i >= 0; i--) {
for (int j = 0; j < WIDTH; j++) {
int64_t t = 1LL << ((HEIGHT + 1) * j + i);
if (mask & t) {
if ((bool)(pos & t) == (moves % 2 == 0)) {
std::cout << " X";
} else {
std::cout << " O";
}
} else {
std::cout << " -";
}
}
std::cout << std::endl;
}
for (int j = 0; j < WIDTH; j++) {
std::cout << " " << j + 1;
}
std::cout << std::endl;
}
int64_t pos;
int64_t mask;
int moves;
};
class Solver {
public:
Solver(Table& table): table(table), nodeCount(0) {}
void reset() {
nodeCount = 0;
}
int negamax(BitBoard a, int alpha, int beta) {
nodeCount++;
int64_t moves = a.get_non_losing_moves();
if (!moves) {
return -(HEIGHT * WIDTH - a.moves) / 2;
}
if (a.moves >= WIDTH * HEIGHT - 2) {
return 0;
}
auto key = a.key();
auto [low, hi] = table.get(key, a.moves);
hi = std::min(hi, (HEIGHT * WIDTH - a.moves - 1) / 2);
low = std::max(low, -(HEIGHT * WIDTH - a.moves - 2) / 2);
if (low == hi) {
return low;
}
if (low >= beta) {
return low;
}
if (hi <= alpha) {
return hi;
}
alpha = std::max(alpha, low);
beta = std::min(hi, beta);
int score = MIN_SCORE;
int alpha0 = alpha;
std::array<int64_t, WIDTH> cand;
int n = 0;
for (int i = 0; i < WIDTH; i++) {
int idx = WIDTH / 2 + (1 - 2 * (i % 2)) * (i + 1) / 2;
int64_t move = get_column_mask(idx) & moves;
if (move) {
cand[n++] = move;
}
}
a.sort_moves(cand.data(), n);
for (int i = 0; i < n; i++) {
BitBoard b = a.make_move(cand[i]);
score =std::max(score, -negamax(b, -beta, -alpha));
if (score >= beta) {
break;
}
if (score > alpha) {
alpha = score;
}
}
alpha = alpha0;
if (score > alpha && score < beta) {
table.put(key, score, score, a.moves);
} else if (score <= alpha) {
table.put(key, low, score, a.moves);
} else {
table.put(key, score, hi, a.moves);
}
return score;
}
int solve(BitBoard b) {
if (b.canWinWithOneMove()) {
return (WIDTH * HEIGHT - b.moves + 1) / 2;
}
int min = -(WIDTH * HEIGHT - b.moves) / 2;
int max = (WIDTH * HEIGHT + 1 - b.moves) / 2;
while(min < max) { // iteratively narrow the min-max exploration window
int med = min + (max - min)/2;
if(med <= 0 && min/2 < med) med = min/2;
else if(med >= 0 && max/2 > med) med = max/2;
int r = negamax(b, med, med + 1); // use a null depth window to know if the actual score is greater or smaller than med
if(r <= med) max = r;
else min = r;
}
return min;
}
int64_t nodeCount;
Table& table;
};
class Searcher {
public:
Searcher(Solver& solver, int depth): solver(solver), depth(depth) {}
void search(BitBoard b) {
if (b.moves <= depth) {
dfs(b);
}
}
private:
void dfs(BitBoard b) {
auto key = b.key();
if (!is_printed(key)){
std::cout << key << " " << solver.solve(b) << std::endl;
mark_printed(key);
}
if (b.moves >= depth) {
return;
}
auto moves = b.get_non_losing_moves();
while (moves) {
int64_t move = moves & -moves;
dfs(b.make_move(move));
moves -= move;
}
}
void mark_printed(int64_t key) {
printed.insert(key);
}
bool is_printed(int64_t key) {
return printed.find(key) != printed.end();
}
std::unordered_set<int64_t> printed;
Solver& solver;
const int depth;
};
int64_t column_to_move(BitBoard& b, int col) {
return get_column_mask(col) & (b.mask + BOTTOM);
}
int move_to_column(int64_t move) {
int n = 0;
while (move) {
move >>= HEIGHT + 1;
n++;
}
return n - 1;
}
class Agent {
public:
virtual ~Agent() = default;
virtual int64_t get_move(BitBoard& b) = 0;
virtual std::string getName() = 0;
};
class Human: public Agent {
public:
virtual int64_t get_move(BitBoard& b) override {
std::cout << getName() << "> ";
int col;
std::cin >> col;
return column_to_move(b, col - 1);
}
virtual std::string getName() override {
return "Human";
}
};
class AI :public Agent{
public:
AI(Solver& solver, const std::string& name="AI"): Agent(), solver(solver), name(name) {}
virtual int64_t get_move(BitBoard& b) override {
int64_t winning_moves = get_winning_moves(b.pos, b.mask) & b.get_legal_moves();
int64_t res;
if (winning_moves) {
res = winning_moves & -winning_moves;
} else {
auto moveScores = get_move_scores(b);
if (moveScores.size() == 0) {
std::cout << std::endl;
// return first legal move when losing
auto moves = b.get_legal_moves();
res = moves & -moves;
} else {
std::cout << getName() << "> " << std::flush;
int max = MIN_SCORE - 1;
for (auto [move, score] : moveScores) {
if(score > max) {
max = score;
res = move;
}
std::cout << move_to_column(move) + 1 << ":" << score << " ";
}
std::cout << std::endl;
}
}
std::cout << getName() << "> " << move_to_column(res) + 1 << std::endl;
return res;
}
virtual std::string getName() override {
return name;
}
private:
std::vector<std::pair<int64_t,int>> get_move_scores(BitBoard& b) {
int64_t moves = b.get_non_losing_moves();
std::vector<std::pair<int64_t,int>> res;
while (moves) {
int64_t move = moves & -moves;
res.emplace_back(move, -solver.solve(b.make_move(move)));
moves -= move;
}
return res;
}
Solver& solver;
std::string name;
};
void printHelpAndExit() {
std::cout
<< " ./main solve # solve game states, take input from stdin." << std::endl
<< " ./main search [-s <starting-moves>] [-d <depth>] # compute and print score table up to given depth (default 8)" << std::endl
<< " ./main play # play with solver with text UI" << std::endl
<< " ./main suggest [-s <starting-moves>] [-d <depth>] # suggest the best move for the next player" << std::endl
<< " Optional flags:"<< std::endl
<< " -l <score-table-file> # load score table" << std::endl
<< " -t <pin-score-depth-thresh> pin score into cache if depth <= this threshold" << std::endl
<< " -a1 [ai|human] agent 1 when playing, default: human" << std::endl
<< " -a2 [ai|human] agent 2 when playing, default: ai" << std::endl
<< " -bp <bitboard-pos> # specify BitBoard position" << std::endl
<< " -bm <bitboard-mask> # specify BitBoard mask" << std::endl
<< " -bo <bitboard-moves> # specify BitBoard moves" << std::endl;
exit(-1);
}
struct Args {
std::string cmd;
std::string startingMoves;
int depth = 8;
std::string scoreTableFile;
int pinScoreDepthThreshold = -1;
bool weakSolver = false;
std::string agent1 = "human";
std::string agent2 = "ai";
int64_t bitBoardPos = 0;
int64_t bitBoardMask = 0;
int bitBoardMoves = 0;
};
Args parseArgs(int argc, char** argv) {
Args args;
if (argc > 1) {
args.cmd = argv[1];
for (int i = 2; i < argc; i++) {
char* s = argv[i];
if (s[0] == '-') {
switch (s[1]) {
case 's':
args.startingMoves = argv[++i];
break;
case 'd':
args.depth = std::atoi(argv[++i]);
break;
case 'l':
args.scoreTableFile = argv[++i];
break;
case 't':
args.pinScoreDepthThreshold = std::atoi(argv[++i]);
break;
case 'a':
if (s[2] == '1') {
args.agent1 = argv[++i];
} else if (s[2] == '2') {
args.agent2 = argv[++i];
} else {
printHelpAndExit();
}
break;
case 'b':
if (s[2] == 'p') {
args.bitBoardPos = std::stoll(argv[++i]);
} else if (s[2] == 'm') {
args.bitBoardMask = std::stoll(argv[++i]);
} else if (s[2] == 'o') {
args.bitBoardMoves = std::atoi(argv[++i]);
} else {
printHelpAndExit();
}
break;
default:
printHelpAndExit();
break;
}
} else {
printHelpAndExit();
}
}
} else {
printHelpAndExit();
}
return args;
}
void clearScreen() {
std::cout << "\033[2J\033[1;1H";
}
class GameRunner {
public:
void play(BitBoard& b, std::array<std::unique_ptr<Agent>, 2>& agents) {
int move;
int turn = 0;
while (1) {
clearScreen();
b.print();
std::cout << b.get_next_move_type() << " playing" << std::endl;
int64_t move = agents[turn]->get_move(b);
if (b.is_winning_move(move)) {
clearScreen();
b.make_move(move).print();
std::cout << agents[turn]->getName() << " " << b.get_next_move_type() << " wins" << std::endl;
break;
}
b = b.make_move(move);
turn = 1 - turn;
sleep(1);
if (b.moves == HEIGHT * WIDTH) {
std::cout << "draw" << std::endl;
break;
}
}
}
int play(BitBoard& b, std::unique_ptr<Agent>& agent) {
int64_t move = agent->get_move(b);
return move_to_column(move) + 1;
}
};
std::unique_ptr<Agent> makeAgent(Solver& solver, const std::string& name) {
if (name == "human") {
return std::make_unique<Human>();
}
return std::make_unique<AI>(solver, "ai");
}
int main(int argc, char** argv) {
Args args = parseArgs(argc, argv);
Table table(args.pinScoreDepthThreshold);
if (args.scoreTableFile.size()) {
table.load(args.scoreTableFile);
}
Solver solver(table);
if (args.cmd == "solve") {
std::string s;
int score;
int testId = 1;
auto start = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> totalDuration = std::chrono::duration<double>::zero();
while (std::cin >> s >> score) {
BitBoard b(s);
b.print();
solver.reset();
auto st = std::chrono::high_resolution_clock::now();
int v = solver.solve(b);
std::chrono::duration<double> duration = std::chrono::high_resolution_clock::now() - st;
totalDuration += duration;
if (score == v) {
std::cout << "test " << testId++ << " pass " << duration.count() << "s node_count " << solver.nodeCount << std::endl;
} else {
std::cout << "test " << testId++ << " fail: " << v << "!=" << score << std::endl;
break;
}
}
std::cout << "total duration: " << totalDuration.count() << "s" << std::endl;
} else if (args.cmd == "search") {
BitBoard b(args.startingMoves);
Searcher searcher(solver, args.depth);
searcher.search(b);
} else if (args.cmd == "play") {
int move;
BitBoard b(args.startingMoves);
std::array<std::unique_ptr<Agent>, 2> agents = {
makeAgent(solver, args.agent1),
makeAgent(solver, args.agent2),
};
GameRunner().play(b, agents);
} else if (args.cmd == "suggest") {
BitBoard b(args.bitBoardPos,args.bitBoardMask,args.bitBoardMoves);
Solver solver(table);
auto agent = makeAgent(solver, "ai");
int bestMoveColumn = GameRunner().play(b, agent);
std::cout << bestMoveColumn << std::endl;
return 0;
} else {
printHelpAndExit();
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment