Commit 0d980d2c authored by Himali saini's avatar Himali saini

content loss analysis

parent 5e7d7f35
......@@ -16,8 +16,8 @@ class FeatureMaps():
super()
try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except:
LOGGER.error("Could not load model")
except Exception as e:
LOGGER.error(f"Could not load model {e}")
return
def get_model(self):
......@@ -36,8 +36,13 @@ class FeatureMaps():
except:
LOGGER.error("Could not fetch layer "+str(layer))
return weights
def get_fmaps(self,img,layer=[0,5,10,19,28]):
'''
0, 5, 10, 19, 28 - 1_1 2_1 3_1 4_1 5_1
21 - 4_2
7 ,12, 21 ,30 - 2_2 3_2 4_2 5_2
7,10,12,21,28,30 - mix
'''
def get_fmaps(self,img,layer=[7,10,12,21,28,300]):
Please register or sign in to reply
'''
Function which will pass the image through the model and get the respective fmaps
@params
......
import imageio
import os
fnames = []
for img in os.listdir("styled_mix"):
fnames.append(os.path.join("styled_mix/", img))
fnames.sort()
with imageio.get_writer('content-mixed.gif', mode='I',duration = 0.9) as writer:
for fname in fnames:
image = imageio.imread(fname)
writer.append_data(image)
\ No newline at end of file
import logging
import os
from xml.dom.minidom import Identified
#Author: @meetdoshi
class Logger:
'''
......@@ -10,11 +11,12 @@ class Logger:
_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
def __new__(cls,*args,**kwargs):
if not cls._instance:
os.system("rm -rf Logs/")
os.mkdir("Logs/")
logHandler = logging.FileHandler("Logs/style_transfer.log")
identifier = 'content-reconstruction-mixed'
if not os.path.isdir("Logs/"):
os.mkdir("Logs/")
logHandler = logging.FileHandler("Logs/style_transfer_"+identifier+".log")
logHandler.setFormatter(cls._formatter)
cls._logHandler = logging.getLogger("Logs/style_transfer.log")
cls._logHandler = logging.getLogger("Logs/style_transfer_"+identifier+".log")
cls._logHandler.setLevel(logging.INFO)
cls._logHandler.addHandler(logHandler)
cls._instance = super(Logger, cls).__new__(cls,*args,**kwargs)
......
......@@ -17,9 +17,9 @@ class Optimizer:
'''
LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000
learning_rate = 0.002
learning_rate = 0.1
alpha = 1
beta = 0.01
beta = 0
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
optimizer=optim.Adam([content_img_clone],lr=learning_rate)
......@@ -46,6 +46,9 @@ class Optimizer:
optimizer.step()
#plt.clf()
#plt.plot(content_img_clone)
if(e%100==0):
learning_rate = max(learning_rate/2,0.001)
if(e%10 == 0):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} Style Loss = {total_cont_loss} Content Loss = {total_style_loss}")
save_image(content_img_clone,"styled.png")
\ No newline at end of file
LOGGER.info(f"Epoch = {e} Learning Rate = {learning_rate} Total Loss = {total_loss} Style Loss = {total_style_loss} Content Loss = {total_cont_loss}")
name = "styled_mix/styled_" + str(e) +".png"
save_image(content_img_clone,name)
......@@ -36,7 +36,10 @@ class Preprocessor:
img: 3d numpy array
'''
#loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224])])
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),])
img = loader(img).unsqueeze(0)
assert img.shape == (1,3,224,224)
return img.to(device,torch.float)
......
......@@ -33,13 +33,15 @@ class StyleTransfer:
device = torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
content_img_path = 'test/content.jpg'
content_img_path = 'test/content.jpeg'
random_img_path = 'test/random.jpeg'
style_img_path = 'test/style.jpg'
content_img = Preprocessor.process(content_img_path)
style_img = Preprocessor.process(style_img_path)
random_img = Preprocessor.process(random_img_path)
content_img_clone = content_img.clone().requires_grad_(True)
content_img_clone = random_img.clone().requires_grad_(True)
Optimizer.gradient_descent(content_img, style_img, content_img_clone)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment