Training a logistic regression classifier on MNIST under a norm constraint.

Open in Colab

Note

This example illustrates how to use Cooper on a simple machine learning problem that involves using mini-batches of data.

In this example, we consider a simple convex constrained optimization problem: training a Logistic Regression classifier on the MNIST dataset. The model is constrained so that the squared L2 norm of its parameters is less than 1.

Although we employ a simple Logistic Regression model, the same principles can be applied

This example illustrates how Cooper integrates with a typical PyTorch training pipeline, where:

  • models are defined using a torch.nn.Module,

  • steps loop over mini-batches of data,

  • CUDA acceleration is used.

%%capture
%pip install cooper-optim
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms

import cooper

np.random.seed(0)
torch.manual_seed(0)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class NormConstrainedLogisticRegression(cooper.ConstrainedMinimizationProblem):
    def __init__(self, constraint_level: float = 1.0):
        super().__init__()

        self.constraint_level = constraint_level

        # The multiplier must be on the same device as the model
        multiplier = cooper.multipliers.DenseMultiplier(num_constraints=1, device=DEVICE)

        self.norm_constraint = cooper.Constraint(
            constraint_type=cooper.ConstraintType.INEQUALITY,
            formulation_type=cooper.formulations.Lagrangian,
            multiplier=multiplier,
        )

    def compute_cmp_state(self, model: torch.nn.Module, inputs: torch.Tensor, targets: torch.Tensor) -> cooper.CMPState:
        logits = model(inputs.view(inputs.shape[0], -1))
        loss = torch.nn.functional.cross_entropy(logits, targets)
        accuracy = (logits.argmax(dim=1) == targets).float().mean()

        sq_l2_norm = model.weight.pow(2).sum() + model.bias.pow(2).sum()

        # Constraint violations use convention g <= 0
        constraint_state = cooper.ConstraintState(violation=sq_l2_norm - self.constraint_level)

        # Create a CMPState object, which contains the loss and observed constraints
        observed_constraints = {self.norm_constraint: constraint_state}
        return cooper.CMPState(loss=loss, observed_constraints=observed_constraints, misc={"accuracy": accuracy})


# Load the MNIST dataset
data_path = "./data"
data_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = datasets.MNIST(data_path, train=True, download=True, transform=data_transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=2, pin_memory=torch.cuda.is_available())

# Create a Logistic Regression model
model = torch.nn.Linear(in_features=28 * 28, out_features=10, bias=True)
model = model.to(DEVICE)

# Instantiate the constrained optimization problem
cmp = NormConstrainedLogisticRegression(constraint_level=1.0)

# Instantiate a PyTorch optimizer for the primal variables
primal_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, amsgrad=True)

# Instantiate PyTorch optimizer for the dual variables
dual_optimizer = torch.optim.SGD(cmp.dual_parameters(), lr=1e-3, maximize=True)

# Instantiate the Cooper optimizer
cooper_optimizer = cooper.optim.SimultaneousOptimizer(
    primal_optimizers=primal_optimizer, dual_optimizers=dual_optimizer, cmp=cmp
)

# Create a directory to save checkpoints
checkpoint_path = "./checkpoint"
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)

# Load checkpoint if exists
if not os.path.isfile(checkpoint_path + "/checkpoint.pth"):
    batch_ix = 0
    start_epoch = 0
    all_metrics = defaultdict(list)
else:
    checkpoint = torch.load(checkpoint_path + "/checkpoint.pth", weights_only=True)
    batch_ix = checkpoint["batch_ix"]
    start_epoch = checkpoint["epoch"] + 1
    all_metrics = checkpoint["all_metrics"]
    model.load_state_dict(checkpoint["model_state_dict"])
    cmp.load_state_dict(checkpoint["cmp_state_dict"])
    cooper_optimizer.load_state_dict(checkpoint["cooper_optimizer_state_dict"])

for epoch_num in range(start_epoch, 7):
    for inputs, targets in train_loader:
        batch_ix += 1

        if torch.cuda.is_available():
            inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True)

        _, cmp_state, primal_lagrangian_store, _ = cooper_optimizer.roll(
            compute_cmp_state_kwargs={"model": model, "inputs": inputs, "targets": targets}
        )

        if batch_ix % 3 == 0:
            all_metrics["batch_ix"].append(batch_ix)
            all_metrics["train_loss"].append(cmp_state.loss.item())
            all_metrics["train_acc"].append(cmp_state.misc["accuracy"].item())

            multiplier_value = primal_lagrangian_store.multiplier_values[cmp.norm_constraint].item()
            all_metrics["multiplier_value"].append(multiplier_value)

            constraint_violation = cmp_state.observed_constraints[cmp.norm_constraint].violation
            all_metrics["constraint_violation"].append(constraint_violation.item())

    # Save checkpoint at the end of each epoch
    torch.save(
        {
            "batch_ix": batch_ix,
            "epoch": epoch_num,
            "all_metrics": all_metrics,
            "model_state_dict": model.state_dict(),
            "cmp_state_dict": cmp.state_dict(),
            "cooper_optimizer_state_dict": cooper_optimizer.state_dict(),
        },
        checkpoint_path + "/checkpoint.pth",
    )

del batch_ix, all_metrics, model, cmp, cooper_optimizer

# Post-training analysis and plotting
all_metrics = torch.load(checkpoint_path + "/checkpoint.pth", weights_only=True)["all_metrics"]

fig, (ax0, ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=4, sharex=True, figsize=(18, 4))

ax0.plot(all_metrics["batch_ix"], all_metrics["train_loss"])
ax0.set_xlabel("Batch")
ax0.set_title("Training Loss")

ax1.plot(all_metrics["batch_ix"], all_metrics["train_acc"])
ax1.set_xlabel("Batch")
ax1.set_title("Training Acc")

ax2.plot(all_metrics["batch_ix"], np.stack(all_metrics["multiplier_value"]))
ax2.set_xlabel("Batch")
ax2.set_title("Inequality Multiplier")

ax3.plot(all_metrics["batch_ix"], np.stack(all_metrics["constraint_violation"]))
# Show that defect converges close to zero
ax3.axhline(0.0, c="gray", alpha=0.5)
ax3.set_xlabel("Batch")
ax3.set_title("Inequality Defect")

plt.show()
0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
100.0%
2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
100.0%
---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
Cell In[2], line 128
    125 del batch_ix, all_metrics, model, cmp, cooper_optimizer
    127 # Post-training analysis and plotting
--> 128 all_metrics = torch.load(checkpoint_path + "/checkpoint.pth", weights_only=True)["all_metrics"]
    130 fig, (ax0, ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=4, sharex=True, figsize=(18, 4))
    132 ax0.plot(all_metrics["batch_ix"], all_metrics["train_loss"])

File ~/checkouts/readthedocs.org/user_builds/cooper/envs/latest/lib/python3.10/site-packages/torch/serialization.py:1470, in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1462                 return _load(
   1463                     opened_zipfile,
   1464                     map_location,
   (...)
   1467                     **pickle_load_args,
   1468                 )
   1469             except pickle.UnpicklingError as e:
-> 1470                 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
   1471         return _load(
   1472             opened_zipfile,
   1473             map_location,
   (...)
   1476             **pickle_load_args,
   1477         )
   1478 if mmap:

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL collections.defaultdict was not an allowed global by default. Please use `torch.serialization.add_safe_globals([defaultdict])` or the `torch.serialization.safe_globals([defaultdict])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.