Commit 2d856f94 authored by Yuxin Wu's avatar Yuxin Wu

Print something for #218

parent 7ea758cd
......@@ -87,12 +87,13 @@ def get_config(fake=False, data_format='NCHW'):
nr_tower = max(get_nr_gpu(), 1)
global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
if fake:
logger.info("For benchmark, batch size is fixed to 64 per tower.")
dataset_train = dataset_val = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
else:
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
dataset_train = get_data('train')
dataset_val = get_data('val')
......
......@@ -45,6 +45,7 @@ class ModelSaver(Callback):
self.var_collections = var_collections
if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR
assert checkpoint_dir is not None
assert tf.gfile.IsDirectory(checkpoint_dir), checkpoint_dir
self.checkpoint_dir = checkpoint_dir
......@@ -150,7 +151,6 @@ class MinSaver(Callback):
files_to_copy = glob.glob(path + '*')
for file_to_copy in files_to_copy:
shutil.copy(file_to_copy, file_to_copy.replace(path, newname))
#shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment