Skip to content

Instantly share code, notes, and snippets.

@Flock1
Created May 6, 2020 13:55
Show Gist options
  • Save Flock1/f0c04fbf404ebaf9b5495352f5a37c34 to your computer and use it in GitHub Desktop.
Save Flock1/f0c04fbf404ebaf9b5495352f5a37c34 to your computer and use it in GitHub Desktop.
class ActorCritic(nn.Module):
def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):
super(ActorCritic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(num_inputs, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
self.actor = nn.Sequential(
nn.Linear(num_inputs, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_outputs),
)
self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)
self.apply(init_weights)
def forward(self, x):
value = self.critic(x)
mu = F.softmax(self.actor(x))
if(mu.shape[0]==3):
mu = mu.reshape(1,mu.shape[0])
dist = Categorical(mu)
return dist, value
num_inputs = proj.shape[0]
num_outputs = len(command)
hidden_size = 64
model_2 = ActorCritic(num_inputs, num_outputs, hidden_size)
model_2.load_state_dict(torch.load('model_March_6_2_5_actions.pth'))
model_2.eval()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment