Source code for fnet.fnet_ensemble

from typing import Union, List
import logging
import os

import numpy as np
import torch

from fnet.fnet_model import Model
from fnet.utils.general_utils import str_to_class

logger =

def _load_model(path_model: str) -> Model:
    """Load saved model from path."""
    state = torch.load(path_model)
    fnet_model_class = state["fnet_model_class"]
    fnet_model_kwargs = state["fnet_model_kwargs"]
    model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
    model.load_state(state, no_optim=True)
    return model

[docs]class FnetEnsemble(Model): """Ensemble of FnetModels. Parameters ---------- paths_model Path to a directory of saved models or a list of paths to saved models. Attributes ---------- paths_model : Union[str, List[str]] Paths to saved models in the ensemble. gpu_ids : List[int] GPU(s) used for prediction tasks. """ def __init__(self, paths_model: Union[str, List[str]]) -> None: if isinstance(paths_model, str): assert os.path.isdir(paths_model) paths_model = sorted( [ p.path for p in os.scandir(os.path.abspath(paths_model)) if p.path.lower().endswith(".p") ] ) assert len(paths_model) > 0 self.paths_model = paths_model self.gpu_ids = -1 def __str__(self): str_out = [] str_out.append(f"{len(self.paths_model)}-model ensemble:") str_out.extend([p for p in self.paths_model]) return os.linesep.join(str_out)
[docs] def to_gpu(self, gpu_ids: Union[int, list]) -> None: """Move network to specified GPU(s). Parameters ---------- gpu_ids GPU(s) on which to perform training or prediction. """ if isinstance(gpu_ids, int): gpu_ids = [gpu_ids] self.gpu_ids = gpu_ids
[docs] def predict( self, x: Union[torch.Tensor, np.ndarray], tta: bool = False ) -> torch.Tensor: """Performs model prediction. Parameters ---------- x Batched input. tta Set to to use test-time augmentation. Returns ------- torch.Tensor Model prediction. """ y_hat_mean = None for path_model in self.paths_model: model = _load_model(path_model) model.to_gpu(self.gpu_ids) y_hat = model.predict(x=x, tta=tta) if y_hat_mean is None: y_hat_mean = torch.zeros(*y_hat.size()) y_hat_mean += y_hat return y_hat_mean / len(self.paths_model)
# Override
[docs] def save(self, path_save: str): """Saves model to disk. Parameters ---------- path_save Filename to which model is saved. """ state = { "fnet_model_class": (self.__module__ + "." + self.__class__.__qualname__), "fnet_model_kwargs": {"paths_model": self.paths_model}, } dirname = os.path.dirname(path_save) if not os.path.exists(dirname): os.makedirs(dirname), path_save)"Ensemble model saved to: {path_save}")
# Override
[docs] def load_state(self, state: dict, no_optim: bool = False): return