import sys
from typing import Dict
import torch
import torch.nn as nn
from monai.inferers import sliding_window_inference
from torchmetrics import MeanMetric
from cyto_dl.models.im2im.multi_task import MultiTaskIm2Im
_DEFAULT_METRICS = {
    "train/loss/discriminator_loss": MeanMetric(),
    "val/loss/discriminator_loss": MeanMetric(),
    "test/loss/discriminator_loss": MeanMetric(),
    "train/loss/generator_loss": MeanMetric(),
    "val/loss/generator_loss": MeanMetric(),
    "test/loss/generator_loss": MeanMetric(),
    "train/loss": MeanMetric(),
    "val/loss": MeanMetric(),
    "test/loss": MeanMetric(),
}
[docs]class GAN(MultiTaskIm2Im):
    """Basic GAN model."""
    def __init__(
        self,
        *,
        backbone: nn.Module,
        task_heads: Dict[str, nn.Module],
        discriminator: nn.Module,
        x_key: str,
        save_dir="./",
        save_images_every_n_epochs=1,
        automatic_optimization: bool = False,
        inference_args: Dict = {},
        compile: False,
        **base_kwargs,
    ):
        """
        Parameters
        ----------
        backbone: nn.Module
            backbone network, parameters are shared between task heads
        task_heads: Dict
            task-specific heads
        discriminator
            discriminator network
        x_key: str
            key of input image in batch
        save_dir="./"
            directory to save images during training and validation
        save_images_every_n_epochs=1
            Frequency to save out images during training
        inference_args: Dict = {}
            Arguments passed to monai's [sliding window inferer](https://docs.monai.io/en/stable/inferers.html#sliding-window-inference)
        compile: False
            Whether to compile the model using torch.compile
        **base_kwargs:
            Additional arguments passed to BaseModel
        """
        metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS)
        super().__init__(
            metrics=metrics, backbone=backbone, task_heads=task_heads, x_key=x_key, **base_kwargs
        )
        self.automatic_optimization = False
        if compile is True and not sys.platform.startswith("win"):
            self.discriminator = torch.compile(discriminator)
        else:
            self.discriminator = discriminator
        assert len(self.task_heads.keys()) == 1, "Only single-head GANs are supported currently."
        self.inference_heads = list(self.task_heads.keys())
        for k, head in self.task_heads.items():
            head.update_params({"head_name": k, "x_key": x_key, "save_dir": save_dir})
    def _train_forward(self, batch, stage, save_image, run_heads):
        """During training we are only dealing with patches,so we can calculate per-patch loss,
        metrics, postprocessing etc."""
        z = self.backbone(batch[self.hparams.x_key])
        return {
            task: self.task_heads[task].run_head(
                z, batch, stage, save_image, discriminator=self.discriminator
            )
            for task in run_heads
        }
    def _inference_forward(self, batch, stage, save_image, run_heads):
        """During inference, we need to calculate per-fov loss/metrics/postprocessing.
        To avoid storing and passing to each head the intermediate results of the backbone, we need
        to run backbone + taskheads patch by patch, then do saving/postprocessing/etc on the entire
        fov.
        """
        with torch.no_grad():
            raw_pred_images = sliding_window_inference(
                inputs=batch[self.hparams.x_key],
                predictor=self.forward,
                run_heads=run_heads,
                **self.hparams.inference_args,
            )
        return {
            head_name: head.run_head(
                None,
                batch,
                stage,
                save_image,
                discriminator=self.discriminator if stage == "test" else None,
                run_forward=False,
                y_hat=raw_pred_images[head_name],
            )
            for head_name, head in self.task_heads.items()
        }
    def _extract_loss(self, outs, loss_type):
        loss = {
            f"{head_name}_{loss_type}": head_result[loss_type]
            for head_name, head_result in outs.items()
        }
        return self._sum_losses(loss)
[docs]    def model_step(self, stage, batch, batch_idx):
        run_heads, _ = self._get_run_heads(batch, stage, batch_idx)
        n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)
        batch = self._to_tensor(batch)
        outs = self.run_forward(batch, stage, n_postprocess, run_heads)
        loss_D = self._extract_loss(outs, "loss_D")
        loss_G = self._extract_loss(outs, "loss_G")
        if stage == "train":
            g_opt, d_opt = self.optimizers()
            g_opt.zero_grad()
            self.manual_backward(loss_G["loss"])
            g_opt.step()
            d_opt.zero_grad()
            self.manual_backward(loss_D["loss"])
            d_opt.step()
        results = {f"discriminator_{key}": loss for key, loss in loss_D.items()}
        results.update({f"generator_{key}": loss for key, loss in loss_G.items()})
        results["loss"] = results["generator_loss"]
        if n_postprocess > 0:
            # add postprocessed images to return dict
            for k in ("pred", "target", "input"):
                results[k] = self.get_per_head(outs, k)
        self.compute_metrics(results, None, None, stage)
        return results 
[docs]    def predict_step(self, batch, batch_idx):
        stage = "predict"
        run_heads, io_map = self._get_run_heads(batch, stage, batch_idx)
        outs = None
        if len(run_heads) > 0:
            n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)
            batch = self._to_tensor(batch)
            outs = self.run_forward(batch, stage, n_postprocess, run_heads)
        return io_map, outs