Example

import torch
import torch.nn as nn

from gradient_metrics import GradientMetricCollector
from gradient_metrics.metrics import Max, Min, PNorm


# Define some model
class MyNet(nn.Module):
    def __init__(self, image_size=32) -> None:
        """This is a model which predicts one of 10 classes.

        Args:
            image_size (int, optional): Input size. Should be a power of 2.
                Defaults to 32.
        """
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(
                in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(
                in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Flatten(),
        )

        self.fc = nn.Sequential(
            nn.Linear(in_features=(image_size // 4) ** 2 * 64, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=10),
        )

    def forward(self, x):
        out = self.features(x)
        return self.fc(out)


if __name__ == "__main__":
    # Create some dummy data
    x = torch.randn((10, 3, 32, 32))

    # Initialize the model
    net = MyNet()

    # Initialize the GradientMetricCollector
    # In this example we want to gather metrics from the whole model,
    # the feature extractor and the fully connected part
    # We use Max, Min and PNorm with p=2
    mcollector = GradientMetricCollector(
        [
            # Extract metrics from the whole network
            Max(net),
            Min(net),
            PNorm(net, p=2),
            # Extract metrics from the feature extraction part
            Max(net.features),
            Min(net.features),
            PNorm(net.features, p=2),
            # Extract metrics from the fully connected part
            Max(net.fc),
            Min(net.fc),
            PNorm(net.fc, p=2),
        ]
    )

    # predict the dummy data
    out = net(x)

    # create pseudo labels
    y_pred = out.argmax(1).clone().detach()

    # Compute a sample-wise loss with the pseudo labels
    # For this we use the binary-cross-entropy loss function
    crit = nn.CrossEntropyLoss(reduction="none")
    loss = crit(out, y_pred)

    # gather gradient metrics
    grad_metrics = mcollector(loss, retain_graph=True)

    # We will get an output shape of (10, 9)
    # 3 metrics over the whole network parameters
    # 3 metrics over the feature part
    # 3 metrics over the fully connected part
    print(f"Shape of the gradient metric output: {grad_metrics.shape}")

    # Now let's say we want to have the minimum and maximum of the absolute gradient
    # values in the first convolution's kernel. We can achieve that by using a
    # `grad_transform`:
    mcollector = GradientMetricCollector(
        [
            Min(net.features[0].weight, grad_transform=lambda grad: grad.abs()),
            Max(net.features[0].weight, grad_transform=lambda grad: grad.abs()),
        ]
    )

    grad_metrics = mcollector(loss)

    print(f"Value range of absolute gradient values:\n{grad_metrics}")