Commit 031e698d authored by Yuxin Wu's avatar Yuxin Wu

windows support for Image2Image (fix #1413)

parent f1a8acf4
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse import argparse
import functools
import glob import glob
import numpy as np import numpy as np
import os import os
...@@ -146,7 +147,7 @@ class Model(GANModelDesc): ...@@ -146,7 +147,7 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3) return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def split_input(dp): def split_input(mode, dp):
""" """
dp: the datapoint. first component is an RGB image of shape (s, 2s, 3). dp: the datapoint. first component is an RGB image of shape (s, 2s, 3).
:return: [input, output] :return: [input, output]
...@@ -156,7 +157,7 @@ def split_input(dp): ...@@ -156,7 +157,7 @@ def split_input(dp):
s = img.shape[0] s = img.shape[0]
assert img.shape[1] == 2 * s assert img.shape[1] == 2 * s
input, output = img[:, :s, :], img[:, s:, :] input, output = img[:, :s, :], img[:, s:, :]
if args.mode == 'BtoA': if mode == 'BtoA':
input, output = output, input input, output = output, input
if IN_CH == 1: if IN_CH == 1:
input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis] input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
...@@ -165,12 +166,12 @@ def split_input(dp): ...@@ -165,12 +166,12 @@ def split_input(dp):
return [input, output] return [input, output]
def get_data(): def get_data(args):
datadir = args.data datadir = args.data
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, split_input) ds = MapData(ds, functools.partial(split_input, args.mode))
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)] augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
...@@ -217,7 +218,7 @@ if __name__ == '__main__': ...@@ -217,7 +218,7 @@ if __name__ == '__main__':
else: else:
logger.auto_set_dir() logger.auto_set_dir()
data = QueueInput(get_data()) data = QueueInput(get_data(args))
trainer = GANTrainer(data, Model(), get_num_gpu()) trainer = GANTrainer(data, Model(), get_num_gpu())
trainer.train_with_defaults( trainer.train_with_defaults(
......
...@@ -196,33 +196,35 @@ class AugmentImageComponents(MapData): ...@@ -196,33 +196,35 @@ class AugmentImageComponents(MapData):
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.ds = ds self.ds = ds
self._exception_handler = ExceptionHandler(catch_exceptions)
self._copy = copy
self._index = index
self._coords_index = coords_index
exception_handler = ExceptionHandler(catch_exceptions) super(AugmentImageComponents, self).__init__(ds, self._aug_mapper)
def func(dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
check_dtype(im)
tfms = self.augs.get_transform(im)
dp[major_image] = tfms.apply_image(im)
for idx in index[1:]:
check_dtype(dp[idx])
dp[idx] = tfms.apply_image(copy_func(dp[idx]))
for idx in coords_index:
coords = copy_func(dp[idx])
validate_coords(coords)
dp[idx] = tfms.apply_coords(coords)
return dp
super(AugmentImageComponents, self).__init__(ds, func)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self.augs.reset_state() self.augs.reset_state()
def _aug_mapper(self, dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if self._copy else lambda x: x # noqa
with self._exception_handler.catch():
major_image = self._index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
check_dtype(im)
tfms = self.augs.get_transform(im)
dp[major_image] = tfms.apply_image(im)
for idx in self._index[1:]:
check_dtype(dp[idx])
dp[idx] = tfms.apply_image(copy_func(dp[idx]))
for idx in self._coords_index:
coords = copy_func(dp[idx])
validate_coords(coords)
dp[idx] = tfms.apply_coords(coords)
return dp
try: try:
import cv2 import cv2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment