{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c835247a",
   "metadata": {},
   "source": [
    "# Document Classification with Deep Sets\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `defaultdict`, `Callable`, `Generator`, `Iterable`, `datasets`, `math`, `mosestokenizer`, `os`, `torch`, `torch.nn.functional` as `F`, `tqdm` and `Optional`.\n",
    "\n",
    "Moreover, import:\n",
    "1. the function `get_random_reshuffler`, that you wrote in Notebook 0221,\n",
    "2. the functions `get_binary_accuracy`, `get_binary_cross_entropy` and `get_seed`, that you wrote in Notebook 0228,\n",
    "3. the function `get_mlp`, that you wrote in Notebook 0319,\n",
    "4. the classes `AdamW` and `Optimizer`, that you wrote in Notebook 0326,\n",
    "5. the functions `pbt_init` and `pbt_update`, that you wrote in Notebook 0328,\n",
    "6. the function `is_ensembled`, that you wrote in Notebook 0402 and\n",
    "7. the function `get_minibatch`, that you wrote in Notebook 0402, as `get_array_minibatch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56c24e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06ee80a0",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_path\"`: `str`  \n",
    "    Get this from the dataset page  \n",
    "    https://huggingface.co/datasets/dair-ai/emotion\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Today, we'll use PBT with population size 8. Thus, make this `(8,)`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    These three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `64`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    1. If you're using CPU, you can make this something like the training minibatch.\n",
    "    2. If you're using GPU, it is best to find the largest value here that does not give you an out of memory error. You can experiment with powers of 2. It may be convenient to use the left bit shift operator `<<`. For example, on my home computer, I could use an evaluation minibatch size of $2^{11}$, that is `1 << 11`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Let's make a switch to turn off PBT. You can use this to test if the algorithm is doing any good, or optimization with the initial hyperparameters works just as well. Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `1001`.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `100`.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. Based on my experiments in the setting of Homework 9, maybe you can try `.8`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b99cd1a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ed592ba",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0efb1e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98503d85",
   "metadata": {},
   "source": [
    "## Data Preprocessing\n",
    "\n",
    "### Loading the Dataset\n",
    "\n",
    "Just like in Notebook 0305:\n",
    "1. Load the `train` split of the `unsplit` subset of the Emotion dataset [4].\n",
    "2. `filter` the dataset to keep entries with labels 0 (sadness) and 1 (joy) only.\n",
    "3. Make a 90%-10% train-valid split\n",
    "\n",
    "Print split datasets and their first entries to see if all's well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07a14098",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ae25bca",
   "metadata": {},
   "source": [
    "### Storing Sequential Data\n",
    "\n",
    "We will preprocess the data by transforming words to token indices. We will have to store sequences of varying lengths. To save space, we will store sequential data in 2 vectors:\n",
    "1. `token_ids`: this stores token IDs of all documents in order.\n",
    "2. `indptr`: the token IDs of the `i`-th document are `token_ids[indptr[i]:indptr[i+1]]`.\n",
    "\n",
    "Note: The variable name `indptr` stands for *index pointer*. It comes from the naming convention of `scipy` sparse matrices in Compressed Sparse Row format: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html\n",
    "\n",
    "1. Write the function below.\n",
    "2. Get its output on the validation set with no optional arguments.\n",
    "3. Print the sizes of the tensors and the vocabulary.\n",
    "3. Using the 2 vectors in the output, get the token IDs of the 100-th validation entry.\n",
    "4. Using the ID-to-token list of the output, convert the token IDs you got in part 3. to a document. Print it out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b5a328",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_corpus(\n",
    "    config: dict,\n",
    "    corpus: list[str],\n",
    "    fixed_vocabulary: Optional[bool] = None,\n",
    "    id2token: Optional[list] = None,\n",
    "    token2id: Optional[dict] = None\n",
    ") -> tuple[\n",
    "    torch.Tensor,\n",
    "    torch.Tensor,\n",
    "    list,\n",
    "    dict\n",
    "]:\n",
    "    \"\"\"\n",
    "    Tokenize a corpus using `mosestokenizer`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"device\"` : `int | str | torch.device`\n",
    "            Device to store tensors on.\n",
    "    corpus : `list[str]`\n",
    "        The document corpus.\n",
    "    fixed_vocabulary : `bool`, optional\n",
    "        Determines if the vocabulary given in the\n",
    "        `id2token` and `token2id` keyword arguments can be augmented\n",
    "        or tokens not in the vocabulary should be discarded.\n",
    "        If not given, then it will be `id2token is not None`.\n",
    "    id2token : `list[str]`\n",
    "        The mapping from token IDs to tokens.\n",
    "        If `not fixed_vocabulary`, then it is augmented\n",
    "        as new tokens are extracted from the corpus. Default: `[]`.\n",
    "    token2id : `dict`\n",
    "        The mapping from tokens to token IDs.\n",
    "        If `not fixed_vocabulary`, then it is augmented\n",
    "        as new tokens are extracted from the corpus. Default: `{}`.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    The quadruple of:\n",
    "    1. The vector of token IDs,\n",
    "    2. the vector of document index pointers,\n",
    "    3. the list `id2token` and\n",
    "    4. the dictionary `token2id`.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2464b733",
   "metadata": {},
   "source": [
    "Via another call to `tokenize_corpus`, get the training token IDs and training index pointers, while augmenting the vocabulary with new tokens from the training corpus.\n",
    "\n",
    "Once again, print sizes and the text of the 100-th entry as extracted from the tokenized data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e33f8be",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a8bcdb7",
   "metadata": {},
   "source": [
    "### Create Dataset Dictionaries\n",
    "\n",
    "As explained in Notebook 0402, for more versatility, we'll make preprocessed datasets, and minibatches thereof, dictionaries. So, make training and validation datasets using the token ids and index pointers you extracted and the labels in the downloaded dataset. The latter should be converted to tensors.\n",
    "\n",
    "Print keys and value shapes of the dictionaries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f2e43c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d354ecd9",
   "metadata": {},
   "source": [
    "## Dataloaders on Sequential Data\n",
    "\n",
    "In Notebook 0402, you created a dataloader that yields minibatches from datasets with an arbitrary number of data tensors. However, all tensors there represent *array* data, that is all entries have the same size.\n",
    "\n",
    "Note: *array* is my own terminology. I you know of more standard terminology for fixed-size data, please tell me."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78149150",
   "metadata": {},
   "source": [
    "### Getting Array and Sequence Keys\n",
    "\n",
    "Our first task in updating the dataloader to handle sequential features is to recognize array and sequential data keys. We will use rules for this that are in the docstrings of the following functions. Write the functions and print their outputs on the training and validation datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d8eca83",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset_size(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    indptr_key=\"indptr\"\n",
    ") -> int:\n",
    "    \"\"\"\n",
    "    Get the size of the potentially ensembled dataset,\n",
    "    which is defined as follows:\n",
    "\n",
    "    Let us call *batch dimension* of a tensor the dimension\n",
    "        at the entry of its shape\n",
    "        that comes after the ensemble shape entries\n",
    "    1. If the dataset has an index pointer tensor,\n",
    "        then its size is the batch the index pointer tensor minus 1.\n",
    "    2. Otherwise, we take any value of the dataset.\n",
    "        Then the dataset size is the batch dimension of the value tensor.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Requires key-value pair:\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape.\n",
    "    dataset : `dict`\n",
    "        The dataset.\n",
    "    indptr_key : `str`, optional\n",
    "        The key of the index pointer tensor. Default: `indptr`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The dataset size.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "436462e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_array_sequence_keys(\n",
    "    dataset: dict,\n",
    "    indptr_key=\"indptr\"\n",
    ") -> tuple[tuple[str], tuple[str]]:\n",
    "    \"\"\"\n",
    "    Get the array and sequence keys of the dataset,\n",
    "    which are defined as follows:\n",
    "\n",
    "    Given a key-value pair, we say that the key is\n",
    "    1. an array key, if the value tensor has the dataset size\n",
    "        at the batch dimension and\n",
    "    2. a sequence key, if the value tensor has the total number of tokens\n",
    "        at the batch dimension.\n",
    "\n",
    "    For the definition of batch dimension and dataset size,\n",
    "    see `get_dataset_size`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset : `dict`\n",
    "        The dataset.\n",
    "    indptr_key : `str`, optional\n",
    "        The key of the index pointer tensor. Default: `indptr`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The pair of\n",
    "    1. the tuple of array keys and\n",
    "    2. the tuple of sequence keys.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f392479",
   "metadata": {},
   "source": [
    "### Making Tensors Ensembled\n",
    "\n",
    "Our minibatch getter will assume that the dataset has the same ensemble shape as the indices it gets. Thus, first, we need to write a function that broadcasts the entries in a dataset to have ensemble dimensions as described in the configuration dictionary. You can use as inspiration the function `is_ensembled` you wrote in Notebook 0402.\n",
    "\n",
    "Write the function, apply it to the values of the training and validation datasets and print their shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "314ce49e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_ensembled(\n",
    "    ensemble_shape: tuple[int],\n",
    "    tensor: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    We say that a tensor is *ensembled*,\n",
    "    if its shape starts by the ensemble shape.\n",
    "\n",
    "    This function converts a tensor to an ensembled tensor.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7df7e6e",
   "metadata": {},
   "source": [
    "### Getting a Minibatch with Sequential Data\n",
    "\n",
    "Time to get minibatches!\n",
    "\n",
    "Recall that GPUs are happy with batched data. So we can no more use sequences of varying length. This means that we need to use a sequence dimension with value the maximum of the lengths of the sequences in the minibatch. The entries that are appended to the sequences to reach the maximum length in the minibatch are called *padding* entries. To show which entries are not padding entries, we use a mask.\n",
    "\n",
    "As the function is rather complicated, I wrote it for you. You can test it out on either the ensembled training or validation dataset by generating minibatch indices as a tensor of shape `ensemble_shape + (minibatch_size,)` with entries random integers in $[\\mathtt{dataset\\_size}]$.\n",
    "\n",
    "Fix an ensemble and batch index, then loop over the masked token IDs and print the tokens. You should see one full document (although the documents being tweets this may be difficult to gauge)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e11b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_minibatch(\n",
    "    dataset: dict,\n",
    "    indices: torch.Tensor,\n",
    "    indptr_key=\"indptr\",\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Returns the minibatch of an ensembled dataset\n",
    "    given by ensembled indices.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset: `dict`\n",
    "        A dataset given as a dictionary with `torch.Tensor` values.\n",
    "    indices: `torch.Tensor`\n",
    "        An index tensor for the dataset.\n",
    "        Each value in the dataset should have shape prefixed by\n",
    "        the shape of the index tensor.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor. Default: `indptr`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The minibatch given by the dataset and the index tensor.\n",
    "    The extra key `\"mask\"` is mapped to the mask tensor of\n",
    "    entries that are not padding entries.\n",
    "    \"\"\"\n",
    "    minibatch = {}\n",
    "\n",
    "    dense_keys, sequence_keys = get_array_sequence_keys(\n",
    "        dataset, indptr_key=indptr_key\n",
    "    )\n",
    "\n",
    "    if len(dense_keys) > 0:\n",
    "        minibatch.update(get_array_minibatch(\n",
    "            {\n",
    "                key: dataset[key]\n",
    "                for key in dense_keys\n",
    "            },\n",
    "            indices\n",
    "        ))\n",
    "\n",
    "    if len(sequence_keys) > 0:\n",
    "        indptr_left, indptr_right = (\n",
    "            dataset[indptr_key].gather(\n",
    "                -1,\n",
    "                i\n",
    "            )\n",
    "            for i in (indices, indices + 1)\n",
    "        )\n",
    "        \n",
    "        sizes = indptr_right - indptr_left\n",
    "        sizes_max = sizes.max()\n",
    "        \n",
    "        sequence_indices = (\n",
    "            indptr_left[..., None]\n",
    "          + torch.arange(sizes_max, device=indices.device)\n",
    "        )\n",
    "\n",
    "        mask: torch.Tensor = sequence_indices < indptr_right[..., None]\n",
    "        minibatch[\"mask\"] = mask\n",
    "        sequence_indices[~mask] = 0\n",
    "\n",
    "        minibatch_shape = indices.shape\n",
    "        for key in sequence_keys:\n",
    "            data_raw: torch.Tensor = dataset[key]\n",
    "            data = data_raw.gather(\n",
    "                len(minibatch_shape) - 1,\n",
    "                sequence_indices.reshape(\n",
    "                    minibatch_shape[:-1]\n",
    "                  + (minibatch_shape[-1] * sizes_max,)\n",
    "                  + data_raw.shape[len(minibatch_shape):]\n",
    "                )\n",
    "            ).reshape(\n",
    "                minibatch_shape\n",
    "              + (sizes_max,)\n",
    "              + data_raw.shape[len(minibatch_shape):]\n",
    "            )\n",
    "            minibatch[key] = data\n",
    "\n",
    "    return minibatch\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d99b8e69",
   "metadata": {},
   "source": [
    "### Getting a Dataloader with Sequential Data\n",
    "\n",
    "With the new `get_minibatch` function, you can update the function `get_dataloader_random_reshuffle` you wrote in Notebook 0402 with minimal changes.\n",
    "\n",
    "Write the function, get a training dataloader and print its keys and value shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f77638",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloader_random_reshuffle(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    indptr_key=\"indptr\",\n",
    "    minibatch_size: Optional[int] = None\n",
    ") -> Generator[dict]:\n",
    "    \"\"\"\n",
    "    Given a dataset as a dictionary with tensor values,\n",
    "    creates a random reshuffling (without replacement) dataloader\n",
    "    that yields minibatch dictionaries indefinitely.\n",
    "    Support arbitrary ensemble shapes and sequential data.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        ensemble_shape : tuple[int]\n",
    "            The required ensemble shapes of the outputs.\n",
    "    dataset : `dict`\n",
    "        Dataset with `torch.Tensor` values.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor. Default: `indptr`.\n",
    "    minibatch_size : `int`, optional\n",
    "        Minibatch size. If not given, it is `config[\"minibatch_size\"]`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of minibatch dictionaries.\n",
    "    Their extra keys `\"mask\"` map to the mask tensors of\n",
    "    entries that are not padding entries.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e02ffef6",
   "metadata": {},
   "source": [
    "## Building the Model\n",
    "\n",
    "### Ensembled Embeddings\n",
    "\n",
    "First of all, let's create a model that maps token IDs to embedding vectors:\n",
    "1. Its parameters should be a tensor of shape `ensemble_shape + (vocabulary_size, embedding_dim)`. You can initialize it with the standard normal distribution.\n",
    "2. When called:\n",
    "    1. It should receive an index tensor of shape `batch_shape` or `ensemble_shape + batch_shape`.\n",
    "    2. We'll want to use the `gather` method of the parameter tensor, with index tensor entries as token indices, to get an embedding tensor of shape `ensemble_shape + batch_shape + (embedding_dim,)`.\n",
    "        1. As the `gather` method will expect an index tensor of shape `ensemble_shape + (batch_dim, embedding_dim)`:\n",
    "            1. Use `to_ensembled`.\n",
    "            2. Flatten `batch_shape`\n",
    "            3. Add an extra 1 to the right of the shape and broadcast.\n",
    "        2. After using `gather`, you should `reshape` the output to the required shape.\n",
    "\n",
    "Write the class, get an instance and print the shape of its output on a training minibatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5f629c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Embedding(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready embedding.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    embedding_dim : `int`\n",
    "        The number of embedding dimensions.\n",
    "    vocabulary_size : `int`\n",
    "        The number of vocabulary entries.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    indices : `torch.Tensor`\n",
    "        The index tensor. It is required to be one of the following shapes:\n",
    "        1. `ensemble_shape + batch_shape`\n",
    "        2. `batch_shape`\n",
    "\n",
    "        Upon a call, the model thinks we're in the first case\n",
    "        if the first `len(ensemble_shape)` many entries of the\n",
    "        shape of the input tensor is `ensemble_shape`.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        embedding_dim: int,\n",
    "        vocabulary_size: int\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def forward(self, indices: torch.Tensor) -> torch.Tensor:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92f51653",
   "metadata": {},
   "source": [
    "Delete the embedding with `del` to release GPU memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a994b637",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9adffea",
   "metadata": {},
   "source": [
    "### Categorical Deep Set\n",
    "\n",
    "Equipped with an embedding, we can create a categorical deep set model. That is, one the input space of which is $L[n]$. It is composed of two models:\n",
    "1. An embedding.\n",
    "2. An MLP with input dimension the embedding dimension.\n",
    "\n",
    "A forward call works as follows:\n",
    "1. We receive a minibatch dictionary with keys, among others, `\"token_ids\"` and `\"mask\"`.\n",
    "2. Using the token IDs, we get embedding vectors.\n",
    "3. Now, we want to take the average of the embedding vectors along the sequence dimension. As we want to avoid padding entries:\n",
    "    1. Multiply the token embedding tensor with the mask tensor (properly reshaped for broadcasting).\n",
    "    2. Sum over the sequence dimension.\n",
    "    3. Divide by the sum of the mask tensor over the sequence dimension.\n",
    "\n",
    "Make a categorical deep set model:\n",
    "1. with the vocabulary size of our corpus,\n",
    "2. 10 embedding dimensions (a little small, compared to the usual 300, but good for a first, test experiment),\n",
    "3. the proper number of output dimensions for binary classification and\n",
    "4. 2 hidden layers of 128 dimensions.\n",
    "\n",
    "Then apply the model on a training minibatch and print the output shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0588ee7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CategoricalDeepSet(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready deep set\n",
    "    with input space subsets of a given finite set, called the *vocabulary*.\n",
    "    It is composed of an embedding and an outgoing MLP.\n",
    "\n",
    "    In a forward call:\n",
    "    1. First, we get the token embeddings.\n",
    "    2. Then, we average the token embeddings over the sequence dimension.\n",
    "    3. Finally, we apply the outgoing MLP.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    embedding_dim : `int`\n",
    "        The number of embedding dimensions.\n",
    "    out_features : `int`\n",
    "        The number of output features.\n",
    "    vocabulary_size : `int`\n",
    "        The number of vocabulary entries.\n",
    "    hidden_layer_num : `int`, optional\n",
    "        If `hidden_layer_sizes` is not given, we create an outgoing MLP with\n",
    "        `hidden_layer_num` hidden layers of\n",
    "        `hidden_layer_size` dimensions.\n",
    "    hidden_layer_size : `int`, optional\n",
    "        If `hidden_layer_sizes` is not given, we create an outgoing MLP with\n",
    "        `hidden_layer_num` hidden layers of\n",
    "        `hidden_layer_size` dimensions.\n",
    "    hidden_layer_sizes: `Iterable[int]`, optional\n",
    "        If given, each entry gives a hidden layer with the given size\n",
    "        for the outgoing MLP.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required keys:\n",
    "        `\"token_ids\"` : `torch.Tensor`\n",
    "            Tensor of token IDs, of shape\n",
    "            `batch_shape + (sequence_dim,)` or\n",
    "            `ensemble_shape + batch_shape + (sequence_dim,)`\n",
    "        `\"mask\"` : `torch.Tensor`\n",
    "            Mask showing which entries are not padding, of the same shape.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        embedding_dim: int,\n",
    "        out_features: int,\n",
    "        vocabulary_size: int,\n",
    "        hidden_layer_num: Optional[int] = None,\n",
    "        hidden_layer_size: Optional[int] = None,\n",
    "        hidden_layer_sizes: Optional[Iterable[int]] = None\n",
    "    ):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> torch.Tensor:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdf31563",
   "metadata": {},
   "source": [
    "### `evaluate_model`\n",
    "\n",
    "Time to update the function `evaluate_model` you wrote in Notebook 0321! When you're done, evaluate the model you just created, on the tokenized validation dataset. This is also a great time to tune `minibatch_size_eval`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d937d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    indptr_key=\"indptr\",\n",
    "    target_key=\"target\"\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Evaluate a model on a supervised dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Size of consecutive minibatches to take from the dataset.\n",
    "            To be set according to RAM or GPU memory capacity.\n",
    "    dataset : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    get_metric : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        Function to get the metric from a pair of\n",
    "        predicted and target value tensors.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to evaluate.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2aefc5b",
   "metadata": {},
   "source": [
    "### `train_supervised`\n",
    "\n",
    "Update the function `pbt` you wrote in Notebook 0326 to use dictionary datasets with potentially sequencial entries. You can also use the functions `pbt_init` and `pbt_update` we refactored in Notebook 0328. When you're done, create an `AdamW` optimizer for the model, train it and print the best validation accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2ee4fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_supervised(\n",
    "    config: dict,\n",
    "    dataset_train: dict,\n",
    "    dataset_valid: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    optimizer: Optimizer,\n",
    "    target_key=\"target\"\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Population-based training on a supervised learning task.\n",
    "    Tuned hyperparameters are given by raw values and transformations.\n",
    "    This way, the hyperparameters are perturbed by\n",
    "    additive noise on raw values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            Minibatch size to use in a training step.\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Minibatch size to use in evaluation.\n",
    "            On CPU, should be about the same as `minibatch_size`.\n",
    "            On GPU, should be as big as possible without\n",
    "            incurring an Out of Memory error.\n",
    "        `\"pbt\"` : `bool`\n",
    "            Whether to use PBT updates in validations.\n",
    "            If `False`, the algorithm just samples hyperparameters at start,\n",
    "            then keeps them constant.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"steps_without_improvement`\" : `int`\n",
    "            If the number of training steps without improvement\n",
    "            exceeds this value, then training is stopped.\n",
    "        `\"valid_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    dataset_train : `dict`\n",
    "        The dataset to train the model on.\n",
    "    dataset_valid : `dict`\n",
    "        The dataset to evaluate the model on.\n",
    "    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of losses per ensemble member.\n",
    "    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of metrics per ensemble member.\n",
    "        We assume a greater metric is better.\n",
    "    `model` : `torch.nn.Module`\n",
    "        The model ensemble to tune.\n",
    "    `optimizer` : `Optimizer`\n",
    "        An optimizer that tracks the parameters of `model`.\n",
    "    indptr_key : `str`, optional\n",
    "        If the dataset has sequential entries,\n",
    "        then this is the key of the index pointer tensor.\n",
    "        Default: `\"indptr\"`.\n",
    "    target_key : `str`, optional\n",
    "        The key mapped to the target value tensor in the dataset.\n",
    "        Default: `\"target\"`\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following key-value pairs:\n",
    "        `\"source mask\"` : `torch.Tensor`\n",
    "            The source masks of population members\n",
    "            that were replace by other members in a PBT update\n",
    "        `\"target indices\"` : `torch.Tensor`\n",
    "            The indices of population members\n",
    "            that the member where the source mask is to were replaced with.\n",
    "        `\"training loss\"` : `torch.Tensor`\n",
    "            The training losses at evaluation steps.\n",
    "        `\"training metric\"` : `torch.Tensor`\n",
    "            The training metrics at evaluation steps.\n",
    "        `\"validation loss\"` : `torch.Tensor`\n",
    "            The validation losses at evaluation steps.\n",
    "        `\"validation metric\"` : `torch.Tensor`\n",
    "            The validation metrics at evaluation steps.\n",
    "\n",
    "        In addition, for each tuned hyperparameter name,\n",
    "        we include a `torch.Tensor` of values per update.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57009af4",
   "metadata": {},
   "source": [
    "Talk about high accuracy!\n",
    "\n",
    "Delete the model and the optimizer, then train a model with 1 embedding dimension and 0 hidden layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce45957e",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2504609e",
   "metadata": {},
   "source": [
    "Still doing great! Moreover, we got a super interpretable model: logistic regression on 1-dimensional word vectors.\n",
    "\n",
    "Using the log dictionary, get the population member ID with the best latest validation binary accuracy.\n",
    "\n",
    "Get the word vectors of this population member from the model.\n",
    "\n",
    "Using the vocabulary, print the words with the 100 smallest and largest word vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2645b0c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7edf4491",
   "metadata": {},
   "source": [
    "Makes sense, eh?\n",
    "\n",
    "Note that we trained the models from scratch, that is, in this case, the models did not have any linguistic knowledge before the training we did."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f5f9987",
   "metadata": {},
   "source": [
    "# Dataset References\n",
    "\n",
    "[1] Elvis Saravia, Hsien-Chi Toby Liu, Yen-Hao Huang, Junlin Wu and Yi-Shin Chen: *CARER: Contextualized Affect Representations for Emotion Recognition*, 2018. Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pages 3687--3697. https://aclanthology.org/D18-1404/ https://huggingface.co/datasets/dair-ai/emotion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c26d178",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under CC BY-NC-SA 4.0. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
