Commit cc8452e5 authored by Yuxin Wu's avatar Yuxin Wu

small fix

parent f1d15364
...@@ -28,6 +28,7 @@ This model uses the whole training set instead of a train-val split. ...@@ -28,6 +28,7 @@ This model uses the whole training set instead of a train-val split.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
NUM_UNITS = None
class Model(ModelDesc): class Model(ModelDesc):
def __init__(self, n): def __init__(self, n):
...@@ -143,7 +144,7 @@ def get_config(): ...@@ -143,7 +144,7 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]), ]),
model=Model(n=18), model=Model(n=NUM_UNITS),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=400, max_epoch=400,
) )
...@@ -151,8 +152,12 @@ def get_config(): ...@@ -151,8 +152,12 @@ def get_config():
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-n', '--num_units',
help='number of units in each stage',
type=int, default=18)
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
NUM_UNITS = args.num_units
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: varmanip.py # File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six import six, os
import tensorflow as tf import tensorflow as tf
from collections import defaultdict from collections import defaultdict
import re import re
...@@ -92,6 +92,8 @@ the same name".format(v.name)) ...@@ -92,6 +92,8 @@ the same name".format(v.name))
def dump_chkpt_vars(model_path): def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict""" """ Dump all variables from a checkpoint to a dict"""
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys() var_names = reader.get_variable_to_shape_map().keys()
result = {} result = {}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment