Commit f6105fb6 authored by Meet Narendra's avatar Meet Narendra 💬

Pulled optimizer.py

parent 6e4afc45
from loss import Loss
from feature_maps import FeatureMaps
import torch.optim as optim
from torchvision.utils import save_image
class Optimizer:
@staticmethod
def gradient_descent(content_img, style_img, content_img_clone):
'''
Fuction to apply gradient descent on the content image
@params
content_img: Original Image
style_img: Styling Image
content_img_clone: Copy of Original Image
Author: @gaurangathavale
'''
epoch = 1000
learning_rate = 0.001
alpha = 10
beta = 100
optimizer=optim.Adam([content_img_clone],lr=learning_rate)
print(optimizer)
for e in range(epoch):
feature_maps = FeatureMaps()
content_fmaps = feature_maps.get_fmaps(content_img)
style_fmaps = feature_maps.get_fmaps(style_img)
# content_clone_fmaps = feature_maps.get_fmaps(content_img_clone)
content_white_noise_fmaps = feature_maps.get_fmaps(content_img, [21])
style_white_noise_fmaps = feature_maps.get_fmaps(style_img, [21])
total_loss = Loss.total_loss(alpha, beta, content_fmaps, content_white_noise_fmaps, style_fmaps, style_white_noise_fmaps)
# clears x.grad for every parameter x in the optimizer.
# It’s important to call this before total_loss.backward(), otherwise it will accumulate the gradients from multiple passes.
optimizer.zero_grad()
# total_loss.backward() computes dtotal_loss/dx for every parameter x which has requires_grad=True
total_loss.backward()
# Optimization Step / Update Rule
optimizer.step()
if(not (e%100)):
print(total_loss)
save_image(content_img_clone,"styled.png")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment