Source code for cyto_dl.image.transforms.merge
from copy import deepcopy
from typing import Union
from monai.transforms import Transform
[docs]class Merged(Transform):
"""Use mask to merge two images."""
def __init__(
self,
mask_key: str,
image_keys: Union[list, str],
base_image_key: str,
output_name: str,
):
"""
Parameters
----------
mask_key: str
key for mask
image_keys: Union[list, str]
keys to merge
base_image_key: str
name of image to serve as base image
"""
super().__init__()
self.mask_key = mask_key
if len(image_keys) != 2:
raise ValueError(f"image_keys must be a list of length 2. Got {image_keys}")
self.image_keys = list(image_keys)
self.base_image_key = base_image_key
self.output_name = output_name
def __call__(self, input_dict):
"""
Parameters
----------
input_dict: Dict[str, torch.Tensor]
dict of CZYX tensors/metadata/paths
"""
if self.base_image_key not in input_dict:
raise KeyError(
f"key `{self.base_image_key}` not available. Available keys are {input_dict.keys()}"
)
base_image_name = input_dict[self.base_image_key]
if base_image_name not in self.image_keys:
raise KeyError(
f"Base image name `{base_image_name}` must match provided image keys `{self.image_keys}`"
)
if self.mask_key not in input_dict or input_dict[self.mask_key] is None:
# no merging mask, return original dict
input_dict[self.output_name] = deepcopy(input_dict[base_image_name])
# remove mask key if it exists
input_dict.pop(self.mask_key, None)
return input_dict
mask = input_dict[self.mask_key].astype(bool)
# From polygoan loader, 1 is everything outside of the polygon, 0 inside the polygon.
# For merging we want to inver this
mask = ~mask
for key in self.image_keys:
if key not in input_dict.keys():
raise KeyError(
f"key `{key}` not available. Available keys are {input_dict.keys()}"
)
base_image = input_dict[base_image_name]
merge_image = input_dict[
self.image_keys[0] if self.image_keys[1] == base_image_name else self.image_keys[1]
]
input_dict[self.output_name] = (base_image * ~mask) + (merge_image * mask)
return input_dict