Commit c831ec0b authored by Meet Narendra's avatar Meet Narendra 💬

Added test.py

parent 061d52d4
......@@ -44,10 +44,6 @@ class Generator(torch.nn.Module):
ResidualBlock(),
ResidualBlock(),
ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
#ResidualBlock(),
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.InstanceNorm2d(128),
......
......@@ -9,7 +9,7 @@ import glob
import os
from PIL import Image
import random
#Reference only for Model and Dataset: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
#Author: @meetdoshi
class LoadData(Dataset):
def __init__(self, data, pair=False, type_data="train"):
......
import numpy as np
import itertools
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from logger import Logger
LOGGER = Logger().logger()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from discriminator import Discriminator
from generator import Generator
from loss import Loss
from preprocess import LoadData
from tqdm import tqdm
from utils import initialize_weights
LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache()
import os
class Test():
def __init__(self,data="test_images/"):
self.dataset = LoadData(data=data)
self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True, num_workers=4)
def test(self):
self.gen_XY = Generator().to(device)
self.gen_YX = Generator().to(device)
self.gen_XY.load_state_dict(torch.load("weights/gen_XY.pth"))
self.gen_YX.load_state_dict(torch.load("weights/gen_YX.pth"))
self.gen_XY.eval()
self.gen_YX.eval()
for i, data in tqdm(enumerate(self.dataloader),total=len(self.dataloader)):
X = data['X'].to(device)
Y = data['Y'].to(device)
fake_Y = self.gen_XY(X)
fake_X = self.gen_YX(Y)
save_image(fake_Y, "test_images/fake_Y_{}.png".format(i))
save_image(fake_X, "test_images/fake_X_{}.png".format(i))
save_image(X, "test_images/X_{}.png".format(i))
save_image(Y, "test_images/Y_{}.png".format(i))
if __name__ == "__main__":
test = Test()
test.test()
\ No newline at end of file
......@@ -15,8 +15,9 @@ from utils import initialize_weights
LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache()
import os
#Reference only for Model and Dataset: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class Train():
def __init__(self,data="dataset/summer2winter_yosemite",pair=False,epochs=200,batch_size=1):
def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1):
'''
@params
@return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment