Created
May 23, 2024 13:56
-
-
Save soraxas/31e1a0eae2d11a5005c60627a3f080fc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
import os.path | |
import git | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from mpd.datasets.normalization import DatasetNormalizer | |
from mpd.utils.loading import load_params_from_yaml | |
from torch_robotics import environments, robots | |
from torch_robotics.environments import EnvDense2DExtraObjects | |
from torch_robotics.environments.env_simple_2d_extra_objects import EnvSimple2DExtraObjects | |
from torch_robotics.tasks.tasks import PlanningTask | |
from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer | |
repo = git.Repo('.', search_parent_directories=True) | |
dataset_base_dir = os.path.join(repo.working_dir, 'data_trajectories') | |
import pickle, io | |
class A: pass | |
self = A() | |
self.base_dir = 'data' | |
class UnpicklerCpu(pickle.Unpickler): | |
def find_class(self, module, name): | |
if module == 'torch.storage' and name == '_load_from_bytes': | |
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
return super().find_class(module, name) | |
class TrajectoryDatasetBase(Dataset, abc.ABC): | |
def __init__(self, | |
dataset_subdir=None, | |
include_velocity=False, | |
normalizer='LimitsNormalizer', | |
use_extra_objects=False, | |
obstacle_cutoff_margin=None, | |
tensor_args=None, | |
**kwargs): | |
self.tensor_args = tensor_args | |
self.dataset_subdir = dataset_subdir | |
self.base_dir = os.path.join(dataset_base_dir, self.dataset_subdir) | |
self.args = load_params_from_yaml(os.path.join(self.base_dir, '0', 'args.yaml')) | |
self.metadata = load_params_from_yaml(os.path.join(self.base_dir, '0', 'metadata.yaml')) | |
if obstacle_cutoff_margin is not None: | |
self.args['obstacle_cutoff_margin'] = obstacle_cutoff_margin | |
# -------------------------------- Load env, robot, task --------------------------------- | |
# Environment | |
env_class = getattr( | |
environments, self.metadata['env_id'] + 'ExtraObjects' if use_extra_objects else self.metadata['env_id']) | |
self.env = env_class(tensor_args=tensor_args) | |
# Robot | |
robot_class = getattr(robots, self.metadata['robot_id']) | |
self.robot = robot_class(tensor_args=tensor_args) | |
# Task | |
self.task = PlanningTask(env=self.env, robot=self.robot, tensor_args=tensor_args, **self.args) | |
self.planner_visualizer = PlanningVisualizer(task=self.task) | |
# -------------------------------- Load trajectories --------------------------------- | |
self.threshold_start_goal_pos = self.args['threshold_start_goal_pos'] | |
self.field_key_traj = 'traj' | |
self.field_key_task = 'task' | |
self.fields = {} | |
# load data | |
self.include_velocity = include_velocity | |
self.map_task_id_to_trajectories_id = {} | |
self.map_trajectory_id_to_task_id = {} | |
self.load_trajectories() | |
# dimensions | |
b, h, d = self.dataset_shape = self.fields[self.field_key_traj].shape | |
self.n_trajs = b | |
self.n_support_points = h | |
self.state_dim = d # state dimension used for the diffusion model | |
self.trajectory_dim = (self.n_support_points, d) | |
# normalize the data (for the diffusion model) | |
self.normalizer = DatasetNormalizer(self.fields, normalizer=normalizer) | |
self.normalizer_keys = [self.field_key_traj, self.field_key_task] | |
self.normalize_all_data(*self.normalizer_keys) | |
def load_trajectories(self): | |
# load free trajectories | |
trajs_free_l = [] | |
task_id = 0 | |
n_trajs = 0 | |
for current_dir, subdirs, files in os.walk(self.base_dir, topdown=True): | |
if 'results_data_dict.pickle' not in files: | |
continue | |
with open(os.path.join(current_dir, 'results_data_dict.pickle'), 'rb') as f: | |
trajs_free_tmp = UnpicklerCpu(f).load()['trajs_iters_free'] | |
if trajs_free_tmp is None: | |
continue | |
if len(trajs_free_tmp.shape) == 4: | |
trajs_free_tmp = trajs_free_tmp[0] | |
trajectories_idx = n_trajs + np.arange(len(trajs_free_tmp)) | |
self.map_task_id_to_trajectories_id[task_id] = trajectories_idx | |
for j in trajectories_idx: | |
self.map_trajectory_id_to_task_id[j] = task_id | |
task_id += 1 | |
n_trajs += len(trajs_free_tmp) | |
trajs_free_l.append(trajs_free_tmp) | |
trajs_free = torch.cat(trajs_free_l) | |
trajs_free_pos = self.robot.get_position(trajs_free) | |
if self.include_velocity: | |
trajs = trajs_free | |
else: | |
trajs = trajs_free_pos | |
self.fields[self.field_key_traj] = trajs | |
# task: start and goal state positions [n_trajectories, 2 * state_dim] | |
task = torch.cat((trajs_free_pos[..., 0, :], trajs_free_pos[..., -1, :]), dim=-1) | |
self.fields[self.field_key_task] = task | |
def normalize_all_data(self, *keys): | |
for key in keys: | |
self.fields[f'{key}_normalized'] = self.normalizer(self.fields[f'{key}'], key) | |
def render(self, task_id=3, | |
render_joint_trajectories=False, | |
render_robot_trajectories=False, | |
**kwargs): | |
# -------------------------------- Visualize --------------------------------- | |
idxs = self.map_task_id_to_trajectories_id[task_id] | |
pos_trajs = self.robot.get_position(self.fields[self.field_key_traj][idxs]) | |
start_state_pos = pos_trajs[0][0] | |
goal_state_pos = pos_trajs[0][-1] | |
fig1, axs1, fig2, axs2 = [None] * 4 | |
if render_joint_trajectories: | |
fig1, axs1 = self.planner_visualizer.plot_joint_space_state_trajectories( | |
trajs=pos_trajs, | |
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos, | |
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos), | |
) | |
if render_robot_trajectories: | |
fig2, axs2 = self.planner_visualizer.render_robot_trajectories( | |
trajs=pos_trajs, start_state=start_state_pos, goal_state=goal_state_pos, | |
) | |
return fig1, axs1, fig2, axs2 | |
def __repr__(self): | |
msg = f'TrajectoryDataset\n' \ | |
f'n_trajs: {self.n_trajs}\n' \ | |
f'trajectory_dim: {self.trajectory_dim}\n' | |
return msg | |
def __len__(self): | |
return self.n_trajs | |
def __getitem__(self, index): | |
# Generates one sample of data - one trajectory and tasks | |
field_traj_normalized = f'{self.field_key_traj}_normalized' | |
field_task_normalized = f'{self.field_key_task}_normalized' | |
traj_normalized = self.fields[field_traj_normalized][index] | |
task_normalized = self.fields[field_task_normalized][index] | |
data = { | |
field_traj_normalized: traj_normalized, | |
field_task_normalized: task_normalized | |
} | |
# build hard conditions | |
hard_conds = self.get_hard_conditions(traj_normalized, horizon=len(traj_normalized)) | |
data.update({'hard_conds': hard_conds}) | |
return data | |
def get_hard_conditions(self, traj, horizon=None, normalize=False): | |
raise NotImplementedError | |
def get_unnormalized(self, index): | |
raise NotImplementedError | |
traj = self.fields[self.field_key_traj][index][..., :self.state_dim] | |
task = self.fields[self.field_key_task][index] | |
if not self.include_velocity: | |
task = task[self.task_idxs] | |
data = {self.field_key_traj: traj, | |
self.field_key_task: task, | |
} | |
if self.variable_environment: | |
data.update({self.field_key_env: self.fields[self.field_key_env][index]}) | |
# hard conditions | |
# hard_conds = self.get_hard_conds(tasks) | |
hard_conds = self.get_hard_conditions(traj) | |
data.update({'hard_conds': hard_conds}) | |
return data | |
def unnormalize(self, x, key): | |
return self.normalizer.unnormalize(x, key) | |
def normalize(self, x, key): | |
return self.normalizer.normalize(x, key) | |
def unnormalize_trajectories(self, x): | |
return self.unnormalize(x, self.field_key_traj) | |
def normalize_trajectories(self, x): | |
return self.normalize(x, self.field_key_traj) | |
def unnormalize_tasks(self, x): | |
return self.unnormalize(x, self.field_key_task) | |
def normalize_tasks(self, x): | |
return self.normalize(x, self.field_key_task) | |
class TrajectoryDataset(TrajectoryDatasetBase): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def get_hard_conditions(self, traj, horizon=None, normalize=False): | |
# start and goal positions | |
start_state_pos = self.robot.get_position(traj[0]) | |
goal_state_pos = self.robot.get_position(traj[-1]) | |
if self.include_velocity: | |
# If velocities are part of the state, then set them to zero at the beggining and end of a trajectory | |
start_state = torch.cat((start_state_pos, torch.zeros_like(start_state_pos)), dim=-1) | |
goal_state = torch.cat((goal_state_pos, torch.zeros_like(goal_state_pos)), dim=-1) | |
else: | |
start_state = start_state_pos | |
goal_state = goal_state_pos | |
if normalize: | |
start_state = self.normalizer.normalize(start_state, key=self.field_key_traj) | |
goal_state = self.normalizer.normalize(goal_state, key=self.field_key_traj) | |
if horizon is None: | |
horizon = self.n_support_points | |
hard_conds = { | |
0: start_state, | |
horizon - 1: goal_state | |
} | |
return hard_conds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment