Commit f1fdb42e authored by Yuxin Wu's avatar Yuxin Wu

smart --load, to handle SaverV2 format

parent 412acd12
......@@ -198,10 +198,10 @@ if __name__ == '__main__':
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
assert args.data
if args.sample:
sample(args.data, args.load)
else:
assert args.data
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
......
......@@ -58,10 +58,21 @@ class SaverRestore(SessionInit):
:param prefix: add a `prefix/` for every variable in this checkpoint
"""
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = model_path
if '00000-of-00001' in model_path:
new_path = model_path.split('.data')[0]
elif model_path.endswith('.index'):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
self.set_path(model_path)
self.prefix = prefix
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment