Created
October 11, 2025 05:17
-
-
Save insaneyilin/f04b497b899c70918298538f1a3a4b7f to your computer and use it in GitHub Desktop.
Get the last valid position indices for each sequence sample
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_last_valid_indices(valid_mask): | |
| """Get the last valid indices for each sample. | |
| Args: | |
| valid_mask (torch.Tensor): [B, seq_len], True if valid | |
| Returns: | |
| torch.Tensor: [B,], Last valid indices for each sample. | |
| """ | |
| # 反转mask | |
| reversed_mask = torch.flip(valid_mask, dims=[1]) | |
| # 找到反转后第一个True的位置 | |
| first_true_in_reversed = torch.argmax(reversed_mask.int(), dim=1) | |
| # 转换回原始索引 | |
| last_valid_indices = (valid_mask.size(1) - 1) - first_true_in_reversed | |
| # 检查是否有有效观测 | |
| has_valid_obs = valid_mask.any(dim=1) # [B] | |
| # 检查那些第一个True实际上是False的情况(即全False的情况) | |
| # 当反转后的mask在位置0处为False时,说明原mask全为False | |
| actually_valid = reversed_mask[torch.arange(reversed_mask.size(0)), first_true_in_reversed] | |
| # 综合判断:既要有有效观测,且argmax找到的位置确实为True | |
| valid_condition = has_valid_obs & actually_valid | |
| last_valid_indices = torch.where( | |
| valid_condition, last_valid_indices, -1) # [B], -1 means no valid observations | |
| return last_valid_indices | |
| def get_last_valid_indices_v2(valid_mask): | |
| """更简洁的版本""" | |
| # 创建序列索引 | |
| indices = torch.arange(valid_mask.size(1), device=valid_mask.device) | |
| # 扩展到每个样本 | |
| indices = indices.expand(valid_mask.size(0), -1) | |
| # 将无效位置设为-1 | |
| valid_indices = torch.where(valid_mask, indices, -1) | |
| # 取最大值(即最后一个有效索引) | |
| last_valid_indices = valid_indices.max(dim=1)[0] | |
| return last_valid_indices |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment