我正在研究分类问题,我有一个字符串列表作为类标签,我想把它们转换成Tensor。到目前为止,我已经尝试过用numpy模块提供的np.array
函数把字符串列表转换成numpy array
。truth = torch.from_numpy(np.array(truths))
但我得到以下错误。RuntimeError: can't convert a given np.ndarray to a tensor - it has an invalid type. The only supported types are: double, float, int64, int32, and uint8.
有人能提出一个替代方法吗?谢谢
4条答案
按热度按时间yvt65v4c1#
不幸的是,你现在不能。我不认为这是一个好主意,因为它会使PyTorch笨拙。一个流行的解决方案是使用sklearn将其转换为数字类型。
下面是一个简短的示例:
由于您可能需要在真标签和转换标签之间进行转换,因此最好存储变量
le
。hof1towb2#
技巧是首先找出列表中单词的最大长度,然后在第二个循环中用零填充Tensor。注意utf8字符串每个字符占用两个字节。
l3zydbqr3#
如果你不想使用sklearn,另一个解决方案是保留原始列表并创建一个额外的索引列表,你可以用它来引用你的原始值。我特别需要这个,当我必须跟踪我的原始字符串,同时批处理标记化的字符串。
示例如下:
对于我的特定用例,代码如下所示:
7cwmlq894#