{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "661d2ea8",
   "metadata": {},
   "source": [
    "# Implementing Dropout and Layer Normalization\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `ABC`, `abstractmethod`, `defaultdict`, `Callable`, `Iterable`, `datasets`, `itertools`, `Optional`, `os`, `torch` and `tqdm`.\n",
    "\n",
    "Moreover, import the following:\n",
    "1. The functions `get_accuracy` and `get_cross_entropy`, that you wrote in Notebook 0221.\n",
    "1. The function `normalize_features`, that you wrote in Notebook 0321.\n",
    "1. The classes `AdamW` and `Optimizer`, that you wrote in Notebook 0326.\n",
    "1. The functions `pbt_init` and `pbt_update`, that you wrote in Notebook 0328.\n",
    "3. The functions `get_dataloader_random_reshuffle` and `to_ensembled`, that you wrote in Notebook 0416.\n",
    "5. The classes `Conv2D`, `DictReLU`, `Linear`, and `Pool2D`, and the function `evaluate_model`,  that you wrote in Notebook 0423."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f184e86",
   "metadata": {},
   "outputs": [],
   "source": [
    "from abc import (\n",
    "    ABC,\n",
    "    abstractmethod\n",
    ")\n",
    "from collections import defaultdict\n",
    "from collections.abc import (\n",
    "    Callable,\n",
    "    Iterable\n",
    ")\n",
    "import datasets\n",
    "import itertools\n",
    "import os\n",
    "import torch\n",
    "import tqdm\n",
    "from typing import Optional\n",
    "\n",
    "from util_0425 import (\n",
    "    AdamW,\n",
    "    Conv2D,\n",
    "    DictReLU,\n",
    "    evaluate_model,\n",
    "    get_accuracy,\n",
    "    get_cross_entropy,\n",
    "    get_dataloader_random_reshuffle,\n",
    "    Linear,\n",
    "    normalize_features,\n",
    "    Optimizer,\n",
    "    pbt_init,\n",
    "    pbt_update,\n",
    "    Pool2D,\n",
    "    to_ensembled\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f79c4fc7",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_path\"`: `str`  \n",
    "    Make this the ID of CIFAR-10 [10], to be found at https://huggingface.co/datasets/uoft-cs/cifar10.\n",
    "- `\"dataset_preprocessed_path\"` : `str`  \n",
    "    Below, I'm going to suggest saving the preprocessed dataset at this location.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Make this `(16,)`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    These three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "    1. Dropout probability $p$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "\n",
    "    We know that the dropout probability should be in the unit interval $[0,1]$. Moreover, it may not help if we zero more than half of the neurons. Thus, let's make its raw initial distribution the uniform distribution on $[0, 0.5]$. For raw perturb, maybe we can use a normal distribution with center $0$ and std $0.1$. For transform function, I recommend clipping the values at $0$ and $1$ as they should be probabilities.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `64`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    On my home computer, I can make this `128`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this `10_001`.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `10_000`.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. Based on my experiments in the setting of Homework 9, maybe you can try `.8`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92500ed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_path\": \"uoft-cs/cifar10\",\n",
    "    \"dataset_preprocessed_path\": \"data/cifar10.pt\",\n",
    "    \"device\": \"cuda\",\n",
    "    \"ensemble_shape\": (16,),\n",
    "    \"hyperparameter_raw_init_distributions\": {\n",
    "        \"dropout_p\": torch.distributions.Uniform(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(.5, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"epsilon\": torch.distributions.Uniform(\n",
    "            torch.tensor(-10, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"first_moment_decay\": torch.distributions.Uniform(\n",
    "            torch.tensor(-3, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"learning_rate\": torch.distributions.Uniform(\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"second_moment_decay\": torch.distributions.Uniform(\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"weight_decay\": torch.distributions.Uniform(\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-1, device=\"cuda\", dtype=torch.float32)\n",
    "        )\n",
    "    },\n",
    "    \"hyperparameter_raw_perturb\": {\n",
    "        \"dropout_p\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(.1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"epsilon\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"first_moment_decay\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"learning_rate\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"second_moment_decay\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"weight_decay\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "    },\n",
    "    \"hyperparameter_transforms\": {\n",
    "        \"dropout_p\": lambda p: p.clip(0,1),\n",
    "        \"epsilon\": lambda log10: 10 ** log10,\n",
    "        \"first_moment_decay\": lambda x: (1 - 10 ** x).clamp(0, 1),\n",
    "        \"learning_rate\": lambda log10: 10 ** log10,\n",
    "        \"second_moment_decay\": lambda x: (1 - 10 ** x).clamp(0, 1),\n",
    "        \"weight_decay\": lambda log10: 10 ** log10,\n",
    "    },\n",
    "    \"improvement_threshold\": 1e-4,\n",
    "    \"minibatch_size\": 64,\n",
    "    \"minibatch_size_eval\": 1 << 7,\n",
    "    \"pbt\": True,\n",
    "    \"seed\": 0,\n",
    "    \"steps_num\": 10_001,\n",
    "    \"steps_without_improvement\": 10_000,\n",
    "    \"valid_interval\": 1000,\n",
    "    \"welch_confidence_level\": .8,\n",
    "    \"welch_sample_size\": 10,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af797e3a",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed, as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "885d1a14",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf1a2ed1",
   "metadata": {},
   "source": [
    "## Load and Preprocess Dataset\n",
    "\n",
    "Just like in Notebook 0423, load and preprocess CIFAR-10.\n",
    "\n",
    "Actually, I suggest first checking if the path `\"dataset_preprocessed_path\"` exists. If not, then load and preprocess the dataset, then save it to this location.\n",
    "\n",
    "In either case, you can just load the preprocessed dataset afterwards."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842bf118",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(config[\"dataset_preprocessed_path\"]):\n",
    "    dataset = datasets.load_dataset(\n",
    "        config[\"dataset_path\"]\n",
    "    ).with_format(\n",
    "        \"torch\",\n",
    "        device=config[\"device\"]\n",
    "    )\n",
    "    train, test = (\n",
    "        dataset[key]\n",
    "        for key in [\"train\", \"test\"]\n",
    "    )\n",
    "    train_valid = train.train_test_split(\n",
    "        seed=config[\"seed\"],\n",
    "        test_size=len(test),\n",
    "    )\n",
    "    train, valid = (\n",
    "        train_valid[key]\n",
    "        for key in [\"train\", \"test\"]\n",
    "    )\n",
    "\n",
    "    (\n",
    "        train_features,\n",
    "        valid_features,\n",
    "        test_features\n",
    "    ) = (\n",
    "        dataset[\"img\"].to(torch.float32)\n",
    "        for dataset in (train, valid, test)\n",
    "    )\n",
    "\n",
    "    print(train_features.std())\n",
    "    \n",
    "    normalize_features(\n",
    "        train_features,\n",
    "        (valid_features, test_features)\n",
    "    )\n",
    "\n",
    "    print(train_features.std())\n",
    "\n",
    "    print(train[\"label\"].dtype)\n",
    "\n",
    "    torch.save(\n",
    "        {\n",
    "            \"train_features\": train_features,\n",
    "            \"train_labels\": train[\"label\"],\n",
    "            \"valid_features\": valid_features,\n",
    "            \"valid_labels\": valid[\"label\"],\n",
    "            \"test_features\": test_features,\n",
    "            \"test_labels\": test[\"label\"],\n",
    "        },\n",
    "        config[\"dataset_preprocessed_path\"]\n",
    "    )\n",
    "\n",
    "loaded = torch.load(\n",
    "    config[\"dataset_preprocessed_path\"],\n",
    "    weights_only=True\n",
    ")\n",
    "(\n",
    "    train_features,\n",
    "    train_labels,\n",
    "    valid_features,\n",
    "    valid_labels,\n",
    "    test_features,\n",
    "    test_labels\n",
    ") = (\n",
    "    loaded[key]\n",
    "    for key in (\n",
    "        \"train_features\",\n",
    "        \"train_labels\",\n",
    "        \"valid_features\",\n",
    "        \"valid_labels\",\n",
    "        \"test_features\",\n",
    "        \"test_labels\"\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "626f0549",
   "metadata": {},
   "source": [
    "### Get Flattenet Datasets\n",
    "\n",
    "For quicker testing, first, we'll add dropout and layer normalization to an MLP. To be used with an MLP, first create train and validation split datasets with flattened features. Check if the feature tensors you got are 2-dimensional."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71f7031d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train = {\n",
    "    \"features\": train_features.flatten(1),\n",
    "    \"label\": train_labels\n",
    "}\n",
    "dataset_valid = {\n",
    "    \"features\": valid_features.flatten(1),\n",
    "    \"label\": valid_labels\n",
    "}\n",
    "\n",
    "for d in (dataset_train, dataset_valid):\n",
    "    for key in (\"features\", \"label\"):\n",
    "        print(key, d[key].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "240455e5",
   "metadata": {},
   "source": [
    "## Training Code Updates\n",
    "\n",
    "### `update_model`\n",
    "\n",
    "Note that we have a new hyperparameter `dropout_p`, that determines the probability that a given feature entry will be dropped by dropout. In our setup, hyperparameters are tracked in a dictionary. So far, we only needed to update the optimizer by the changes in this dictionary. Now, we need this for a model too. To this end, implement the function below. You can iterate over the submodules of a `torch.nn.Module` by its `modules` method. Among the iterates, if one has an attribute `config`, use its `update` method to send the hyperparameter updates to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "595f2b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_model(\n",
    "    config: dict,\n",
    "    model: torch.nn.Module\n",
    "):\n",
    "    \"\"\"\n",
    "    Update the configuration dictionary of a model.\n",
    "    We iterate over its submodules and whichever has a `config` attribute,\n",
    "    we update it by the included `config` dictionary.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        The updated configuration dictionary.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to update.\n",
    "    \"\"\"\n",
    "    for module in model.modules():\n",
    "        if hasattr(module, \"config\"):\n",
    "            module.config.update(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23abc965",
   "metadata": {},
   "source": [
    "## `model.train()` and `model.eval()`\n",
    "\n",
    "We discussed in the lecture, that there are layers such as dropout and batch normalization, that behave differently during training and evaluation. You can switch between these modes by calling the `train` and `eval` methods of the model before training and evaluation steps, respectively.\n",
    "\n",
    "Make this change to the function `train_supervised` you wrote in Notebook 0423. Moreover, still in the function `train_supervised`, call the function `update_model` after calls to the functions `pbt_init` or `pbt_update`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceb99cec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_supervised(\n",
    "    config: dict,\n",
    "    dataset_train: dict,\n",
    "    dataset_valid: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    optimizer: Optimizer,\n",
    "    out_features: Optional[int] = None,\n",
    "    target_key=\"target\"\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Population-based training on a supervised learning task.\n",
    "    Tuned hyperparameters are given by raw values and transformations.\n",
    "    This way, the hyperparameters are perturbed by\n",
    "    additive noise on raw values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            Minibatch size to use in a training step.\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Minibatch size to use in evaluation.\n",
    "            On CPU, should be about the same as `minibatch_size`.\n",
    "            On GPU, should be as big as possible without\n",
    "            incurring an Out of Memory error.\n",
    "        `\"pbt\"` : `bool`\n",
    "            Whether to use PBT updates in validations.\n",
    "            If `False`, the algorithm just samples hyperparameters at start,\n",
    "            then keeps them constant.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"steps_without_improvement`\" : `int`\n",
    "            If the number of training steps without improvement\n",
    "            exceeds this value, then training is stopped.\n",
    "        `\"valid_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    dataset_train : `dict`\n",
    "        The dataset to train the model on.\n",
    "    dataset_valid : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of losses per ensemble member.\n",
    "    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of metrics per ensemble member.\n",
    "        We assume a greater metric is better.\n",
    "    `model` : `torch.nn.Module`\n",
    "        The model ensemble to tune.\n",
    "    `optimizer` : `Optimizer`\n",
    "        An optimizer that tracks the parameters of `model`.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    out_features: `int`, optional\n",
    "        The number of output features in the predict tensors.\n",
    "        By default, it is the last dimension of the target tensor.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following key-value pairs:\n",
    "        `\"best parameters\"` : `dict`  \n",
    "            The state dictionary of the model with the best metric\n",
    "            encountered during training.\n",
    "        `\"source mask\"` : `torch.Tensor`\n",
    "            The source masks of population members\n",
    "            that were replace by other members in a PBT update\n",
    "        `\"target indices\"` : `torch.Tensor`\n",
    "            The indices of population members\n",
    "            that the member where the source mask is to were replaced with.\n",
    "        `\"validation metric\"` : `torch.Tensor`\n",
    "            The validation metrics at evaluation steps.\n",
    "\n",
    "        In addition, for each tuned hyperparameter name,\n",
    "        we include a `torch.Tensor` of values per update.\n",
    "    \"\"\"\n",
    "    ensemble_shape = config[\"ensemble_shape\"]\n",
    "    if len(ensemble_shape) != 1:\n",
    "        raise ValueError(f\"The number of dimensions in the ensemble shape should be 1 for the  population size, but it is {len(ensemble_shape)}\")\n",
    "\n",
    "    population_size = ensemble_shape[0]\n",
    "    config_local = dict(config)\n",
    "    log = defaultdict(list)\n",
    "\n",
    "    pbt_init(config_local, log)\n",
    "\n",
    "    optimizer.update_config(config_local)\n",
    "    update_model(config_local, model)\n",
    "\n",
    "    best_valid_metric = -torch.inf\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    steps_without_improvement = 0\n",
    "    train_dataloader = get_dataloader_random_reshuffle(\n",
    "        config,\n",
    "        dataset_train\n",
    "    )\n",
    "\n",
    "    for step_id in progress_bar:        \n",
    "        if step_id % config[\"valid_interval\"] == 0:\n",
    "            model.eval()\n",
    "            with torch.no_grad():\n",
    "                validation_metric = evaluate_model(\n",
    "                    config,\n",
    "                    dataset_valid,\n",
    "                    get_metric,\n",
    "                    model,\n",
    "                    out_features=out_features,\n",
    "                    target_key=target_key\n",
    "                ).nan_to_num(-torch.inf)\n",
    "                log[\"validation metric\"].append(validation_metric)\n",
    "                print(\n",
    "                    f\"validation metric {validation_metric.max().cpu().item():.4f}\"\n",
    "                )\n",
    "\n",
    "                best_last_metric, best_last_metric_id \\\n",
    "                    = log[\"validation metric\"][-1].max(dim=-1)\n",
    "                print(\n",
    "                    f\"Best last metric {best_last_metric.cpu().item():.2f}\",\n",
    "                    flush=True\n",
    "                )\n",
    "                if (\n",
    "                    best_valid_metric + config[\"improvement_threshold\"]\n",
    "                ) < best_last_metric:\n",
    "                    print(\n",
    "                        f\"New best metric\",\n",
    "                        flush=True\n",
    "                    )\n",
    "                    best_valid_metric = best_last_metric\n",
    "                    steps_without_improvement = 0\n",
    "                    log[\"best parameters\"] = {\n",
    "                        key: value[best_last_metric_id].clone()\n",
    "                        for key, value in model.state_dict().items()\n",
    "                    }\n",
    "                else:\n",
    "                    print(\n",
    "                        f\"Best metric {best_valid_metric.cpu().item():.2f}\",\n",
    "                        flush=True\n",
    "                    )\n",
    "                    steps_without_improvement += config[\"valid_interval\"]\n",
    "                    if steps_without_improvement > config[\n",
    "                        \"steps_without_improvement\"\n",
    "                    ]:\n",
    "                        break\n",
    "\n",
    "                if config[\"pbt\"] and (len(log[\"validation metric\"]) >= config[\n",
    "                    \"welch_sample_size\"\n",
    "                ]):\n",
    "                    evaluations = torch.stack(\n",
    "                        log[\"validation metric\"][-config[\"welch_sample_size\"]:]\n",
    "                    )\n",
    "                    pbt_update(\n",
    "                        config_local, evaluations, log, optimizer.get_parameters()\n",
    "                    )\n",
    "\n",
    "                    update_model(config_local, model)\n",
    "                    optimizer.update_config(config_local)\n",
    "\n",
    "        model.train()\n",
    "\n",
    "        minibatch = next(train_dataloader)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        predict = model(minibatch)[\"features\"]\n",
    "        target = minibatch[target_key]\n",
    "\n",
    "        loss = get_loss(predict, target).sum()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "\n",
    "    progress_bar.close()\n",
    "    for key, value in log.items():\n",
    "        if isinstance(value, list):\n",
    "            log[key] = torch.stack(value)\n",
    "\n",
    "    return log"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffa2c5f1",
   "metadata": {},
   "source": [
    "### Train an MLP\n",
    "\n",
    "Create an MLP as `torch.Sequential` of `Linear` and `DictReLU` layers and an `AdamW` optimizer to optimize its parameters. Train it via `train_supervised`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "240b8607",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    Linear(\n",
    "        config,\n",
    "        dataset_train[\"features\"].shape[1],\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        10\n",
    "    )\n",
    ")\n",
    "\n",
    "optimizer = AdamW(model.parameters())\n",
    "\n",
    "log = train_supervised(\n",
    "    config,\n",
    "    dataset_train,\n",
    "    dataset_valid,\n",
    "    get_cross_entropy,\n",
    "    get_accuracy,\n",
    "    model,\n",
    "    optimizer,\n",
    "    out_features=10,\n",
    "    target_key=\"label\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aee13798",
   "metadata": {},
   "source": [
    "Call `del` on the model and the optimizer to delete them and release memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3e2db1",
   "metadata": {},
   "outputs": [],
   "source": [
    "del model, optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "860090bf",
   "metadata": {},
   "source": [
    "## Implementing Dropout\n",
    "\n",
    "The built-in `Dropout` layer of `torch` records the dropout probability at the attribute `p`: [click here](https://github.com/pytorch/pytorch/blob/78953ee1223391df5c162ac6d7e3eb70294a722e/torch/nn/modules/dropout.py#L35) to access the source code.\n",
    "\n",
    "We would like to tune the dropout probability through our PBT machinery, which stores hyperparameters in a configuration dictionary. Thus, we'll write our own `Dropout` layer.\n",
    "\n",
    "Note that the `\"dropout_p\"` entry of the configuration dictionary will be a tensor of shape `ensemble_shape` of dropout probabilities. So, upon receiving a feature tensor (in our present setting, as the `\"features\"` entry of a data dictionary):\n",
    "1. Check if the model is in training mode, via its `training` attribute. If not, just return the input batch dictionary.\n",
    "2. Apply the `to_ensembled` function to the feature tensor, to make sure that it includes ensemble dimensions.\n",
    "3. Broadcast the dropout probabilities to the right, by first adding an appropriate number of dimension 1's to the right of its shape, to match the shape of the feature tensor.\n",
    "4. Multiply the feature tensor by $\\frac{1}{1-p+\\epsilon}$, where $p$ are the dropout probabilities and $\\epsilon$ is a small number added for numerical stability.\n",
    "5. Get a sample from the uniform distribution on the unit interval of the same shape as the feature tensor.\n",
    "6. Get a mask of values where the sample is larger than the dropout probabilities.\n",
    "7. Multiplying the features with the mask, you can affect dropout.\n",
    "\n",
    "Write the layer, then create an MLP where in front of each affine transformation, you include a dropout layer. Train this too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e98a6e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dropout(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready dropout layer.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"dropout_p\"` : `torch.Tensor`\n",
    "            Dropout probability tensor, of shape `ensemble_shape`.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required key:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of features.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config: dict):\n",
    "        super().__init__()\n",
    "\n",
    "        self.config = config\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> dict:\n",
    "        if not self.training:\n",
    "            return batch\n",
    "        \n",
    "        ensemble_shape = self.config[\"ensemble_shape\"]\n",
    "        ensemble_dim = len(ensemble_shape)\n",
    "        features = batch[\"features\"]\n",
    "        \n",
    "        features = to_ensembled(self.config[\"ensemble_shape\"], features)\n",
    "        dropout_p = config[\"dropout_p\"].unflatten(\n",
    "            -1,\n",
    "            ensemble_shape + (1,) * (len(features.shape) - ensemble_dim)\n",
    "        )\n",
    "\n",
    "        features = features / (1 - dropout_p + 1e-4)\n",
    "\n",
    "        sample = torch.rand(features.shape, device=features.device)\n",
    "        mask = sample > dropout_p\n",
    "\n",
    "        features = features * mask\n",
    "\n",
    "        return batch | {\"features\": features}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b5630c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        dataset_train[\"features\"].shape[1],\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        10\n",
    "    )\n",
    ")\n",
    "\n",
    "optimizer = AdamW(model.parameters())\n",
    "\n",
    "log = train_supervised(\n",
    "    config,\n",
    "    dataset_train,\n",
    "    dataset_valid,\n",
    "    get_cross_entropy,\n",
    "    get_accuracy,\n",
    "    model,\n",
    "    optimizer,\n",
    "    out_features=10,\n",
    "    target_key=\"label\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "614bee26",
   "metadata": {},
   "source": [
    "Let's again delete the model and the optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fca5bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "del model, optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cc2a348",
   "metadata": {},
   "source": [
    "## Implementing Layer Normalization\n",
    "\n",
    "Now for layer normalization! Once again, the built-in `LayerNorm` is not up to the task, as we can't train different scale and offset tensors per ensemble member.\n",
    "\n",
    "Moreover, we include a keyword argument that is not in the built-in version: `normalized_offset`. We need this when we want to normalize along dimensions that are not the last in the shape. For example, today, after adding layer normalization layers to the MLP, we'll do the same for the CNN, but we'll only normalize along the feature dimension, that is before the sequence dimensions in the `torch` image processing convention.\n",
    "\n",
    "Write the new layer as per the docstrings. Then add layer normalization layers in front of the dropout layers in the MLP and train it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20c0def9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LayerNorm(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready layer normalization layer\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    normalized_shape : `int | tuple[int]`\n",
    "        The part of the shape of the incoming tensors\n",
    "        that are to be normalized together with batch dimensions.\n",
    "        We view the following as batch dimensions:\n",
    "        ```\n",
    "        range(\n",
    "            len(ensemble_shape),\n",
    "            -len(normalized_shape) - normalized_offset\n",
    "        )\n",
    "        ```\n",
    "        If an integer, we view it as a single-element tuple.\n",
    "    bias : `bool`, optional\n",
    "        If `elementwise_affine`, whether to include offset\n",
    "        in the learned transformation. Default: `True`.\n",
    "    elementwise_affine : `bool`, optional\n",
    "        Whether to include learnable scale. If this and `bias`,\n",
    "        then we also include learnable offset. These will be tensors\n",
    "        of shape `ensemble_shape + normalized_shape` that are\n",
    "        broadcast to the incoming feature tensors appropriately.\n",
    "        Default: `True`.\n",
    "    epsilon : `float`, optional\n",
    "        Small positive value, to be included in the divisor when we\n",
    "        divide by the variance, for numerical stability. Default: `1e-5`.\n",
    "    normalized_offset : `int`, optional\n",
    "        We get `normalized_shape` out of an incoming feature tensor\n",
    "        at dimensions\n",
    "        ```\n",
    "        range(\n",
    "            -len(normalized_shape) - normalized_offset,\n",
    "            -normalized_offset\n",
    "        )\n",
    "        ```\n",
    "        Default: `0`.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required key:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of features.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        normalized_shape: int | tuple[int],\n",
    "        bias=True,\n",
    "        elementwise_affine=True,\n",
    "        epsilon=1e-5,\n",
    "        normalized_offset=0\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        if hasattr(normalized_shape, \"__int__\"):\n",
    "            self.normalized_shape = (normalized_shape,)\n",
    "        else:\n",
    "            self.normalized_shape = normalized_shape\n",
    "\n",
    "        self.ensemble_shape = config[\"ensemble_shape\"]\n",
    "        self.epsilon = epsilon\n",
    "        self.normalized_offset = normalized_offset\n",
    "\n",
    "        if elementwise_affine:\n",
    "            self.scale = torch.nn.Parameter(torch.ones(\n",
    "                self.ensemble_shape + self.normalized_shape + (1,) * normalized_offset,\n",
    "                device=config[\"device\"],\n",
    "                dtype=torch.float32\n",
    "            ))\n",
    "            if bias:\n",
    "                self.bias = torch.nn.Parameter(torch.zeros_like(self.scale))\n",
    "            else:\n",
    "                self.bias = None\n",
    "\n",
    "        else:\n",
    "            self.bias, self.scale = None, None\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> dict:\n",
    "        features: torch.Tensor = batch[\"features\"]\n",
    "\n",
    "        ensemble_dim = len(self.ensemble_shape)\n",
    "        features = to_ensembled(self.ensemble_shape, features)\n",
    "\n",
    "        normalized_dim = len(self.normalized_shape)\n",
    "\n",
    "        batch_dim = len(features.shape) - ensemble_dim - normalized_dim - self.normalized_offset\n",
    "        normalized_range = tuple(range(\n",
    "            ensemble_dim,\n",
    "            ensemble_dim + batch_dim\n",
    "        )) + tuple(range(\n",
    "            -normalized_dim - self.normalized_offset,\n",
    "            -self.normalized_offset\n",
    "        ))\n",
    "\n",
    "        features = features - features.mean(dim=normalized_range, keepdim=True)\n",
    "        features = features / features.std(dim=normalized_range, keepdim=True)\n",
    "\n",
    "        if self.scale is not None:\n",
    "            scale = self.scale.unflatten(\n",
    "                ensemble_dim,\n",
    "                (1,) * batch_dim + self.normalized_shape[:1]\n",
    "            )\n",
    "\n",
    "            features = features * scale\n",
    "\n",
    "            if self.bias is not None:\n",
    "                bias = self.bias.unflatten(\n",
    "                    ensemble_dim,\n",
    "                    (1,) * batch_dim + self.normalized_shape[:1]\n",
    "                )\n",
    "                features = features + bias\n",
    "\n",
    "        return batch | {\"features\": features}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6bd6dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        dataset_train[\"features\"].shape[1]\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        dataset_train[\"features\"].shape[1],\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        256\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        256\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        256,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        256\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        256,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        128\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        128\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        10\n",
    "    )\n",
    ")\n",
    "\n",
    "optimizer = AdamW(model.parameters())\n",
    "\n",
    "log = train_supervised(\n",
    "    config,\n",
    "    dataset_train,\n",
    "    dataset_valid,\n",
    "    get_cross_entropy,\n",
    "    get_accuracy,\n",
    "    model,\n",
    "    optimizer,\n",
    "    out_features=10,\n",
    "    target_key=\"label\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f082b403",
   "metadata": {},
   "source": [
    "Delete the model and the optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f91d7de",
   "metadata": {},
   "outputs": [],
   "source": [
    "del model, optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e806991",
   "metadata": {},
   "source": [
    "## Try it on a CNN!\n",
    "\n",
    "Take the CNN you used last time and give it layer normalization (set `normalize_offset` so that you only normalize along the channel dimension) and dropout layers similary to how you did for the MLP. Train it on training and validation datasets with the unflattened features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d254075c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        3,\n",
    "        normalized_offset=2,\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Conv2D(\n",
    "        config,\n",
    "        3,\n",
    "        (3,3),\n",
    "        16,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        16,\n",
    "        normalized_offset=2,\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Conv2D(\n",
    "        config,\n",
    "        16,\n",
    "        (3,3),\n",
    "        32,\n",
    "        # init_multiplier=2 ** .5\n",
    "    ),\n",
    "    Pool2D(\n",
    "        config,\n",
    "        kernel_shape=(3,3),\n",
    "        stride=2\n",
    "    ),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        32,\n",
    "        normalized_offset=2,\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Conv2D(\n",
    "        config,\n",
    "        32,\n",
    "        (3,3),\n",
    "        64,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        64,\n",
    "        normalized_offset=2,\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Conv2D(\n",
    "        config,\n",
    "        64,\n",
    "        (3,3),\n",
    "        128\n",
    "    ),\n",
    "    Pool2D(\n",
    "        config\n",
    "    ),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        128\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        128,\n",
    "        init_multiplier=2 ** .5\n",
    "    ),\n",
    "    DictReLU(),\n",
    "    LayerNorm(\n",
    "        config,\n",
    "        128\n",
    "    ),\n",
    "    Dropout(config),\n",
    "    Linear(\n",
    "        config,\n",
    "        128,\n",
    "        10\n",
    "    )\n",
    ")\n",
    "\n",
    "optimizer = AdamW(model.parameters())\n",
    "\n",
    "dataset_train = {\n",
    "    \"features\": train_features,\n",
    "    \"label\": train_labels\n",
    "}\n",
    "dataset_valid = {\n",
    "    \"features\": valid_features,\n",
    "    \"label\": valid_labels\n",
    "}\n",
    "\n",
    "log = train_supervised(\n",
    "    config,\n",
    "    dataset_train,\n",
    "    dataset_valid,\n",
    "    get_cross_entropy,\n",
    "    get_accuracy,\n",
    "    model,\n",
    "    optimizer,\n",
    "    out_features=10,\n",
    "    target_key=\"label\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7e92286",
   "metadata": {},
   "source": [
    "## Dataset References\n",
    "\n",
    "[5] Alex Krizhevsky: *Learning Multiple Layers of Features from Tiny Images*. 2009. https://www.cs.toronto.edu/~kriz/cifar.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19ea7254",
   "metadata": {},
   "source": [
    "## License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a7412d3",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
