此问题与tf.cast equivalent in pytorch?不同。
位转换执行按位重新解释(如C++中reinterpret_cast
),而不是“安全”类型转换。
当你想用numpy存储bfloat16Tensor时,这个操作很有用。
x = torch.ones(224, 224, 3, dtype=torch.bfloat16
x_np = bitcast(x, torch.uint8).numpy()
目前numpy本身不支持bfloat16,因此x.numpy()
将引发TypeError: Got unsupported ScalarType BFloat16
1条答案
按热度按时间fykwrbwg1#
使用第二个过载torch.Tensor.view。
它的语义与numpy.ndarray.view非常相似。