Skip to content

Instantly share code, notes, and snippets.

@ntabris
Last active April 2, 2020 21:09
Show Gist options
  • Save ntabris/89eb17317f57bbd9af13442db3241730 to your computer and use it in GitHub Desktop.
Save ntabris/89eb17317f57bbd9af13442db3241730 to your computer and use it in GitHub Desktop.
converts "tracking" h5 to sleap dataset
"""
Script to make a labels dataset from a "tracking" h5 file.
> python tracking_h5_to_sleap.py path/to/tracking.h5 path/to/video.mp4
"""
import h5py
import numpy as np
from sleap import Labels, LabeledFrame, PredictedInstance, Skeleton, Video
from sleap.instance import Track
def main(in_path, out_path, video_path, connect_adj_nodes=False):
video = Video.from_filename(video_path)
with h5py.File(in_path, "r") as f:
tracks_matrix = f["tracks"][:].T
track_names_list = f["track_names"][:].T
node_names_list = f["node_names"][:].T
print(tracks_matrix.shape)
# shape: frames * nodes * 2 * tracks
frame_count, node_count, _, track_count = tracks_matrix.shape
tracks = [Track(0, track_name.decode()) for track_name in track_names_list]
skeleton = Skeleton()
last_node_name = None
for node_name in node_names_list:
node_name = node_name.decode()
skeleton.add_node(node_name)
if connect_adj_nodes and last_node_name:
skeleton.add_edge(last_node_name, node_name)
last_node_name = node_name
frames = []
for frame_idx in range(frame_count):
instances = []
for track_idx in range(track_count):
points = tracks_matrix[frame_idx, ..., track_idx]
if not np.all(np.isnan(points)):
point_scores = np.ones(len(points))
instances.append(
PredictedInstance.from_arrays(
points=points,
point_confidences=point_scores,
skeleton=skeleton,
track=tracks[track_idx],
instance_score=1
)
)
if instances:
frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
labels = Labels(frames)
Labels.save_file(labels, filename=out_path)
print(f"Saved: {out_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("data_path", help="Path to labels json file")
parser.add_argument("video_path", help="Path to video file")
args = parser.parse_args()
out_path = args.data_path.replace(".h5", ".sleap.h5")
main(in_path=args.data_path, video_path=args.video_path, out_path=out_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment