{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "23aa1f0d",
   "metadata": {},
   "source": [
    "# Approximating Convex Hull Volume with Deep Sets\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `defaultdict`, `Callable`, `Generator`, `Iterable`, `datasets`, `math`, `matplotlib.pyplot` as `plt`, `os`, `torch`, `torch.nn.functional` as `F`, `tqdm` and `Optional`.\n",
    "\n",
    "Moreover, import from `scipy.spatial` the class `ConvexHull`. We'll use it to generate the targets, volumes of convex hulls of point clouds.\n",
    "\n",
    "Finally, import:\n",
    "1. the function `get_random_reshuffler`, that you wrote in Notebook 0221,\n",
    "3. the functions `get_mlp` and `get_mse`, that you wrote in Notebook 0319,\n",
    "4. the classes `AdamW` and `Optimizer`, that you wrote in Notebook 0326,\n",
    "5. the functions `pbt_init` and `pbt_update`, that you wrote in Notebook 0328,\n",
    "6. the function `is_ensembled`, that you wrote in Notebook 0402,\n",
    "7. the function `get_minibatch`, that you wrote in Notebook 0402, as `get_array_minibatch`, and\n",
    "8. the functions `get_array_sequence_keys`, `get_dataset_size`, and `to_ensembled`, that you wrote in Notebook 0409."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ede705e",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06070a5a",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Today, we'll use PBT with population size 8. Thus, make this `(8,)`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    These three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `64`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    1. If you're using CPU, you can make this something like the training minibatch.\n",
    "    2. If you're using GPU, it is best to find the largest value here that does not give you an out of memory error. You can experiment with powers of 2. It may be convenient to use the left bit shift operator `<<`. For example, on my home computer, I could use an evaluation minibatch size of $2^{11}$, that is `1 << 11`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Let's make a switch to turn off PBT. You can use this to test if the algorithm is doing any good, or optimization with the initial hyperparameters works just as well. Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `100_001`.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `10_000`.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. Based on my experiments in the setting of Homework 9, maybe you can try `.8`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afd8d9bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14f6b767",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64eb21cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd931038",
   "metadata": {},
   "source": [
    "## Dataset Generation\n",
    "\n",
    "We will generate a supervised dataset as follows:\n",
    "1. The inputs are finite sequences $X\\in L\\mathbf R^d$ of $d$-dimensional vectors.\n",
    "2. The targets are the volumes, or areas in case $d=2$, of the convex hulls of the point clouds determined by the sequences $X$.\n",
    "\n",
    "We will refer to $d$ as the *ambient dimension* of the point sets."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3c14c44",
   "metadata": {},
   "source": [
    "### Sampling point clouds\n",
    "\n",
    "First of all, sample a 20-point, 2-dimensional set using a standard normal distribution. After setting the aspect ratio to `\"equal\"`, make a scatter plot of your point set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d93f638",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e167e208",
   "metadata": {},
   "source": [
    "### Getting Convex Hulls\n",
    "\n",
    "Now you can call [`scipy.spatial.ConvexHull`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.ConvexHull.html) on the point set to get a `ConvexHull` object.\n",
    "\n",
    "Its `simplices` attribute is an index matrix of shape `(facet_num, ambiend_dim)`, where each row gives the vertex indices of a *facet* of the convex hull, that is is a $(d-1)$-simplex. In particular, we get line segments in case $d=2$.\n",
    "\n",
    "Make the same aspect ratio adjustment and scatter plot. Afterwards, iterate through the rows of this index array and make line plots of the facets. Then show the plot.\n",
    "\n",
    "Finally, you can access the volume (or area) of the convex hull at the `volume` attribute. Print it out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "541b1afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1149416",
   "metadata": {},
   "source": [
    "### Generate the Dataset with Sparse Storage\n",
    "\n",
    "Time to generate full datasets! As discussed in Notebook 0409, to effectively store sequential data with varying length, we\n",
    "1. Store all the point cloud vertices in one matrix `vertices` of shape `(cumulated_size, ambient_dim)`, and\n",
    "2. Store the point cloud boundary indices in another vector `indptr` of shape `(dataset_size + 1,)` such that the vertices of the `i`th point cloud are `vertices[indptr[i]:indptr[i+1]]`\n",
    "\n",
    "To generate such a dataset, you can proceed as follows:\n",
    "1. You need to generate `dataset_size` dataset entry lengths between `subset_size_min` and `subset_size_max`. You can do this with `torch.randint`.\n",
    "2. Given the dataset entry lengths, you can use `torch.cumsum`, to create the tensor `indptr`. Note that the length of the `i`th entry should be `indptr[i+1] - indptr[i]`.\n",
    "3. In particular, the lengths of all the dataset entries are added up in the last entry of `indptr`. Therefore, you can take that as `cumulated_size` to use in sampling the cumulated vertex tensor `vertices`, using `torch.normal`.\n",
    "4. Finally, you can collect the convex hull volumes of the point clouds `vertices[indptr[i]:indptr[i+1]]` using `ConvexHull.volume`, as above. Look out that the volume tensor should be unsqueezed to have shape `(datset_size, 1)`.\n",
    "\n",
    "Once you're done:\n",
    "1. generate a dataset with 5 entries,\n",
    "2. iterate through the entries,\n",
    "3. plot the vertices with convex hulls as above,\n",
    "4. recalculate the volumes, and\n",
    "5. check if the recalculated volumes agree with the ones in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2119530b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_convex_hull_dataset(\n",
    "    ambient_dim: int,\n",
    "    dataset_size: int,\n",
    "    device: int | str | torch.device,\n",
    "    subset_size_max: int,\n",
    "    subset_size_min: int,\n",
    "    std=1.\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Generate a supervised dataset mapping point clouds\n",
    "    to the volumes of their convex hulls.\n",
    "\n",
    "    The number of points in a point cloud\n",
    "    is sampled with uniform distribution from the closed interval\n",
    "    `[subset_size_min, subset_size_max]`\n",
    "    and the coordinates of the points are sampled from\n",
    "    the normal distribution with center `0.` and std `std`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    ambient_dim : `int`  \n",
    "        The dimension of the Euclidean space\n",
    "        the point clouds are to be finite subsets of.\n",
    "    dataset_size : `int`  \n",
    "        The number of dataset entries to generate.\n",
    "    device : `int | str | torch.device`\n",
    "        Device to store the dataset on.\n",
    "    subset_size_max : `int`  \n",
    "        The maximum number of points in a point cloud.\n",
    "    subset_size_min : `int`  \n",
    "        The minimum number of points in a point cloud.\n",
    "    std : `float`, optional  \n",
    "        The std of the normal distribution\n",
    "        the point coordinates are sampled from.\n",
    "        Default: `1.`\n",
    "\n",
    "    Returns \n",
    "    -------\n",
    "    The dataset, in the form of a dictionary with\n",
    "    `torch.Tensor`-valued keys `\"indptr\"`, `\"vertices\"` and `\"volume\"`,\n",
    "    which store the dataset as follows:\n",
    "    The `i`-th dataset entry has vertices `vertices[indptr[i]:indptr[i+1]]`\n",
    "    and volume `volume[i]`.\n",
    "\n",
    "    For ease of use with supervised learning algorithms,\n",
    "    the tensor `volume` is unsqueezed and has shape `(dataset_size, 1)`.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1813f98a",
   "metadata": {},
   "source": [
    "Generate train, validation and test datasets, with `80_000`, `10_000`, and `10_000` dataset entries. In each, point clouds should have between `10` and `100` entires."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d402983",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3c4a27b",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "### Fixed `get_minibatch`\n",
    "\n",
    "The function `get_minibatch` I put in Notebook 0409 had a mistake that gave an error when applied to a dataset with a sequential input from $L\\mathscr X_0$ where $\\mathscr X_0\\subseteq\\mathbf R^d$. I provide the fixed function below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "477a13e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_minibatch(\n",
    "    dataset: dict,\n",
    "    indices: torch.Tensor,\n",
    "    indptr_key=\"indptr\",\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Returns the minibatch of an ensembled dataset\n",
    "    given by ensembled indices.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset: `dict`\n",
    "        A dataset given as a dictionary with `torch.Tensor` values.\n",
    "    indices: `torch.Tensor`\n",
    "        An index tensor for the dataset.\n",
    "        Each value in the dataset should have shape prefixed by\n",
    "        the shape of the index tensor.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor. Default: `indptr`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The minibatch given by the dataset and the index tensor.\n",
    "    The extra key `\"mask\"` is mapped to the mask tensor of\n",
    "    entries that are not padding entries.\n",
    "    \"\"\"\n",
    "    minibatch = {}\n",
    "\n",
    "    dense_keys, sequence_keys = get_array_sequence_keys(\n",
    "        dataset, indptr_key=indptr_key\n",
    "    )\n",
    "\n",
    "    # print(\"dense\")\n",
    "    # for key in dense_keys:\n",
    "    #     print(key, dataset[key].shape, flush=True)\n",
    "\n",
    "    # print(\"sequence\")\n",
    "    # for key in sequence_keys:\n",
    "    #     print(key, dataset[key].shape, flush=True)\n",
    "\n",
    "    # print(\"indptr\", dataset[indptr_key].shape, dataset[indptr_key][-1], flush=True)\n",
    "\n",
    "    if len(dense_keys) > 0:\n",
    "        minibatch.update({\n",
    "            key: dataset[key].gather(\n",
    "                len(indices.shape) - 1,\n",
    "                indices.reshape(\n",
    "                    indices.shape + (1,) * (\n",
    "                        len(dataset[key].shape) - len(indices.shape)\n",
    "                    )\n",
    "                ).expand(\n",
    "                    indices.shape + dataset[key].shape[len(indices.shape):]\n",
    "                )\n",
    "            )\n",
    "            for key in dense_keys\n",
    "        })\n",
    "\n",
    "    if len(sequence_keys) > 0:\n",
    "        indptr_left, indptr_right = (\n",
    "            dataset[indptr_key].gather(\n",
    "                -1,\n",
    "                i\n",
    "            )\n",
    "            for i in (indices, indices + 1)\n",
    "        )\n",
    "        \n",
    "        sizes = indptr_right - indptr_left\n",
    "        sizes_max = sizes.max()\n",
    "        \n",
    "        sequence_indices = (\n",
    "            indptr_left[..., None]\n",
    "          + torch.arange(sizes_max, device=indices.device)\n",
    "        )\n",
    "\n",
    "        mask: torch.Tensor = sequence_indices < indptr_right[..., None]\n",
    "        minibatch[\"mask\"] = mask\n",
    "        sequence_indices[~mask] = 0\n",
    "\n",
    "        minibatch_shape = indices.shape\n",
    "        for key in sequence_keys:\n",
    "            data_raw: torch.Tensor = dataset[key]\n",
    "\n",
    "            feature_dims = data_raw.shape[len(minibatch_shape):]\n",
    "\n",
    "            data = data_raw.gather(\n",
    "                len(minibatch_shape) - 1,\n",
    "                sequence_indices.reshape(\n",
    "                    minibatch_shape[:-1]\n",
    "                  + (minibatch_shape[-1] * sizes_max,)\n",
    "                  + (1,) * len(feature_dims)\n",
    "                ).expand(\n",
    "                    minibatch_shape[:-1]\n",
    "                  + (minibatch_shape[-1] * sizes_max,)\n",
    "                  + feature_dims\n",
    "                )\n",
    "            ).reshape(\n",
    "                minibatch_shape\n",
    "              + (sizes_max,)\n",
    "              + data_raw.shape[len(minibatch_shape):]\n",
    "            )\n",
    "            minibatch[key] = data\n",
    "\n",
    "    return minibatch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad70ad19",
   "metadata": {},
   "source": [
    "Since we modified `get_minibatch`, we need to redefine `get_dataloader_random_reshuffle`. You can just copy here the code you wrote for Notebook 0409. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4a4b359",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloader_random_reshuffle(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    indptr_key=\"indptr\",\n",
    "    minibatch_size: Optional[int] = None\n",
    ") -> Generator[dict]:\n",
    "    \"\"\"\n",
    "    Given a dataset as a dictionary with tensor values,\n",
    "    creates a random reshuffling (without replacement) dataloader\n",
    "    that yields minibatch dictionaries indefinitely.\n",
    "    Support arbitrary ensemble shapes and sequential data.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        ensemble_shape : tuple[int]\n",
    "            The required ensemble shapes of the outputs.\n",
    "    dataset : `dict`\n",
    "        Dataset with `torch.Tensor` values.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor. Default: `indptr`.\n",
    "    minibatch_size : `int`, optional\n",
    "        Minibatch size. If not given, it is `config[\"minibatch_size\"]`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of minibatch dictionaries.\n",
    "    Their extra keys `\"mask\"` map to the mask tensors of\n",
    "    entries that are not padding entries.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc41a348",
   "metadata": {},
   "source": [
    "### `DeepSet` with arbitrary `Embedding`\n",
    "\n",
    "Recall that the module `CategoricalDeepSet` had a hardcoded `Embedding` in it, that mapped discrete entries $x\\in\\mathscr X_0\\cong[n]$ to processed feature vectors $g(x)\\in\\mathbf R^\\ell$.\n",
    "\n",
    "Let's refactor the deep set module and make the embedding an arbitrary submodule! This way, we can use as embedding an MLP $\\mathbf R^d\\xrightarrow g\\mathbf R^\\ell$, as needed in the present case of point clouds.\n",
    "\n",
    "Once you're done, create the following models:\n",
    "1. An MLP of input dimension 2, 3 hidden layers of width 8, and output dimension 32.\n",
    "2. A `DeepSet` with embedding the MLP created above, 3 hidden layers of width 128, and output dimension 1.\n",
    "\n",
    "Afterwards, get a minibatch from a train dataloader on the train dataset, and get the output of the deep set on the minibatch. Is its shape what you expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9ba1520",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeepSet(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready deep set.\n",
    "\n",
    "    It is composed of an embedding and an outgoing MLP.\n",
    "\n",
    "    In a forward call:\n",
    "    1. First, we get the embedding vectors.\n",
    "    2. Then, we average the embedding vectors over the sequence dimension.\n",
    "    3. Finally, we apply the outgoing MLP.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    embedding : `torch.Module`\n",
    "        The model that transforms the input to element-wise\n",
    "        embedding vectors.\n",
    "    embedding_dim : `int`\n",
    "        The number of dimensions of embedding vectors\n",
    "        that the `embedding` model outputs.\n",
    "    target_key : `str`\n",
    "        The key mapped to the input tensor in the dataset.\n",
    "    out_features : `int`\n",
    "        The number of output features.\n",
    "    hidden_layer_num : `int`, optional\n",
    "        If `hidden_layer_sizes` is not given, we create an outgoing MLP with\n",
    "        `hidden_layer_num` hidden layers of\n",
    "        `hidden_layer_size` dimensions.\n",
    "    hidden_layer_size : `int`, optional\n",
    "        If `hidden_layer_sizes` is not given, we create an outgoing MLP with\n",
    "        `hidden_layer_num` hidden layers of\n",
    "        `hidden_layer_size` dimensions.\n",
    "    hidden_layer_sizes: `Iterable[int]`, optional\n",
    "        If given, each entry gives a hidden layer with the given size\n",
    "        for the outgoing MLP.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required keys:\n",
    "        input : `torch.Tensor`\n",
    "            Tensor of token IDs, of shape\n",
    "            `batch_shape + (sequence_dim,)` or\n",
    "                `ensemble_shape + batch_shape + (sequence_dim,)`,\n",
    "            plus additional dimensions that the embedding model may require.\n",
    "        `\"mask\"` : `torch.Tensor`\n",
    "            Mask showing which entries are not padding, of shape\n",
    "            `batch_shape + (sequence_dim,)` or\n",
    "            `ensemble_shape + batch_shape + (sequence_dim,)`\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        embedding: torch.nn.Module,\n",
    "        embedding_dim: int,\n",
    "        input_key: str,\n",
    "        out_features: int,\n",
    "        hidden_layer_num: Optional[int] = None,\n",
    "        hidden_layer_size: Optional[int] = None,\n",
    "        hidden_layer_sizes: Optional[Iterable[int]] = None\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> torch.Tensor:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a736cd67",
   "metadata": {},
   "source": [
    "### `get_output_by_batches`\n",
    "\n",
    "In Notebook 0409, you wrote the function `evaluate_model` that evaluated a model by getting predictions in minibatches. Let's refactor the part of this that collects model predictions by minibatches. Thus, we get a function that output the predicted targets on an entire dataset. This will be useful when we'll make a plot to display the predicted targets.\n",
    "\n",
    "Write the function, then use it to get the predicted target tensor on the entire validation dataset. Print its shape. Is it what you would expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43d53ddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_output_by_batches(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    model: torch.nn.Module,\n",
    "    out_features: int,\n",
    "    indptr_key=\"indptr\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Get the output of a model in a single tensor for a full dataset,\n",
    "    but collected via evaluation by minibatches.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Size of consecutive minibatches to take from the dataset.\n",
    "            To be set according to RAM or GPU memory capacity.\n",
    "    dataset : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to evaluate.\n",
    "    out_features : `int`  \n",
    "        The number of output features of the model.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d1614d6",
   "metadata": {},
   "source": [
    "### `evaluate_model`\n",
    "\n",
    "Refactor the function `evaluate_model` you wrote in Notebook 0409 so that it uses `get_output_by_batches`. Note that this will not change the function's functionality, but it will make the code more maintainable.\n",
    "\n",
    "Evaluate the model by getting the MSE on the validation dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7ccb797",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    indptr_key=\"indptr\",\n",
    "    target_key=\"target\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Evaluate a model on a supervised dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Size of consecutive minibatches to take from the dataset.\n",
    "            To be set according to RAM or GPU memory capacity.\n",
    "    dataset : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    get_metric : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        Function to get the metric from a pair of\n",
    "        predicted and target value tensors.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to evaluate.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "451b245c",
   "metadata": {},
   "source": [
    "### `train_supervised`\n",
    "\n",
    "Make the following changes to the function `train_supervised` you wrote in Notebook 0409:\n",
    "1. Some hyperparameter configurations can make training numerically unstable, which results in NaN (Not a Number) values. To avoid propagation of these values, we can convert them to the worst possible losses and metrics in PBT updates, so that population members with NaN values get replaced by those without. To this end, when you calculate training and validation losses and metrics in the beginning of the evaluation phase of the function, apply the `nan_to_num` method of the loss and metric tensors to convert NaN values to $\\infty$ and $-\\infty$, respectively.\n",
    "2. Where you check if one of the best last metrics is better than the all-time best, you should also record the population member ID that has the best last metric. If the best last metric does is better than the all-time best, then when you update the all-time best, also record at the `\"best parameters\"` key of the log dictionary the parameters of the best model. You can do this by looping over the key-value pairs of the `state_dict` of the model, and taking the subtensors of the parameters given by the best metric population ID.\n",
    "\n",
    "Create an `AdamW` optimizer for the model parameters. Run training, letting the `get_metric` function be negative `get_mse`, and keep the log dictionary you get by assigning it to a variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "404f1bff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_supervised(\n",
    "    config: dict,\n",
    "    dataset_train: dict,\n",
    "    dataset_valid: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    optimizer: Optimizer,\n",
    "    target_key=\"target\"\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Population-based training on a supervised learning task.\n",
    "    Tuned hyperparameters are given by raw values and transformations.\n",
    "    This way, the hyperparameters are perturbed by\n",
    "    additive noise on raw values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            Minibatch size to use in a training step.\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Minibatch size to use in evaluation.\n",
    "            On CPU, should be about the same as `minibatch_size`.\n",
    "            On GPU, should be as big as possible without\n",
    "            incurring an Out of Memory error.\n",
    "        `\"pbt\"` : `bool`\n",
    "            Whether to use PBT updates in validations.\n",
    "            If `False`, the algorithm just samples hyperparameters at start,\n",
    "            then keeps them constant.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"steps_without_improvement`\" : `int`\n",
    "            If the number of training steps without improvement\n",
    "            exceeds this value, then training is stopped.\n",
    "        `\"valid_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    dataset_train : `dict`\n",
    "        The dataset to train the model on.\n",
    "    dataset_valid : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of losses per ensemble member.\n",
    "    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of metrics per ensemble member.\n",
    "        We assume a greater metric is better.\n",
    "    `model` : `torch.nn.Module`\n",
    "        The model ensemble to tune.\n",
    "    `optimizer` : `Optimizer`\n",
    "        An optimizer that tracks the parameters of `model`.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following key-value pairs:\n",
    "        `\"best parameters\"` : `dict`  \n",
    "            The state dictionary of the model with the best metric\n",
    "            encountered during training.\n",
    "        `\"source mask\"` : `torch.Tensor`\n",
    "            The source masks of population members\n",
    "            that were replace by other members in a PBT update\n",
    "        `\"target indices\"` : `torch.Tensor`\n",
    "            The indices of population members\n",
    "            that the member where the source mask is to were replaced with.\n",
    "        `\"training loss\"` : `torch.Tensor`\n",
    "            The training losses at evaluation steps.\n",
    "        `\"training metric\"` : `torch.Tensor`\n",
    "            The training metrics at evaluation steps.\n",
    "        `\"validation loss\"` : `torch.Tensor`\n",
    "            The validation losses at evaluation steps.\n",
    "        `\"validation metric\"` : `torch.Tensor`\n",
    "            The validation metrics at evaluation steps.\n",
    "\n",
    "        In addition, for each tuned hyperparameter name,\n",
    "        we include a `torch.Tensor` of values per update.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eafc0c9a",
   "metadata": {},
   "source": [
    "## Evaluation on Test Set\n",
    "\n",
    "Finally, let's evaluate the best model on the test set. To that end:\n",
    "1. Create a new configuration dictionary with one change to the original: change `ensemble_shape` to `()`.\n",
    "2. Create a new MLP and `DeepSet` just like before, besides using the new configuration dictionary.\n",
    "3. Use the `load_state_dict` method of the new model, to load the best parameters at key `\"best parameters\"` of the training log dictionary.\n",
    "4. Use `evaluate_model` to evaluate the model on the validation dataset. You should see the same validation score as the best score during training.\n",
    "5. Use `get_output_by_batches` to get the predicted target on the full test dataset.\n",
    "6. Just like in Notebook 0207, get an argsort of the true targets, and sort the true and predicted targets when making a line and scatter plot of them, respectively.\n",
    "7. Also, print out the MSE on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73e1beab",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8da8ed25",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under CC BY-NC-SA 4.0. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
