Commit f41e6326 authored by Yuxin Wu's avatar Yuxin Wu

remove colorization from im2im. it's better to put into a separate script

parent e2261920
...@@ -25,10 +25,6 @@ To train Image-to-Image translation model with image pairs: ...@@ -25,10 +25,6 @@ To train Image-to-Image translation model with image pairs:
# you can download some data from the original authors: # you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/ # https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
To train colorization:
./Image2Image.py --data /path/to/datadir --mode colorization --batch 4
# datadir should contain colored jpg images
Speed: Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s) On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
...@@ -153,33 +149,15 @@ def split_input(img): ...@@ -153,33 +149,15 @@ def split_input(img):
return [input, output] return [input, output]
def colorization_input(img):
assert img.ndim == 3
if min(img.shape[:2]) < SHAPE:
return None # skip the image
# create gray + RGB pairs
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
return [gray, img]
def get_data(): def get_data():
datadir = args.data datadir = args.data
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
if args.mode == 'colorization': # Image-to-Image translation mode
# colorization mode ds = MapData(ds, lambda dp: split_input(dp[0]))
ds = MapData(ds, lambda dp: colorization_input(dp[0])) assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.RandomResize( augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
xrange=(0.75, 1.5), yrange=(0.75, 1.5),
minimum=(SHAPE, SHAPE),
aspect_ratio_thres=0),
imgaug.RandomCrop(SHAPE)]
else:
# Image-to-Image translation mode
ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1) ds = PrefetchData(ds, 100, 1)
...@@ -226,7 +204,7 @@ if __name__ == '__main__': ...@@ -226,7 +204,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='Image directory') parser.add_argument('--data', help='Image directory')
parser.add_argument('--mode', choices=['AtoB', 'BtoA', 'colorization'], default='AtoB') parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-b', '--batch', type=int, default=1)
global args global args
args = parser.parse_args() args = parser.parse_args()
...@@ -236,10 +214,6 @@ if __name__ == '__main__': ...@@ -236,10 +214,6 @@ if __name__ == '__main__':
BATCH = args.batch BATCH = args.batch
if args.mode == 'colorization':
IN_CH = 1
OUT_CH = 3
if args.sample: if args.sample:
sample(args.data, args.load) sample(args.data, args.load)
else: else:
......
...@@ -170,7 +170,7 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -170,7 +170,7 @@ class ModelFromMetaGraph(ModelDesc):
tf.train.import_meta_graph(filename) tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys() all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES, for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys().VARIABLES]: tf.GraphKeys.GLOBAL_VARIABLES]:
assert k in all_coll, \ assert k in all_coll, \
"Collection {} not found in metagraph!".format(k) "Collection {} not found in metagraph!".format(k)
......
...@@ -152,7 +152,7 @@ class ParamRestore(SessionInit): ...@@ -152,7 +152,7 @@ class ParamRestore(SessionInit):
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables]) variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment