{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing Stateful Optimizers\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "From `abc` import `ABC` and `abstractmethod`. We'll write an abstract `Optimizer` base class; then we can make the optimizers its children.\n",
    "\n",
    "Import `defaultdict`, `Callable`, `Iterable`, `plt`, `torch`, `torch.nn.functional` as `F`and `tqdm`.\n",
    "\n",
    "Moreover, import the functions:\n",
    "1. `load_preprocessed_dataset`,  that you wrote in Notebook 0219,\n",
    "2. `get_dataloader_random_reshuffle` and `line_plot_confidence_band`, that you wrote in Notebook 0221, \n",
    "3. `get_mlp`, that you wrote in Notebook 0319 and\n",
    "4. `evaluate_model`, `get_welch_one_sided` and `normalize_features` that you wrote in Notebook 0321."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_preprocessed_path\"`: `str`  \n",
    "    The path to the preprocessed dataset that we'll load. Make this the path to your preprocessed MNIST.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Use a population of size 64 on GPU or 8 on CPU.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    Once again, these three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `128`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    As our models are growing bigger, instead of evaluating on the entire dataset, we'll to that in batches too.\n",
    "    1. If you're using CPU, you can make this something like the training minibatch.\n",
    "    2. If you're using GPU, it is best to find the largest value here that does not give you an out of memory error. You can experiment with powers of 2. It may be convenient to use the left bit shift operator `<<`. For example, on my home computer, I could use an evaluation minibatch size of $2^{14}$, that is `1 << 14`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Let's make a switch to turn off PBT. You can use this to test if the algorithm is doing any good, or optimization with the initial hyperparameters works just as well. Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `100_000`. Let early stopping take care of stopping.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `10_000` on GPU or `1000` on CPU.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `1000` on GPU or `100` on CPU.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. As the PBT paper does not specify this, let's start with `.95`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the Dataset\n",
    "\n",
    "Just like in the previous notebook:\n",
    "1. Load the preprocessed MNIST dataset.\n",
    "2. Remove the extra column of 1's if you have it.\n",
    "3. Normalize features by the total mean and std of the training features.\n",
    "\n",
    "I also recommend saving the preprocessed dataset you ended up with, so that next time you can load that at once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stateful Optimizer Classes\n",
    "\n",
    "Recall that `torch` has optimizer implementations with the following pattern:\n",
    "1. Initialize the optimizer with an iterable of model parameters as first positional argument. In a simple case, where we optimize one `model: torch.nn.Module`, this iterable can be `model.parameters()`.\n",
    "2. In a training step:\n",
    "    1. Call `optimizer.zero_grad` to set all parameter gradients to `None` or zero.\n",
    "    2. Calculate the training step loss tensor `loss`.\n",
    "    3. Call `loss.backward` to backpropagate gradients.\n",
    "    4. Call `optimizer.step` to update parameters.\n",
    "\n",
    "Now the built-in optimizers in `torch` don't have ensemble support. So we'll write our own optimizer classes. These will also have the above API. \n",
    "\n",
    "On the other hand, they will not be subclasses of `torch.optim.optimizer.Optimizer`, that is they will not be ready to be used in distributed training, that is multi-GPU training. If you need that in the future, please refer to the official repository:  \n",
    "https://github.com/pytorch/pytorch/tree/main/torch/optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `Optimizer` Abstract Base Class\n",
    "\n",
    "We'll first write an *Abstract Base Class (ABC)*. This pattern is helpful as we'll only have to write the fixed parts once, then we can write the changing parts in its descendants corresponding to various optimizer algorithms.\n",
    "\n",
    "Note that the class is a child of `ABC`. This raises an error, if you want to initialize an instance of a descendant class where not all abstract methods have been defined.\n",
    "\n",
    "You can decorate a method by `abstractmethod`, to declare it an abstract method, that is to be defined by a descendant class.\n",
    "\n",
    "Fill in the ABC below:\n",
    "1. First of all, fill in the `keys` class attribute. Here and in each descendant class, it will hold the keys of the hyperparameters in the configuration dictionary that the optimizer uses.\n",
    "2. I describe the methods below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `__init__`\n",
    "\n",
    "1. Assign to the instance variable `parameters` a list of the entries that `parameters` iterates over.\n",
    "2. Assign to the instance variable `config` an empty configuration dictionary. Then fill it with the method `update_config` if `config` is not `None`.\n",
    "3. Finally, assign to the instance variable `step_id` 0. This will be the train step counter."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `get_parameters`\n",
    "\n",
    "Recall that during a PBT update, we copy the parameters of some good population members over the parameters over some bad ones. When using stateful optimizers, we'll also need to do this for the optimizer state tensors. To this end, this method returns an iterable over both model and optimizer parameters. In the base, stateless optimizer case, the method should just return an iterable over the model parameters, recorded in the `parameters` attribute."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `get_hyperparameter`\n",
    "\n",
    "To make the hyperparameter broadcastable to the parameter, you can take a similar approach to how the learning rate is reshaped in the train step in the `train_supervised` function you wrote in Notebook 0321."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `step`\n",
    "\n",
    "First, increment the train step counter.\n",
    "\n",
    "Then, in a `torch.no_grad` context, iterate over pairs of parameter and parameter indices and call the `_update_parameter` method on them.\n",
    "\n",
    "In Python, by convention, we prefix a method by an underscore `_` to express that we don't expect the method to be called from outside code in a descendant of the class. In a more rigorous language such as C++ or Rust, these would be private methods. In Python, you are free to break the rules."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `_apply_parameter_update`\n",
    "\n",
    "As discussed in the lecture, this should add the update to the parameter in-place."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `_get_parameter_update`\n",
    "\n",
    "In this abstract method, we return the optional weight decay term:\n",
    "1. A tensor of zeros like the parameter, if the `weight_decay` hyperparameter is `None` and\n",
    "2. Minus the product of the values of the `learning_rate` and `weight_decay` hyperparameter times the parameter otherwise.\n",
    "\n",
    "This way, you can get the weight decay term in descendants by a call to `super()._get_parameter_update`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `_update_parameter`\n",
    "\n",
    "Call the following methods in order:\n",
    "1. `_update_state`\n",
    "2. `_get_parameter_update`\n",
    "3. `_apply_parameter_update`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Optimizer(ABC):\n",
    "    \"\"\"\n",
    "    Optimizer base class.\n",
    "    Can optimize model ensembles\n",
    "    with training defined by hyperparameter ensembles.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    parameters : `Iterable[torch.nn.Parameter]`\n",
    "        An iterable of `torch.nn.Parameter` to track.\n",
    "        In a simple case of optimizing a single `model: torch.nn.Module`,\n",
    "        this can be `model.parameters()`.\n",
    "    config : `dict`, optional\n",
    "        If given, the `update_config` method is called on it\n",
    "        to initialize hyperparameters. Default: `None`.\n",
    "\n",
    "    Class attributes\n",
    "    ----------------\n",
    "    keys : `tuple[str]`\n",
    "        The collection of the hyperparameter keys to track\n",
    "        in the configuration dictionary.\n",
    "\n",
    "        We expect the hyperparameter values to be either\n",
    "        `float` or `torch.Tensor`. In the latter case,\n",
    "        we expect the shape to be a prefix of the shape of the parameters.\n",
    "        The hyperparameter shapes are regarded as ensemble shapes.\n",
    "\n",
    "        Required keys:\n",
    "        `\"learning_rate\"`\n",
    "        `\"weight_decay\"`\n",
    "\n",
    "    Instance attributes\n",
    "    -------------------\n",
    "    config : `dict`\n",
    "        The hyperparameter dictionary.\n",
    "    parameters : `list[torch.nn.Parameter]`\n",
    "        The list of tracked parameters.\n",
    "    step_id : `int`\n",
    "        Train step counter.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "    def __init__(\n",
    "        self,\n",
    "        parameters: Iterable[torch.nn.Parameter],\n",
    "        config=None\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "    \n",
    "\n",
    "    def get_parameters(self) -> Iterable[torch.Tensor]:\n",
    "        \"\"\"\n",
    "        Get an iterable over tracked parameters\n",
    "        and optimizer state tensors.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def get_hyperparameter(\n",
    "        self,\n",
    "        key: str,\n",
    "        parameter: torch.Tensor\n",
    "    ) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Take the hyperparameter with name `key`,\n",
    "        transform it to `torch.Tensor` with the same\n",
    "        `device` and `dtype` as `parameter`\n",
    "        and reshape it to be broadcastable\n",
    "        to `parameter` by postfixing to its shape\n",
    "        an appropriate number of dimensions of 1.\n",
    "        \"\"\"        \n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def step(self):\n",
    "        \"\"\"\n",
    "        Update optimizer state, then apply parameter updates in-place.\n",
    "        Assumes that backpropagation has already occurred by\n",
    "        a call to the `backward` method of the loss tensor.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def update_config(self, config: dict):\n",
    "        \"\"\"\n",
    "        Update hyperparameters by the values in `config: dict`.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def zero_grad(self):\n",
    "        \"\"\"\n",
    "        Make the `grad` attribute of each tracked parameter `None`.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def _apply_parameter_update(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_update: torch.Tensor\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    @abstractmethod\n",
    "    def _get_parameter_update(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_id: int\n",
    "    ) -> torch.Tensor:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def _update_state(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_id: int\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def _update_parameter(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_id: int\n",
    "    ):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `SGD` Optimizer\n",
    "\n",
    "Now you can easily get an SGD optimizer by making a child of `Optimizer` and setting `_get_parameter_update` accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SGD(Optimizer):\n",
    "    \"\"\"\n",
    "    Stochastic Gradient Descent optimizer with optionally weight decay.\n",
    "    Can optimize model ensembles\n",
    "    with training defined by hyperparameter ensembles.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    parameters : `Iterable[torch.nn.Parameter]`\n",
    "        An iterable of `torch.nn.Parameter` to track.\n",
    "        In a simple case of optimizing a single `model: torch.nn.Module`,\n",
    "        this can be `model.parameters()`.\n",
    "    config : `dict`, optional\n",
    "        If given, the `update_config` method is called on it\n",
    "        to initialize hyperparameters. Default: `None`.\n",
    "\n",
    "    Class attributes\n",
    "    ----------------\n",
    "    keys : `tuple[str]`\n",
    "        The collection of the hyperparameter keys to track\n",
    "        in the configuration dictionary.\n",
    "\n",
    "        We expect the hyperparameter values to be either\n",
    "        `float` or `torch.Tensor`. In the latter case,\n",
    "        we expect the shape to be a prefix of the shape of the parameters.\n",
    "        The hyperparameter shapes are regarded as ensemble shapes.\n",
    "\n",
    "        Required keys:\n",
    "        `\"learning_rate\"`\n",
    "        `\"weight_decay\"`\n",
    "    \"\"\"\n",
    "    def _get_parameter_update(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_id: int\n",
    "    ) -> torch.Tensor:\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Update PBT\n",
    "\n",
    "Update the `pbt` function you wrote in Notebook 0321 as follows:\n",
    "1. It should take an `Optimizer` as positional argument that tracks the parameters of the `torch.nn.Module`.\n",
    "2. After the hyperparameters have been initialized, send them to the `Optimizer` by its `update_config` method.\n",
    "3. In a training step, use the `zero_grad` and `step` methods of the `Optimizer` instead of the hard-coded SGD operations.\n",
    "3. In a PBT update:\n",
    "    1. Update the parameters yielded by not `model.parameters()`, but `optimizer.get_parameters()`\n",
    "    2. At the end of the update, send the hyperparameters to the `Optimizer` by its `update_config` method.\n",
    "\n",
    "Afterwards:\n",
    "1. Initialize an MLP with 3 hidden layers of 128 dimensions each\n",
    "2. Initialize an `SGD` optimizer to optimize the MLP\n",
    "3. Run `pbt` with these and the dataset you loaded.\n",
    "4. Print the best validation metric you get\n",
    "5. Plot learning rate and weight decay log10 schedules with confidence bands."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pbt(\n",
    "    config: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    optimizer: Optimizer,\n",
    "    train_features: torch.Tensor,\n",
    "    train_values: torch.Tensor,\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_values: torch.Tensor,\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Population-based training on a supervised learning task.\n",
    "    Tuned hyperparameters are given by raw values and transformations.\n",
    "    This way, the hyperparameters are perturbed by\n",
    "    additive noise on raw values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            Minibatch size to use in a training step.\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Minibatch size to use in evaluation.\n",
    "            On CPU, should be about the same as `minibatch_size`.\n",
    "            On GPU, should be as big as possible without\n",
    "            incurring an Out of Memory error.\n",
    "        `\"pbt\"` : `bool`\n",
    "            Whether to use PBT updates in validations.\n",
    "            If `False`, the algorithm just samples hyperparameters at start,\n",
    "            then keeps them constant.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"steps_without_improvement`\" : `int`\n",
    "            If the number of training steps without improvement\n",
    "            exceeds this value, then training is stopped.\n",
    "        `\"valid_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of losses per ensemble member.\n",
    "    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of metrics per ensemble member.\n",
    "        We assume a greater metric is better.\n",
    "    `model` : `torch.nn.Module`\n",
    "        The model ensemble to tune.\n",
    "    `optimizer` : `Optimizer`\n",
    "        An optimizer that tracks the parameters of `model`.\n",
    "    `train_features` : `torch.Tensor`\n",
    "        Training feature tensor.\n",
    "    `train_values` : `torch.Tensor`\n",
    "        Training value tensor.\n",
    "    `valid_features` : `torch.Tensor`\n",
    "        Validation feature tensor.\n",
    "    `valid_values` : `torch.Tensor`\n",
    "        Validation value tensor.\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following key-value pairs:\n",
    "        `\"source mask\"` : `torch.Tensor`\n",
    "            The source masks of population members\n",
    "            that were replace by other members in a PBT update\n",
    "        `\"target indices\"` : `torch.Tensor`\n",
    "            The indices of population members\n",
    "            that the member where the source mask is to were replaced with.\n",
    "        `\"training loss\"` : `torch.Tensor`\n",
    "            The training losses at evaluation steps.\n",
    "        `\"training metric\"` : `torch.Tensor`\n",
    "            The training metrics at evaluation steps.\n",
    "        `\"validation loss\"` : `torch.Tensor`\n",
    "            The validation losses at evaluation steps.\n",
    "        `\"validation metric\"` : `torch.Tensor`\n",
    "            The validation metrics at evaluation steps.\n",
    "\n",
    "        In addition, for each tuned hyperparameter name,\n",
    "        we include a `torch.Tensor` of values per update.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `AdamW` `Optimizer`\n",
    "\n",
    "Write an `AdamW` `Optimizer`, as described in the lecture:\n",
    "1. Add to the `keys` class attribute the keys of the additional hyperparameters `epsilon`, `first_moment_decay` and `second_moment_decay`.\n",
    "2. At initialization, after calling `super().__init__`, initialize state tensors. These should be lists of zero tensors, the same number and shape as the parameters. This is why `_update_parameter` has arguments `parameter` and `parameter_id`. The latter is the index of the parameter in the parameter list.\n",
    "3. Write the `_update_state` and `_get_parameter_update` methods accordingly.\n",
    "4. The `get_parameters` method should now return an iterator not only on `parameters`, but also the state parameter collections `first_moments` and `second_moments`. To this end, you can use either\n",
    "    1. the function `itertools.chain` or\n",
    "    2. consecutive `yield from` statements on all 3 parameter collections.\n",
    "\n",
    "Then try it out similarly as you tried out `SGD`. Don't forget to reinitialize the model! In the hyperparameter schedules, take 1-log10 of the momentum moving average decay rates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdamW(Optimizer):\n",
    "    \"\"\"\n",
    "    Adam optimizer with optionally weight decay.\n",
    "    Can optimize model ensembles\n",
    "    with training defined by hyperparameter ensembles.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    parameters : `Iterable[torch.nn.Parameter]`\n",
    "        An iterable of `torch.nn.Parameter` to track.\n",
    "        In a simple case of optimizing a single `model: torch.nn.Module`,\n",
    "        this can be `model.parameters()`.\n",
    "    config : `dict`, optional\n",
    "        If given, the `update_config` method is called on it\n",
    "        to initialize hyperparameters. Default: `None`.\n",
    "\n",
    "    Class attributes\n",
    "    ----------------\n",
    "    keys : `tuple[str]`\n",
    "        The collection of the hyperparameter keys to track\n",
    "        in the configuration dictionary.\n",
    "\n",
    "        We expect the hyperparameter values to be either\n",
    "        `float` or `torch.Tensor`. In the latter case,\n",
    "        we expect the shape to be a prefix of the shape of the parameters.\n",
    "        The hyperparameter shapes are regarded as ensemble shapes.\n",
    "\n",
    "        Required keys:\n",
    "        `\"epsilon\"`,\n",
    "        `\"first_moment_decay\"`,\n",
    "        `\"learning_rate\"`\n",
    "        `\"second_moment_decay\"`,\n",
    "        `\"weight_decay\"`\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "    def __init__(\n",
    "        self,\n",
    "        parameters: Iterable[torch.nn.Parameter],\n",
    "        config=None\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def get_parameters(self) -> Iterable[torch.Tensor]:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def _get_parameter_update(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_id: int\n",
    "    ) -> torch.Tensor:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def _update_state(\n",
    "        self,\n",
    "        parameter: torch.nn.Parameter,\n",
    "        parameter_id: int\n",
    "    ):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets\n",
    "\n",
    "### MNIST\n",
    "\n",
    "https://huggingface.co/datasets/ylecun/mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
