# This work is licensed under Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International. 
# To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/

from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Generator, Iterable, Sequence
import datasets
import gymnasium as gym
import math
import matplotlib.pyplot as plt
from moviepy import ImageSequenceClip
import os
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
import scipy
from scipy.spatial import ConvexHull
import torch
import torch.nn.functional as F
import tqdm
from typing import Optional


class Conv2D(torch.nn.Module):
    """
    Ensemble-ready, two-dimensional convolution layer

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
    in_channels : `int`
        The number of input channels.
    kernel_shape : `tuple[int]`
        The kernel shape.
    out_channels : `int`
        The number of output channels.
    bias : `bool`, optional
        Whether to include bias along the output channels.
    dilation : `int | tuple[int]`, optional
        The spacing between kernel elements, in all directions,
        or per direction. Default: `1`.
    init_multiplier : `float`, optional
        We initialize linear maps with Glorot normal initialization,
        that is using the centered normal distribution
        with standard deviation `out_channels ** -.5` times this value.
        Default: `1.`.
    padding : `int | str | tuple[int]`, optional
        The stride in all directions or per direction.
        Alternatively, `"valid"` is the same as `0`,
        and `"same"` pads the input so the output has the same shape
        as the input.
        Default: `0`.
    stride : `int | tuple[int]`, optional
        The stride in all directions or per direction.
        Default: `1`.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of features, of shape
            `batch_shape + (in_channels, height, width)` or
            `ensemble_shape + batch_shape + (in_channels, height, width)`
    """
    def __init__(
        self,
        config: dict,
        in_channels: int,
        kernel_shape: tuple[int],
        out_channels: int,
        bias=True,
        dilation=1,
        init_multiplier=1.,
        padding=0,
        stride=1
    ):
        super().__init__()

        self.dilation = dilation
        self.ensemble_shape = config["ensemble_shape"]
        self.in_channels = in_channels
        self.kernel_shape = kernel_shape
        self.out_channels = out_channels
        self.padding = padding
        self.stride = stride

        height, width = kernel_shape
        self.weight = torch.nn.Parameter(torch.empty(
            self.ensemble_shape
          + (
                out_channels,
                in_channels,
                height,
                width
            ),
            device=config["device"],
            dtype=torch.float32
        ).normal_(std=out_channels ** -.5) * init_multiplier)

        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(
                self.ensemble_shape + (out_channels,),
                device=config["device"],
                dtype=torch.float32
            ))
        else:
            self.bias = None


    def forward(self, batch: dict) -> torch.Tensor:
        ensemble_dim = len(self.ensemble_shape)

        # ensemble_shape + (minibatch_size, in_channels, input_height, input_width)
        features: torch.Tensor = batch["features"]
        features = (
            features
           .movedim(ensemble_dim, 0)
           .flatten(1, ensemble_dim + 1)
        ) # (minibatch_size, in_channels_total, input_height, input_width)

        features = F.conv2d(
            features,
            self.weight.flatten(end_dim=ensemble_dim),
            bias=self.bias.flatten(end_dim=ensemble_dim),
            dilation=self.dilation,
            groups=max(1, sum(self.ensemble_shape)),
            padding=self.padding,
            stride=self.stride
        ) # (minibatch_size, out_channels_total, output_height, output_width)

        features = (
            features
           .unflatten(1, self.ensemble_shape + (self.out_channels,))
           .movedim(0, ensemble_dim)
        )

        return batch | {"features": features}
    

class DeepSet(torch.nn.Module):
    """
    Ensemble-ready deep set.

    It is composed of an embedding and an outgoing MLP.

    In a forward call:
    1. First, we get the embedding vectors.
    2. Then, we average the embedding vectors over the sequence dimension.
    3. Finally, we apply the outgoing MLP.

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
    embedding : `torch.Module`
        The model that transforms the input to element-wise
        embedding vectors.
    embedding_dim : `int`
        The number of dimensions of embedding vectors
        that the `embedding` model outputs.
    target_key : `str`
        The key mapped to the input tensor in the dataset.
    out_features : `int`
        The number of output features.
    hidden_layer_num : `int`, optional
        If `hidden_layer_sizes` is not given, we create an outgoing MLP with
        `hidden_layer_num` hidden layers of
        `hidden_layer_size` dimensions.
    hidden_layer_size : `int`, optional
        If `hidden_layer_sizes` is not given, we create an outgoing MLP with
        `hidden_layer_num` hidden layers of
        `hidden_layer_size` dimensions.
    hidden_layer_sizes: `Iterable[int]`, optional
        If given, each entry gives a hidden layer with the given size
        for the outgoing MLP.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required keys:
        input : `torch.Tensor`
            Tensor of token IDs, of shape
            `batch_shape + (sequence_dim,)` or
                `ensemble_shape + batch_shape + (sequence_dim,)`,
            plus additional dimensions that the embedding model may require.
        `"mask"` : `torch.Tensor`
            Mask showing which entries are not padding, of shape
            `batch_shape + (sequence_dim,)` or
            `ensemble_shape + batch_shape + (sequence_dim,)`
    """
    def __init__(
        self,
        config: dict,
        embedding: torch.nn.Module,
        embedding_dim: int,
        input_key: str,
        out_features: int,
        hidden_layer_num: Optional[int] = None,
        hidden_layer_size: Optional[int] = None,
        hidden_layer_sizes: Optional[Iterable[int]] = None
    ):
        super().__init__()

        self.config = config
        self.embedding = embedding
        self.input_key = input_key

        self.mlp = get_mlp(
            config,
            embedding_dim,
            out_features,
            hidden_layer_num=hidden_layer_num,
            hidden_layer_size=hidden_layer_size,
            hidden_layer_sizes=hidden_layer_sizes
        )


    def forward(self, batch: dict) -> torch.Tensor:
        embedding = self.embedding(batch[self.input_key])
        mask = to_ensembled(self.config["ensemble_shape"], batch["mask"])

        pooled = (
            (embedding * mask[..., None]).sum(dim=-2)
          / mask[..., None].sum(dim=-2)
        )

        features = self.mlp(pooled)

        return features
    

class DictReLU(torch.nn.Module):
    """
    Applies ReLU elementwise to the feature tensor in a dictionary.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. ReLU is applied to tensor
        at the `"features"` key.
    """
    def forward(self, batch: dict) -> dict:

        return batch | {"features": F.relu(batch["features"])}
    

class Dropout(torch.nn.Module):
    """
    Ensemble-ready dropout layer.

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"dropout_p"` : `torch.Tensor`
            Dropout probability tensor, of shape `ensemble_shape`.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of features.
    """

    def __init__(self, config: dict):
        super().__init__()

        self.config = config


    def forward(self, batch: dict) -> dict:
        if not self.training:
            return batch
        
        ensemble_shape = self.config["ensemble_shape"]
        ensemble_dim = len(ensemble_shape)
        features = batch["features"]
        
        features = to_ensembled(self.config["ensemble_shape"], features)
        dropout_p = self.config["dropout_p"].unflatten(
            -1,
            ensemble_shape + (1,) * (len(features.shape) - ensemble_dim)
        )

        features = features / (1 - dropout_p + 1e-4)

        sample = torch.rand(features.shape, device=features.device)
        mask = sample > dropout_p

        features = features * mask

        return batch | {"features": features}
    

class LayerNorm(torch.nn.Module):
    """
    Ensemble-ready layer normalization layer

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
    normalized_shape : `int | tuple[int]`
        The part of the shape of the incoming tensors
        that are to be normalized together with batch dimensions.
        We view the following as batch dimensions:
        ```
        range(
            len(ensemble_shape),
            -len(normalized_shape) - normalized_offset
        )
        ```
        If an integer, we view it as a single-element tuple.
    bias : `bool`, optional
        If `elementwise_affine`, whether to include offset
        in the learned transformation. Default: `True`.
    elementwise_affine : `bool`, optional
        Whether to include learnable scale. If this and `bias`,
        then we also include learnable offset. These will be tensors
        of shape `ensemble_shape + normalized_shape` that are
        broadcast to the incoming feature tensors appropriately.
        Default: `True`.
    epsilon : `float`, optional
        Small positive value, to be included in the divisor when we
        divide by the variance, for numerical stability. Default: `1e-5`.
    normalized_offset : `int`, optional
        We get `normalized_shape` out of an incoming feature tensor
        at dimensions
        ```
        range(
            -len(normalized_shape) - normalized_offset,
            -normalized_offset
        )
        ```
        Default: `0`.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of features.
    """
    def __init__(
        self,
        config: dict,
        normalized_shape: int | tuple[int],
        bias=True,
        elementwise_affine=True,
        epsilon=1e-5,
        normalized_offset=0
    ):
        super().__init__()

        if hasattr(normalized_shape, "__int__"):
            self.normalized_shape = (normalized_shape,)
        else:
            self.normalized_shape = normalized_shape

        self.ensemble_shape = config["ensemble_shape"]
        self.epsilon = epsilon
        self.normalized_offset = normalized_offset

        if elementwise_affine:
            self.scale = torch.nn.Parameter(torch.ones(
                self.ensemble_shape + self.normalized_shape + (1,) * normalized_offset,
                device=config["device"],
                dtype=torch.float32
            ))
            if bias:
                self.bias = torch.nn.Parameter(torch.zeros_like(self.scale))
            else:
                self.bias = None

        else:
            self.bias, self.scale = None, None


    def forward(self, batch: dict) -> dict:
        features: torch.Tensor = batch["features"]

        ensemble_dim = len(self.ensemble_shape)
        features = to_ensembled(self.ensemble_shape, features)

        normalized_dim = len(self.normalized_shape)

        batch_dim = len(features.shape) - ensemble_dim - normalized_dim - self.normalized_offset
        normalized_range = tuple(range(
            ensemble_dim,
            ensemble_dim + batch_dim
        )) + tuple(range(
            -normalized_dim - self.normalized_offset,
            -self.normalized_offset
        ))

        features = features - features.mean(dim=normalized_range, keepdim=True)
        features = features / features.std(dim=normalized_range, keepdim=True)

        if self.scale is not None:
            scale = self.scale.unflatten(
                ensemble_dim,
                (1,) * batch_dim + self.normalized_shape[:1]
            )

            features = features * scale

            if self.bias is not None:
                bias = self.bias.unflatten(
                    ensemble_dim,
                    (1,) * batch_dim + self.normalized_shape[:1]
                )
                features = features + bias

        return batch | {"features": features}


class Linear(torch.nn.Module):
    """
    Ensemble-ready affine transformation `y = x^T W + b`.

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
    in_features : `int`
        The number of input features
    out_features : `int`
        The number of output features.
    bias : `bool`, optional
        Whether the model should include bias. Default: `True`.
    feature_dim_index: `int`, optional
        The index of the feature dimension. Default: `-1`,
    init_multiplier : `float`, optional
        The weight parameter values are initialized following
        a normal distribution with center 0 and std
        `in_features ** -.5` times this value. Default: `1.`

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of features. The feature dimension is determined by
            `feature_dim_index`.
    """
    def __init__(
        self,
        config: dict,
        in_features: int,
        out_features: int,
        bias=True,
        feature_dim_index=-1,
        init_multiplier=1.
    ):
        super().__init__()

        self.feature_dim_index = feature_dim_index

        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(
                config["ensemble_shape"] + (out_features,),
                device=config["device"],
                dtype=torch.float32
            ))
        else:
            self.bias = None

        self.weight = torch.nn.Parameter(torch.empty(
            config["ensemble_shape"] + (in_features, out_features),
            device=config["device"],
            dtype=torch.float32
        ).normal_(std=out_features ** -.5) * init_multiplier)


    def forward(
        self,
        batch: dict
    ) -> dict:
        features: torch.Tensor = batch["features"]

        features = features.movedim(
            self.feature_dim_index,
            -1
        )

        ensemble_shape = self.weight.shape[:-2]
        ensemble_dim = len(ensemble_shape)
        ensemble_input = features.shape[:ensemble_dim] == ensemble_shape
        batch_dim = len(features.shape) - 1 - ensemble_dim * ensemble_input
        
        # (*e, *b, i) @ (*e, *b[:-1], i, o)
        weight = self.weight.reshape(
            ensemble_shape
          + (1,) * (batch_dim - 1)
          + self.weight.shape[-2:]
        )
        features = features @ weight

        if self.bias is None:
            return batch | {"features": features}
        
        # (*e, *b, o) + (*e, *b, o)
        bias = self.bias.reshape(
            ensemble_shape
          + (1,) * batch_dim
          + self.bias.shape[-1:]
        )
        features = features + bias

        features = features.movedim(
            -1,
            self.feature_dim_index
        )

        return batch | {"features": features}
    

class Optimizer(ABC):
    """
    Optimizer base class.
    Can optimize model ensembles
    with training defined by hyperparameter ensembles.

    Arguments
    ---------
    parameters : `Iterable[torch.nn.Parameter]`
        An iterable of `torch.nn.Parameter` to track.
        In a simple case of optimizing a single `model: torch.nn.Module`,
        this can be `model.parameters()`.
    config : `dict`, optional
        If given, the `update_config` method is called on it
        to initialize hyperparameters. Default: `None`.

    Class attributes
    ----------------
    keys : `tuple[str]`
        The collection of the hyperparameter keys to track
        in the configuration dictionary.

        We expect the hyperparameter values to be either
        `float` or `torch.Tensor`. In the latter case,
        we expect the shape to be a prefix of the shape of the parameters.
        The hyperparameter shapes are regarded as ensemble shapes.

        Required keys:
        `"learning_rate"`
        `"weight_decay"`

    Instance attributes
    -------------------
    config : `dict`
        The hyperparameter dictionary.
    parameters : `list[torch.nn.Parameter]`
        The list of tracked parameters.
    step_id : `int`
        Train step counter.
    """
    keys=(
        "learning_rate",
        "weight_decay"
    )
    def __init__(
        self,
        parameters: Iterable[torch.nn.Parameter],
        config=None
    ):
        self.config = dict()
        self.parameters = list(parameters)
        self.step_id = 0

        if config is not None:
            self.update_config(config)
    

    def get_parameters(self) -> Iterable[torch.Tensor]:
        """
        Get an iterable over tracked parameters
        and optimizer state tensors.
        """
        return iter(self.parameters)


    def get_hyperparameter(
        self,
        key: str,
        parameter: torch.Tensor
    ) -> torch.Tensor:
        """
        Take the hyperparameter with name `key`,
        transform it to `torch.Tensor` with the same
        `device` and `dtype` as `parameter`
        and reshape it to be broadcastable
        to `parameter` by postfixing to its shape
        an appropriate number of dimensions of 1.
        """        
        hyperparameter = torch.asarray(
            self.config[key],
            device=parameter.device,
            dtype=parameter.dtype
        )

        return hyperparameter.reshape(
            hyperparameter.shape
            + (
                len(parameter.shape)
                - len(hyperparameter.shape)
            )
            * (1,)
        )


    def step(self):
        """
        Update optimizer state, then apply parameter updates in-place.
        Assumes that backpropagation has already occurred by
        a call to the `backward` method of the loss tensor.
        """
        self.step_id += 1
        with torch.no_grad():
            for i, parameter in enumerate(self.parameters):
                self._update_parameter(parameter, i)


    def update_config(self, config: dict):
        """
        Update hyperparameters by the values in `config: dict`.
        """
        for key in self.keys:
            self.config[key] = config[key]


    def zero_grad(self):
        """
        Make the `grad` attribute of each tracked parameter `None`.
        """
        for parameter in self.parameters:
            parameter.grad = None


    def _apply_parameter_update(
        self,
        parameter: torch.nn.Parameter,
        parameter_update: torch.Tensor
    ):
        parameter += parameter_update


    @abstractmethod
    def _get_parameter_update(
        self,
        parameter: torch.nn.Parameter,
        parameter_id: int
    ) -> torch.Tensor:
        if self.config["weight_decay"] is None:
            return torch.zeros_like(parameter)
        
        return -(
            self.get_hyperparameter("learning_rate", parameter)
          * self.get_hyperparameter("weight_decay", parameter)
          * parameter
        )


    def _update_state(
        self,
        parameter: torch.nn.Parameter,
        parameter_id: int
    ):
        pass


    def _update_parameter(
        self,
        parameter: torch.nn.Parameter,
        parameter_id: int
    ):
        self._update_state(parameter, parameter_id)
        parameter_update = self._get_parameter_update(
            parameter,
            parameter_id
        )
        self._apply_parameter_update(
            parameter,
            parameter_update
        )
    

class AdamW(Optimizer):
    """
    Adam optimizer with optionally weight decay.
    Can optimize model ensembles
    with training defined by hyperparameter ensembles.

    Arguments
    ---------
    parameters : `Iterable[torch.nn.Parameter]`
        An iterable of `torch.nn.Parameter` to track.
        In a simple case of optimizing a single `model: torch.nn.Module`,
        this can be `model.parameters()`.
    config : `dict`, optional
        If given, the `update_config` method is called on it
        to initialize hyperparameters. Default: `None`.

    Class attributes
    ----------------
    keys : `tuple[str]`
        The collection of the hyperparameter keys to track
        in the configuration dictionary.

        We expect the hyperparameter values to be either
        `float` or `torch.Tensor`. In the latter case,
        we expect the shape to be a prefix of the shape of the parameters.
        The hyperparameter shapes are regarded as ensemble shapes.

        Required keys:
        `"epsilon"`,
        `"first_moment_decay"`,
        `"learning_rate"`
        `"second_moment_decay"`,
        `"weight_decay"`
    """
    keys = (
        "epsilon",
        "first_moment_decay",
        "learning_rate",
        "second_moment_decay",
        "weight_decay"
    )
    def __init__(
        self,
        parameters: Iterable[torch.nn.Parameter],
        config=None
    ):
        super().__init__(parameters, config)
        self.first_moments = [
            torch.zeros_like(parameter)
            for parameter in self.parameters
        ]
        self.second_moments = [
            torch.zeros_like(parameter)
            for parameter in self.parameters
        ]


    def get_parameters(self) -> Iterable[torch.Tensor]:
        yield from self.parameters
        yield from self.first_moments
        yield from self.second_moments


    def _get_parameter_update(
        self,
        parameter: torch.nn.Parameter,
        parameter_id: int
    ) -> torch.Tensor:
        parameter_update = super()._get_parameter_update(
            parameter,
            parameter_id
        )

        epsilon = self.get_hyperparameter(
            "epsilon",
            parameter
        )
        first_moment = self.first_moments[parameter_id]
        first_moment_decay = self.get_hyperparameter(
            "first_moment_decay",
            parameter
        )
        learning_rate = self.get_hyperparameter(
            "learning_rate",
            parameter
        )
        second_moment = self.second_moments[parameter_id]
        second_moment_decay = self.get_hyperparameter(
            "second_moment_decay",
            parameter
        )

        first_moment_debiased = (
            first_moment
          / (1 - first_moment_decay ** self.step_id)
        )
        second_moment_debiased = (
            second_moment
          / (1 - second_moment_decay ** self.step_id)
        )        

        parameter_update -= (
            learning_rate
          * first_moment_debiased
          / (
                second_moment_debiased.sqrt()
              + epsilon
            )
        )

        return parameter_update


    def _update_state(
        self,
        parameter: torch.nn.Parameter,
        parameter_id: int
    ):
        first_moment = self.first_moments[parameter_id]
        first_moment_decay = self.get_hyperparameter(
            "first_moment_decay",
            parameter
        )
        second_moment = self.second_moments[parameter_id]
        second_moment_decay = self.get_hyperparameter(
            "second_moment_decay",
            parameter
        )

        first_moment[:] = (
            first_moment_decay
          * first_moment
          + (1 - first_moment_decay)
          * parameter.grad
        )
        second_moment[:] = (
            second_moment_decay
          * second_moment
          + (1 - second_moment_decay)
          * parameter.grad.square()
        )
        
        
class Pool2D(torch.nn.Module):
    """
    Ensemble-ready two-dimensional mean pool operation

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
    kernel_shape : `int | tuple[int]`, optional
        The kernel shape.
        If given, we pool along the kernel displacements.
        Otherwise, we pool along all the two sequential dimensions.
    padding : `int | tuple[int]`, optional
        The padding in all directions or per direction.
        It is used if `kernel_shape` is given.
        Default: `0`.
    stride : `int | tuple[int]`, optional
        The stride in all directions or per direction.
        It is used if `kernel_shape` is given.
        Default: `1`.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of features, of shape
            `batch_shape + (in_channels, height, width)` or
            `ensemble_shape + batch_shape + (in_channels, height, width)`
    """
    def __init__(
        self,
        config: dict,
        kernel_shape: Optional[tuple[int]] = None,
        padding=0,
        stride=1
    ):
        super().__init__()

        self.ensemble_shape = config["ensemble_shape"]
        self.kernel_shape = kernel_shape
        self.padding = padding
        self.stride = stride


    def forward(self, batch: dict) -> dict:        
        # ensemble_shape + (minibatch_size, channels, input_height, input_width)
        features: torch.Tensor = batch["features"]

        if self.kernel_shape is None:
            return batch | {"features": features.mean(dim=(-2, -1))}
        
        channels = features.shape[-3]
        ensemble_dim = len(self.ensemble_shape)
        features = (
            features
           .movedim(ensemble_dim, 0)
           .flatten(1, ensemble_dim + 1)
        ) # (minibatch_size, channels_total, input_height, input_width)
        features = F.avg_pool2d(
            features,
            self.kernel_shape,
            padding=self.padding,
            stride=self.stride
        ) # (minibatch_size, channels_total, output_height, output_width)
        features = (
            features
           .unflatten(1, self.ensemble_shape + (channels,))
           .movedim(0, ensemble_dim)
        )

        return batch | {"features": features}


def evaluate_model(
    config: dict,
    dataset: dict,
    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    model: torch.nn.Module,
    indptr_key="indptr",
    out_features: Optional[int] = None,
    target_key="target"
) -> torch.Tensor:
    """
    Evaluate a model on a supervised dataset.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pair:
        `"minibatch_size_eval"` : `int`
            Size of consecutive minibatches to take from the dataset.
            To be set according to RAM or GPU memory capacity.
    dataset : `dict`
        The dataset to evaluate the model on.
    get_metric : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`
        Function to get the metric from a pair of
        predicted and target value tensors.
    model : `torch.nn.Module`
        The model to evaluate.
    indptr_key : `str`, optional
        If the dataset has sequential entries,
        then this is the key of the index pointer tensor.
        Default: `"indptr"`.
    out_features: `int`, optional
        The number of output features in the predict tensors.
        By default, it is the last dimension of the target tensor.
    target_key : `str`, optional
        The key mapped to the target value tensor in the dataset.
        Default: `"target"`
    """
    target = dataset[target_key]
    if out_features is None:
        out_features = target.shape[-1]

    predict = get_output_by_batches(
        config,
        dataset,
        model,
        out_features,
        indptr_key=indptr_key
    )
    metric = get_metric(predict, target)

    return metric





def get_accuracy(
    logits: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    """
    Given logits output by a classification model, calculate the accuracy.
    Supports model ensembles of arbitrary ensemble shape.

    Parameters
    ----------
    logits : torch.Tensor
        Logit tensor of shape
        `ensemble_shape + (dataset_size, label_num)`.
    labels : torch.Tensor
        Label tensor of shape 
        `(dataset_size,)` or
        `ensemble_shape + (dataset_size,)`.

    Returns
    -------
    The tensor of accuracies of shape `ensemble_shape`.
    """
    labels_predict = logits.argmax(dim=-1)
    accuracy = (labels == labels_predict).to(torch.float32).mean(dim=-1)

    return accuracy


def get_array_sequence_keys(
    dataset: dict,
    indptr_key="indptr"
) -> tuple[tuple[str], tuple[str]]:
    """
    Get the array and sequence keys of the dataset,
    which are defined as follows:

    Given a key-value pair, we say that the key is
    1. an array key, if the value tensor has the dataset size
        at the batch dimension and
    2. a sequence key, if the value tensor has the total number of tokens
        at the batch dimension.

    For the definition of batch dimension and dataset size,
    see `get_dataset_size`.

    Parameters
    ----------
    dataset : `dict`
        The dataset.
    indptr_key : `str`, optional
        The key of the index pointer tensor. Default: `indptr`.

    Returns
    -------
    The pair of
    1. the tuple of array keys and
    2. the tuple of sequence keys.
    """
    if indptr_key in dataset:
        indptr = dataset[indptr_key]
        ensemble_shape = indptr.shape[:-1]

        ensemble_dim = len(ensemble_shape)
        entry_num = indptr.shape[-1] - 1
        token_num = indptr.max()

        if entry_num == token_num:
            raise ValueError(
                f"The number of dataset entries equals the maximum number of tokens per ensemble member: {entry_num}"
            )

        array_keys = tuple((
            key
            for key, value in dataset.items()
            if (
                key != indptr_key
            and value.shape[ensemble_dim] == entry_num
            )
        ))
        sequence_keys = tuple((
            key
            for key, value in dataset.items()
            if (
                (key != indptr_key)
            and (value.shape[ensemble_dim] == token_num)
            )
        ))
    else:
        array_keys = tuple(dataset)
        sequence_keys = ()

    return array_keys, sequence_keys


def get_binary_accuracy(
    logits: torch.Tensor,
    labels: torch.Tensor
) -> torch.Tensor:
    """
    Get the binary accuracy between a label and a logit tensor.
    It can handle arbitrary ensemble shapes.

    Parameters
    ----------
    logits : torch.Tensor
        The logit tensor. We assume it has shape
        `ensemble_shape + (dataset_size, 1)`.
    labels : torch.Tensor
        The tensor of true labels. We assume it has shape
        `(dataset_size,)` or `ensemble_shape + (dataset_size,)`.

    Returns
    -------
    The tensor of binary accuracies per ensemble member
    of shape `ensemble_shape`.
    """
    predict_positives = logits[..., 0] > 0
    true_positives = labels.to(torch.bool)

    return (
        predict_positives == true_positives
    ).to(torch.float32).mean(dim=-1)


def get_binary_cross_entropy(
    logits: torch.Tensor,
    labels: torch.Tensor
) -> torch.Tensor:
    """
    Get the binary cross-entropy between a label and a logit tensor.
    It can handle arbitrary ensemble shapes.

    Parameters
    ----------
    logits : torch.Tensor
        The logit tensor. We assume it has shape
        `ensemble_shape + (dataset_size,)`.
    labels : torch.Tensor
        The tensor of true labels. We assume it has shape
        `(dataset_size,)` or `ensemble_shape + (dataset_size, 1)`.

    Returns
    -------
    The tensor of binary cross-entropies per ensemble member
    of shape `ensemble_shape`.
    """

    return F.binary_cross_entropy_with_logits(
        logits[..., 0],
        labels.broadcast_to(logits.shape[:-1]),
        reduction="none"
    ).mean(dim=-1)


def get_cross_entropy(
    logits: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    """
    Given logits output by a classification model, 
    calculate the cross-entropy.
    Supports model ensembles of arbitrary ensemble shape.

    Parameters
    ----------
    logits : torch.Tensor
        Logit tensor of shape
        `ensemble_shape + (dataset_size, label_num)`.
    labels : torch.Tensor
        Label tensor of shape 
        `(dataset_size,)` or
        `ensemble_shape + (dataset_size,)`.

    Returns
    -------
    The tensor of accuracies of shape `ensemble_shape`.
    """
    return F.cross_entropy(
        logits.movedim((-2, -1), (0, 1)),
        labels.broadcast_to(logits.shape[:-1]).movedim(-1, 0),
        reduction="none"
    ).mean(dim=0)


def get_dataloader_random_reshuffle(
    config: dict,
    dataset: dict,
    indptr_key="indptr",
    minibatch_size: Optional[int] = None
) -> Generator[dict]:
    """
    Given a dataset as a dictionary with tensor values,
    creates a random reshuffling (without replacement) dataloader
    that yields minibatch dictionaries indefinitely.
    Support arbitrary ensemble shapes and sequential data.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pair:
        ensemble_shape : tuple[int]
            The required ensemble shapes of the outputs.
    dataset : `dict`
        Dataset with `torch.Tensor` values.
    indptr_key : `str`, optional
        If the dataset has sequential entries,
        then this is the key of the index pointer tensor. Default: `indptr`.
    minibatch_size : `int`, optional
        Minibatch size. If not given, it is `config["minibatch_size"]`.

    Returns
    -------
    A generator of minibatch dictionaries.
    Their extra keys `"mask"` map to the mask tensors of
    entries that are not padding entries.
    """
    if minibatch_size is None:
        minibatch_size = config["minibatch_size"]

    dataset = {
        key: to_ensembled(config["ensemble_shape"], value)
        for key, value in dataset.items()
    }
    
    dataset_size = get_dataset_size(config, dataset, indptr_key=indptr_key)

    random_reshuffler = get_random_reshuffler(
        dataset_size,
        minibatch_size,
        device=config["device"],
        ensemble_shape=config["ensemble_shape"]
    )

    for indices in random_reshuffler:
        yield get_minibatch(dataset, indices, indptr_key=indptr_key)


def get_dataset_size(
    config: dict,
    dataset: dict,
    indptr_key="indptr"
) -> int:
    """
    Get the size of the potentially ensembled dataset,
    which is defined as follows:

    Let us call *batch dimension* of a tensor the dimension
        at the entry of its shape
        that comes after the ensemble shape entries
    1. If the dataset has an index pointer tensor,
        then its size is the batch the index pointer tensor minus 1.
    2. Otherwise, we take any value of the dataset.
        Then the dataset size is the batch dimension of the value tensor.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Requires key-value pair:
        `"ensemble_shape"` : `tuple[int]`
            Ensemble shape.
    dataset : `dict`
        The dataset.
    indptr_key : `str`, optional
        The key of the index pointer tensor. Default: `indptr`.

    Returns
    -------
    The dataset size.
    """
    if indptr_key in dataset:
        dataset_size = dataset[indptr_key].shape[-1] - 1
    else:
        ensemble_dim = len(config["ensemble_shape"])
        dataset_size = next(iter(dataset.values())).shape[ensemble_dim]

    return dataset_size


def get_minibatch(
    dataset: dict,
    indices: torch.Tensor,
    indptr_key="indptr",
) -> dict:
    """
    Returns the minibatch of an ensembled dataset
    given by ensembled indices.

    Parameters
    ----------
    dataset: `dict`
        A dataset given as a dictionary with `torch.Tensor` values.
    indices: `torch.Tensor`
        An index tensor for the dataset.
        Each value in the dataset should have shape prefixed by
        the shape of the index tensor.
    indptr_key : `str`, optional
        If the dataset has sequential entries,
        then this is the key of the index pointer tensor. Default: `indptr`.

    Returns
    -------
    The minibatch given by the dataset and the index tensor.
    The extra key `"mask"` is mapped to the mask tensor of
    entries that are not padding entries.
    """
    minibatch = {}

    dense_keys, sequence_keys = get_array_sequence_keys(
        dataset, indptr_key=indptr_key
    )

    # print("dense")
    # for key in dense_keys:
    #     print(key, dataset[key].shape, flush=True)

    # print("sequence")
    # for key in sequence_keys:
    #     print(key, dataset[key].shape, flush=True)

    # print("indptr", dataset[indptr_key].shape, dataset[indptr_key][-1], flush=True)

    if len(dense_keys) > 0:
        minibatch.update({
            key: dataset[key].gather(
                len(indices.shape) - 1,
                indices.reshape(
                    indices.shape + (1,) * (
                        len(dataset[key].shape) - len(indices.shape)
                    )
                ).expand(
                    indices.shape + dataset[key].shape[len(indices.shape):]
                )
            )
            for key in dense_keys
        })

    if len(sequence_keys) > 0:
        indptr_left, indptr_right = (
            dataset[indptr_key].gather(
                -1,
                i
            )
            for i in (indices, indices + 1)
        )
        
        sizes = indptr_right - indptr_left
        sizes_max = sizes.max()
        
        sequence_indices = (
            indptr_left[..., None]
          + torch.arange(sizes_max, device=indices.device)
        )

        mask: torch.Tensor = sequence_indices < indptr_right[..., None]
        minibatch["mask"] = mask
        sequence_indices[~mask] = 0

        minibatch_shape = indices.shape
        for key in sequence_keys:
            data_raw: torch.Tensor = dataset[key]

            feature_dims = data_raw.shape[len(minibatch_shape):]

            data = data_raw.gather(
                len(minibatch_shape) - 1,
                sequence_indices.reshape(
                    minibatch_shape[:-1]
                  + (minibatch_shape[-1] * sizes_max,)
                  + (1,) * len(feature_dims)
                ).expand(
                    minibatch_shape[:-1]
                  + (minibatch_shape[-1] * sizes_max,)
                  + feature_dims
                )
            ).reshape(
                minibatch_shape
              + (sizes_max,)
              + data_raw.shape[len(minibatch_shape):]
            )
            minibatch[key] = data

    return minibatch


def get_mlp(
    config: dict,
    in_features: int,
    out_features: int,
    hidden_layer_num: Optional[int] = None,
    hidden_layer_size: Optional[int] = None,
    hidden_layer_sizes: Optional[Iterable[int]] = None,
) -> torch.nn.Sequential:
    """
    Creates an MLP with ReLU activation functions.
    Can create a model ensemble.

    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
    in_features : `int`
        The number of input features
    out_features : `int`
        The number of output features.
    hidden_layer_num : `int`, optional
        If `hidden_layer_sizes` is not given, we create an MLP with
        `hidden_layer_num` hidden layers of
        `hidden_layer_size` dimensions.
    hidden_layer_size : `int`, optional
        If `hidden_layer_sizes` is not given, we create an MLP with
        `hidden_layer_num` hidden layers of
        `hidden_layer_size` dimensions.
    hidden_layer_sizes: `Iterable[int]`, optional
        If given, each entry gives a hidden layer with the given size.
    """
    if hidden_layer_sizes is None:
        hidden_layer_sizes = (hidden_layer_size,) * hidden_layer_num

    layers = []
    layer_in_size = in_features
    for layer_out_size in hidden_layer_sizes:
        layers.extend([
            Linear(
                config,
                layer_in_size,
                layer_out_size,
                init_multiplier=2 ** .5
            ),
            torch.nn.ReLU()
        ])
        layer_in_size = layer_out_size
    
    layers.append(Linear(
        config,
        layer_in_size,
        out_features
    ))

    return torch.nn.Sequential(*layers)


def get_mse(
    predict: torch.Tensor,
    target: torch.Tensor
) -> torch.Tensor:
    """
    Calculates the MSE between two tensors. Compatible with ensembles.

    Parameters
    ----------
    predict : `torch.Tensor`
        Predicted values. The expected shape is
        `ensemble_shape + (batch_size, values_dim)`.
    target : `torch.Tensor`
        Target values. The expected shape is either
        `ensemble_shape + (batch_size, values_dim)` or
        `(batch_size, values_dim)`.

    Returns
    -------
    The tensor of MSE values, of shape `ensemble_shape`.
    """
    target = target.broadcast_to(predict.shape)
    mse = F.mse_loss(predict, target, reduction="none")
    return mse.sum(dim=-1).mean(dim=-1)


def get_output_by_batches(
    config: dict,
    dataset: dict,
    model: torch.nn.Module,
    out_features: int,
    indptr_key="indptr"
) -> torch.Tensor:
    """
    Get the output of a model in a single tensor for a full dataset,
    but collected via evaluation by minibatches.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pair:
        `"minibatch_size_eval"` : `int`
            Size of consecutive minibatches to take from the dataset.
            To be set according to RAM or GPU memory capacity.
    dataset : `dict`
        The dataset to evaluate the model on.
    model : `torch.nn.Module`
        The model to evaluate.
    out_features : `int`  
        The number of output features of the model.
    indptr_key : `str`, optional
        If the dataset has sequential entries,
        then this is the key of the index pointer tensor.
        Default: `"indptr"`.

    """
    ensemble_shape = config["ensemble_shape"]
    dataset = {
        key: to_ensembled(ensemble_shape, value)
        for key, value in dataset.items()
    }

    dataset_size = get_dataset_size(config, dataset, indptr_key=indptr_key)
    minibatch_size = config["minibatch_size_eval"]

    minibatch_num = math.ceil(dataset_size / minibatch_size)
    output = torch.empty(
        ensemble_shape + (dataset_size, out_features),
        device=config["device"],
        dtype=torch.float32
    )
    with torch.no_grad():
        for i in range(minibatch_num):
            minibatch_indices = torch.arange(
                i * minibatch_size,
                min((i + 1) * minibatch_size, dataset_size),
                device=config["device"]
            )
            minibatch_indices = minibatch_indices.broadcast_to(
                ensemble_shape + (len(minibatch_indices),)
            )
            minibatch = get_minibatch(
                dataset,
                minibatch_indices,
                indptr_key=indptr_key
            )
            output[
                ...,
                i*minibatch_size:(i+1)*minibatch_size,
                :
            ] = model(minibatch)["features"]

    return output


def get_random_reshuffler(
    dataset_size: int,
    minibatch_size: int,
    device="cpu",
    ensemble_shape=()
) -> Generator[torch.Tensor]:
    """
    Generate minibatch indices for a random shuffling dataloader.
    Supports arbitrary ensemble shapes.

    Parameters
    ----------
    dataset_size : int
        The size of the dataset to yield batches of minibatch indices for.
    minibatch_size : int
        The minibatch size.
    device : int | str | torch.device, optional
        The device to store the index tensors on. Default: "cpu"
    ensemble_shape : tuple[int], optional
        The ensemble shape of the minibatch indices. Default: ()
    """
    q, r = divmod(dataset_size, minibatch_size)
    minibatch_num = q + min(1, r)
    minibatch_index = minibatch_num
    while True:
        if minibatch_index == minibatch_num:
            minibatch_index = 0
            shuffled_indices = get_shuffled_indices(
                dataset_size,
                device=device,
                ensemble_shape=ensemble_shape
            )

        yield shuffled_indices[
            ...,
            minibatch_index * minibatch_size
        :(minibatch_index + 1) * minibatch_size
        ]

        minibatch_index += 1


def get_seed(
    upper=1 << 31
) -> int:
    """
    Generates a random integer by the `torch` PRNG,
    to be used as seed in a stochastic function.

    Parameters
    ----------
    upper : int, optional
        Exclusive upper bound of the interval to generate integers from.
        Default: 1 << 31.

    Returns
    -------
    A random integer.
    """
    return int(torch.randint(upper, size=()))


def get_shuffled_indices(
    dataset_size: int,
    device="cpu",
    ensemble_shape=(),
) -> torch.Tensor:
    """
    Get a tensor of a batch of shuffles of indices `0,...,dataset_size - 1`.

    Parameters
    ----------
    dataset_size : int
        The size of the dataset the indices of which to shuffle
    device : int | str | torch.device, optional
        The device to store the resulting tensor on. Default: "cpu"
    ensemble_shape : tuple[int], optional
        The batch shape of the shuffled index tensors. Default: ()
    """
    total_shape = ensemble_shape + (dataset_size,)
    uniform = torch.rand(
        total_shape,
        device=device
    )
    indices = uniform.argsort(dim=-1)

    return indices


def is_ensembled(
    ensemble_shape: tuple[int],
    tensor: torch.Tensor
) -> bool:
    """
    We view `tensor` as *ensembled* if it is prefixed by `ensemble_shape`,
    that is its slice of the first `len(ensemble_shape)` entries
    is `ensemble_shape`.

    This function checks this condition.
    """
    return tensor.shape[:len(ensemble_shape)] == ensemble_shape


def line_plot_confidence_band(
    x: Sequence,
    y: torch.Tensor,
    color=None,
    confidence_level=.95,
    label="",
    opacity=.2
):
    """
    Plot training curves from an ensemble with a pointwise confidence band.

    Parameters
    ----------
    x : Sequence
        The sequence of time indicators (eg. number of train steps)
        when the measurements took place.
    y : torch.Tensor
        The tensor of measurements of shape `(len(x), ensemble_num)`.
    color : str | tuple[float] | None, optional
        The color of the plot. Default: `None`
    confidence_level : float, optional
        The confidence level of the confidence band. Default: 0.95
    label : str, optional
        The label of the plot. Default: ""
    opacity : float, optional
        The opacity of the confidence band, to be set via the
        `alpha` keyword argument of `plt.fill_between`. Default: 0.2
    """
    sample_size = y.shape[1]
    student_coefficient = -scipy.stats.t(sample_size - 1).ppf(
        (1 - confidence_level) / 2
    )
    y_mean = y.mean(dim=-1)
    y_std = y.std(dim=-1)
    
    interval_half_length = student_coefficient * y_std / sample_size ** .5
    y_low = y_mean - interval_half_length
    y_high = y_mean + interval_half_length

    plt.fill_between(x, y_low, y_high, alpha=opacity, color=color)
    plt.plot(x, y_mean, color=color, label=label)


def load_preprocessed_dataset(
    config: dict
) -> tuple[
    tuple[torch.Tensor, torch.Tensor],
    tuple[torch.Tensor, torch.Tensor],
    tuple[torch.Tensor, torch.Tensor]
]:
    """
    Loads a dataset that was saved with `torch.save`.
    We expect that the object that was saved is a dictionary with keys
    `train_features`, `train_labels`
    `valid_features`, `valid_labels`,
    `test_features`, `test_labels`
    storing the appropriate data in tensors.

    Parameters
    ----------
    config : dict
        Configuration dictionary. Required keys:  
        dataset_preprocessed_path : str
            The path where the preprocessed dataset was saved to.
        device : torch.device | int | str
            The device to map the tensors to.

    Returns
    -------
    The triple of pairs
    `(train_features, train_labels),
    (valid_feautres, valid_labels),
    (test_features, test_labels)`
    """
    loaded = torch.load(
        config["dataset_preprocessed_path"],
        weights_only=True
    )
    (
        train_features,
        train_labels,
        valid_features,
        valid_labels,
        test_features,
        test_labels
    ) = (
        loaded[key].to(config["device"])
        for key in [
            "train_features",
            "train_labels",
            "valid_features",
            "valid_labels",
            "test_features",
            "test_labels"
        ]
    )

    return (
        (train_features, train_labels),
        (valid_features, valid_labels),
        (test_features, test_labels)
    )


def lsa(
    config: dict,
    training_dataset: datasets.Dataset,
    validation_datasets: Iterable[datasets.Dataset] = ()
) -> Generator[tuple[torch.Tensor, torch.Tensor]]:
    """
    Fit a composite of a `TfidfVectorizer` and a `TruncatedSVD`
    on the corpus at the `"text"` key of the training dataset.
    Then use this composite to transform the training corpus
    and the optional validation corpora to feature matrices.
    Also returns the labels in the datasets as tensors.

    Parameters
    ----------
    config : dict
        Configuration dictionary. Required keys:
        "device" : torch.device
            The device to store feature matrices and label vectors on.
        "labels_dtype" : torch.dtype
            The datatype of label vectors.
        "n_components": int
            The number of dimensions to reduce the feature dimensions to
            with truncated SVD.
    training_dataset : datasets.Dataset
        The training dataset. Required keys:
        "text" : Iterable[str]
            The dataset corpus
        "label" : Iterable[int]
            The dataset labels
    validation_datasets : Iterable[datasets.Dataset], optional
        An iterable of additional datasets,
        of the same structure as `training_dataset`.
        Default: `()`.

    Returns
    -------
    A generator of pairs of feature matrices and label vectors.
    The first pair is the training data.
    Then the optional validation data follows.
    """
    tf_idf = TfidfVectorizer()
    train_features = tf_idf.fit_transform(training_dataset["text"])

    truncated_svd = TruncatedSVD(
        n_components=config["n_components"],
        random_state=get_seed()
    )
    train_features = truncated_svd.fit_transform(train_features)

    train_features = torch.asarray(
        train_features,
        device=config["device"],
        dtype=torch.float32
    )
    train_labels = training_dataset.with_format(
        "torch",
        device=config["device"]
    )["label"].to(config["labels_dtype"])
    
    yield train_features, train_labels

    for validation_dataset in validation_datasets:
        valid_features = tf_idf.transform(validation_dataset["text"])
        valid_features = truncated_svd.transform(valid_features)
        valid_features = torch.asarray(
            valid_features,
            device=config["device"],
            dtype=torch.float32
        )
        
        valid_labels = validation_dataset.with_format(
            "torch",
            device=config["device"]
        )["label"].to(config["labels_dtype"])

        yield (valid_features, valid_labels)


def normalize_features(
    train_features: torch.Tensor,
    additional_features=(),
    verbose=False
):
    """
    Normalize feature tensors by
    1. subtracting the total mean of the training features, then
    2. dividing by the total std of the offset training features.

    Optionally, apply the same transformation to additional feature tensors,
    eg. validation and test feature tensors.

    Parameters
    ----------
    train_features : `torch.Tensor`
        Training feature tensor.
    additional_features : `Iterable[torch.Tensor]`, optional
        Iterable of additional features to apply the transformation to.
        Default: `()`.
    verbose : `bool`, optional
        Whether to print the total mean and std
        gotten for the transformation.
    """
    sample_mean = train_features.mean()
    train_features -= sample_mean
    for features in additional_features:
        features -= sample_mean

    sample_std = train_features.std()
    train_features /= sample_std
    for features in additional_features:
        features /= sample_std

    if verbose:
        print(
            "Training feature tensor statistics before normalization:",
            f"mean {sample_mean.cpu().item():.4f}",
            f"std {sample_std.cpu().item():.4f}",
            flush=True
        )


def run_episode(
    config: dict,
    env: gym.Env,
    gif_name: Optional[str] = None,
    policy: Optional[Callable[[int], int]]=None,
) -> float:
    """
    Run an episode in a `gym.Env`
    with discrete observation and action spaces,
    following a policy.

    Make a gif video of the gameplay.

    Parameters
    ----------
    config : dict
        Configuration dictionary. Required key-value pairs:
        gif_fps : int
            Frames per second in the gif.
        video_directory : str
            If `gif_name` is given, the created movie will be saved
            to this directory.
    env : gym.Env
        The environment to get an episode in.
    gif_name : str, optional
        If given, a gif movie is saved to this filename
        in `config['videos_directory]`.
    policy : Callable[[int], int], optional
        The policy to get an episode with. Default: random policy.

    Returns
    -------
    The discounted return of the episode.
    """
    if policy is None:
        policy = lambda observation: env.action_space.sample()

    episode_return = 0
    frames = []
    step_id = 0
    observation, _ = env.reset(seed=get_seed())
    if gif_name is not None:
        os.makedirs(config["videos_directory"], exist_ok=True)
        frames.append(env.render())

    while True:
        action = policy(observation)
        observation, reward, _, terminated, _ = env.step(action)
        episode_return += reward * config["discount"] ** step_id
        if gif_name is not None:
            frames.append(env.render())

        if terminated:
            break

        step_id += 1

    if gif_name is not None:
        # https://stackoverflow.com/a/64796174
        clip = ImageSequenceClip(frames, fps=config["gif_fps"])
        gif_path = os.path.join(config["videos_directory"], gif_name)
        clip.write_gif(gif_path, fps=config["gif_fps"])

    return episode_return


def pbt_init(
    config: dict,
    log: dict
):
    """
    Initializes Population Based Training.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"ensemble_shape"` : tuple[int]
            Ensemble shape. We assume this is a 1-dimensional tuple
            with dimensions the population size.
        `"hyperparameter_raw_init_distributions"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to `torch.distributions.Distribution` of raw hyperparameter values.
        `"hyperparameter_transforms"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to transformations of raw hyperparameter values.
    log : `defauldict(list)`
        Training log dictionary.

    Updates
    -------
    For each `key in config["hyperparameter_raw_init_distributions"]`:
    1. It samples raw hyperparameter values
        and updates `config[key + "_raw"]` by them.
    2. It applies `config["hyperparameter_transforms"][key]`
        to the raw hyperparameter values and
        1. updates `config[key]` by them and
        2. appends them to `log[key]`.
    """
    for name, distribution in config[
        "hyperparameter_raw_init_distributions"
    ].items():
        value_raw = distribution.sample(config["ensemble_shape"])
        config[name + "_raw"] = value_raw
        value = config[
            "hyperparameter_transforms"
        ][name](value_raw)
        config[name] = value
        log[name].append(value)


def pbt_update(
    config: dict,
    evaluations: torch.Tensor,
    log: dict,
    parameters: Iterable[torch.nn.Parameter]
):
    """
    Performs a Population Based Training update
    with exploitation determined by one-sided Welch's t-tests.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"ensemble_shape"` : tuple[int]
            Ensemble shape. We assume this is a 1-dimensional tuple
            with dimensions the population size.
        `"hyperparameter_raw_perturbs"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to `torch.distributions.Distribution` of additive noise.
        `"hyperparameter_transforms"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to transformations of raw hyperparameter values.
        `"welch_confidence_level"` : `float`
            The confidence level in Welch's t-test
            that is used in determining if a population member
            is to be replaced by another member with perturbed hyperparameters.
        `"welch_sample_size"` : `int`
            The last this many validation metrics are used
            in Welch's t-test.
    evaluations : `torch.Tensor`
        Tensor of evaluations. We assume it has shape
        `(welch_sample_size, population_size)`.
    log : `defauldict(list)`
        Training log dictionary.
    parameters : `Iterable[torch.nn.Parameter]`
        Iterable of parameters to update.

    Updates
    -------
    1. For each population entry, a target index is drawn,
        the index of another population entry to compare evaluations with.
    2. The entries are then compared with the target entries
        via a one-sided Welch's t-test.
    3. We get a mask of population entries
        such that the hypothesis that the corresponding entry at the target
        index has better expected evaluations cannot be rejected.
    4. The indices and masks are appended to the
        `"source mask"` and `"target indices"` lists of `log`.

    5. For each tuned hyperparameter, name `key`:
        we replace the masked entries
        by perturbed corresponding target values:
        to the appropriate values at `config[key + "_raw"]`,
        we add noise sampled from
        `config["hyperparameter_raw_perturbs"][key]`,
        then transform them by
        `config["hyperparameter_transforms"][key]`.

        We update the appropriate values of
        `config[key]` and `config[key + "_raw"]`
        and append the new hyperparameter values to `log[key]`.

    6. For each parameter in `parameters`:
        We replace the masked subtensors by the
        correponding entries at the target indices.
    """
    population_size = config["ensemble_shape"][0]

    target_indices = torch.randint(
        device=evaluations.device,
        high=population_size,
        size=(population_size,)
    )
    source_mask = welch_one_sided(
        evaluations,
        evaluations[:, target_indices],
        confidence_level=config["welch_confidence_level"]
    )
    log["source mask"].append(source_mask)
    log["target indices"].append(target_indices)

    if source_mask.any():
        for parameter in parameters:
            parameter[source_mask] = parameter[
                target_indices[source_mask]
            ]

        for name, transform in config[
            "hyperparameter_transforms"
        ].items():
            value_raw: torch.Tensor = config[
                name + "_raw"
            ]

            additive_noise = config[
                "hyperparameter_raw_perturb"
            ][name].sample(
                (source_mask.sum(),)
            )
            perturbed_values = value_raw[
                target_indices
            ][source_mask] + additive_noise
            value_raw[source_mask] = perturbed_values
            value = transform(value_raw)
            config[name] = value
            log[name].append(value)


def to_ensembled(
    ensemble_shape: tuple[int],
    tensor: torch.Tensor
) -> torch.Tensor:
    """
    We say that a tensor is *ensembled*,
    if its shape starts by the ensemble shape.

    This function converts a tensor to an ensembled tensor.
    """
    if is_ensembled(ensemble_shape, tensor):
        return tensor
    
    return tensor.broadcast_to(
        ensemble_shape + tensor.shape
    )


def train_logistic_regression(
    config: dict,
    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    out_features: int,
    train_dataloader: Generator[tuple[torch.Tensor, torch.Tensor]],
    valid_features: torch.Tensor,
    valid_labels: torch.Tensor,
    loss_name="loss",
    metric_name="metric",
    use_bias=True
) -> dict:
    """
    Train a logistic regression model on a classification task.
    Support model ensembles of arbitrary shape.

    Parameters
    ----------
    config : dict
        Configuration dictionary. Required keys:
        ensemble_shape : tuple[int]
            The shape of the model ensemble.
        improvement_threshold : float
            Making the best validation score this much better
            counts as an improvement.
        learning_rate : float | torch.Tensor
            The learning rate of the SGD optimization.
            If a tensor, then it should have shape
            broadcastable to `ensemble_shape`.
            In that case, the members of the ensemble are trained with
            different learning rates.
        steps_num : int
            The maximum number of training steps to take.
        steps_without_improvement : int
            The maximum number of training steps without improvement to take.
        valid_interval : int
            The frequency of evaluations,
            measured in the number of train steps.
    out_features : int
        The number of output features.
        When training a binary logistic regression model, this should be 1.
        Otherwise, this should be
        the number of distinct labels in the classification task.
    train_dataloader : Generator[tuple[torch.Tensor, torch.Tensor]]
        A training minibatch dataloader, that yields pairs of
        feature and label tensors indefinitely.
        We assume that these have shape
        `ensemble_shape + (minibatch_size, feature_dim)`
        and `ensemble_shape + (minibatch_size,)`
        respectively.
    valid_features : torch.Tensor
        Validation feature matrix.
    valid_labels : torch.Tensor
        Validation label vector.
    loss_name : str, optional
        The name of the loss values in the output dictionary.
        Default: "loss"
    metric_name : str, optional
        The name of the metric values in the output dictionary.
        Default: "metric"
    use_bias : bool, optional
        Whether to use a bias vector in the logistic regression model.
        Default: `True`

    Returns
    -------
    An output dictionary with the following keys:
        best scores : torch.Tensor
            The best validation accuracy per each ensemble member
        best weights : torch.Tensor
            The logistic regression weights
            that were the best per each ensemble member.
        training {metric_name} : torch.Tensor
            The tensor of training metric values, of shape
            `(evaluation_num,) + ensemble_shape`.
        training {loss_name} : torch.Tensor
            The tensor of training loss values, of shape
            `(evaluation_num,) + ensemble_shape`.
        training steps : list[int]
            The list of the number of training steps at each evaluation.
        validation {metric_name} : torch.Tensor
            The tensor of validation metric values, of shape
            `(evaluation_num,) + ensemble_shape`.
        validation {loss_name} : torch.Tensor
            The tensor of validation loss values, of shape
            `(evaluation_num,) + ensemble_shape`.
        best bias : torch.Tensor, optional
            The logistic regression biases
            that were the best per each ensemble member, if used.
    """
    device = valid_features.device
    features_dtype = valid_features.dtype
    output = defaultdict(list)

    best_scores = torch.zeros(
        config["ensemble_shape"],
        device=device,
        dtype=features_dtype
    ).log()
    steps_without_improvement = 0

    if isinstance(config["learning_rate"], torch.Tensor):
        learning_rate = config["learning_rate"][..., None, None]
    else:
        learning_rate = config["learning_rate"]

    train_accuracies_step = torch.zeros(
        config["ensemble_shape"],
        device=device,
        dtype=features_dtype
    )
    train_entries = 0
    train_losses_step = torch.zeros(
        config["ensemble_shape"],
        device=device,
        dtype=features_dtype
    )

    progress_bar = tqdm.trange(config["steps_num"])
    step_id = 0
    weights = torch.zeros(
        config["ensemble_shape"] + (valid_features.shape[1], out_features),
        device=device,
        dtype=features_dtype,
        requires_grad=True
    )

    best_weights = torch.empty_like(weights, requires_grad=False)

    if use_bias:
        bias = torch.zeros_like(weights[..., 0:1, :], requires_grad=True)
        best_bias = torch.empty_like(bias, requires_grad=False)

    for minibatch_features, minibatch_labels in train_dataloader:
        minibatch_size = minibatch_labels.shape[-1]
        weights.grad = None
        if use_bias:
            bias.grad = None

        logits = minibatch_features @ weights
        if use_bias:
            logits = logits + bias

        train_accuracies_step += get_metric(
            logits.detach(),
            minibatch_labels
        ) * minibatch_size
        loss = get_loss(
            logits,
            minibatch_labels
        )
        loss.sum().backward()
        with torch.no_grad():
            weights -= learning_rate * weights.grad
            if use_bias:
                bias -= learning_rate * bias.grad

        train_losses_step += loss.detach() * minibatch_size
        train_entries += minibatch_size

        progress_bar.update()
        step_id += 1
        if step_id % config["valid_interval"] == 0:
            with torch.no_grad():
                logits = valid_features @ weights
                if use_bias:
                    logits = logits + bias

            valid_accuracy = get_metric(
                logits,
                valid_labels
            )

            valid_loss = get_loss(
                logits,
                valid_labels
            )

            output[f"training {metric_name}"].append(
                (train_accuracies_step / train_entries)
            )
            output[f"training {loss_name}"].append(
                (train_losses_step / train_entries)
            )
            output["training steps"].append(step_id)
            output[f"validation {metric_name}"].append(valid_accuracy)
            output[f"validation {loss_name}"].append(valid_loss)

            train_accuracies_step.zero_()
            train_entries = 0
            train_losses_step.zero_()

            improvement = valid_accuracy - best_scores
            improvement_mask = improvement > config["improvement_threshold"]

            if improvement_mask.any():
                best_scores[improvement_mask] \
                    = valid_accuracy[improvement_mask]
                best_weights[improvement_mask] = weights[improvement_mask]
                steps_without_improvement = 0
            else:
                steps_without_improvement += config["valid_interval"]

            if (
                step_id >= config["steps_num"]
             or (
                    steps_without_improvement
                 >= config["steps_without_improvement"]
                )  
            ):
                for key in (
                    f"training {metric_name}",
                    f"training {loss_name}",
                    f"validation {metric_name}",
                    f"validation {loss_name}"
                ):
                    output[key] = torch.stack(output[key]).cpu()

                output["best scores"] = best_scores
                output["best weights"] = best_weights
                if use_bias:
                    output["best_bias"] = best_bias
                progress_bar.close()

                return output
            

def train_supervised(
    config: dict,
    dataset_train: dict,
    dataset_valid: dict,
    get_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    get_metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    model: torch.nn.Module,
    optimizer: Optimizer,
    out_features: Optional[int] = None,
    target_key="target"
) -> dict:
    """
    Population-based training on a supervised learning task.
    Tuned hyperparameters are given by raw values and transformations.
    This way, the hyperparameters are perturbed by
    additive noise on raw values.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"ensemble_shape"` : tuple[int]
            Ensemble shape. We assume this is a 1-dimensional tuple
            with dimensions the population size.
        `"hyperparameter_raw_init_distributions"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to `torch.distributions.Distribution` of raw hyperparameter values.
            Required keys:
            `"learning_rate"`:
                The learning rate of stochastic gradient descent.
        `"hyperparameter_raw_perturbs"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to `torch.distributions.Distribution` of additive noise.
        `"hyperparameter_transforms"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to transformations of raw hyperparameter values.
        `"improvement_threshold"` : `float`
            A new metric score has to be this much better
            than the previous best to count as an improvement.
        `"minibatch_size"` : `int`
            Minibatch size to use in a training step.
        `"minibatch_size_eval"` : `int`
            Minibatch size to use in evaluation.
            On CPU, should be about the same as `minibatch_size`.
            On GPU, should be as big as possible without
            incurring an Out of Memory error.
        `"pbt"` : `bool`
            Whether to use PBT updates in validations.
            If `False`, the algorithm just samples hyperparameters at start,
            then keeps them constant.
        `"steps_num"` : `int`
            Maximum number of training steps.
        `"steps_without_improvement`" : `int`
            If the number of training steps without improvement
            exceeds this value, then training is stopped.
        `"valid_interval"` : `int`
            Frequency of evaluations, measured in number of training steps.
        `"welch_confidence_level"` : `float`
            The confidence level in Welch's t-test
            that is used in determining if a population member
            is to be replaced by another member with perturbed hyperparameters.
        `"welch_sample_size"` : `int`
            The last this many validation metrics are used
            in Welch's t-test.
    dataset_train : `dict`
        The dataset to train the model on.
    dataset_valid : `dict`
        The dataset to evaluate the model on.
    `get_loss` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`
        A function that maps a pair of predicted and target value tensors
        to a tensor of losses per ensemble member.
    `get_metric` : `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`
        A function that maps a pair of predicted and target value tensors
        to a tensor of metrics per ensemble member.
        We assume a greater metric is better.
    `model` : `torch.nn.Module`
        The model ensemble to tune.
    `optimizer` : `Optimizer`
        An optimizer that tracks the parameters of `model`.
    indptr_key : `str`, optional
        If the dataset has sequential entries,
        then this is the key of the index pointer tensor.
        Default: `"indptr"`.
    out_features: `int`, optional
        The number of output features in the predict tensors.
        By default, it is the last dimension of the target tensor.
    target_key : `str`, optional
        The key mapped to the target value tensor in the dataset.
        Default: `"target"`
        
    Returns
    -------
    An output dictionary with the following key-value pairs:
        `"best parameters"` : `dict`  
            The state dictionary of the model with the best metric
            encountered during training.
        `"source mask"` : `torch.Tensor`
            The source masks of population members
            that were replace by other members in a PBT update
        `"target indices"` : `torch.Tensor`
            The indices of population members
            that the member where the source mask is to were replaced with.
        `"validation metric"` : `torch.Tensor`
            The validation metrics at evaluation steps.

        In addition, for each tuned hyperparameter name,
        we include a `torch.Tensor` of values per update.
    """
    ensemble_shape = config["ensemble_shape"]
    if len(ensemble_shape) != 1:
        raise ValueError(f"The number of dimensions in the ensemble shape should be 1 for the  population size, but it is {len(ensemble_shape)}")

    population_size = ensemble_shape[0]
    config_local = dict(config)
    log = defaultdict(list)

    pbt_init(config_local, log)

    optimizer.update_config(config_local)
    update_model(config_local, model)

    best_valid_metric = -torch.inf
    progress_bar = tqdm.trange(config["steps_num"])
    steps_without_improvement = 0
    train_dataloader = get_dataloader_random_reshuffle(
        config,
        dataset_train
    )

    for step_id in progress_bar:        
        if step_id % config["valid_interval"] == 0:
            model.eval()
            with torch.no_grad():
                validation_metric = evaluate_model(
                    config,
                    dataset_valid,
                    get_metric,
                    model,
                    out_features=out_features,
                    target_key=target_key
                ).nan_to_num(-torch.inf)
                log["validation metric"].append(validation_metric)
                print(
                    f"validation metric {validation_metric.max().cpu().item():.4f}"
                )

                best_last_metric, best_last_metric_id \
                    = log["validation metric"][-1].max(dim=-1)
                print(
                    f"Best last metric {best_last_metric.cpu().item():.2f}",
                    flush=True
                )
                if (
                    best_valid_metric + config["improvement_threshold"]
                ) < best_last_metric:
                    print(
                        f"New best metric",
                        flush=True
                    )
                    best_valid_metric = best_last_metric
                    steps_without_improvement = 0
                    log["best parameters"] = {
                        key: value[best_last_metric_id].clone()
                        for key, value in model.state_dict().items()
                    }
                else:
                    print(
                        f"Best metric {best_valid_metric.cpu().item():.2f}",
                        flush=True
                    )
                    steps_without_improvement += config["valid_interval"]
                    if steps_without_improvement > config[
                        "steps_without_improvement"
                    ]:
                        break

                if config["pbt"] and (len(log["validation metric"]) >= config[
                    "welch_sample_size"
                ]):
                    evaluations = torch.stack(
                        log["validation metric"][-config["welch_sample_size"]:]
                    )
                    pbt_update(
                        config_local, evaluations, log, optimizer.get_parameters()
                    )

                    update_model(config_local, model)
                    optimizer.update_config(config_local)

        model.train()

        minibatch = next(train_dataloader)
        optimizer.zero_grad()

        predict = model(minibatch)["features"]
        target = minibatch[target_key]

        loss = get_loss(predict, target).sum()
        loss.backward()
        optimizer.step()


    progress_bar.close()
    for key, value in log.items():
        if isinstance(value, list):
            log[key] = torch.stack(value)

    return log


def update_model(
    config: dict,
    model: torch.nn.Module
):
    """
    Update the configuration dictionary of a model.
    We iterate over its submodules and whichever has a `config` attribute,
    we update it by the included `config` dictionary.

    Parameters
    ----------
    config : `dict`
        The updated configuration dictionary.
    model : `torch.nn.Module`
        The model to update.
    """
    for module in model.modules():
        if hasattr(module, "config"):
            module.config.update(config)
            

def welch_one_sided(
    source: torch.Tensor,
    target: torch.Tensor,
    confidence_level=.95
) -> torch.Tensor:
    """
    Performs Welch's t-test with null hypothesis: the expected value
    of the random variable the target tensor collects samples of
    is larger then the expected value
    of the random variable the source tensor collects samples of.

    In the tensors, dimensions after the first 
    are considered batch dimensions.

    Parameters
    ----------
    source : `torch.Tensor`
        Source sample, of shape `(sample_size,) + batch_shape`.
    target : `torch.Tensor`
        Target sample, of shape `(sample_size,) + batch_shape`.
    confidence_level : `float`, optional
        Confidence level of the test. Default: `.95`.
    Returns
    -------
    A Boolean tensor of shape `batch_shape` that is `False`
    where the null hypothesis is rejected.
    """
    sample_num = len(source)
    source_sample_mean, target_sample_mean = (
        t.mean(dim=0)
        for t in (source, target)
    )
    source_sample_var, target_sample_var = (
        t.var(dim=0)
        for t in (source, target)
    )
    var_sum = source_sample_var + target_sample_var

    t = (
        (target_sample_mean - source_sample_mean)
      * (sample_num / var_sum).sqrt()
    )

    nu = (
        var_sum.square()
      * (sample_num - 1)
      / (source_sample_var ** 2 + target_sample_var ** 2)
    )

    p = scipy.stats.t(
        nu.cpu().numpy()
    ).cdf(
        t.cpu().numpy()
    )

    return torch.asarray(
        p > confidence_level,
        device=source.device
    )