scipy Assert错误:batch_size和euler_angles形状不匹配

d7v8vwbk  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(127)

我尝试使用以下代码从四元数计算欧拉角:

  1. def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
  2. """
  3. Convert a batch of quaternions to Euler angles.
  4. Args:
  5. quaternions: Tensor of shape (batch_size, 4). Batch of quaternions.
  6. sequence: String specifying the rotation sequence. Default is 'xyz'.
  7. Returns:
  8. euler_angles: Tensor of shape (batch_size, 3). Batch of Euler angles.
  9. """
  10. batch_size = quaternions.shape[0]
  11. q = quaternions.detach().cpu().numpy() # Convert to NumPy array
  12. rotations = Rotation.from_quat(q)
  13. euler_angles = rotations.as_euler(sequence, degrees=False)
  14. euler_angles = torch.tensor(euler_angles, device=quaternions.device)
  15. euler_angles = euler_angles.view(batch_size, 3)
  16. return euler_angles

字符串
但我有个问题

  1. RuntimeError: shape '[4, 3]' is invalid for input of size 3


我尝试使用以下代码修复它,因为我找到了一些解决方案here

  1. def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
  2. batch_size = quaternions.shape[0] # Ensure that batch_size is correctly calculated
  3. print("Batch size:", batch_size)
  4. q = quaternions.detach().cpu().numpy() # Convert to NumPy array
  5. rotations = Rotation.from_quat(q)
  6. euler_angles = rotations.as_euler(sequence, degrees=False)
  7. print("Shape of euler_angles:", euler_angles.shape)
  8. if batch_size == 1:
  9. euler_angles = euler_angles.reshape(1, -1) # Reshape single quaternion to (1, 3)
  10. euler_angles = torch.tensor(euler_angles, device=quaternions.device)
  11. assert batch_size == euler_angles.shape[0], "Mismatch in batch_size and euler_angles shape"
  12. euler_angles = euler_angles.view(batch_size, 3) # Reshape to [batch_size, 3]
  13. return euler_angles


当我输入一些批量大小时,例如1或2或16,我遇到了这个问题:

  1. Batch size: 4
  2. Shape of euler_angles: (3,)
  3. Traceback (most recent call last):
  4. File "test_quat.py", line 147, in <module>
  5. euler = utils.compute_euler_angles_from_quaternion(
  6. File "/home/redhwan/2/HPE/quat/utils.py", line 318, in compute_euler_angles_from_quaternion
  7. assert batch_size == euler_angles.shape[0], "Mismatch in batch_size and euler_angles shape"
  8. AssertionError: Mismatch in batch_size and euler_angles shape

t9aqgxwy

t9aqgxwy1#

由于代码是为期望一个批处理维度而编写的,因此请确保您有一个批处理维度,或者如果您永远不会有批处理维度,请删除batch_size = quaternions.shape[0]euler_angles = euler_angles.view(batch_size, 3)行。
如果要保留批维度,可以执行以下操作

  1. def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
  2. """
  3. Convert a batch of quaternions to Euler angles.
  4. Args:
  5. quaternions: Tensor of shape (batch_size, 4). Batch of quaternions.
  6. sequence: String specifying the rotation sequence. Default is 'xyz'.
  7. Returns:
  8. euler_angles: Tensor of shape (batch_size, 3). Batch of Euler angles.
  9. """
  10. # Add a batch dimension if input quaternions do not have one
  11. if quaternions.dim() == 1:
  12. quaternions = quaternions.unsqueeze(0)
  13. batch_size = quaternions.shape[0]
  14. q = quaternions.detach().cpu().numpy() # Convert to NumPy array
  15. rotations = Rotation.from_quat(q)
  16. euler_angles = rotations.as_euler(sequence, degrees=False)
  17. euler_angles = torch.tensor(euler_angles, device=quaternions.device)
  18. euler_angles = euler_angles.view(batch_size, 3)
  19. return euler_angles

字符串
您也可以在将quaternions传递给函数之前执行unsqueeze,这样就不需要检查和unsqueeze

展开查看全部

相关问题