From fbf95e46bcc360a6b80d4dcfe22a6535a3c4f3f4 Mon Sep 17 00:00:00 2001 From: alex-bene Date: Tue, 4 Feb 2025 01:38:25 +0200 Subject: [PATCH 1/3] support direct conversion between rot matrix <-> axis angle (faster) --- pytorch3d/transforms/rotation_conversions.py | 45 +++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index a9fcae226..cb28a14db 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -476,7 +476,28 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: Returns: Rotation matrices as tensor of shape (..., 3, 3). """ - return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + eps = 1e-6 + shape = axis_angle.shape + device, dtype = axis_angle.device, axis_angle.dtype + + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + axis = axis_angle / torch.where(angles.abs() < eps, 1, angles) + + cos_theta = torch.cos(angles)[..., None] + sin_theta = torch.sin(angles)[..., None] + + rx, ry, rz = axis[..., 0], axis[..., 1], axis[..., 2] + zeros = torch.zeros(shape[:-1], dtype=dtype, device=device) + cross_product_matrix = torch.stack( + [zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1 + ).view(shape + torch.Size([3])) + + identity = torch.eye(3, dtype=dtype, device=device) + return ( + identity.expand(cross_product_matrix.shape) + + sin_theta * cross_product_matrix + + (1 - cos_theta) * torch.bmm(cross_product_matrix, cross_product_matrix) + ) def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: @@ -492,7 +513,27 @@ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: turned anticlockwise in radians around the vector's direction. """ - return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + omegas = torch.stack( + [ + matrix[..., 2, 1] - matrix[..., 1, 2], + matrix[..., 0, 2] - matrix[..., 2, 0], + matrix[..., 1, 0] - matrix[..., 0, 1], + ], + dim=-1, + ) + norms = torch.norm(omegas, p=2, dim=-1, keepdim=True) + traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1) + angles = torch.atan2(norms, traces - 1) + + zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device) + omegas = torch.where( + torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas + ) + + return 0.5 * omegas / torch.sinc(angles/torch.pi) def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: From 1b0940daa11706f99d92b9da14926ab195c5e115 Mon Sep 17 00:00:00 2001 From: alex-bene Date: Tue, 4 Feb 2025 22:43:45 +0200 Subject: [PATCH 2/3] simplify and perf improv of axis_angle <-> quaternion conv; simplify and handle edge case in axis_angle <-> rot matrix conv --- pytorch3d/transforms/rotation_conversions.py | 80 ++++++++++---------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index cb28a14db..893365f52 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten return out[..., 1:] -def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: +def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool=False) -> torch.Tensor: """ Convert rotations given as axis/angle to rotation matrices. @@ -472,47 +472,58 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. + fast: Whether to use the new faster implementation (based on the + Rodrigues formula) instead of the original implementation (which + first converted to a quaternion and then back to a rotation matrix). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ - eps = 1e-6 + if not fast: + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + shape = axis_angle.shape device, dtype = axis_angle.device, axis_angle.dtype - angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) - axis = axis_angle / torch.where(angles.abs() < eps, 1, angles) - - cos_theta = torch.cos(angles)[..., None] - sin_theta = torch.sin(angles)[..., None] + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1) - rx, ry, rz = axis[..., 0], axis[..., 1], axis[..., 2] + rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2] zeros = torch.zeros(shape[:-1], dtype=dtype, device=device) cross_product_matrix = torch.stack( [zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1 ).view(shape + torch.Size([3])) + cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix identity = torch.eye(3, dtype=dtype, device=device) + angles_sqrd = angles * angles + angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd) return ( identity.expand(cross_product_matrix.shape) - + sin_theta * cross_product_matrix - + (1 - cos_theta) * torch.bmm(cross_product_matrix, cross_product_matrix) + + torch.sinc(angles/torch.pi) * cross_product_matrix + + ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd ) -def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: +def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool=False) -> torch.Tensor: """ Convert rotations given as rotation matrices to axis/angle. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). + fast: Whether to use the new faster implementation (based on the + Rodrigues formula) instead of the original implementation (which + first converted to a quaternion and then back to a rotation matrix). Returns: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. + """ + if not fast: + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") @@ -533,7 +544,18 @@ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas ) - return 0.5 * omegas / torch.sinc(angles/torch.pi) + near_pi = torch.isclose( + ((traces - 1) / 2).abs(), torch.ones_like(traces) + ).squeeze(-1) + axis_angles = torch.empty_like(omegas) + axis_angles[~near_pi] = ( + 0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi) + ) + axis_angles[near_pi] = ( + quaternion_to_axis_angle(matrix_to_quaternion(matrix[near_pi])) + ) + + return axis_angles def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: @@ -550,22 +572,11 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) - half_angles = angles * 0.5 - eps = 1e-6 - small_angles = angles.abs() < eps - sin_half_angles_over_angles = torch.empty_like(angles) - sin_half_angles_over_angles[~small_angles] = ( - torch.sin(half_angles[~small_angles]) / angles[~small_angles] - ) - # for x small, sin(x/2) is about x/2 - (x/2)^3/6 - # so sin(x/2)/x is about 1/2 - (x*x)/48 - sin_half_angles_over_angles[small_angles] = ( - 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + sin_half_angles_over_angles = 0.5 * torch.sinc(0.5 * angles / torch.pi) + return torch.cat( + [torch.cos(0.5 * angles), axis_angle * sin_half_angles_over_angles], + dim=-1 ) - quaternions = torch.cat( - [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 - ) - return quaternions def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: @@ -584,18 +595,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: """ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) half_angles = torch.atan2(norms, quaternions[..., :1]) - angles = 2 * half_angles - eps = 1e-6 - small_angles = angles.abs() < eps - sin_half_angles_over_angles = torch.empty_like(angles) - sin_half_angles_over_angles[~small_angles] = ( - torch.sin(half_angles[~small_angles]) / angles[~small_angles] - ) - # for x small, sin(x/2) is about x/2 - (x/2)^3/6 - # so sin(x/2)/x is about 1/2 - (x*x)/48 - sin_half_angles_over_angles[small_angles] = ( - 0.5 - (angles[small_angles] * angles[small_angles]) / 48 - ) + sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi) + # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles + # can't be zero return quaternions[..., 1:] / sin_half_angles_over_angles From 866ec9772d841e92372ad73486fce742597695f4 Mon Sep 17 00:00:00 2001 From: alex-bene Date: Wed, 5 Feb 2025 17:57:03 +0200 Subject: [PATCH 3/3] =?UTF-8?q?make=20`matrix=5Fto=5Faxis=5Fangle`=20stabl?= =?UTF-8?q?e=20near=20"k=CF=80"=20and=20faster?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch3d/transforms/rotation_conversions.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 893365f52..98e66c755 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -547,13 +547,18 @@ def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool=False) -> torch.Tensor near_pi = torch.isclose( ((traces - 1) / 2).abs(), torch.ones_like(traces) ).squeeze(-1) + axis_angles = torch.empty_like(omegas) - axis_angles[~near_pi] = ( - 0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi) + axis_angles[~near_pi] = 0.5 * omegas[~near_pi] / torch.sinc( + angles[~near_pi] / torch.pi ) - axis_angles[near_pi] = ( - quaternion_to_axis_angle(matrix_to_quaternion(matrix[near_pi])) + + # this derives from: nnT = (R + 1) / 2 + n = 0.5 * ( + matrix[near_pi][..., 0, :] + + torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device) ) + axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n) return axis_angles