如何将torchTensor转换为numpy?
vddsk6oq1#
从pytorch doc复制:
a = torch.ones(5) print(a)
tensor([1.,1.,1.,1.,1.])
b = a.numpy() print(b)
[1. 1.1.1.1.]以下是与@John的讨论:如果Tensor是(或可以是)GPU上的,或者如果它(或它可以)需要grad,则可以使用
t.detach().cpu().numpy()
我建议只根据需要来丑化你的代码。
g9icjywg2#
你可以尝试以下方法
1. torch.Tensor().numpy() 2. torch.Tensor().cpu().data.numpy() 3. torch.Tensor().cpu().detach().numpy()
zpjtge223#
另一个有用的方法:
a = torch(0.1, device='cuda') a.cpu().data.numpy()
回答array(0.1,dtype=float32)
0dxa2lsx4#
这是fastai core的一个函数:
def to_np(x): "Convert a tensor to a numpy array." return apply(lambda o: o.data.cpu().numpy(), x)
可能使用来自未来PyTorch库的函数是一个不错的选择。如果你查看PyTorch Transformers内部,你会发现以下代码:
preds = logits.detach().cpu().numpy()
那么你可能会问为什么需要detach()方法?当我们想要从AD计算图中分离Tensor时,需要它。仍然注意到CPUTensor和numpy数组是连接的。它们共享相同的存储:
detach()
import torch tensor = torch.zeros(2) numpy_array = tensor.numpy() print('Before edit:') print(tensor) print(numpy_array) tensor[0] = 10 print() print('After edit:') print('Tensor:', tensor) print('Numpy array:', numpy_array)
输出:
Before edit: tensor([0., 0.]) [0. 0.] After edit: Tensor: tensor([10., 0.]) Numpy array: [10. 0.]
第一个元素的值由tensor和numpy数组共享。在tensor中将其更改为10也会在numpy数组中更改它。这就是为什么我们需要小心,因为改变numpy数组也会改变CPUTensor。
b4lqfgs45#
您可能会发现以下两个功能很有用。
iecba09b6#
有时候,如果有“应用”梯度,你首先必须把.detach()函数放在.numpy()函数之前。
.detach()
.numpy()
loss = loss_fn(preds, labels) print(loss.detach().numpy())
s2j5cfk07#
x = torch.tensor([0.1,0.32], device='cuda:0') x.detach().cpu().data.numpy()
7条答案
按热度按时间vddsk6oq1#
从pytorch doc复制:
tensor([1.,1.,1.,1.,1.])
[1. 1.1.1.1.]
以下是与@John的讨论:
如果Tensor是(或可以是)GPU上的,或者如果它(或它可以)需要grad,则可以使用
我建议只根据需要来丑化你的代码。
g9icjywg2#
你可以尝试以下方法
zpjtge223#
另一个有用的方法:
回答
array(0.1,dtype=float32)
0dxa2lsx4#
这是fastai core的一个函数:
可能使用来自未来PyTorch库的函数是一个不错的选择。
如果你查看PyTorch Transformers内部,你会发现以下代码:
那么你可能会问为什么需要
detach()
方法?当我们想要从AD计算图中分离Tensor时,需要它。仍然注意到CPUTensor和numpy数组是连接的。它们共享相同的存储:
输出:
第一个元素的值由tensor和numpy数组共享。在tensor中将其更改为10也会在numpy数组中更改它。
这就是为什么我们需要小心,因为改变numpy数组也会改变CPUTensor。
b4lqfgs45#
您可能会发现以下两个功能很有用。
iecba09b6#
有时候,如果有“应用”梯度,你首先必须把
.detach()
函数放在.numpy()
函数之前。s2j5cfk07#