如何以jax兼容的方式(例如,使用 jax.numpy
)?
def actions(state: tuple[int, ...]) -> list[tuple[int, ...]]:
l = []
iterables = [range(1, i+1) for i in state]
ns = list(range(len(iterables)))
for i, iterable in enumerate(iterables):
for value in iterable:
action = tuple(value if n == i else 0 for n in ns)
l.append(action)
return l
>>> state = (3, 1, 2)
>>> actions(state)
[(1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 2)]
1条答案
按热度按时间yrefmtwq1#
jax和numpy一样,不能有效地操作python容器类型,比如列表和元组,因此实际上没有任何jax兼容的方法来创建具有上面指定的确切签名的函数。
但是如果你对返回值是一个二维数组没问题,你可以这样做,基于
jnp.vstack
:注意,因为输出数组的大小取决于
state
,state
必须是静态量,因此元组是输入的好选项。