Source code for cooper.state_logger

from collections import OrderedDict
from copy import deepcopy
from typing import List

from .problem import Formulation


[docs]class StateLogger: """ Utility for storing optimization metrics (e.g. loss, multipliers) through training. Args: save_metrics: List of metric names to be stored. Currently supported values are: ``loss``, ``ineq_defect``, ``eq_defect``, ``ineq_multipliers``, ``eq_multipliers``. """ def __init__(self, save_metrics: List[str]): self.logger: OrderedDict = OrderedDict() self.save_metrics = save_metrics
[docs] def store_metrics( self, formulation: Formulation, step_id: int, partial_dict: dict = None, ): """ Store a new screenshot of the metrics. Args: formulation: Formulation from which to take the current metric values. step_id: Identifier for the optimization step. partial_dict: Auxiliary dictionary with other metrics to be logged, but which are not part of the "canonical" options available in ``save_metrics``. Defaults to None. """ aux_dict = {} for metric in self.save_metrics: if metric == "loss": aux_dict[metric] = formulation.cmp.state.loss.item() elif metric == "ineq_defect": aux_dict[metric] = deepcopy(formulation.cmp.state.ineq_defect.data) elif metric == "eq_defect": aux_dict[metric] = deepcopy(formulation.cmp.state.eq_defect.data) elif metric == "ineq_multipliers": aux_dict[metric] = deepcopy(formulation.state()[0].data) elif metric == "eq_multipliers": aux_dict[metric] = deepcopy(formulation.state()[1].data) if partial_dict is not None: aux_dict.update(partial_dict) self.save_metrics = list(set(self.save_metrics + list(partial_dict.keys()))) self.logger[step_id] = aux_dict
[docs] def unpack_stored_metrics(self) -> dict: """ Returns a dictionary containing the stored values separated by metric. """ unpacked_metrics = {} unpacked_metrics["iters"] = list(self.logger.keys()) for metric in self.save_metrics: unpacked_metrics[metric] = [_[metric] for (__, _) in self.logger.items()] return unpacked_metrics