python 如何从TensorflowTensor中移除元素列表

sd2nnvve  于 2022-12-21  发布在  Python
关注(0)|答案(3)|浏览(370)

对于以下Tensor:

<tf.Tensor: shape=(2, 10, 6), dtype=int64, numpy=
 array([[[  3,  16,  43,  10,   7, 431],
         [  3,   2,   6,   5,   7,   2],
         [  3,  37,   5,   7,   2,  12],
         [  3,   2,  11,   5,   7,   2],
         [  3,   2,   6,  18,  14, 195],
         [  3,   2,   6,   5,   7, 195],
         [  3,   2,   6,   5,   7,   9],
         [  3,   2,  11,   7,   2,  12],
         [  3,  16,  52,  92, 177, 923],
         [  3,   9,  43,  10,   7,   9]],
 
        [[  3,   2,  22, 495, 230,   4],
         [  3,   2,  22,   5, 102, 122],
         [  3,   2,  22,   5, 102, 230],
         [  3,   2,  22,   5,  70, 908],
         [  3,   2,  22,   5,  70, 450],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 230],
         [  3,   2,  22,  70,  34, 470],
         [  3,   2,  22, 855, 450,   4]]], dtype=int64)>)

我想删除Tensor中的最后一个列表[ 3, 2, 22, 855, 450, 4]。我尝试了tf.unstack,但没有成功。

yzckvree

yzckvree1#

您还可以简单地使用tf.ragged.boolean_mask来排除不需要的行:

import tensorflow as tf

x = tf.constant([[[  3,  16,  43,  10,   7, 431],
         [  3,   2,   6,   5,   7,   2],
         [  3,  37,   5,   7,   2,  12],
         [  3,   2,  11,   5,   7,   2],
         [  3,   2,   6,  18,  14, 195],
         [  3,   2,   6,   5,   7, 195],
         [  3,   2,   6,   5,   7,   9],
         [  3,   2,  11,   7,   2,  12],
         [  3,  16,  52,  92, 177, 923],
         [  3,   9,  43,  10,   7,   9]],
 
        [[  3,   2,  22, 495, 230,   4],
         [  3,   2,  22,   5, 102, 122],
         [  3,   2,  22,   5, 102, 230],
         [  3,   2,  22,   5,  70, 908],
         [  3,   2,  22,   5,  70, 450],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 122],
         [  3,   2,  22,   5,  70, 230],
         [  3,   2,  22,  70,  34, 470],
         [  3,   2,  22, 855, 450,   4]]])
x_shape = tf.shape(x)
remove = tf.constant([3, 2, 22, 855, 450, 4])

mask = tf.reduce_all(tf.equal(x, remove), axis=-1)
x = tf.ragged.boolean_mask(x, ~mask)
print(x)
<tf.RaggedTensor [[[3, 16, 43, 10, 7, 431],
  [3, 2, 6, 5, 7, 2],
  [3, 37, 5, 7, 2, 12],
  [3, 2, 11, 5, 7, 2],
  [3, 2, 6, 18, 14, 195],
  [3, 2, 6, 5, 7, 195],
  [3, 2, 6, 5, 7, 9],
  [3, 2, 11, 7, 2, 12],
  [3, 16, 52, 92, 177, 923],
  [3, 9, 43, 10, 7, 9]]     , [[3, 2, 22, 495, 230, 4],
                               [3, 2, 22, 5, 102, 122],
                               [3, 2, 22, 5, 102, 230],
                               [3, 2, 22, 5, 70, 908],
                               [3, 2, 22, 5, 70, 450],
                               [3, 2, 22, 5, 70, 122],
                               [3, 2, 22, 5, 70, 122],
                               [3, 2, 22, 5, 70, 230],
                               [3, 2, 22, 70, 34, 470]]]>
ttp71kqs

ttp71kqs2#

您可以尝试以下操作从Tensor中删除最后一个列表:

sliced_tensor = tf.slice(tensor, [0, 0, 0], [2, 9, 6])
qzlgjiam

qzlgjiam3#

试试这个

new_tensor = tf.slice(tensor, [0,0,0], [2,9,6], [1,1,1])

相关问题