Created
July 23, 2022 21:00
-
-
Save ptrblck/ed837c8f34caf8313332363c6602cdee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for https://twitter.com/francoisfleuret/status/1550886362815012865 | |
import torch | |
# setup | |
N, Q, R = 5, 20, 10 | |
U = torch.randn(N, Q) | |
V = torch.arange(N*R).view(N, R).float() | |
# add -1s to U | |
U[torch.arange(U.size(0)), torch.randint(0, Q-R, (U.size(0),))] = -1. | |
# use another pass for 50% of the rows to make sure we are seeing some duplicates | |
idx = torch.randint(0, U.size(0), (U.size(0)//2,)) | |
U[idx, torch.randint(0, Q-R, idx.size())] = -1. | |
print(U) | |
# get min indices for U==-1. for each row | |
r, c = (U==-1.).nonzero(as_tuple=True) | |
idx = torch.zeros(N).long() | |
idx.scatter_reduce_(0, r, c, reduce="amin", include_self=False) | |
print(idx) | |
# create mask to index U | |
mask = torch.ones(N, R).long() | |
mask[:, 0] = idx | |
mask.cumsum_(dim=1) | |
print(mask) | |
# copy V into U | |
U[torch.arange(U.size(0)).unsqueeze(1), mask] = V | |
print(U) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment