Commit 2d856f94 authored by Yuxin Wu's avatar Yuxin Wu

Print something for #218

parent 7ea758cd
...@@ -87,12 +87,13 @@ def get_config(fake=False, data_format='NCHW'): ...@@ -87,12 +87,13 @@ def get_config(fake=False, data_format='NCHW'):
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
global BATCH_SIZE global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
if fake: if fake:
logger.info("For benchmark, batch size is fixed to 64 per tower.")
dataset_train = dataset_val = FakeData( dataset_train = dataset_val = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8') [[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
else: else:
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
dataset_train = get_data('train') dataset_train = get_data('train')
dataset_val = get_data('val') dataset_val = get_data('val')
......
...@@ -45,6 +45,7 @@ class ModelSaver(Callback): ...@@ -45,6 +45,7 @@ class ModelSaver(Callback):
self.var_collections = var_collections self.var_collections = var_collections
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR checkpoint_dir = logger.LOG_DIR
assert checkpoint_dir is not None
assert tf.gfile.IsDirectory(checkpoint_dir), checkpoint_dir assert tf.gfile.IsDirectory(checkpoint_dir), checkpoint_dir
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
...@@ -150,7 +151,6 @@ class MinSaver(Callback): ...@@ -150,7 +151,6 @@ class MinSaver(Callback):
files_to_copy = glob.glob(path + '*') files_to_copy = glob.glob(path + '*')
for file_to_copy in files_to_copy: for file_to_copy in files_to_copy:
shutil.copy(file_to_copy, file_to_copy.replace(path, newname)) shutil.copy(file_to_copy, file_to_copy.replace(path, newname))
#shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format( logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat)) 'maximum' if self.reverse else 'minimum', self.monitor_stat))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment