Created
November 5, 2023 15:01
-
-
Save paulgribble/2ada10b34b3c224e32c9dc3dda48cded to your computer and use it in GitHub Desktop.
Effect of batch size on simulation time per movement during MotorNet training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import os | |
import time | |
import sys | |
import json | |
import numpy as np | |
import torch as th | |
import matplotlib.pyplot as plt | |
import motornet as mn | |
print('All packages imported.') | |
print('pytorch version: ' + th.__version__) | |
print('numpy version: ' + np.__version__) | |
print('motornet version: ' + mn.__version__) | |
# %% | |
effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.RigidTendonHillMuscle()) | |
env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=1.) | |
# %% | |
class Policy(th.nn.Module): | |
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, device): | |
super().__init__() | |
self.device = device | |
self.hidden_dim = hidden_dim | |
self.n_layers = 1 | |
self.gru = th.nn.GRU(input_dim, hidden_dim, 1, batch_first=True) | |
self.fc = th.nn.Linear(hidden_dim, output_dim) | |
self.sigmoid = th.nn.Sigmoid() | |
# the default initialization in torch isn't ideal | |
for name, param in self.named_parameters(): | |
if name == "gru.weight_ih_l0": | |
th.nn.init.xavier_uniform_(param) | |
elif name == "gru.weight_hh_l0": | |
th.nn.init.orthogonal_(param) | |
elif name == "gru.bias_ih_l0": | |
th.nn.init.zeros_(param) | |
elif name == "gru.bias_hh_l0": | |
th.nn.init.zeros_(param) | |
elif name == "fc.weight": | |
th.nn.init.xavier_uniform_(param) | |
elif name == "fc.bias": | |
th.nn.init.constant_(param, -5.) | |
else: | |
raise ValueError | |
self.to(device) | |
def forward(self, x, h0): | |
y, h = self.gru(x[:, None, :], h0) | |
u = self.sigmoid(self.fc(y)).squeeze(dim=1) | |
return u, h | |
def init_hidden(self, batch_size): | |
weight = next(self.parameters()).data | |
hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device) | |
return hidden | |
device = th.device("cpu") | |
policy = Policy(env.observation_space.shape[0], 128, env.n_muscles, device=device) | |
optimizer = th.optim.Adam(policy.parameters(), lr=10**-3) | |
# %% | |
def l1(x, y): | |
"""L1 loss""" | |
return th.mean(th.sum(th.abs(x - y), dim=-1)) | |
# %% | |
BS = np.array([8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]) | |
BSt = np.zeros(np.shape(BS)) | |
n_batch = 10 | |
for i, batch_size in enumerate(BS): | |
bt0 = time.time() | |
losses = [] | |
for batch in range(n_batch): | |
# initialize batch | |
h = policy.init_hidden(batch_size=batch_size) | |
obs, info = env.reset(options={"batch_size": batch_size}) | |
terminated = False | |
# initial positions and targets | |
xy = [info["states"]["fingertip"][:, None, :]] | |
tg = [info["goal"][:, None, :]] | |
# simulate whole episode | |
while not terminated: # will run until `max_ep_duration` is reached | |
action, h = policy(obs, h) | |
obs, reward, terminated, truncated, info = env.step(action=action) | |
xy.append(info["states"]["fingertip"][:, None, :]) # trajectories | |
tg.append(info["goal"][:, None, :]) # targets | |
# concatenate into a (batch_size, n_timesteps, xy) tensor | |
xy = th.cat(xy, axis=1) | |
tg = th.cat(tg, axis=1) | |
loss = l1(xy, tg) # L1 loss on position | |
# backward pass & update weights | |
optimizer.zero_grad() | |
loss.backward() | |
th.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.) # important! | |
optimizer.step() | |
losses.append(loss.item()) | |
BSt[i] = (time.time()-bt0) * 1000 | |
print(f"{n_batch} x batch_size of {batch_size}: {BSt[i]:.0f} ms total, {BSt[i]/batch_size/n_batch:.1f} ms per movement") | |
# %% | |
plt.semilogy(BS, BSt/BS/n_batch, 'o-') | |
plt.xlabel('Batch Size') | |
plt.ylabel('Time per movement (ms)') | |
plt.title('Effect of Batch Size on Simulation Time') | |
plt.show() | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment