from typing import List, Optional, Union
import json
import logging
import os
import torch
from fnet.fnet_ensemble import FnetEnsemble
from fnet.fnet_model import Model
from fnet.utils.general_utils import str_to_class
logger = logging.getLogger(__name__)
def _find_model_checkpoint(path_model_dir: str, checkpoint: str):
"""Finds path to a specific model checkpoint.
Parameters
----------
path_model_dir
Path to model as a directory.
checkpoint
String that identifies a model checkpoint
Returns
-------
str
Path to saved model file.
"""
path_cp_dir = os.path.join(path_model_dir, "checkpoints")
if not os.path.exists(path_cp_dir):
raise ValueError(f"Model ({path_cp_dir} has no checkpoints)")
paths_cp = sorted(
[p.path for p in os.scandir(path_cp_dir) if p.path.endswith(".p")]
)
for path_cp in paths_cp:
if checkpoint in os.path.basename(path_cp):
return path_cp
raise ValueError(f"Model checkpoint not found: {checkpoint}")
[docs]def load_model(
path_model: str,
no_optim: bool = False,
checkpoint: Optional[str] = None,
path_options: Optional[str] = None,
) -> Model:
"""Loaded saved FnetModel.
Parameters
----------
path_model
Path to model as a directory or .p file.
no_optim
Set to not the model optimizer.
checkpoint
Optional string that identifies a model checkpoint
path_options
Path to training options json. For legacy saved models where the
FnetModel class/kwargs are not not included in the model save file.
Returns
-------
Model
Loaded model.
"""
if not os.path.exists(path_model):
raise ValueError(f"Model path does not exist: {path_model}")
if os.path.isdir(path_model):
if checkpoint is None:
path_model = os.path.join(path_model, "model.p")
if not os.path.exists(path_model):
raise ValueError(f"Default model not found: {path_model}")
if checkpoint is not None:
path_model = _find_model_checkpoint(path_model, checkpoint)
state = torch.load(path_model)
if "fnet_model_class" not in state:
if path_options is not None:
with open(path_options, "r") as fi:
train_options = json.load(fi)
if "fnet_model_class" in train_options:
state["fnet_model_class"] = train_options["fnet_model_class"]
state["fnet_model_kwargs"] = train_options["fnet_model_kwargs"]
fnet_model_class = state.get("fnet_model_class", "fnet.models.Model")
fnet_model_kwargs = state.get("fnet_model_kwargs", {})
model = str_to_class(fnet_model_class)(**fnet_model_kwargs)
model.load_state(state, no_optim)
return model
[docs]def load_or_init_model(path_model: str, path_options: str):
"""Loaded saved model if it exists otherwise inititialize new model.
Parameters
----------
path_model
Path to saved model.
path_options
Path to json where model training options are saved.
Returns
-------
FnetModel
Loaded or new FnetModel instance.
"""
if not os.path.exists(path_model):
with open(path_options, "r") as fi:
train_options = json.load(fi)
logger.info("Initializing new model!")
fnet_model_class = train_options["fnet_model_class"]
fnet_model_kwargs = train_options["fnet_model_kwargs"]
return str_to_class(fnet_model_class)(**fnet_model_kwargs)
return load_model(path_model, path_options=path_options)
[docs]def create_ensemble(paths_model: Union[str, List[str]], path_save_dir: str) -> None:
"""Create and save an ensemble model.
Parameters
----------
paths_model
Paths to models or model directories. Paths can be specified as items
in list or as a string with paths separated by spaces. Any model
specified as a directory assumed to be at 'directory/model.p'.
path_save_dir
Model save path directory. Model will be saved at in path_save_dir as
'model.p'.
"""
if isinstance(paths_model, str):
paths_model = paths_model.split(" ")
paths_member = []
for path_model in paths_model:
path_model = os.path.abspath(path_model)
if os.path.isdir(path_model):
path_member = os.path.join(path_model, "model.p")
if os.path.exists(path_member):
paths_member.append(path_member)
continue
paths_member.extend(
sorted(
[p.path for p in os.scandir(path_model) if p.path.endswith(".p")]
)
)
else:
paths_member.append(path_model)
path_save = os.path.join(path_save_dir, "model.p")
ensemble = FnetEnsemble(paths_model=paths_member)
ensemble.save(path_save)