Commit 45de1076 authored by Meet Narendra's avatar Meet Narendra 💬

First run of gradient descent

parent f6105fb6
......@@ -4,3 +4,5 @@
*.ipynb
*Logs*
*.log
*test*
*.ipynb*
import torch
import torch.nn as NN
from logger import Logger
LOGGER = Logger().logger()
LOGGER.info("Started Feature Maps")
device=torch.device( "cuda" if (torch.cude.is_available()) else 'cpu')
device=torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
LOGGER.info("Running the model cuda_available = "+str(torch.cuda.is_available()))
#Author: @meetdoshi
class FeatureMaps:
class FeatureMaps():
def __init__(self,arch="vgg19"):
'''
Init function
@params
arch: str {vgg11,vgg13,vgg16,vgg19,vgg19bn}
'''
super()
try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except:
......@@ -49,7 +52,7 @@ class FeatureMaps:
fmaps.append(img)
layer_num+=1
return fmaps
'''
if __name__ == "__main__":
fmap = FeatureMaps()
model = fmap.get_model()
......@@ -58,3 +61,4 @@ if __name__ == "__main__":
print(len(weights))
for weight in weights:
print(type(weight),weight.shape)
'''
......@@ -16,10 +16,10 @@ class Loss:
l2_norm_sq = None
try:
diff = F-P
l2_norm_sq = np.sum(diff**2)
l2_norm_sq = torch.norm(diff)**2
except Exception as e:
LOGGER.error("Error computing loss",e)
return l2_norm_sq/2.0
return l2_norm_sq
@staticmethod
def gram_matrix(F):
......@@ -40,21 +40,22 @@ class Loss:
@params
Author: @soumyagupta
'''
num_channels = F[1]
h = F[2]
w = F[3]
num_channels = F.shape[1]
h = F.shape[2]
w = F.shape[3]
style_gram_matrix = Loss.gram_matrix(F)
target_gram_matrix = Loss.gram_matrix(A)
loss_s = np.sum((style_gram_matrix-target_gram_matrix)**2)
loss_s = torch.norm(style_gram_matrix-target_gram_matrix)**2
constant = 1/(4.0*(num_channels**2)*((h*w)**2))
return constant*loss_s
@staticmethod
def total_loss(alpha,beta,cont_fmap_real,cont_fmap_noise,style_fmap_real,style_fmap_noise):
def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen):
'''
Function which computes total loss and returns it
@params
Author: @jiteshg
'''
loss_t = alpha*Loss.content_loss(cont_fmap_real,cont_fmap_noise) + beta*Loss.style_loss(style_fmap_real,style_fmap_noise)
for gen,cont,sty in zip(content_fmap_gen,cont_fmap_real,style_fmap_real):
loss_t = alpha*Loss.content_loss(cont,gen) + beta*Loss.style_loss(sty,gen)
return loss_t
\ No newline at end of file
from loss import Loss
from feature_maps import FeatureMaps
from feature_maps import LOGGER, FeatureMaps
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt
plt.ion()
class Optimizer:
@staticmethod
def gradient_descent(content_img, style_img, content_img_clone):
......@@ -14,24 +15,25 @@ class Optimizer:
content_img_clone: Copy of Original Image
Author: @gaurangathavale
'''
epoch = 1000
LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000
learning_rate = 0.001
alpha = 10
beta = 100
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
optimizer=optim.Adam([content_img_clone],lr=learning_rate)
print(optimizer)
LOGGER.info("Optimizer = " + str(optimizer))
#fig = plt.figure()
#ax = fig.add_subplot(111)
feature_maps = FeatureMaps()
for e in range(epoch):
feature_maps = FeatureMaps()
content_fmaps = feature_maps.get_fmaps(content_img)
style_fmaps = feature_maps.get_fmaps(style_img)
# content_clone_fmaps = feature_maps.get_fmaps(content_img_clone)
content_white_noise_fmaps = feature_maps.get_fmaps(content_img, [21])
style_white_noise_fmaps = feature_maps.get_fmaps(style_img, [21])
content_generated_fmaps = feature_maps.get_fmaps(content_img_clone)
total_loss = Loss.total_loss(alpha, beta, content_fmaps, content_white_noise_fmaps, style_fmaps, style_white_noise_fmaps)
total_loss = Loss.total_loss(alpha, beta, content_fmaps, style_fmaps, content_generated_fmaps)
# clears x.grad for every parameter x in the optimizer.
# It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
......@@ -42,8 +44,8 @@ class Optimizer:
# Optimization Step / Update Rule
optimizer.step()
if(not (e%100)):
print(total_loss)
#plt.clf()
#plt.plot(content_img_clone)
if(e%10):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss}")
save_image(content_img_clone,"styled.png")
\ No newline at end of file
......@@ -5,7 +5,7 @@ import torchvision.transforms as transforms
from PIL import Image
import numpy as np
LOGGER = Logger().logger()
device=torch.device( "cuda" if (torch.cude.is_available()) else 'cpu')
device=torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
#Author: @meetdoshi
class Preprocessor:
@staticmethod
......@@ -35,9 +35,9 @@ class Preprocessor:
@params
img: 3d numpy array
'''
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([512,512]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
img = loader(img).unsqueeze(0)
assert img.shape == (1,3,224,224)
assert img.shape == (1,3,512,512)
return img.to(device,torch.float)
......
import os
import warnings
from optimizer import Optimizer
from loss import Loss
from preprocess import Preprocessor
from feature_maps import FeatureMaps
import numpy as np
import time
import torch
import argparse
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import save_image
warnings.filterwarnings('ignore')
from logger import Logger
LOGGER = Logger().logger()
LOGGER.info("Started Style Transfer")
class StyleTransfer:
'''
Style Transfer Base Class
......@@ -22,5 +31,21 @@ class StyleTransfer:
Author: @gaurangathavale
'''
device = torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
content_img_path = 'test/content.jpg'
style_img_path = 'test/style.jpg'
content_img = Preprocessor.process(content_img_path)
style_img = Preprocessor.process(style_img_path)
content_img_clone = content_img.clone().requires_grad_(True)
Optimizer.gradient_descent(content_img, style_img, content_img_clone)
if __name__ == "__main__":
stf = StyleTransfer()
stf.pipeline()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment