{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stochastic Gradient Descent\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "The class `collections.abc.Generator` is a type hint for a `Generator`, the type of object we'll use for dataloaders. Import it.\n",
    "\n",
    "Moreover, import the following:\n",
    "1. `matplotlib.pyplot` as `plt`.\n",
    "2. `torch`\n",
    "3. `torch.nn.functional` as `F`\n",
    "4. `tqdm`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections.abc import Generator\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants\n",
    "\n",
    "Create a configuration dictionary with the following\n",
    "- `\"dataset_preprocessed_path\"`: `str`\n",
    "    * If you saved the preprocessed MNIST dataset in notebook 0214,\n",
    "        this should be the path to that file.\n",
    "    * Otherwise, please run notebook 0214 with solutions\n",
    "        so that the preprocessed dataset is saved into `data/mnist.pt`\n",
    "- `\"device\"`: `torch.device | int | str`\n",
    "    Make this respective to your system (`\"cpu\"`, `\"cuda\"`, `\"mps\"`, etc.)\n",
    "- \"`learning_rate\"`: `float`\n",
    "    Make this a `1.`\n",
    "- `\"minibatch_size\"`: `int`\n",
    "    Traditionally, this is a power of two.\n",
    "    This is because older GPU architectures were faster like that.\n",
    "    This does not seem to be the case anymore. See eg. here:\n",
    "    https://sebastianraschka.com/blog/2022/batch-size-2.html\n",
    "    Make it a `256`.\n",
    "- `\"seed\"`: `int`\n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`\n",
    "    Make this a `10_000`.\n",
    "- `\"valid_interval\"` : `int`\n",
    "    The frequency of model evaluation during training,\n",
    "    measured in train steps. Make this a `100`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_preprocessed_path\": \"data/mnist.pt\",\n",
    "    \"device\": \"cpu\",\n",
    "    \"learning_rate\": 1,\n",
    "    \"minibatch_size\": 256,\n",
    "    \"seed\": 1,\n",
    "    \"steps_num\": 1000,\n",
    "    \"valid_interval\": 10\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the seed of the `torch` pseudorandom number generator as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the Preprocessed Dataset\n",
    "\n",
    "Let's first write a function that loads the preprocessed MNIST dataset\n",
    "we saved in notebook 0214. Write the function as per the description, and load the dataset.\n",
    "Print the shape, dtype and device of all 6 tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_preprocessed_dataset(\n",
    "    config: dict\n",
    ") -> tuple[\n",
    "    tuple[torch.Tensor, torch.Tensor],\n",
    "    tuple[torch.Tensor, torch.Tensor],\n",
    "    tuple[torch.Tensor, torch.Tensor]\n",
    "]:\n",
    "    \"\"\"\n",
    "    Loads a dataset that was saved with `torch.save`.\n",
    "    We expect that the object that was saved is a dictionary with keys\n",
    "    `train_features`, `train_labels`\n",
    "    `valid_features`, `valid_labels`,\n",
    "    `test_features`, `test_labels`\n",
    "    storing the appropriate data in tensors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:  \n",
    "        dataset_preprocessed_path : str\n",
    "            The path where the preprocessed dataset was saved to.\n",
    "        device : torch.device | int | str\n",
    "            The device to map the tensors to.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The triple of pairs\n",
    "    `(train_features, train_labels),\n",
    "    (valid_feautres, valid_labels),\n",
    "    (test_features, test_labels)`\n",
    "    \"\"\"\n",
    "    loaded = torch.load(\n",
    "        config[\"dataset_preprocessed_path\"],\n",
    "        weights_only=True\n",
    "    )\n",
    "    (\n",
    "        train_features,\n",
    "        train_labels,\n",
    "        valid_features,\n",
    "        valid_labels,\n",
    "        test_features,\n",
    "        test_labels\n",
    "    ) = (\n",
    "        loaded[key].to(config[\"device\"])\n",
    "        for key in [\n",
    "            \"train_features\",\n",
    "            \"train_labels\",\n",
    "            \"valid_features\",\n",
    "            \"valid_labels\",\n",
    "            \"test_features\",\n",
    "            \"test_labels\"\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    return (\n",
    "        (train_features, train_labels),\n",
    "        (valid_features, valid_labels),\n",
    "        (test_features, test_labels)\n",
    "    )\n",
    "\n",
    "(\n",
    "    (train_features, train_labels),\n",
    "    (valid_features, valid_labels),\n",
    "    (test_features, test_labels)\n",
    ") = load_preprocessed_dataset(config)\n",
    "\n",
    "for t in (\n",
    "    train_features, train_labels,\n",
    "    valid_features, valid_labels,\n",
    "    test_features, test_labels\n",
    "):\n",
    "    print(t.shape, t.dtype, t.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SGD with Random Sampling\n",
    "\n",
    "### Sampling with Replacement from a Dataset\n",
    "\n",
    "We shall first create a Stochastic Gradient Descent (SGD) with replacement\n",
    "training loop. For this, first of all, we need to see how to sample\n",
    "from the train dataset with replacement.\n",
    "\n",
    "We will draw random samples with replacement by the following two steps:\n",
    "1. Draw an index tensor. This will be a 1d tensor of\n",
    "    1. random integers between 0 (inclusive) and\n",
    "        the size of the train dataset (exclusive)\n",
    "    2. of size a 1-dimensional tuple with the minibatch size.\n",
    "    For this, you can use `torch.randint`.\n",
    "    Try to figure this out from the documentation:\n",
    "    https://pytorch.org/docs/stable/generated/torch.randint.html\n",
    "2. Index into the train feature and label tensor with the index tensor\n",
    "    to get the minibatch features and labels.\n",
    "Get a minibatch, print out\n",
    "the shape of the minibatch feature and label tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = torch.randint(0, len(train_features), (config[\"minibatch_size\"],))\n",
    "minibatch_features, minibatch_labels = (\n",
    "    t[indices]\n",
    "    for t in [train_features, train_labels]\n",
    ")\n",
    "print(minibatch_features.shape, minibatch_labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing a Dataloader\n",
    "\n",
    "Now, we'll want to write a *dataloader*, that outputs minibatches like that.\n",
    "Programmaticaly, a dataloader is an *iterable*.\n",
    "For us, this means that the pattern we want to use is the following:\n",
    "```python\n",
    "for minibatch_features, minibatch_labels in dataloader:\n",
    "  # Take train step\n",
    "```\n",
    "For more information, on iterables, see for example here:\n",
    "https://stackoverflow.com/a/9884259\n",
    "\n",
    "Specifically, we'll use a *generator iterator*.\n",
    "We'll write a function, but instead of a single `return` statement\n",
    "(maybe per condition)\n",
    "it will have multiple `yield` statements.\n",
    "Then when you loop over the output of the function,\n",
    "you'll get the results of the `yield` statements one by one.\n",
    "For more information on generators, see here:\n",
    "https://exercism.org/tracks/python/concepts/generators\n",
    "\n",
    "For a quick introduction to this, write a function that yields 3, 4 and 5\n",
    "then loop over its output and print the entries you get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def three_four_five() -> Generator[int]:\n",
    "    yield 3\n",
    "    yield 4\n",
    "    yield 5\n",
    "\n",
    "for entry in three_four_five():\n",
    "    print(entry)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make a dataloader that yields minibatches indefinitely.\n",
    "This will make it easier to change the stopping condition in the train loop.\n",
    "On the other hand, this means you'll have to make sure to include a\n",
    "stopping condition as otherwise you'll have an infinite loop.\n",
    "\n",
    "1. Write the function defined below.\n",
    "2. Make 10 iterations over it\n",
    "    while printing the shapes of the tensors you get.\n",
    "    There are two ways to achieve this:\n",
    "    1. Iterate over the dataloader while incrementing a counter\n",
    "        explicitly. When the counter reaches 10,\n",
    "        use the statement `break`.\n",
    "    2. Iterate over a `range` and in each iteration,\n",
    "        call the function `next` on the dataloader to get a minibatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloader_random_sampling(\n",
    "    config: dict,\n",
    "    features: torch.Tensor,\n",
    "    labels: torch.Tensor\n",
    ") -> Generator[tuple[torch.Tensor, torch.Tensor]]:\n",
    "    \"\"\"\n",
    "    Given a feature and a label tensor,\n",
    "    creates a random sampling (with replacement) dataloader\n",
    "    that yields pairs `minibatch_features, minibatch_labels` indefinitely.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required key:\n",
    "        minibatch_size : int\n",
    "            The size of the minibatches.\n",
    "    features : torch.Tensor\n",
    "        Tensor of dataset features.\n",
    "        We assume that the first dimension is the batch dimension\n",
    "    labels : torch.Tensor\n",
    "        Tensor of dataset labels.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of tuples `minibatch_features, minibatch_labels`.\n",
    "    \"\"\"\n",
    "    dataset_size = len(features)\n",
    "\n",
    "    while True:\n",
    "        minibatch_indices = torch.randint(\n",
    "            0,\n",
    "            dataset_size,\n",
    "            (config[\"minibatch_size\"],)\n",
    "        )\n",
    "\n",
    "        yield features[minibatch_indices], labels[minibatch_indices]\n",
    "\n",
    "dataloader_random_resampling = get_dataloader_random_sampling(\n",
    "    config,\n",
    "    train_features,\n",
    "    train_labels\n",
    ")\n",
    "step_id = 0\n",
    "for minibatch_features, minibatch_labels in dataloader_random_resampling:\n",
    "    print(minibatch_features.shape, minibatch_labels.shape)\n",
    "    step_id += 1\n",
    "    if step_id == 10:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculating Accuracy\n",
    "\n",
    "Besides taking train steps on minibatches, another change we make\n",
    "is to regularly measure accuracy during training. This will give us\n",
    "more detailed information on the training process and the model.\n",
    "\n",
    "First of all, let's write a function that takes weights and a dataset\n",
    "and outputs the accuracy. Don't forget to:\n",
    "1. suspend gradient tracking when you calculate the logits and\n",
    "2. convert the final result to a Python `float`.\n",
    "\n",
    "Test your function by calculating the validation accuracy of a model\n",
    "with all 0 weights and optionally biases. Is it what you'd expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy(\n",
    "    logits: torch.Tensor,\n",
    "    labels: torch.Tensor\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Given unnormalized predicted logits, outputs their accuracy\n",
    "    on a given dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : torch.Tensor\n",
    "        Unnormalized prediction logits, of shape `(dataset_size, label_num)`.\n",
    "    labels : torch.Tensor\n",
    "        The label vector, of shape `(dataset_size,)` and dtype `torch.int64`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The accuracy value as a `float`.\n",
    "    \"\"\"\n",
    "    labels_predict = logits.argmax(dim=-1)\n",
    "    accuracy = (labels == labels_predict).to(torch.float32).mean()\n",
    "\n",
    "    return accuracy.cpu().item()\n",
    "\n",
    "valid_size, feature_dim = valid_features.shape\n",
    "weights = torch.zeros(\n",
    "    (feature_dim, 10),\n",
    "    device=config[\"device\"],\n",
    "    dtype=torch.float32\n",
    ")\n",
    "bias = torch.zeros_like(weights[0])\n",
    "valid_logits = valid_features @ weights + bias\n",
    "print(get_accuracy(valid_logits, valid_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculating Cross-Entropy\n",
    "\n",
    "To calculate cross-entropy, you can use [`torch.nn.functional.cross_entropy`](https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html). Note the following:\n",
    "1. As per the documentation, you can input unnormalized logits. Therefore, there is no need to apply softmax to the output of the affine transformation.\n",
    "2. The way we imported the module `torch.nn.functional`, the function call is shorter.\n",
    "\n",
    "Print the cross-entropy of a zero-initialized logistic regression model on the validation dataset. Is it what you'd expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(F.cross_entropy(valid_logits, valid_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing the Training Loop\n",
    "\n",
    "Now for the train loop itself!\n",
    "\n",
    "We will collect evaluation data in 5 lists:\n",
    "1. Training loss (cross-entropy: CE)\n",
    "2. Validation loss (cross-entropy: CE)\n",
    "3. Training accuracy\n",
    "4. Validation accuracy\n",
    "5. Number of steps taken at evaluation.\n",
    "\n",
    "After each `config[\"valid_interval\"]` steps, you can just\n",
    "calculate items 2 and 4 and add them to their respective lists.\n",
    "Also, you can just append the current number of steps to item 5.\n",
    "\n",
    "As for items 1 and 3:\n",
    "\n",
    "6. Keep a separate list to store train accuracies and losses on minibatches.\n",
    "7. At each train step:\n",
    "    1. Calculate accuracy and append it to the appropriate list in step 6.\n",
    "        Do this before the gradient descent step as otherwise\n",
    "        you're evaluating the model on the minibatch it was just\n",
    "        optimized on.\n",
    "    2. Append the loss value you calculated for the gradient descent step\n",
    "        to the appropriate list in step 6. Don't forget to\n",
    "        detach the tensor and convert the value to a Python `float`.\n",
    "8. When it is time to store evaluation data, append to the\n",
    "    lists in items 1 and 3 the respective averages\n",
    "    of the entries in the lists in item 6.\n",
    "9. Call the `clear` methods of the lists in item 6 to empty them.\n",
    "\n",
    "Note also that we will leave the dataloader as an argument of the function.\n",
    "This is so that when we write the random reshuffling dataloader,\n",
    "we can just plug that in this function\n",
    "(after a change that we will do then).\n",
    "\n",
    "WARNING: Don't forget to `break` the training loop once you reach\n",
    "`config[\"steps_num\"]` steps.\n",
    "\n",
    "If you want to add a tqdm progress bar, there are two options:\n",
    "\n",
    "10. Iterate over the progress bar\n",
    "    and get minibatches by calling `next` on the dataloader.\n",
    "11. Iterate over the dataloader and advance the progress bar manually\n",
    "    by manually its `update` method each step.\n",
    "\n",
    "Write the function and run it on our dataset split.\n",
    "Print the last values of the accuracy, loss and step lists."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_logistic_regression(\n",
    "    config: dict,\n",
    "    label_num: int,\n",
    "    train_dataloader: Generator[tuple[torch.Tensor, torch.Tensor]],\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_labels: torch.Tensor,\n",
    "    use_bias=True\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Train a logistic regression model\n",
    "    and record accuracy and cross-entropy values regularly.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        learning_rate : float\n",
    "            The learning rate used in SGD.\n",
    "        steps_num : int\n",
    "            The number of SGD steps to take\n",
    "        valid_interval : int\n",
    "            The frequency of evaluations given in training steps.\n",
    "    label_num : int\n",
    "        The number of distinct labels.\n",
    "    train_dataloader : Generator[tuple[torch.Tensor, torch.Tensor]]\n",
    "        A dataloader that output `minibatch_features, minibatch_labels`\n",
    "        pairs indefinitely.\n",
    "    valid_features : torch.Tensor\n",
    "        Validation feature matrix.\n",
    "    valid_labels : torch.Tensor\n",
    "        Validation label vector.\n",
    "    use_bias : bool, optional\n",
    "        Whether to use a bias vector in the logistic regression model.\n",
    "        Default: `True`\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A dictionary with the following keys and values:\n",
    "    train_accuracies : list[float]\n",
    "        A list of accuracies on minibatches\n",
    "        averaged on evaluation intervals.\n",
    "    train_losses : list[float]\n",
    "        A list of cross-entropy values on minibatches\n",
    "        averaged on evaluation intervals.\n",
    "    valid_accuracies : list[float]\n",
    "        A list of validation accuracies measured each evaluation.\n",
    "    valid_losses : list[float]\n",
    "        A list of validation cross-entropies measured each evaluation.\n",
    "    valid_steps : list[int]\n",
    "        The number of train steps taken at the time of each evaluation.\n",
    "    weights : torch.Tensor\n",
    "        The logistic model weights we get at the end of training.\n",
    "    bias : torch.Tensor, optional\n",
    "        The logistic model bias we get at the end of training, if `use_bias`.\n",
    "    \"\"\"\n",
    "    (\n",
    "        train_accuracies_steps,\n",
    "        train_accuracies,\n",
    "        train_losses,\n",
    "        train_losses_steps,\n",
    "        valid_accuracies,\n",
    "        valid_losses,\n",
    "        valid_steps\n",
    "    ) = [], [], [], [], [], [], []\n",
    "    weights = torch.zeros(\n",
    "        (valid_features.shape[1], label_num),\n",
    "        device=valid_features.device,\n",
    "        dtype=valid_features.dtype,\n",
    "        requires_grad=True\n",
    "    )\n",
    "    if use_bias:\n",
    "        bias = torch.zeros_like(weights[0])\n",
    "\n",
    "    optimizer = torch.optim.SGD([weights], lr=config[\"learning_rate\"])\n",
    "\n",
    "    step_id = 0\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    for minibatch_features, minibatch_labels in train_dataloader:\n",
    "        if step_id and step_id % config[\"valid_interval\"] == 0:\n",
    "            for source, target in (\n",
    "                (train_accuracies_steps, train_accuracies),\n",
    "                (train_losses_steps, train_losses)\n",
    "            ):\n",
    "                target.append(sum(source) / len(source))\n",
    "                source.clear()\n",
    "\n",
    "            with torch.no_grad():\n",
    "                logits = valid_features @ weights\n",
    "                if use_bias:\n",
    "                    logits = logits + bias\n",
    "\n",
    "            valid_accuracies.append(get_accuracy(\n",
    "                logits,\n",
    "                valid_labels\n",
    "            ))\n",
    "            valid_losses.append(F.cross_entropy(\n",
    "                logits,\n",
    "                valid_labels\n",
    "            ).cpu().item())\n",
    "            valid_steps.append(step_id)\n",
    "\n",
    "        step_id += 1\n",
    "        progress_bar.update()\n",
    "\n",
    "        if step_id > config[\"steps_num\"]:\n",
    "            break\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        logits = minibatch_features @ weights\n",
    "        if use_bias:\n",
    "            logits = logits + bias\n",
    "\n",
    "        train_accuracies_steps.append(get_accuracy(\n",
    "            logits.detach(),\n",
    "            minibatch_labels\n",
    "        ))\n",
    "        loss = F.cross_entropy(logits, minibatch_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_losses_steps.append(loss.detach().cpu().item())\n",
    "\n",
    "    progress_bar.close()\n",
    "\n",
    "    return {\n",
    "        \"train_accuracies\": train_accuracies,\n",
    "        \"train_losses\": train_losses,\n",
    "        \"valid_accuracies\": valid_accuracies,\n",
    "        \"valid_losses\": valid_losses,\n",
    "        \"valid_steps\": valid_steps,\n",
    "        \"weights\": weights\n",
    "    }\n",
    "\n",
    "output = train_logistic_regression(\n",
    "    config,\n",
    "    10,\n",
    "    dataloader_random_resampling,\n",
    "    valid_features,\n",
    "    valid_labels\n",
    ")\n",
    "\n",
    "for key in [\n",
    "    \"train_accuracies\",\n",
    "    \"train_losses\",\n",
    "    \"valid_accuracies\",\n",
    "    \"valid_losses\",\n",
    "    \"valid_steps\"\n",
    "]:\n",
    "    print(key, output[key][-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Training Curves\n",
    "\n",
    "1. In each line plot, make the x-axis values the list of step numbers.\n",
    "2. Give each `plt.plot` call a `label` keyword argument\n",
    "    so that we can see which line plot represents which metric.\n",
    "3. To write the labels on the canvas, after drawing all 4 plots,\n",
    "    call `plt.legend`.\n",
    "4. If the cross-entropies make the y-axis values too wide,\n",
    "    you can limit them with `plt.ylim`. See docs here:\n",
    "    https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.ylim.html\n",
    "5. If the colors of the line plots are difficult to distinguish\n",
    "    or you would just prefer other colors, you can change them\n",
    "    with the `color` keyword argument of `plt.plot`.\n",
    "    For options, look for example here:\n",
    "    https://matplotlib.org/stable/users/explain/colors/colors.html\n",
    "6. You can specify what the x-axis values signify with `plt.xlabel`\n",
    "7. Maybe even give the plot a title with `plt.title`\n",
    "\n",
    "Wrap all of this in a function for easy reuse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_output(\n",
    "    output: dict,\n",
    "    title=\"\"\n",
    "):\n",
    "    \"\"\"\n",
    "    Given a training result dictionary output by\n",
    "    `train_logistic_regression`,\n",
    "    make line plots of the metric curves.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    output : dict\n",
    "        The output of `train_logistic_regression`.\n",
    "    title : str, optional\n",
    "        If given, specifies the title of the plot. Default: \"\"\n",
    "    \"\"\"\n",
    "    plt.plot(\n",
    "        output[\"valid_steps\"],\n",
    "        output[\"train_accuracies\"],\n",
    "        label=\"training accuracy\"\n",
    "    )\n",
    "    plt.plot(\n",
    "        output[\"valid_steps\"],\n",
    "        output[\"valid_accuracies\"],\n",
    "        label=\"validation accuracy\"\n",
    "    )\n",
    "    plt.plot(\n",
    "        output[\"valid_steps\"],\n",
    "        output[\"train_losses\"],\n",
    "        label=\"training cross-entropy\"\n",
    "    )\n",
    "    plt.plot(\n",
    "        output[\"valid_steps\"],\n",
    "        output[\"valid_losses\"],\n",
    "        color=\"cyan\",\n",
    "        label=\"validation cross-entropy\"\n",
    "    )\n",
    "\n",
    "    plt.legend()\n",
    "    if title:\n",
    "        plt.title(title)\n",
    "    plt.xlabel(\"Train steps\")\n",
    "    plt.ylim(0, 1)\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "plot_output(\n",
    "    output,\n",
    "    title=\"Logistic regression by SGD with replacement on MNIST\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SGD with Random Reshuffle\n",
    "\n",
    "Time to create our random reshuffling dataloader!\n",
    "Note that this is the only training component we'll have to change,\n",
    "we can then just plug it into `train_logistic_regression` (after a small adjustment to the function that will still accept both -- and other -- kinds of dataloaders).\n",
    "The coding paradigm that focuses on plug-and-playable components\n",
    "is called *modular*. This focus is very important in machine learning\n",
    "for tractable experiments."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Index reshuffler\n",
    "\n",
    "We will first create a generator of indices. The point is that we\n",
    "can test this separately.\n",
    "\n",
    "1. This generator should know the dataset size and the minibatch size.\n",
    "2. Given these, you should first calculate\n",
    "    the number of minibatches per epoch.\n",
    "    For this, I suggest using the function `divmod`.\n",
    "    You can see the documentation here:\n",
    "    https://docs.python.org/3/library/functions.html#divmod\n",
    "3. You can get a random permutation of the number 0, ..., dataset size - 1\n",
    "    with the function `torch.randperm`.\n",
    "4. Then you can increment a minibatch index variable\n",
    "    and use this to yield consecutive slices from the permuted indices\n",
    "    of length the minibatch size.\n",
    "    1. Even if you give the slice like this,\n",
    "        the last slice may be smaller and that's OK.\n",
    "5. When the minibatch index variable reaches the number of minibatches\n",
    "    per epoch:\n",
    "    1. reset it to zero and\n",
    "    2. get a new permutation of indices.\n",
    "6. You can test the function for example\n",
    "    by printing the results of 6 iterations\n",
    "    with dataset size 10 and minibatch size 4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_random_reshuffler(\n",
    "    dataset_size: int,\n",
    "    minibatch_size: int\n",
    ") -> Generator[torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Yield consecutive slices of `minibatch_size`\n",
    "    from a permutation of the integers `0, ..., dataset_size - 1`.\n",
    "    The last slice can be smaller.\n",
    "    After the last slice was yielded,\n",
    "    reshuffle the integers and start again.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset_size : int\n",
    "        The integers `0, ..., dataset_size-1` are permuted and sliced.\n",
    "    minibatch_size : int\n",
    "        The size of the slices.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A Generator[torch.Tensor] of index tensors.\n",
    "    \"\"\"\n",
    "    q, r = divmod(dataset_size, minibatch_size)\n",
    "    minibatch_num = q + min(1, r)\n",
    "    minibatch_index = minibatch_num\n",
    "    while True:\n",
    "        if minibatch_index == minibatch_num:\n",
    "            dataset_indices = torch.randperm(dataset_size)\n",
    "            minibatch_index = 0\n",
    "\n",
    "        yield dataset_indices[\n",
    "            minibatch_index * minibatch_size\n",
    "           :(minibatch_index + 1) * minibatch_size\n",
    "        ]\n",
    "\n",
    "        minibatch_index += 1\n",
    "\n",
    "index_generator = get_random_reshuffler(10, 4)\n",
    "for _ in range(6):\n",
    "    print(next(index_generator))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Random Reshuffling Dataloader\n",
    "\n",
    "Now to get a random reshuffling dataloader,\n",
    "you can iterate over the random reshuffler\n",
    "and for each index tensor you get,\n",
    "yield the respective parts of the train features and labels.\n",
    "Write this function, get a dataloader,\n",
    "run 10 iterations and print tensor shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloader_random_reshuffling(\n",
    "    config: dict,\n",
    "    features: torch.Tensor,\n",
    "    labels: torch.Tensor\n",
    ") -> Generator[tuple[torch.Tensor, torch.Tensor]]:\n",
    "    \"\"\"\n",
    "    Gives a feature and a label tensor,\n",
    "    creates a random reshuffling (without replacement) dataloader\n",
    "    that yields pairs `minibatch_features, minibatch_labels` indefinitely.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required key:\n",
    "        minibatch_size : int\n",
    "            The size of the minibatches.\n",
    "    features : torch.Tensor\n",
    "        Tensor of dataset features.\n",
    "        We assume that the first dimension is the batch dimension\n",
    "    labels : torch.Tensor\n",
    "        Tensor of dataset labels.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of tuples `minibatch_features, minibatch_labels`.\n",
    "    \"\"\"\n",
    "    dataset_size = len(features)\n",
    "\n",
    "    minibatch_indices_generator = get_random_reshuffler(\n",
    "        dataset_size,\n",
    "        config[\"minibatch_size\"]\n",
    "    )\n",
    "    for minibatch_indices in minibatch_indices_generator:\n",
    "        yield features[minibatch_indices], labels[minibatch_indices]\n",
    "\n",
    "dataloader_random_reshuffling = get_dataloader_random_reshuffling(\n",
    "    config,\n",
    "    train_features,\n",
    "    train_labels\n",
    ")\n",
    "step_id = 0\n",
    "for minibatch_features, minibatch_labels in dataloader_random_reshuffling:\n",
    "    print(minibatch_features.shape, minibatch_labels.shape)\n",
    "    step_id += 1\n",
    "    if step_id == 10:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adjust Training Loop\n",
    "\n",
    "There is one change you need to make to the dataloader:\n",
    "when you average stepwise train accuracies and losses,\n",
    "you need to take into account that not all minibatches have the same size.\n",
    "Thus:\n",
    "1. besides the minibatch-wise accuracy and loss collectors,\n",
    "    also keep a counter of total number of dataset entries\n",
    "    seen since the last evaluation.\n",
    "2. With this:\n",
    "    1. upon a train step, append to the stepwise lists not the\n",
    "        accuracies and the losses, but their multiples\n",
    "        by the number of entries in the minibatch.\n",
    "    2. at the same time, increment the entry counter\n",
    "        by the minibatch size.\n",
    "    3. At evaluation, instead of taking the average of the entries\n",
    "        in the stepwise lists, sum them up, then divide them by the\n",
    "        values of the total entry number counters.\n",
    "    4. Besides clearing the stepwise lists, reset the\n",
    "        entry number counter to 0.\n",
    "\n",
    "Run the updated training loop with the new dataloader.\n",
    "Print the last entries in the output lists and make the plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_logistic_regression(\n",
    "    config: dict,\n",
    "    label_num: int,\n",
    "    train_dataloader: Generator[tuple[torch.Tensor, torch.Tensor]],\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_labels: torch.Tensor,\n",
    "    use_bias=True\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Train a logistic regression model\n",
    "    and record accuracy and cross-entropy values regularly.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        learning_rate : float\n",
    "            The learning rate used in SGD.\n",
    "        steps_num : int\n",
    "            The number of SGD steps to take\n",
    "        valid_interval : int\n",
    "            The frequency of evaluations given in training steps.\n",
    "    label_num : int\n",
    "        The number of distinct labels.\n",
    "    train_dataloader : Generator[tuple[torch.Tensor, torch.Tensor]]\n",
    "        A dataloader that output `minibatch_features, minibatch_labels`\n",
    "        pairs indefinitely.\n",
    "    valid_features : torch.Tensor\n",
    "        Validation feature matrix.\n",
    "    valid_labels : torch.Tensor\n",
    "        Validation label vector.\n",
    "    use_bias : bool, optional\n",
    "        Whether to use a bias vector in the logistic regression model.\n",
    "        Default: `True`\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A dictionary with the following keys and values:\n",
    "    train_accuracies : list[float]\n",
    "        A list of accuracies on minibatches\n",
    "        averaged on evaluation intervals.\n",
    "    train_losses : list[float]\n",
    "        A list of cross-entropy values on minibatches\n",
    "        averaged on evaluation intervals.\n",
    "    valid_accuracies : list[float]\n",
    "        A list of validation accuracies measured each evaluation.\n",
    "    valid_losses : list[float]\n",
    "        A list of validation cross-entropies measured each evaluation.\n",
    "    valid_steps : list[int]\n",
    "        The number of train steps taken at the time of each evaluation.\n",
    "    weights : torch.Tensor\n",
    "        The logistic model weights we get at the end of training.\n",
    "    bias : torch.Tensor, optional\n",
    "        The logistic model bias we get at the end of training, if `use_bias`.\n",
    "    \"\"\"\n",
    "    (\n",
    "        train_accuracies_steps,\n",
    "        train_accuracies,\n",
    "        train_losses,\n",
    "        train_losses_steps,\n",
    "        valid_accuracies,\n",
    "        valid_losses,\n",
    "        valid_steps\n",
    "    ) = [], [], [], [], [], [], []\n",
    "    entry_counter_since_evaluation = 0\n",
    "    weights = torch.zeros(\n",
    "        (valid_features.shape[1], label_num),\n",
    "        device=valid_features.device,\n",
    "        dtype=valid_features.dtype,\n",
    "        requires_grad=True\n",
    "    )\n",
    "    if use_bias:\n",
    "        bias = torch.zeros_like(weights[0])\n",
    "\n",
    "    optimizer = torch.optim.SGD([weights], lr=config[\"learning_rate\"])\n",
    "\n",
    "    step_id = 0\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    for minibatch_features, minibatch_labels in train_dataloader:\n",
    "        if step_id and step_id % config[\"valid_interval\"] == 0:\n",
    "            for source, target in (\n",
    "                (train_accuracies_steps, train_accuracies),\n",
    "                (train_losses_steps, train_losses)\n",
    "            ):\n",
    "                target.append(sum(source) / entry_counter_since_evaluation)\n",
    "                source.clear()\n",
    "\n",
    "            entry_counter_since_evaluation = 0\n",
    "\n",
    "            with torch.no_grad():\n",
    "                logits = valid_features @ weights\n",
    "                if use_bias:\n",
    "                    logits = logits + bias\n",
    "\n",
    "            valid_accuracies.append(get_accuracy(\n",
    "                logits,\n",
    "                valid_labels\n",
    "            ))\n",
    "            valid_losses.append(F.cross_entropy(\n",
    "                logits,\n",
    "                valid_labels\n",
    "            ).cpu().item())\n",
    "            valid_steps.append(step_id)\n",
    "\n",
    "        step_id += 1\n",
    "        progress_bar.update()\n",
    "\n",
    "        if step_id > config[\"steps_num\"]:\n",
    "            break\n",
    "\n",
    "        minibatch_size = len(minibatch_features)\n",
    "        entry_counter_since_evaluation += minibatch_size\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        logits = minibatch_features @ weights\n",
    "        if use_bias:\n",
    "            logits = logits + bias\n",
    "\n",
    "        train_accuracies_steps.append(get_accuracy(\n",
    "            logits.detach(),\n",
    "            minibatch_labels\n",
    "        ) * minibatch_size)\n",
    "        loss = F.cross_entropy(logits, minibatch_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_losses_steps.append(loss.detach().cpu().item() * minibatch_size)\n",
    "\n",
    "    progress_bar.close()\n",
    "\n",
    "    return {\n",
    "        \"train_accuracies\": train_accuracies,\n",
    "        \"train_losses\": train_losses,\n",
    "        \"valid_accuracies\": valid_accuracies,\n",
    "        \"valid_losses\": valid_losses,\n",
    "        \"valid_steps\": valid_steps,\n",
    "        \"weights\": weights\n",
    "    }\n",
    "\n",
    "output = train_logistic_regression(\n",
    "    config,\n",
    "    10,\n",
    "    dataloader_random_reshuffling,\n",
    "    valid_features,\n",
    "    valid_labels\n",
    ")\n",
    "\n",
    "for key in [\n",
    "    \"train_accuracies\",\n",
    "    \"train_losses\",\n",
    "    \"valid_accuracies\",\n",
    "    \"valid_losses\",\n",
    "    \"valid_steps\"\n",
    "]:\n",
    "    print(key, output[key][-1])\n",
    "\n",
    "plot_output(\n",
    "    output,\n",
    "    title=\"Logistic regression by SGD without replacement on MNIST\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets\n",
    "\n",
    "## MNIST\n",
    "\n",
    "https://huggingface.co/datasets/ylecun/mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
