Commit c926f916 authored by Himali saini's avatar Himali saini

added 4_2 as content layer

parent e1e608b2
...@@ -37,7 +37,23 @@ class FeatureMaps(): ...@@ -37,7 +37,23 @@ class FeatureMaps():
LOGGER.error("Could not fetch layer "+str(layer)) LOGGER.error("Could not fetch layer "+str(layer))
return weights return weights
def get_fmaps(self,img,layer=[0,5,10,19,28]): def get_fmaps_content(self,img,layer=[21]):
'''
Function which will pass the image through the model and get the respective fmaps
@params
img: numpy image f64
layer: list
'''
fmaps = []
layer_num = 0
for layer_i in self.model.features:
img = layer_i(img)
if layer_num in layer:
fmaps.append(img)
layer_num+=1
return fmaps
def get_fmaps_style(self,img,layer=[0,5,10,19,28]):
''' '''
Function which will pass the image through the model and get the respective fmaps Function which will pass the image through the model and get the respective fmaps
@params @params
......
...@@ -2,7 +2,7 @@ import imageio ...@@ -2,7 +2,7 @@ import imageio
import os import os
fnames = [] fnames = []
newNameFolder = '' newNameFolder = 'content-4_2'
path = 'styled_images/'+newNameFolder path = 'styled_images/'+newNameFolder
for img in os.listdir(path): for img in os.listdir(path):
......
...@@ -11,7 +11,7 @@ class Logger: ...@@ -11,7 +11,7 @@ class Logger:
_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') _formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
def __new__(cls,*args,**kwargs): def __new__(cls,*args,**kwargs):
if not cls._instance: if not cls._instance:
identifier = '' identifier = 'content-4_2'
if not os.path.isdir("Logs/"): if not os.path.isdir("Logs/"):
os.mkdir("Logs/") os.mkdir("Logs/")
logHandler = logging.FileHandler("Logs/style_transfer_"+identifier+".log") logHandler = logging.FileHandler("Logs/style_transfer_"+identifier+".log")
......
...@@ -50,7 +50,7 @@ class Loss: ...@@ -50,7 +50,7 @@ class Loss:
return loss_s return loss_s
@staticmethod @staticmethod
def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen): def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,generated_fmaps_content,generated_fmaps_style):
''' '''
Function which computes total loss and returns it Function which computes total loss and returns it
@params @params
...@@ -59,10 +59,14 @@ class Loss: ...@@ -59,10 +59,14 @@ class Loss:
loss_t = 0.0 loss_t = 0.0
a = 0.0 a = 0.0
b = 0.0 b = 0.0
for gen,cont,sty in zip(content_fmap_gen,cont_fmap_real,style_fmap_real):
loss_cont = Loss.content_loss(cont,gen) for cont,gen_cont in zip(cont_fmap_real,generated_fmaps_content):
loss_style = Loss.style_loss(sty,gen) loss_cont = Loss.content_loss(cont,gen_cont)
a+= loss_cont a+= loss_cont
for gen_style,sty in zip(generated_fmaps_style,style_fmap_real):
loss_style = Loss.style_loss(sty,gen_style)
b+= loss_style b+= loss_style
loss_t += alpha*loss_cont + beta*loss_style
loss_t += alpha*a + beta*b
return loss_t,a,b return loss_t,a,b
\ No newline at end of file
...@@ -3,6 +3,7 @@ from feature_maps import LOGGER, FeatureMaps ...@@ -3,6 +3,7 @@ from feature_maps import LOGGER, FeatureMaps
import torch.optim as optim import torch.optim as optim
from torchvision.utils import save_image from torchvision.utils import save_image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os
plt.ion() plt.ion()
class Optimizer: class Optimizer:
@staticmethod @staticmethod
...@@ -16,10 +17,12 @@ class Optimizer: ...@@ -16,10 +17,12 @@ class Optimizer:
Author: @gaurangathavale Author: @gaurangathavale
''' '''
LOGGER.info("Running gradient descent with the following parameters") LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000 epoch = 4000
learning_rate = 0.01 learning_rate = 0.01
alpha = 1 alpha = 1
beta = 0.01 beta = 0.01
identifier = "content-4_2"
os.mkdir("styled_images/"+identifier)
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}") LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
optimizer=optim.Adam([content_img_clone],lr=learning_rate) optimizer=optim.Adam([content_img_clone],lr=learning_rate)
...@@ -29,11 +32,12 @@ class Optimizer: ...@@ -29,11 +32,12 @@ class Optimizer:
feature_maps = FeatureMaps() feature_maps = FeatureMaps()
for e in range(epoch): for e in range(epoch):
content_fmaps = feature_maps.get_fmaps(content_img) content_fmaps = feature_maps.get_fmaps_content(content_img)
style_fmaps = feature_maps.get_fmaps(style_img) style_fmaps = feature_maps.get_fmaps_style(style_img)
content_generated_fmaps = feature_maps.get_fmaps(content_img_clone) generated_fmaps_content = feature_maps.get_fmaps_content(content_img_clone)
generated_fmaps_style = feature_maps.get_fmaps_style(content_img_clone)
total_loss,total_cont_loss,total_style_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps, content_generated_fmaps) total_loss,total_cont_loss,total_style_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps,generated_fmaps_content,generated_fmaps_style)
# clears x.grad for every parameter x in the optimizer. # clears x.grad for every parameter x in the optimizer.
# It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes. # It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
...@@ -48,8 +52,5 @@ class Optimizer: ...@@ -48,8 +52,5 @@ class Optimizer:
#plt.plot(content_img_clone) #plt.plot(content_img_clone)
if(e%10 == 0): if(e%10 == 0):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} content Loss = {total_cont_loss} style Loss = {total_style_loss}") LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} content Loss = {total_cont_loss} style Loss = {total_style_loss}")
identifier = ""
name = "styled_images/"+identifier+"/styled_" + str(e) +".png" name = "styled_images/"+identifier+"/styled_" + str(e) +".png"
save_image(content_img_clone,name) save_image(content_img_clone,name)
\ No newline at end of file
save_image(content_img_clone,name)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment