import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()

class Loss:
    @staticmethod
    def content_loss(F,P):
        '''
        Function to compute content loss between two feature representations at a particular layer
        @params
        F: 2D numpy array
        P: 2D numpy array
        Author: @meetdoshi
        '''
        l2_norm_sq = None
        try:
            diff = F-P
            l2_norm_sq = np.sum(diff**2)
        except Exception as e:
            LOGGER.error("Error computing loss",e)
        return l2_norm_sq/2.0

    @staticmethod
    def gram_matrix(F):
        '''
        Function to compute the gram matrix of a feature representation at a layer
        Author: @himalisaini
        '''
        shape_mat = F.shape
        num_channels = shape_mat[1]
        height = shape_mat[2]
        width = shape_mat[3]
        return torch.mm(F.view(num_channels,(height*width)),F.view(num_channels,(height*width)).t())

    @staticmethod
    def style_loss(F,A):
        '''
        Function to compute style loss between two feature representations at multiple layers
        @params
        Author: @soumyagupta
        '''
        num_channels = F[1]
        h = F[2]
        w = F[3]
        style_gram_matrix = Loss.gram_matrix(F)
        target_gram_matrix = Loss.gram_matrix(A)
        loss_s = np.sum((style_gram_matrix-target_gram_matrix)**2)
        constant = 1/(4.0*(num_channels**2)*((h*w)**2))
        return constant*loss_s

    @staticmethod
    def total_loss(alpha,beta,cont_fmap_real,cont_fmap_noise,style_fmap_real,style_fmap_noise):
        '''
        Function which computes total loss and returns it
        @params
        Author: @jiteshg
        '''
        loss_t = alpha*Loss.content_loss(cont_fmap_real,cont_fmap_noise) + beta*Loss.style_loss(style_fmap_real,style_fmap_noise)
        return loss_t