Created
July 24, 2015 03:43
-
-
Save shrubb/e79c86ccb5ab1e4d36f2 to your computer and use it in GitHub Desktop.
DQN helper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cmath> | |
#include <iostream> | |
#include <ale_interface.hpp> | |
#include <glog/logging.h> | |
#include <gflags/gflags.h> | |
#include "prettyprint.hpp" | |
#include "dqn.hpp" | |
#include <fstream> | |
#include <string> | |
DEFINE_bool(gpu, false, "Use GPU to brew Caffe"); | |
DEFINE_bool(gui, false, "Open a GUI window"); | |
DEFINE_string(rom, "/home/shrubb/Projects/deephack/games/tutankham.bin", "Atari 2600 ROM to play"); | |
DEFINE_string(game_name, "gopher", "Game name"); | |
DEFINE_string(solver, "/home/shrubb/Projects/deephack/Net/Snapshot-tutankham/dqn_tutankham_solver.prototxt", "Solver parameter file (*.prototxt)"); | |
DEFINE_int32(memory, 500000, "Capacity of replay memory"); | |
DEFINE_int32(explore, 1000000, "Number of iterations needed for epsilon to reach 0.1"); | |
DEFINE_double(gamma, 0.95, "Discount factor of future rewards (0,1]"); | |
DEFINE_int32(memory_threshold, 100, "Enough amount of transitions to start learning"); | |
DEFINE_int32(skip_frame, 3, "Number of frames skipped"); | |
DEFINE_bool(show_frame, false, "Show the current frame in CUI"); | |
DEFINE_string(model, "/home/shrubb/Projects/deephack/Net/Snapshot-tutankham/tutankham_iter_140000.caffemodel", "Model file to load"); | |
DEFINE_bool(evaluate, true, "Evaluation mode: only playing a game, no updates"); | |
DEFINE_double(evaluate_with_epsilon, 0.05, "Epsilon value to be used in evaluation mode"); | |
DEFINE_double(repeat_games, 30, "Number of games played in evaluation mode"); | |
double CalculateEpsilon(const int iter) { | |
if (iter < FLAGS_explore) { | |
return 1.0 - 0.9 * (static_cast<double>(iter) / FLAGS_explore); | |
} else { | |
return 0.1; | |
} | |
} | |
unsigned char hex(char x) { | |
if (x >= '0' and x <= '9') { | |
return x - '0'; | |
} else { | |
return (unsigned char)10 + (unsigned char)(x - 'A'); | |
} | |
} | |
int zero_count = 0; | |
int total_score = 0; | |
bool read_screen(std::vector<std::vector<unsigned char>> &raw_screen, ALEInterface & interface, int iter){ | |
bool term = false; | |
char a, b; | |
char terminate; | |
char reward[10]; | |
const ALEScreen * screen_const = & interface.getScreen(); | |
ALEScreen * screen = const_cast<ALEScreen*>(screen_const); | |
for (int i = 0; i < 210; ++i) { | |
for (int j = 0; j < 160; ++j) { | |
int scan_res = fscanf(stdin, "%c%c", &a, &b); | |
if (b == 'I') { // DIE | |
term = true; | |
break; | |
} | |
raw_screen[i][j] = hex(a) * (unsigned char)16 + hex(b); | |
*screen->pixel(i, j) = raw_screen[i][j]; | |
} | |
} | |
char temp; | |
int scan_res = fscanf(stdin, "%c", &temp); // : | |
scan_res = fscanf(stdin, "%c", &terminate); | |
scan_res = fscanf(stdin, "%c", &temp); // : | |
std::cin.get(reward, 9, ':'); | |
if (reward[0] != '0') { | |
std::cerr << zero_count << " zeros\n" << reward << std::endl; | |
zero_count = 0; | |
total_score += atoi(reward); | |
} else { | |
zero_count++; | |
} | |
if (terminate == '1') { | |
term = true; | |
} | |
//char temp[70000]; | |
//fgets(temp, 69999, stdin); | |
std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n'); | |
if (iter % 25 == 0) { | |
interface.saveScreenPNG(std::string("/home/shrubb/screen") + std::to_string(iter) + std::string(".png")); | |
} | |
return term; | |
} | |
void make_action(ALEInterface ale, Action action){ | |
fprintf(stdout, "%d,18\n", action); | |
fflush(stdout); | |
if(FLAGS_gui) | |
ale.act(action); | |
return; | |
} | |
int main(int argc, char** argv) { | |
/*int test; | |
std::ifstream in("/home/shrubb/test.txt"); | |
for (int i = 0; i < 10; ++i) { | |
test = in.get(); | |
std::cout << test << std::endl; | |
} | |
return 0;*/ | |
//google:: | |
//minloglevel=google::ERROR; | |
gflags::ParseCommandLineFlags(&argc, &argv, true); | |
google::InitGoogleLogging(argv[0]); | |
google::InstallFailureSignalHandler(); | |
google::LogToStderr(); | |
if (FLAGS_gpu) { | |
caffe::Caffe::set_mode(caffe::Caffe::GPU); | |
} else { | |
caffe::Caffe::set_mode(caffe::Caffe::CPU); | |
} | |
//freopen("simple_in", "r", stdin); | |
fprintf(stdout, "team_4,CZELol,%s\n", FLAGS_game_name.c_str()); | |
fflush(stdout); | |
//std::ofstream out("test.txt"); | |
char temp[70000]; | |
fgets(temp, 69999, stdin); | |
//out << temp << std::endl; | |
//out.close(); | |
fprintf(stdout, "0,0,0,1\n"); | |
fflush(stdout); | |
ALEInterface ale(FLAGS_gui); | |
ALEInterface ale2(FLAGS_gui); | |
// Load the ROM file | |
ale.loadROM(FLAGS_rom); | |
ale2.loadROM(FLAGS_rom); | |
// Get the vector of legal actions | |
const auto legal_actions = ale.getMinimalActionSet(); | |
dqn::DQN dqn(legal_actions, FLAGS_solver, FLAGS_memory, FLAGS_gamma); | |
dqn.Initialize(); | |
std::cerr << "Loading " << FLAGS_model << std::endl; | |
dqn.LoadTrainedModel(FLAGS_model); | |
// char test[70000]; | |
// fscanf(stdin, "%s", test); | |
// std::cout << strlen(test); | |
// return 0; | |
std::deque<dqn::FrameDataSp> past_frames; | |
dqn::FrameDataSp current_frame; | |
std::vector<std::vector<unsigned char>> raw_screen(210, std::vector<unsigned char>(160)); | |
bool term; | |
// If there are not past frames enough for DQN input, just select NOOP | |
int frame = 0; | |
for (; frame < 4; ++frame){ | |
term = read_screen(raw_screen, ale2, frame); | |
if (term) | |
goto konec; | |
//std::cout << "Term " << term << std::endl; | |
current_frame = dqn::PreprocessArrayScreen(raw_screen); | |
past_frames.push_back(current_frame); | |
make_action(ale, PLAYER_A_NOOP); | |
} | |
for (; ; ++frame) { | |
term = read_screen(raw_screen, ale2, 4 * frame); | |
if (term) | |
goto konec; | |
std::cerr << "frame " << frame << std::endl; | |
//std::cout << "Term " << term << std::endl; | |
if (frame % 500 == 0) std::cerr << "frame: " << frame << std::endl; | |
current_frame = dqn::PreprocessArrayScreen(raw_screen); | |
past_frames.push_back(current_frame); | |
past_frames.pop_front(); | |
dqn::InputFrames input_frames; | |
std::copy(past_frames.begin(), past_frames.end(), input_frames.begin()); | |
float max_qvalue; | |
const auto action = dqn.SelectAction(input_frames, FLAGS_evaluate_with_epsilon, max_qvalue); | |
auto immediate_score = 0.0; | |
//for (auto i = 0; i < FLAGS_skip_frame + 1 && !ale.game_over(); ++i) { | |
for (auto i = 0; i < FLAGS_skip_frame + 1; ++i) { | |
// Last action is repeated on skipped frames | |
make_action(ale, action); | |
term = read_screen(raw_screen, ale2, 4 * frame + i + 1); | |
//std::cout << "Term " << term << std::endl; | |
if (term) | |
goto konec; | |
} | |
make_action(ale, action); | |
} | |
//out.close(); | |
konec: | |
std::cerr << "total score: " << total_score << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment