{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "## Imports\n",
    "\n",
    "Import some old friends: `defaultdict`, `Generator`, `datasets`, `matplotlib` as `mpl`, `matplotlib.pyplot` as `plt`, `torch`, `torch.nn.Functional` as `F` and `tqdm`\n",
    "\n",
    "and some new ones:\n",
    "1. From `collections.abc`, also import `Callable` and `Iterable`. These are type hints for objects with the respective functionalities.\n",
    "2. We will load the tf-idf and truncated SVD algorithms from the scikit-learn library:  \n",
    "https://scikit-learn.org/stable/index.html  \n",
    "This is a collection of machine learning tools. As a Python module, it's called `sklearn`.\n",
    "    1. From `sklearn.decomposition` import `TruncatedSVD`.\n",
    "    2. From `sklearn.feature_extraction.text` import `TfidfVectorizer`.\n",
    "\n",
    "Moreover, import some functions you created in Notebook 0221:\n",
    "1. `get_dataloader_random_reshuffle`\n",
    "2. `line_plot_confidence_band`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from collections.abc import Callable, Generator, Iterable\n",
    "import datasets\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tqdm\n",
    "\n",
    "from util_0226 import (\n",
    "    get_dataloader_random_reshuffle,\n",
    "    line_plot_confidence_band\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constants\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_path\"`: `str`  \n",
    "    Get this from the dataset page  \n",
    "    https://huggingface.co/datasets/stanfordnlp/imdb  \n",
    "    We'll look on the page some more soon.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier, explained in notebook 091224.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Make this a `(10,)`, as for now, we'll train one ensemble with the same hyperparameters.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"labels_dtype\"`: `torch.dtype`  \n",
    "    The datatype we use for label tensors. With multiclass classification, this was `torch.int64`. For binary classification, which is the case here, we want `torch.float32`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `256`.\n",
    "- `\"n_components\"`: `int`  \n",
    "    The number of feature dimensions to find with truncated SVD. Let's start with `100`; it will be homework to check out other options.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `100_000`. Let early stopping take care of stopping.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"valid_interval\"` : `int`  \n",
    "    Make this a `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_path\": \"stanfordnlp/imdb\",\n",
    "    \"device\": \"cpu\",\n",
    "    \"ensemble_shape\": (10,),\n",
    "    \"improvement_threshold\": 1e-4,\n",
    "    \"labels_dtype\": torch.float32,\n",
    "    \"learning_rate\": 1,\n",
    "    \"n_components\": 100,\n",
    "    \"minibatch_size\": 256,\n",
    "    \"seed\": 1,\n",
    "    \"steps_num\": 100_000,\n",
    "    \"steps_without_improvement\": 1000,\n",
    "    \"valid_interval\": 10\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the seed of `torch` pseudo-random generation to the seed in the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Random Integers\n",
    "\n",
    "In this notebook, we will need to provide randomized functions outside of `torch` seeds. For this, we can use the `torch.randint` function.\n",
    "1. As positional argument, give an upper bound for the integers to generate. By default, I like to use `1 << 31`. This is $2^{31}$, expressed via the [left bit shift operation](https://docs.python.org/3/reference/expressions.html#shifting-operations).\n",
    "2. Via setting the `size` keyword argument to `()`, you can specify that you want to receive a scalar value (0-dimensional shape).\n",
    "3. Feed the output of `torch.randint` to the `int` function to receive a Python integer and not a `torch` integer tensor.\n",
    "\n",
    "Write the function below and print its output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_seed(\n",
    "    upper=1 << 31\n",
    ") -> int:\n",
    "    \"\"\"\n",
    "    Generates a random integer by the `torch` PRNG,\n",
    "    to be used as seed in a stochastic function.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    upper : int, optional\n",
    "        Exclusive upper bound of the interval to generate integers from.\n",
    "        Default: 1 << 31.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A random integer.\n",
    "    \"\"\"\n",
    "    return int(torch.randint(upper, size=()))\n",
    "\n",
    "print(get_seed())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing\n",
    "\n",
    "## Perusing the Dataset Page\n",
    "\n",
    "Let's gather information from the dataset page:  \n",
    "https://huggingface.co/datasets/stanfordnlp/imdb\n",
    "\n",
    "1. We can see that there are 3 splits: `\"train\"`, `\"test\"` and `\"unsupervised\"`.\n",
    "    1. The first two have binary labels, with 25 000 entries each.\n",
    "    2. The third has 50 000 entries with inputs only. This is for pretraining methods to get better feature vectors; we'll see more of that later.\n",
    "2. The dataset has 2 columns: `\"text\"`, which are the inputs and `\"label`\", that gives the labels. Fortunately, here, the labels are already the integers 0 and 1, so we don't have to transform a categorical variable as in the abalone dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading and Splitting the Training Set\n",
    "\n",
    "Load the training split. As both the training and test splits have 25 000 entries, we shouldn't split as many entries off as validation set as there are in the test set. Instead, make a 90%-10% training-validation split on the training set. From now on, we'll call this 22 500 entry set the training set and the remaining 2500 entry set the validation set.\n",
    "\n",
    "As seed in the splitting function, use an output of `get_seed`.\n",
    "\n",
    "Print the training and validation splits and the first entry in each."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = datasets.load_dataset(\n",
    "    config[\"dataset_path\"],\n",
    "    split=\"train\"\n",
    ")\n",
    "train_valid = train.train_test_split(.1, seed=get_seed())\n",
    "train, valid = (train_valid[key] for key in (\"train\", \"test\"))\n",
    "print(train)\n",
    "print(train[0])\n",
    "print(valid)\n",
    "print(valid[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Latent Semantic Analysis (LSA)\n",
    "\n",
    "Now we shall perform LSA, the two-stage operation of tf-idf and truncated SVD. Both operations are conducted by `sklearn` *transformers*, not to be confused by the neural network architecture. These transformers are objects with the following methods:\n",
    "1. `fit`: This fits the transformer to the data it receives. For example:\n",
    "    1. A `TfidfVectorizer` builds a vocabulary from the corpus it receives and calculates the inverse document frequencies (idf).\n",
    "    2. A `TruncatedSVD` calculates the dimension reduction matrix $V'$.\n",
    "2. `transform`: With this, the transformer transforms the data it receives using the state it got into via fitting.\n",
    "    1. A `TfidfVectorizer` transforms documents to tf-idf vectors where\n",
    "        1. it uses the vocabulary and idf values it got in fitting and\n",
    "        2. it uses the term frequences (tf) of the documents.\n",
    "    2. A `TruncatedSVD` applies the dimension reduction operator $V'$ on the matrix it receives.\n",
    "3. `fit_transform`: This means apply `fit` then `transform` on the same data.\n",
    "\n",
    "It's important to know about these as with each transformer:\n",
    "1. we want to apply `fit_transform` to the training corpus, then\n",
    "2. we want to apply `transform` to the validation corpus.\n",
    "\n",
    "Note that the `fit_transform` and `transform` methods of `TfidfVectorizer` and `TruncatedSVD` output `numpy.ndarray` objects. These are the basic array objects of `numpy`:  \n",
    "https://numpy.org/  \n",
    "the premier CPU-bound array manipulation library. It does not have GPU support or automatic differentiation, so for our purposes this is a precursor to `torch`.\n",
    "\n",
    "To transform the `numpy.ndarray`s to `torch.Tensor`s, you can use the function `torch.asarray`. Give it the following:\n",
    "1. The array as first position argument.\n",
    "2. The `device` and `dtype` keyword arguments, if needed.\n",
    "\n",
    "As the `TruncatedSVD` transformer uses a fast randomized SVD solver, for reproducibility please set its `random_state` keyword argument to an output of `get_seed`.\n",
    "\n",
    "Write the following function, apply it to the training and validation datasets and print out the shape, device and dtype of the tensors you get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lsa(\n",
    "    config: dict,\n",
    "    training_dataset: datasets.Dataset,\n",
    "    validation_datasets: Iterable[datasets.Dataset] = ()\n",
    ") -> Generator[tuple[torch.Tensor, torch.Tensor]]:\n",
    "    \"\"\"\n",
    "    Fit a composite of a `TfidfVectorizer` and a `TruncatedSVD`\n",
    "    on the corpus at the `\"text\"` key of the training dataset.\n",
    "    Then use this composite to transform the training corpus\n",
    "    and the optional validation corpora to feature matrices.\n",
    "    Also returns the labels in the datasets as tensors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        \"device\" : torch.device\n",
    "            The device to store feature matrices and label vectors on.\n",
    "        \"labels_dtype\" : torch.dtype\n",
    "            The datatype of label vectors.\n",
    "        \"n_components\": int\n",
    "            The number of dimensions to reduce the feature dimensions to\n",
    "            with truncated SVD.\n",
    "    training_dataset : datasets.Dataset\n",
    "        The training dataset. Required keys:\n",
    "        \"text\" : Iterable[str]\n",
    "            The dataset corpus\n",
    "        \"label\" : Iterable[int]\n",
    "            The dataset labels\n",
    "    validation_datasets : Iterable[datasets.Dataset], optional\n",
    "        An iterable of additional datasets,\n",
    "        of the same structure as `training_dataset`.\n",
    "        Default: `()`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of pairs of feature matrices and label vectors.\n",
    "    The first pair is the training data.\n",
    "    Then the optional validation data follows.\n",
    "    \"\"\"\n",
    "    tf_idf = TfidfVectorizer()\n",
    "    train_features = tf_idf.fit_transform(training_dataset[\"text\"])\n",
    "\n",
    "    truncated_svd = TruncatedSVD(\n",
    "        n_components=config[\"n_components\"],\n",
    "        random_state=get_seed()\n",
    "    )\n",
    "    train_features = truncated_svd.fit_transform(train_features)\n",
    "\n",
    "    train_features = torch.asarray(\n",
    "        train_features,\n",
    "        device=config[\"device\"],\n",
    "        dtype=torch.float32\n",
    "    )\n",
    "    train_labels = training_dataset.with_format(\n",
    "        \"torch\",\n",
    "        device=config[\"device\"]\n",
    "    )[\"label\"].to(config[\"labels_dtype\"])\n",
    "    \n",
    "    yield train_features, train_labels\n",
    "\n",
    "    for validation_dataset in validation_datasets:\n",
    "        valid_features = tf_idf.transform(validation_dataset[\"text\"])\n",
    "        valid_features = truncated_svd.transform(valid_features)\n",
    "        valid_features = torch.asarray(\n",
    "            valid_features,\n",
    "            device=config[\"device\"],\n",
    "            dtype=torch.float32\n",
    "        )\n",
    "        \n",
    "        valid_labels = validation_dataset.with_format(\n",
    "            \"torch\",\n",
    "            device=config[\"device\"]\n",
    "        )[\"label\"].to(config[\"labels_dtype\"])\n",
    "\n",
    "        yield (valid_features, valid_labels)\n",
    "\n",
    "(train_features, train_labels), (valid_features, valid_labels) = lsa(\n",
    "    config,\n",
    "    train,\n",
    "    [valid]\n",
    ")\n",
    "\n",
    "for t in (\n",
    "    train_features,\n",
    "    train_labels,\n",
    "    valid_features,\n",
    "    valid_labels\n",
    "):\n",
    "    print(t.shape, t.device, t.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get a training dataloader and print the shape of the tensors of a minibatch it yields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = get_dataloader_random_reshuffle(\n",
    "    config,\n",
    "    train_features,\n",
    "    train_labels\n",
    ")\n",
    "for t in next(train_dataloader):\n",
    "    print(t.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n",
    "\n",
    "Recall that this time, our logistic regression model will output a single logit value per entry, that of the positive label. That is:\n",
    "1. the weight tensor will have shape `ensemble_shape + (feature_dim, 1)`,\n",
    "1. optionally, the bias tensor will have shape `ensemble_shape + (1, 1)`, and\n",
    "2. thus a logit tensor will have shape `ensemble_shape + (dataset_size, 1)`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary Accuracy\n",
    "\n",
    "1. To get a tensor of predicted positive labels:\n",
    "    1. get rid of the last dimension of the logit tensor and\n",
    "    2. check which logit entry is positive.\n",
    "2. The label tensor may not have ensemble dimensions, for example if we're looking at the validation labels. Therefore, use the `broadcast_to` method to extend the tensor to have validation dimensions. Then transform it to a Boolean tensor.\n",
    "3. Calculate the accuracy from the two Boolean tensors at hand.\n",
    "\n",
    "Write the binary accuracy function. Print the binary accuracy of an all 0 logit tensor with respect to a train minibatch label set and the validation labels. Are they what you expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_binary_accuracy(\n",
    "    logits: torch.Tensor,\n",
    "    labels: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Get the binary accuracy between a label and a logit tensor.\n",
    "    It can handle arbitrary ensemble shapes.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : torch.Tensor\n",
    "        The logit tensor. We assume it has shape\n",
    "        `ensemble_shape + (dataset_size, 1)`.\n",
    "    labels : torch.Tensor\n",
    "        The tensor of true labels. We assume it has shape\n",
    "        `(dataset_size,)` or `ensemble_shape + (dataset_size,)`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The tensor of binary accuracies per ensemble member\n",
    "    of shape `ensemble_shape`.\n",
    "    \"\"\"\n",
    "    predict_positives = logits[..., 0] > 0\n",
    "    true_positives = labels.to(torch.bool)\n",
    "\n",
    "    return (\n",
    "        predict_positives == true_positives\n",
    "    ).to(torch.float32).mean(dim=-1)\n",
    "\n",
    "\n",
    "for labels in (valid_labels, next(train_dataloader)[1]):\n",
    "    print(get_binary_accuracy(\n",
    "        torch.zeros_like(labels)[..., None],\n",
    "        labels\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary Cross-Entropy\n",
    "\n",
    "We'll use as loss function [`F.binary_cross_entropy_with_logit`](https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.htm). This calculates binary cross-entropy between a logit and a label tensor. That is, you don't have to convert the logits to probabilities beforehand. Besides convenience, this approach is numerically stabler. Please note that as the `F.binary_cross_entropy_with_logit` function does not broadcast, you need to use the `broadcast_to` method of the `labels` tensor to broadcast it to the required shape.\n",
    "\n",
    "Write the binary cross-entropy function. Print the binary cross-entropy of an all 0 logit tensor with respect to a train minibatch label set and the validation labels. Are they what you expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_binary_cross_entropy(\n",
    "    logits: torch.Tensor,\n",
    "    labels: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Get the binary cross-entropy between a label and a logit tensor.\n",
    "    It can handle arbitrary ensemble shapes.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : torch.Tensor\n",
    "        The logit tensor. We assume it has shape\n",
    "        `ensemble_shape + (dataset_size,)`.\n",
    "    labels : torch.Tensor\n",
    "        The tensor of true labels. We assume it has shape\n",
    "        `(dataset_size,)` or `ensemble_shape + (dataset_size, 1)`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The tensor of binary cross-entropies per ensemble member\n",
    "    of shape `ensemble_shape`.\n",
    "    \"\"\"\n",
    "\n",
    "    return F.binary_cross_entropy_with_logits(\n",
    "        logits[..., 0],\n",
    "        labels.broadcast_to(logits.shape[:-1]),\n",
    "        reduction=\"none\"\n",
    "    ).mean(dim=-1)\n",
    "\n",
    "\n",
    "for labels in (valid_labels, next(train_dataloader)[1]):\n",
    "    print(get_binary_cross_entropy(\n",
    "        torch.zeros_like(labels)[..., None],\n",
    "        labels,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Update the Training Loop\n",
    "\n",
    "Make the following changes to the training loop function you achieved in Notebook 0226:\n",
    "1. Give two more positional arguments: `get_loss` and `get_metrics`.\n",
    "    1. We expect these functions to have as inputs:\n",
    "        2. a logit tensor and\n",
    "        1. a label tensor, and\n",
    "\n",
    "    2. they return a loss and metric tensor of shape `ensemble_shape`, respectively.\n",
    "    \n",
    "    Replace the use of `F.cross_entropy` and `get_accuracy` to these.\n",
    "2. Note that the positional argument `label_num` is used to determine the number of output feature dimensions. Therefore, in case we intend to use binary classifiation, we want to make it a 1. To reflect this, we rename the positional argument to `out_features`.\n",
    "2. Add additional keyword arguments\n",
    "    1. `loss_name=\"loss\"` and\n",
    "    2. `metric_name=\"metric\"`.\n",
    "\n",
    "    Then in the update dictionary, replace the keys of the output dictionary that have `\"cross-entropy\"` and `\"accuracy\"` in their names to the appropriate keys using the values of these keyword arguments.\n",
    "\n",
    "Run training on the IMDB dataset with\n",
    "1. `get_binary_cross_entropy` as loss function and\n",
    "2. `get_binary_accuracy` as metric function.\n",
    "\n",
    "Draw plots with confidence bands of training and validation binary cross-entropy and binary accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_logistic_regression(\n",
    "    config: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    out_features: int,\n",
    "    train_dataloader: Generator[tuple[torch.Tensor, torch.Tensor]],\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_labels: torch.Tensor,\n",
    "    loss_name=\"loss\",\n",
    "    metric_name=\"metric\",\n",
    "    use_bias=True\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Train a logistic regression model on a classification task.\n",
    "    Support model ensembles of arbitrary shape.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        ensemble_shape : tuple[int]\n",
    "            The shape of the model ensemble.\n",
    "        improvement_threshold : float\n",
    "            Making the best validation score this much better\n",
    "            counts as an improvement.\n",
    "        learning_rate : float | torch.Tensor\n",
    "            The learning rate of the SGD optimization.\n",
    "            If a tensor, then it should have shape\n",
    "            broadcastable to `ensemble_shape`.\n",
    "            In that case, the members of the ensemble are trained with\n",
    "            different learning rates.\n",
    "        steps_num : int\n",
    "            The maximum number of training steps to take.\n",
    "        steps_without_improvement : int\n",
    "            The maximum number of training steps without improvement to take.\n",
    "        valid_interval : int\n",
    "            The frequency of evaluations,\n",
    "            measured in the number of train steps.\n",
    "    out_features : int\n",
    "        The number of output features.\n",
    "        When training a binary logistic regression model, this should be 1.\n",
    "        Otherwise, this should be\n",
    "        the number of distinct labels in the classification task.\n",
    "    train_dataloader : Generator[tuple[torch.Tensor, torch.Tensor]]\n",
    "        A training minibatch dataloader, that yields pairs of\n",
    "        feature and label tensors indefinitely.\n",
    "        We assume that these have shape\n",
    "        `ensemble_shape + (minibatch_size, feature_dim)`\n",
    "        and `ensemble_shape + (minibatch_size,)`\n",
    "        respectively.\n",
    "    valid_features : torch.Tensor\n",
    "        Validation feature matrix.\n",
    "    valid_labels : torch.Tensor\n",
    "        Validation label vector.\n",
    "    loss_name : str, optional\n",
    "        The name of the loss values in the output dictionary.\n",
    "        Default: \"loss\"\n",
    "    metric_name : str, optional\n",
    "        The name of the metric values in the output dictionary.\n",
    "        Default: \"metric\"\n",
    "    use_bias : bool, optional\n",
    "        Whether to use a bias vector in the logistic regression model.\n",
    "        Default: `True`\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following keys:\n",
    "        best scores : torch.Tensor\n",
    "            The best validation accuracy per each ensemble member\n",
    "        best weights : torch.Tensor\n",
    "            The logistic regression weights\n",
    "            that were the best per each ensemble member.\n",
    "        training {metric_name} : torch.Tensor\n",
    "            The tensor of training metric values, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        training {loss_name} : torch.Tensor\n",
    "            The tensor of training loss values, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        training steps : list[int]\n",
    "            The list of the number of training steps at each evaluation.\n",
    "        validation {metric_name} : torch.Tensor\n",
    "            The tensor of validation metric values, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        validation {loss_name} : torch.Tensor\n",
    "            The tensor of validation loss values, of shape\n",
    "            `(evaluation_num,) + ensemble_shape`.\n",
    "        best bias : torch.Tensor, optional\n",
    "            The logistic regression biases\n",
    "            that were the best per each ensemble member, if used.\n",
    "    \"\"\"\n",
    "    device = valid_features.device\n",
    "    features_dtype = valid_features.dtype\n",
    "    output = defaultdict(list)\n",
    "\n",
    "    best_scores = torch.zeros(\n",
    "        config[\"ensemble_shape\"],\n",
    "        device=device,\n",
    "        dtype=features_dtype\n",
    "    ).log()\n",
    "    steps_without_improvement = 0\n",
    "\n",
    "    if isinstance(config[\"learning_rate\"], torch.Tensor):\n",
    "        learning_rate = config[\"learning_rate\"][..., None, None]\n",
    "    else:\n",
    "        learning_rate = config[\"learning_rate\"]\n",
    "\n",
    "    train_accuracies_step = torch.zeros(\n",
    "        config[\"ensemble_shape\"],\n",
    "        device=device,\n",
    "        dtype=features_dtype\n",
    "    )\n",
    "    train_entries = 0\n",
    "    train_losses_step = torch.zeros(\n",
    "        config[\"ensemble_shape\"],\n",
    "        device=device,\n",
    "        dtype=features_dtype\n",
    "    )\n",
    "\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    step_id = 0\n",
    "    weights = torch.zeros(\n",
    "        config[\"ensemble_shape\"] + (valid_features.shape[1], out_features),\n",
    "        device=device,\n",
    "        dtype=features_dtype,\n",
    "        requires_grad=True\n",
    "    )\n",
    "\n",
    "    best_weights = torch.empty_like(weights, requires_grad=False)\n",
    "\n",
    "    if use_bias:\n",
    "        bias = torch.zeros_like(weights[..., 0:1, :], requires_grad=True)\n",
    "        best_bias = torch.empty_like(bias, requires_grad=False)\n",
    "\n",
    "    for minibatch_features, minibatch_labels in train_dataloader:\n",
    "        minibatch_size = minibatch_labels.shape[-1]\n",
    "        weights.grad = None\n",
    "        if use_bias:\n",
    "            bias.grad = None\n",
    "\n",
    "        logits = minibatch_features @ weights\n",
    "        if use_bias:\n",
    "            logits = logits + bias\n",
    "\n",
    "        train_accuracies_step += get_metric(\n",
    "            logits.detach(),\n",
    "            minibatch_labels\n",
    "        ) * minibatch_size\n",
    "        loss = get_loss(\n",
    "            logits,\n",
    "            minibatch_labels\n",
    "        )\n",
    "        loss.sum().backward()\n",
    "        with torch.no_grad():\n",
    "            weights -= learning_rate * weights.grad\n",
    "            if use_bias:\n",
    "                bias -= learning_rate * bias.grad\n",
    "\n",
    "        train_losses_step += loss.detach() * minibatch_size\n",
    "        train_entries += minibatch_size\n",
    "\n",
    "        progress_bar.update()\n",
    "        step_id += 1\n",
    "        if step_id % config[\"valid_interval\"] == 0:\n",
    "            with torch.no_grad():\n",
    "                logits = valid_features @ weights\n",
    "                if use_bias:\n",
    "                    logits = logits + bias\n",
    "\n",
    "            valid_accuracy = get_metric(\n",
    "                logits,\n",
    "                valid_labels\n",
    "            )\n",
    "\n",
    "            valid_loss = get_loss(\n",
    "                logits,\n",
    "                valid_labels\n",
    "            )\n",
    "\n",
    "            output[f\"training {metric_name}\"].append(\n",
    "                (train_accuracies_step / train_entries)\n",
    "            )\n",
    "            output[f\"training {loss_name}\"].append(\n",
    "                (train_losses_step / train_entries)\n",
    "            )\n",
    "            output[\"training steps\"].append(step_id)\n",
    "            output[f\"validation {metric_name}\"].append(valid_accuracy)\n",
    "            output[f\"validation {loss_name}\"].append(valid_loss)\n",
    "\n",
    "            train_accuracies_step.zero_()\n",
    "            train_entries = 0\n",
    "            train_losses_step.zero_()\n",
    "\n",
    "            improvement = valid_accuracy - best_scores\n",
    "            improvement_mask = improvement > config[\"improvement_threshold\"]\n",
    "\n",
    "            if improvement_mask.any():\n",
    "                best_scores[improvement_mask] \\\n",
    "                    = valid_accuracy[improvement_mask]\n",
    "                best_weights[improvement_mask] = weights[improvement_mask]\n",
    "                steps_without_improvement = 0\n",
    "            else:\n",
    "                steps_without_improvement += config[\"valid_interval\"]\n",
    "\n",
    "            if (\n",
    "                step_id >= config[\"steps_num\"]\n",
    "             or (\n",
    "                    steps_without_improvement\n",
    "                 >= config[\"steps_without_improvement\"]\n",
    "                )  \n",
    "            ):\n",
    "                for key in (\n",
    "                    f\"training {metric_name}\",\n",
    "                    f\"training {loss_name}\",\n",
    "                    f\"validation {metric_name}\",\n",
    "                    f\"validation {loss_name}\"\n",
    "                ):\n",
    "                    output[key] = torch.stack(output[key]).cpu()\n",
    "\n",
    "                output[\"best scores\"] = best_scores\n",
    "                output[\"best weights\"] = best_weights\n",
    "                if use_bias:\n",
    "                    output[\"best_bias\"] = best_bias\n",
    "                progress_bar.close()\n",
    "\n",
    "                return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = train_logistic_regression(\n",
    "    config,\n",
    "    get_binary_cross_entropy,\n",
    "    get_binary_accuracy,\n",
    "    1,\n",
    "    train_dataloader,\n",
    "    valid_features,\n",
    "    valid_labels,\n",
    "    \"binary cross-entropy\",\n",
    "    \"binary accuracy\"\n",
    ")\n",
    "\n",
    "for key in (\n",
    "    \"training binary accuracy\",\n",
    "    \"training binary cross-entropy\",\n",
    "    \"validation binary accuracy\",\n",
    "    \"validation binary cross-entropy\"\n",
    "):\n",
    "    line_plot_confidence_band(\n",
    "        output[\"training steps\"],\n",
    "        output[key],\n",
    "        label=key\n",
    "    )\n",
    "\n",
    "plt.legend()\n",
    "plt.title(\"Logistic Regression and Latent Semantic Analysis on IMDB\")\n",
    "plt.xlabel(\"Train steps\")\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets\n",
    "\n",
    "## IMDB\n",
    "\n",
    "http://ai.stanford.edu/~amaas/data/sentiment/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# References\n",
    "\n",
    "[1] Karen Spärck Jones: *A statistical interpretation of term specificity and its application in retrieval*. Journal of Documentation, Volume 28, Number 1, 1972, pp. 11-21. [doi:10.1108/eb026526](https://doi.org/10.1108%2Feb026526), [link to paper](https://www.cl.cam.ac.uk/archive/ksj21/ksjdigipapers/jdoc72.pdf)\n",
    "\n",
    "[2] Nathan Halko, Per-Gunnar Martinsson and Joel A. Tropp: *Finding Structure with Randomness: Probabilistic Algorithms for Constructing Approximate Matrix Decompositions*. SIAM Review, Volume 53, Number 2, 2011, pp. 217-288. [doi:10.1137/090771806](https://doi.org/10.1137/090771806), [link to paper](https://arxiv.org/abs/0909.4061)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
