import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np

# t_embed dimension
t_emb_dim = 8


def get_model():  # Can change complexity of model
    net = nn.Sequential(
        nn.Linear(3+t_emb_dim, 64),
        nn.Linear(64, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.Linear(64, 32),
        nn.ReLU(),
        nn.Linear(32, 3),
    )
    return net


def zeros_like(x, a, t):  # custom zeros_like, return zeros (similar to torch.zeros_like)
    if isinstance(t, int):
        return torch.zeros((x.shape[0], 1)) + a[t]
    return torch.zeros((x.shape[0], 1)) + a[t][:, None]


class LitDiffusionModel(pl.LightningModule):
    def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear"):
        super().__init__()
        """
        If you include more hyperparams (e.g. `n_layers`), be sure to add that to `argparse` from `train.py`.
        Also, manually make sure that this new hyperparameter is being saved in `hparams.yaml`.
        """
        self.save_hyperparameters()
        """
        Your model implementation starts here. We have separate learnable modules for `time_embed` and `model`.
        You may choose a different architecture altogether. Feel free to explore what works best for you.
        If your architecture is just a sequence of `torch.nn.XXX` layers, using `torch.nn.Sequential` will be easier.
        
        `time_embed` can be learned or a fixed function based on the insights you get from visualizing the data.
        If your `model` is different for different datasets, you can use a hyperparameter to switch between them.
        Make sure that your hyperparameter behaves as expecte and is being saved correctly in `hparams.yaml`.
        """
        # Added embedding into model as learnable param
        self.time_embed = nn.Embedding(n_steps, t_emb_dim)
        self.model = get_model()
        """
        Be sure to save at least these 2 parameters in the model instance.
        """
        self.n_steps = n_steps
        self.n_dim = n_dim
        self.scheduler_type = scheduler_type
        """
        Sets up variables for noise schedule
        """
        self.init_alpha_beta_schedule(lbeta, ubeta)

    def forward(self, x, t):
        """
        Similar to `forward` function in `nn.Module`. 
        Notice here that `x` and `t` are passed separately. If you are using an architecture that combines
        `x` and `t` in a different way, modify this function appropriately.
        """
        if not isinstance(t, torch.Tensor):
            t = torch.LongTensor([t]).expand(x.size(0))
        t = self.time_embed(t)
        return self.model(torch.cat((x, t), dim=1).float())

    def init_alpha_beta_schedule(self, lbeta, ubeta):
        """
        Set up your noise schedule. You can perhaps have an additional hyperparameter that allows you to
        switch between various schedules for answering q4 in depth. Make sure that this hyperparameter 
        is included correctly while saving and loading your checkpoints.
        """
        # Various noise schedulers added
        if self.scheduler_type == "linear":
            self.betas = torch.linspace(lbeta, ubeta, self.n_steps)
        elif self.scheduler_type == "sigmoid":
            self.betas = torch.linspace(-6, 6, self.n_steps)
            self.betas = torch.sigmoid(self.betas) * (ubeta - lbeta) + lbeta
        elif self.scheduler_type == "cubic":
            self.betas = torch.linspace(
                lbeta * 0.33, ubeta * 0.33, self.n_steps) ** 3
        elif self.scheduler_type == "cosine":
            self.betas = torch.linspace(-1, 1, self.n_steps)
            self.betas = (torch.cos((self.betas)*np.pi/2)+0.008) / \
                1.008 * (ubeta - lbeta) + lbeta
        elif self.scheduler_type == "scaled_linear":
            self.betas = torch.linspace(
                lbeta**0.5, ubeta**0.5, self.n_steps) ** 2
        self.alphas_cum = torch.sqrt(torch.cumprod(1 - self.betas, 0))

    def q_sample(self, x, t):
        """
        Sample from q given x_t.
        """
        # Sampling during noising process
        alpha = zeros_like(x, torch.sqrt(self.alphas_cum), t)
        one_minus_alpha = zeros_like(x, 1-self.alphas_cum, t)
        q_samp = (alpha * x) + one_minus_alpha * torch.randn_like(x)

        return q_samp  # added noise

    def p_sample(self, x, t):
        """
        Sample from p given x_t.
        """
        # Sampling from posterior distribution
        e_factor = (zeros_like(x, self.betas, t) /
                    zeros_like(x, torch.sqrt(1-self.alphas_cum), t))
        eps_theta = self.forward(x, t)
        mean = (1 / zeros_like(x, torch.sqrt(1-self.betas), t)) * \
            (x - (e_factor * eps_theta))
        z = torch.randn_like(x)
        sigma_t = zeros_like(x, torch.sqrt(self.betas), t)
        sample = mean + sigma_t * z
        return sample

    def training_step(self, batch, batch_idx):
        """
        Implements one training step.
        Given a batch of samples (n_samples, n_dim) from the distribution you must calculate the loss
        for this batch. Simply return this loss from this function so that PyTorch Lightning will 
        automatically do the backprop for you. 
        Refer to the DDPM paper [1] for more details about equations that you need to implement for
        calculating loss. Make sure that all the operations preserve gradients for proper backprop.
        Refer to PyTorch Lightning documentation [2,3] for more details about how the automatic backprop 
        will update the parameters based on the loss you return from this function.
        References:
        [1]: https://arxiv.org/abs/2006.11239
        [2]: https://pytorch-lightning.readthedocs.io/en/stable/
        [3]: https://www.pytorchlightning.ai/tutorials
        """
        t = torch.randint(0, self.n_steps, size=(batch.shape[0],))
        root_alpha = zeros_like(batch, torch.sqrt(self.alphas_cum), t)
        root_one_min_alpha = zeros_like(
            batch, torch.sqrt(1-self.alphas_cum), t)
        noise = torch.randn_like(batch)
        x = batch * root_alpha + noise * root_one_min_alpha
        generated_output = self.forward(x, t)
        return (noise - generated_output).square().mean()

    def sample(self, n_samples, progress=False, return_intermediate=False):
        """
        Implements inference step for the DDPM.
        `progress` is an optional flag to implement -- it should just show the current step in diffusion
        reverse process.
        If `return_intermediate` is `False`,
            the function returns a `n_samples` sampled from the learned DDPM
            i.e. a Tensor of size (n_samples, n_dim).
            Return: (n_samples, n_dim)(final result from diffusion)
        Else
            the function returns all the intermediate steps in the diffusion process as well 
            i.e. a Tensor of size (n_samples, n_dim) and a list of `self.n_steps` Tensors of size (n_samples, n_dim) each.
            Return: (n_samples, n_dim)(final result), [(n_samples, n_dim)(intermediate) x n_steps]
        """
        if not return_intermediate:
            x_sample = torch.randn((n_samples, self.n_dim))
            for t in range(self.n_steps-1, -1, -1):
                x_sample = self.p_sample(x_sample, t)
            return x_sample
        else:
            samples = []
            x_sample = torch.randn((n_samples, self.n_dim))
            samples.append(x_sample)
            for t in range(self.n_steps-1, -1, -1):
                x_sample = self.p_sample(x_sample, t)
                samples.append(x_sample)
            return x_sample, samples

    def configure_optimizers(self):
        """
        Sets up the optimizer to be used for backprop.
        Must return a `torch.optim.XXX` instance.
        You may choose to add certain hyperparameters of the optimizers to the `train.py` as well.
        In our experiments, we chose one good value of optimizer hyperparameters for all experiments.
        """
        # return torch.optim.SGD(self.model.parameters(), lr=0.001)
        # return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.6)
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)
