Commit f462bad0 authored by Meet Narendra's avatar Meet Narendra 💬

Changed variable names

parent a03af97b
...@@ -13,12 +13,15 @@ from tqdm import tqdm ...@@ -13,12 +13,15 @@ from tqdm import tqdm
from utils import initialize_weights from utils import initialize_weights
LOGGER.info("Cuda status: "+str(device)) LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache() torch.cuda.empty_cache()
import os
class Train(): class Train():
def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1): def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1):
''' '''
@params @params
@return @return
''' '''
if not os.path.exists("images"):
os.mkdir("images")
self.epochs = epochs self.epochs = epochs
self.batch_size = batch_size self.batch_size = batch_size
...@@ -137,12 +140,12 @@ class Train(): ...@@ -137,12 +140,12 @@ class Train():
self.dis_X_optim.zero_grad() self.dis_X_optim.zero_grad()
real_X_label = self.dis_X(real_X) real_X_label = self.dis_X(real_X)
loss_dis_real_A = adversarial_loss(real_X_label, real_label) loss_dis_real_X = adversarial_loss(real_X_label, real_label)
fake_X_label = self.dis_X(fake_gen_X.detach()) fake_X_label = self.dis_X(fake_gen_X.detach())
loss_dis_fake_A = adversarial_loss(fake_X_label, fake_label) loss_dis_fake_X = adversarial_loss(fake_X_label, fake_label)
loss_dis_X = (loss_dis_real_A + loss_dis_fake_A) / 2 loss_dis_X = (loss_dis_real_X + loss_dis_fake_X) / 2
#backprop #backprop
loss_dis_X.backward() loss_dis_X.backward()
self.dis_X_optim.step() self.dis_X_optim.step()
...@@ -151,12 +154,12 @@ class Train(): ...@@ -151,12 +154,12 @@ class Train():
self.dis_Y_optim.zero_grad() self.dis_Y_optim.zero_grad()
real_Y_label = self.dis_Y(real_Y) real_Y_label = self.dis_Y(real_Y)
loss_dis_real_B = adversarial_loss(real_Y_label, real_label) loss_dis_real_Y = adversarial_loss(real_Y_label, real_label)
fake_Y_label = self.dis_Y(fake_gen_Y.detach()) fake_Y_label = self.dis_Y(fake_gen_Y.detach())
loss_dis_fake_B = adversarial_loss(fake_Y_label, fake_label) loss_dis_fake_Y = adversarial_loss(fake_Y_label, fake_label)
loss_dis_Y = (loss_dis_real_B + loss_dis_fake_B) / 2 loss_dis_Y = (loss_dis_real_Y + loss_dis_fake_Y) / 2
#backprop #backprop
loss_dis_Y.backward() loss_dis_Y.backward()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment