Source code for cyto_dl.utils.rotation

import math
from typing import Tuple

import torch
import torch.nn.functional as F


[docs]class RotationModule: def __init__(self, group, spatial_dims, pad_value=0.0, eps=1e-6): if group not in ("so2", "so3"): raise ValueError(f"Group must be one of ('so2', 'so3'), got {group}") self.group = group self.spatial_dims = spatial_dims self.eps = eps self.pad_value = pad_value
[docs] def compute_rotation_matrix(self, pose): if self.group == "so2": return get_rotation_matrix_so2(pose, self.spatial_dims) return get_rotation_matrix_so3(pose, eps=self.eps)
def __call__(self, img, pose, R=None): if R is None: R = self.compute_rotation_matrix(pose) if self.spatial_dims == 3: # assume data comes in -Z -Y X format, to match `escnn`. # the rotation matrix comes in XYZ form, so I transform it # to X -Y -Z R[:, -2:, :] = R[:, -2:, :] * -1 R[:, :, -2:] = R[:, :, -2:] * -1 if self.spatial_dims == 2: # affine_grid and grid_sample rotate the grid, not the signal. # to rotate the signal by alpha, we want to rotate the grid # by -alpha, so we make that correction here. R[:, -1, :] = R[:, -1, :] * -1 R[:, :, -1] = R[:, :, -1] * -1 # add a displacement vector of zeros to the rotation matrix disp = torch.tensor(0).expand(len(img), self.spatial_dims, 1).type_as(img) A = torch.cat((R, disp), dim=2) grid = F.affine_grid(A, img.size(), align_corners=False).type_as(img) y = F.grid_sample(img - self.pad_value, grid, align_corners=False) return y + self.pad_value
[docs]def get_rotation_matrix_so2(pose, spatial_dims): """Computes a (batch of) rotation matrix of the SO(2) group, in either 2d or 3d. In 3d, rotation is assumed to be about the Z axis (first axis). Parameters ---------- pose: torch.Tensor A (batch of) equivariant 2d vector(s), of the form (cos(theta), sin(theta)) spatial_dims: int Indicates whether it's 2d or 3d Returns ------- The (batch of) rotation matrix """ if len(pose.shape) != 2: v = pose.unsqueeze(0) else: v = pose if spatial_dims == 2: R = torch.stack( ( torch.stack((v[:, 0], -v[:, 1]), dim=1), torch.stack((v[:, 1], v[:, 0]), dim=1), ), dim=1, ) else: zeros = torch.tensor(0).type_as(v).expand(len(v)) ones = torch.tensor(1).type_as(v).expand(len(v)) R = torch.stack( ( torch.stack((v[:, 0], -v[:, 1], zeros), dim=1), torch.stack((v[:, 1], v[:, 0], zeros), dim=1), torch.stack((zeros, zeros, ones), dim=1), ), dim=1, ) if len(pose.shape) != 2: return R.squeeze(0) return R
[docs]def get_rotation_matrix_so3(z, eps=1e-6): """Computes a (batch of) rotation matrix of the SO(3) group. Parameters ---------- z: torch.Tensor A batch of pairs of equivariant vectors, from which a rotation matrix is inferred eps: float Precision Returns ------- The (batch of) rotation matrix """ # produce unit vector v1 = z[:, 0, :] u1 = v1 / (v1.norm(dim=1, keepdim=True) + eps) # produce a second unit vector, orthogonal to the first one v2 = z[:, 1, :] v2 = v2 - (v2 * u1).sum(1, keepdim=True) * u1 u2 = v2 / (v2.norm(dim=1, keepdim=True) + eps) # produce a third orthogonal vector, as the cross product of the first two u3 = torch.cross(u1, u2) rot = torch.stack([u1, u2, u3], dim=1) return rot