{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classifying Hyperspheres with MLPs\n",
    "\n",
    "Today, we'll encounter the modular Artificial Neural Network interface of `torch`. We'll write code to sample from a hypersphere dataset, a simple MLP training loop and plotting results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Install\n",
    "\n",
    "At the end of this lab, we'll train a model that creates 3-dimensional hidden features. To study them at ease, we will make interactive 3d plots. So that we can do that in a Jupyter notebook, we need the `ipympl` package. You can install it either by\n",
    "```bash\n",
    "pip install ipympl\n",
    "```\n",
    "or\n",
    "```bash\n",
    "mamba install -n {the name of your environment} -c conda-forge ipympl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports\n",
    "\n",
    "1. Import `Iterable`, `mpl`, `plt`, `torch`, `torch.nn.functional` as `F`, `tqdm` and `Optional`.\n",
    "2. Import the function `get_dataloader_random_reshuffle` that you wrote in Notebook 0219 and the functions `get_binary_accuracy` and `get_binary_cross_entropy` that you wrote in Notebook 0228."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections.abc import Iterable\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tqdm\n",
    "from typing import Optional\n",
    "\n",
    "from util_0312 import (\n",
    "    get_binary_accuracy,\n",
    "    get_binary_cross_entropy,\n",
    "    get_dataloader_random_reshuffle\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"classes_num\"`: `int`  \n",
    "    The number of hyperspheres to have in the synthetic dataset. Make this a `2` for binary classifiation.\n",
    "- `\"classes_offset\"`: `tuple[tuple[float]]`  \n",
    "    The offsets of the hyperspheres in the synthetic dataset. Make this `((0, 0), (1, 1))` for starters.\n",
    "- `\"classes_scale\"`: `tuple[float]`  \n",
    "    The radii of the hyperspheres in the synthetic dataset. Make this `(0.8, 1.2)`.\n",
    "- `\"dataset_size\"`: `int`  \n",
    "    The number of samples to generate in the hypersphere dataset. Make this `10_000`.\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier, explained in notebook 091224.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Make this a `()`, as we'll not use ensembles today. (We need the value as `get_dataloader_random_reshuffle` expects it.)\n",
    "- `\"feature_num\"`: `int`  \n",
    "    The number of features, that is the ambient dimension of the hypersphere dataset. Make this `2`.\n",
    "- `\"learning_rate\"`: `float`  \n",
    "    The learning rate to use in the simple training algorithm we'll use today. Make this `1`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this a `256`.\n",
    "- `\"noise_std\"`: `float`  \n",
    "    The std of the additive noise of the hypersphere dataset. Make this `0.1`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `1000` for today's simple training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"classes_num\": 2,\n",
    "    \"classes_offset\": ((0, 0), (1, 1)),\n",
    "    \"classes_scale\": (0.8, 1.2),\n",
    "    \"dataset_size\": 10_000,\n",
    "    \"device\": \"cuda\",\n",
    "    \"ensemble_shape\": (),\n",
    "    \"feature_num\": 2,\n",
    "    \"learning_rate\": 1,\n",
    "    \"minibatch_size\": 256,\n",
    "    \"noise_std\": 0.1,\n",
    "    \"seed\": 1,\n",
    "    \"steps_num\": 1000\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed according to the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling from the Hypersphere Dataset\n",
    "\n",
    "### Sampling from a Hypersphere\n",
    "\n",
    "1. You can sample from the uniform distribution $\\mathscr U(S^{n-1})$ on a unit hypersphere as follows:\n",
    "    1. Take a sample from the standard normal distribution in $n$ dimensions. You can do this by\n",
    "        1. initializing an `empty` tensor of appropriate size, then\n",
    "        2. calling its `normal_` method which fills in entries following a normal distribution and returns the tensor.\n",
    "    2. You have a sample from a standard normal distribution in $n$ dimensions. Divide the vectors by their norms. You can use `torch.linalg.vector_norm` for this.\n",
    "2. You can add additive noise by adding values from a sample from a normal distribution in $n$ dimensions with the given std, gotten the same way as above.\n",
    "\n",
    "Write the function, get a sample as per the configuration dictionary and make a scatter plot of it. Before drawing the plot, set the aspect ratio to `\"equal\"`. For example, you can use `plt.gca().set_aspect`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_unit_hypersphere_sample(\n",
    "    config: dict\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Gets a sample from the uniform distribution on a unit hypersphere\n",
    "    with additive Gaussian noise.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"dataset_size\"` : `int`\n",
    "            The size of the sample to get.\n",
    "        `\"device\"` : `torch.Device | int | str`\n",
    "            The device the output tensor to be stored on.\n",
    "        `\"feature_num\"` : `int`\n",
    "            The ambient dimension of the hypersphere.\n",
    "        `\"noise_std\"` : `float`\n",
    "            The standard deviation of the additive Gaussian noise.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A tensor of shape `(dataset_size, feature_num)` containing the sample.\n",
    "    \"\"\"\n",
    "    sample = torch.empty(\n",
    "        (\n",
    "            config[\"dataset_size\"],\n",
    "            config[\"feature_num\"]\n",
    "        ),\n",
    "        device=config[\"device\"],\n",
    "        dtype=torch.float32\n",
    "    ).normal_()\n",
    "    sample /= torch.linalg.vector_norm(\n",
    "        sample,\n",
    "        dim=-1,\n",
    "        keepdim=True\n",
    "    )\n",
    "\n",
    "    sample += torch.empty_like(sample).normal_(std=config[\"noise_std\"])\n",
    "\n",
    "    return sample\n",
    "\n",
    "sample = get_unit_hypersphere_sample(config)\n",
    "plt.gca().set_aspect(\"equal\")\n",
    "plt.scatter(*sample.T.cpu(), s=1)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Separating the classes\n",
    "\n",
    "You can finish up the hypersphere classification dataset creator as follows:\n",
    "1. You need to get a label vector. For example, you can take the remainders of the vector of values `0,1,...,dataset_size-1` modulo `classes_num`.\n",
    "2. You can transform the scale and offset tuples in the configuration dictionary to tensors using `torch.asarray`.\n",
    "3. You can index into the scale and offset tensor by the label vector, then after broadcasting you can apply it to the hypersphere sample to get the feature matrix of the dataset.\n",
    "\n",
    "Write the function below. We'll write a plotting function in the next step. For now, do the following to check your results:\n",
    "1. Get a dataset.\n",
    "2. Get unique labels and their counts via `torch.unique`.\n",
    "3. Loop over the unique labels:\n",
    "    1. Make a mask of entries with the given label via the label vector.\n",
    "    2. Mask the feature matrix to get a matrix with column vectors the sample from the hypersphere of the given label.\n",
    "        1. Print the center of the sample.\n",
    "        2. Print the mean and std of the distances from the center."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hyperspheres_dataset(\n",
    "    config: dict\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Gets a sample from the synthetic dataset\n",
    "    of classification of hyperspheres.\n",
    "    The hyperspheres are generated by `get_unit_hypersphere_sample`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"classes_num\"`: `int`  \n",
    "            The number of hyperspheres to have in the synthetic dataset.\n",
    "        `\"classes_offset\"`: `tuple[tuple[float]]`  \n",
    "            The offsets of the hyperspheres in the synthetic dataset.\n",
    "        `\"classes_scale\"`: `tuple[float]`  \n",
    "            The radii of the hyperspheres in the synthetic dataset.\n",
    "        `\"dataset_size\"` : `int`\n",
    "            The size of the sample to get.\n",
    "        `\"device\"` : `torch.Device | int | str`\n",
    "            The device the output tensor to be stored on.\n",
    "        `\"feature_num\"` : `int`\n",
    "            The ambient dimension of the hypersphere.\n",
    "        `\"noise_std\"` : `float`\n",
    "            The standard deviation of the additive Gaussian noise.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A pair of the feature matrix and label vector\n",
    "    that represents the dataset.\n",
    "    \"\"\"\n",
    "    features = get_unit_hypersphere_sample(config)\n",
    "    labels = torch.arange(\n",
    "        config[\"dataset_size\"],\n",
    "        device=config[\"device\"]\n",
    "    ) % config[\"classes_num\"]\n",
    "\n",
    "    scale, offset = (\n",
    "        torch.asarray(\n",
    "            config[key],\n",
    "            device=config[\"device\"],\n",
    "            dtype=torch.float32\n",
    "        )\n",
    "        for key in (\"classes_scale\", \"classes_offset\")\n",
    "    )\n",
    "\n",
    "    features *= scale[labels][:, None]\n",
    "    features += offset[labels]\n",
    "\n",
    "    return features, labels\n",
    "\n",
    "features, labels = get_hyperspheres_dataset(config)\n",
    "unique, counts = torch.unique(labels, return_counts=True)\n",
    "print(unique.cpu().numpy(), counts.cpu().numpy())\n",
    "for label in range(len(unique)):\n",
    "    mask = labels == label\n",
    "    features_label = features[mask]\n",
    "    center = features_label.mean(dim=0)\n",
    "    displacements = features_label - center\n",
    "    distances = torch.linalg.vector_norm(displacements, dim=-1)\n",
    "    print(center.cpu().numpy(), distances.mean().cpu().numpy(), distances.std().cpu().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting the Dataset\n",
    "\n",
    "Let's plot how the dataset looks like!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Getting Colors from a Colormap\n",
    "\n",
    "First, let's refactor getting colors from a colormap, what we did in Notebook 0226. Write the function below and print some output. It should be a matrix with column vectors of RGBA (Red, Green, Blue, Alpha -- opacity) values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_colors(colors_num: int, colormap=\"viridis\") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Gets color RGBA values\n",
    "    that are gotten at evenly spaced points of a colormap.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    colors_num : `int`\n",
    "        The number of colors to return. They are taken at positions\n",
    "        ```\n",
    "        0, 1 / (colors_num - 1),..., 1\n",
    "        ```\n",
    "        of the colormap\n",
    "    colormap : str, optional\n",
    "        The colormap to load from `mpl.colormaps`. Default: `\"viridis\"`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A matrix with column vectors the RGBA values of the colors.\n",
    "    \"\"\"\n",
    "    colormap = mpl.colormaps[colormap]\n",
    "    colors = torch.asarray(colormap(torch.linspace(0, 1, colors_num)))\n",
    "\n",
    "    return colors\n",
    "\n",
    "print(get_colors(2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plotting a 2D Classification Dataset\n",
    "\n",
    "Time to plot the dataset! We'll do this in the $n=2$ case. You can plot the full dataset as one scatter plot if in the `c` keyword argument, you supply an iterable yielding the colors of the vertices one by one.\n",
    "\n",
    "Write the function below and plot the dataset you generated above. Don't forget about the aspect ratio."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_2d_classification_dataset(\n",
    "    config: dict,\n",
    "    features: torch.Tensor,\n",
    "    labels: torch.Tensor,\n",
    "    colormap=\"rainbow\",\n",
    "    colors: Optional[Iterable]=None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot a classification dataset with 2-dimensional features.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"classes_num\"` : `int`\n",
    "            The number of classes in the dataset.\n",
    "    features : `torch.Tensor`\n",
    "        The feature matrix.\n",
    "    labels : `torch.Tensor`\n",
    "        The label vector.\n",
    "    colormap : str, optional\n",
    "        The colormap to get colors from if `colors` is not given.\n",
    "        Default: \"rainbow\".\n",
    "    colors : Iterable, optional\n",
    "        An iterable, yielding the color to be used for each label.\n",
    "        If not given, it is gotten via `get_colors`.\n",
    "    \"\"\"\n",
    "    if colors is None:\n",
    "        colors = get_colors(config[\"classes_num\"], colormap)\n",
    "\n",
    "    features, labels = (t.cpu() for t in (features, labels))\n",
    "\n",
    "    plt.scatter(\n",
    "        *features.T,\n",
    "        c=[colors[i] for i in labels],\n",
    "        s=1\n",
    "    )\n",
    "\n",
    "plt.gca().set_aspect(\"equal\")\n",
    "plot_2d_classification_dataset(config, features, labels)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `torch.nn.Linear`\n",
    "\n",
    "As we saw, a basic component of ANNs is an affine transformation. Confusingly, as a `Module`, these are called `torch.nn.Linear`.\n",
    "1. The two positional arguments declare the input and output dimensions of the transformation.\n",
    "2. The `bias` keyword argument determines if the transformation should have a bias vector. (So, the module can only be properly called linear if this is `False`.) Default: `True`.\n",
    "3. The `device` and `dtype` keyword arguments have the usual role.\n",
    "\n",
    "First, we shall perform logistic regression in this framework. Create the affine transformation of appropriate arguments. As it includes bias, we won't need to add a column of 1s to the feature matrix.\n",
    "\n",
    "`Module`s are callable. Call your affine transformation on the feature matrix. Print the output and its shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Linear(\n",
    "    2,\n",
    "    1,\n",
    "    device=config[\"device\"],\n",
    "    dtype=torch.float32\n",
    ")\n",
    "logits = model(features)\n",
    "print(logits)\n",
    "print(logits.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see the following:\n",
    "1. The output has the shape we expect. In particular, we can plug it in our `get_binary_accuracy`. Let's do that and print the result.\n",
    "2. The output has gradient information. This is the default behaviour for `Module`s. You can switch this off at evaluation with `torch.no_grad`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    accuracy = get_binary_accuracy(\n",
    "        logits,\n",
    "        labels,\n",
    "    )\n",
    "\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time, the accuracy may not be close to 0.5 (why?)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### State Dictionary\n",
    "\n",
    "Print the output of the `state_dict` method of your `Module`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that at key `\"weight\"`, we have a $1\\times2$ matrix and at key `bias`, we have a 1-dimensional vector. So the weight matrix here is the transpose of what we're using.\n",
    "\n",
    "Moreover, we can see that the feature matrix and bias vector entries are nonzero at initialization. We'll discuss this matter next Wednesday."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters\n",
    "\n",
    "To help with optimization, a `Module` has a `parameters` method, that returns an iteration of the parameters. That is, this output can be directly fed to a `torch.optim.Optimizer`.\n",
    "\n",
    "Print the list of the entries of this iterator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(list(model.parameters()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that the entries are \"parameters containing\" the tensors. This means, the tensors are wrapped in a `torch.nn.Parameter`. When constructing a `Module`, this lets the module *register* the parameter, so that for example it appears in the output of `state_dict` and `parameters`. We'll see about this next Wednesday."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training and Evaluating `Module`s\n",
    "\n",
    "### A Simple `Module` Training Loop\n",
    "\n",
    "Next Wednesday, we'll update our supervised ensemble learning training loop. For today, let's make a very simple one.\n",
    "\n",
    "Write the function below, then run it. Print the accuracy you get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_binary_classifier_simple(\n",
    "    config: dict,\n",
    "    features: torch.Tensor,\n",
    "    labels: torch.Tensor,\n",
    "    model: torch.nn.Module\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Train a model on a binary classification task\n",
    "    via Stochastic Gradient Descent\n",
    "    for a fixed number of training steps,\n",
    "    then return its binary accuracy on the same dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"learning_rate\"` : `float`\n",
    "            The learning rate in SGD.\n",
    "        `\"minibatch_size\"` : `int`\n",
    "            The minibatch size to use in the random reshuffling dataloader.\n",
    "        `\"steps_num\"` : `int`\n",
    "            The number of SGD training steps to take.\n",
    "    features : `torch.Tensor`\n",
    "        The feature matrix of the dataset, of shape\n",
    "        `(dataset_size, feature_num)`\n",
    "    labels : `torch.Tensor`\n",
    "        The label vector of the dataset, of shape `(dataset_size,)`.\n",
    "    model : `torch.nn.Module`\n",
    "        The model to optimizer.\n",
    "        We assume it has 1-dimensional output,\n",
    "        the binary classification logits.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The accuracy of the trained model on the dataset.\n",
    "    \"\"\"\n",
    "    float_dtype = features.dtype\n",
    "    dataloader = get_dataloader_random_reshuffle(\n",
    "        config,\n",
    "        features,\n",
    "        labels.to(float_dtype)\n",
    "    )\n",
    "    optimizer = torch.optim.SGD(\n",
    "        params=model.parameters(),\n",
    "        lr=config[\"learning_rate\"]\n",
    "    )\n",
    "    progress_bar = tqdm.trange(config[\"steps_num\"])\n",
    "    for _ in progress_bar:\n",
    "        minibatch_features, minibatch_labels = next(dataloader)\n",
    "        optimizer.zero_grad()\n",
    "        logits = model(minibatch_features)\n",
    "        loss = get_binary_cross_entropy(logits, minibatch_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    progress_bar.close()\n",
    "    with torch.no_grad():\n",
    "        logits = model(features)\n",
    "\n",
    "    accuracy = get_binary_accuracy(logits, labels)\n",
    "\n",
    "    return accuracy.cpu().item()\n",
    "\n",
    "accuracy = train_binary_classifier_simple(\n",
    "    config,\n",
    "    features,\n",
    "    labels,\n",
    "    model\n",
    ")\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If all went well, you should see an accuracy value of about 75%."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting the Decision Boundary\n",
    "\n",
    "Let's plot the decision boundary of our logistic regression model\n",
    "$$\n",
    "\\mathbf x\\mapsto\\sigma(\\mathbf x^T\\mathbf w+b)!\n",
    "$$\n",
    "Recall that this is the hyperplane with equation\n",
    "$$\n",
    "\\mathbf x^T\\mathbf w+b=0.\n",
    "$$\n",
    "\n",
    "In case we have $n=2$ features, this is a line. You can plot a line segment by feeding `plt.plot` the $x$- and $y$-coordinates of its two endpoints.\n",
    "\n",
    "We'll call this function after having called `plot_2d_classification_dataset`. In particular, we'll want to keep the plot $x$- and $y$-limits intact. To that end, you can\n",
    "1. get these limits using `plt.xlim` and `plt.ylim` and\n",
    "2. use these values to decide which segment of the decision boundary line to plot. We want both endpoints of the decision boundary line segment to be on a boundary line segment of the plot.\n",
    "\n",
    "Write the function below. Then plot first the dataset, then the decision boundary line of the model you just trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_decision_boundary(\n",
    "    model: torch.nn.Linear,\n",
    "    color=\"black\"\n",
    "):\n",
    "    \"\"\"\n",
    "    Given a logistic regression model\n",
    "    on a binary classification dataset of 2-dimensional features,\n",
    "    plot the decision boundary line.\n",
    "\n",
    "    We assume that a previous plot has set the plot boundaries\n",
    "    `(x_min, x_max), (y_min_ y_max)`.\n",
    "    We the segment of the decision boundary line such that\n",
    "    both its endpoints are on a plot boundary line segment.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    model : `torch.nn.Linear`\n",
    "        The affine transformation mapping feature vectors to logits.\n",
    "        We assume it has input dimension 2 and output dimension 1.\n",
    "    color : optional\n",
    "        A color identifier for the plot. Default: `\"black\"`.\n",
    "    \"\"\"\n",
    "    state_dict = model.state_dict()\n",
    "    bias, (normal_x, normal_y) = (\n",
    "        state_dict[key][0].cpu()\n",
    "        for key in (\"bias\", \"weight\")\n",
    "    )\n",
    "\n",
    "    (x_min, x_max) = plt.xlim()\n",
    "    (y_min, y_max) = plt.ylim()\n",
    "\n",
    "    # nx x + ny y + b = 0\n",
    "    # x = -(ny y + b) / nx\n",
    "    # y = -(nx x + b) / ny\n",
    "\n",
    "    x_to_y = lambda x: -(normal_x * x + bias) / normal_y\n",
    "    y_to_x = lambda y: -(normal_y * y + bias) / normal_x\n",
    "\n",
    "    if normal_x.abs() < 1e-5:\n",
    "        x0, x1 = x_min, x_max\n",
    "        y0 = x_to_y(0)\n",
    "        y1 = y0\n",
    "    elif normal_y.abs() < 1e-5:\n",
    "        x0 = y_to_x(0)\n",
    "        x1 = x0\n",
    "        y0, y1 = y_min, y_max\n",
    "    else:\n",
    "        xz = (x_min, x_max, y_to_x(y_min), y_to_x(y_max))\n",
    "\n",
    "        x0, x1 = min(xz), max(xz)\n",
    "        y0, y1 = x_to_y(x0), x_to_y(x1)\n",
    "\n",
    "    plt.plot((x0, x1), (y0, y1), color=color)\n",
    "    plt.xlim(x_min, x_max)\n",
    "    plt.ylim(y_min, y_max)\n",
    "\n",
    "plt.gca().set_aspect(\"equal\")\n",
    "plot_2d_classification_dataset(config, features, labels)\n",
    "plot_decision_boundary(model)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're getting a reasonable result, repeat the above training and plotting steps with `classes_offset` being `((0, 0), (4, -2))` and `((0, 0), (0, 0))`. In each case, either print the accuracy or put it in the plot title."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for classes_offset in (((0, 0), (4, -2)), ((0, 0), (0, 0))):\n",
    "    config[\"classes_offset\"] = classes_offset\n",
    "    features, labels = get_hyperspheres_dataset(config)\n",
    "    model = torch.nn.Linear(\n",
    "        2,\n",
    "        1,\n",
    "        device=config[\"device\"],\n",
    "        dtype=torch.float32\n",
    "    )\n",
    "    accuracy = train_binary_classifier_simple(\n",
    "        config,\n",
    "        features,\n",
    "        labels,\n",
    "        model\n",
    "    )\n",
    "\n",
    "    plt.gca().set_aspect(\"equal\")\n",
    "    plot_2d_classification_dataset(config, features, labels)\n",
    "    plot_decision_boundary(model)\n",
    "    plt.title(f\"Accuracy: {accuracy:.2f}\")\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training an MLP\n",
    "\n",
    "Time to train a MLP on the dataset with 2 concentric (noisy) circles! We'll have\n",
    "1. 1 hidden layer of 3 dimensions and\n",
    "2. ReLU as activation function.\n",
    "\n",
    "That is, our full architecture is as follows:\n",
    "$$\n",
    "\\mathbf R^2 \\xrightarrow{A_0} \\mathbf R^3 \\xrightarrow{\\mathrm{ReLU}}\n",
    "\\mathbf R^3 \\xrightarrow{A_1} \\mathbf R.\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using `torch.nn.Sequential`\n",
    "\n",
    "We already know that we can define the affine transformations via `torch.nn.Linear`. To get a composite function as a model, we can use `torch.nn.Sequential`. Its positional arguments should be the `Module`s that we indend to apply in order. For activation, we can use `torch.nn.ReLU`. This gets us a `Module` with no parameters that applies ReLU elementwise to its input.\n",
    "\n",
    "Create the model and print:\n",
    "1. the model and\n",
    "2. the state dictionary of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "affine0 = torch.nn.Linear(\n",
    "    2,\n",
    "    3,\n",
    "    device=config[\"device\"],\n",
    "    dtype=torch.float32\n",
    ")\n",
    "affine1 = torch.nn.Linear(\n",
    "    3,\n",
    "    1,\n",
    "    device=config[\"device\"],\n",
    "    dtype=torch.float32\n",
    ")\n",
    "model = torch.nn.Sequential(\n",
    "    affine0,\n",
    "    torch.nn.ReLU(),\n",
    "    affine1\n",
    ")\n",
    "\n",
    "print(model)\n",
    "state_dict = model.state_dict()\n",
    "print(state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that when printing a model, we get a nice description of its components. In particular, the layers are indexed. Then in the state dictionary, we can see that the parameter names of `Linear`: `\"weight\"` and `\"bias\"` are prefixed by the layer index in the sequential model: `0` and `2`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the Model\n",
    "\n",
    "By a plot, make sure that you have the dataset with zero class offsets in hand."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.gca().set_aspect(\"equal\")\n",
    "plot_2d_classification_dataset(config, features, labels)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If what you're seeing is not concentric noisy circles, then adjust `classes_offset` in the configuration dictionary and regenerate the dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model and print its accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(config[\"classes_offset\"])\n",
    "accuracy = train_binary_classifier_simple(\n",
    "    config,\n",
    "    features,\n",
    "    labels,\n",
    "    model\n",
    ")\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the accuracy you're seeing is not super good, don't fret! Training such small models is highly dependend on initialization. For now, we can get around this issue by repeating:\n",
    "1. initializining a model and\n",
    "2. training it on the dataset\n",
    "\n",
    "until we get an accuracy value of over 95%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    affine0 = torch.nn.Linear(\n",
    "        2,\n",
    "        3,\n",
    "        device=config[\"device\"],\n",
    "        dtype=torch.float32\n",
    "    )\n",
    "    affine1 = torch.nn.Linear(\n",
    "        3,\n",
    "        1,\n",
    "        device=config[\"device\"],\n",
    "        dtype=torch.float32\n",
    "    )\n",
    "    model = torch.nn.Sequential(\n",
    "        affine0,\n",
    "        torch.nn.ReLU(),\n",
    "        affine1\n",
    "    )\n",
    "    accuracy = train_binary_classifier_simple(\n",
    "        config,\n",
    "        features,\n",
    "        labels,\n",
    "        model\n",
    "    )\n",
    "    print(accuracy)\n",
    "\n",
    "    if accuracy > .95:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3D Scatter Plots of Hidden Representations\n",
    "\n",
    "Now, we'll make 3D scatter plots of the hidden representations.\n",
    "\n",
    "Let's make them interactive! This way, you can rotate them. To that end, start the cell with the following magic command:\n",
    "```python\n",
    "%matplotlib widget\n",
    "```\n",
    "You only need to do this once in the notebook.\n",
    "\n",
    "With interactive plots, it's important not to call `plt.close` in the same cell. Closing the plot disables interactivity.\n",
    "\n",
    "To make a 3D plot, you can start the plotting procedure as follows:\n",
    "```python\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection=\"3d\")\n",
    "```\n",
    "Afterwards, you can make the scatter plot by `ax.scatter`.\n",
    "1. It needs 3 positional arguments for the 3 coordinates of the vertices.\n",
    "2. Just like in the 2-dimensional case, you can give an iterable of colors of the vertices as the `c` keyword argument.\n",
    "3. Make the `marker` keyword argument have value `'.'` for pointlike vertices.\n",
    "\n",
    "Take the first affine transformation and transform the dataset feature matrix with it. Make a scatter plot of the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib widget\n",
    "colors = get_colors(2, \"rainbow\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    hidden_features_0 = affine0(features).cpu()\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection=\"3d\")\n",
    "ax.scatter(*hidden_features_0.T, c=colors[labels.cpu()], marker='.')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the image of the circles by an affine transformation $\\mathbf R^2\\to\\mathbf R^3$. Apply ReLU to this image and plot the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_features_1 = F.relu(hidden_features_0)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection=\"3d\")\n",
    "ax.scatter(\n",
    "    *hidden_features_1.T,\n",
    "    c=colors[labels.cpu()],\n",
    "    marker='.'\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that the classes are separable by a hyperplane in this hidden representation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## License\n",
    "\n",
    "This work is licensed under CC BY-NC-SA 4.0. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
