Commit a03af97b authored by Meet Narendra's avatar Meet Narendra 💬

Cycle gans minor modifications

parent 938b36dc
......@@ -44,10 +44,10 @@ class Generator(torch.nn.Module):
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128),
......
......@@ -5,7 +5,6 @@ LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi
device = torch.device("cpu")
class Loss:
@staticmethod
def adversarial_G():
......@@ -15,14 +14,6 @@ class Loss:
'''
return torch.nn.MSELoss().to(device)
@staticmethod
def adversarial_D():
'''
@params
@return
'''
return torch.nn.MSELoss().to(device)
@staticmethod
def cycle_consistency():
'''
......
......@@ -3,7 +3,6 @@ import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import glob
......
......@@ -5,19 +5,23 @@ from torchvision.utils import save_image
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from discriminator import Discriminator
from generator import Generator
from loss import Loss
from preprocess import LoadData
from tqdm import tqdm
from utils import initialize_weights
LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache()
class Train():
def __init__(self,data="dataset/vangogh2photo",pair=False):
def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1):
'''
@params
@return
'''
self.epochs = epochs
self.batch_size = batch_size
self.gen_XY = Generator().to(device)
self.gen_YX = Generator().to(device)
self.dis_X = Discriminator().to(device)
......@@ -42,14 +46,16 @@ class Train():
self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []}
self.dataset = LoadData(data=data,pair=pair)
self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True, num_workers=4)
self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
def train(self):
'''
@params
@return
'''
EPOCHS = 200
EPOCHS = self.epochs
batch_size = self.batch_size
for epoch in range(EPOCHS):
'''
Steps:
......@@ -79,10 +85,14 @@ class Train():
'''
adversarial_loss = self.adversarial_loss()
cycle_loss = self.cycle_loss()
identity_loss = self.identity_loss()
size = len(self.dataloader)
for i, data in tqdm(enumerate(self.dataloader), total=size):
torch.cuda.empty_cache()
real_X = data['X'].to(device)
real_Y = data['Y'].to(device)
#print(real_X.shape)
#print(real_Y.shape)
batch_size = real_X.size(0)
real_label = torch.ones(batch_size, 1).to(device)
fake_label = torch.zeros(batch_size, 1).to(device)
......@@ -91,8 +101,15 @@ class Train():
# Training the generator
self.gen_XY_optim.zero_grad()
fake_gen_X = self.gen_XY(real_Y)
fake_gen_Y = self.gen_YX(real_X)
identity_X = self.gen_YX(real_X)
identity_Y = self.gen_XY(real_Y)
loss_iden_X = identity_loss(identity_X,real_X) * 10
loss_iden_Y = identity_loss(identity_Y,real_Y) * 10
fake_gen_X = self.gen_YX(real_Y)
fake_gen_Y = self.gen_XY(real_X)
fake_gen_X_label = self.dis_X(fake_gen_X)
fake_gen_Y_label = self.dis_Y(fake_gen_Y)
......@@ -106,10 +123,10 @@ class Train():
#print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y)
loss_cycle_X2Y = cycle_loss(recovered_X, real_X)
loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y) * 20
loss_cycle_X2Y = cycle_loss(recovered_X, real_X) * 20
total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y
total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y + loss_iden_X + loss_iden_Y
#backprop
total_loss.backward()
self.gen_XY_optim.step()
......@@ -149,7 +166,7 @@ class Train():
self.losses["G"].append(total_loss.item())
self.losses["D"].append((loss_dis_X.item() + loss_dis_Y.item()) / 2)
self.losses["C"].append((loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2)
LOGGER.info("Epoch: {} | G: {} | D: {} | C: {}".format(epoch, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2))
LOGGER.info("Epoch: {} | i: {} | G: {} | D: {} | C: {}".format(epoch, i, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2))
# Save Image
if i % 100 == 0:
......@@ -169,6 +186,9 @@ class Train():
torch.save(self.dis_X.state_dict(), "weights/dis_X.pth")
torch.save(self.dis_Y.state_dict(), "weights/dis_Y.pth")
#Save losses
torch.save(self.losses,"losses.pt")
if __name__ == "__main__":
train = Train()
train.train()
......@@ -4,7 +4,6 @@ from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
#Author: @meetdoshi
def initialize_weights(model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment