Source code for fnet.transforms

from typing import Optional
import logging

import numpy as np
import scipy


logger = logging.getLogger(__name__)


[docs]class Normalize: def __init__(self, per_dim=None): """Class version of normalize function.""" self.per_dim = per_dim def __call__(self, x): return normalize(x, per_dim=self.per_dim) def __repr__(self): return "Normalize({})".format(self.per_dim)
[docs]class ToFloat: def __call__(self, x): return x.astype(np.float32) def __repr__(self): return "ToFloat()"
[docs]def normalize(img, per_dim=None): """Subtract mean, set STD to 1.0 Parameters: per_dim: normalize along other axes dimensions not equal to per dim """ axis = tuple([i for i in range(img.ndim) if i != per_dim]) slices = tuple( [slice(None) if i == per_dim else np.newaxis for i in range(img.ndim)] ) # to handle broadcasting result = img.astype(np.float32) result -= np.mean(result, axis=axis)[slices] result /= np.std(result, axis=axis)[slices] return result
[docs]def do_nothing(img): return img.astype(np.float)
[docs]class Propper: """Padder + Cropper""" def __init__(self, action="-", **kwargs): self.action = action if self.action in ["+", "pad"]: self.transformer = Padder(**kwargs) elif self.action in ["-", "crop"]: self.transformer = Cropper(**kwargs) else: raise NotImplementedError def __repr__(self): return repr(self.transformer) def __call__(self, x_in): return self.transformer(x_in)
[docs] def undo_last(self, x_in): return self.transformer.undo_last(x_in)
[docs]class Padder(object): def __init__(self, padding="+", by=16, mode="constant"): """ padding: '+', int, sequence '+': pad dimensions up to multiple of "by" int: pad each dimension by this value sequence: pad each dimensions by corresponding value in sequence by: int for use with '+' padding option mode: str passed to numpy.pad function """ self.padding = padding self.by = by self.mode = mode self.pads = {} self.last_pad = None def __repr__(self): return "Padder{}".format((self.padding, self.by, self.mode)) def _calc_pad_width(self, shape_in): if isinstance(self.padding, (str, int)): paddings = (self.padding,) * len(shape_in) else: paddings = self.padding pad_width = [] for i in range(len(shape_in)): if isinstance(paddings[i], int): pad_width.append((paddings[i],) * 2) elif paddings[i] == "+": padding_total = ( int(np.ceil(1.0 * shape_in[i] / self.by) * self.by) - shape_in[i] ) pad_left = padding_total // 2 pad_right = padding_total - pad_left pad_width.append((pad_left, pad_right)) assert len(pad_width) == len(shape_in) return pad_width
[docs] def undo_last(self, x_in): """Crops input so its dimensions matches dimensions of last input to __call__.""" assert x_in.shape == self.last_pad["shape_out"] slices = [ slice(a, -b) if (a, b) != (0, 0) else slice(None) for a, b in self.last_pad["pad_width"] ] return x_in[slices].copy()
def __call__(self, x_in): shape_in = x_in.shape pad_width = self.pads.get(shape_in, self._calc_pad_width(shape_in)) x_out = np.pad(x_in, pad_width, mode=self.mode) if shape_in not in self.pads: self.pads[shape_in] = pad_width self.last_pad = { "shape_in": shape_in, "pad_width": pad_width, "shape_out": x_out.shape, } return x_out
[docs]class Cropper(object): def __init__( self, cropping="-", by=16, offset="mid", n_max_pixels=9732096, dims_no_crop=None ): """Crop input array to given shape.""" self.cropping = cropping self.offset = offset self.by = by self.n_max_pixels = n_max_pixels self.dims_no_crop = ( [dims_no_crop] if isinstance(dims_no_crop, int) else dims_no_crop ) self.crops = {} self.last_crop = None def __repr__(self): return "Cropper{}".format( (self.cropping, self.by, self.offset, self.n_max_pixels, self.dims_no_crop) ) def _adjust_shape_crop(self, shape_crop): shape_crop_new = list(shape_crop) prod_shape = np.prod(shape_crop_new) idx_dim_reduce = 0 order_dim_reduce = list( range(len(shape_crop))[-2:] ) # alternate between last two dimensions while prod_shape > self.n_max_pixels: dim = order_dim_reduce[idx_dim_reduce] if not (dim == 0 and shape_crop_new[dim] <= 64): shape_crop_new[dim] -= self.by prod_shape = np.prod(shape_crop_new) idx_dim_reduce += 1 if idx_dim_reduce >= len(order_dim_reduce): idx_dim_reduce = 0 value = tuple(shape_crop_new) return value def _calc_shape_crop(self, shape_in): croppings = ( (self.cropping,) * len(shape_in) if isinstance(self.cropping, (str, int)) else self.cropping ) shape_crop = [] for i in range(len(shape_in)): if (croppings[i] is None) or ( self.dims_no_crop is not None and i in self.dims_no_crop ): shape_crop.append(shape_in[i]) elif isinstance(croppings[i], int): shape_crop.append(shape_in[i] - croppings[i]) elif croppings[i] == "-": shape_crop.append(shape_in[i] // self.by * self.by) else: raise NotImplementedError if self.n_max_pixels is not None: shape_crop = self._adjust_shape_crop(shape_crop) self.crops[shape_in]["shape_crop"] = shape_crop return shape_crop def _calc_offsets_crop(self, shape_in, shape_crop): offsets = ( (self.offset,) * len(shape_in) if isinstance(self.offset, (str, int)) else self.offset ) offsets_crop = [] for i in range(len(shape_in)): offset = ( (shape_in[i] - shape_crop[i]) // 2 if offsets[i] == "mid" else offsets[i] ) if offset + shape_crop[i] > shape_in[i]: logger.error( f"Cannot crop outsize image dimensions ({offset}:{offset + shape_crop[i]} for dim {i})" ) raise AttributeError offsets_crop.append(offset) self.crops[shape_in]["offsets_crop"] = offsets_crop return offsets_crop def _calc_slices(self, shape_in): shape_crop = self._calc_shape_crop(shape_in) offsets_crop = self._calc_offsets_crop(shape_in, shape_crop) slices = [ slice(offsets_crop[i], offsets_crop[i] + shape_crop[i]) for i in range(len(shape_in)) ] self.crops[shape_in]["slices"] = slices return slices def __call__(self, x_in): shape_in = x_in.shape if shape_in in self.crops: slices = self.crops[shape_in]["slices"] else: self.crops[shape_in] = {} slices = self._calc_slices(shape_in) x_out = x_in[slices].copy() self.last_crop = { "shape_in": shape_in, "slices": slices, "shape_out": x_out.shape, } return x_out
[docs] def undo_last(self, x_in): """Pads input with zeros so its dimensions matches dimensions of last input to __call__.""" assert x_in.shape == self.last_crop["shape_out"] shape_out = self.last_crop["shape_in"] slices = self.last_crop["slices"] x_out = np.zeros(shape_out, dtype=x_in.dtype) x_out[slices] = x_in return x_out
[docs]class Resizer(object): def __init__(self, factors, per_dim=None): """ Parameters: factors: tuple of resizing factors for each dimension of the input array per_dim: normalize along other axes dimensions not equal to per dim """ self.factors = factors self.per_dim = per_dim def __call__(self, x): if self.per_dim is None: return scipy.ndimage.zoom(x, (self.factors), mode="nearest") ars_resized = list() for idx in range(x.shape[self.per_dim]): slices = tuple( [idx if i == self.per_dim else slice(None) for i in range(x.ndim)] ) ars_resized.append( scipy.ndimage.zoom(x[slices], self.factors, mode="nearest") ) return np.stack(ars_resized, axis=self.per_dim) def __repr__(self): return "Resizer({:s}, {})".format(str(self.factors), self.per_dim)
[docs]class Capper(object): def __init__(self, low=None, hi=None): self._low = low self._hi = hi def __call__(self, ar): result = ar.copy() if self._hi is not None: result[result > self._hi] = self._hi if self._low is not None: result[result < self._low] = self._low return result def __repr__(self): return "Capper({}, {})".format(self._low, self._hi)
[docs]def flip_y(ar: np.ndarray) -> np.ndarray: """Flip array along y axis. Array dimensions should end in YX. Parameters ---------- ar Input array to be flipped. Returns ------- np.ndarray Flipped array. """ return np.flip(ar, axis=-2)
[docs]def flip_x(ar: np.ndarray) -> np.ndarray: """Flip array along x axis. Array dimensions should end in YX. Parameters ---------- ar Input array to be flipped. Returns ------- np.ndarray Flipped array. """ return np.flip(ar, axis=-1)
[docs]def norm_around_center(ar: np.ndarray, z_center: Optional[int] = None): """Returns normalized version of input array. The array will be normalized with respect to the mean, std pixel intensity of the sub-array of length 32 in the z-dimension centered around the array's "z_center". Parameters ---------- ar Input 3d array to be normalized. z_center Z-index of cell centers. Returns ------- np.ndarray Nomralized array, dtype = float32 """ if ar.ndim != 3: raise ValueError("Input array must be 3d") if ar.shape[0] < 32: raise ValueError("Input array must be at least length 32 in first dimension") if z_center is None: z_center = ar.shape[0] // 2 chunk_zlen = 32 z_start = z_center - chunk_zlen // 2 if z_start < 0: z_start = 0 logger.warn(f"Warning: z_start set to {z_start}") if (z_start + chunk_zlen) > ar.shape[0]: z_start = ar.shape[0] - chunk_zlen logger.warn(f"Warning: z_start set to {z_start}") chunk = ar[z_start : z_start + chunk_zlen, :, :] ar = ar - chunk.mean() ar = ar / chunk.std() return ar.astype(np.float32)