Commit 031e698d authored by Yuxin Wu's avatar Yuxin Wu

windows support for Image2Image (fix #1413)

parent f1a8acf4
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu
import argparse
import functools
import glob
import numpy as np
import os
......@@ -146,7 +147,7 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def split_input(dp):
def split_input(mode, dp):
"""
dp: the datapoint. first component is an RGB image of shape (s, 2s, 3).
:return: [input, output]
......@@ -156,7 +157,7 @@ def split_input(dp):
s = img.shape[0]
assert img.shape[1] == 2 * s
input, output = img[:, :s, :], img[:, s:, :]
if args.mode == 'BtoA':
if mode == 'BtoA':
input, output = output, input
if IN_CH == 1:
input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
......@@ -165,12 +166,12 @@ def split_input(dp):
return [input, output]
def get_data():
def get_data(args):
datadir = args.data
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, split_input)
ds = MapData(ds, functools.partial(split_input, args.mode))
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
......@@ -217,7 +218,7 @@ if __name__ == '__main__':
else:
logger.auto_set_dir()
data = QueueInput(get_data())
data = QueueInput(get_data(args))
trainer = GANTrainer(data, Model(), get_num_gpu())
trainer.train_with_defaults(
......
......@@ -196,33 +196,35 @@ class AugmentImageComponents(MapData):
else:
self.augs = AugmentorList(augmentors)
self.ds = ds
self._exception_handler = ExceptionHandler(catch_exceptions)
self._copy = copy
self._index = index
self._coords_index = coords_index
super(AugmentImageComponents, self).__init__(ds, self._aug_mapper)
exception_handler = ExceptionHandler(catch_exceptions)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
def func(dp):
def _aug_mapper(self, dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design?
copy_func = copy_mod.deepcopy if self._copy else lambda x: x # noqa
with self._exception_handler.catch():
major_image = self._index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
check_dtype(im)
tfms = self.augs.get_transform(im)
dp[major_image] = tfms.apply_image(im)
for idx in index[1:]:
for idx in self._index[1:]:
check_dtype(dp[idx])
dp[idx] = tfms.apply_image(copy_func(dp[idx]))
for idx in coords_index:
for idx in self._coords_index:
coords = copy_func(dp[idx])
validate_coords(coords)
dp[idx] = tfms.apply_coords(coords)
return dp
super(AugmentImageComponents, self).__init__(ds, func)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
try:
import cv2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment