numpy (Jax)改变包含不同形状数组的pytree的形状

vdzxcuhz  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(115)

我有一个pytree,包含有不同形状的数组,例如它包含:

  • 形状为(5, 3, 250, 23)observations
  • 形状为(5, 3, 250)dones

我想重塑我的pytree,使前两个维度合并,这将为我的pytree中的每个对象给予类似(15, 250, ...)的结果。
我通常使用tree_map来处理我的pytree,但这次我很难让它工作,我尝试了:

jax.tree_map(lambda x: jnp.reshape(x, newshape=(15, -1)), my_pytree)

它对dones很有效,但它合并了observations的最后几个维度,得到一个(15, 5750)形状的数组(这里我希望它是(15, 250, 23))。
注意:我不能修改pytree的定义,我必须使用这个结构。

6l7fqoea

6l7fqoea1#

很抱歉发这个帖子,它有点琐碎。我把答案贴出来以防万一:

jax.tree_map(lambda x: jnp.reshape(x, newshape=(15, *x.shape[2:])),my_pytree)

相关问题