Source code for cyto_dl.models.im2im.utils.postprocessing.dict_to_im
from typing import Dict
import numpy
import numpy as np
import torch
[docs]def detach(img: torch.Tensor) -> np.ndarray:
"""Convert CUDA tensor to numpy array on cpu."""
img = img.detach().cpu()
if img.dtype == torch.bfloat16:
img = img.half()
img = img.numpy()
return img
[docs]class DictToIm:
"""Convert dictionary with image values to multichannel image."""
def __init__(self, keys, allow_missing_keys: bool = False):
"""
Parameters
----------
keys: Union[str, List[str]]
keys from dictionary to concatenate into multichannel image
allow_missing_keys: bool = False
whether to raise error if specified key is missing
"""
self.keys = keys
self.allow_missing_keys = allow_missing_keys
def __call__(self, input_dict: Dict[str, torch.Tensor]) -> np.ndarray:
output_img = []
for key in self.keys:
if key in input_dict:
im = detach(input_dict[key]).astype(np.uint8)
output_img.append(im)
elif not self.allow_missing_keys:
raise KeyError(
f"key `{key}` not available. Available keys are {input_dict.keys()}"
)
return np.stack(output_img)