我有一个pytree,包含有不同形状的数组,例如它包含:
- 形状为
(5, 3, 250, 23)
的observations
- 形状为
(5, 3, 250)
的dones
我想重塑我的pytree,使前两个维度合并,这将为我的pytree中的每个对象给予类似(15, 250, ...)
的结果。
我通常使用tree_map
来处理我的pytree,但这次我很难让它工作,我尝试了:
jax.tree_map(lambda x: jnp.reshape(x, newshape=(15, -1)), my_pytree)
它对dones
很有效,但它合并了observations
的最后几个维度,得到一个(15, 5750)
形状的数组(这里我希望它是(15, 250, 23)
)。
注意:我不能修改pytree的定义,我必须使用这个结构。
1条答案
按热度按时间6l7fqoea1#
很抱歉发这个帖子,它有点琐碎。我把答案贴出来以防万一: