Commit 31236d84 authored by Yuxin Wu's avatar Yuxin Wu

add colorization mode to im2im

parent 743dc730
...@@ -19,16 +19,20 @@ import tensorpack.tfutils.symbolic_functions as symbf ...@@ -19,16 +19,20 @@ import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
""" """
To train: To train Image-to-Image translation model with image pairs:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA} ./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain jpg images of shpae 2s x s, formed by A and B # datadir should contain jpg images of shpae 2s x s, formed by A and B
# you can download some data from the original authors: # you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/ # https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
# training visualization will appear be in tensorboard
To train colorization:
./Image2Image.py --data /path/to/datadir --mode colorization --batch 4
# datadir should contain colored jpg images
Speed: Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s) On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
Training visualization will appear be in tensorboard.
To visualize on test set: To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model ./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
...@@ -133,10 +137,12 @@ class Model(GANModelDesc): ...@@ -133,10 +137,12 @@ class Model(GANModelDesc):
def split_input(img): def split_input(img):
""" """
img: an image with shape (s, 2s, 3) img: an RGB image of shape (s, 2s, 3).
:return: [input, output] :return: [input, output]
""" """
# split the image into left + right pairs
s = img.shape[0] s = img.shape[0]
assert img.shape[1] == 2 * s
input, output = img[:, :s, :], img[:, s:, :] input, output = img[:, :s, :], img[:, s:, :]
if args.mode == 'BtoA': if args.mode == 'BtoA':
input, output = output, input input, output = output, input
...@@ -147,13 +153,32 @@ def split_input(img): ...@@ -147,13 +153,32 @@ def split_input(img):
return [input, output] return [input, output]
def colorization_input(img):
assert img.ndim == 3
# create gray + RGB pairs
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
return [gray, img]
def get_data(): def get_data():
datadir = args.data datadir = args.data
# assume each image is 512x256 split to left and right # assume each image is 512x256 split to left and right
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)] if args.mode == 'colorization':
# colorization mode
ds = MapData(ds, lambda dp: colorization_input(dp[0]))
augs = [imgaug.RandomResize(
xrange=(0.75, 1.5), yrange=(0.75, 1.5),
minimum=(SHAPE, SHAPE),
aspect_ratio_thres=0),
imgaug.RandomCrop(SHAPE)]
else:
# Image-to-Image translation mode
ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1) ds = PrefetchData(ds, 100, 1)
...@@ -168,7 +193,7 @@ def get_config(): ...@@ -168,7 +193,7 @@ def get_config():
dataflow=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=[ callbacks=[
PeriodicCallback(ModelSaver(), 3), PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
], ],
model=Model(), model=Model(),
...@@ -200,12 +225,20 @@ if __name__ == '__main__': ...@@ -200,12 +225,20 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='Image directory') parser.add_argument('--data', help='Image directory')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB') parser.add_argument('--mode', choices=['AtoB', 'BtoA', 'colorization'], default='AtoB')
parser.add_argument('-b', '--batch', type=int, default=1)
global args global args
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
assert args.data assert args.data
BATCH = args.batch
if args.mode == 'colorization':
IN_CH = 1
OUT_CH = 3
if args.sample: if args.sample:
sample(args.data, args.load) sample(args.data, args.load)
else: else:
......
...@@ -119,13 +119,13 @@ class RandomResize(ImageAugmentor): ...@@ -119,13 +119,13 @@ class RandomResize(ImageAugmentor):
sy = sx sy = sx
else: else:
sy = self._rand_range(*self.yrange) sy = self._rand_range(*self.yrange)
destX = int(max(sx * img.shape[1], self.minimum[0])) destX = max(sx * img.shape[1], self.minimum[0])
destY = int(max(sy * img.shape[0], self.minimum[1])) destY = max(sy * img.shape[0], self.minimum[1])
oldr = img.shape[1] * 1.0 / img.shape[0] oldr = img.shape[1] * 1.0 / img.shape[0]
newr = destX * 1.0 / destY newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres + 1e-7: if diff <= self.aspect_ratio_thres + 1e-5:
return (destX, destY) return (int(destX), int(destY))
cnt += 1 cnt += 1
if cnt > 50: if cnt > 50:
logger.warn("RandomResize failed to augment an image") logger.warn("RandomResize failed to augment an image")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment