Last active
March 13, 2020 19:46
-
-
Save edraizen/8fd362e759132e154c4f54efe1709aae to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import sparseconvnet as scn | |
from more_itertools import pairwise | |
from torch.nn.parallel._functions import Scatter | |
class Batch(object): | |
def __init__(self, indices, data, truth, chunk_sizes=None, dim=3): | |
assert isinstance(indices, torch.Tensor), "indices must be tensor" | |
assert isinstance(data, torch.Tensor), "data must be tensor" | |
assert isinstance(truth, torch.Tensor), "truth must be tensor" | |
assert indices.size()[0]==data.size()[0], "indices and data must have same length" | |
assert indices.size()[1]==dim+1, "indices must have sample number in last col" | |
self.data = [indices, data] | |
self.truth = truth | |
self.chunk_sizes = chunk_sizes | |
def scatter(self, target_gpus, dim=0): | |
n_gpus = len(target_gpus) | |
n_samples = self.data[0][:, -1].unique().size()[0] | |
n_samples_gpu = np.floor(n_samples/n_gpus) | |
_, pts_per_sample = torch.unique(a[:, -1], sorted=False, | |
return_counts=True, dim=0) | |
starts = list(range(0, len(s), int(n_samples_per_gpu))) | |
if starts[-1]<len(pts_per_sample): | |
starts.append(len(pts_per_sample)) | |
self.chunk_sizes = torch.Tensor([pts_per_sample[i:j].sum().item() \ | |
for i, j in pairwise(starts)]) | |
start_stop = [0]+self.chunk_sizes.cumsum(0).int().tolist() | |
idx = torch.fmod(self.data[0][:, -1], n_samples_gpu) | |
data = Scatter.apply(target_gpus, self.chunk_sizes, dim, self.data[1]) | |
chunks = [] | |
for (start, stop), x in zip(start_stop, data): | |
chunk = Batch(idx[start:stop], x, self.truth[start:stop], self.chunk_sizes, dim) | |
chunks.append(chunk) | |
return chunks |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment