This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AtariNet(object): | |
# ... | |
# ... | |
def _build(self): | |
# ... | |
# ... | |
# convolutional layers for minimap features | |
self.minimap_conv1 = tf.layers.conv2d( | |
inputs=self.minimap_processed, | |
filters=16, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def preprocess_spatial_features(features, screen=True): | |
"""Embed categorical spatial features, log transform numeric features.""" | |
# ... | |
# ... | |
preprocess_ops = [] | |
for index, (feature_type, scale) in enumerate(feature_specs): | |
layer = transposed[:, :, :, index] | |
if feature_type == sc2_features.FeatureType.CATEGORICAL: | |
# one-hot encode in channel dimension -> 1x1 convolution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AtariNet(object): | |
# ... | |
# ... | |
# ... | |
def _build(self): | |
# ... | |
# ... | |
self.screen_features = tf.placeholder( | |
tf.int32, | |
[None, len(SCREEN_FEATURES), *self.screen_dimensions], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DQNMoveOnly(base_agent.BaseAgent): | |
# ... | |
# ... | |
# ... | |
# ... | |
def _update_target_network(self): | |
online_vars = tf.get_collection( | |
tf.GraphKeys.TRAINABLE_VARIABLES, 'DQN') | |
target_vars = tf.get_collection( | |
tf.GraphKeys.TRAINABLE_VARIABLES, 'DQNTarget') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from sklearn.datasets import load_iris | |
data = load_iris() | |
fig, axes = plt.subplots(nrows=2, ncols=2) | |
fig.subplots_adjust(hspace=0.5) | |
fig.suptitle('Distributions of Iris Features') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.datasets import load_iris | |
data = load_iris() | |
fig, axes = plt.subplots(nrows=2, ncols=2) | |
fig.subplots_adjust(hspace=0.5) | |
fig.suptitle('Distributions of Iris Features') | |
for ax, feature, name in zip(axes.flatten(), data.data.T, data.feature_names): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import load_iris | |
from sklearn.svm import SVC | |
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV | |
X, y = load_iris(return_X_y=True) | |
np.random.seed(41) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils import * | |
import ast | |
row_units = [cross(r, cols) for r in rows] | |
column_units = [cross(rows, c) for c in cols] | |
square_units = [cross(rs, cs) for rs in ('ABC','DEF','GHI') for cs in ('123','456','789')] | |
# TODO: Update the unit list to add the new diagonal units | |
left_diag_units = [[r+c for (r, c) in zip(rows, cols)]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
precession_newt = function() { | |
celestial_bodies = rbind(planets, dwarfplanets, asteroids) | |
celestial_bodies = celestial_bodies[order(celestial_bodies$distance), ] | |
celestial_bodies = celestial_bodies[complete.cases(celestial_bodies), ] | |
precession = c(rep(0, nrow(celestial_bodies))) | |
mass_sun = 1.989 * 10**30 | |