{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d92520d0",
   "metadata": {},
   "source": [
    "# Convex Hull Volume with Transformers\n",
    "\n",
    "Note that, as we defined it, a transformer is a permutation-equivariant model. We will discuss next time how to imbue it with positional information. For now, we shall train a transformer on our permutation-invariant task of choice: convex hull volume."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20d8b89a",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `defaultdict`, `matplotlib.pyplot` as `plt`, `ConvexHull` from `scipy.spatial`, `torch`, and `torch.nn.functional` as `F`.\n",
    "\n",
    "Moreover, import the following:\n",
    "1. The function `get_mlp`, that you wrote in Notebook 0319.\n",
    "1. The class `AdamW`, that you wrote in Notebook 0326.\n",
    "2. The function `pbt_init`, that you wrote in Notebook 0328.\n",
    "3. The functions `get_dataloader_random_reshuffle` and `to_ensembled`, that you wrote in Notebook 0416.\n",
    "4. The classes `DictReLU` and `Linear`, and the functions `evaluate_model` and `get_output_by_batches`, that you wrote in Notebook 0423.\n",
    "5. The classes `Dropout` and `LayerNorm`, and the function `train_supervised`, that you wrote in Notebook 0425."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88dc971c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.spatial import ConvexHull\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from util_0430 import (\n",
    "    AdamW,\n",
    "    DictReLU,\n",
    "    Dropout,\n",
    "    evaluate_model,\n",
    "    get_dataloader_random_reshuffle,\n",
    "    get_mse,\n",
    "    get_output_by_batches,\n",
    "    LayerNorm,\n",
    "    Linear,\n",
    "    pbt_init,\n",
    "    to_ensembled,\n",
    "    train_supervised\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e4e30f1",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier, explained in notebook 091224.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Make this `(16,)`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    These three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "    1. Dropout probability $p$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "\n",
    "    We know that the dropout probability should be in the unit interval $[0,1]$. Moreover, it may not help if we zero more than half of the neurons. Thus, let's make its raw initial distribution the uniform distribution on $[0, 0.5]$. For raw perturb, maybe we can use a normal distribution with center $0$ and std $0.1$. For transform function, I recommend clipping the values at $0$ and $1$ as they should be probabilities.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this `32`.\n",
    "- `\"minibatch_size_eval\"`: `int`  \n",
    "    On my home computer, I can make this `128`.\n",
    "- `\"pbt\"` : `bool`  \n",
    "    Make this `True`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this `10_001`.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"valid_interval\"`: `int`  \n",
    "    Make this `100`.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. Based on my experiments in the setting of Homework 9, maybe you can try `.8`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fb6f44d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"device\": \"cuda\",\n",
    "    \"ensemble_shape\": (16,),\n",
    "    \"hyperparameter_raw_init_distributions\": {\n",
    "        \"dropout_p\": torch.distributions.Uniform(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(.01, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"epsilon\": torch.distributions.Uniform(\n",
    "            torch.tensor(-10, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"first_moment_decay\": torch.distributions.Uniform(\n",
    "            torch.tensor(-3, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"learning_rate\": torch.distributions.Uniform(\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"second_moment_decay\": torch.distributions.Uniform(\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"weight_decay\": torch.distributions.Uniform(\n",
    "            torch.tensor(-5, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(-1, device=\"cuda\", dtype=torch.float32)\n",
    "        )\n",
    "    },\n",
    "    \"hyperparameter_raw_perturb\": {\n",
    "        \"dropout_p\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(.01, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"epsilon\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"first_moment_decay\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"learning_rate\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"second_moment_decay\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "        \"weight_decay\": torch.distributions.Normal(\n",
    "            torch.tensor(0, device=\"cuda\", dtype=torch.float32),\n",
    "            torch.tensor(1, device=\"cuda\", dtype=torch.float32)\n",
    "        ),\n",
    "    },\n",
    "    \"hyperparameter_transforms\": {\n",
    "        \"dropout_p\": lambda p: p.clip(0,1),\n",
    "        \"epsilon\": lambda log10: 10 ** log10,\n",
    "        \"first_moment_decay\": lambda x: (1 - 10 ** x).clamp(0, 1),\n",
    "        \"learning_rate\": lambda log10: 10 ** log10,\n",
    "        \"second_moment_decay\": lambda x: (1 - 10 ** x).clamp(0, 1),\n",
    "        \"weight_decay\": lambda log10: 10 ** log10,\n",
    "    },\n",
    "    \"improvement_threshold\": 1e-4,\n",
    "    \"minibatch_size\": 1 << 5,\n",
    "    \"minibatch_size_eval\": 1 << 7,\n",
    "    \"pbt\": True,\n",
    "    \"seed\": 1,\n",
    "    \"steps_num\": 10_001,\n",
    "    \"steps_without_improvement\": 1_000,\n",
    "    \"valid_interval\": 100,\n",
    "    \"welch_confidence_level\": .8,\n",
    "    \"welch_sample_size\": 10,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73ad9d6e",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generation seed as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5274e83c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "381c2d3a",
   "metadata": {},
   "source": [
    "### Dataset Generation\n",
    "\n",
    "Change `generate_convex_hull_dataset` so that the vertices are at key `features`. Then, generate train, validation and test convex hull volume datasets, with `80_000`, `10_000`, and `10_000` dataset entries. In each, point clouds should have between `10` and `100` entries.\n",
    "\n",
    "Afterwards, create a train dataloader, and get a minibatch out of it. Print keys and value shapes of the minibatch dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab610367",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_convex_hull_dataset(\n",
    "    ambient_dim: int,\n",
    "    dataset_size: int,\n",
    "    device: int | str | torch.device,\n",
    "    subset_size_max: int,\n",
    "    subset_size_min: int,\n",
    "    std=1.\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Generate a supervised dataset mapping point clouds\n",
    "    to the volumes of their convex hulls.\n",
    "\n",
    "    The number of points in a point cloud\n",
    "    is sampled with uniform distribution from the closed interval\n",
    "    `[subset_size_min, subset_size_max]`\n",
    "    and the coordinates of the points are sampled from\n",
    "    the normal distribution with center `0.` and std `std`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    ambient_dim : `int`  \n",
    "        The dimension of the Euclidean space\n",
    "        the point clouds are to be finite subsets of.\n",
    "    dataset_size : `int`  \n",
    "        The number of dataset entries to generate.\n",
    "    device : `int | str | torch.device`\n",
    "        Device to store the dataset on.\n",
    "    subset_size_max : `int`  \n",
    "        The maximum number of points in a point cloud.\n",
    "    subset_size_min : `int`  \n",
    "        The minimum number of points in a point cloud.\n",
    "    std : `float`, optional  \n",
    "        The std of the normal distribution\n",
    "        the point coordinates are sampled from.\n",
    "        Default: `1.`\n",
    "\n",
    "    Returns \n",
    "    -------\n",
    "    The dataset, in the form of a dictionary with\n",
    "    `torch.Tensor`-valued keys `\"indptr\"`, `\"features\"` and `\"volume\"`,\n",
    "    which store the dataset as follows:\n",
    "    The `i`-th dataset entry has vertices `features[indptr[i]:indptr[i+1]]`\n",
    "    and volume `volume[i]`.\n",
    "\n",
    "    For ease of use with supervised learning algorithms,\n",
    "    the tensor `volume` is unsqueezed and has shape `(dataset_size, 1)`.\n",
    "    \"\"\"\n",
    "    sizes = torch.randint(\n",
    "        subset_size_min,\n",
    "        subset_size_max + 1,\n",
    "        (dataset_size,),\n",
    "        device=device,\n",
    "    )\n",
    "    indptr = torch.empty(\n",
    "        dataset_size + 1,\n",
    "        device=device,\n",
    "        dtype=torch.int64\n",
    "    )\n",
    "    indptr[0] = 0\n",
    "    indptr[1:] = torch.cumsum(sizes, 0)\n",
    "\n",
    "    vertices = torch.normal(\n",
    "        0.,\n",
    "        std,\n",
    "        (indptr[-1], ambient_dim),\n",
    "        device=device,\n",
    "        dtype=torch.float32\n",
    "    )\n",
    "\n",
    "    volume = torch.tensor(\n",
    "        [\n",
    "            ConvexHull(vertices[start:end].cpu()).volume\n",
    "            for start, end in zip(indptr[:-1], indptr[1:])\n",
    "        ],\n",
    "        device=device,\n",
    "        dtype=torch.float32\n",
    "    ).unsqueeze(-1)\n",
    "\n",
    "    return {\n",
    "        \"indptr\": indptr,\n",
    "        \"features\": vertices,\n",
    "        \"volume\": volume,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6e6dc22",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train = generate_convex_hull_dataset(\n",
    "    2,\n",
    "    80_000,\n",
    "    config[\"device\"],\n",
    "    100,\n",
    "    10\n",
    ")\n",
    "dataset_valid = generate_convex_hull_dataset(\n",
    "    2,\n",
    "    10_000,\n",
    "    config[\"device\"],\n",
    "    100,\n",
    "    10\n",
    ")\n",
    "dataset_test = generate_convex_hull_dataset(\n",
    "    2,\n",
    "    10_000,\n",
    "    config[\"device\"],\n",
    "    100,\n",
    "    10\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "444a7053",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader_train = get_dataloader_random_reshuffle(\n",
    "    config,\n",
    "    dataset_train\n",
    ")\n",
    "minibatch = next(dataloader_train)\n",
    "for key, value in minibatch.items():\n",
    "    print(key, value.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb56c3db",
   "metadata": {},
   "source": [
    "## Assembling a Transformer\n",
    "\n",
    "To approximate convex hull volumes, we'll create a small transformer model with embedding dimension 32, and 4 attention heads. Note that this will make the key dimension 8."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f89bf3c",
   "metadata": {},
   "source": [
    "### Affine Embeddings\n",
    "\n",
    "Note that the input has 2-dimensional features. Thus, we first need to map the features to 32-dimensional vectors. We can use a linear layer for this. You can call it `embedding`. When created, apply it to the input minibatch, and print keys and value shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0d500d",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = Linear(\n",
    "    config,\n",
    "    2,\n",
    "    32\n",
    ")\n",
    "\n",
    "minibatch = embedding(minibatch)\n",
    "for key, value in minibatch.items():\n",
    "    print(key, value.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b0746e5",
   "metadata": {},
   "source": [
    "### `MultiHeadSelfAttentionBlock`\n",
    "\n",
    "Let's implement a Pre-LN multi-head self-attention (MHSA) block.\n",
    "\n",
    "1. In initialization, you should create the following layers:\n",
    "    1. `LayerNorm`\n",
    "    2. Four `Linear` layers, with input and output dimensions `embedding_dim`, to give the key, query, value, and output weight matrices. Note that you'll not need biases.\n",
    "    3. `Dropout`.\n",
    "\n",
    "2. Now when implementing the `forward` method, you can follow the formula:\n",
    "    1. Get an extra reference to the input features, as we'll need them for the skip connetion.\n",
    "    2. Use layer norm.\n",
    "    3. Get key, query and value tensors via their respective weight matrices.\n",
    "        1. Reshape and transpose them to get them in shape `(population_size, minibatch_size, attention_head_num, sequence_length, key_dim)`.\n",
    "        2. Now you can use `F.scaled_dot_product_attention` to calculate the values aggregated by attention. I recommend using this function as it will select the best possible attention implementation, that is available.\n",
    "            1. Recall that the minibatches include a `\"mask\"` entry of shape `(..., sequence_length)`, that say which entries are not padding entries. This information can be fed to the function via the `attn_mask` keyword argument. The attention mask should have shape `(..., sequence_length, sequence_length)`. An entry `[..., i, j]` should be nonzero if and only if either\n",
    "                1. we have `mask[..., i]` and `mask[..., j]` or\n",
    "                2. we have `i == j` (we need this to make sure that softmax doesn't divide by zero).\n",
    "\n",
    "            To create the attention mask, you can use the bitwise and `&` and or `|` operations and broadcasting.\n",
    "\n",
    "    4. Apply the output weight matrix, and then dropout, to the aggregated value vectors.\n",
    "    5. Add the result to the skip connetion. The sum should make the new features entry in the minibatch that is output by the `forward` method.\n",
    "\n",
    "When you're done, apply `pbt_init` to the configuration dictionary, so that the dropout probabilities get initialized , then create a MHSA block, apply it to the output of the embedding layer, and finally print keys and value shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "374fe913",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadSelfAttentionBlock(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Pre-LN multi-head self-attention block.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    attention_head_num : `int`\n",
    "        The number of attention heads.\n",
    "    config : `int`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"dropout_p\"` : `torch.Tensor`\n",
    "            Dropout probability tensor, of shape `ensemble_shape`.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape.      \n",
    "    embedding_dim : `int`\n",
    "        The feature dimension of internal representations.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required keys:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of element-level features, of shape\n",
    "            `batch_shape + (sequence_dim, embedding_dim)` or\n",
    "            `ensemble_shape + batch_shape + (sequence_dim, embedding_dim)`\n",
    "        `\"mask\"` : `torch.Tensor`\n",
    "            Mask showing which entries are not padding, of shape\n",
    "            `batch_shape + (sequence_dim,)` or\n",
    "            `ensemble_shape + batch_shape + (sequence_dim,)`\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        attention_head_num: int,\n",
    "        config: dict,\n",
    "        embedding_dim: int\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.attention_head_num = attention_head_num\n",
    "        self.config = config\n",
    "        self.dropout = Dropout(config)\n",
    "        self.layer_norm = LayerNorm(\n",
    "            config,\n",
    "            embedding_dim\n",
    "        )\n",
    "\n",
    "        (\n",
    "            self.key_weights,\n",
    "            self.output_weights,\n",
    "            self.query_weights,\n",
    "            self.value_weights\n",
    "        ) = (\n",
    "            Linear(\n",
    "                config,\n",
    "                embedding_dim,\n",
    "                embedding_dim,\n",
    "                bias=False\n",
    "            )\n",
    "            for _ in range(4)\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        batch: dict\n",
    "    ) -> dict:\n",
    "        skip = batch[\"features\"]\n",
    "        batch = self.layer_norm(batch)\n",
    "        residual, mask = (batch[key] for key in (\"features\", \"mask\"))\n",
    "\n",
    "        sequence_dim, embedding_dim = residual.shape[-2:]\n",
    "        key_dim = embedding_dim // self.attention_head_num\n",
    "\n",
    "        key, query, value = (\n",
    "            (\n",
    "                linear(batch)\n",
    "            )[\"features\"].reshape(\n",
    "                residual.shape[:-1] + (self.attention_head_num, key_dim)\n",
    "            ).transpose(-3, -2)\n",
    "            for linear in (\n",
    "                self.key_weights,\n",
    "                self.query_weights,\n",
    "                self.value_weights\n",
    "            )\n",
    "        )\n",
    "\n",
    "        arange = torch.arange(sequence_dim, device=mask.device)\n",
    "        attention_mask = mask[..., None, :] & mask[..., None]\n",
    "        attention_mask |= (arange == arange[:, None])\n",
    "\n",
    "        pooled_values = F.scaled_dot_product_attention(\n",
    "            query,\n",
    "            key,\n",
    "            value,\n",
    "            attention_mask[..., None, :, :]\n",
    "        )\n",
    "\n",
    "        residual = pooled_values.transpose(-3, -2).reshape(residual.shape)\n",
    "        residual = self.output_weights({\"features\": residual})[\"features\"]\n",
    "        residual = self.dropout({\"features\": residual})[\"features\"]\n",
    "\n",
    "        features = skip + residual\n",
    "\n",
    "        return batch | {\"features\": features}\n",
    "    \n",
    "pbt_init(config, defaultdict(list))\n",
    "mhsa = MultiHeadSelfAttentionBlock(\n",
    "    4,\n",
    "    config,\n",
    "    32\n",
    ")\n",
    "minibatch = mhsa(minibatch)\n",
    "for key, value in minibatch.items():\n",
    "    print(key, value.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f48314e",
   "metadata": {},
   "source": [
    "### `FeedForwardBlock`\n",
    "\n",
    "One more block to go! This one is pretty straightforward. Once again, test output shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeeca17a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeedForwardBlock(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Pre-LN feedforward block.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `int`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `str`\n",
    "            The device to store parameters on.\n",
    "        `\"dropout_p\"` : `torch.Tensor`\n",
    "            Dropout probability tensor, of shape `ensemble_shape`.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape.      \n",
    "    embedding_dim : `int`\n",
    "        The feature dimension of internal representations.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    Instance calls require one positional argument:\n",
    "    batch : `dict`\n",
    "        The input data dictionary. Required key:\n",
    "        `\"features\"` : `torch.Tensor`\n",
    "            Tensor of element-level features, of shape\n",
    "            `batch_shape + (sequence_dim, embedding_dim)` or\n",
    "            `ensemble_shape + batch_shape + (sequence_dim, embedding_dim)`\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict,\n",
    "        embedding_dim: int\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.residual_f = torch.nn.Sequential(\n",
    "            LayerNorm(\n",
    "                config,\n",
    "                embedding_dim\n",
    "            ),\n",
    "            Linear(\n",
    "                config,\n",
    "                embedding_dim,\n",
    "                embedding_dim,\n",
    "                init_multiplier=2 ** .5\n",
    "            ),\n",
    "            DictReLU(),\n",
    "            Linear(\n",
    "                config,\n",
    "                embedding_dim,\n",
    "                embedding_dim\n",
    "            ),\n",
    "            Dropout(config)\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> dict:\n",
    "        skip = batch\n",
    "\n",
    "        residual = self.residual_f(batch)\n",
    "\n",
    "        features = skip[\"features\"] + residual[\"features\"]\n",
    "\n",
    "        return batch | {\"features\": features}\n",
    "    \n",
    "ff = FeedForwardBlock(config, 32)\n",
    "minibatch = ff(minibatch)\n",
    "for key, value in minibatch.items():\n",
    "    print(key, value.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79bd78c9",
   "metadata": {},
   "source": [
    "### `MeanPool`\n",
    "\n",
    "Write a layer that aggregates its input along the sequential dimension. Try it out on the output of the feedforward block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "147685b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MeanPool(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    This module performs mean pool of sequential data\n",
    "    along the sequential dimension.\n",
    "\n",
    "    Arguments\n",
    "    ---------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            The shape of the ensemble of models\n",
    "            that process the data.\n",
    "\n",
    "    Calling\n",
    "    -------\n",
    "    The expected input is a dictionary of the following\n",
    "    key-tensor pairs:\n",
    "    `\"features\"`\n",
    "        The feature tensor, of shape\n",
    "        `(..., sequence_dim, feature_dim)`\n",
    "    `\"mask\"`\n",
    "        This mask signals which entry is not a padding element.\n",
    "        It should have size `(..., sequence_dim)`.\n",
    "    \"\"\"\n",
    "    def __init__(self, config: dict):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "\n",
    "\n",
    "    def forward(self, batch: dict) -> dict:\n",
    "        features, mask = (\n",
    "            to_ensembled(self.config[\"ensemble_shape\"], batch[key])\n",
    "            for key in (\"features\", \"mask\")\n",
    "        )\n",
    "\n",
    "        pooled = (\n",
    "            (features * mask[..., None]).sum(dim=-2)\n",
    "          / mask[..., None].sum(dim=-2)\n",
    "        )\n",
    "\n",
    "        return batch | {\"features\": pooled}\n",
    "    \n",
    "mean_pool = MeanPool(config)\n",
    "minibatch = mean_pool(minibatch)\n",
    "for key, value in minibatch.items():\n",
    "    print(key, value.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b7b3028",
   "metadata": {},
   "source": [
    "### Stack the Blocks\n",
    "\n",
    "Now, you are ready to create the full model, using `torch.nn.Sequential`. It should include:\n",
    "1. An embedding layer,\n",
    "1. A dropout layer,\n",
    "2. A couple of pairs of MHSA and Feedforward blocks, collectively referred to as *transformer blocks*, and\n",
    "3. A mean pool layer.\n",
    "4. An affine layer, to map the 32 embedding dimensions to the 1 target dimension.\n",
    "\n",
    "Make the full model, apply it to a new minibatch, and check output tensor shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21fbe760",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    Linear(config, 2, 32),\n",
    "    Dropout(config),\n",
    "    MultiHeadSelfAttentionBlock(4, config, 32),\n",
    "    FeedForwardBlock(config, 32),\n",
    "    MultiHeadSelfAttentionBlock(4, config, 32),\n",
    "    FeedForwardBlock(config, 32),\n",
    "    MultiHeadSelfAttentionBlock(4, config, 32),\n",
    "    FeedForwardBlock(config, 32),\n",
    "    MeanPool(config),\n",
    "    Linear(config, 32, 1)\n",
    ")\n",
    "\n",
    "output = model(next(dataloader_train))\n",
    "for key, value in output.items():\n",
    "    print(key, value.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca8918ae",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf72b5d8",
   "metadata": {},
   "source": [
    "Make an `AdamW` optimizer for the model, and train it on the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352ee337",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = AdamW(model.parameters())\n",
    "log = train_supervised(\n",
    "    config,\n",
    "    dataset_train,\n",
    "    dataset_valid,\n",
    "    get_mse,\n",
    "    lambda predict, target: -get_mse(predict, target),\n",
    "    model,\n",
    "    optimizer,\n",
    "    target_key=\"volume\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "937a8267",
   "metadata": {},
   "source": [
    "Test the best ensemble entry on the test split, with plotting the true and predicted convex hull areas, as we did in Notebook 0416."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f4a1c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_single = config | {\"dropout_p\": 0, \"ensemble_shape\": ()}\n",
    "model = torch.nn.Sequential(\n",
    "    Linear(config_single, 2, 32),\n",
    "    Dropout(config_single),\n",
    "    MultiHeadSelfAttentionBlock(4, config_single, 32),\n",
    "    FeedForwardBlock(config_single, 32),\n",
    "    MultiHeadSelfAttentionBlock(4, config_single, 32),\n",
    "    FeedForwardBlock(config_single, 32),\n",
    "    MultiHeadSelfAttentionBlock(4, config_single, 32),\n",
    "    FeedForwardBlock(config_single, 32),\n",
    "    MeanPool(config_single),\n",
    "    Linear(config_single, 32, 1)\n",
    ")\n",
    "model.eval()\n",
    "model.load_state_dict(log[\"best parameters\"])\n",
    "print(evaluate_model(\n",
    "    config_single,\n",
    "    dataset_valid,\n",
    "    get_mse,\n",
    "    model,\n",
    "    target_key=\"volume\"\n",
    "))\n",
    "test_predict = get_output_by_batches(\n",
    "    config_single,\n",
    "    dataset_test,\n",
    "    model,\n",
    "    1\n",
    ").squeeze(-1)\n",
    "test_target = dataset_test[\"volume\"].squeeze(-1)\n",
    "\n",
    "argsort = test_target.argsort()\n",
    "plt.plot(test_target[argsort].cpu())\n",
    "plt.scatter(\n",
    "    torch.arange(len(test_target)),\n",
    "    test_predict[argsort].cpu(),\n",
    "    c=\"red\",\n",
    "    s=1\n",
    ")\n",
    "plt.title(f\"Test MSE: {get_mse(test_predict[..., None], test_target[..., None]).cpu().item():.4f}\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f241abc6",
   "metadata": {},
   "source": [
    "## License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cbb5895",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
