{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE\n",
    "## Setup\n",
    "### Imports\n",
    "\n",
    "### Imports\n",
    "\n",
    "1. Import `defaultdict`, `Iterable`, `gymnasium` as `gym`, `IPython.display.Video`, `ImageSequenceClip` from `moviepy`, `os`, `torch`, `tqdm` and `Optional`.\n",
    "2. Import\n",
    "    1. `get_seed`, that you wrote in Notebook 0228,\n",
    "    2. `get_mlp`, that you wrote in Notebook 0319,\n",
    "    3. `welch_one_sided`, that you wrote in Notebook 0321 and\n",
    "    4. `AdamW` and `Optimizer` that you wrote in Notebook 0328."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Create a configuration dictionary with the following keys:\n",
    "- `\"device\"`: `torch.device | int | str`  \n",
    "    The device identifier.\n",
    "- `\"ensemble_shape\"`: `tuple[int]`  \n",
    "    Today, we'll use PBT with population size 16. Thus, make this `(16,)`.\n",
    "- `\"env_id\"`: `str`  \n",
    "    The `gym` environment identifier. You can see on its page:  \n",
    "    https://gymnasium.farama.org/environments/classic_control/cart_pole/  \n",
    "    that in the case of Cart Pole, this is `\"CartPole-v1\"`.\n",
    "- `\"env_kwargs\"`: `dict`  \n",
    "    This dictionary stores extra settings in the environment. In the case of Cart Pole, set this to an empty dictionary.\n",
    "- `\"eval_interval\"`: `int`  \n",
    "    The frequency of evaluations and PBT updates in terms of train steps. Make this `100`.\n",
    "- `\"hyperparameter_raw_init_distributions\"`, `\"hyperparameter_raw_perturb\"`, `\"hyperparameter_transforms\"` : `dict`  \n",
    "    Once again, these three dictionaries are going to determine how the hyperparameters are tuned. We'll tune the following hyperparameters:\n",
    "    1. Epsilon $\\epsilon$.\n",
    "    2. Learning rate $\\eta$.\n",
    "    3. Weight decay $\\lambda$.\n",
    "    4. First moment moving average decay rate $\\beta_1$.\n",
    "    5. Second moment moving average decay rate $\\beta_2$.\n",
    "\n",
    "    Of these, we don't know the required order of magnitude of the first three. Thus it may be good to make them distributed along $10^\\mathscr D$ where $\\mathscr D$ is a normal or uniform distribution. You can try to center the distributions at the recommended values.\n",
    "\n",
    "    We know that the recommended values of the fourth and fifth are $0.9$ and $0.999$. So it may be best to give them a distribution of the form $1-10^\\mathscr D$.\n",
    "- `\"improvement_threshold:`: `float`  \n",
    "    Make this `1e-4`.\n",
    "- `\"seed\"`: `int`  \n",
    "    This is for reproducible experiments. Insert any integer.\n",
    "- `\"steps_num\"`: `int`  \n",
    "    Make this a `10_000`. Let early stopping take care of stopping.\n",
    "- `\"steps_without_improvement`: `int`  \n",
    "    Make this `1000`.\n",
    "- `\"videos_dictionary\"`: `str`\n",
    "    The path to the directory to store videos at. I set this to `videos`. Change it at will.\n",
    "- `\"welch_confidence_level\"`: `float`  \n",
    "    We will exploit based on a one-sided Welch $t$-test with this confidence level. As the PBT paper does not specify this, let's start with `.95`. Feel free to try out various values here!\n",
    "- `\"welch_sample_size\"`: `int`  \n",
    "    We will exploit based on a one-sided Welch $t$-test on the last this many validation metrics of the population members. To follow the PBT paper, make this `10`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the `torch` pseudo-random number generation as per the configuration dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making a Video on a Single Environment\n",
    "\n",
    "### Create and Test Single Environment\n",
    "\n",
    "1. Create a single environment as per the configuration dictionary.\n",
    "2. Print its action and observation spaces.\n",
    "3. Reset it (don't forget to `get_seed`) and print the observation you get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that:\n",
    "1. The action space is a discrete space of 2 entries and\n",
    "2. The observation space is a continuous space of 4 dimensions.\n",
    "    1. The first 2 attributes you see are the lower and upper boundaries of the 4 components.\n",
    "    2. You can get the shape of observations as the `shape` attribute of the observation space.\n",
    "\n",
    "Print the number of actions and the shape of the observation space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting Logits\n",
    "\n",
    "1. Create an MLP with:\n",
    "    1. Number of input features the number of observation dimensions of the environment.\n",
    "    2. Number of output features the number of possible actions.\n",
    "    3. 2 hidden layers with 128 dimensions each.\n",
    "2. Convert the observation vector to a `torch.Tensor` of device as per the configuration dictionary and datatype `torch.float32`.\n",
    "3. Fill in the function `get_logits` below:\n",
    "    1. Recall that our `Linear` layers expect input tensors to have at least 1 batch dimension. Therefore, you need to give the observation vector a batch dimension of 1.\n",
    "    2. Recall that given an input tensor of shape either `ensemble_shape + batch_shape + (in_features,)` or `batch_shape + (in_features,)`, the MLP will output a tensor of shape `ensemble_shape + batch_shape + (out_features,)`. As in our case, the batch shape is an auxiliary one, let's remove it by indexing or reshaping before we return the logit tensor.\n",
    "4. Use `get_logits` to get a logit tensor from the initial observation. Is its shape what you would expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logits(\n",
    "    config: dict,\n",
    "    observation: torch.Tensor,\n",
    "    policy: torch.nn.Module\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    In the context of a POMDP\n",
    "    with continuous observation space and finite action space,\n",
    "    given an observation tensor and a policy model,\n",
    "    output the unnormalized next action logits as per the model.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pair:\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape.\n",
    "    observation : `torch.Tensor`\n",
    "        A tensor of a single or multiple observations of shape\n",
    "        `batch_shape + (observation_dim,)` or\n",
    "        `ensemble_shape + batch_shape + (observation_dim,)`\n",
    "    policy : `torch.nn.Module`\n",
    "        A policy model that outputs unnormalized next action logits.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The tensor of unnormalized next action logits.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `make_video`\n",
    "\n",
    "Based on the function `run_episode` that you created in Notebook 0305, write the function `make_video` as per the docstring.\n",
    "\n",
    "We make a video following a given ensemble member as the idea is that after each evaluation, we make a video following the best performing policy.\n",
    "\n",
    "After you updated the function, make a video using it by following a member of the randomly initialized policy ensemble.\n",
    "\n",
    "Print the discounted and undiscounted return and play the video using `Video`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_video(\n",
    "    config: dict,\n",
    "    policy: torch.nn.Module,\n",
    "    ensemble_id: Optional[int] = 0,\n",
    "    fps: Optional[int] = None,\n",
    "    video_name=\"test.mp4\",\n",
    ") -> tuple[float, float, str]:\n",
    "    \"\"\"\n",
    "    Given a POMDP with continuous observation space and finite action space\n",
    "    and an ensemble of policies as models that output\n",
    "    unnormalized action logits,\n",
    "    make a video of an episode following an ensemble member.\n",
    "\n",
    "    Here, we follow the policy deterministically,\n",
    "    that is we choose the action with the highest logit.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `torch.device | int | str`\n",
    "            The device the policy model is stored on.\n",
    "        `\"discount\"` : `float`\n",
    "            Discount to use when calculating the discounted return.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape of the policy model.\n",
    "        `\"env_id\"` : `str`\n",
    "            The ID of the environment in the `gym` registry.\n",
    "        `\"env_kwargs\"` : `dict`\n",
    "            Additional arguments of the environment.\n",
    "        `\"videos_directory\"` : `str`\n",
    "            Path to the directory to save the video to.\n",
    "    policy : `torch.nn.Module`\n",
    "        The policy model.\n",
    "    ensemble_id : `int`, optional\n",
    "        The ID of the ensemble member to follow. Default: 0\n",
    "    fps : `int`, optional\n",
    "        Frames per second in the video.\n",
    "        If not given, we use the default given in the environment,\n",
    "        at the `\"render_fps\"` key of its `metadata` attribute.\n",
    "    video_name : `str`\n",
    "        The of the video to create.\n",
    "        Its extension determines the video format. Default: \"test.mp4\"\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The triple of:\n",
    "    1. The discounted return.\n",
    "    2. The undiscounted return.\n",
    "    3. The path to the video.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting Vectorized Episode Data\n",
    "\n",
    "### Create and Test Vector Environment\n",
    "\n",
    "1. Create a vectorized environment with as many environments as the first (and only) dimension in the ensemble shape.\n",
    "2. Print the shape of its observation space.\n",
    "3. Get an initial observation by resetting the environment. Check that the shape of the initial observation is as expected.\n",
    "4. Get logits from the initial observation. Print their shape too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the Gumbel-Max Trick\n",
    "\n",
    "1. Create a standard Gumbel distribution as a `torch.distributions.Gumbel`. Make sure to give the mean and std with device as per the configuration dictionary.\n",
    "2. Get a sample of the sample shape as the logits.\n",
    "3. Add the logits and the sample and take argmax along the last dimension. This is a sample of actions as per the policy ensemble at the initial observation. Print the sample of actions. Are its shape and its values what you expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `get_episode_data`\n",
    "\n",
    "Time to write the function to generate vectorized episode data. The objective of this is twofold:\n",
    "1. If we supply a Gumbel distribution, `get_episode_data` should output vectorized episode data following sampling actions from the action distributions as per the policy ensemble. This can be used for training steps.\n",
    "2. If we do not supply a Gumbel distribution, `get_episode_data` should output vectorized episode data following the deterministic policies, that is when we take the actions with the highest logits. This can be used for evaluating the policy ensemble members.\n",
    "\n",
    "1. Write the function described below.\n",
    "    1. Create lists to collect actions, observations and rewards step by step.\n",
    "    2. Convert the `np.ndarray` output by the `gym.vector.VectorEnv` to `torch.Tensor`.\n",
    "        1. Use the device as per the configuration dictionary for all three.\n",
    "        2. For actions, use the `torch.int64` datatype. For the others, use `torch.float32`.\n",
    "    3. Don't forget to start the observation list by the initial observation.\n",
    "    4. Just like in the `evaluate_q_values` function you wrote in Notebook 0307, you need to keep track of what episode is still ongoing.\n",
    "        1. First, add the rewards multiplied by the old ongoing mask to the reward list.\n",
    "        2. Then update the ongoing mask.\n",
    "2. Run it twice:\n",
    "    1. once with sampled and \n",
    "    2. once with deterministic actions.\n",
    "3. Note that adding up the rewards along the sequence dimension gives you the undiscounted returns per ensemble member. Compare the means of these."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_episode_data(\n",
    "    config: dict,\n",
    "    env: gym.vector.VectorEnv,\n",
    "    policy: torch.nn.Module,\n",
    "    gumbel: Optional[torch.distributions.Gumbel] = None\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Given a vectorized POMDP and a policy model,\n",
    "    run one episode per environment in parallel\n",
    "    and output actions, observations and rewards.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `torch.device | int | str`\n",
    "            The device the policy model is stored on.\n",
    "        `\"ensemble_shape\"` : `tuple[int]`\n",
    "            Ensemble shape of the policy model.\n",
    "    env : `gym.vector.VectorEnv`\n",
    "        The vectorized environment.\n",
    "    policy : `torch.nn.Module`\n",
    "        The policy model.\n",
    "    gumbel : `torch.distributions.Gumbel`, optional\n",
    "        If given, actions are sampled as per their logits,\n",
    "        using the Gumbel-Max trick. Otherwise, the actions\n",
    "        with the largest logits are taken.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A dictionary with keys `\"actions\"`, `\"observations\"` and `\"rewards\"`\n",
    "    that stores episode data as tensors:\n",
    "    1. Their first two dimensions are\n",
    "        1. the number of environments and\n",
    "        2. the maximum number of steps in the episodes.\n",
    "    2. The reward where an episode has already ended is 0.\n",
    "        1. The actions and observations at these positions are irrelevant.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting the Discounted Return Tensor\n",
    "\n",
    "Given a vector of rewards\n",
    "$$\n",
    "\\mathbf r = \\begin{pmatrix} r_1 & r_2 & \\dotsb & r_T \\end{pmatrix},\n",
    "$$\n",
    "we'd like to create the vector of discounted returns\n",
    "$$\n",
    "\\mathbf g = \\begin{pmatrix} g_0 & g_1 & \\dotsb & g_{T-1}\\end{pmatrix}.\n",
    "$$\n",
    "A vectorized way to go about this is\n",
    "$$\n",
    "\\mathbf g=\\mathbf r\\cdot\\Gamma\n",
    "$$\n",
    "with\n",
    "$$\n",
    "\\Gamma = \\begin{pmatrix}\n",
    "1 & 0 & 0 & 0 & \\dotsb & 0 \\\\\n",
    "\\gamma & 1 & 0 & 0 & \\dotsb & 0 \\\\\n",
    "\\gamma^2 & \\gamma & 1 & 0 & \\dotsb & 0 \\\\\n",
    "\\vdots & \\vdots & \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
    "\\gamma^{T-1} & \\gamma^{T-2} & \\gamma^{T-3} & \\gamma^{T-4} & \\dotsb & 1\n",
    "\\end{pmatrix},\n",
    "$$\n",
    "where $\\gamma$ is the discount.\n",
    "\n",
    "1. Write the function below.\n",
    "2. From the matrix of rewards in the deterministic action episode data, get the matrix of discounted returns.\n",
    "3. Print the vector of discounted returns for ensemble member 0. Given that in Cart Pole, each step gives a reward of 1, is the vector what you expect?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_discounted_returns(\n",
    "    config: dict,\n",
    "    rewards: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Given a reward vector `(r_1, r_2, ..., r_T)`\n",
    "    or a batch of such vectors,\n",
    "    output the corresponding discounted return vector\n",
    "    `(g_0, g_1, ..., g_{T-1})` or the batch of such vectors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `torch.device | int | str`\n",
    "            The device the policy model is stored on.\n",
    "        `\"discount\"` : `float`\n",
    "            Discount value.\n",
    "    rewards : torch.Tensor\n",
    "        Reward tensor.\n",
    "        The last dimension is viewed as the sequence dimension.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    The discounted return tensor.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing the Training Loop\n",
    "\n",
    "### `pbt_init` and `pbt_update`\n",
    "\n",
    "Time to write the training loop! First of all, let's separate the PBT initialization and update functionalities off the `pbt` function you have in Notebook 0326.\n",
    "\n",
    "1. Write the functions below.\n",
    "2. Create a copy of the configuration dictionary and a log dictionary.\n",
    "3. Run `pbt_init` with them.\n",
    "4. Print the values at the modified configuration dictionary of a raw and transformed hyperparameter.\n",
    "5. Print the log dictionary.\n",
    "4. Create an `AdamW` optimizer for the parameters of the policy model.\n",
    "5. Update the optimizer by the PBT initialized configuration dictionary.\n",
    "6. Print the `config` attribute of the optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pbt_init(\n",
    "    config: dict,\n",
    "    log: dict\n",
    "):\n",
    "    \"\"\"\n",
    "    Initializes Population Based Training.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "    log : `defauldict(list)`\n",
    "        Training log dictionary.\n",
    "\n",
    "    Updates\n",
    "    -------\n",
    "    For each `key in config[\"hyperparameter_raw_init_distributions\"]`:\n",
    "    1. It samples raw hyperparameter values\n",
    "        and updates `config[key + \"_raw\"]` by them.\n",
    "    2. It applies `config[\"hyperparameter_transforms\"][key]`\n",
    "        to the raw hyperparameter values and\n",
    "        1. updates `config[key]` by them and\n",
    "        2. appends them to `log[key]`.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "\n",
    "def pbt_update(\n",
    "    config: dict,\n",
    "    evaluations: torch.Tensor,\n",
    "    log: dict,\n",
    "    parameters: Iterable[torch.nn.Parameter]\n",
    "):\n",
    "    \"\"\"\n",
    "    Performs a Population Based Training update\n",
    "    with exploitation determined by one-sided Welch's t-tests.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            The last this many validation metrics are used\n",
    "            in Welch's t-test.\n",
    "    evaluations : `torch.Tensor`\n",
    "        Tensor of evaluations. We assume it has shape\n",
    "        `(welch_sample_size, population_size)`.\n",
    "    log : `defauldict(list)`\n",
    "        Training log dictionary.\n",
    "    parameters : `Iterable[torch.nn.Parameter]`\n",
    "        Iterable of parameters to update.\n",
    "\n",
    "    Updates\n",
    "    -------\n",
    "    1. For each population entry, a target index is drawn,\n",
    "        the index of another population entry to compare evaluations with.\n",
    "    2. The entries are then compared with the target entries\n",
    "        via a one-sided Welch's t-test.\n",
    "    3. We get a mask of population entries\n",
    "        such that the hypothesis that the corresponding entry at the target\n",
    "        index has better expected evaluations cannot be rejected.\n",
    "    4. The indices and masks are appended to the\n",
    "        `\"source mask\"` and `\"target indices\"` lists of `log`.\n",
    "\n",
    "    5. For each tuned hyperparameter, name `key`:\n",
    "        we replace the masked entries\n",
    "        by perturbed corresponding target values:\n",
    "        to the appropriate values at `config[key + \"_raw\"]`,\n",
    "        we add noise sampled from\n",
    "        `config[\"hyperparameter_raw_perturbs\"][key]`,\n",
    "        then transform them by\n",
    "        `config[\"hyperparameter_transforms\"][key]`.\n",
    "\n",
    "        We update the appropriate values of\n",
    "        `config[key]` and `config[key + \"_raw\"]`\n",
    "        and append the new hyperparameter values to `log[key]`.\n",
    "\n",
    "    6. For each parameter in `parameters`:\n",
    "        We replace the masked subtensors by the\n",
    "        correponding entries at the target indices.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `reinforce_step`\n",
    "\n",
    "1. Write the function below.\n",
    "    1. Of course, start with `zero_grad`.\n",
    "    1. Get the unnormalized logits at all but the last batch of observations in the episode.\n",
    "    2. Transform them to normalized logits using the `logsumexp` method.\n",
    "    3. Now you have a log tensor of shape `(population_size, step_num, action_num)`. For each `population_id, step_id` pair, you need to get the logit of the action that was chosen. You can use the `gather` method for this.\n",
    "        1. It expects an index tensor of the same number of dimensions as the input tensor, that is 3. So, add an extra dimension 1 to the right of the shape of the action tensor.\n",
    "        2. Now the output will have an extra dimension 1 at the right of its shape, so you have to get rid of that by indexing or reshaping.\n",
    "    4. Using the discounted returns and the logits, you can calculate the loss as per the REINFORCE formula.\n",
    "    6. Backpropagate and take an optimizer step.\n",
    "2. Take a REINFORCE step using the stochastic episode data you generated earlier.\n",
    "3. Generate another deterministic episode data using the updated policy. Print the mean undiscounted return."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reinforce_step(\n",
    "    config: dict,\n",
    "    episode_data: dict,\n",
    "    optimizer: Optimizer,\n",
    "    policy: torch.nn.Module,\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Take a REINFORCE step.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `torch.device | int | str`\n",
    "            The device the policy model is stored on.\n",
    "        `\"discount\"` : `float`\n",
    "            Discount value.\n",
    "        We also need the hyperparameters listed at `optimizer.keys`.\n",
    "    episode_data : `dict`\n",
    "        Episode data, generated by `get_episode_data`,\n",
    "        with stochastic actions.\n",
    "    optimizer : `Optimizer`\n",
    "        An optimizer that optimizes `policy.parameters()`.\n",
    "    policy : `torch.nn.Module`\n",
    "        A policy model.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `reinforce`\n",
    "\n",
    "Now, for the final touch! We describe the details of the training loop in its docstring.\n",
    "1. The 10 episodes in each evaluation are to be generated sequentially as running much more parallel environments than available CPU cores does not offer a speedup.\n",
    "2. I recommend printing undiscounted return statistics after each evaluation to track progress.\n",
    "\n",
    "1. Write the function below.\n",
    "1. Reinitialize your policy and optimizer.\n",
    "2. Run training! \n",
    "    1. If after an evaluation you see that a population member achieved a mean undiscounted return of 500, feel free to interrupt the process.\n",
    "2. Play the video that was generated after the evaluation with the best score.\n",
    "\n",
    "Happy training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reinforce(\n",
    "    config: dict,\n",
    "    env: gym.vector.VectorEnv,\n",
    "    optimizer: Optimizer,\n",
    "    policy: torch.nn.Module,\n",
    "):\n",
    "    \"\"\"\n",
    "    REINFORCE training loop that uses Population-Based Training.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    config : `dict`\n",
    "        Configuration dictionary. Required key-value pairs:\n",
    "        `\"device\"` : `torch.device | int | str`\n",
    "            The device the policy model is stored on.\n",
    "        `\"discount\"` : `float`\n",
    "            Discount value.\n",
    "        `\"ensemble_shape\"` : tuple[int]\n",
    "            Ensemble shape. We assume this is a 1-dimensional tuple\n",
    "            with dimensions the population size.\n",
    "        `\"eval_interval\"` : `int`\n",
    "            Frequency of evaluations, measured in number of training steps.\n",
    "        `\"hyperparameter_raw_init_distributions\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of raw hyperparameter values.\n",
    "            Required keys:\n",
    "            `\"learning_rate\"`:\n",
    "                The learning rate of stochastic gradient descent.\n",
    "        `\"hyperparameter_raw_perturbs\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to `torch.distributions.Distribution` of additive noise.\n",
    "        `\"hyperparameter_transforms\"` : `dict`\n",
    "            Dictionary that maps tuned hyperparameter names\n",
    "            to transformations of raw hyperparameter values.\n",
    "        `\"improvement_threshold\"` : `float`\n",
    "            A new metric score has to be this much better\n",
    "            than the previous best to count as an improvement.\n",
    "        `\"steps_num\"` : `int`\n",
    "            Maximum number of training steps.\n",
    "        `\"videos_directory\"` : `str`\n",
    "            After each evaluation, a video following the best policy\n",
    "            is generated and saved to this directory.\n",
    "        `\"welch_confidence_level\"` : `float`\n",
    "            The confidence level in Welch's t-test\n",
    "            that is used in determining if a population member\n",
    "            is to be replaced by another member with perturbed hyperparameters.\n",
    "        `\"welch_sample_size\"` : `int`\n",
    "            At each evaluation, the population entries are compared\n",
    "            based on the undiscounted returns of this many episodes\n",
    "            with deterministic actions.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    A training log dictionary.\n",
    "    Besides the entries given in `pbt_init` and `pbt_update`,\n",
    "    it collects evaluation results at key `evaluations`.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset References\n",
    "\n",
    "[1] Andrew G. Barto; Richard S. Sutton and Charles W. Anderson: *Neuronlike adaptive elements that can solve difficult learning control problems*, 1983. IEEE Transactions on Systems, Man, and Cybernetics, vol. SMC-13 (5), pp. 834--846, doi: [10.1109/TSMC.1983.6313077](https://www.doi.org/10.1109/TSMC.1983.6313077). http://www.incompleteideas.net/papers/barto-sutton-anderson-83.pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "This work is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
