Commit 388a23b3 authored by Meet Narendra's avatar Meet Narendra 💬

Framework for CycleGANs

parent 6cece00c
import torch
#Author: @meetdoshi
class Discriminator(torch.nn.Module):
'''
PatchGAN Discriminator with 70x70 overlapping image patches
'''
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 4, 2, 1),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(64, 128, 4, 2, 1),
torch.nn.InstanceNorm2d(128),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(128, 256, 4, 2, 1),
torch.nn.InstanceNorm2d(256),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(256, 512, 4, 1, 1),
torch.nn.InstanceNorm2d(512),
torch.nn.LeakyReLU(0.2, True),
torch.nn.Conv2d(512, 1, 4, 1, 1),
)
def forward(self, x):
return self.model(x)
\ No newline at end of file
import torch
import torch.nn as nn
#Author: @meetdoshi
class ResidualBlock(nn.Module):
def __init__(self) -> None:
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(256, 256, 3),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(256, 256, 3),
nn.InstanceNorm2d(256),
)
def forward(self, x):
return x + self.block(x)
class Generator(torch.nn.Module):
'''
https://arxiv.org/pdf/1603.08155.pdf
'''
def __init__(self) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 7, 1, 3),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, 2, 1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 256, 3, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(True),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7),
nn.Tanh(),
)
def forward(self, x):
return self.model(x)
import torch
class ImageBuffer:
def __init__(self) -> None:
pass
\ No newline at end of file
...@@ -2,10 +2,55 @@ import numpy as np ...@@ -2,10 +2,55 @@ import numpy as np
import torch import torch
from logger import Logger from logger import Logger
LOGGER = Logger().logger() LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi
class Loss: class Loss:
@staticmethod @staticmethod
def adversarial_G(): def adversarial_G():
''' '''
L_gan(G,Dy,X,Y) = @params
@return
''' '''
return torch.nn.BCELoss().to(device)
@staticmethod
def adversarial_D():
'''
@params
@return
'''
return torch.nn.BCELoss().to(device)
@staticmethod
def cycle_consistency_forward():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def cycle_consistency_backward():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def identity():
'''
@params
@return
'''
return torch.nn.L1Loss().to(device)
@staticmethod
def total_loss(lamda=10):
'''
@params
@return
'''
return Loss.adversarial_G() + Loss.adversarial_D() + lamda*Loss.cycle_consistency()
\ No newline at end of file
import torch
#Author: @meetdoshi
class Optimizer():
@staticmethod
def gradient_descent():
pass
\ No newline at end of file
import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from torch.utils.data import Dataset, DataLoader
#Author: @meetdoshi
class LoadData(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = self.data[idx]
if self.transform:
sample = self.transform(sample)
return sample
\ No newline at end of file
import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from discriminator import Discriminator
from generator import Generator
from loss import Loss
from preprocess import LoadData
from tqdm import tqdm
from utils import initialize_weights
class Train():
def __init__(self):
'''
@params
@return
'''
self.gen_XY = Generator().to(device)
self.gen_YX = Generator().to(device)
self.dis_X = Discriminator().to(device)
self.dis_Y = Discriminator().to(device)
self.gen_XY.apply(initialize_weights)
self.gen_YX.apply(initialize_weights)
self.dis_X.apply(initialize_weights)
self.dis_Y.apply(initialize_weights)
self.gen_XY_optim = torch.optim.Adam(self.gen_XY.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.dis_X_optim = torch.optim.Adam(self.dis_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.dis_Y_optim = torch.optim.Adam(self.dis_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.gen_XY_scheduler = torch.optim.lr_scheduler.StepLR(self.gen_XY_optim, step_size=100, gamma=0.1)
self.dis_X_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_X_optim, step_size=100, gamma=0.1)
self.dis_Y_scheduler = torch.optim.lr_scheduler.StepLR(self.dis_Y_optim, step_size=100, gamma=0.1)
self.cycle_loss = Loss().cycle_loss.to(device)
self.identity_loss = Loss().identity_loss.to(device)
self.adversarial_loss = Loss().adversarial_loss.to(device)
self.losses = {"G": [], "D": [], "C": [], "I": [], "T": []}
def train():
'''
@params
@return
'''
EPOCHS = 200
for epoch in range(EPOCHS):
'''
Steps:
For Generator
1. Load Images
2. Reset Gradients
3. Generate Fake Images
4. Calculate Advetisarial Loss
5. Generate Original Images from Fake Images
6. Calculate Cycle Consistency Loss
7. Total Loss
8. Backpropagate
9. Update Weights
For Discriminator
10. Reset Gradients
11. Calculate Adversarial Loss between labels for real images and generated images
12. Total loss = Average of both losses
13. Backpropagate
14. Update Weights
15. Do same for other discriminator
16. Update Logs
17. Save Image
18. Update Learning Rate
'''
\ No newline at end of file
import numpy as np
import torch
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Author: @meetdoshi
def initialize_weights(model):
'''
@params
@return
'''
for m in model.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, torch.nn.BatchNorm2d):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0)
elif isinstance(m, torch.nn.Linear):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment