Commit f77da926 authored by Meet Narendra's avatar Meet Narendra 💬

Merge branch 'DEV' into 'master'

Merging dev to master

See merge request !3
parents 7c0e64ce 061d52d4
......@@ -3,3 +3,6 @@
*.csv
*.ipynb
*Logs*
*dataset*
*images*
*weights*
import torch
#Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class Discriminator(torch.nn.Module):
'''
PatchGAN Discriminator with 70x70 overlapping image patches
'''
def __init__(self):
super(Discriminator,self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 4, stride=2, padding=1),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(64, 128, 4,stride=2, padding=1),
torch.nn.InstanceNorm2d(128),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(128, 256, 4, stride=2, padding=1),
torch.nn.InstanceNorm2d(256),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(256, 512, 4,padding=1),
torch.nn.InstanceNorm2d(512),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(512, 1, 4, padding=1),
)
def forward(self, x):
x = self.model(x)
x = torch.nn.functional.avg_pool2d(x,x.size()[2:])
x = torch.flatten(x,1)
return x
\ No newline at end of file
#!/bin/bash
mkdir dataset
cd dataset
for FILE in "apple2orange" "summer2winter_yosemite" "horse2zebra" "monet2photo" "cezanne2photo" "ukiyoe2photo" "vangogh2photo" "maps" "cityscapes" "facades" "iphone2dslr_flower"; do
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/${FILE}.zip
ZIP_FILE=${FILE}.zip
TARGET_DIR=${FILE}
wget ${URL}
unzip ${ZIP_FILE}
rm ${ZIP_FILE}
# Adapt to project expected directory heriarchy
mkdir -p "$TARGET_DIR/train" "$TARGET_DIR/test"
mv "$TARGET_DIR/trainA" "$TARGET_DIR/train/A"
mv "$TARGET_DIR/trainB" "$TARGET_DIR/train/B"
mv "$TARGET_DIR/testA" "$TARGET_DIR/test/A"
mv "$TARGET_DIR/testB" "$TARGET_DIR/test/B"
done
import torch
import torch.nn as nn
#Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class ResidualBlock(nn.Module):
def __init__(self):
super(ResidualBlock,self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(256, 256, 3),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(256, 256, 3),
nn.InstanceNorm2d(256),
)
def forward(self, x):
return x + self.block(x)
class Generator(torch.nn.Module):
'''
https://arxiv.org/pdf/1603.08155.pdf
'''
def __init__(self):
super(Generator,self).__init__()
self.model = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(3, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, 2, 1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 256, 3, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(True),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7),
nn.Tanh(),
)
def forward(self, x):
return self.model(x)
import torch
class ImageBuffer:
def __init__(self) -> None:
pass
\ No newline at end of file
......@@ -2,10 +2,38 @@ import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi
class Loss:
@staticmethod
def adversarial_G():
'''
L_gan(G,Dy,X,Y) =
@params
@return
'''
return torch.nn.MSELoss().to(device)
@staticmethod
def cycle_consistency():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def identity():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def total_loss(lamda=10):
'''
@params
@return
'''
return Loss.adversarial_G() + Loss.adversarial_D() + lamda*Loss.cycle_consistency()
\ No newline at end of file
import torch
#Author: @meetdoshi
class Optimizer():
@staticmethod
def gradient_descent():
pass
\ No newline at end of file
import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import glob
import os
from PIL import Image
import random
#Author: @meetdoshi
class LoadData(Dataset):
def __init__(self, data, pair=False, type_data="train"):
self.data = data
self.pair = pair
self.type = type_data
self.transform = transforms.Compose([
transforms.Resize(int(256 * 1.12), Image.BICUBIC),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
#LOGGER.info(data)
#LOGGER.info(os.path.join(self.data,f"{type_data}/A"+ "*.jpg"))
self.X = sorted(glob.glob(os.path.join(self.data, f"{type_data}/A/")+ "*.jpg"))
self.Y = sorted(glob.glob(os.path.join(self.data, f"{type_data}/B/")+ "*.jpg"))
def __len__(self):
return max(len(self.X), len(self.Y))
def __getitem__(self, idx):
if self.pair:
return {'X':self.transform(Image.open(self.X[idx%len(self.X)])), 'Y':self.transform(Image.open(self.Y[idx%len(self.Y)]))}
else:
return {'X':self.transform(Image.open(self.X[idx%len(self.X)])), 'Y':self.transform(Image.open(self.Y[random.randint(0, len(self.Y)-1)%len(self.Y)]))}
import numpy as np
import itertools
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from discriminator import Discriminator
from generator import Generator
from loss import Loss
from preprocess import LoadData
from tqdm import tqdm
from utils import initialize_weights
LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache()
import os
class Train():
def __init__(self,data="dataset/summer2winter_yosemite",pair=False,epochs=200,batch_size=1):
'''
@params
@return
'''
if not os.path.exists("images"):
os.mkdir("images")
if not os.path.exists("weights"):
os.mkdir("weights")
self.epochs = epochs
self.batch_size = batch_size
self.gen_XY = Generator().to(device)
self.gen_YX = Generator().to(device)
self.dis_X = Discriminator().to(device)
self.dis_Y = Discriminator().to(device)
self.gen_XY.apply(initialize_weights)
self.gen_YX.apply(initialize_weights)
self.dis_X.apply(initialize_weights)
self.dis_Y.apply(initialize_weights)
self.gen_XY_optim = torch.optim.Adam(itertools.chain(self.gen_XY.parameters(),self.gen_YX.parameters()), lr=0.0002, betas=(0.5, 0.999))
self.dis_X_optim = torch.optim.Adam(self.dis_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.dis_Y_optim = torch.optim.Adam(self.dis_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.gen_XY_scheduler = torch.optim.lr_scheduler.StepLR(self.gen_XY_optim, step_size=100, gamma=0.1)
self.dis_X_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_X_optim, step_size=100, gamma=0.1)
self.dis_Y_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_Y_optim, step_size=100, gamma=0.1)
self.cycle_loss = Loss().cycle_consistency
self.identity_loss = Loss().identity
self.adversarial_loss = Loss().adversarial_G
self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []}
self.dataset = LoadData(data=data,pair=pair)
self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
def train(self):
'''
@params
@return
'''
EPOCHS = self.epochs
batch_size = self.batch_size
for epoch in range(EPOCHS):
'''
Steps:
For Generator
1. Load Images
2. Reset Gradients
3. Generate Fake Images
4. Calculate Advetisarial Loss
5. Generate Original Images from Fake Images
6. Calculate Cycle Consistency Loss
7. Total Loss
8. Backpropagate
9. Update Weights
For Discriminator
10. Reset Gradients
11. Calculate Adversarial Loss between labels for real images and generated images
12. Total loss = Average of both losses
13. Backpropagate
14. Update Weights
15. Do same for other discriminator
16. Update Logs
17. Save Image
18. Update Learning Rate
'''
adversarial_loss = self.adversarial_loss()
cycle_loss = self.cycle_loss()
identity_loss = self.identity_loss()
size = len(self.dataloader)
for i, data in tqdm(enumerate(self.dataloader), total=size):
torch.cuda.empty_cache()
real_X = data['X'].to(device)
real_Y = data['Y'].to(device)
#print(real_X.shape)
#print(real_Y.shape)
batch_size = real_X.size(0)
real_label = torch.ones(batch_size, 1).to(device)
fake_label = torch.zeros(batch_size, 1).to(device)
#print(real_label.shape)
# Training the generator
self.gen_XY_optim.zero_grad()
identity_X = self.gen_YX(real_X)
identity_Y = self.gen_XY(real_Y)
loss_iden_X = identity_loss(identity_X,real_X) * 10
loss_iden_Y = identity_loss(identity_Y,real_Y) * 10
fake_gen_X = self.gen_YX(real_Y)
fake_gen_Y = self.gen_XY(real_X)
fake_gen_X_label = self.dis_X(fake_gen_X)
fake_gen_Y_label = self.dis_Y(fake_gen_Y)
#print(fake_gen_X.shape,fake_gen_X_label.shape)
#print(fake_gen_Y_label,real_label)
loss_gen_Y2X = adversarial_loss(fake_gen_X_label, real_label)
loss_gen_X2Y = adversarial_loss(fake_gen_Y_label, real_label)
recovered_Y = self.gen_XY(fake_gen_X)
recovered_X = self.gen_YX(fake_gen_Y)
#print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X = cycle_loss(recovered_Y, real_Y) * 20
loss_cycle_X2Y = cycle_loss(recovered_X, real_X) * 20
total_loss = loss_gen_Y2X + loss_gen_X2Y + loss_cycle_Y2X + loss_cycle_X2Y + loss_iden_X + loss_iden_Y
#backprop
total_loss.backward()
self.gen_XY_optim.step()
# Training the discriminator
# Discriminator for X
self.dis_X_optim.zero_grad()
real_X_label = self.dis_X(real_X)
loss_dis_real_X = adversarial_loss(real_X_label, real_label)
fake_X_label = self.dis_X(fake_gen_X.detach())
loss_dis_fake_X = adversarial_loss(fake_X_label, fake_label)
loss_dis_X = (loss_dis_real_X + loss_dis_fake_X) / 2
#backprop
loss_dis_X.backward()
self.dis_X_optim.step()
# Discriminator for Y
self.dis_Y_optim.zero_grad()
real_Y_label = self.dis_Y(real_Y)
loss_dis_real_Y = adversarial_loss(real_Y_label, real_label)
fake_Y_label = self.dis_Y(fake_gen_Y.detach())
loss_dis_fake_Y = adversarial_loss(fake_Y_label, fake_label)
loss_dis_Y = (loss_dis_real_Y + loss_dis_fake_Y) / 2
#backprop
loss_dis_Y.backward()
self.dis_Y_optim.step()
# Update Logs
self.losses["G"].append(total_loss.item())
self.losses["D"].append((loss_dis_X.item() + loss_dis_Y.item()) / 2)
self.losses["C"].append((loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2)
LOGGER.info("Epoch: {} | i: {} | G: {} | D: {} | C: {}".format(epoch, i, total_loss.item(), (loss_dis_X.item() + loss_dis_Y.item()) / 2, (loss_cycle_Y2X.item() + loss_cycle_X2Y.item()) / 2))
# Save Image
if i % 100 == 0:
save_image(fake_gen_X, "images/{}_{}_fake_X.png".format(epoch, i))
save_image(fake_gen_Y, "images/{}_{}_fake_Y.png".format(epoch, i))
save_image(real_X, "images/{}_{}_real_X.png".format(epoch, i))
save_image(real_Y, "images/{}_{}_real_Y.png".format(epoch, i))
# Update Learning Rate
self.gen_XY_scheduler.step()
self.dis_X_scheduler.step()
self.dis_Y_scheduler.step()
# Save weights
torch.save(self.gen_XY.state_dict(), "weights/gen_XY.pth")
torch.save(self.gen_YX.state_dict(), "weights/gen_YX.pth")
torch.save(self.dis_X.state_dict(), "weights/dis_X.pth")
torch.save(self.dis_Y.state_dict(), "weights/dis_Y.pth")
#Save losses
torch.save(self.losses,"losses.pt")
if __name__ == "__main__":
train = Train()
train.train()
import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi
def initialize_weights(model):
'''
@params
@return
'''
for m in model.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, torch.nn.BatchNorm2d):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0)
elif isinstance(m, torch.nn.Linear):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment