Commit 4adbaa94 authored by Yuxin Wu's avatar Yuxin Wu

improve docs

parent 43f7ca75
......@@ -80,7 +80,7 @@ def get_config(model, fake=False):
EstimatedTimeLeft(),
ScheduledHyperParamSetter(
'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
(85, BASE_LR * 1e-3), (95, BASE_LR * 1e-4), (105, BASE_LR * 1e-5)]),
(90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
]
if BASE_LR > 0.1:
callbacks.append(
......@@ -102,7 +102,7 @@ def get_config(model, fake=False):
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=110,
max_epoch=105,
)
......@@ -115,7 +115,7 @@ if __name__ == '__main__':
parser.add_argument('--data_format', help='specify NCHW or NHWC',
type=str, default='NCHW')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152])
type=int, default=50, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true')
parser.add_argument('--batch', default=256, type=int,
help='total batch size. 32 per GPU gives best accuracy, higher values should be similarly good')
......
......@@ -78,10 +78,8 @@ class InferenceRunnerBase(Callback):
try:
self._size = input.size()
logger.info("InferenceRunner will eval {} iterations".format(input.size()))
except NotImplementedError:
self._size = 0
logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!")
self._hooks = []
......@@ -95,6 +93,10 @@ class InferenceRunnerBase(Callback):
def _before_train(self):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train()
if self._size > 0:
logger.info("InferenceRunner will eval {} iterations".format(self._size))
else:
logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!")
def _after_train(self):
self._input_callbacks.after_train()
......
......@@ -94,4 +94,5 @@ class EstimatedTimeLeft(Callback):
average_epoch_time = np.mean(self._times)
time_left = (self._max_epoch - self.epoch_num) * average_epoch_time
if time_left > 0:
logger.info("Estimated Time Left: " + humanize_time_delta(time_left))
......@@ -114,7 +114,7 @@ def _guess_dir_structure(dir):
else:
dir_structure = 'original'
logger.info(
"Assuming directory {} has {} structure.".format(
"[ILSVRC12] Assuming directory {} has '{}' structure.".format(
dir, dir_structure))
return dir_structure
......
......@@ -12,7 +12,6 @@ from six.moves import zip, range
from ..utils import logger
from ..tfutils.tower import TowerContext
from ..tfutils.common import get_tf_version_number
from ..tfutils.gradproc import ScaleGradient
from .utils import (
......@@ -39,16 +38,10 @@ class DataParallelBuilder(GraphBuilder):
towers(list[int]): list of GPU ids.
"""
if len(towers) > 1:
logger.info("Training a model of {} towers".format(len(towers)))
DataParallelBuilder._check_tf_version()
logger.info("[DataParallel] Training a model of {} towers.".format(len(towers)))
self.towers = towers
@staticmethod
def _check_tf_version():
assert get_tf_version_number() >= 1.1, \
"TF version {} is too old to run multi GPU training!".format(tf.VERSION)
@staticmethod
def _check_grad_list(grad_list):
"""
......@@ -103,7 +96,7 @@ class DataParallelBuilder(GraphBuilder):
index=idx,
vs_name=tower_names[idx] if usevs else ''):
if len(str(device)) < 10: # a device function doesn't have good string description
logger.info("Building graph for training tower {} on device {}...".format(idx, device))
logger.info("Building graph for training tower {} on device {} ...".format(idx, device))
else:
logger.info("Building graph for training tower {} ...".format(idx))
......
......@@ -78,8 +78,8 @@ def regularize_cost(regex, func, name='regularize_cost'):
return name[prefixlen:]
return name
names = list(map(f, names))
logger.info("regularize_cost() found {} tensors.".format(len(names)))
_log_once("Applying regularizer for {}".format(', '.join(names)))
logger.info("regularize_cost() applying regularizers on {} tensors.".format(len(names)))
_log_once("The following tensors will be regularized: {}".format(', '.join(names)))
return tf.add_n(costs, name=name)
......@@ -106,7 +106,8 @@ def regularize_cost_from_collection(name='regularize_cost'):
else:
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0:
logger.info("regularize_cost_from_collection() found {} tensors in REGULARIZATION_LOSSES.".format(len(losses)))
logger.info("regularize_cost_from_collection() applying regularizers on "
"{} tensors in REGULARIZATION_LOSSES.".format(len(losses)))
reg_loss = tf.add_n(losses, name=name)
return reg_loss
else:
......
......@@ -24,42 +24,38 @@ def humanize_time_delta(sec):
"""Humanize timedelta given in seconds
Args:
sec (float): time difference in seconds.
sec (float): time difference in seconds. Must be positive.
Examples:
Returns:
str - time difference as a readable string
Several time differences as a human readable string
Examples:
.. code-block:: python
print humanize_seconds(1) # 1 second
print humanize_seconds(60 + 1) # 1 minute 1 second
print humanize_seconds(87.6) # 1 minute 27 seconds
print humanize_seconds(0.01) # 0.01 seconds
print humanize_seconds(60 * 60 + 1) # 1 hour 0 minutes 1 second
print humanize_seconds(60 * 60 * 24 + 1) # 1 day 0 hours 0 minutes 1 second
print humanize_seconds(60 * 60 * 24 + 60 * 2 + 60*60*9+ 3) # 1 day 9 hours 2 minutes 3 seconds
Returns:
time difference as a readable string
print(humanize_time_delta(1)) # 1 second
print(humanize_time_delta(60 + 1)) # 1 minute 1 second
print(humanize_time_delta(87.6)) # 1 minute 27 seconds
print(humanize_time_delta(0.01)) # 0.01 seconds
print(humanize_time_delta(60 * 60 + 1)) # 1 hour 1 second
print(humanize_time_delta(60 * 60 * 24 + 1)) # 1 day 1 second
print(humanize_time_delta(60 * 60 * 24 + 60 * 2 + 60*60*9 + 3)) # 1 day 9 hours 2 minutes 3 seconds
"""
assert sec >= 0, sec
if sec == 0:
return "0 second"
time = datetime(2000, 1, 1) + timedelta(seconds=int(sec))
units = ['day', 'hour', 'minute', 'second']
vals = [time.day - 1, time.hour, time.minute, time.second]
vals = [int(sec // 86400), time.hour, time.minute, time.second]
if sec < 60:
vals[-1] = sec
def _format(v, u):
return "{:.3g} {}{}".format(v, u, "s" if v > 1 else "")
required = False
ans = []
for v, u in zip(vals, units):
if not required:
if v > 0:
required = True
ans.append(_format(v, u))
else:
ans.append(_format(v, u))
return " ".join(ans)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment