Commit c85f68bf authored by Soumya Gupta's avatar Soumya Gupta

norm loss to mean loss

parent 5e7d7f35
...@@ -16,8 +16,8 @@ class FeatureMaps(): ...@@ -16,8 +16,8 @@ class FeatureMaps():
super() super()
try: try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True) self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except: except Exception as e:
LOGGER.error("Could not load model") LOGGER.error("Could not load model" + str(e))
return return
def get_model(self): def get_model(self):
......
...@@ -16,7 +16,7 @@ class Loss: ...@@ -16,7 +16,7 @@ class Loss:
l2_norm_sq = None l2_norm_sq = None
try: try:
diff = F-P diff = F-P
l2_norm_sq = torch.norm(diff)**2 l2_norm_sq = torch.mean((diff)**2)
except Exception as e: except Exception as e:
LOGGER.error("Error computing loss",e) LOGGER.error("Error computing loss",e)
return l2_norm_sq return l2_norm_sq
...@@ -45,9 +45,9 @@ class Loss: ...@@ -45,9 +45,9 @@ class Loss:
w = F.shape[3] w = F.shape[3]
style_gram_matrix = Loss.gram_matrix(F) style_gram_matrix = Loss.gram_matrix(F)
target_gram_matrix = Loss.gram_matrix(A) target_gram_matrix = Loss.gram_matrix(A)
loss_s = torch.norm(style_gram_matrix-target_gram_matrix)**2 loss_s = torch.mean((style_gram_matrix-target_gram_matrix)**2)
constant = 1/(4.0*(num_channels**2)*((h*w)**2)) #constant = 1/(4.0*(num_channels**2)*((h*w)**2))
return constant*loss_s return loss_s
@staticmethod @staticmethod
def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen): def total_loss(alpha,beta,cont_fmap_real,style_fmap_real,content_fmap_gen):
......
...@@ -17,7 +17,7 @@ class Optimizer: ...@@ -17,7 +17,7 @@ class Optimizer:
''' '''
LOGGER.info("Running gradient descent with the following parameters") LOGGER.info("Running gradient descent with the following parameters")
epoch = 5000 epoch = 5000
learning_rate = 0.002 learning_rate = 0.01
alpha = 1 alpha = 1
beta = 0.01 beta = 0.01
LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}") LOGGER.info(f"{epoch},{learning_rate},{alpha},{beta}")
...@@ -47,5 +47,5 @@ class Optimizer: ...@@ -47,5 +47,5 @@ class Optimizer:
#plt.clf() #plt.clf()
#plt.plot(content_img_clone) #plt.plot(content_img_clone)
if(e%10 == 0): if(e%10 == 0):
LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} Style Loss = {total_cont_loss} Content Loss = {total_style_loss}") LOGGER.info(f"Epoch = {e} Total Loss = {total_loss} content Loss = {total_cont_loss} style Loss = {total_style_loss}")
save_image(content_img_clone,"styled.png") save_image(content_img_clone,"styled.png")
\ No newline at end of file
...@@ -38,7 +38,7 @@ class Preprocessor: ...@@ -38,7 +38,7 @@ class Preprocessor:
#loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),]) #loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224])]) loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224])])
img = loader(img).unsqueeze(0) img = loader(img).unsqueeze(0)
assert img.shape == (1,3,224,224) #assert img.shape == (1,3,224,224)
return img.to(device,torch.float) return img.to(device,torch.float)
......
1508.06576/styled.png

104 KB | W: | H:

1508.06576/styled.png

138 KB | W: | H:

1508.06576/styled.png
1508.06576/styled.png
1508.06576/styled.png
1508.06576/styled.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment