{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing PBT\n",
    "\n",
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `defaultdict`, `Callable`, `math`, `plt`, `scipy`, `torch`, `torch.nn.functional` as `F` and `tqdm`.\n",
    "\n",
    "Moreover, import the functions:\n",
    "1. `load_preprocessed_dataset`,  that you wrote in Notebook 0219,\n",
    "2. `get_accuracy`, `get_cross_entropy`, `get_dataloader_random_reshuffle`, and `line_plot_confidence_band`, that you wrote in Notebook 0221 and \n",
    "3. `get_mlp`, that you wrote in Notebook 0319."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_preprocessed_path\"`: `str`  \n",
    "    The path to the preprocessed dataset that we'll load. Make this the path to your preprocessed MNIST.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Today, we'll use PBT with population size 16. Thus, make this `(16,)`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`: `dict`  \n",
    "    We'll make this a dictionary of tuned hyperparameter name - initial distribution of raw values.\n",
    "    1. We'll make the distributions `torch.distributions.Distibution`. Then we can expect the distributions have a `sample` method. This method has one positional argument that says a sample of what shape it should output.\n",
    "    2. Today, we'll only tune the learning rate. We'll give the raw hyperparameter the uniform distribution $\\mathscr U([-5, 0])$.\n",
    "- `\"hyperparameter_raw_perturb\"` : `dict`  \n",
    "    This will be a dictionary of tuned hyperparameter name - additive noise distribution. We expect these values too to be `torch.distributions.Distribution`. Today, we make the raw learning rate additive noise a normal distribution of center 0 and std 2. Feel free to try out various values here!\n",
    "- `\"hyperparameter_transforms\"` : `dict`  \n",
    "    This will be a dictionary of tuned hyperparameter name - transformation. In case of learning rate, use as value the function $x\\mapsto10^x$.\n",
    "\n",
    "    Make sure that the last 3 dictionaries have the same keys.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `128`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    As our models are growing bigger, instead of evaluating on the entire dataset, we'll do that in batches too.\n",
    "    1. If you're using CPU, you can make this something like the training minibatch.\n",
    "    2. If you're using GPU, it is best to find the largest value here that does not give you an out of memory error. You can experiment with powers of 2. It may be convenient to use the left bit shift operator `<<`. For example, on my home computer, I could use an evaluation minibatch size of $2^{14}$, that is `1 << 14`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Let's make a switch to turn off PBT. You can use this to test if the algorithm is doing any good, or optimization with the initial hyperparameters works just as well. Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `100_000`. Let early stopping take care of stopping.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `100`.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. As the PBT paper does not specify this, let's start with `.95`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you perform the following steps, print all intermediate results, preferably with text saying what is the number that we're seeing.\n",
    "1. Generate initial raw learning rate hyperparameters of the ensemble shape.\n",
    "2. Transform the initial raw learning rate.\n",
    "3. Generate noise.\n",
    "4. Add the noise to the initial raw learning rate.\n",
    "5. Transform the perturbed raw learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading and Preprocessing the Dataset\n",
    "\n",
    "Use `load_preprocessed_dataset` to get the training, validation and test feature matrices and label vectors.\n",
    "\n",
    "Recall that for logistic regression, optionally, we added extra columns of 1s to the feature matrices. Now, as the bias vectors are included in the \"linear\" layers, we'll not need this. Thus, if you added a constant feature in your preprocessed dataset, make slices of the feature matrices that drop the last columns.\n",
    "\n",
    "Print the feature matrix shapes. The last dimensions should be 784."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Normalization\n",
    "\n",
    "Let's normalize the dataset by subtracting the *total* mean and dividing by the *total* std.\n",
    "\n",
    "1. Write the function defined below.\n",
    "2. Apply it to the training, validation and test feature matrices.\n",
    "3. Print the mean and std of the training feature matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_features(\n",
    "    train_features: torch.Tensor,\n",
    "    additional_features=(),\n",
    "    verbose=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Normalize feature tensors by\n",
    "    1. subtracting the total mean of the training features, then\n",
    "    2. dividing by the total std of the offset training features.\n",
    "\n",
    "    Optionally, apply the same transformation to additional feature tensors,\n",
    "    eg. validation and test feature tensors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    train_features : `torch.Tensor`\n",
    "        Training feature tensor.\n",
    "    additional_features : `Iterable[torch.Tensor]`, optional\n",
    "        Iterable of additional features to apply the transformation to.\n",
    "        Default: `()`.\n",
    "    verbose : `bool`, optional\n",
    "        Whether to print the total mean and std\n",
    "        gotten for the transformation.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementation\n",
    "\n",
    "### Evaluation in Batches\n",
    "\n",
    "Let's write the function `evaluate_model`, that, given a metric function and feature and value tensors, evaluates a model, using batches so that we don't get memory overflow for larger models.\n",
    "\n",
    "1. Just like the version of  `train_logistic_regression` that you wrote in Notebook 0226, we want to accumulate minibatch metrics in a tensor. You can initialize the accumulator as the integer `0`. This way, you don't have to know the shape of the metric tensors that you'll accumulate.\n",
    "2. You can get the number of minibatches you'll iterate over as `math.ceil` of dataset size divided by evaluation minibatch size.\n",
    "3. Now, using the number of minibatches as upper limit, you can iterate over minibatch indices.\n",
    "4. Given a minibatch index, you can use it to get the minibatch features and values as slices of the dataset features and values.\n",
    "6. The minibatch metric you get is an average over the minibatch entry dimension. Thus, multiply it up by the size of the minibatch before adding it to the accumulator.\n",
    "7. When you finished iterating over the minibatches, return the quotient of the accumulated metrics by the dataset size.\n",
    "\n",
    "Write the function. Then initialize an MLP ensemble with 3 hidden layers of width 128 each and the appropriate number of output dimensions for digit classification. Evaluate your model on the training dataset. If you're using a GPU, this is a great time to adjust `minibatch_size_eval`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(\n",
    "    config: dict,\n",
    "    features: torch.Tensor,\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    values: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Evaluate a model on a supervised dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Size of consecutive minibatches to take from the dataset.\n",
    "            To be set according to RAM or GPU memory capacity.\n",
    "    features : `torch.Tensor`\n",
    "        Feature tensor.\n",
    "    get_metric : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        Function to get the metric from a pair of\n",
    "        predicted and target value tensors.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to evaluate.\n",
    "    values : `torch.Tensor`\n",
    "        Target value tensor.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One-Sided Welch $t$-Test\n",
    "\n",
    "To decide if a population member should be replaced by another, we'll perform a one-sided Welch $t$-test between their last `welch_sample_size` metrics. In general, the test works as follows:\n",
    "\n",
    "We have sets of samples $\\{x_i^{(k)}:i=1,\\dotsc,N_k\\}$ from random variables $X^{(k)}$ for $k=1,2$. We assume that the random variables have normal distributions, but we let them have different mean and std. We want to test the null hypothesis that $\\mathbf E X^{(2)} > \\mathbf E X^{(1)}$ given the samples.\n",
    "\n",
    "From now on, we'll assume $N_1=N_2=:N$ as that will be our use case. For more general formulas, see eg. [the wikipedia page](https://en.wikipedia.org/wiki/Welch's_t-test).\n",
    "\n",
    "1. Start by getting:\n",
    "    1. the sample means $\\bar x^{(k)}=\\frac{1}{N}\\sum_{i=1}^Nx_i^{(k)}$ and\n",
    "    2. the unbiased sample variances $(s^{(k)})^2=\\frac{1}{N-1}\\sum_{i=1}^N(x_i^{(k)}-\\bar x^{(k)})^2$.\n",
    "2. Then we can get the required values:\n",
    "    1. the $t$-statistic \n",
    "    $t=\\sqrt N\\frac\n",
    "    {\\bar x^{(2)}-\\bar x^{(1)}}\n",
    "    {\\sqrt{(s^{(1)})^2+(s^{(2)})^2}}$ and\n",
    "    2. the approximate degrees of freedom\n",
    "    $\\nu=(N-1)\\frac\n",
    "    {((s^{(1)})^2+(s^{(2)})^2)^2}\n",
    "    {(s^{(1)})^4+(s^{(2)})^4}$.\n",
    "3. Now the probability $p$ that $\\mathbf E X^{(2)} > \\mathbf E X^{(1)}$ can be approximated by the cumulative distribution function value $\\mathbf P(T\\le t)$ where $T$ follows a Student distribution of $\\nu$ degrees of freedom. Recall that this value can be calculated by the `cdf` method of `scipy.stats.t`.\n",
    "\n",
    "Write the function below. Then repeat the following:\n",
    "1. Make source and target tensors samples of shape `(10, 10)` from the standard normal distribution.\n",
    "2. Apply the function to these tensors to get a test mask.\n",
    "\n",
    "until you get a test mask that is not all `False`. Given that, make a box plot (see Notebook 0226) of the samples, with using the mask values as labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def welch_one_sided(\n",
    "    source: torch.Tensor,\n",
    "    target: torch.Tensor,\n",
    "    confidence_level=.95\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Performs Welch's t-test with null hypothesis: the expected value\n",
    "    of the random variable the target tensor collects samples of\n",
    "    is larger then the expected value\n",
    "    of the random variable the source tensor collects samples of.\n",
    "\n",
    "    In the tensors, dimensions after the first \n",
    "    are considered batch dimensions.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    source : `torch.Tensor`\n",
    "        Source sample, of shape `(sample_size,) + batch_shape`.\n",
    "    target : `torch.Tensor`\n",
    "        Target sample, of shape `(sample_size,) + batch_shape`.\n",
    "    confidence_level : `float`, optional\n",
    "        Confidence level of the test. Default: `.95`.\n",
    "    Returns\n",
    "    -------\n",
    "    A Boolean tensor of shape `batch_shape` that is `False`\n",
    "    where the null hypothesis is rejected.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop\n",
    "\n",
    "Time to write the PBT training loop! Start out from the function `train_supervised` that you wrote in Notebook 0319."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Initialization\n",
    "\n",
    "1. We'll assume the ensemble shape is 1-dimensional, the population dimension. Thus, to help the user -- this could be you a month later -- I recommend to raise an error if the given ensemble shape is not 1-dimensional.\n",
    "2. To make the function flexible enough to tune various hyperparameters later, we'll store and update all hyperparameters in the configuration dictionary. To avoid changing the received configuration dictionary, it is recommended to create a local copy.\n",
    "3. As for logging results, we'll use another dictionary, that we'll refer to as the *output* dictionary. It is recommended to make this a `defaultdict(list)` for ease of use.\n",
    "4. Since the population members interact, we'll keep one global best metric score. Initialize this best score and the number of steps without improvement."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Initialization of Hyperparameters\n",
    "\n",
    "As you iterate over `name, distribution` pairs from `hyperparameter_raw_init_distributions`:\n",
    "1. Get the raw hyperparameters by `distribution.sample`.\n",
    "2. Save the raw hyperparameters in the local configuration dictionary. We'll need these for perturbations. Make the name differ from `name`. For example, you can use `name + \"_raw\"`.\n",
    "3. Transform the raw hyperparameters via the transformation in `hyperparameter_transforms` at key `name`.\n",
    "4. Save the transformed hyperparameters to the local configuration dictionary at `name`.\n",
    "5. Also append the hyperparameters to the output dictionary at `name`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training Steps\n",
    "\n",
    "The only change you need to make to `train_supervised` at training steps is that the learning rate is now not a local variable, but a value of the local configuration dictionary."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validation\n",
    "\n",
    "At validation, first there are some changes unrelated to PBT:\n",
    "1. You don't only want to record losses, but also metrics.\n",
    "2. Calculate both losses and metrics using `evaluate_model`.\n",
    "3. Store the results at appropriate keys of the output dictionary.\n",
    "4. Update the best metric and the steps without improvement counter.\n",
    "5. If the number of steps without improvement exceeds the maximum, stop training.\n",
    "\n",
    "If the number of evaluations is at least `welch_sample_size`, then we can initiate a PBT update:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Exploitation with Welch's $t$-Test\n",
    "\n",
    "You can stack the last `welch_sample_size` entries in the validation metric list of the output dictionary to get a tensor of validation metrics of shape `(welch_sample_size, population_size)`.\n",
    "\n",
    "1. For each population member, draw a replacement index.\n",
    "2. Now you can create the replacement mask by `welch_one_sided`.\n",
    "3. Append the replacement indices and the replacement masks to appropriate lists in the output dictionary."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Exploration by Perturbations\n",
    "\n",
    "If the replacement mask is nonempty:\n",
    "1. Loop over the model parameters and update them using the replacement indices and mask.\n",
    "2. Loop over the name, transform pairs of `hyperparameter_transforms`:\n",
    "    1. Get the raw hyperparameter values from the local configuration dictionary.\n",
    "    2. Take the values at masked replacement indices.\n",
    "    3. Add additive noise, sampled from the appropriate distribution in `hyperparameter_raw_perturb`.\n",
    "    4. Apply the transform.\n",
    "    5. Update the raw and transformed values in the local configuration dictionary.\n",
    "    6. Append the raw and transformed values to the appropriate lists in the output dictionary."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Output\n",
    "\n",
    "At the end of training, replace all values of the output dictionary that are lists by tensors that are formed by stacking their entries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Write the training function and run it. Print the maximum of validation metric tensor in the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pbt(\n",
    "    config: dict,\n",
    "    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
    "    model: torch.nn.Module,\n",
    "    train_features: torch.Tensor,\n",
    "    train_values: torch.Tensor,\n",
    "    valid_features: torch.Tensor,\n",
    "    valid_values: torch.Tensor,\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Population-based training on a supervised learning task.\n",
    "    Tuned hyperparameters are given by raw values and transformations.\n",
    "    This way, the hyperparameters are perturbed by\n",
    "    additive noise on raw values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            Minibatch size to use in a training step.\n",
    "        `\"minibatch_size_eval\"` : `int`\n",
    "            Minibatch size to use in evaluation.\n",
    "            On CPU, should be about the same as `minibatch_size`.\n",
    "            On GPU, should be as big as possible without\n",
    "            incurring an Out of Memory error.\n",
    "        `\"pbt\"` : `bool`\n",
    "            Whether to use PBT updates in validations.\n",
    "            If `False`, the algorithm just samples hyperparameters at start,\n",
    "            then keeps them constant.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"steps_without_improvement`\" : `int`\n",
    "            If the number of training steps without improvement\n",
    "            exceeds this value, then training is stopped.\n",
    "        `\"valid_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of losses per ensemble member.\n",
    "    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`\n",
    "        A function that maps a pair of predicted and target value tensors\n",
    "        to a tensor of metrics per ensemble member.\n",
    "        We assume a greater metric is better.\n",
    "    `model` : `torch.nn.Module`\n",
    "        The model ensemble to tune.\n",
    "    `train_features` : `torch.Tensor`\n",
    "        Training feature tensor.\n",
    "    `train_values` : `torch.Tensor`\n",
    "        Training value tensor.\n",
    "    `valid_features` : `torch.Tensor`\n",
    "        Validation feature tensor.\n",
    "    `valid_values` : `torch.Tensor`\n",
    "        Validation value tensor.\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    An output dictionary with the following key-value pairs:\n",
    "        `\"source mask\"` : `torch.Tensor`\n",
    "            The source masks of population members\n",
    "            that were replace by other members in a PBT update\n",
    "        `\"target indices\"` : `torch.Tensor`\n",
    "            The indices of population members\n",
    "            that the member where the source mask is to were replaced with.\n",
    "        `\"training loss\"` : `torch.Tensor`\n",
    "            The training losses at evaluation steps.\n",
    "        `\"training metric\"` : `torch.Tensor`\n",
    "            The training metrics at evaluation steps.\n",
    "        `\"validation loss\"` : `torch.Tensor`\n",
    "            The validation losses at evaluation steps.\n",
    "        `\"validation metric\"` : `torch.Tensor`\n",
    "            The validation metrics at evaluation steps.\n",
    "\n",
    "        In addition, for each tuned hyperparameter name,\n",
    "        we include a `torch.Tensor` of values per update.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall that with logistic regression, our best accuracy was about 92%.\n",
    "\n",
    "Make a `line_plot_confidence_band` of log10 of the transpose of the entry of the output dictionary that stores the learning rates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see how the method found a learning rate schedule."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, reinitialize the model and run `pbt` with `steps_num` set to `10_000` and `pbt` set to `False` in the configuration dictionary. Again, print the maximum validation metric.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may have gotten not that much worse of a result with this, or even better. But note that we only had one hyperparameter to tune. Later on, the number of these can grow, so we'll get to appreciate PBT more."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets\n",
    "\n",
    "### MNIST\n",
    "\n",
    "https://huggingface.co/datasets/ylecun/mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
