tensorflow 堆叠、洗牌并恢复原始Tensor

5us2dqdw  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(154)

我有以下两个tensorflow 占位符:

Tensor("Placeholder:0", shape=(32, 2048), dtype=float32)
Tensor("Placeholder:1", shape=(64, 2048), dtype=float32)

我们将它们命名为ab。我想先对它们进行stack,然后随机地对shuffle进行shuffle。然后,我想通过网络传递它们。最后,我想在stackshuffle之前取回ab
"我所做的一切"
我理解stacking和random shuffle。因此,请指导我如何堆叠它们,洗牌它们,并最终恢复原始索引。

up9lanfz

up9lanfz1#

你可以在级联矩阵上创建一个shuffle index,这样我们就知道被打乱的元素去了哪里,然后我们可以用索引的argsort把它们按顺序放在一起。
输入:

a = tf.random.normal(shape=(32, 2048), dtype=tf.float32)
b = tf.random.normal(shape=(64, 2048), dtype=tf.float32)

堆叠阵列:

c = tf.concat([a,b], axis=0)

随机洗牌:

indices = tf.range(start=0, limit=tf.shape(c)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices) #shuffled index will tell where each element of c went.
shuffled_c = tf.gather(c, shuffled_indices)

取回c, a, b

getback_c = tf.gather(shuffled_c, tf.argsort(shuffled_indices))
a_1, b_1 = getback_c[:tf.shape(a)[0]], getback_c[tf.shape(a)[0]:]

验证值是否相同:

np.testing.assert_allclose(a.numpy(), a_1.numpy())
np.testing.assert_allclose(b.numpy(), b_1.numpy())

相关问题