Commit 9dfa23ae authored by Gaurang-Athavale's avatar Gaurang-Athavale

tried gradient descent

parent 107e376e
import torch
from logger import Logger
LOGGER = Logger().logger()
LOGGER.info("Started Feature Maps")
#Author: @meetdoshi
class FeatureMaps:
def __init__(self,arch="vgg19"):
'''
Init function
@params
arch: str {vgg11,vgg13,vgg16,vgg19,vgg19bn}
'''
try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except:
LOGGER.error("Could not load model")
return
def get_model(self):
return self.model
def get_layers(self,layers=[]):
'''
Function to extract layers
@params
layers: list
'''
weights = []
for layer in layers:
try:
weights.append(self.model.features[layer].weight)
except:
LOGGER.error("Could not fetch layer "+str(layer))
return weights
def get_fmaps(self,img,layer=[0,5,10,19,28]):
'''
Function which will pass the image through the model and get the respective fmaps
@params
img: numpy image f64
layer: list
'''
fmaps = []
layer_num = 0
for layer_i in self.model.features:
img = layer_i(img)
if layer_num in layer:
fmaps.append(img)
layer_num+=1
return fmaps
if __name__ == "__main__":
fmap = FeatureMaps()
model = fmap.get_model()
print(model.features)
weights = fmap.get_layers([4,2,6])
print(len(weights))
for weight in weights:
print(type(weight),weight.shape)
import logging
import os
#Author: @meetdoshi
class Logger:
'''
Singleton logger class
'''
_instance = None
_logHandler = None
_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
def __new__(cls,*args,**kwargs):
if not cls._instance:
os.system("rm -rf Logs/")
os.mkdir("Logs/")
logHandler = logging.FileHandler("Logs/style_transfer.log")
logHandler.setFormatter(cls._formatter)
cls._logHandler = logging.getLogger("Logs/style_transfer.log")
cls._logHandler.setLevel(logging.INFO)
cls._logHandler.addHandler(logHandler)
cls._instance = super(Logger, cls).__new__(cls,*args,**kwargs)
return cls._instance
def logger(self):
return self._logHandler
'''
#Demo use
if __name__ == "__main__":
a = Logger()
b = Logger()
print(a is b)
INFO = a.logger()
ERROR = b.logger()
INFO.info("TEST")
ERROR.info("ERROR")
'''
import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
class Loss:
@staticmethod
def content_loss(F,P):
'''
Function to compute content loss between two feature representations at a particular layer
@params
F: 2D numpy array
P: 2D numpy array
Author: @meetdoshi
'''
l2_norm_sq = None
for i in range(len(F)):
try:
diff = F[i]-P[i]
l2_norm_sq = torch.mean(diff**2)
except Exception as e:
LOGGER.error("Error computing loss",e)
return l2_norm_sq
@staticmethod
def gram_matrix(F):
'''
Function to compute the gram matrix of a feature representation at a layer
Author: @himalisaini
'''
shape_mat = F.shape
num_channels = shape_mat[1]
height = shape_mat[2]
width = shape_mat[3]
return torch.mm(F.view(num_channels,(height*width)),F.view(num_channels,(height*width)).t())
@staticmethod
def style_loss(F,A):
'''
Function to compute style loss between two feature representations at multiple layers
@params
Author: @soumyagupta
'''
for i in range(len(F)):
num_channels = F[i][1]
h = F[i][2]
w = F[i][3]
style_gram_matrix = Loss.gram_matrix(F[i])
target_gram_matrix = Loss.gram_matrix(A[i])
loss_s = torch.sum((style_gram_matrix-target_gram_matrix)**2)
constant = 1/(4.0*(num_channels**2)*((h*w)**2))
return constant*loss_s
@staticmethod
def total_loss(alpha,beta,cont_fmap_real,cont_fmap_noise,style_fmap_real,style_fmap_noise):
'''
Function which computes total loss and returns it
@params
Author: @jiteshg
'''
content_loss = Loss.content_loss(cont_fmap_real, cont_fmap_noise)
style_loss = Loss.style_loss(style_fmap_real, style_fmap_noise)
loss_t = alpha*content_loss + beta*style_loss
return loss_t
\ No newline at end of file
from loss import Loss
from feature_maps import FeatureMaps
import torch.optim as optim
from torchvision.utils import save_image
class Optimizer:
@staticmethod
def gradient_descent(content_img, style_img, content_img_clone):
'''
Fuction to apply gradient descent on the content image
@params
content_img: Original Image
style_img: Styling Image
content_img_clone: Copy of Original Image
Author: @gaurangathavale
'''
epoch = 1000
learning_rate = 0.001
alpha = 10
beta = 100
optimizer=optim.Adam([content_img_clone],lr=learning_rate)
print(optimizer)
for e in range(epoch):
feature_maps = FeatureMaps()
content_fmaps = feature_maps.get_fmaps(content_img)
style_fmaps = feature_maps.get_fmaps(style_img)
# content_clone_fmaps = feature_maps.get_fmaps(content_img_clone)
content_white_noise_fmaps = feature_maps.get_fmaps(content_img, [21])
style_white_noise_fmaps = feature_maps.get_fmaps(style_img, [21])
total_loss = Loss.total_loss(alpha, beta, content_fmaps, content_white_noise_fmaps, style_fmaps, style_white_noise_fmaps)
# clears x.grad for every parameter x in the optimizer.
# It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
optimizer.zero_grad()
# total_loss.backward() computes dtotal_loss/dx for every parameter x which has requires_grad=True
total_loss.backward()
# Optimization Step / Update Rule
optimizer.step()
if(not (e%100)):
print(total_loss)
save_image(content_img_clone,"styled.png")
\ No newline at end of file
from logger import Logger
from torch import transforms
from PIL import Image
import numpy as np
LOGGER = Logger().logger()
#Author: @meetdoshi
class Preprocessor:
@staticmethod
def load_image(path):
'''
Function to load image
@params
path: os.path
'''
img = Image.open(path)
return img
@staticmethod
def subtract_mean(img):
'''
Function to subtract mean values of RGB channels computed over whole ImageNet dataset
@params
img: 3d numpy array
'''
mean = np.reshape([103.939, 116.779, 123.68],(1,1,3))#b,g,r
return img-mean
@staticmethod
def reshape_img(img):
'''
Function to reshpae image in 224x224xnum_of_channels shape
@params
img: 3d numpy array
'''
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
img = loader(img).unsqueeze(0)
return img
@staticmethod
def process(path):
'''
Function to preprocess the image
@params
path: os.path
'''
img = Preprocessor.load_image(path)
img = Preprocessor.reshape_img(img)
img = Preprocessor.subtract_mean(img)
return img
if __name__=="__main__":
prec = Preprocessor()
img = np.zeros(shape=(4,4,3))
print(img.shape)
for i in range(img.shape[2]):
print(img[:,:,i])
img = prec.subtract_mean(img)
for i in range(img.shape[2]):
print(img[:,:,i])
import os
from optimizer import Optimizer
from loss import Loss
from preprocess import Preprocessor
from feature_maps import FeatureMaps
import numpy as np
import time
import torch
import argparse
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import save_image
from logger import Logger
LOGGER = Logger().logger()
LOGGER.info("Started Style Transfer")
class StyleTransfer:
'''
Style Transfer Base Class
'''
def __init__(self) -> None:
pass
@staticmethod
def pipeline():
'''
Pipeline for style transfer
@params: None
Author: @gaurangathavale
'''
device = torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
content_img_path = 'Nikola-Tesla.jpg'
style_img_path = 'style-Image.jpg'
content_img = Preprocessor.process(content_img_path)
style_img = Preprocessor.process(style_img_path)
content_img_clone = content_img.clone().requires_grad_(True)
Optimizer.gradient_descent(content_img, style_img, content_img_clone)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment