Source code for cyto_dl.callbacks.layer_freeze

from typing import List, Optional, Union

import torch
from lightning.pytorch.callbacks import Callback


[docs]class LayerFreeze(Callback): def __init__( self, modules: Union[str, List[str]] = [], contains: Optional[Union[str, List[str]]] = None, excludes: Optional[Union[str, List[str]]] = None, ): """ Parameters ---------- modules: Union[str, List[str]] List of modules to search within contains: Optional[Union[str, List[str]]] List of strings that must be contained in the layer name to freeze excludes: Optional[Union[str, List[str]]] List of strings that must not be contained in the layer name to freeze """ assert ( contains is not None or excludes is not None ), "One of `contains` or `excludes` must be provided" self.modules = modules contains = contains or [] self.contains = [contains] if isinstance(contains, str) else contains excludes = excludes or [] self.excludes = [excludes] if isinstance(excludes, str) else excludes def _filter(self, mod): for n, p in mod.named_parameters(): requires_grad = True for contain in self.contains: if contain in n: requires_grad = False break for exclude in self.excludes: if exclude in n: requires_grad = True break if not requires_grad: print(f"\tFreezing layer {n}") p.requires_grad = requires_grad
[docs] def setup(self, trainer, pl_module, stage): for module in self.modules: if isinstance(getattr(pl_module, module), torch.nn.Module): print(f"Freezing layers in {module}") self._filter(getattr(pl_module, module))