Commit a03af97b authored by Meet Narendra's avatar Meet Narendra 💬

Cycle gans minor modifications

parent 938b36dc
...@@ -44,10 +44,10 @@ class Generator(torch.nn.Module): ...@@ -44,10 +44,10 @@ class Generator(torch.nn.Module):
ResidualBlock(), ResidualBlock(),
ResidualBlock(), ResidualBlock(),
ResidualBlock(), ResidualBlock(),
ResidualBlock(), #ResidualBlock(),
ResidualBlock(), #ResidualBlock(),
ResidualBlock(), #ResidualBlock(),
ResidualBlock(), #ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128), nn.InstanceNorm2d(128),
......
...@@ -5,7 +5,6 @@ LOGGER = Logger().logger() ...@@ -5,7 +5,6 @@ LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi #Author: @meetdoshi
device = torch.device("cpu")
class Loss: class Loss:
@staticmethod @staticmethod
def adversarial_G(): def adversarial_G():
...@@ -15,14 +14,6 @@ class Loss: ...@@ -15,14 +14,6 @@ class Loss:
''' '''
return torch.nn.MSELoss().to(device) return torch.nn.MSELoss().to(device)
@staticmethod
def adversarial_D():
'''
@params
@return
'''
return torch.nn.MSELoss().to(device)
@staticmethod @staticmethod
def cycle_consistency(): def cycle_consistency():
''' '''
......
...@@ -3,7 +3,6 @@ import torch ...@@ -3,7 +3,6 @@ import torch
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils from torchvision import transforms, utils
import glob import glob
......
...@@ -5,19 +5,23 @@ from torchvision.utils import save_image ...@@ -5,19 +5,23 @@ from torchvision.utils import save_image
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from discriminator import Discriminator from discriminator import Discriminator
from generator import Generator from generator import Generator
from loss import Loss from loss import Loss
from preprocess import LoadData from preprocess import LoadData
from tqdm import tqdm from tqdm import tqdm
from utils import initialize_weights from utils import initialize_weights
LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache()
class Train(): class Train():
def __init__(self,data="dataset/vangogh2photo",pair=False): def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1):
''' '''
@params @params
@return @return
''' '''
self.epochs = epochs
self.batch_size = batch_size
self.gen_XY = Generator().to(device) self.gen_XY = Generator().to(device)
self.gen_YX = Generator().to(device) self.gen_YX = Generator().to(device)
self.dis_X = Discriminator().to(device) self.dis_X = Discriminator().to(device)
...@@ -42,14 +46,16 @@ class Train(): ...@@ -42,14 +46,16 @@ class Train():
self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []} self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []}
self.dataset = LoadData(data=data,pair=pair) self.dataset = LoadData(data=data,pair=pair)
self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True, num_workers=4) self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
def train(self): def train(self):
''' '''
@params @params
@return @return
''' '''
EPOCHS = 200 EPOCHS = self.epochs
batch_size = self.batch_size
for epoch in range(EPOCHS): for epoch in range(EPOCHS):
''' '''
Steps: Steps:
...@@ -79,10 +85,14 @@ class Train(): ...@@ -79,10 +85,14 @@ class Train():
''' '''
adversarial_loss = self.adversarial_loss() adversarial_loss = self.adversarial_loss()
cycle_loss = self.cycle_loss() cycle_loss = self.cycle_loss()
identity_loss = self.identity_loss()
size = len(self.dataloader) size = len(self.dataloader)
for i, data in tqdm(enumerate(self.dataloader), total=size): for i, data in tqdm(enumerate(self.dataloader), total=size):
torch.cuda.empty_cache()
real_X = data['X'].to(device) real_X = data['X'].to(device)
real_Y = data['Y'].to(device) real_Y = data['Y'].to(device)
#print(real_X.shape)
#print(real_Y.shape)
batch_size = real_X.size(0) batch_size = real_X.size(0)
real_label = torch.ones(batch_size, 1).to(device) real_label = torch.ones(batch_size, 1).to(device)
fake_label = torch.zeros(batch_size, 1).to(device) fake_label = torch.zeros(batch_size, 1).to(device)
...@@ -91,8 +101,15 @@ class Train(): ...@@ -91,8 +101,15 @@ class Train():
# Training the generator # Training the generator
self.gen_XY_optim.zero_grad() self.gen_XY_optim.zero_grad()
fake_gen_X = self.gen_XY(real_Y) identity_X = self.gen_YX(real_X)
fake_gen_Y = self.gen_YX(real_X) identity_Y = self.gen_XY(real_Y)
loss_iden_X = identity_loss(identity_X,real_X) * 10
loss_iden_Y = identity_loss(identity_Y,real_Y) * 10
fake_gen_X = self.gen_YX(real_Y)
fake_gen_Y = self.gen_XY(real_X)
fake_gen_X_label = self.dis_X(fake_gen_X) fake_gen_X_label = self.dis_X(fake_gen_X)
fake_gen_Y_label = self.dis_Y(fake_gen_Y) fake_gen_Y_label = self.dis_Y(fake_gen_Y)
...@@ -106,10 +123,10 @@ class Train(): ...@@ -106,10 +123,10 @@ class Train():
#print(recovered_Y.shape,recovered_X.shape) #print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y) loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y) * 20
loss_cycle_X2Y = cycle_loss(recovered_X, real_X) loss_cycle_X2Y = cycle_loss(recovered_X, real_X) * 20
total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y + loss_iden_X + loss_iden_Y
#backprop #backprop
total_loss.backward() total_loss.backward()
self.gen_XY_optim.step() self.gen_XY_optim.step()
...@@ -149,7 +166,7 @@ class Train(): ...@@ -149,7 +166,7 @@ class Train():
self.losses["G"].append(total_loss.item()) self.losses["G"].append(total_loss.item())
self.losses["D"].append((loss_dis_X.item() + loss_dis_Y.item()) / 2) self.losses["D"].append((loss_dis_X.item() + loss_dis_Y.item()) / 2)
self.losses["C"].append((loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2) self.losses["C"].append((loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2)
LOGGER.info("Epoch: {} | G: {} | D: {} | C: {}".format(epoch, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2)) LOGGER.info("Epoch: {} | i: {} | G: {} | D: {} | C: {}".format(epoch, i, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2))
# Save Image # Save Image
if i % 100 == 0: if i % 100 == 0:
...@@ -169,6 +186,9 @@ class Train(): ...@@ -169,6 +186,9 @@ class Train():
torch.save(self.dis_X.state_dict(), "weights/dis_X.pth") torch.save(self.dis_X.state_dict(), "weights/dis_X.pth")
torch.save(self.dis_Y.state_dict(), "weights/dis_Y.pth") torch.save(self.dis_Y.state_dict(), "weights/dis_Y.pth")
#Save losses
torch.save(self.losses,"losses.pt")
if __name__ == "__main__": if __name__ == "__main__":
train = Train() train = Train()
train.train() train.train()
...@@ -4,7 +4,6 @@ from logger import Logger ...@@ -4,7 +4,6 @@ from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
#Author: @meetdoshi #Author: @meetdoshi
def initialize_weights(model): def initialize_weights(model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment