Commit c85f68bf authored by Soumya Gupta's avatar Soumya Gupta

norm loss to mean loss

parent 5e7d7f35
......@@ -16,8 +16,8 @@ class FeatureMaps():
super()
try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except:
LOGGER.error("Could not load model")
except Exception as e:
LOGGER.error("Could not load model" + str(e))
return
def get_model(self):
......
......@@ -16,7 +16,7 @@ class Loss:
l2_norm_sq = None
try:
diff = F-P
l2_norm_sq = torch.norm(diff)**2
l2_norm_sq = torch.mean((diff)**2)
except Exception as e:
LOGGER.error("Error computing loss",e)
return l2_norm_sq
......@@ -45,9 +45,9 @@ class Loss:
w = F.shape[3]
style_gram_matrix = Loss.gram_matrix(F)
target_gram_matrix = Loss.gram_matrix(A)
loss_s = torch.norm(style_gram_matrix-target_gram_matrix)**2
constant = 1/(4.0*(num_channels**2)*((h*w)**2))
return constant*loss_s
loss_s = torch.mean((style_gram_matrix-target_gram_matrix)**2)
#constant = 1/(4.0*(num_channels**2)*((h*w)**2))
return loss_s
@staticmethod
def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen):
......
......@@ -17,7 +17,7 @@ class Optimizer:
'''
LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000
learning_rate = 0.002
learning_rate = 0.01
alpha = 1
beta = 0.01
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
......@@ -47,5 +47,5 @@ class Optimizer:
#plt.clf()
#plt.plot(content_img_clone)
if(e%10 == 0):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} Style Loss = {total_cont_loss} Content Loss = {total_style_loss}")
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} content Loss = {total_cont_loss} style Loss = {total_style_loss}")
save_image(content_img_clone,"styled.png")
\ No newline at end of file
......@@ -38,7 +38,7 @@ class Preprocessor:
#loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224])])
img = loader(img).unsqueeze(0)
assert img.shape == (1,3,224,224)
#assert img.shape == (1,3,224,224)
return img.to(device,torch.float)
......
1508.06576/styled.png

104 KB | W: | H:

1508.06576/styled.png

138 KB | W: | H:

1508.06576/styled.png
1508.06576/styled.png
1508.06576/styled.png
1508.06576/styled.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment