{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lab: Dataloader from Ensembled Dataset\n",
    "\n",
    "For time constraints, we will not code a full PPO algorithm today. Instead, we will implement just one important piece, which will have more general purpose:\n",
    "\n",
    "If you generate data in a POMDP for example, you end up with a couple of distinct tensors such as: observations, actions, rewards and termination signals. As these can have different shapes and datatypes, you cannot represent them as one tensor.\n",
    "\n",
    "Thus, a convenient approach is to represent the dataset as a dictionary with tensor values. Then a minibatch should be a dictionary with the same keys, but sampled values.\n",
    "\n",
    "Moreover, given a policy ensemble for example, we have to deal with datasets the values of which are ensembled themselves. We will update our dataloader to include these features."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Imports\n",
    "\n",
    "Import `Generator`, `torch`, `Optional` and the function `get_random_reshuffler`, that you wrote in Notebook 0221."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections.abc import Generator\n",
    "import torch\n",
    "from typing import Optional\n",
    "\n",
    "from util_0328 import get_random_reshuffler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configurations\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"dataset_size\"`: `int`  \n",
    "    The size per ensemble entry of the synthetic dataset we'll generate. Make this `20`.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Make this `(4,)`.\n",
    "- `\"minibatch_size\"`: `int`  \n",
    "    Make this `3`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"upper\"`: `int`\n",
    "    The entries in the `\"numbers\"` tensor of our synthetic dataset will be random integers between 0 (inclusive) and this number (exclusive). Make it `256`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_size\": 20,\n",
    "    \"ensemble_shape\": (4,),\n",
    "    \"minibatch_size\": 3,\n",
    "    \"seed\": 0,\n",
    "    \"upper\": 256\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the `torch` pseudo-random number generator as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Generation\n",
    "\n",
    "1. Write the following two functions, that generate a synthetic dataset and test that it is as intended. The latter function will also be used to check that the minibatches are correct.\n",
    "2. Generate a dataset.\n",
    "3. Run the test on it. It should pass.\n",
    "4. Change an entry in the dataset.\n",
    "5. Run the test on it. It should fail.\n",
    "6. Redo steps 2 and 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(\n",
    "    config: dict\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Generates a synthetic dataset to be used for testing our new dataloader.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"dataset_size\"` : `int`\n",
    "            The size of the dataset per ensemble member.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape. We expect it to be 1-dimensional.\n",
    "        `\"upper\"` : `int`\n",
    "            The entries in the `\"numbers\"` tensors of the dataset\n",
    "            are random integers between `0` (inclusive)\n",
    "            and this value (exclusive).\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A dictionary with the following key-value pairs:\n",
    "    `\"numbers\"` : `torch.Tensor`\n",
    "        An integer tensor of shape `ensemble_shape + (dataset_size, 2)`\n",
    "        of entries as specified in `\"upper\"`.\n",
    "    `\"plus\"` : `torch.Tensor`\n",
    "        An integer tensor of shape `ensemble_shape + (dataset_size,)`\n",
    "        where the entry of index pair `(i,j)` is\n",
    "        ```\n",
    "        i * (numbers[i,j,0] + numbers[i,j,1])\n",
    "        ```\n",
    "    `\"minus\"` : `torch.Tensor`\n",
    "        An integer tensor of shape `ensemble_shape + (dataset_size,)`\n",
    "        where the entry of index pair `(i,j)` is\n",
    "        ```\n",
    "        i * (numbers[i,j,0] - numbers[i,j,1])\n",
    "        ```\n",
    "    \"\"\"\n",
    "    ensemble_shape = config[\"ensemble_shape\"]\n",
    "    ensemble_dim = len(ensemble_shape)\n",
    "    if ensemble_dim != 1:\n",
    "        raise ValueError(f\"The number of dimensions in `config['ensemble_shape']` should be 1 but it is {ensemble_dim}\")\n",
    "    \n",
    "    dataset_size = config[\"dataset_size\"]\n",
    "    ensemble_num = ensemble_shape[0]\n",
    "\n",
    "    numbers = torch.randint(\n",
    "        high=config[\"upper\"],\n",
    "        size=(ensemble_num, dataset_size, 2)\n",
    "    )\n",
    "    plus = (\n",
    "        torch.arange(ensemble_num)[:, None]\n",
    "      * (\n",
    "            numbers[..., 0]\n",
    "          + numbers[..., 1]  \n",
    "        )\n",
    "    )\n",
    "    minus = (\n",
    "        torch.arange(ensemble_num)[:, None]\n",
    "      * (\n",
    "            numbers[..., 0]\n",
    "          - numbers[..., 1]  \n",
    "        )\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"numbers\": numbers,\n",
    "        \"plus\": plus,\n",
    "        \"minus\": minus\n",
    "    }\n",
    "\n",
    "\n",
    "def test_dataset(\n",
    "    config: dict,\n",
    "    dataset: dict\n",
    "):\n",
    "    \"\"\"\n",
    "    Tests if `dataset` is a valid output of `get_dataset`.\n",
    "    \"\"\"\n",
    "    ensemble_shape = config[\"ensemble_shape\"]\n",
    "    ensemble_dim = len(ensemble_shape)\n",
    "    if ensemble_dim != 1:\n",
    "        raise ValueError(f\"The number of dimensions in `config['ensemble_shape']` should be 1 but it is {ensemble_dim}\")\n",
    "    \n",
    "    dataset_size = config[\"dataset_size\"]\n",
    "    ensemble_num = ensemble_shape[0]\n",
    "    upper = config[\"upper\"]\n",
    "\n",
    "    numbers, plus, minus = (\n",
    "        dataset[key]\n",
    "        for key in [\"numbers\", \"plus\", \"minus\"]\n",
    "    )\n",
    "        \n",
    "    numbers_invalid = (numbers < 0) | (numbers >= upper)\n",
    "    if torch.any(numbers_invalid):\n",
    "        i, j, k = torch.nonzero(numbers_invalid)[0]\n",
    "        raise ValueError(f\"numbers[{i},{j},{k}] should be between 0 (inclusive) and {upper} exclusive, but it is {numbers[i,j,k]}\")\n",
    "    \n",
    "    plus_invalid = (\n",
    "        torch.arange(ensemble_num)[:, None]\n",
    "      * (\n",
    "            numbers[..., 0]\n",
    "          + numbers[..., 1]  \n",
    "        )\n",
    "    ) != plus\n",
    "    if torch.any(plus_invalid):\n",
    "        i, j = torch.nonzero(plus_invalid)[0]\n",
    "        raise ValueError(f\"plus[{i},{j}] should be {i} * (numbers[{i},{j},0] + numbers[{i},{j},1]) = {i * (numbers[i,j,0] + numbers[i,j,1])}, but it is {plus[i,j]}\")\n",
    "    \n",
    "    minus_invalid = (\n",
    "        torch.arange(ensemble_num)[:, None]\n",
    "      * (\n",
    "            numbers[..., 0]\n",
    "          - numbers[..., 1]  \n",
    "        )\n",
    "    ) != minus\n",
    "    if torch.any(minus_invalid):\n",
    "        i, j = torch.nonzero(minus_invalid)[0]\n",
    "        raise ValueError(f\"minus[{i},{j}] should be {i} * (numbers[{i},{j},0] - numbers[{i},{j},1]) = {i * (numbers[i,j,0] - numbers[i,j,1])}, but it is {minus[i,j]}\")\n",
    "    \n",
    "dataset = get_dataset(config)\n",
    "test_dataset(config, dataset)\n",
    "dataset[\"plus\"][0] = 1\n",
    "test_dataset(config, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(config)\n",
    "test_dataset(config, dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling minibatches from an ensembled dataset\n",
    "\n",
    "### Review `get_random_reshuffler`\n",
    "\n",
    "Recall that the `get_random_reshuffler` function outputs a generator that yields indices that have ensemble and dataset dimensions.\n",
    "\n",
    "1. Get a random reshuffler of\n",
    "    1. dataset size,\n",
    "    2. minibatch size and\n",
    "    3. ensemble shape\n",
    "\n",
    "    as per the configuration dictionary.\n",
    "2. Get an output and print it. Is it what you expect it to be?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_reshuffler = get_random_reshuffler(\n",
    "    config[\"dataset_size\"],\n",
    "    config[\"minibatch_size\"],\n",
    "    ensemble_shape=config[\"ensemble_shape\"]\n",
    ")\n",
    "indices = next(random_reshuffler)\n",
    "print(indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `get_minibatch`\n",
    "\n",
    "Let us next write the function that given\n",
    "1. an ensembled dataset and\n",
    "2. ensembled indices,\n",
    "\n",
    "returns the appropriate minibatch as a dictionary.\n",
    "\n",
    "Note that the shapes of dataset values can have more dimensions than the shape of the index tensor. In this case, you need to\n",
    "1. reshape the index tensor so that its shape has an appropriate number of 1's on the right and\n",
    "2. expand the reshaped index tensor so that the extra 1's get turned to the dimensions in the shape of the dataset value tensor.\n",
    "\n",
    "After you transformed the index tensor like this, you can use the `gather` method of the dataset value tensors.\n",
    "\n",
    "After you wrote the function, call it on your dataset and the index tensor you just generated. Print the keys and the shapes of the values of the dictionary you get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_minibatch(\n",
    "    dataset: dict,\n",
    "    indices: torch.Tensor\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Returns the minibatch of an ensembled dataset\n",
    "    given by ensembled indices.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset: `dict`\n",
    "        A dataset given as a dictionary with `torch.Tensor` values.\n",
    "    indices: `torch.Tensor`\n",
    "        An index tensor for the dataset.\n",
    "        Each value in the dataset should have shape prefixed by\n",
    "        the shape of the index tensor.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The minibatch given by the dataset and the index tensor.\n",
    "    \"\"\"\n",
    "    return {\n",
    "        key: value.gather(\n",
    "            indices.dim() - 1,\n",
    "            indices.unflatten(\n",
    "                -1,\n",
    "                (-1,) + (1,) * (value.dim() - indices.dim())\n",
    "            ).expand(\n",
    "                indices.shape + value.shape[indices.dim():]\n",
    "            )\n",
    "        )\n",
    "        for key, value in dataset.items()\n",
    "    }\n",
    "\n",
    "minibatch = get_minibatch(dataset, indices)\n",
    "print(minibatch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `is_ensembled`\n",
    "\n",
    "Our new dataloader should be able to handle both ensembled and not ensembled datasets. Recall that we view a tensor as ensembled if it is prefixed by the ensemble shape. Let's write a function that checks if this is the case. Test it on a minibatch value and the first entry of a dataset value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_ensembled(\n",
    "    ensemble_shape: tuple[int],\n",
    "    tensor: torch.Tensor\n",
    ") -> bool:\n",
    "    \"\"\"\n",
    "    We view `tensor` as *ensembled* if it is prefixed by `ensemble_shape`,\n",
    "    that is its slice of the first `len(ensemble_shape)` entries\n",
    "    is `ensemble_shape`.\n",
    "\n",
    "    This function checks this condition.\n",
    "    \"\"\"\n",
    "    return tensor.shape[:len(ensemble_shape)] == ensemble_shape\n",
    "\n",
    "print(\n",
    "    is_ensembled(config[\"ensemble_shape\"], minibatch[\"numbers\"]),\n",
    "    is_ensembled(config[\"ensemble_shape\"], dataset[\"numbers\"][0])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `get_dataloader_random_reshuffle`\n",
    "\n",
    "Time to write the new dataloader! To be able to use `get_minibatch`, you should create a new dataset with the same keys but values broadcast to ensembled tensors. You can perform this operation with the use of `is_ensemble`.\n",
    "\n",
    "Once you wrote the function, create a dataloader and get a minibatch. Check that the shapes of its values are correct and call `test_dataset` on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloader_random_reshuffle(\n",
    "    config: dict,\n",
    "    dataset: dict,\n",
    "    minibatch_size: Optional[int] = None\n",
    ") -> Generator[dict]:\n",
    "    \"\"\"\n",
    "    Given a dataset as a dictionary with tensor values,\n",
    "    creates a random reshuffling (without replacement) dataloader\n",
    "    that yields pairs `minibatch_features, minibatch_labels` indefinitely.\n",
    "    Support arbitrary ensemble shapes.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        ensemble_shape : tuple[int]\n",
    "            The required ensemble shapes of the outputs.\n",
    "    dataset : `dict`\n",
    "        Dataset with `torch.Tensor` values.\n",
    "    minibatch_size : `int`, optional\n",
    "        Minibatch size. If not given, it is `config[\"minibatch_size\"]`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A generator of minibatch dictionaries.\n",
    "    \"\"\"\n",
    "    ensemble_shape = config[\"ensemble_shape\"]\n",
    "    ensemble_dim = len(ensemble_shape)\n",
    "    if minibatch_size is None:\n",
    "        minibatch_size = config[\"minibatch_size\"]\n",
    "\n",
    "    dataset = {\n",
    "        key: value.broadcast_to(\n",
    "            ensemble_shape * (not is_ensembled(ensemble_shape, value))\n",
    "          + value.shape\n",
    "        )\n",
    "        for key, value in dataset.items()\n",
    "    }\n",
    "    value = next(iter(dataset.values()))\n",
    "    dataset_size = value.shape[ensemble_dim]\n",
    "\n",
    "    random_reshuffler = get_random_reshuffler(\n",
    "        dataset_size,\n",
    "        minibatch_size,\n",
    "        device=value.device,\n",
    "        ensemble_shape=config[\"ensemble_shape\"]\n",
    "    )\n",
    "\n",
    "    for indices in random_reshuffler:\n",
    "        yield get_minibatch(dataset, indices)\n",
    "\n",
    "dataloader = get_dataloader_random_reshuffle(\n",
    "    config,\n",
    "    dataset\n",
    ")\n",
    "minibatch = next(dataloader)\n",
    "for key, value in minibatch.items():\n",
    "    print(key, value.shape)\n",
    "\n",
    "test_dataset(config, minibatch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
