Commit 45de1076 authored by Meet Narendra's avatar Meet Narendra 💬

First run of gradient descent

parent f6105fb6
...@@ -3,4 +3,6 @@ ...@@ -3,4 +3,6 @@
*.csv *.csv
*.ipynb *.ipynb
*Logs* *Logs*
*.log *.log
\ No newline at end of file *test*
*.ipynb*
import torch import torch
import torch.nn as NN
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
LOGGER.info("Started Feature Maps") LOGGER.info("Started Feature Maps")
device=torch.device( "cuda" if (torch.cude.is_available()) else 'cpu') device=torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
LOGGER.info("Running the model cuda_available = "+str(torch.cuda.is_available()))
#Author: @meetdoshi #Author: @meetdoshi
class FeatureMaps: class FeatureMaps():
def __init__(self,arch="vgg19"): def __init__(self,arch="vgg19"):
''' '''
Init function Init function
@params @params
arch: str {vgg11,vgg13,vgg16,vgg19,vgg19bn} arch: str {vgg11,vgg13,vgg16,vgg19,vgg19bn}
''' '''
super()
try: try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True) self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except: except:
...@@ -49,7 +52,7 @@ class FeatureMaps: ...@@ -49,7 +52,7 @@ class FeatureMaps:
fmaps.append(img) fmaps.append(img)
layer_num+=1 layer_num+=1
return fmaps return fmaps
'''
if __name__ == "__main__": if __name__ == "__main__":
fmap = FeatureMaps() fmap = FeatureMaps()
model = fmap.get_model() model = fmap.get_model()
...@@ -58,3 +61,4 @@ if __name__ == "__main__": ...@@ -58,3 +61,4 @@ if __name__ == "__main__":
print(len(weights)) print(len(weights))
for weight in weights: for weight in weights:
print(type(weight),weight.shape) print(type(weight),weight.shape)
'''
...@@ -16,10 +16,10 @@ class Loss: ...@@ -16,10 +16,10 @@ class Loss:
l2_norm_sq = None l2_norm_sq = None
try: try:
diff = F-P diff = F-P
l2_norm_sq = np.sum(diff**2) l2_norm_sq = torch.norm(diff)**2
except Exception as e: except Exception as e:
LOGGER.error("Error computing loss",e) LOGGER.error("Error computing loss",e)
return l2_norm_sq/2.0 return l2_norm_sq
@staticmethod @staticmethod
def gram_matrix(F): def gram_matrix(F):
...@@ -40,21 +40,22 @@ class Loss: ...@@ -40,21 +40,22 @@ class Loss:
@params @params
Author: @soumyagupta Author: @soumyagupta
''' '''
num_channels = F[1] num_channels = F.shape[1]
h = F[2] h = F.shape[2]
w = F[3] w = F.shape[3]
style_gram_matrix = Loss.gram_matrix(F) style_gram_matrix = Loss.gram_matrix(F)
target_gram_matrix = Loss.gram_matrix(A) target_gram_matrix = Loss.gram_matrix(A)
loss_s = np.sum((style_gram_matrix-target_gram_matrix)**2) loss_s = torch.norm(style_gram_matrix-target_gram_matrix)**2
constant = 1/(4.0*(num_channels**2)*((h*w)**2)) constant = 1/(4.0*(num_channels**2)*((h*w)**2))
return constant*loss_s return constant*loss_s
@staticmethod @staticmethod
def total_loss(alpha,beta,cont_fmap_real,cont_fmap_noise,style_fmap_real,style_fmap_noise): def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen):
''' '''
Function which computes total loss and returns it Function which computes total loss and returns it
@params @params
Author: @jiteshg Author: @jiteshg
''' '''
loss_t = alpha*Loss.content_loss(cont_fmap_real,cont_fmap_noise) + beta*Loss.style_loss(style_fmap_real,style_fmap_noise) for gen,cont,sty in zip(content_fmap_gen,cont_fmap_real,style_fmap_real):
loss_t = alpha*Loss.content_loss(cont,gen) + beta*Loss.style_loss(sty,gen)
return loss_t return loss_t
\ No newline at end of file
from loss import Loss from loss import Loss
from feature_maps import FeatureMaps from feature_maps import LOGGER, FeatureMaps
import torch.optim as optim import torch.optim as optim
from torchvision.utils import save_image from torchvision.utils import save_image
import matplotlib.pyplot as plt
plt.ion()
class Optimizer: class Optimizer:
@staticmethod @staticmethod
def gradient_descent(content_img, style_img, content_img_clone): def gradient_descent(content_img, style_img, content_img_clone):
...@@ -14,24 +15,25 @@ class Optimizer: ...@@ -14,24 +15,25 @@ class Optimizer:
content_img_clone: Copy of Original Image content_img_clone: Copy of Original Image
Author: @gaurangathavale Author: @gaurangathavale
''' '''
epoch = 1000 LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000
learning_rate = 0.001 learning_rate = 0.001
alpha = 10 alpha = 10
beta = 100 beta = 100
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
optimizer=optim.Adam([content_img_clone],lr=learning_rate) optimizer=optim.Adam([content_img_clone],lr=learning_rate)
print(optimizer) LOGGER.info("Optimizer = " + str(optimizer))
#fig = plt.figure()
#ax = fig.add_subplot(111)
feature_maps = FeatureMaps()
for e in range(epoch): for e in range(epoch):
feature_maps = FeatureMaps()
content_fmaps = feature_maps.get_fmaps(content_img) content_fmaps = feature_maps.get_fmaps(content_img)
style_fmaps = feature_maps.get_fmaps(style_img) style_fmaps = feature_maps.get_fmaps(style_img)
# content_clone_fmaps = feature_maps.get_fmaps(content_img_clone) content_generated_fmaps = feature_maps.get_fmaps(content_img_clone)
content_white_noise_fmaps = feature_maps.get_fmaps(content_img, [21])
style_white_noise_fmaps = feature_maps.get_fmaps(style_img, [21])
total_loss = Loss.total_loss(alpha, beta, content_fmaps, content_white_noise_fmaps, style_fmaps, style_white_noise_fmaps) total_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps, content_generated_fmaps)
# clears x.grad for every parameter x in the optimizer. # clears x.grad for every parameter x in the optimizer.
# It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes. # It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
...@@ -42,8 +44,8 @@ class Optimizer: ...@@ -42,8 +44,8 @@ class Optimizer:
# Optimization Step / Update Rule # Optimization Step / Update Rule
optimizer.step() optimizer.step()
#plt.clf()
if(not (e%100)): #plt.plot(content_img_clone)
print(total_loss) if(e%10):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss}")
save_image(content_img_clone,"styled.png") save_image(content_img_clone,"styled.png")
\ No newline at end of file
...@@ -5,7 +5,7 @@ import torchvision.transforms as transforms ...@@ -5,7 +5,7 @@ import torchvision.transforms as transforms
from PIL import Image from PIL import Image
import numpy as np import numpy as np
LOGGER = Logger().logger() LOGGER = Logger().logger()
device=torch.device( "cuda" if (torch.cude.is_available()) else 'cpu') device=torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
#Author: @meetdoshi #Author: @meetdoshi
class Preprocessor: class Preprocessor:
@staticmethod @staticmethod
...@@ -35,9 +35,9 @@ class Preprocessor: ...@@ -35,9 +35,9 @@ class Preprocessor:
@params @params
img: 3d numpy array img: 3d numpy array
''' '''
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),]) loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([512,512]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
img = loader(img).unsqueeze(0) img = loader(img).unsqueeze(0)
assert img.shape == (1,3,224,224) assert img.shape == (1,3,512,512)
return img.to(device,torch.float) return img.to(device,torch.float)
......
import os import os
import warnings
from optimizer import Optimizer
from loss import Loss
from preprocess import Preprocessor
from feature_maps import FeatureMaps
import numpy as np import numpy as np
import time import time
import torch import torch
import argparse import argparse
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import save_image
warnings.filterwarnings('ignore')
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
LOGGER.info("Started Style Transfer") LOGGER.info("Started Style Transfer")
class StyleTransfer: class StyleTransfer:
''' '''
Style Transfer Base Class Style Transfer Base Class
...@@ -21,6 +30,22 @@ class StyleTransfer: ...@@ -21,6 +30,22 @@ class StyleTransfer:
@params: None @params: None
Author: @gaurangathavale Author: @gaurangathavale
''' '''
device = torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
content_img_path = 'test/content.jpg'
style_img_path = 'test/style.jpg'
content_img = Preprocessor.process(content_img_path)
style_img = Preprocessor.process(style_img_path)
content_img_clone = content_img.clone().requires_grad_(True)
Optimizer.gradient_descent(content_img, style_img, content_img_clone)
if __name__ == "__main__":
stf = StyleTransfer()
stf.pipeline()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment