Commit 938b36dc authored by Meet Narendra's avatar Meet Narendra 💬

Cycle GAN

parent 388a23b3
...@@ -3,3 +3,5 @@ ...@@ -3,3 +3,5 @@
*.csv *.csv
*.ipynb *.ipynb
*Logs* *Logs*
*dataset*
*images*
import torch import torch
#Author: @meetdoshi #Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class Discriminator(torch.nn.Module): class Discriminator(torch.nn.Module):
''' '''
PatchGAN Discriminator with 70x70 overlapping image patches PatchGAN Discriminator with 70x70 overlapping image patches
''' '''
def __init__(self) -> None: def __init__(self):
super().__init__() super(Discriminator,self).__init__()
self.model = torch.nn.Sequential( self.model = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 4, 2, 1), torch.nn.Conv2d(3, 64, 4, stride=2, padding=1),
torch.nn.LeakyReLU(0.2, True), torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(64, 128, 4, 2, 1), torch.nn.Conv2d(64, 128, 4,stride=2, padding=1),
torch.nn.InstanceNorm2d(128), torch.nn.InstanceNorm2d(128),
torch.nn.LeakyReLU(0.2, True), torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(128, 256, 4, 2, 1), torch.nn.Conv2d(128, 256, 4, stride=2, padding=1),
torch.nn.InstanceNorm2d(256), torch.nn.InstanceNorm2d(256),
torch.nn.LeakyReLU(0.2, True), torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(256, 512, 4, 1, 1), torch.nn.Conv2d(256, 512, 4,padding=1),
torch.nn.InstanceNorm2d(512), torch.nn.InstanceNorm2d(512),
torch.nn.LeakyReLU(0.2, True), torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(512, 1, 4, 1, 1), torch.nn.Conv2d(512, 1, 4, padding=1),
) )
def forward(self, x): def forward(self, x):
return self.model(x) x = self.model(x)
\ No newline at end of file x = torch.nn.functional.avg_pool2d(x,x.size()[2:])
x = torch.flatten(x,1)
return x
\ No newline at end of file
#!/bin/bash
mkdir dataset
cd dataset
for FILE in "apple2orange" "summer2winter_yosemite" "horse2zebra" "monet2photo" "cezanne2photo" "ukiyoe2photo" "vangogh2photo" "maps" "cityscapes" "facades" "iphone2dslr_flower"; do
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/${FILE}.zip
ZIP_FILE=${FILE}.zip
TARGET_DIR=${FILE}
wget ${URL}
unzip ${ZIP_FILE}
rm ${ZIP_FILE}
# Adapt to project expected directory heriarchy
mkdir -p "$TARGET_DIR/train" "$TARGET_DIR/test"
mv "$TARGET_DIR/trainA" "$TARGET_DIR/train/A"
mv "$TARGET_DIR/trainB" "$TARGET_DIR/train/B"
mv "$TARGET_DIR/testA" "$TARGET_DIR/test/A"
mv "$TARGET_DIR/testB" "$TARGET_DIR/test/B"
done
import torch import torch
import torch.nn as nn import torch.nn as nn
#Author: @meetdoshi #Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self) -> None: def __init__(self):
super().__init__() super(ResidualBlock,self).__init__()
self.block = nn.Sequential( self.block = nn.Sequential(
nn.ReflectionPad2d(1), nn.ReflectionPad2d(1),
nn.Conv2d(256, 256, 3), nn.Conv2d(256, 256, 3),
...@@ -22,33 +23,40 @@ class Generator(torch.nn.Module): ...@@ -22,33 +23,40 @@ class Generator(torch.nn.Module):
https://arxiv.org/pdf/1603.08155.pdf https://arxiv.org/pdf/1603.08155.pdf
''' '''
def __init__(self) -> None: def __init__(self):
super().__init__() super(Generator,self).__init__()
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv2d(3, 64, 7, 1, 3), nn.ReflectionPad2d(3),
nn.Conv2d(3, 64, 7),
nn.InstanceNorm2d(64), nn.InstanceNorm2d(64),
nn.ReLU(True), nn.ReLU(True),
nn.Conv2d(64, 128, 3, 2, 1), nn.Conv2d(64, 128, 3, 2, 1),
nn.InstanceNorm2d(128), nn.InstanceNorm2d(128),
nn.ReLU(True), nn.ReLU(True),
nn.Conv2d(128, 256, 3, 2, 1), nn.Conv2d(128, 256, 3, 2, 1),
nn.InstanceNorm2d(256), nn.InstanceNorm2d(256),
nn.ReLU(True), nn.ReLU(True),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(), ResidualBlock(),
ResidualBlock(), ResidualBlock(),
ResidualBlock(), ResidualBlock(),
ResidualBlock(), ResidualBlock(),
ResidualBlock(), ResidualBlock(),
#ResidualBlock(), ResidualBlock(),
#ResidualBlock(), ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128), nn.InstanceNorm2d(128),
nn.ReLU(True), nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
nn.InstanceNorm2d(64), nn.InstanceNorm2d(64),
nn.ReLU(True), nn.ReLU(True),
nn.ReflectionPad2d(3), nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7), nn.Conv2d(64, 3, 7),
nn.Tanh(), nn.Tanh(),
......
...@@ -5,6 +5,7 @@ LOGGER = Logger().logger() ...@@ -5,6 +5,7 @@ LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi #Author: @meetdoshi
device = torch.device("cpu")
class Loss: class Loss:
@staticmethod @staticmethod
def adversarial_G(): def adversarial_G():
...@@ -12,7 +13,7 @@ class Loss: ...@@ -12,7 +13,7 @@ class Loss:
@params @params
@return @return
''' '''
return torch.nn.BCELoss().to(device) return torch.nn.MSELoss().to(device)
@staticmethod @staticmethod
def adversarial_D(): def adversarial_D():
...@@ -20,25 +21,16 @@ class Loss: ...@@ -20,25 +21,16 @@ class Loss:
@params @params
@return @return
''' '''
return torch.nn.BCELoss().to(device) return torch.nn.MSELoss().to(device)
@staticmethod @staticmethod
def cycle_consistency_forward(): def cycle_consistency():
''' '''
@params @params
@return @return
''' '''
return torch.nn.L1Loss().to(device) return torch.nn.L1Loss().to(device)
@staticmethod
def cycle_consistency_backward():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod @staticmethod
def identity(): def identity():
''' '''
......
...@@ -3,21 +3,38 @@ import torch ...@@ -3,21 +3,38 @@ import torch
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import glob
import os
from PIL import Image
import random
#Author: @meetdoshi #Author: @meetdoshi
class LoadData(Dataset): class LoadData(Dataset):
def __init__(self, data, transform=None): def __init__(self, data, pair=False, type_data="train"):
self.data = data self.data = data
self.transform = transform self.pair = pair
self.type = type_data
self.transform = transforms.Compose([
transforms.Resize(int(256 * 1.12), Image.BICUBIC),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
#LOGGER.info(data)
#LOGGER.info(os.path.join(self.data,f"{type_data}/A"+ "*.jpg"))
self.X = sorted(glob.glob(os.path.join(self.data, f"{type_data}/A/")+ "*.jpg"))
self.Y = sorted(glob.glob(os.path.join(self.data, f"{type_data}/B/")+ "*.jpg"))
def __len__(self): def __len__(self):
return len(self.data) return max(len(self.X), len(self.Y))
def __getitem__(self, idx): def __getitem__(self, idx):
if torch.is_tensor(idx): if self.pair:
idx = idx.tolist() return {'X':self.transform(Image.open(self.X[idx%len(self.X)])), 'Y':self.transform(Image.open(self.Y[idx%len(self.Y)]))}
sample = self.data[idx] else:
if self.transform: return {'X':self.transform(Image.open(self.X[idx%len(self.X)])), 'Y':self.transform(Image.open(self.Y[random.randint(0, len(self.Y)-1)%len(self.Y)]))}
sample = self.transform(sample)
return sample
\ No newline at end of file
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
from discriminator import Discriminator from discriminator import Discriminator
from generator import Generator from generator import Generator
from loss import Loss from loss import Loss
...@@ -10,7 +13,7 @@ from preprocess import LoadData ...@@ -10,7 +13,7 @@ from preprocess import LoadData
from tqdm import tqdm from tqdm import tqdm
from utils import initialize_weights from utils import initialize_weights
class Train(): class Train():
def __init__(self): def __init__(self,data="dataset/vangogh2photo",pair=False):
''' '''
@params @params
@return @return
...@@ -33,13 +36,15 @@ class Train(): ...@@ -33,13 +36,15 @@ class Train():
self.dis_X_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_X_optim, step_size=100, gamma=0.1) self.dis_X_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_X_optim, step_size=100, gamma=0.1)
self.dis_Y_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_Y_optim, step_size=100, gamma=0.1) self.dis_Y_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_Y_optim, step_size=100, gamma=0.1)
self.cycle_loss = Loss().cycle_loss.to(device) self.cycle_loss = Loss().cycle_consistency
self.identity_loss = Loss().identity_loss.to(device) self.identity_loss = Loss().identity
self.adversarial_loss = Loss().adversarial_loss.to(device) self.adversarial_loss = Loss().adversarial_G
self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []} self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []}
self.dataset = LoadData(data=data,pair=pair)
self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True, num_workers=4)
def train(): def train(self):
''' '''
@params @params
@return @return
...@@ -72,3 +77,98 @@ class Train(): ...@@ -72,3 +77,98 @@ class Train():
18. Update Learning Rate 18. Update Learning Rate
''' '''
adversarial_loss = self.adversarial_loss()
cycle_loss = self.cycle_loss()
size = len(self.dataloader)
for i, data in tqdm(enumerate(self.dataloader), total=size):
real_X = data['X'].to(device)
real_Y = data['Y'].to(device)
batch_size = real_X.size(0)
real_label = torch.ones(batch_size, 1).to(device)
fake_label = torch.zeros(batch_size, 1).to(device)
#print(real_label.shape)
# Training the generator
self.gen_XY_optim.zero_grad()
fake_gen_X = self.gen_XY(real_Y)
fake_gen_Y = self.gen_YX(real_X)
fake_gen_X_label = self.dis_X(fake_gen_X)
fake_gen_Y_label = self.dis_Y(fake_gen_Y)
#print(fake_gen_X.shape,fake_gen_X_label.shape)
#print(fake_gen_Y_label,real_label)
loss_gen_Y2X = adversarial_loss(fake_gen_X_label, real_label)
loss_gen_X2Y = adversarial_loss(fake_gen_Y_label, real_label)
recovered_Y = self.gen_XY(fake_gen_X)
recovered_X = self.gen_YX(fake_gen_Y)
#print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y)
loss_cycle_X2Y = cycle_loss(recovered_X, real_X)
total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y
#backprop
total_loss.backward()
self.gen_XY_optim.step()
# Training the discriminator
# Discriminator for X
self.dis_X_optim.zero_grad()
real_X_label = self.dis_X(real_X)
loss_dis_real_A = adversarial_loss(real_X_label, real_label)
fake_X_label = self.dis_X(fake_gen_X.detach())
loss_dis_fake_A = adversarial_loss(fake_X_label, fake_label)
loss_dis_X = (loss_dis_real_A + loss_dis_fake_A) / 2
#backprop
loss_dis_X.backward()
self.dis_X_optim.step()
# Discriminator for Y
self.dis_Y_optim.zero_grad()
real_Y_label = self.dis_Y(real_Y)
loss_dis_real_B = adversarial_loss(real_Y_label, real_label)
fake_Y_label = self.dis_Y(fake_gen_Y.detach())
loss_dis_fake_B = adversarial_loss(fake_Y_label, fake_label)
loss_dis_Y = (loss_dis_real_B + loss_dis_fake_B) / 2
#backprop
loss_dis_Y.backward()
self.dis_Y_optim.step()
# Update Logs
self.losses["G"].append(total_loss.item())
self.losses["D"].append((loss_dis_X.item() + loss_dis_Y.item()) / 2)
self.losses["C"].append((loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2)
LOGGER.info("Epoch: {} | G: {} | D: {} | C: {}".format(epoch, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2))
# Save Image
if i % 100 == 0:
save_image(fake_gen_X, "images/{}_{}_fake_X.png".format(epoch, i))
save_image(fake_gen_Y, "images/{}_{}_fake_Y.png".format(epoch, i))
save_image(real_X, "images/{}_{}_real_X.png".format(epoch, i))
save_image(real_Y, "images/{}_{}_real_Y.png".format(epoch, i))
# Update Learning Rate
self.gen_XY_scheduler.step()
self.dis_X_scheduler.step()
self.dis_Y_scheduler.step()
# Save weights
torch.save(self.gen_XY.state_dict(), "weights/gen_XY.pth")
torch.save(self.gen_YX.state_dict(), "weights/gen_YX.pth")
torch.save(self.dis_X.state_dict(), "weights/dis_X.pth")
torch.save(self.dis_Y.state_dict(), "weights/dis_Y.pth")
if __name__ == "__main__":
train = Train()
train.train()
...@@ -4,6 +4,8 @@ from logger import Logger ...@@ -4,6 +4,8 @@ from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
#Author: @meetdoshi #Author: @meetdoshi
def initialize_weights(model): def initialize_weights(model):
''' '''
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment