Last active
May 28, 2020 03:29
-
-
Save Shikugawa/3079dc7be8c50710c4e5bc10bb8c181c to your computer and use it in GitHub Desktop.
強化学習のテンプレ(オセロ+ミニマックスAI)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <array> | |
#include <cassert> | |
#include <cstddef> | |
#include <iostream> | |
#include <optional> | |
#include <ostream> | |
#include <set> | |
#include <vector> | |
class MinMaxPlayer; | |
class Othello { | |
public: | |
enum class CellState { | |
EMPTY, | |
BLACK, | |
WHITE, | |
}; | |
struct Position { | |
size_t x; | |
size_t y; | |
bool operator==(const Position &q) const { | |
return this->x == q.x && this->y == q.y; | |
} | |
bool operator==(Position &q) const { | |
return *const_cast<const Position *>(&q) == *this; | |
} | |
bool validate() const { return x < Board::CELL_NUM && y < Board::CELL_NUM; } | |
}; | |
enum class Turn { | |
BLACK, WHITE | |
}; | |
Othello(Turn default_turn) : current_turn_(default_turn) { | |
current_availables_ = availables(default_turn); | |
} | |
void update(const Position &pos) { | |
if (gameend_ || !pos.validate()) { | |
return; | |
} | |
const auto current_position = Position{pos.x, pos.y}; | |
switch (current_turn_) { | |
case Turn::WHITE: { | |
for (const auto &convertable : | |
convertables(current_position, CellState::WHITE)) { | |
board_.put(convertable, CellState::WHITE); | |
} | |
board_.put(current_position, CellState::WHITE); | |
current_turn_ = Turn::BLACK; | |
current_availables_ = availables(current_turn_); | |
if (current_availables_.empty()) { | |
current_turn_ = Turn::WHITE; | |
current_availables_ = availables(current_turn_); | |
} | |
break; | |
} | |
case Turn::BLACK: { | |
for (const auto &convertable : | |
convertables(current_position, CellState::BLACK)) { | |
board_.put(convertable, CellState::BLACK); | |
} | |
board_.put(current_position, CellState::BLACK); | |
current_turn_ = Turn::WHITE; | |
current_availables_ = availables(current_turn_); | |
if (current_availables_.empty()) { | |
current_turn_ = Turn::BLACK; | |
current_availables_ = availables(current_turn_); | |
} | |
break; | |
} | |
default: | |
break; | |
} | |
if (isEnd()) { | |
gameend_ = true; | |
} | |
} | |
std::vector<Position> availables(Turn turn) const { | |
std::vector<Position> available_cells; | |
const auto state = | |
turn == Turn::BLACK ? CellState::BLACK : CellState::WHITE; | |
for (size_t i = 0; i < Board::CELL_NUM; ++i) { | |
for (size_t j = 0; j < Board::CELL_NUM; ++j) { | |
if (board_.get(Position{j, i}) != CellState::EMPTY) { | |
continue; | |
} | |
const auto convertables_ = convertables(Position{j, i}, state); | |
if (convertables_.empty()) { | |
continue; | |
} | |
const auto cell = Position{j, i}; | |
available_cells.emplace_back(cell); | |
} | |
} | |
return available_cells; | |
} | |
bool isEnd() { return board_.black_ + board_.white_ == 64; } | |
Turn currentTurn() { return current_turn_; } | |
const std::vector<Position> ¤tAvailables() { | |
return current_availables_; | |
} | |
private: | |
enum class Direction { | |
LEFT, | |
RIGHT, | |
UP, | |
DOWN, | |
LEFT_UP, | |
RIGHT_UP, | |
LEFT_DOWN, | |
RIGHT_DOWN, | |
}; | |
std::vector<Position> convertables(const Position &pos, | |
CellState self) const { | |
if (!pos.validate() || board_.get(pos) != CellState::EMPTY) { | |
return std::vector<Position>(); | |
} | |
std::set<Direction> available_directions{ | |
Direction::UP, Direction::DOWN, Direction::LEFT, | |
Direction::RIGHT, Direction::LEFT_UP, Direction::RIGHT_UP, | |
Direction::LEFT_DOWN, Direction::RIGHT_DOWN}; | |
std::vector<Position> convert_positions; | |
for (const auto &direction : available_directions) { | |
const auto should_converts = shouldConverts(pos, direction, self); | |
if (should_converts != std::nullopt) { | |
convert_positions.insert(convert_positions.end(), | |
should_converts.value().begin(), | |
should_converts.value().end()); | |
} | |
} | |
return convert_positions; | |
} | |
std::optional<std::vector<Position>> | |
shouldConverts(const Position ¤t_pos, Direction dir, | |
CellState self) const { | |
bool should_convert = false; | |
auto current_position = nextPosition(current_pos, dir); | |
std::vector<Position> convertable_positions; | |
while (current_position.validate() && | |
board_.get(current_position) != CellState::EMPTY) { | |
if (board_.get(current_position) == self) { | |
should_convert = true; | |
break; | |
} | |
convertable_positions.emplace_back(current_position); | |
current_position = nextPosition(current_position, dir); | |
} | |
if (!should_convert) { | |
return std::nullopt; | |
} | |
return convertable_positions; | |
} | |
Position nextPosition(const Position &pos, Direction dir) const { | |
switch (dir) { | |
case Direction::UP: | |
return Position{pos.x, pos.y - 1}; | |
case Direction::DOWN: | |
return Position{pos.x, pos.y + 1}; | |
case Direction::RIGHT: | |
return Position{pos.x + 1, pos.y}; | |
case Direction::LEFT: | |
return Position{pos.x - 1, pos.y}; | |
case Direction::LEFT_UP: | |
return Position{pos.x - 1, pos.y - 1}; | |
case Direction::RIGHT_DOWN: | |
return Position{pos.x + 1, pos.y + 1}; | |
case Direction::LEFT_DOWN: | |
return Position{pos.x - 1, pos.y + 1}; | |
case Direction::RIGHT_UP: | |
return Position{pos.x + 1, pos.y - 1}; | |
} | |
} | |
Turn current_turn_; | |
struct Board { | |
static constexpr uint8_t CELL_NUM = 8; | |
using BoardType = std::array<std::array<CellState, CELL_NUM>, CELL_NUM>; | |
Board() { | |
for (auto i = 0; i < Board::CELL_NUM; ++i) { | |
for (auto j = 0; j < Board::CELL_NUM; ++j) { | |
board_[i][j] = CellState::EMPTY; | |
} | |
} | |
board_[3][3] = CellState::BLACK; | |
board_[4][4] = CellState::BLACK; | |
board_[3][4] = CellState::WHITE; | |
board_[4][3] = CellState::WHITE; | |
} | |
CellState get(const Position &pos) const { | |
assert(pos.validate()); | |
return board_[pos.y][pos.x]; | |
} | |
void put(const Position &pos, CellState state) { | |
assert(state != CellState::EMPTY); | |
assert(pos.validate()); | |
if (board_[pos.y][pos.x] != CellState::EMPTY && | |
board_[pos.y][pos.x] != state) { | |
if (board_[pos.y][pos.x] == CellState::WHITE) | |
--white_; | |
else | |
--black_; | |
} | |
board_[pos.y][pos.x] = state; | |
if (state == CellState::BLACK) { | |
++black_; | |
} else { | |
++white_; | |
} | |
} | |
size_t black_{2}; | |
size_t white_{2}; | |
private: | |
BoardType board_; | |
}; | |
Board board_; | |
bool gameend_; | |
std::vector<Position> current_availables_; | |
friend std::ostream &operator<<(std::ostream &, const Othello &); | |
friend MinMaxPlayer; | |
}; | |
std::ostream &operator<<(std::ostream &os, const Othello &othello) { | |
os << "black: " << othello.board_.black_ << std::endl; | |
os << "white: " << othello.board_.white_ << std::endl; | |
for (auto i = -1; i < Othello::Board::CELL_NUM; ++i) { | |
if (i == -1) { | |
os << " "; | |
} else { | |
os << i << " "; | |
} | |
} | |
os << std::endl; | |
size_t col_num = 0; | |
for (size_t i = 0; i < Othello::Board::CELL_NUM; ++i) { | |
os << col_num << " "; | |
++col_num; | |
for (size_t j = 0; j < Othello::Board::CELL_NUM; ++j) { | |
auto position = Othello::Position{j, i}; | |
if (std::find(othello.current_availables_.begin(), | |
othello.current_availables_.end(), | |
position) != othello.current_availables_.end()) { | |
os << "× "; | |
} else { | |
const auto state = othello.board_.get(position); | |
if (state == Othello::CellState::EMPTY) { | |
os << " "; | |
continue; | |
} | |
const auto symbol = othello.board_.get(Othello::Position{j, i}) == | |
Othello::CellState::WHITE | |
? "◯" | |
: "●"; | |
os << symbol << " "; | |
} | |
} | |
os << std::endl; | |
} | |
return os; | |
} | |
class MinMaxPlayer { | |
public: | |
MinMaxPlayer(Othello::Turn player_role) : player_role_(player_role) {} | |
Othello::Position selectPosition(Othello game) { | |
uint32_t score = 0; | |
Othello::Position pos; | |
for (auto pos_ : game.availables(player_role_)) { | |
game.board_.put(static_cast<const Othello::Position>(pos_), | |
player_role_ == Othello::Turn::BLACK ? Othello::CellState::BLACK | |
: Othello::CellState::WHITE); | |
auto next = player_role_ == Othello::Turn::BLACK ? Othello::Turn::WHITE | |
: Othello::Turn::BLACK; | |
const auto current_score = evaluation(game, next); | |
if (current_score > score) { | |
score = current_score; | |
pos = pos_; | |
} | |
} | |
return pos; | |
} | |
uint32_t evaluation(Othello game, Othello::Turn current_role, int current_depth = 0) { | |
if (current_depth == max_depth_) { | |
return current_role == Othello::Turn::WHITE ? game.board_.white_ : game.board_.black_; | |
} | |
const auto available_cells = game.availables(current_role); | |
uint32_t min_score = std::numeric_limits<uint32_t>::max(); | |
uint32_t max_score = 0; | |
for (const auto &available_cell : available_cells) { | |
auto child_game = game; | |
auto cell_state = | |
current_role == Othello::Turn::BLACK ? Othello::CellState::BLACK : Othello::CellState::WHITE; | |
child_game.board_.put(available_cell, cell_state); | |
auto evaluation_score = evaluation(child_game, current_role == Othello::Turn::BLACK ? Othello::Turn::WHITE | |
: Othello::Turn::BLACK, | |
1 + current_depth); | |
if (current_role == player_role_ && max_score < evaluation_score) { | |
max_score = evaluation_score; | |
continue; | |
} | |
if (current_role != player_role_ && min_score > evaluation_score) { | |
min_score = evaluation_score; | |
continue; | |
} | |
} | |
return current_role == player_role_ ? max_score : min_score; | |
} | |
private: | |
size_t max_depth_{5}; // 5手先読み | |
Othello::Turn player_role_; | |
}; | |
int main() { | |
Othello::Turn player_turn(Othello::Turn::BLACK); | |
Othello::Turn enemy_turn(Othello::Turn::WHITE); | |
Othello game(player_turn); | |
MinMaxPlayer enemy(enemy_turn); | |
while (!game.isEnd()) { | |
std::cout << game << std::endl; | |
Othello::Position pos; | |
if (game.currentTurn() == player_turn) { | |
size_t x, y; | |
std::cout << "x: "; | |
std::cin >> x; | |
std::cout << "y: "; | |
std::cin >> y; | |
std::cout << std::endl; | |
pos = Othello::Position{x, y}; | |
} else { | |
pos = enemy.selectPosition(game); | |
} | |
if (!pos.validate() || std::find(game.currentAvailables().begin(), | |
game.currentAvailables().end(), | |
pos) == game.currentAvailables().end()) { | |
continue; | |
} | |
game.update(pos); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment