每次遇到x个连续零的序列时,我都会尝试分割一个1D pytorchTensor。如果在此“分裂”之后有其他零元素,我打算将它们删除,直到下一个非零值。目前,我在零索引上使用for循环来实现这一点。然而,这种方法很慢,特别是在处理包含大量零值的大型Tensor时,你对我如何增强和优化这段代码有什么建议吗?可能使用PyTorch特定的函数来提高性能?
这里我的Tensor有两个dim,但是第一个dim对这个任务来说无关紧要(忽略它)。
def _split_tensor_gpu(split_flow, consecutive_zeros):
zero_indices = torch.nonzero(split_flow[:, 1] == 0).view(-1)
if len(zero_indices) == 0:
return [split_flow]
splitted_list = []
first_index = 0
zero_counter = 0
for i in range(1, len(zero_indices)):
if zero_indices[i] - zero_indices[i - 1] == 1:
zero_counter += 1
else:
zero_counter = 0
if zero_counter == consecutive_zeros:
splitted_list.append(split_flow[first_index:zero_indices[i]])
first_index = zero_indices[i] + 1
if zero_counter > consecutive_zeros:
first_index = zero_indices[i] + 1
if first_index <= len(split_flow) - 1:
splitted_list.append(split_flow[first_index:])
return splitted_list
字符串
解决方法:基于第一条评论,它完成了大部分工作,但在拆分后没有删除零,我修改了函数并得到了以下内容(现在应该可以完成这项工作了):
def _split_tensor_gpu2(tensor_, consecutive_zeros):
# step 1: identify Zero Sequences
# create a mask of zeros and find the difference between consecutive elements
is_zero = tensor_[:, 1] == 0
diff = torch.diff(is_zero.float(), prepend=torch.tensor([0.0], device=tensor_.device))
# start and end indices of zero sequences
start_indices = torch.where(diff == 1)[0]
end_indices = torch.where(diff == -1)[0]
# adjust for cases where sequences reach the end of the tensor
if len(end_indices) == 0 or (len(start_indices) > 0 and end_indices[-1] < start_indices[-1]):
end_indices = torch.cat([end_indices, tensor_.size(0) * torch.ones(1, dtype=torch.long, device=tensor_.device)])
# step 2: mark split points
# find sequences with length >= consecutive_zeros
valid_seqs = (end_indices - start_indices) > consecutive_zeros
valid_start_indices = start_indices[valid_seqs] + consecutive_zeros # 0:st+2
valid_end_indices = end_indices[valid_seqs]
splits = []
end_idx = 0
for i in range(len(valid_start_indices)):
splits.append(tensor_[end_idx:valid_start_indices[i]])
end_idx = valid_end_indices[i]
# add the remaining part of the tensor if any
if end_idx < tensor_.size(0):
splits.append(tensor_[end_idx:])
return splits
型
1条答案
按热度按时间b0zn9rqh1#
您可以在与Tensor相关的操作中使用PyTorch的内置函数:
字符串