{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construction and Training of MLP Ensembles\n",
    "\n",
    "Today, we'll train MLPs on Abalone, to see how much better they can get than least squares."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `Callable`, `Iterable`, `plt`, `torch`, `torch.nn.functional` as `F`, `tqdm` and `Optional`.\n",
    "\n",
    "Moreover, import the functions `get_dataloader_random_reshuffle` and `line_plot_confidence_band` that you created in Notebooks 0221 and 0219, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections.abc import (\n",
    "    Callable,\n",
    "    Iterable\n",
    ")\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tqdm\n",
    "from typing import Optional\n",
    "\n",
    "from util_0312 import (\n",
    "    get_dataloader_random_reshuffle,\n",
    "    line_plot_confidence_band\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_preprocessed_path\"`: `str`  \n",
    "    The path to the preprocessed dataset that we'll load. Make this the path to `abalone.pt` that we used previously. [Follow this link](https://www.renyi.hu/~zsamboki/teaching/dml-spring-2025/lab_notebooks/data/abalone_preprocessed.pt) if you need to download it again.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Today, we'll try out 11 learning rates, 8 times each. Thus, make the ensemble shape `(9, 8)`.\n",
    "- `\"improvement_threshold\"`: `float`  \n",
    "    A metric has to be this much better than the previous best, to count as improvement. Make this `1e-4`.\n",
    "- `\"learning_rate\"`: `float` | `torch.Tensor`  \n",
    "    The learning rate or rates to use. Let's try $10^i$ for $i=-5, 0.45,\\dotsc,-1$. Make the learning rate tensor broadcastable to the ensemble shape.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `256`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Today, we'll train models for a fixed number of `10_000` training steps.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `100`. Thus, each training of `10_000` steps will get us 100 validation results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_preprocessed_path\": \"data/abalone_preprocessed.pt\",\n",
    "    \"device\": \"cuda\",\n",
    "    \"ensemble_shape\": (9, 8),\n",
    "    \"improvement_threshold\": 1e-4,\n",
    "    \"learning_rate\": torch.logspace(\n",
    "        -5,\n",
    "        -1,\n",
    "        9,\n",
    "        device=\"cuda\"\n",
    "    )[:, None],\n",
    "    \"minibatch_size\": 256,\n",
    "    \"seed\": 1,\n",
    "    \"steps_num\": 10_000,\n",
    "    \"valid_interval\": 100\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and Preprocess Dataset\n",
    "\n",
    "Load the preprocessed dataset via `torch.load`. You'll get a dictionary. Print its keys as a list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded = torch.load(config[\"dataset_preprocessed_path\"], weights_only=True)\n",
    "print(list(loaded))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Assign the train and test features and targets to variables. Move the tensors to the device specified in the configuration dictionary.\n",
    "2. We'll make a more generic MSE loss today that is compatible to regression to vector-valued targets. Thus, add an extra dimension 1 to the right of the shape of train and test targets.\n",
    "3. Print the shapes of the feature and target matrices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    train_features,\n",
    "    train_values,\n",
    "    test_features,\n",
    "    test_values\n",
    ") = (\n",
    "    loaded[key].to(config[\"device\"])\n",
    "    for key in (\n",
    "        \"features_train\",\n",
    "        \"targets_train\",\n",
    "        \"features_test\",\n",
    "        \"targets_test\"\n",
    "    )\n",
    ")\n",
    "\n",
    "train_values, test_values = (\n",
    "    t[:, None]\n",
    "    for t in (train_values, test_values)\n",
    ")\n",
    "\n",
    "for t in (\n",
    "    train_features,\n",
    "    train_values,\n",
    "    test_features,\n",
    "    test_values\n",
    "):\n",
    "    print(t.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's normalize the train features, that is make the componentwise sample mean 0 and the componentwise sample std 1:\n",
    "1. Calculate the componentwise sample mean and subtract it from the train features.\n",
    "2. Calculate the compnentwise sample std and divide the train features by it.\n",
    "3. Transform the test features the same way:\n",
    "    1. subtract the train sample mean and\n",
    "    2. divide by the train sample std.\n",
    "4. Print the componentwise mean and std of the transformed train features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_mean = train_features.mean(dim=0)\n",
    "train_features -= sample_mean\n",
    "test_features -= sample_mean\n",
    "\n",
    "sample_std = train_features.std(dim=0)\n",
    "train_features /= sample_std\n",
    "test_features /= sample_std\n",
    "\n",
    "print(train_features.mean(dim=0), train_features.std(dim=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's split a validation set off the train set. Once you're done, print the shapes of the train and validation feature and target matrices to see if all's well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_val_indices = torch.randperm(len(train_features))\n",
    "(\n",
    "    (train_features, train_values),\n",
    "    (valid_features, valid_values)\n",
    ") = (\n",
    "    (train_features[indices], train_values[indices])\n",
    "    for indices in (\n",
    "        train_val_indices[:-len(test_features)],\n",
    "        train_val_indices[-len(test_features):]\n",
    "    )\n",
    ")\n",
    "\n",
    "for t in (\n",
    "    train_features,\n",
    "    train_values,\n",
    "    valid_features,\n",
    "    valid_values\n",
    "):\n",
    "    print(t.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make a train dataloader. Get a minibatch and print the shapes of the feature and value matrices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = get_dataloader_random_reshuffle(\n",
    "    config,\n",
    "    train_features,\n",
    "    train_values\n",
    ")\n",
    "for t in next(train_dataloader):\n",
    "    print(t.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constructing MLPs\n",
    "\n",
    "### Writing an Ensemble-Ready `Linear` Class\n",
    "\n",
    "The module `torch.nn.Linear` is unable to handle vectorized ensembles. Moreover, we'd like to initialize the modules ourselves. Finally, it is good to know how to write a custom module.\n",
    "\n",
    "Follow this link for a quick introduction to or refresher on Python classes:  \n",
    "https://exercism.org/tracks/python/concepts/classes\n",
    "\n",
    "First of all, make the class you define a child of `torch.nn.Module`.\n",
    "\n",
    "In what follows, I'll describe the two methods you need to define."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `__init__`\n",
    "\n",
    "The `__init__` method of each `Module` has to start with a call to `Module.__init__`. You can do this with `super().__init__()`.\n",
    "\n",
    "Now, we'll need to give the weight and optionally the bias parameters. That is, we need to create attributes that are tensors wrapped in `torch.nn.Parameter`. The latter ensures that the parameters are registered and thus for example they show up in the `parameters` and the `state_dict`.\n",
    "\n",
    "We'll want to initialize the weight parameter using a normal distribution. As we saw in Notebook 101424, you can do this by creating a tensor with `torch.empty`, then calling its `normal_` method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `forward`\n",
    "\n",
    "`Module`s are `Callable`s. That is, they have a `__call__` method by proxy of which you can call instances as functions. However, for technical reasons, in case of a `Module`, you should not write the `__call__` method itself, but the `forward` method. Besides this difference, just like in the case of `__call__`, the positional and keyword arguments after `self` become the positional and keyword arguments of the all.\n",
    "\n",
    "We expect input `features: torch.Tensor` of two possible shapes:\n",
    "1. `ensemble_shape + batch_shape + (in_features,)`. In this case, for each tuple of ensemble indices, the input is processed by the parameters of that ensemble member.\n",
    "2. `batch_shape + (in_features,)`. In this case, the input is broadcast to the above case.\n",
    "\n",
    "To decide which is the case:\n",
    "1. You can get `ensemble_shape` as all but the last two entries of the shape of the weight parameter.\n",
    "2. The number of entries in `ensemble_shape` is the ensemble dimension.\n",
    "3. Then you can check if the first ensemble dimension many entries of the shape of the input tensor is equal to `ensemble_shape`.\n",
    "4. You can use this information to determine the number of batch dimensions from the number of dimensions of the shape of the input tensor.\n",
    "\n",
    "Recall that the weight parameter has shape `ensemble_shape + (in_features, out_features)`. Now, locally, you need to create a view of this tensor that is reshaped to the sum of:\n",
    "1. `ensemble_shape`,\n",
    "2. batch dimension minus 1 many 1's for broadcasting over the batch shape and\n",
    "3. `(in_features, out_features)`.\n",
    "\n",
    "With this, you can perform the batched matrix product of the input tensor and the reshaped weight parameter. If there is no bias, you can return the result.\n",
    "\n",
    "Otherwise, locally, you need to create a view of the bias parameter that is reshaped to the sum of:\n",
    "1. `ensemble_shape`\n",
    "2. batch dimension many 1's for broadcasting over the batch shape and\n",
    "3. `(out_features,)`.\n",
    "\n",
    "With this, you can add to the previous result the reshaped bias parameter. Then return the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Write the `Linear` module as described. Create an instance of it with as many input features as there are in the dataset and one output feature. Then write the shape of the output of applying the model to:\n",
    "1. a training minibatch feature tensor and\n",
    "2. the validation feature matrix.\n",
    "\n",
    "Also, print a list of the keys of the state dictionary of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Linear(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Ensemble-ready affine transformation `y = x^T W + b`.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    in_features : `int`\n",
    "        The number of input features\n",
    "    out_features : `int`\n",
    "        The number of output features.\n",
    "    bias : `bool`, optional\n",
    "        Whether the model should include bias. Default: `True`.\n",
    "    init_multiplier : `float`, optional\n",
    "        The weight parameter values are initialized following\n",
    "        a normal distribution with center 0 and std\n",
    "        `in_features ** -.5` times this value. Default: `1.`\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    features : `torch.Tensor`\n",
    "        The input tensor. It is required to be one of the following shapes:\n",
    "        1. `ensemble_shape + batch_shape + (in_features,)`\n",
    "        2. `batch_shape + (in_features,)\n",
    "\n",
    "        Upon a call, the model thinks we're in the first case\n",
    "        if the first `len(ensemble_shape)` many entries of the\n",
    "        shape of the input tensor is `ensemble_shape`.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        in_features: int,\n",
    "        out_features: int,\n",
    "        bias=True,\n",
    "        init_multiplier=1.\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        if bias:\n",
    "            self.bias = torch.nn.Parameter(torch.zeros(\n",
    "                config[\"ensemble_shape\"] + (out_features,),\n",
    "                device=config[\"device\"],\n",
    "                dtype=torch.float32\n",
    "            ))\n",
    "        else:\n",
    "            self.bias = None\n",
    "\n",
    "        self.weight = torch.nn.Parameter(torch.empty(\n",
    "            config[\"ensemble_shape\"] + (in_features, out_features),\n",
    "            device=config[\"device\"],\n",
    "            dtype=torch.float32\n",
    "        ).normal_(std=out_features ** -.5) * init_multiplier)\n",
    "\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        features: torch.Tensor\n",
    "    ) -> torch.Tensor:\n",
    "        ensemble_shape = self.weight.shape[:-2]\n",
    "        ensemble_dim = len(ensemble_shape)\n",
    "        ensemble_input = features.shape[:ensemble_dim] == ensemble_shape\n",
    "        batch_dim = len(features.shape) - 1 - ensemble_dim * ensemble_input\n",
    "        \n",
    "        # (*e, *b, i) @ (*e, *b[:-1], i, o)\n",
    "        weight = self.weight.reshape(\n",
    "            ensemble_shape\n",
    "          + (1,) * (batch_dim - 1)\n",
    "          + self.weight.shape[-2:]\n",
    "        )\n",
    "        features = features @ weight\n",
    "\n",
    "        if self.bias is None:\n",
    "            return features\n",
    "        \n",
    "        # (*e, *b, o) + (*e, *b, o)\n",
    "        bias = self.bias.reshape(\n",
    "            ensemble_shape\n",
    "          + (1,) * batch_dim\n",
    "          + self.bias.shape[-1:]\n",
    "        )\n",
    "        features = features + bias\n",
    "\n",
    "        return features\n",
    "    \n",
    "model = Linear(config, train_features.shape[-1], 1)\n",
    "for t in (next(train_dataloader)[0], valid_features):\n",
    "    print(model(t).shape)\n",
    "\n",
    "print(list(model.state_dict()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ReLU MLP Factory Function\n",
    "\n",
    "In order to conveniently initialize MLPs, we write a so-called *factory function*. Each call will return a MLP.\n",
    "\n",
    "We will use `torch.nn.Sequential` to create a MLP. Therefore, we'll make a list of layers we'll then destructure as positional arguments to `Sequential`.\n",
    "\n",
    "When you're creating the affine transformations that precede ReLU's, make `init_multiplier` $\\sqrt2$ for Kaiming initialization.\n",
    "\n",
    "Create an MLP of 2 hidden layers of 32 dimensions, with input dimensions like in the dataset and 1 output dimension. Print the shape of its output when you give it a\n",
    "1. minibatch feature tensor and\n",
    "2. the validation feature matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mlp(\n",
    "    config: dict,\n",
    "    in_features: int,\n",
    "    out_features: int,\n",
    "    hidden_layer_num: Optional[int] = None,\n",
    "    hidden_layer_size: Optional[int] = None,\n",
    "    hidden_layer_sizes: Optional[Iterable[int]] = None,\n",
    ") -> torch.nn.Sequential:\n",
    "    \"\"\"\n",
    "    Creates an MLP with ReLU activation functions.\n",
    "    Can create a model ensemble.\n",
    "\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of affine transformations\n",
    "            the model represents.\n",
    "    in_features : `int`\n",
    "        The number of input features\n",
    "    out_features : `int`\n",
    "        The number of output features.\n",
    "    hidden_layer_num : `int`, optional\n",
    "        If `hidden_layer_sizes` is not given, we create an MLP with\n",
    "        `hidden_layer_num` hidden layers of\n",
    "        `hidden_layer_size` dimensions.\n",
    "    hidden_layer_size : `int`, optional\n",
    "        If `hidden_layer_sizes` is not given, we create an MLP with\n",
    "        `hidden_layer_num` hidden layers of\n",
    "        `hidden_layer_size` dimensions.\n",
    "    hidden_layer_sizes: `Iterable[int]`, optional\n",
    "        If given, each entry gives a hidden layer with the given size.\n",
    "    \"\"\"\n",
    "    if hidden_layer_sizes is None:\n",
    "        hidden_layer_sizes = (hidden_layer_size,) * hidden_layer_num\n",
    "\n",
    "    layers = []\n",
    "    layer_in_size = in_features\n",
    "    for layer_out_size in hidden_layer_sizes:\n",
    "        layers.extend([\n",
    "            Linear(\n",
    "                config,\n",
    "                layer_in_size,\n",
    "                layer_out_size,\n",
    "                init_multiplier=2 ** .5\n",
    "            ),\n",
    "            torch.nn.ReLU()\n",
    "        ])\n",
    "        layer_in_size = layer_out_size\n",
    "    \n",
    "    layers.append(Linear(\n",
    "        config,\n",
    "        layer_in_size,\n",
    "        out_features\n",
    "    ))\n",
    "\n",
    "    return torch.nn.Sequential(*layers)\n",
    "\n",
    "model = get_mlp(config, train_features.shape[-1], 1, 2, 32)\n",
    "for t in (next(train_dataloader)[0], valid_features):\n",
    "    print(model(t).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A Simple MLP Training Loop\n",
    "\n",
    "### Ensemble-Ready MSE\n",
    "\n",
    "Let's first write a function that can calculate MSE of ensembles of predicted and true values.\n",
    "\n",
    "After you wrote the function, print the shape of the MSE of your randomly initialized model on a training minibatch and the validation dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse(\n",
    "    predict: torch.Tensor,\n",
    "    target: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Calculates the MSE between two tensors. Compatible with ensembles.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    predict : `torch.Tensor`\n",
    "        Predicted values. The expected shape is\n",
    "        `ensemble_shape + (batch_size, values_dim)`.\n",
    "    target : `torch.Tensor`\n",
    "        Target values. The expected shape is either\n",
    "        `ensemble_shape + (batch_size, values_dim)` or\n",
    "        `(batch_size, values_dim)`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The tensor of MSE values, of shape `ensemble_shape`.\n",
    "    \"\"\"\n",
    "    target = target.broadcast_to(predict.shape)\n",
    "    mse = F.mse_loss(predict, target, reduction=\"none\")\n",
    "    return mse.sum(dim=-1).mean(dim=-1)\n",
    "\n",
    "for features, values in (\n",
    "    next(train_dataloader),\n",
    "    (valid_features, valid_values)\n",
    "):\n",
    "    predict = model(features)\n",
    "    print(get_mse(predict, values).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing the Training Loop\n",
    "\n",
    "Let's write the training loop! We still don't want to make it as feature-full as our logistic regression training loop, as next time, we'll introduce a very important hyperparameter optimization method that we'll want to incorporate.\n",
    "\n",
    "For now, let's just make `10_000` training steps. In each step, you'll have to perform SGD by hand as we are using a tensor of different learning rate values.\n",
    "1. At start of each training steps, you can loop over the `parameters` of a `Module` and set the `grad` to `None` one by one.\n",
    "2. When you update the parameters, you can once again loop over the `parameters`. To make sure the learning rate tensor broadcasts correctly, reshape it so that its shape has enough 1's on the right.\n",
    "\n",
    "Moreover, make training and validation loss lists. At each evaluation, append to these lists the MSE values on the full training and validation sets.\n",
    "\n",
    "Write the training function, then run it with the MLP you created. Print the best average losses by learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_supervised(\n",
    "    config: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    train_features: torch.Tensor,\n",
    "    train_values: torch.Tensor,\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_values: torch.Tensor\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Train a model on a supervised dataset. Supports ensembles.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : dict\n",
    "        Configuration dictionary. Required keys:\n",
    "        ensemble_shape : tuple[int]\n",
    "            The shape of the model ensemble.\n",
    "        learning_rate : float | torch.Tensor\n",
    "            The learning rate of the SGD optimization.\n",
    "            If a tensor, then it should have shape\n",
    "            broadcastable to `ensemble_shape`.\n",
    "            In that case, the members of the ensemble are trained with\n",
    "            different learning rates.\n",
    "        minibatch_size : int\n",
    "            The minibatch size of the reshuffling dataloader to use.\n",
    "        steps_num : int\n",
    "            The number of training steps to take.\n",
    "        valid_interval : int\n",
    "            The frequency of evaluations,\n",
    "            measured in the number of train steps.\n",
    "    get_loss : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]\n",
    "        Function that calculates the loss values of an ensemble\n",
    "        from a predicted and a target value tensor.\n",
    "    model : torch.nn.Module\n",
    "        The model to train.\n",
    "    train_features : torch.Tensor\n",
    "        Training feature matrix.\n",
    "    train_values : torch.Tensor\n",
    "        Training value vector.\n",
    "    valid_features : torch.Tensor\n",
    "        Validation feature matrix.\n",
    "    valid_values : torch.Tensor\n",
    "        Validation value vector.\n",
    "    \"\"\"\n",
    "    learning_rate = torch.asarray(\n",
    "        config[\"learning_rate\"],\n",
    "        device=valid_features.device,\n",
    "        dtype=valid_features.dtype\n",
    "    )\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    train_dataloader = get_dataloader_random_reshuffle(\n",
    "        config,\n",
    "        train_features,\n",
    "        train_values\n",
    "    )\n",
    "    train_losses = []\n",
    "    valid_losses = []\n",
    "\n",
    "    for step_id in progress_bar:\n",
    "        minibatch_features, minibatch_values = next(train_dataloader)\n",
    "        for parameter in model.parameters():\n",
    "            parameter.grad = None\n",
    "\n",
    "        predict = model(minibatch_features)\n",
    "        loss = get_loss(predict, minibatch_values).sum()\n",
    "        loss.backward()\n",
    "        with torch.no_grad():\n",
    "            for parameter in model.parameters():\n",
    "                parameter -= learning_rate.reshape(\n",
    "                    learning_rate.shape\n",
    "                  + (len(parameter.shape) - len(learning_rate.shape))\n",
    "                  * (1,)\n",
    "                ) * parameter.grad\n",
    "        \n",
    "        if step_id % config[\"valid_interval\"] == 0:\n",
    "            with torch.no_grad():\n",
    "                for features, values, losses in (\n",
    "                    (train_features, train_values, train_losses),\n",
    "                    (valid_features, valid_values, valid_losses)\n",
    "                ):\n",
    "                    predict = model(features)\n",
    "                    loss = get_loss(predict, values)\n",
    "                    losses.append(loss)\n",
    "\n",
    "    return tuple((\n",
    "        torch.stack(losses)\n",
    "        for losses in (train_losses, valid_losses)\n",
    "    ))\n",
    "\n",
    "train_losses, valid_losses = train_supervised(\n",
    "    config,\n",
    "    get_mse,\n",
    "    model,\n",
    "    train_features,\n",
    "    train_values,\n",
    "    valid_features,\n",
    "    valid_values\n",
    ")\n",
    "\n",
    "for losses in (train_losses, valid_losses):\n",
    "    print(losses.mean(dim=-1).min(dim=0)[0].cpu().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may notice that the best train losses are lower than the best validation losses. Let's plot line plots with confidence bands of the training curves with the best learning rate! Set `plt.ylim` for greater visibility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (color, curve, label) in (\n",
    "    (\"blue\", train_losses[:, 4], \"training\"),\n",
    "    (\"red\", valid_losses[:, 4], \"validation\")\n",
    "):\n",
    "    line_plot_confidence_band(\n",
    "        torch.arange(0, 10000, 100),\n",
    "        curve.cpu(),\n",
    "        color=color,\n",
    "        label=label\n",
    "    )\n",
    "\n",
    "plt.ylim(3, 6)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that although the training loss keeps decreasing, not the validation loss. This means that the model *overfits*: it starts memorizing the dataset instead of generalization. One can counteract this tendency with *regularization* methods. We'll introduce such a method soon."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets\n",
    "\n",
    "## Abalone\n",
    "\n",
    "https://archive.ics.uci.edu/dataset/1/abalone  \n",
    "This dataset is licensed under a [Creative Commons Attribution 4.0 International (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/legalcode) license."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
