from loss import Loss
from feature_maps import LOGGER, FeatureMaps
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
plt.ion()
class Optimizer:
    @staticmethod
    def gradient_descent(content_img, style_img, content_img_clone):
        '''
        Fuction to apply gradient descent on the content image
        @params
        content_img: Original Image
        style_img: Styling Image
        content_img_clone: Copy of Original Image
        Author: @gaurangathavale
        '''
        LOGGER.info("Running gradient descent with the following parameters")
        epoch = 4000
        learning_rate = 0.01
        alpha = 1
        beta = 0.01 
        identifier = "content-4_2"
        os.mkdir("styled_images/"+identifier)
        LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")

        optimizer=optim.Adam([content_img_clone],lr=learning_rate)
        LOGGER.info("Optimizer = " + str(optimizer))
        #fig = plt.figure()
        #ax = fig.add_subplot(111)
        feature_maps = FeatureMaps()
        for e in range(epoch):
    
            content_fmaps = feature_maps.get_fmaps_content(content_img)
            style_fmaps = feature_maps.get_fmaps_style(style_img)
            generated_fmaps_content = feature_maps.get_fmaps_content(content_img_clone)
            generated_fmaps_style = feature_maps.get_fmaps_style(content_img_clone)
            
            total_loss,total_cont_loss,total_style_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps,generated_fmaps_content,generated_fmaps_style)
            
            # clears x.grad for every parameter x in the optimizer. 
            # It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
            optimizer.zero_grad() 
            
            # total_loss.backward() computes dtotal_loss/dx for every parameter x which has requires_grad=True
            total_loss.backward()
            
            # Optimization Step / Update Rule
            optimizer.step()
            #plt.clf()
            #plt.plot(content_img_clone)
            if(e%10 == 0):
                LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} content Loss = {total_cont_loss} style Loss = {total_style_loss}")
                name = "styled_images/"+identifier+"/styled_" + str(e) +".png"
                save_image(content_img_clone,name)