Commit c926f916 authored by Himali saini's avatar Himali saini

added 4_2 as content layer

parent e1e608b2
......@@ -37,7 +37,23 @@ class FeatureMaps():
LOGGER.error("Could not fetch layer "+str(layer))
return weights
def get_fmaps(self,img,layer=[0,5,10,19,28]):
def get_fmaps_content(self,img,layer=[21]):
'''
Function which will pass the image through the model and get the respective fmaps
@params
img: numpy image f64
layer: list
'''
fmaps = []
layer_num = 0
for layer_i in self.model.features:
img = layer_i(img)
if layer_num in layer:
fmaps.append(img)
layer_num+=1
return fmaps
def get_fmaps_style(self,img,layer=[0,5,10,19,28]):
'''
Function which will pass the image through the model and get the respective fmaps
@params
......
......@@ -2,7 +2,7 @@ import imageio
import os
fnames = []
newNameFolder = ''
newNameFolder = 'content-4_2'
path = 'styled_images/'+newNameFolder
for img in os.listdir(path):
......
......@@ -11,7 +11,7 @@ class Logger:
_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
def __new__(cls,*args,**kwargs):
if not cls._instance:
identifier = ''
identifier = 'content-4_2'
if not os.path.isdir("Logs/"):
os.mkdir("Logs/")
logHandler = logging.FileHandler("Logs/style_transfer_"+identifier+".log")
......
......@@ -50,7 +50,7 @@ class Loss:
return loss_s
@staticmethod
def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen):
def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,generated_fmaps_content,generated_fmaps_style):
'''
Function which computes total loss and returns it
@params
......@@ -59,10 +59,14 @@ class Loss:
loss_t = 0.0
a = 0.0
b = 0.0
for gen,cont,sty in zip(content_fmap_gen,cont_fmap_real,style_fmap_real):
loss_cont = Loss.content_loss(cont,gen)
loss_style = Loss.style_loss(sty,gen)
a+= loss_cont
b+= loss_style
loss_t += alpha*loss_cont + beta*loss_style
for cont,gen_cont in zip(cont_fmap_real,generated_fmaps_content):
loss_cont = Loss.content_loss(cont,gen_cont)
a+= loss_cont
for gen_style,sty in zip(generated_fmaps_style,style_fmap_real):
loss_style = Loss.style_loss(sty,gen_style)
b+= loss_style
loss_t += alpha*a + beta*b
return loss_t,a,b
\ No newline at end of file
......@@ -3,6 +3,7 @@ from feature_maps import LOGGER, FeatureMaps
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
plt.ion()
class Optimizer:
@staticmethod
......@@ -16,10 +17,12 @@ class Optimizer:
Author: @gaurangathavale
'''
LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000
epoch = 4000
learning_rate = 0.01
alpha = 1
beta = 0.01
identifier = "content-4_2"
os.mkdir("styled_images/"+identifier)
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
optimizer=optim.Adam([content_img_clone],lr=learning_rate)
......@@ -29,11 +32,12 @@ class Optimizer:
feature_maps = FeatureMaps()
for e in range(epoch):
content_fmaps = feature_maps.get_fmaps(content_img)
style_fmaps = feature_maps.get_fmaps(style_img)
content_generated_fmaps = feature_maps.get_fmaps(content_img_clone)
content_fmaps = feature_maps.get_fmaps_content(content_img)
style_fmaps = feature_maps.get_fmaps_style(style_img)
generated_fmaps_content = feature_maps.get_fmaps_content(content_img_clone)
generated_fmaps_style = feature_maps.get_fmaps_style(content_img_clone)
total_loss,total_cont_loss,total_style_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps, content_generated_fmaps)
total_loss,total_cont_loss,total_style_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps,generated_fmaps_content,generated_fmaps_style)
# clears x.grad for every parameter x in the optimizer.
# It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
......@@ -48,8 +52,5 @@ class Optimizer:
#plt.plot(content_img_clone)
if(e%10 == 0):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} content Loss = {total_cont_loss} style Loss = {total_style_loss}")
identifier = ""
name = "styled_images/"+identifier+"/styled_" + str(e) +".png"
save_image(content_img_clone,name)
save_image(content_img_clone,name)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment