Last active
January 5, 2017 22:02
-
-
Save Tostino/4ba74e5366c95ae403c165eddb24f5ed to your computer and use it in GitHub Desktop.
rl4j doom
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.deeplearning4j.gym.StepReply; | |
import org.deeplearning4j.rl4j.mdp.MDP; | |
import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | |
import org.deeplearning4j.rl4j.space.DiscreteSpace; | |
import org.deeplearning4j.rl4j.space.Encodable; | |
import org.deeplearning4j.rl4j.space.ObservationSpace; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import oshi.SystemInfo; | |
import oshi.hardware.GlobalMemory; | |
import oshi.util.FormatUtil; | |
import vizdoom.*; | |
import vizdoom.Button; | |
import javax.imageio.ImageIO; | |
import java.awt.*; | |
import java.awt.color.ColorSpace; | |
import java.awt.image.*; | |
import java.io.*; | |
import java.nio.ByteBuffer; | |
import java.nio.ByteOrder; | |
import java.nio.IntBuffer; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.nio.file.Paths; | |
import java.util.ArrayList; | |
import java.util.List; | |
/** | |
* @author rubenfiszel ([email protected]) on 7/28/16. | |
* <p> | |
* Mother abstract class for all VizDoom scenarios | |
* <p> | |
* is mostly configured by | |
* <p> | |
* String scenario; name of the scenario | |
* double livingReward; additional reward at each step for living | |
* double deathPenalty; negative reward when ded | |
* int doomSkill; skill of the ennemy | |
* int timeout; number of step after which simulation time out | |
* int startTime; number of internal tics before the simulation starts (useful to draw weapon by example) | |
* List<Button> buttons; the list of inputs one can press for a given scenario (noop is automatically added) | |
*/ | |
abstract public class VizDoom implements MDP<VizDoom.MdpGameScreen, Integer, DiscreteSpace> | |
{ | |
final public static String DOOM_ROOT = "./vizdoom"; | |
final protected Logger log = LoggerFactory.getLogger("Vizdoom"); | |
final protected GlobalMemory memory = new SystemInfo().getHardware().getMemory(); | |
final protected List<int[]> actions; | |
final protected DiscreteSpace discreteSpace; | |
final protected ObservationSpace<MdpGameScreen> observationSpace; | |
final protected boolean render; | |
protected DoomGame game; | |
protected double scaleFactor = 1; | |
public VizDoom() | |
{ | |
this(true); | |
} | |
public VizDoom(boolean render) | |
{ | |
this.render = render; | |
actions = new ArrayList<int[]>(); | |
game = new DoomGame(); | |
setupGame(); | |
discreteSpace = new DiscreteSpace(getConfiguration().getButtons().size() + 1); | |
observationSpace = new ArrayObservationSpace<>(new int[]{game.getScreenHeight(), game.getScreenWidth(), 3}); | |
} | |
public boolean isRender() | |
{ | |
return render; | |
} | |
public void setScaleFactor(final double scaleFactor) | |
{ | |
this.scaleFactor = scaleFactor; | |
} | |
public void setupGame() | |
{ | |
Configuration conf = getConfiguration(); | |
game.setViZDoomPath(DOOM_ROOT + "/vizdoom"); | |
game.setDoomGamePath(DOOM_ROOT + "/scenarios/freedoom2.wad"); | |
game.setDoomScenarioPath(DOOM_ROOT + "/scenarios/" + conf.getScenario() + ".wad"); | |
game.setDoomMap("map01"); | |
game.setScreenFormat(ScreenFormat.RGB24); | |
game.setScreenResolution(ScreenResolution.RES_800X600); | |
// Sets other rendering options | |
game.setRenderHud(false); | |
game.setRenderCrosshair(false); | |
game.setRenderWeapon(true); | |
game.setRenderDecals(false); | |
game.setRenderParticles(false); | |
GameVariable[] gameVar = new GameVariable[]{ | |
GameVariable.KILLCOUNT, | |
GameVariable.ITEMCOUNT, | |
GameVariable.SECRETCOUNT, | |
GameVariable.FRAGCOUNT, | |
GameVariable.HEALTH, | |
GameVariable.ARMOR, | |
GameVariable.DEAD, | |
GameVariable.ON_GROUND, | |
GameVariable.ATTACK_READY, | |
GameVariable.ALTATTACK_READY, | |
GameVariable.SELECTED_WEAPON, | |
GameVariable.SELECTED_WEAPON_AMMO, | |
GameVariable.AMMO1, | |
GameVariable.AMMO2, | |
GameVariable.AMMO3, | |
GameVariable.AMMO4, | |
GameVariable.AMMO5, | |
GameVariable.AMMO6, | |
GameVariable.AMMO7, | |
GameVariable.AMMO8, | |
GameVariable.AMMO9, | |
GameVariable.AMMO0 | |
}; | |
// Adds game variables that will be included in state. | |
for (int i = 0; i < gameVar.length; i++) | |
{ | |
game.addAvailableGameVariable(gameVar[i]); | |
} | |
// Causes episodes to finish after timeout tics | |
game.setEpisodeTimeout(conf.getTimeout()); | |
game.setEpisodeStartTime(conf.getStartTime()); | |
game.setWindowVisible(render); | |
game.setSoundEnabled(false); | |
game.setMode(Mode.PLAYER); | |
game.setLivingReward(conf.getLivingReward()); | |
// Adds buttons that will be allowed. | |
List<Button> buttons = conf.getButtons(); | |
int size = buttons.size(); | |
actions.add(new int[size + 1]); | |
for (int i = 0; i < size; i++) | |
{ | |
game.addAvailableButton(buttons.get(i)); | |
int[] action = new int[size + 1]; | |
action[i] = 1; | |
actions.add(action); | |
} | |
game.setDeathPenalty(conf.getDeathPenalty()); | |
game.setDoomSkill(conf.getDoomSkill()); | |
game.init(); | |
} | |
public boolean isDone() | |
{ | |
return game.isEpisodeFinished(); | |
} | |
public MdpGameScreen reset() | |
{ | |
log.info("free Memory: " + FormatUtil.formatBytes(memory.getAvailable()) + "/" | |
+ FormatUtil.formatBytes(memory.getTotal())); | |
game.newEpisode(); | |
int[] screen_buffer = convertScreenBuffer(game.getState().screenBuffer); | |
return new MdpGameScreen(screen_buffer); | |
} | |
public void close() | |
{ | |
game.close(); | |
} | |
public int[] convertScreenBuffer(byte[] buffer) | |
{ | |
IntBuffer intBuf = | |
ByteBuffer.wrap(game.getState().screenBuffer) | |
.order(ByteOrder.BIG_ENDIAN) | |
.asIntBuffer(); | |
int[] initial_array = new int[intBuf.remaining()]; | |
intBuf.get(initial_array); | |
int height_ratio = game.getScreenHeight() / game.getScreenWidth(); | |
int width_ratio = game.getScreenWidth() / game.getScreenHeight(); | |
if (height_ratio / width_ratio == 1) | |
{ | |
return initial_array; | |
} | |
else | |
{ | |
// FIXME: here I need help | |
// do some scaling somehow | |
try | |
{ | |
DataBufferByte buf = new DataBufferByte(buffer, buffer.length); | |
ColorModel cm = new ComponentColorModel(ColorSpace.getInstance(ColorSpace.CS_sRGB), new int[]{8, 8, 8}, false, false, Transparency.OPAQUE, DataBuffer.TYPE_BYTE); | |
BufferedImage i = new BufferedImage(cm, Raster.createInterleavedRaster(buf, 800, 600, 800 * 3, 3, new int[]{0, 1, 2}, null), false, null); | |
File f = new File("E:\\test.png"); | |
ImageIO.write(i, "png", f); | |
} | |
catch (IOException e) | |
{ | |
e.printStackTrace(); | |
} | |
// return scaled value | |
return initial_array; | |
} | |
} | |
public StepReply<MdpGameScreen> step(Integer action) | |
{ | |
double r = game.makeAction(actions.get(action)) * scaleFactor; | |
log.info(game.getEpisodeTime() + " " + r + " " + action + " "); | |
int[] screen_buffer = convertScreenBuffer(game.getState().screenBuffer); | |
return new StepReply(new MdpGameScreen(screen_buffer), r, game.isEpisodeFinished(), null); | |
} | |
public ObservationSpace<MdpGameScreen> getObservationSpace() | |
{ | |
return observationSpace; | |
} | |
public DiscreteSpace getActionSpace() | |
{ | |
return discreteSpace; | |
} | |
public abstract Configuration getConfiguration(); | |
public abstract VizDoom newInstance(); | |
public static class Configuration | |
{ | |
String scenario; | |
double livingReward; | |
double deathPenalty; | |
int doomSkill; | |
int timeout; | |
int startTime; | |
List<Button> buttons; | |
public Configuration(final String scenario, final double livingReward, final double deathPenalty, final int doomSkill, final int timeout, final int startTime, final List<Button> buttons) | |
{ | |
this.scenario = scenario; | |
this.livingReward = livingReward; | |
this.deathPenalty = deathPenalty; | |
this.doomSkill = doomSkill; | |
this.timeout = timeout; | |
this.startTime = startTime; | |
this.buttons = buttons; | |
} | |
public String getScenario() | |
{ | |
return scenario; | |
} | |
public double getLivingReward() | |
{ | |
return livingReward; | |
} | |
public double getDeathPenalty() | |
{ | |
return deathPenalty; | |
} | |
public int getDoomSkill() | |
{ | |
return doomSkill; | |
} | |
public int getTimeout() | |
{ | |
return timeout; | |
} | |
public int getStartTime() | |
{ | |
return startTime; | |
} | |
public List<Button> getButtons() | |
{ | |
return buttons; | |
} | |
} | |
public static class MdpGameScreen implements Encodable | |
{ | |
double[] array; | |
public MdpGameScreen(int[] screen) | |
{ | |
array = new double[screen.length]; | |
for (int i = 0; i < screen.length; i++) | |
{ | |
array[i] = screen[i]; | |
} | |
} | |
public double[] toArray() | |
{ | |
return array; | |
} | |
} | |
} |
Author
Tostino
commented
Jan 5, 2017
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment