Source code for cyto_dl.callbacks.grad_monitor
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback
[docs]class GradientLoggingCallback(Callback):
def __init__(self, grouping_level: int = 3):
super().__init__()
self.grouping_level = grouping_level
def _get_group(self, name):
if self.grouping_level == -1:
return name
return ".".join(name.split(".")[: self.grouping_level])
[docs] def on_train_epoch_end(self, trainer, pl_module):
# Initialize a dictionary to store the average norms
group_names = {self._get_group(name) for name, _ in pl_module.named_parameters()}
norms = {name: [] for name in group_names}
# Iterate over the model parameters
for name, parameter in pl_module.named_parameters():
# Check if the parameter has a norm
if parameter.grad is not None:
group_name = self._get_group(name)
# Store the norm in the dictionary
norms[group_name].append(torch.norm(parameter.grad).item())
# Calculate the average norm for each group
average_norms = {name: np.mean(norms[name]) for name in norms if len(norms[name]) > 0}
# Log the average norms
trainer.logger.log_metrics(average_norms, step=trainer.global_step)