Last active
October 20, 2017 17:30
-
-
Save alanwells/a00dd44b529658cb95dbb72dfeeb8b0a to your computer and use it in GitHub Desktop.
Donkey2 custom train script, updated to work when records have been removed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def custom_train(cfg, tub_names, model_name): | |
''' | |
use the specified data in tub_names to train an artifical neural network | |
saves the output trained model as model_name | |
''' | |
import sklearn | |
from sklearn.model_selection import train_test_split | |
from sklearn.utils import shuffle | |
import random | |
from PIL import Image | |
import numpy as np | |
import glob | |
import json | |
import sys | |
images = [] | |
angles = [] | |
throttles = [] | |
tubs = gather_tubs(cfg, tub_names) | |
for tub in tubs: | |
record_paths = glob.glob(os.path.join(tub.path, 'record_*.json')) | |
for record_path in record_paths: | |
with open(record_path, 'r') as fp: | |
json_data = json.load(fp) | |
user_angle = dk.utils.linear_bin(json_data['user/angle']) | |
user_throttle = float(json_data["user/throttle"]) | |
image_filename = json_data["cam/image_array"] | |
image_path = os.path.join(tub.path, image_filename) | |
if (user_angle[7] != 1.0): | |
#if the categorical angle is not in the 0 bucket, always include it | |
images.append(image_path) | |
angles.append(user_angle) | |
throttles.append(user_throttle) | |
elif (random.randint(0, 9) < 2): | |
#Drop a percentage of records where categorical angle is in the 0 bucket | |
#increase the number in the conditional above to include more records | |
#(< 2 = 20% of 0 angle records included, < 3 = 30% of 0 angle records included, etc.) | |
images.append(image_path) | |
angles.append(user_angle) | |
throttles.append(user_throttle) | |
#shuffle and split the data | |
train_images, val_images, train_angles, val_angles, train_throttles, val_throttles = train_test_split(images, angles, throttles, shuffle=True, test_size=(1 - cfg.TRAIN_TEST_SPLIT)) | |
def generator(images, angles, throttles, batch_size=cfg.BATCH_SIZE): | |
num_records = len(images) | |
while True: | |
#shuffle again for good measure | |
shuffle(images, angles, throttles) | |
for offset in range(0, num_records, batch_size): | |
batch_images = images[offset:offset+batch_size] | |
batch_angles = angles[offset:offset+batch_size] | |
batch_throttles = throttles[offset:offset+batch_size] | |
augmented_images = [] | |
augmented_angles = [] | |
augmented_throttles = [] | |
for image_path, angle, throttle in zip(batch_images, batch_angles, batch_throttles): | |
image = Image.open(image_path) | |
image = np.array(image) | |
augmented_images.append(image) | |
augmented_angles.append(angle) | |
augmented_throttles.append(throttle) | |
if (angle[7] != 1.0): | |
#augment the data set with flipped versions of the nonzero angle records | |
augmented_images.append(np.fliplr(image)) | |
augmented_angles.append(np.flip(angle, axis=0)) | |
augmented_throttles.append(throttle) | |
augmented_images = np.array(augmented_images) | |
augmented_angles = np.array(augmented_angles) | |
augmented_throttles = np.array(augmented_throttles) | |
shuffle(augmented_images, augmented_angles, augmented_throttles) | |
X = [augmented_images] | |
y = [augmented_angles, augmented_throttles] | |
yield X, y | |
train_gen = generator(train_images, train_angles, train_throttles) | |
val_gen = generator(val_images, val_angles, val_throttles) | |
kl = dk.parts.KerasCategorical() | |
tubs = gather_tubs(cfg, tub_names) | |
model_path = os.path.expanduser(model_name) | |
total_records = len(images) | |
total_train = len(train_images) | |
total_val = len(val_images) | |
print('train: %d, validation: %d' %(total_train, total_val)) | |
steps_per_epoch = total_train // cfg.BATCH_SIZE | |
print('steps_per_epoch', steps_per_epoch) | |
kl.train(train_gen, | |
val_gen, | |
saved_model_path=model_path, | |
steps=steps_per_epoch, | |
train_split=cfg.TRAIN_TEST_SPLIT) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment