Commit c4288d29 authored by Meet Narendra's avatar Meet Narendra 💬

Speeding up training

parent f462bad0
...@@ -17,8 +17,8 @@ class LoadData(Dataset): ...@@ -17,8 +17,8 @@ class LoadData(Dataset):
self.pair = pair self.pair = pair
self.type = type_data self.type = type_data
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Resize(int(256 * 1.12), Image.BICUBIC), transforms.Resize(int(64 * 1.12), Image.BICUBIC),
transforms.RandomCrop(256), transforms.RandomCrop(64),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
......
import numpy as np import numpy as np
import itertools
import torch import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image from torchvision.utils import save_image
...@@ -15,13 +16,15 @@ LOGGER.info("Cuda status: "+str(device)) ...@@ -15,13 +16,15 @@ LOGGER.info("Cuda status: "+str(device))
torch.cuda.empty_cache() torch.cuda.empty_cache()
import os import os
class Train(): class Train():
def __init__(self,data="dataset/vangogh2photo",pair=False,epochs=200,batch_size=1): def __init__(self,data="dataset/summer2winter_yosemite",pair=False,epochs=200,batch_size=1):
''' '''
@params @params
@return @return
''' '''
if not os.path.exists("images"): if not os.path.exists("images"):
os.mkdir("images") os.mkdir("images")
if not os.path.exists("weights"):
os.mkdir("weights")
self.epochs = epochs self.epochs = epochs
self.batch_size = batch_size self.batch_size = batch_size
...@@ -34,7 +37,7 @@ class Train(): ...@@ -34,7 +37,7 @@ class Train():
self.dis_X.apply(initialize_weights) self.dis_X.apply(initialize_weights)
self.dis_Y.apply(initialize_weights) self.dis_Y.apply(initialize_weights)
self.gen_XY_optim = torch.optim.Adam(self.gen_XY.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.gen_XY_optim = torch.optim.Adam(itertools.chain(self.gen_XY.parameters(),self.gen_YX.parameters()), lr=0.0002, betas=(0.5, 0.999))
self.dis_X_optim = torch.optim.Adam(self.dis_X.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.dis_X_optim = torch.optim.Adam(self.dis_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.dis_Y_optim = torch.optim.Adam(self.dis_Y.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.dis_Y_optim = torch.optim.Adam(self.dis_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment