Commit cc8452e5 authored by Yuxin Wu's avatar Yuxin Wu

small fix

parent f1d15364
......@@ -28,6 +28,7 @@ This model uses the whole training set instead of a train-val split.
"""
BATCH_SIZE = 128
NUM_UNITS = None
class Model(ModelDesc):
def __init__(self, n):
......@@ -143,7 +144,7 @@ def get_config():
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]),
model=Model(n=18),
model=Model(n=NUM_UNITS),
step_per_epoch=step_per_epoch,
max_epoch=400,
)
......@@ -151,8 +152,12 @@ def get_config():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-n', '--num_units',
help='number of units in each stage',
type=int, default=18)
parser.add_argument('--load', help='load model')
args = parser.parse_args()
NUM_UNITS = args.num_units
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......
......@@ -3,7 +3,7 @@
# File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
import six, os
import tensorflow as tf
from collections import defaultdict
import re
......@@ -92,6 +92,8 @@ the same name".format(v.name))
def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict"""
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921
reader = tf.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment