{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ebf9c0f5",
   "metadata": {},
   "source": [
    "# Implementing a CNN\n",
    "\n",
    "Now, we shall undertake implementing a CNN for image classification.\n",
    "\n",
    "1. Although `pytorch` has built-in convolution and pooling layers, those don't support ensembles, so we'll have to make our own.\n",
    "2. I give the indices relevant to convolution above. But indexing with these would result in copying tensors, thus wasting memory and compute time. However, the build-in convolution and pooling operations do are effective, so we can use those."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf411eed",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `defaultdict`, `Callable`, `datasets`, `math`, `torch`, `torch.nn.functional` as `F`, `tqdm` and `Optional`.\n",
    "\n",
    "Moreover, import:\n",
    "1. The functions `get_accuracy` and `get_cross_entropy`, that you wrote in Notebook 0221,\n",
    "1. The function `normalize_features`, that you wrote in Notebook 0321,\n",
    "1. The classes `AdamW` and `Optimizer`, that you wrote in Notebook 0326.\n",
    "3. The function `get_dataset_size`, that you wrote in Notebook 0409.\n",
    "3. The functions `get_dataloader_random_reshuffle` `get_minibatch`, and `to_ensembled`, that you wrote in Notebook 0416."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0dadd32",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b72a3041",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_path\"`: `str`  \n",
    "    As we achieved 98% accuracy on MNIST with an MLP and AdamW in Notebook 0326, we'll take a step up and train on CIFAR-10: https://huggingface.co/datasets/uoft-cs/cifar10 [10], that we used previously in Homework 2.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Make this `(16,)`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    These three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `64`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    On my home computer, I can make this `128`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this `10_001`.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `10_000`.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. Based on my experiments in the setting of Homework 9, maybe you can try `.8`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47fa2a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0974b31",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37fe9a06",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d15ed55",
   "metadata": {},
   "source": [
    "## Dataset\n",
    "\n",
    "### Load the Dataset\n",
    "\n",
    "Load the dataset specified by the configuration dictionary. Make a train-valid split of its `\"train\"` split. Print tensor shapes and datatypes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35775427",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8884c41",
   "metadata": {},
   "source": [
    "Note that the shapes of the images are `(batch dimension, channel dimension, vertical dimension, horizontal dimension)`. Usually, I prefer feature dimensions (here: channel) after spatial dimensions. But this time, to be more compatible with `torch` conventions, we'll keep the original order.\n",
    "\n",
    "Still, convert the feature tensors to floating points and normalize them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede63da2",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0de04aab",
   "metadata": {},
   "source": [
    "Make training and validation datasets as dictionaries with appropriate `\"features\"` and `\"label\"` keys. Create a training dataloader and print minibatch entry shapes, to see if the machinery is working as expected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8bfe1c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fd9998b",
   "metadata": {},
   "source": [
    "## New Layers\n",
    "\n",
    "Let's implement the new, 2d convolution and 2d mean pool layers. For composability in the setting, let's make them output dictionaries. We'll also reimplement the linear and the ReLU layers this way.\n",
    "\n",
    "In all the layers we implement today, in the minibatch dictionary it receives, it should read and update the value at the `\"features\"` key. You can use the copy update operator `|` for this."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2cc5d0c",
   "metadata": {},
   "source": [
    "### `Conv2D`\n",
    "\n",
    "Let's write our own convolution layer! As we'll want to train an ensemble, we can't use the built-in module `torch.nn.Module.Conv2D`. However, we can make do with the function `F.conv2d` as it supports blocked connections via the `groups` keyword argument. [Follow this link](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html) for the documentation. You can implement the ensemble-ready two-dimensional convolution layer using this function and some tensor shape transformations.\n",
    "\n",
    "When you're finished, make a 2D convolution layer with kernel shape `(3,3)` and check the shape of its output features on a training minibatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43b55e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv2D(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready, two-dimensional convolution layer\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    in_channels : `int`\n",
    "        The number of input channels.\n",
    "    kernel_shape : `tuple[int]`\n",
    "        The kernel shape.\n",
    "    out_channels : `int`\n",
    "        The number of output channels.\n",
    "    bias : `bool`, optional\n",
    "        Whether to include bias along the output channels.\n",
    "    dilation : `int | tuple[int]`, optional\n",
    "        The spacing between kernel elements, in all directions,\n",
    "        or per direction. Default: `1`.\n",
    "    init_multiplier : `float`, optional\n",
    "        We initialize linear maps with Glorot normal initialization,\n",
    "        that is using the centered normal distribution\n",
    "        with standard deviation `out_channels ** -.5` times this value.\n",
    "        Default: `1.`.\n",
    "    padding : `int | str | tuple[int]`, optional\n",
    "        The stride in all directions or per direction.\n",
    "        Alternatively, `\"valid\"` is the same as `0`,\n",
    "        and `\"same\"` pads the input so the output has the same shape\n",
    "        as the input.\n",
    "        Default: `0`.\n",
    "    stride : `int | tuple[int]`, optional\n",
    "        The stride in all directions or per direction.\n",
    "        Default: `1`.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required key:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of features, of shape\n",
    "            `batch_shape + (in_channels, height, width)` or\n",
    "            `ensemble_shape + batch_shape + (in_channels, height, width)`\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        in_channels: int,\n",
    "        kernel_shape: tuple[int],\n",
    "        out_channels: int,\n",
    "        bias=True,\n",
    "        dilation=1,\n",
    "        init_multiplier=1.,\n",
    "        padding=0,\n",
    "        stride=1\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> torch.Tensor:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3990dc7f",
   "metadata": {},
   "source": [
    "### `Pool2D`\n",
    "\n",
    "Let's now write a mean pool operation! It can be created similarly to `Conv2D`, this time using `F.avg_pool2d` in case we are pooling along the kernel displacements.\n",
    "\n",
    "When finished test output shapes giving either\n",
    "1. `kernel_shape` or\n",
    "2. `sequence_dim_num`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0669ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Pool2D(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready two-dimensional mean pool operation\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    kernel_shape : `int | tuple[int]`, optional\n",
    "        The kernel shape.\n",
    "        If given, we pool along the kernel displacements.\n",
    "        Otherwise, we pool along all the two sequential dimensions.\n",
    "    padding : `int | tuple[int]`, optional\n",
    "        The padding in all directions or per direction.\n",
    "        It is used if `kernel_shape` is given.\n",
    "        Default: `0`.\n",
    "    stride : `int | tuple[int]`, optional\n",
    "        The stride in all directions or per direction.\n",
    "        It is used if `kernel_shape` is given.\n",
    "        Default: `1`.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required key:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of features, of shape\n",
    "            `batch_shape + (in_channels, height, width)` or\n",
    "            `ensemble_shape + batch_shape + (in_channels, height, width)`\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        kernel_shape: Optional[tuple[int]] = None,\n",
    "        padding=0,\n",
    "        stride=1\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> dict:        \n",
    "        raise NotImplementedError\n",
    "    \n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d730a59",
   "metadata": {},
   "source": [
    "### `Linear`\n",
    "\n",
    "So that we can write a full CNN using `torch.nn.Sequential`, take the class `Linear` that you wrote in Notebook 0319 and rewrite it so that\n",
    "1. it takes in and outputs dictionaries and\n",
    "2. you can tell it that the feature dimension is not the last one.\n",
    "\n",
    "Make one and try it out on a minibatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a66c072",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Linear(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready affine transformation `y = x^T W + b`.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    in_features : `int`\n",
    "        The number of input features\n",
    "    out_features : `int`\n",
    "        The number of output features.\n",
    "    bias : `bool`, optional\n",
    "        Whether the model should include bias. Default: `True`.\n",
    "    feature_dim_index: `int`, optional\n",
    "        The index of the feature dimension. Default: `-1`,\n",
    "    init_multiplier : `float`, optional\n",
    "        The weight parameter values are initialized following\n",
    "        a normal distribution with center 0 and std\n",
    "        `in_features ** -.5` times this value. Default: `1.`\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required key:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of features. The feature dimension is determined by\n",
    "            `feature_dim_index`.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        in_features: int,\n",
    "        out_features: int,\n",
    "        bias=True,\n",
    "        feature_dim_index=-1,\n",
    "        init_multiplier=1.\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        batch: dict\n",
    "    ) -> dict:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52b13a91",
   "metadata": {},
   "source": [
    "### `DictReLU`\n",
    "\n",
    "Finally, create a `DictReLU` layer, that applies ReLU to the tensor at the `\"features\"` key of its input.\n",
    "\n",
    "Test it on a minibatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "222dbe0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DictReLU(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Applies ReLU elementwise to the feature tensor in a dictionary.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. ReLU is applied to tensor\n",
    "        at the `\"features\"` key.\n",
    "    \"\"\"\n",
    "    def forward(self, batch: dict) -> dict:\n",
    "\n",
    "        raise NotImplementedError\n",
    "    \n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b88fc528",
   "metadata": {},
   "source": [
    "### Make a CNN\n",
    "\n",
    "Time to make a full CNN! There's a lot of create freedom here. The general idea is to alternate the following:\n",
    "1. Process local data by one or more `Conv2D` layers, with ReLUs in between.\n",
    "2. Aggregate data by a `Pool2D` layer of stride > 1.\n",
    "\n",
    "Rules of thumb:\n",
    "1. As you aggregate data you can have more channels.\n",
    "2. Keep kernel sizes small odd numbers larger than 1.\n",
    "3. The last `Pool2D` layer should be a total pooling one, with `kernel_shape` not given. It should be followed by a small MLP.\n",
    "\n",
    "Make a model, then check the shape of its output on a minibatch. The output should have shape appropriate for logits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd763da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26652068",
   "metadata": {},
   "source": [
    "## Train a CNN\n",
    "\n",
    "Time to train this thing! You can use `train_supervised`, that you wrote in Notebook 0416. Note that that function expects the model to output prediction tensors, not dictionaries. So, compose the `model` with a function that, given a dictionary, outputs its value at key `\"features\"`. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08a5355d",
   "metadata": {},
   "source": [
    "Adapt the function `get_output_by_batches`, that you wrote in Notebook 0416, to the fact that now `model` outputs a dictionary that has the predict tensors at its `\"features\"` key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d76a54f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_output_by_batches(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    model: torch.nn.Module,\n",
    "    out_features: int,\n",
    "    indptr_key=\"indptr\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Get the output of a model in a single tensor for a full dataset,\n",
    "    but collected via evaluation by minibatches.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Size of consecutive minibatches to take from the dataset.\n",
    "            To be set according to RAM or GPU memory capacity.\n",
    "    dataset : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to evaluate.\n",
    "    out_features : `int`  \n",
    "        The number of output features of the model.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abe82c12",
   "metadata": {},
   "source": [
    "Make the following small change to `evaluate_model`, that you wrote in Notebook 0416: give it an optional keyword argument `out_features`. With this, we can override the number of output features in predict tensors, that is the last dimension of the target tensor in the dataset by default. We need it as in case of classification, the target tensor is a label tensor, and the output features should be the number of possible labels, not the number entries in the label tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab659f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    indptr_key=\"indptr\",\n",
    "    out_features: Optional[int] = None,\n",
    "    target_key=\"target\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Evaluate a model on a supervised dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Size of consecutive minibatches to take from the dataset.\n",
    "            To be set according to RAM or GPU memory capacity.\n",
    "    dataset : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    get_metric : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        Function to get the metric from a pair of\n",
    "        predicted and target value tensors.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to evaluate.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    out_features: `int`, optional\n",
    "        The number of output features in the predict tensors.\n",
    "        By default, it is the last dimension of the target tensor.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68ebd3d0",
   "metadata": {},
   "source": [
    "Make the following changes to `train_supervised`, that you wrote in Notebook 0423:\n",
    "1. Extend this also with an `out_features` keyword argument so that, in a classification task, you can give it the number of classes manually.\n",
    "2. To save time, make it only calculate the validation metric at evaluation.\n",
    "3. Adapt it to the fact that now the model outputs a dictionary that has the predict tensor at its `\"features\"` key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff8ebcc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_supervised(\n",
    "    config: dict,\n",
    "    dataset_train: dict,\n",
    "    dataset_valid: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    optimizer: Optimizer,\n",
    "    out_features: Optional[int] = None,\n",
    "    target_key=\"target\"\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Population-based training on a supervised learning task.\n",
    "    Tuned hyperparameters are given by raw values and transformations.\n",
    "    This way, the hyperparameters are perturbed by\n",
    "    additive noise on raw values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            Minibatch size to use in a training step.\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Minibatch size to use in evaluation.\n",
    "            On CPU, should be about the same as `minibatch_size`.\n",
    "            On GPU, should be as big as possible without\n",
    "            incurring an Out of Memory error.\n",
    "        `\"pbt\"` : `bool`\n",
    "            Whether to use PBT updates in validations.\n",
    "            If `False`, the algorithm just samples hyperparameters at start,\n",
    "            then keeps them constant.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"steps_without_improvement`\" : `int`\n",
    "            If the number of training steps without improvement\n",
    "            exceeds this value, then training is stopped.\n",
    "        `\"valid_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    dataset_train : `dict`\n",
    "        The dataset to train the model on.\n",
    "    dataset_valid : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of losses per ensemble member.\n",
    "    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of metrics per ensemble member.\n",
    "        We assume a greater metric is better.\n",
    "    `model` : `torch.nn.Module`\n",
    "        The model ensemble to tune.\n",
    "    `optimizer` : `Optimizer`\n",
    "        An optimizer that tracks the parameters of `model`.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    out_features: `int`, optional\n",
    "        The number of output features in the predict tensors.\n",
    "        By default, it is the last dimension of the target tensor.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following key-value pairs:\n",
    "        `\"best parameters\"` : `dict`  \n",
    "            The state dictionary of the model with the best metric\n",
    "            encountered during training.\n",
    "        `\"source mask\"` : `torch.Tensor`\n",
    "            The source masks of population members\n",
    "            that were replace by other members in a PBT update\n",
    "        `\"target indices\"` : `torch.Tensor`\n",
    "            The indices of population members\n",
    "            that the member where the source mask is to were replaced with.\n",
    "        `\"validation metric\"` : `torch.Tensor`\n",
    "            The validation metrics at evaluation steps.\n",
    "\n",
    "        In addition, for each tuned hyperparameter name,\n",
    "        we include a `torch.Tensor` of values per update.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97036bf4",
   "metadata": {},
   "source": [
    "## Dataset Reference\n",
    "\n",
    "[1] Alex Krizhevsky: *Learning Multiple Layers of Features from Tiny Images*. 2009. https://www.cs.toronto.edu/~kriz/cifar.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1cf38c4",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bea5abbb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
