Pytorch分裂Tensor当连续零

dzhpxtsq  于 2024-01-09  发布在  其他
关注(0)|答案(1)|浏览(134)

每次遇到x个连续零的序列时,我都会尝试分割一个1D pytorchTensor。如果在此“分裂”之后有其他零元素,我打算将它们删除,直到下一个非零值。目前,我在零索引上使用for循环来实现这一点。然而,这种方法很慢,特别是在处理包含大量零值的大型Tensor时,你对我如何增强和优化这段代码有什么建议吗?可能使用PyTorch特定的函数来提高性能?
这里我的Tensor有两个dim,但是第一个dim对这个任务来说无关紧要(忽略它)。

  1. def _split_tensor_gpu(split_flow, consecutive_zeros):
  2. zero_indices = torch.nonzero(split_flow[:, 1] == 0).view(-1)
  3. if len(zero_indices) == 0:
  4. return [split_flow]
  5. splitted_list = []
  6. first_index = 0
  7. zero_counter = 0
  8. for i in range(1, len(zero_indices)):
  9. if zero_indices[i] - zero_indices[i - 1] == 1:
  10. zero_counter += 1
  11. else:
  12. zero_counter = 0
  13. if zero_counter == consecutive_zeros:
  14. splitted_list.append(split_flow[first_index:zero_indices[i]])
  15. first_index = zero_indices[i] + 1
  16. if zero_counter > consecutive_zeros:
  17. first_index = zero_indices[i] + 1
  18. if first_index <= len(split_flow) - 1:
  19. splitted_list.append(split_flow[first_index:])
  20. return splitted_list

字符串
解决方法:基于第一条评论,它完成了大部分工作,但在拆分后没有删除零,我修改了函数并得到了以下内容(现在应该可以完成这项工作了):

  1. def _split_tensor_gpu2(tensor_, consecutive_zeros):
  2. # step 1: identify Zero Sequences
  3. # create a mask of zeros and find the difference between consecutive elements
  4. is_zero = tensor_[:, 1] == 0
  5. diff = torch.diff(is_zero.float(), prepend=torch.tensor([0.0], device=tensor_.device))
  6. # start and end indices of zero sequences
  7. start_indices = torch.where(diff == 1)[0]
  8. end_indices = torch.where(diff == -1)[0]
  9. # adjust for cases where sequences reach the end of the tensor
  10. if len(end_indices) == 0 or (len(start_indices) > 0 and end_indices[-1] < start_indices[-1]):
  11. end_indices = torch.cat([end_indices, tensor_.size(0) * torch.ones(1, dtype=torch.long, device=tensor_.device)])
  12. # step 2: mark split points
  13. # find sequences with length >= consecutive_zeros
  14. valid_seqs = (end_indices - start_indices) > consecutive_zeros
  15. valid_start_indices = start_indices[valid_seqs] + consecutive_zeros # 0:st+2
  16. valid_end_indices = end_indices[valid_seqs]
  17. splits = []
  18. end_idx = 0
  19. for i in range(len(valid_start_indices)):
  20. splits.append(tensor_[end_idx:valid_start_indices[i]])
  21. end_idx = valid_end_indices[i]
  22. # add the remaining part of the tensor if any
  23. if end_idx < tensor_.size(0):
  24. splits.append(tensor_[end_idx:])
  25. return splits

b0zn9rqh

b0zn9rqh1#

您可以在与Tensor相关的操作中使用PyTorch的内置函数:

  1. import torch
  2. def _split_tensor_gpu(tensor, consecutive_zeros):
  3. # step 1: identify Zero Sequences
  4. # create a mask of zeros and find the difference between consecutive elements
  5. is_zero = tensor[:, 1] == 0
  6. diff = torch.diff(is_zero.float(), prepend=torch.tensor([0.0], device=tensor.device))
  7. # start and end indices of zero sequences
  8. start_indices = torch.where(diff == 1)[0]
  9. end_indices = torch.where(diff == -1)[0]
  10. # adjust for cases where sequences reach the end of the tensor
  11. if len(end_indices) == 0 or (len(start_indices) > 0 and end_indices[-1] < start_indices[-1]):
  12. end_indices = torch.cat([end_indices, tensor.size(0) * torch.ones(1, dtype=torch.long, device=tensor.device)])
  13. # step 2: mark split points
  14. # find sequences with length >= consecutive_zeros
  15. valid_seqs = (end_indices - start_indices) >= consecutive_zeros
  16. valid_start_indices = start_indices[valid_seqs]
  17. valid_end_indices = end_indices[valid_seqs]
  18. # step 3: split the tensor
  19. # split the tensor at valid indices
  20. splits = []
  21. start_idx = 0
  22. for end_idx in valid_end_indices:
  23. splits.append(tensor[start_idx:end_idx])
  24. start_idx = end_idx
  25. # add the remaining part of the tensor if any
  26. if start_idx < tensor.size(0):
  27. splits.append(tensor[start_idx:])
  28. return splits
  29. # Example usage
  30. tensor = torch.tensor([[1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1]], dtype=torch.float32).t()
  31. consecutive_zeros = 3
  32. split_tensors = _split_tensor_gpu(tensor, consecutive_zeros)

字符串

展开查看全部

相关问题