import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()

class Loss:
    @staticmethod
    def content_loss(F,P):
        '''
        Function to compute content loss between two feature representations at a particular layer
        @params
        F: 2D numpy array
        P: 2D numpy array
        Author: @meetdoshi
        '''
        l2_norm_sq = None
        try:
            diff = F-P
            l2_norm_sq = torch.mean((diff)**2)
        except Exception as e:
            LOGGER.error("Error computing loss",e)
        return l2_norm_sq

    @staticmethod
    def gram_matrix(F):
        '''
        Function to compute the gram matrix of a feature representation at a layer
        Author: @himalisaini
        '''
        shape_mat = F.shape
        num_channels = shape_mat[1]
        height = shape_mat[2]
        width = shape_mat[3]
        return torch.mm(F.view(num_channels,(height*width)),F.view(num_channels,(height*width)).t())

    @staticmethod
    def style_loss(F,A):
        '''
        Function to compute style loss between two feature representations at multiple layers
        @params
        Author: @soumyagupta
        '''
        num_channels = F.shape[1]
        h = F.shape[2]
        w = F.shape[3]
        style_gram_matrix = Loss.gram_matrix(F)
        target_gram_matrix = Loss.gram_matrix(A)
        loss_s = torch.mean((style_gram_matrix-target_gram_matrix)**2)
        #constant = 1/(4.0*(num_channels**2)*((h*w)**2))
        return loss_s

    @staticmethod
    def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,generated_fmaps_content,generated_fmaps_style):
        '''
        Function which computes total loss and returns it
        @params
        Author: @jiteshg
        '''
        loss_t = 0.0
        a = 0.0
        b = 0.0

        for cont,gen_cont in zip(cont_fmap_real,generated_fmaps_content):
                loss_cont = Loss.content_loss(cont,gen_cont)
                a+= loss_cont

        for gen_style,sty in zip(generated_fmaps_style,style_fmap_real):
                loss_style = Loss.style_loss(sty,gen_style)
                b+= loss_style

        loss_t += alpha*a + beta*b
        return loss_t,a,b