import sys
import warnings
from pathlib import Path
from typing import Dict, List, Union
import torch
import torch.nn as nn
from monai.data.meta_tensor import MetaTensor
from monai.inferers import sliding_window_inference
from torchmetrics import MeanMetric
from cyto_dl.models.base_model import BaseModel
warnings.simplefilter("once", UserWarning)
[docs]class MultiTaskIm2Im(BaseModel):
def __init__(
self,
*,
backbone: nn.Module,
task_heads: Dict,
x_key: str,
save_dir="./",
save_images_every_n_epochs=1,
inference_args: Dict = {},
inference_heads: Union[List, None] = None,
compile=False,
**base_kwargs,
):
"""
Parameters
----------
backbone: nn.Module
backbone network, parameters are shared between task heads
task_heads: Dict
task-specific heads
x_key: str
key of input image in batch
save_dir="./"
directory to save images during training and validation
save_images_every_n_epochs=1
Frequency to save out images during training
inference_args: Dict = {}
Arguments passed to monai's [sliding window inferer](https://docs.monai.io/en/stable/inferers.html#sliding-window-inference)
inference_heads: Union[List, None] = None
Optional list of heads to run during inference. Defaults to running all heads.
compile: False
Whether to compile the model using torch.compile
**base_kwargs:
Additional arguments passed to BaseModel
"""
_DEFAULT_METRICS = {
"train/loss": MeanMetric(),
"val/loss": MeanMetric(),
"test/loss": MeanMetric(),
}
for head in task_heads.keys():
assert head not in (
"loss",
"pred",
"target",
"input",
), "Task head name cannot be 'loss', 'pred', 'target', or 'input'"
_DEFAULT_METRICS.update(
{
f"train/loss/{head}": MeanMetric(),
f"val/loss/{head}": MeanMetric(),
f"test/loss/{head}": MeanMetric(),
}
)
metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS)
super().__init__(metrics=metrics, **base_kwargs)
self.automatic_optimization = True
if compile is True and not sys.platform.startswith("win"):
self.backbone = torch.compile(backbone)
self.task_heads = torch.nn.ModuleDict(
{k: torch.compile(v) for k, v in task_heads.items()}
)
else:
self.backbone = backbone
self.task_heads = torch.nn.ModuleDict(task_heads)
self.inference_heads = inference_heads or list(self.task_heads.keys())
for k, head in self.task_heads.items():
head.update_params({"head_name": k, "x_key": x_key, "save_dir": save_dir})
def _train_forward(self, batch, stage, n_postprocess, run_heads):
"""During training we are only dealing with patches,so we can calculate per-patch loss,
metrics, postprocessing etc."""
z = self.backbone(batch[self.hparams.x_key])
return {
task: self.task_heads[task].run_head(z, batch, stage, n_postprocess)
for task in run_heads
}
[docs] def forward(self, x, run_heads):
z = self.backbone(x)
return {task: self.task_heads[task](z) for task in run_heads}
def _inference_forward(self, batch, stage, n_postprocess, run_heads):
"""During inference, we need to calculate per-fov loss/metrics/postprocessing.
To avoid storing and passing to each head the intermediate results of the backbone, we need
to run backbone + taskheads patch by patch, then do saving/postprocessing/etc on the entire
fov.
"""
with torch.no_grad():
raw_pred_images = sliding_window_inference(
inputs=batch[self.hparams.x_key],
predictor=self.forward,
run_heads=run_heads,
**self.hparams.inference_args,
)
return {
head_name: self.task_heads[head_name].run_head(
None,
batch,
stage,
n_postprocess,
run_forward=False,
y_hat=raw_pred_images[head_name],
)
for head_name in run_heads
}
[docs] def run_forward(self, batch, stage, n_postprocess, run_heads):
if stage in ("train", "val"):
return self._train_forward(batch, stage, n_postprocess, run_heads)
return self._inference_forward(batch, stage, n_postprocess, run_heads)
[docs] def get_n_postprocess_image(self, batch, batch_idx, stage):
# save first batch every n epochs during train/val
if (
stage in ("train", "val")
and batch_idx
== (self.current_epoch + 1) % self.hparams.save_images_every_n_epochs
== 0
):
return 1
# postprocess all images in batch for predict/test
elif stage in ("predict", "test"):
return batch[self.hparams.x_key].shape[0]
return 0
def _sum_losses(self, losses):
losses["loss"] = torch.sum(torch.stack(list(losses.values())))
return losses
def _get_unrun_heads(self, io_map):
"""Returns heads that don't have outputs yet."""
updated_run_heads = []
# check that all output files exist for each head
for head, head_io_map in io_map.items():
for fn in head_io_map["output"]:
if not Path(fn).exists():
updated_run_heads.append(head)
break
return updated_run_heads
def _combine_io_maps(self, io_maps):
"""Aggregate io_maps from per-head to per-input image."""
io_map = {}
# create input-> per head output mapping
for head, head_io_map in io_maps.items():
for in_file, out_file in zip(head_io_map["input"], head_io_map["output"]):
if in_file not in io_map:
io_map[in_file] = {}
io_map[in_file][head] = out_file
return io_map
def _get_run_heads(self, batch, stage, batch_idx):
"""Get heads that are either specified as inference heads and don't have outputs yet or all
heads."""
run_heads = self.inference_heads
if stage in ("train", "val", "test"):
run_heads = [key for key in self.task_heads.keys() if key in batch]
return run_heads, None
filenames = batch[self.hparams.x_key].meta.get("filename_or_obj", None)
if filenames is None:
warnings.warn(
'Batch MetaTensors must have "filename_or_obj" to be saved out. Returning array prediction instead...',
UserWarning,
)
return run_heads, None
# IO_map is only generated for prediction
io_map = {h: self.task_heads[h].generate_io_map(filenames) for h in run_heads}
# only run heads that don't have outputs yet for prediction
run_heads = self._get_unrun_heads(io_map)
io_map = self._combine_io_maps(io_map)
return run_heads, io_map
def _to_tensor(self, batch):
"""Convert monai metatensors to tensors."""
for k, v in batch.items():
if isinstance(v, MetaTensor):
batch[k] = v.as_tensor()
return batch
[docs] def get_per_head(self, outs, key):
return {head_name: head_result[key] for head_name, head_result in outs.items()}
[docs] def model_step(self, stage, batch, batch_idx):
run_heads, _ = self._get_run_heads(batch, stage, batch_idx)
n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)
batch = self._to_tensor(batch)
outs = self.run_forward(batch, stage, n_postprocess, run_heads)
# aggregate losses across heads
results = self.get_per_head(outs, "loss")
results = self._sum_losses(results)
if n_postprocess > 0:
# add postprocessed images to return dict
for k in ("pred", "target", "input"):
results[k] = self.get_per_head(outs, k)
self.compute_metrics(results, None, None, stage)
return results
[docs] def training_step(self, batch, batch_idx):
return self.model_step("train", batch, batch_idx)
[docs] def validation_step(self, batch, batch_idx):
return self.model_step("val", batch, batch_idx)
[docs] def test_step(self, batch, batch_idx):
return self.model_step("test", batch, batch_idx)
[docs] def predict_step(self, batch, batch_idx):
stage = "predict"
run_heads, io_map = self._get_run_heads(batch, stage, batch_idx)
outs = None
if len(run_heads) > 0:
n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)
batch = self._to_tensor(batch)
outs = self.run_forward(batch, stage, n_postprocess, run_heads)
return io_map, outs