Skip to content

Instantly share code, notes, and snippets.

@haixuanTao
Created November 13, 2024 04:57
Show Gist options
  • Save haixuanTao/067b255f411a463c02d40852ead4dff0 to your computer and use it in GitHub Desktop.
Save haixuanTao/067b255f411a463c02d40852ead4dff0 to your computer and use it in GitHub Desktop.
record.py
import h5py
import os
from dora import Node
import numpy as np
## Make data dir if it does not exist
if not os.path.exists("data"):
os.makedirs("data")
def save_data(data_dict, dataset_path, data_size):
with h5py.File(dataset_path + ".hdf5", "w", rdcc_nbytes=1024**2 * 2) as root:
root.attrs["sim"] = False
root.attrs["compress"] = False
obs = root.create_group("observations")
variable_length = h5py.vlen_dtype(np.dtype("uint8"))
image = obs.create_group("images")
_ = image.create_dataset(
"cam_high",
(data_size,),
dtype=variable_length,
)
_ = image.create_dataset(
"cam_left_wrist",
(data_size,),
dtype=variable_length,
)
_ = image.create_dataset(
"cam_right_wrist",
(data_size,),
dtype=variable_length,
)
_ = obs.create_dataset("qpos", (data_size, 14))
_ = root.create_dataset("action", (data_size, 14))
_ = root.create_dataset("base_action", (data_size, 2))
# data_dict write into h5py.File
for name, array in data_dict.items():
print(name)
if "images" in name:
image[name][...] = array
else:
root[name][...] = array
data_dict = {
"/observations/qpos": [],
"/observations/images/cam_high": [],
"/observations/images/cam_left_wrist": [],
"/observations/images/cam_right_wrist": [],
"/action": [],
"/base_action": [],
}
node = Node()
LEAD_CAMERA = "/observations/images/cam_high"
tmp_dict = {
"/base_action": [0.0, 0.0],
}
i = 0
for event in node:
if event["type"] == "INPUT":
if "save" in event["id"]:
char = event["value"][0].as_py()
if char == "p":
save_data(
data_dict,
f"data/episode_{i}",
len(data_dict["/observations/qpos"]),
)
# Reset dict
data_dict = {
"/observations/qpos": [],
"/observations/images/cam_high": [],
"/observations/images/cam_left_wrist": [],
"/observations/images/cam_right_wrist": [],
"/action": [],
"/base_action": [],
}
i += 1
elif "image" in event["id"]:
tmp_dict[event["id"]] = event["value"].to_numpy()
elif "qpos" in event["id"]:
tmp_dict[event["id"]] = event["value"].to_numpy()
# Check if tmp dict is full
if len(tmp_dict) != 6:
continue
elif event["id"] == LEAD_CAMERA:
data_dict["/observations/qpos"].append(
np.concatenate(
[
tmp_dict["/observations/qpos_left"],
tmp_dict["/observations/qpos_right"],
]
)
)
# We reproduce obs and action
data_dict["/action"].append(
np.concatenate(
[
tmp_dict["/observations/qpos_left"],
tmp_dict["/observations/qpos_right"],
]
)
)
data_dict["/base_action"].append(tmp_dict["/base_action"])
data_dict["/observations/images/cam_high"].append(
tmp_dict["/observations/images/cam_high"]
)
data_dict["/observations/images/cam_left_wrist"].append(
tmp_dict["/observations/images/cam_left_wrist"]
)
data_dict["/observations/images/cam_right_wrist"].append(
tmp_dict["/observations/images/cam_right_wrist"]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment