Skip to content

Instantly share code, notes, and snippets.

# note: it is not a working script, just some parts of code
## 1. preprocessing and batching part (using pytorch Dataset class)
import torch.utils.data as data
def raw_labels_to_sparseTriple(arr):
arr = [sorted(x) for x in arr]
if len(arr)==0:
raise RuntimeError('empty arr')
cols = np.concatenate(arr)