Commit c4288d29 authored by Meet Narendra's avatar Meet Narendra 💬

Speeding up training

parent f462bad0
......@@ -17,8 +17,8 @@ class LoadData(Dataset):
self.pair = pair
self.type = type_data
self.transform = transforms.Compose([
transforms.Resize(int(256 * 1.12), Image.BICUBIC),
transforms.RandomCrop(256),
transforms.Resize(int(64 * 1.12), Image.BICUBIC),
transforms.RandomCrop(64),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
......
import numpy as np
import itertools
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
......@@ -15,13 +16,15 @@ LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache()
import os
class Train():
def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1):
def __init__(self,data="dataset/summer2winter_yosemite",pair=False,epochs=200,batch_size=1):
'''
@params
@return
'''
if not os.path.exists("images"):
os.mkdir("images")
if not os.path.exists("weights"):
os.mkdir("weights")
self.epochs = epochs
self.batch_size = batch_size
......@@ -34,7 +37,7 @@ class Train():
self.dis_X.apply(initialize_weights)
self.dis_Y.apply(initialize_weights)
self.gen_XY_optim = torch.optim.Adam(self.gen_XY.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.gen_XY_optim = torch.optim.Adam(itertools.chain(self.gen_XY.parameters(),self.gen_YX.parameters()), lr=0.0002, betas=(0.5, 0.999))
self.dis_X_optim = torch.optim.Adam(self.dis_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.dis_Y_optim = torch.optim.Adam(self.dis_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment