{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ensemble Training\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "1. From `collections` import `defaultdict`. This will help us manage all those metric lists in the training loop.\n",
    "2. From `collections.abc` import `Generator` and `Sequence`. These will serve as type hints.\n",
    "3. Import the usual collection of `plt`, `torch`, `F` and `tqdm`.\n",
    "4. Also import the package `scipy`. We'll need this to get confidence interval bounds.\n",
    "4. Put the function `load_preprocessed_dataset` you created in notebook 0219 to a file and import it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from collections.abc import Generator, Sequence\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tqdm\n",
    "\n",
    "from util_0221 import load_preprocessed_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_preprocessed_path\"`: `str`  \n",
    "    * If you saved the preprocessed MNIST dataset in notebook 0216,\n",
    "        this should be the path to that file.\n",
    "    * Otherwise, please run notebook 0216 with solutions\n",
    "        so that the preprocessed dataset is saved into `data/mnist.pt`\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device to store tensors on.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    We will add the ensemble dimensions to the left of the weights.\n",
    "    That is, in MNIST, with 785 features and 10 labels, the shape\n",
    "    of the weight tensor will be `ensemble_shape + (785, 10)`.\n",
    "    For now, let's create one ensemble dimension of size 10.\n",
    "    That is, make this `(10,)`. This will make the shape of the weight tensor `(10, 785, 10)`.\n",
    "- `\"learning_rate\"`: `float`  \n",
    "    Make this a `1`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `256`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `1000`.\n",
    "- `\"valid_interval\"` : `int`  \n",
    "    The frequency of model evaluation during training,\n",
    "    measured in train steps. Make this a `100`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_preprocessed_path\": \"data/mnist.pt\",\n",
    "    \"device\": \"cuda\",\n",
    "    \"ensemble_shape\": (10,),\n",
    "    \"learning_rate\": 1,\n",
    "    \"minibatch_size\": 256,\n",
    "    \"seed\": 1,\n",
    "    \"steps_num\": 1000,\n",
    "    \"valid_interval\": 10\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set `torch.manual_seed` to the value of the `\"seed\"` key in the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the preprocessed MNIST dataset with `load_preprocessed_dataset`.\n",
    "Check tensor shapes and devices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    (train_features, train_labels),\n",
    "    (valid_features, valid_labels),\n",
    "    (test_features, test_labels)\n",
    ") = load_preprocessed_dataset(\n",
    "    config\n",
    ")\n",
    "\n",
    "for t in (\n",
    "    train_features, train_labels,\n",
    "    valid_features, valid_labels,\n",
    "    test_features, test_labels\n",
    "):\n",
    "    print(t.shape, t.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ensemble Training\n",
    "\n",
    "Now, we shall embark on modifying the code from Notebook 0219 to allow training ensembles of logistic regression models. Please note that in the functions we write, we will allow arbitrary `ensemble_shape`, not just 1-dimensional ensembles. The reason for this generalization is that\n",
    "\n",
    "1. It will not be any more difficult to write the code with this and\n",
    "2. We'll need this generality next week, for efficient hyperparameter tuning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ensemble Dataloader\n",
    "\n",
    "1. You can generate a tensor of any size of values drawn from $\\mathscr U([0, 1))$ by `torch.rand`.\n",
    "    1. Make the positional argument the size of the tensor you want to get and\n",
    "    2. specify the device you want the tensor on by the keyword argument `device`.\n",
    "\n",
    "2. You can get shuffled indices by calling the `argsort` method of this tensor. You can specify by the keyword argument `dim` that you want to perform the operation along the last dimension.\n",
    "\n",
    "Try this out first: write the function below and print its output with ensemble shape `(2, 4)` and dataset size `5`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_shuffled_indices(\n",
    "    dataset_size: int,\n",
    "    device=\"cpu\",\n",
    "    ensemble_shape=(),\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Get a tensor of a batch of shuffles of indices `0,...,dataset_size - 1`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset_size : int\n",
    "        The size of the dataset the indices of which to shuffle\n",
    "    device : int | str | torch.device, optional\n",
    "        The device to store the resulting tensor on. Default: \"cpu\"\n",
    "    ensemble_shape : tuple[int], optional\n",
    "        The batch shape of the shuffled index tensors. Default: ()\n",
    "    \"\"\"\n",
    "    total_shape = ensemble_shape + (dataset_size,)\n",
    "    uniform = torch.rand(\n",
    "        total_shape,\n",
    "        device=device\n",
    "    )\n",
    "    indices = uniform.argsort(dim=-1)\n",
    "\n",
    "    return indices\n",
    "\n",
    "get_shuffled_indices(5, ensemble_shape=(2, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we're ready to yield indices! Recall that in the single model case, the index generator yielded slices of the tensor of shuffled indices. We can do the same with a tensor of batched shuffled indices: just start the indexing with ellipsis `...`: This will make the indexing pattern match, see here:  \n",
    "https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools\n",
    "\n",
    "Write the function below and print the output of 3 iterations on the generator it returns with dataset size `10`, ensemble shape `(3,)` and minibatch size `4`. Don't forget to use `get_shuffled_indices` that you already wrote!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_random_reshuffler(\n",
    "    dataset_size: int,\n",
    "    minibatch_size: int,\n",
    "    device=\"cpu\",\n",
    "    ensemble_shape=()\n",
    ") -> Generator[torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Generate minibatch indices for a random shuffling dataloader.\n",
    "    Supports arbitrary ensemble shapes.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset_size : int\n",
    "        The size of the dataset to yield batches of minibatch indices for.\n",
    "    minibatch_size : int\n",
    "        The minibatch size.\n",
    "    device : int | str | torch.device, optional\n",
    "        The device to store the index tensors on. Default: \"cpu\"\n",
    "    ensemble_shape : tuple[int], optional\n",
    "        The ensemble shape of the minibatch indices. Default: ()\n",
    "    \"\"\"\n",
    "    q, r = divmod(dataset_size, minibatch_size)\n",
    "    minibatch_num = q + min(1, r)\n",
    "    minibatch_index = minibatch_num\n",
    "    while True:\n",
    "        if minibatch_index == minibatch_num:\n",
    "            minibatch_index = 0\n",
    "            shuffled_indices = get_shuffled_indices(\n",
    "                dataset_size,\n",
    "                device=device,\n",
    "                ensemble_shape=ensemble_shape\n",
    "            )\n",
    "\n",
    "        yield shuffled_indices[\n",
    "            ...,\n",
    "            minibatch_index * minibatch_size\n",
    "        :(minibatch_index + 1) * minibatch_size\n",
    "        ]\n",
    "\n",
    "        minibatch_index += 1\n",
    "\n",
    "random_reshuffler = get_random_reshuffler(\n",
    "    10,\n",
    "    4,\n",
    "    ensemble_shape=(3,)\n",
    ")\n",
    "for _ in range(3):\n",
    "    print(next(random_reshuffler))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you can create a dataloader just by iterating over the output of `get_random_reshuffler` and yielding the pair of features and labels\n",
    "indexed by the indices you get.\n",
    "\n",
    "Write the `get_dataloader_random_reshuffle` function, get a train dataloader and iterate over it 5 times while printing the shape, device and dtype of the tensors you get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloader_random_reshuffle(\n",
    "    config: dict,\n",
    "    features: torch.Tensor,\n",
    "    labels: torch.Tensor\n",
    ") -> Generator[tuple[torch.Tensor, torch.Tensor]]:\n",
    "    \"\"\"\n",
    "    Given a feature and a label tensor,\n",
    "    creates a random reshuffling (without replacement) dataloader\n",
    "    that yields pairs `minibatch_features, minibatch_labels` indefinitely.\n",
    "    Support arbitrary ensemble shapes.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        ensemble_shape : tuple[int]\n",
    "            The required ensemble shapes of the outputs.\n",
    "        minibatch_size : int\n",
    "            The size of the minibatches.\n",
    "    features : torch.Tensor\n",
    "        Tensor of dataset features.\n",
    "        We assume that the first dimension is the batch dimension\n",
    "    labels : torch.Tensor\n",
    "        Tensor of dataset labels.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of tuples `minibatch_features, minibatch_labels`.\n",
    "    \"\"\"\n",
    "    for indices in get_random_reshuffler(\n",
    "        len(labels),\n",
    "        config[\"minibatch_size\"],\n",
    "        ensemble_shape=config[\"ensemble_shape\"]\n",
    "    ):\n",
    "        yield features[indices], labels[indices]\n",
    "\n",
    "train_dataloader = get_dataloader_random_reshuffle(\n",
    "    config,\n",
    "    train_features,\n",
    "    train_labels\n",
    ")\n",
    "\n",
    "for _ in range(5):\n",
    "    for t in next(train_dataloader):\n",
    "        print(t.shape, t.device, t.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ensemble Accuracy\n",
    "\n",
    "Recall from the broadcasted matrix multiplication use case that we get the logit tensors with shape $(e_1,\\dotsc,e_s,n,c)$ where\n",
    "1. $(e_1,\\dotsc,e_s)$ is the ensemble shape,\n",
    "2. $n$ is the dataset size or a minibatch size and\n",
    "3. $c$ is the number of labels.\n",
    "\n",
    "We also receive a label tensor of shape either\n",
    "1. $(n,)$ or\n",
    "2. $(e_1,\\dotsc,e_s)$.\n",
    "\n",
    "After we produced the logits as above, we first need to extract from them the predicted labels for each member of the ensemble. This will be a tensor of shape $(e_1,\\dotsc,e_s,n)$. Similarly to the single model case, you can achieve this using the `argmax` method on the logit tensor. But this, time, use the `dim` keyword argument to specify that you want to perform the operation along the last dimension.\n",
    "\n",
    "Now you can perform the elementwise equality operation `==` between the true label and the predicted label tensors. By broadcasting (think about this!) you will get a Boolean tensor of shape $(e_1,\\dotsc,e_s,n)$. After transforming it to a floating point tensor, you can again take the mean, but this time only along the last dimension. You should end up with a tensor of accuracies of shape the ensemble shape $(e_1,\\dotsc,e_s)$.\n",
    "\n",
    "Write the function `get_accuracy` like this. After you wrote `get_accuracy`, calculate the validation accuracy of a logistic model with\n",
    "1. constant zero weight tensor of shape `ensemble_shape + (feature_dim, label_num)` and optionally\n",
    "2. constant zero bias tensor of shape `ensemble_shape + (label_num,)`.\n",
    "\n",
    "You should get a tensor of random choice accuracies.\n",
    "\n",
    "Also, get a minibatch from the train dataloader and calculate the accuracy on that.\n",
    "\n",
    "Note that we test both as they are different cases: train minibatch logits have shape $(e_1,\\dotsc,e_s,n_\\text{minibatch},c)$, while validation minibatch logits have shape $(n_\\text{valid}, c)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy(\n",
    "    logits: torch.Tensor,\n",
    "    labels: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Given logits output by a classification model, calculate the accuracy.\n",
    "    Supports model ensembles of arbitrary ensemble shape.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : torch.Tensor\n",
    "        Logit tensor of shape\n",
    "        `ensemble_shape + (dataset_size, label_num)`.\n",
    "    labels : torch.Tensor\n",
    "        Label tensor of shape \n",
    "        `(dataset_size,)` or\n",
    "        `ensemble_shape + (dataset_size,)`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The tensor of accuracies of shape `ensemble_shape`.\n",
    "    \"\"\"\n",
    "    labels_predict = logits.argmax(dim=-1)\n",
    "    accuracy = (labels == labels_predict).to(torch.float32).mean(dim=-1)\n",
    "\n",
    "    return accuracy\n",
    "\n",
    "minibatch_features, minibatch_labels = next(train_dataloader)\n",
    "\n",
    "# For this example, we don't need to track gradients of weights.\n",
    "weights = torch.zeros(\n",
    "    config[\"ensemble_shape\"] + (valid_features.shape[1], 10),\n",
    "    device=config[\"device\"]\n",
    ")\n",
    "bias = torch.zeros_like(weights[..., 0:1, :])\n",
    "\n",
    "# Thus, we don't need to `detach` weights either.\n",
    "minibatch_logits = minibatch_features @ weights + bias[..., :]\n",
    "valid_logits = valid_features @ weights + bias[..., :]\n",
    "\n",
    "print(\n",
    "    get_accuracy(\n",
    "        minibatch_logits,\n",
    "        minibatch_labels,\n",
    "    ),\n",
    "    get_accuracy(\n",
    "        valid_logits,\n",
    "        valid_labels,\n",
    "    ),\n",
    "    sep='\\n'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ensemble Cross-Entropy\n",
    "\n",
    "Let's write a function that calculates the cross-entropy of input potentially coming from an ensemble.\n",
    "\n",
    "Perusing the documentation:  \n",
    "https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html  \n",
    "we see that\n",
    "1. The logit and the label tensors do can have additional batch dimensions, but they are in the end of the shape and not the beginning. To swap dimensions, you can use the method `movedim` of the tensors:  \n",
    "https://pytorch.org/docs/stable/generated/torch.Tensor.movedim.html\n",
    "2. The logit and label tensors need to have the same batch shape. So before step 1, you need to account for the case where logits have ensemble shape, while labels not yet. To that end, you can use the `broadcast_to` method of the labels tensor:  \n",
    "https://pytorch.org/docs/stable/generated/torch.Tensor.broadcast_to.html\n",
    "3. By default, `F.cross_entropy` outputs the mean cross-entropy between the full logit and label tensors. But we want to take the mean only along the dataset dimension, not the ensemble shape.\n",
    "    1. Thus, we can turn off taking the mean by setting the `reduction` keyword argument of `F.cross_entropy` to `\"none\"`.\n",
    "    2. Now, the output will be a tensor of shape $(n, e_1,\\dotsc,e_s)$. To take the mean along the leftmost dimension, you can set the `dim` keyword argument of the `mean` method accordingly.\n",
    "\n",
    "After you wrote `get_cross_entropy`, calculate the validation cross-entropy of a logistic model with constant zero weight tensor of shape `ensemble_shape + (feature_dim, label_num)`. You should get a tensor of constant values $\\log(10)$.\n",
    "\n",
    "Also, get a minibatch from the train dataloader and calculate the cross-entropy on that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cross_entropy(\n",
    "    logits: torch.Tensor,\n",
    "    labels: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Given logits output by a classification model, \n",
    "    calculate the cross-entropy.\n",
    "    Supports model ensembles of arbitrary ensemble shape.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : torch.Tensor\n",
    "        Logit tensor of shape\n",
    "        `ensemble_shape + (dataset_size, label_num)`.\n",
    "    labels : torch.Tensor\n",
    "        Label tensor of shape \n",
    "        `(dataset_size,)` or\n",
    "        `ensemble_shape + (dataset_size,)`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The tensor of accuracies of shape `ensemble_shape`.\n",
    "    \"\"\"\n",
    "    return F.cross_entropy(\n",
    "        logits.movedim((-2, -1), (0, 1)),\n",
    "        labels.broadcast_to(logits.shape[:-1]).movedim(-1, 0),\n",
    "        reduction=\"none\"\n",
    "    ).mean(dim=0)\n",
    "\n",
    "print(\n",
    "    get_cross_entropy(\n",
    "        minibatch_logits,\n",
    "        minibatch_labels,\n",
    "    ),\n",
    "    get_cross_entropy(\n",
    "        valid_logits,\n",
    "        valid_labels,\n",
    "    ),\n",
    "    sep='\\n'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ensemble Training Loop\n",
    "\n",
    "At this point, we made most of the changes required to write a function that trains an ensemble of models at once. Let's write that training loop!\n",
    "\n",
    "1. To make the evaluation metric lists easier to handle, we shall store them in a `collections.defaultdict(list)`. With this, if you try to access a key of the dictionary that it does not yet have, the key gets the default value of an empty list. So you can append an evaluation metric to a value of the dictionary without prior initialization.\n",
    "2. Note that the metrics we get from `get_accuracy` and `get_cross_entropy` are not scalars, but tensors of shape `ensemble_shape`.\n",
    "    1. This means that when you append these to the lists, you still need to call their `cpu` method, but not the `item` method.\n",
    "    2. Also, the train loss is going to have this shape. So in order to call its `backward` method, you first need to call its `sum` method.\n",
    "    3. To accumulate train accuracies and cross-entropies, instead of dedicated lists, you can accumulate into tensors:\n",
    "        1. At initialization, create zero-valued tensors of shape `ensemble_shape` for this.\n",
    "        2. Also, keep the entry number counter, initialized at 0.\n",
    "        3. At each train step, add to these tensors minibatch size times the minibatch accuracy and cross-entropy, respectively.\n",
    "        4. Also, add the minibatch size to the entry number counter.\n",
    "        5. At evaluation, append to the train accuracy and cross-entropy lists (to be accessed at keys of the output dictionary) the quotients of the accumulator tensors by the entry number count.\n",
    "        6. Afterwards, reset the accumulator tensors and the entry number counter to 0.\n",
    "3. At the end of training:\n",
    "    1. Change the values of the output dictionary that are lists of tensors to the output of `torch.stack` called on them. This will make it easier to get sample means and confidence intervals.\n",
    "    2. Set the `weights` key of the output dictionary to the final weights.\n",
    "\n",
    "Write `train_logistic_regression` and run it. Print the train and validation accuracy and cross-entropy tensors you got at the last evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_logistic_regression(\n",
    "    config: dict,\n",
    "    label_num: int,\n",
    "    train_dataloader: Generator[tuple[torch.Tensor, torch.Tensor]],\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_labels: torch.Tensor,\n",
    "    use_bias=True\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Train a logistic regression model on a classification task.\n",
    "    Support model ensembles of arbitrary shape.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        ensembe_shape : tuple[int]\n",
    "            The shape of the model ensemble.\n",
    "        learning_rate : float\n",
    "            The learning rate of the SGD optimization.\n",
    "        steps_num : int\n",
    "            The number of training steps to take.\n",
    "        valid_interval : int\n",
    "            The frequency of evaluations,\n",
    "            measured in the number of train steps.\n",
    "    label_num : int\n",
    "        The number of distinct labels in the classification task.\n",
    "    train_dataloader : Generator[tuple[torch.Tensor, torch.Tensor]]\n",
    "        A training minibatch dataloader, that yields pairs of\n",
    "        feature and label tensors indefinitely.\n",
    "        We assume that these have shape\n",
    "        `ensemble_shape + (minibatch_size, feature_dim)`\n",
    "        and `ensemble_shape + (minibatch_size,)`\n",
    "        respectively.\n",
    "    valid_features : torch.Tensor\n",
    "        Validation feature matrix.\n",
    "    valid_labels : torch.Tensor\n",
    "        Validation label vector.\n",
    "    use_bias : bool, optional\n",
    "        Whether to use a bias vector in the logistic regression model.\n",
    "        Default: `True`\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following keys:\n",
    "        training accuracy : torch.Tensor\n",
    "            The tensor of training accuracies, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        training cross-entropy : torch.Tensor\n",
    "            The tensor of training cross-entropies, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        training steps : list[int]\n",
    "            The list of the number of training steps at each evaluation.\n",
    "        validation accuracy : torch.Tensor\n",
    "            The tensor of validation accuracies, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        validation cross-entropy : torch.Tensor\n",
    "            The tensor of validation cross-entropies, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        weights : torch.Tensor\n",
    "            The logistic regression weights at the end of training.\n",
    "        bias : torch.Tensor, optional\n",
    "            The logistic regression biases at the end of training, if `use_bias`.\n",
    "    \"\"\"\n",
    "    device = valid_features.device\n",
    "    features_dtype = valid_features.dtype\n",
    "    output = defaultdict(list)\n",
    "\n",
    "    train_accuracies_step = torch.zeros(\n",
    "        config[\"ensemble_shape\"],\n",
    "        device=device,\n",
    "        dtype=features_dtype\n",
    "    )\n",
    "    train_entries = 0\n",
    "    train_losses_step = torch.zeros(\n",
    "        config[\"ensemble_shape\"],\n",
    "        device=device,\n",
    "        dtype=features_dtype\n",
    "    )\n",
    "\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    step_id = 0\n",
    "    weights = torch.zeros(\n",
    "        config[\"ensemble_shape\"] + (valid_features.shape[1], label_num),\n",
    "        device=device,\n",
    "        dtype=features_dtype,\n",
    "        requires_grad=True\n",
    "    )\n",
    "    if use_bias:\n",
    "        bias = torch.zeros_like(weights[..., 0:1, :])\n",
    "\n",
    "    optimizer = torch.optim.SGD([weights], lr=config[\"learning_rate\"])\n",
    "\n",
    "    for minibatch_features, minibatch_labels in train_dataloader:\n",
    "        minibatch_size = minibatch_labels.shape[-1]\n",
    "        optimizer.zero_grad()\n",
    "        logits = minibatch_features @ weights\n",
    "        if use_bias:\n",
    "            logits = logits + bias\n",
    "\n",
    "        train_accuracies_step += get_accuracy(\n",
    "            logits.detach(),\n",
    "            minibatch_labels,\n",
    "        ) * minibatch_size\n",
    "        loss = get_cross_entropy(\n",
    "            logits,\n",
    "            minibatch_labels,\n",
    "        )\n",
    "        loss.sum().backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_losses_step += loss.detach() * minibatch_size\n",
    "        train_entries += minibatch_size\n",
    "\n",
    "        progress_bar.update()\n",
    "        step_id += 1\n",
    "        if step_id % config[\"valid_interval\"] == 0:\n",
    "            with torch.no_grad():\n",
    "                logits = valid_features @ weights\n",
    "                if use_bias:\n",
    "                    logits = logits + bias\n",
    "\n",
    "            valid_accuracy = get_accuracy(\n",
    "                logits,\n",
    "                valid_labels,\n",
    "            )\n",
    "            valid_loss = get_cross_entropy(\n",
    "                logits,\n",
    "                valid_labels,\n",
    "            )\n",
    "\n",
    "            output[\"training accuracy\"].append(\n",
    "                (train_accuracies_step / train_entries).cpu()\n",
    "            )\n",
    "            output[\"training cross-entropy\"].append(\n",
    "                (train_losses_step / train_entries).cpu()\n",
    "            )\n",
    "            output[\"training steps\"].append(step_id)\n",
    "            output[\"validation accuracy\"].append(valid_accuracy.cpu())\n",
    "            output[\"validation cross-entropy\"].append(valid_loss.cpu())\n",
    "\n",
    "            train_accuracies_step.zero_()\n",
    "            train_entries = 0\n",
    "            train_losses_step.zero_()\n",
    "\n",
    "        if step_id >= config[\"steps_num\"]:\n",
    "            for key in (\n",
    "                \"training accuracy\",\n",
    "                \"training cross-entropy\",\n",
    "                \"validation accuracy\",\n",
    "                \"validation cross-entropy\"\n",
    "            ):\n",
    "                output[key] = torch.stack(output[key])\n",
    "\n",
    "            output[\"weights\"] = weights\n",
    "            progress_bar.close()\n",
    "\n",
    "            return output\n",
    "        \n",
    "output = train_logistic_regression(\n",
    "    config,\n",
    "    10,\n",
    "    train_dataloader,\n",
    "    valid_features,\n",
    "    valid_labels\n",
    ")\n",
    "for key in (\n",
    "    \"training accuracy\",\n",
    "    \"training cross-entropy\",\n",
    "    \"validation accuracy\",\n",
    "    \"validation cross-entropy\"\n",
    "):\n",
    "    print(key, output[key][-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting training curves with confidence bands\n",
    "\n",
    "Finally, let's plot training curves with pointwise confidence bands! This means to include at each evaluation a confidence interval.\n",
    "\n",
    "For this, we'll suppose that the ensemble is 1-dimensional, that is we have `ensemble_shape = (ensemble_num,)`. We have an ensemble training curve in hand, that is a tensor of shape `(evals_num, ensemble_num)` where `evals_num` is the number of evaluations taken. We also selected a confidence level $\\alpha\\in[0,1]$.\n",
    "\n",
    "1. First of all, we want to get the mean metric at each evaluation. This we can get just by calling the `mean` method and specifying in the `dim` keyword argument that the operation should be performed along the last dimension.\n",
    "2. We also want to get confidence intervals at each evaluation. Recall that these are of the form $[\\bar X-cS/\\sqrt n,\\bar X+cS/\\sqrt n]$ where:\n",
    "    1. $\\bar X$ is the sample mean.\n",
    "    2. $c$ is the number such that $\\mathbf P(|T| < c) = \\alpha$ where $T$ is a Student $t$-distribution with $n-1$ degrees of freedom. As this distribution is symmetrical, we have\n",
    "    $$\n",
    "    \\mathbf P(|T|<c) = 1 - \\mathbf P(|T| > c)\n",
    "    = 1 - 2\\mathbf P(T < -c)\n",
    "    $$\n",
    "    \n",
    "    That is, we have $\\mathbf P(T < -c) = \\frac{1-\\alpha}{2}$.\n",
    "    \n",
    "    Note: A typo in an earlier version has been pointed out by Jack Xie.\n",
    "    \n",
    "    As the function $x\\mapsto\\mathbf P(T<x)$ is the *cumulative distribution function (cdf)*, you can get the value $c$ from $\\alpha$ by the *inverse cdf*. This is what we need `scipy` for:  \n",
    "    https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.t.html  \n",
    "    3. $S^2$ is the sample variance, thus $S$ is the sample standard deviation (std). You can calculate the latter by the `std` method on a tensor. Don't forget to set the `dim`.  \n",
    "    4. $n$ is the number of samples, which in the current case is the ensemble number.\n",
    "3. You can plot an confidence band with `plt.fill_between`.\n",
    "    1. The three positional arguments are the $x$-coordinates, the lower $y$-boundaries and the upper $y$-boundaries. In our case, the first should be the train step numbers at evaluations, and the other two the boundaries of the confidence intervals.\n",
    "    2. The keyword argument `alpha` defines the opacity of the confidence band on a scale $[0, 1]$.\n",
    "    3. I suggest also setting the `color` and `label` keyword arguments.\n",
    "\n",
    "Write the function `line_plot_confidence_band`.\n",
    "\n",
    "Draw the training and validation accuracies on a canvas. Maybe restrict the y values shown with `plt.ylim`. Show and clear the canvas. Then draw the training and validation cross-entropies on another canvas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def line_plot_confidence_band(\n",
    "    x: Sequence,\n",
    "    y: torch.Tensor,\n",
    "    color=None,\n",
    "    confidence_level=.95,\n",
    "    label=\"\",\n",
    "    opacity=.2\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot training curves from an ensemble with a pointwise confidence band.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    x : Sequence\n",
    "        The sequence of time indicators (eg. number of train steps)\n",
    "        when the measurements took place.\n",
    "    y : torch.Tensor\n",
    "        The tensor of measurements of shape `(len(x), ensemble_num)`.\n",
    "    color : str | tuple[float] | None, optional\n",
    "        The color of the plot. Default: `None`\n",
    "    confidence_level : float, optional\n",
    "        The confidence level of the confidence band. Default: 0.95\n",
    "    label : str, optional\n",
    "        The label of the plot. Default: \"\"\n",
    "    opacity : float, optional\n",
    "        The opacity of the confidence band, to be set via the\n",
    "        `alpha` keyword argument of `plt.fill_between`. Default: 0.2\n",
    "    \"\"\"\n",
    "    sample_size = y.shape[1]\n",
    "    student_coefficient = -scipy.stats.t(sample_size - 1).ppf(\n",
    "        (1 - confidence_level) / 2\n",
    "    )\n",
    "    y_mean = y.mean(dim=-1)\n",
    "    y_std = y.std(dim=-1)\n",
    "    \n",
    "    interval_half_length = student_coefficient * y_std / sample_size ** .5\n",
    "    y_low = y_mean - interval_half_length\n",
    "    y_high = y_mean + interval_half_length\n",
    "\n",
    "    plt.fill_between(x, y_low, y_high, alpha=opacity, color=color)\n",
    "    plt.plot(x, y_mean, color=color, label=label)\n",
    "\n",
    "for key, color in (\n",
    "    (\"training accuracy\", \"red\"),\n",
    "    (\"validation accuracy\", \"blue\"),\n",
    "):\n",
    "    line_plot_confidence_band(\n",
    "        output[\"training steps\"],\n",
    "        output[key],\n",
    "        color=color,\n",
    "        label=key\n",
    "    )\n",
    "\n",
    "plt.legend()\n",
    "plt.xlabel(\"Training steps\")\n",
    "plt.ylim(.85, .95)\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "for key, color in (\n",
    "    (\"training cross-entropy\", \"red\"),\n",
    "    (\"validation cross-entropy\", \"blue\"),\n",
    "):\n",
    "    line_plot_confidence_band(\n",
    "        output[\"training steps\"],\n",
    "        output[key],\n",
    "        color=color,\n",
    "        label=key\n",
    "    )\n",
    "\n",
    "plt.legend()\n",
    "plt.xlabel(\"Training steps\")\n",
    "plt.ylim(.2, .6)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets\n",
    "\n",
    "## MNIST\n",
    "\n",
    "https://huggingface.co/datasets/ylecun/mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
