Source code for cyto_dl.nn.head.mae_head

from cyto_dl.nn.head import BaseHead


[docs]class MAEHead(BaseHead):
[docs] def run_head( self, backbone_features, batch, stage, n_postprocess=1, run_forward=True, y_hat=None, ): """Run head on backbone features, calculate loss, postprocess and save image, and calculate metrics.""" if run_forward: y_hat, mask = backbone_features else: raise ValueError("MAE head is only intended for use during training.") loss = (batch[self.head_name] - y_hat) ** 2 if mask.sum() > 0: loss = loss[mask.bool()].mean() else: loss = loss.mean() return { "loss": loss, "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), "target": ( self._postprocess( batch[self.head_name], img_type="input", n_postprocess=n_postprocess ) if stage != "predict" else None ), "input": ( self._postprocess(batch[self.x_key], img_type="input", n_postprocess=n_postprocess) if stage != "predict" else None ), }