Commit 938b36dc authored by Meet Narendra's avatar Meet Narendra 💬

Cycle GAN

parent 388a23b3
......@@ -3,3 +3,5 @@
*.csv
*.ipynb
*Logs*
*dataset*
*images*
import torch
#Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class Discriminator(torch.nn.Module):
'''
PatchGAN Discriminator with 70x70 overlapping image patches
'''
def __init__(self) -> None:
super().__init__()
def __init__(self):
super(Discriminator,self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 4, 2, 1),
torch.nn.Conv2d(3, 64, 4, stride=2, padding=1),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(64, 128, 4, 2, 1),
torch.nn.Conv2d(64, 128, 4,stride=2, padding=1),
torch.nn.InstanceNorm2d(128),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(128, 256, 4, 2, 1),
torch.nn.Conv2d(128, 256, 4, stride=2, padding=1),
torch.nn.InstanceNorm2d(256),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(256, 512, 4, 1, 1),
torch.nn.Conv2d(256, 512, 4,padding=1),
torch.nn.InstanceNorm2d(512),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(512, 1, 4, 1, 1),
torch.nn.Conv2d(512, 1, 4, padding=1),
)
def forward(self, x):
return self.model(x)
\ No newline at end of file
x = self.model(x)
x = torch.nn.functional.avg_pool2d(x,x.size()[2:])
x = torch.flatten(x,1)
return x
\ No newline at end of file
#!/bin/bash
mkdir dataset
cd dataset
for FILE in "apple2orange" "summer2winter_yosemite" "horse2zebra" "monet2photo" "cezanne2photo" "ukiyoe2photo" "vangogh2photo" "maps" "cityscapes" "facades" "iphone2dslr_flower"; do
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/${FILE}.zip
ZIP_FILE=${FILE}.zip
TARGET_DIR=${FILE}
wget ${URL}
unzip ${ZIP_FILE}
rm ${ZIP_FILE}
# Adapt to project expected directory heriarchy
mkdir -p "$TARGET_DIR/train" "$TARGET_DIR/test"
mv "$TARGET_DIR/trainA" "$TARGET_DIR/train/A"
mv "$TARGET_DIR/trainB" "$TARGET_DIR/train/B"
mv "$TARGET_DIR/testA" "$TARGET_DIR/test/A"
mv "$TARGET_DIR/testB" "$TARGET_DIR/test/B"
done
import torch
import torch.nn as nn
#Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class ResidualBlock(nn.Module):
def __init__(self) -> None:
super().__init__()
def __init__(self):
super(ResidualBlock,self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(256, 256, 3),
......@@ -22,33 +23,40 @@ class Generator(torch.nn.Module):
https://arxiv.org/pdf/1603.08155.pdf
'''
def __init__(self) -> None:
super().__init__()
def __init__(self):
super(Generator,self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 7, 1, 3),
nn.ReflectionPad2d(3),
nn.Conv2d(3, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, 2, 1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 256, 3, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(True),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7),
nn.Tanh(),
......
......@@ -5,6 +5,7 @@ LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi
device = torch.device("cpu")
class Loss:
@staticmethod
def adversarial_G():
......@@ -12,7 +13,7 @@ class Loss:
@params
@return
'''
return torch.nn.BCELoss().to(device)
return torch.nn.MSELoss().to(device)
@staticmethod
def adversarial_D():
......@@ -20,25 +21,16 @@ class Loss:
@params
@return
'''
return torch.nn.BCELoss().to(device)
return torch.nn.MSELoss().to(device)
@staticmethod
def cycle_consistency_forward():
def cycle_consistency():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def cycle_consistency_backward():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def identity():
'''
......
......@@ -3,21 +3,38 @@ import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import glob
import os
from PIL import Image
import random
#Author: @meetdoshi
class LoadData(Dataset):
def __init__(self, data, transform=None):
def __init__(self, data, pair=False, type_data="train"):
self.data = data
self.transform = transform
self.pair = pair
self.type = type_data
self.transform = transforms.Compose([
transforms.Resize(int(256 * 1.12), Image.BICUBIC),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
#LOGGER.info(data)
#LOGGER.info(os.path.join(self.data,f"{type_data}/A"+ "*.jpg"))
self.X = sorted(glob.glob(os.path.join(self.data, f"{type_data}/A/")+ "*.jpg"))
self.Y = sorted(glob.glob(os.path.join(self.data, f"{type_data}/B/")+ "*.jpg"))
def __len__(self):
return len(self.data)
return max(len(self.X), len(self.Y))
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = self.data[idx]
if self.transform:
sample = self.transform(sample)
return sample
\ No newline at end of file
if self.pair:
return {'X':self.transform(Image.open(self.X[idx%len(self.X)])), 'Y':self.transform(Image.open(self.Y[idx%len(self.Y)]))}
else:
return {'X':self.transform(Image.open(self.X[idx%len(self.X)])), 'Y':self.transform(Image.open(self.Y[random.randint(0, len(self.Y)-1)%len(self.Y)]))}
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from discriminator import Discriminator
from generator import Generator
from loss import Loss
......@@ -10,7 +13,7 @@ from preprocess import LoadData
from tqdm import tqdm
from utils import initialize_weights
class Train():
def __init__(self):
def __init__(self,data="dataset/vangogh2photo",pair=False):
'''
@params
@return
......@@ -33,13 +36,15 @@ class Train():
self.dis_X_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_X_optim, step_size=100, gamma=0.1)
self.dis_Y_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_Y_optim, step_size=100, gamma=0.1)
self.cycle_loss = Loss().cycle_loss.to(device)
self.identity_loss = Loss().identity_loss.to(device)
self.adversarial_loss = Loss().adversarial_loss.to(device)
self.cycle_loss = Loss().cycle_consistency
self.identity_loss = Loss().identity
self.adversarial_loss = Loss().adversarial_G
self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []}
self.dataset = LoadData(data=data,pair=pair)
self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True, num_workers=4)
def train():
def train(self):
'''
@params
@return
......@@ -71,4 +76,99 @@ class Train():
17. Save Image
18. Update Learning Rate
'''
\ No newline at end of file
'''
adversarial_loss = self.adversarial_loss()
cycle_loss = self.cycle_loss()
size = len(self.dataloader)
for i, data in tqdm(enumerate(self.dataloader), total=size):
real_X = data['X'].to(device)
real_Y = data['Y'].to(device)
batch_size = real_X.size(0)
real_label = torch.ones(batch_size, 1).to(device)
fake_label = torch.zeros(batch_size, 1).to(device)
#print(real_label.shape)
# Training the generator
self.gen_XY_optim.zero_grad()
fake_gen_X = self.gen_XY(real_Y)
fake_gen_Y = self.gen_YX(real_X)
fake_gen_X_label = self.dis_X(fake_gen_X)
fake_gen_Y_label = self.dis_Y(fake_gen_Y)
#print(fake_gen_X.shape,fake_gen_X_label.shape)
#print(fake_gen_Y_label,real_label)
loss_gen_Y2X = adversarial_loss(fake_gen_X_label, real_label)
loss_gen_X2Y = adversarial_loss(fake_gen_Y_label, real_label)
recovered_Y = self.gen_XY(fake_gen_X)
recovered_X = self.gen_YX(fake_gen_Y)
#print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y)
loss_cycle_X2Y = cycle_loss(recovered_X, real_X)
total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y
#backprop
total_loss.backward()
self.gen_XY_optim.step()
# Training the discriminator
# Discriminator for X
self.dis_X_optim.zero_grad()
real_X_label = self.dis_X(real_X)
loss_dis_real_A = adversarial_loss(real_X_label, real_label)
fake_X_label = self.dis_X(fake_gen_X.detach())
loss_dis_fake_A = adversarial_loss(fake_X_label, fake_label)
loss_dis_X = (loss_dis_real_A + loss_dis_fake_A) / 2
#backprop
loss_dis_X.backward()
self.dis_X_optim.step()
# Discriminator for Y
self.dis_Y_optim.zero_grad()
real_Y_label = self.dis_Y(real_Y)
loss_dis_real_B = adversarial_loss(real_Y_label, real_label)
fake_Y_label = self.dis_Y(fake_gen_Y.detach())
loss_dis_fake_B = adversarial_loss(fake_Y_label, fake_label)
loss_dis_Y = (loss_dis_real_B + loss_dis_fake_B) / 2
#backprop
loss_dis_Y.backward()
self.dis_Y_optim.step()
# Update Logs
self.losses["G"].append(total_loss.item())
self.losses["D"].append((loss_dis_X.item() + loss_dis_Y.item()) / 2)
self.losses["C"].append((loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2)
LOGGER.info("Epoch: {} | G: {} | D: {} | C: {}".format(epoch, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2))
# Save Image
if i % 100 == 0:
save_image(fake_gen_X, "images/{}_{}_fake_X.png".format(epoch, i))
save_image(fake_gen_Y, "images/{}_{}_fake_Y.png".format(epoch, i))
save_image(real_X, "images/{}_{}_real_X.png".format(epoch, i))
save_image(real_Y, "images/{}_{}_real_Y.png".format(epoch, i))
# Update Learning Rate
self.gen_XY_scheduler.step()
self.dis_X_scheduler.step()
self.dis_Y_scheduler.step()
# Save weights
torch.save(self.gen_XY.state_dict(), "weights/gen_XY.pth")
torch.save(self.gen_YX.state_dict(), "weights/gen_YX.pth")
torch.save(self.dis_X.state_dict(), "weights/dis_X.pth")
torch.save(self.dis_Y.state_dict(), "weights/dis_Y.pth")
if __name__ == "__main__":
train = Train()
train.train()
......@@ -4,6 +4,8 @@ from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
#Author: @meetdoshi
def initialize_weights(model):
'''
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment