Commit 31236d84 authored by Yuxin Wu's avatar Yuxin Wu

add colorization mode to im2im

parent 743dc730
......@@ -19,16 +19,20 @@ import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, GANModelDesc
"""
To train:
To train Image-to-Image translation model with image pairs:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain jpg images of shpae 2s x s, formed by A and B
# you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
# training visualization will appear be in tensorboard
To train colorization:
./Image2Image.py --data /path/to/datadir --mode colorization --batch 4
# datadir should contain colored jpg images
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
Training visualization will appear be in tensorboard.
To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
......@@ -133,10 +137,12 @@ class Model(GANModelDesc):
def split_input(img):
"""
img: an image with shape (s, 2s, 3)
img: an RGB image of shape (s, 2s, 3).
:return: [input, output]
"""
# split the image into left + right pairs
s = img.shape[0]
assert img.shape[1] == 2 * s
input, output = img[:, :s, :], img[:, s:, :]
if args.mode == 'BtoA':
input, output = output, input
......@@ -147,13 +153,32 @@ def split_input(img):
return [input, output]
def colorization_input(img):
assert img.ndim == 3
# create gray + RGB pairs
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
return [gray, img]
def get_data():
datadir = args.data
# assume each image is 512x256 split to left and right
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
if args.mode == 'colorization':
# colorization mode
ds = MapData(ds, lambda dp: colorization_input(dp[0]))
augs = [imgaug.RandomResize(
xrange=(0.75, 1.5), yrange=(0.75, 1.5),
minimum=(SHAPE, SHAPE),
aspect_ratio_thres=0),
imgaug.RandomCrop(SHAPE)]
else:
# Image-to-Image translation mode
ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1)
......@@ -168,7 +193,7 @@ def get_config():
dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=[
PeriodicCallback(ModelSaver(), 3),
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
model=Model(),
......@@ -200,12 +225,20 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='Image directory')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
parser.add_argument('--mode', choices=['AtoB', 'BtoA', 'colorization'], default='AtoB')
parser.add_argument('-b', '--batch', type=int, default=1)
global args
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
assert args.data
BATCH = args.batch
if args.mode == 'colorization':
IN_CH = 1
OUT_CH = 3
if args.sample:
sample(args.data, args.load)
else:
......
......@@ -119,13 +119,13 @@ class RandomResize(ImageAugmentor):
sy = sx
else:
sy = self._rand_range(*self.yrange)
destX = int(max(sx * img.shape[1], self.minimum[0]))
destY = int(max(sy * img.shape[0], self.minimum[1]))
destX = max(sx * img.shape[1], self.minimum[0])
destY = max(sy * img.shape[0], self.minimum[1])
oldr = img.shape[1] * 1.0 / img.shape[0]
newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres + 1e-7:
return (destX, destY)
if diff <= self.aspect_ratio_thres + 1e-5:
return (int(destX), int(destY))
cnt += 1
if cnt > 50:
logger.warn("RandomResize failed to augment an image")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment