Finding a maximum entropy (discrete) distribution using the Lagrangian Approach.

Open in Colab

Here we consider a simple convex optimization problem to illustrate how to use Cooper. This example is inspired by this StackExchange question:

I am trying to solve the following problem using PyTorch: given a 6-sided die whose average roll is known to be 4.5, what is the maximum entropy distribution for the faces?

Formally, we want to solve the following optimization problem:

\[\begin{split} \begin{aligned} \max_{p} & -\sum_{i=1}^6 p_i \log p_i \\ & \sum_{i=1}^6 i p_i = 4.5 \\ \text{s.t.} & \sum_{i=1}^6 p_i = 1 \\ & p_i \geq 0 \quad \forall i \end{aligned} \end{split}\]

where \(p\) is the probability distribution over the faces of the die.

This example makes use of the \(\nu\)PI algorithm for improving the training dynamics of the dual variables. For a detailed explanation of the \(\nu\)PI algorithm, see the paper: On PI Controllers for Updating Lagrange Multipliers in Constrained Optimization at ICML 2024.

%%capture
%pip install cooper-optim
import matplotlib.pyplot as plt
import numpy as np
import torch

import cooper

torch.manual_seed(0)
np.random.seed(0)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MaximumEntropy(cooper.ConstrainedMinimizationProblem):
    def __init__(self, target_mean: float) -> None:
        super().__init__()
        self.target_mean = target_mean

        mean_multiplier = cooper.multipliers.DenseMultiplier(num_constraints=1, device=DEVICE)
        sum_multiplier = cooper.multipliers.DenseMultiplier(num_constraints=1, device=DEVICE)

        self.mean_constraint = cooper.Constraint(
            constraint_type=cooper.ConstraintType.EQUALITY,
            formulation_type=cooper.formulations.Lagrangian,
            multiplier=mean_multiplier,
        )
        self.sum_constraint = cooper.Constraint(
            constraint_type=cooper.ConstraintType.EQUALITY,
            formulation_type=cooper.formulations.Lagrangian,
            multiplier=sum_multiplier,
        )

        # For simple non-negativity constraints, we use projection

    def compute_cmp_state(self, log_probs: torch.Tensor) -> cooper.CMPState:
        probs = torch.exp(log_probs)
        entropy = -torch.sum(probs * log_probs)

        # Equality constraints for proper normalization and mean constraint
        mean = torch.sum(probs * torch.arange(1, len(probs) + 1, device=DEVICE))

        sum_constraint_violation = cooper.ConstraintState(violation=torch.sum(probs) - 1)
        mean_constraint_violation = cooper.ConstraintState(violation=mean - self.target_mean)

        observed_constraints = {
            self.sum_constraint: sum_constraint_violation,
            self.mean_constraint: mean_constraint_violation,
        }

        # Flip loss sign since we want to *maximize* the entropy
        return cooper.CMPState(loss=-entropy, observed_constraints=observed_constraints)


# Define the problem with the constraints
cmp = MaximumEntropy(target_mean=4.5)

# Define the primal parameters and optimizer
log_probs = torch.nn.Parameter(torch.log(torch.ones(6, device=DEVICE) / 6))
primal_optimizer = torch.optim.SGD([log_probs], lr=3e-2)

# We employ the nuPI algorithm for updating the dual variables
dual_optimizer = cooper.optim.nuPI(cmp.dual_parameters(), lr=1e-2, Kp=10, maximize=True)

cooper_optimizer = cooper.optim.SimultaneousOptimizer(
    primal_optimizers=primal_optimizer, dual_optimizers=dual_optimizer, cmp=cmp
)

state_history = {}
for i in range(3000):
    _, cmp_state, primal_lagrangian_store, _ = cooper_optimizer.roll(compute_cmp_state_kwargs={"log_probs": log_probs})

    observed_violation = cmp_state.observed_constraints[cmp.mean_constraint].violation
    observed_multiplier = list(primal_lagrangian_store.observed_multiplier_values())
    state_history[i] = {
        "loss": -cmp_state.loss.item(),
        "multipliers": torch.stack(observed_multiplier).detach(),
        "violation": observed_violation.detach(),
    }


# Theoretical solution
optimal_prob = torch.tensor([0.05435, 0.07877, 0.1142, 0.1654, 0.2398, 0.3475])
optimal_entropy = -torch.sum(optimal_prob * torch.log(optimal_prob))

# Generate plots
iters, loss_hist, multipliers_hist, violation_hist = zip(
    *[(k, v["loss"], v["multipliers"], v["violation"]) for k, v in state_history.items()]
)

_, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 4))

ax0.plot(iters, torch.stack(multipliers_hist).squeeze().cpu())
ax0.axhline(0.0, c="gray", alpha=0.35)
ax0.set_title("Multipliers")

ax1.plot(iters, torch.stack(violation_hist).squeeze().cpu(), label=["Sum Constraint", "Mean Constraint"])
ax1.legend()
# Show that defect remains below/at zero
ax1.axhline(0.0, c="gray", alpha=0.35)
ax1.set_title("Constraint Violations")

ax2.plot(iters, loss_hist)
# Show optimal entropy is achieved
ax2.axhline(optimal_entropy, c="gray", alpha=0.35, linestyle="dashed")
ax2.set_title("Objective")

plt.show()
/tmp/ipykernel_1257/2878445315.py:95: MatplotlibDeprecationWarning: Passing label as a length 2 sequence when plotting a single dataset is deprecated in Matplotlib 3.9 and will error in 3.11.  To keep the current behavior, cast the sequence to string before passing.
  ax1.plot(iters, torch.stack(violation_hist).squeeze().cpu(), label=["Sum Constraint", "Mean Constraint"])
../_images/37bca80a4cd07c60aef476d146a19263c97f925a3f2fcc93c71a9c987a150daa.png