Source code for cyto_dl.utils.checkpoint
import torch
[docs]def load_checkpoint(model, load_params):
if load_params.get("weights_only"):
assert load_params.get(
"ckpt_path"
), "ckpt_path must be provided to with argument weights_only=True"
# load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights
state_dict = torch.load(load_params["ckpt_path"], map_location="cpu")[
"state_dict"
] # nosec B614
model.load_state_dict(state_dict, strict=load_params.get("strict", True))
# set ckpt_path to None to avoid loading checkpoint again with model.fit/model.test
load_params["ckpt_path"] = None
elif not load_params.get("strict"):
raise ValueError("To use `strict=False`, `weights_only` must be set to True")
return model, load_params