Commit 4adbaa94 authored by Yuxin Wu's avatar Yuxin Wu

improve docs

parent 43f7ca75
...@@ -80,7 +80,7 @@ def get_config(model, fake=False): ...@@ -80,7 +80,7 @@ def get_config(model, fake=False):
EstimatedTimeLeft(), EstimatedTimeLeft(),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2), 'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
(85, BASE_LR * 1e-3), (95, BASE_LR * 1e-4), (105, BASE_LR * 1e-5)]), (90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
] ]
if BASE_LR > 0.1: if BASE_LR > 0.1:
callbacks.append( callbacks.append(
...@@ -102,7 +102,7 @@ def get_config(model, fake=False): ...@@ -102,7 +102,7 @@ def get_config(model, fake=False):
dataflow=dataset_train, dataflow=dataset_train,
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=100 if args.fake else 1280000 // args.batch, steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=110, max_epoch=105,
) )
...@@ -115,7 +115,7 @@ if __name__ == '__main__': ...@@ -115,7 +115,7 @@ if __name__ == '__main__':
parser.add_argument('--data_format', help='specify NCHW or NHWC', parser.add_argument('--data_format', help='specify NCHW or NHWC',
type=str, default='NCHW') type=str, default='NCHW')
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152]) type=int, default=50, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
parser.add_argument('--batch', default=256, type=int, parser.add_argument('--batch', default=256, type=int,
help='total batch size. 32 per GPU gives best accuracy, higher values should be similarly good') help='total batch size. 32 per GPU gives best accuracy, higher values should be similarly good')
......
...@@ -78,10 +78,8 @@ class InferenceRunnerBase(Callback): ...@@ -78,10 +78,8 @@ class InferenceRunnerBase(Callback):
try: try:
self._size = input.size() self._size = input.size()
logger.info("InferenceRunner will eval {} iterations".format(input.size()))
except NotImplementedError: except NotImplementedError:
self._size = 0 self._size = 0
logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!")
self._hooks = [] self._hooks = []
...@@ -95,6 +93,10 @@ class InferenceRunnerBase(Callback): ...@@ -95,6 +93,10 @@ class InferenceRunnerBase(Callback):
def _before_train(self): def _before_train(self):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train() self._input_callbacks.before_train()
if self._size > 0:
logger.info("InferenceRunner will eval {} iterations".format(self._size))
else:
logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!")
def _after_train(self): def _after_train(self):
self._input_callbacks.after_train() self._input_callbacks.after_train()
......
...@@ -94,4 +94,5 @@ class EstimatedTimeLeft(Callback): ...@@ -94,4 +94,5 @@ class EstimatedTimeLeft(Callback):
average_epoch_time = np.mean(self._times) average_epoch_time = np.mean(self._times)
time_left = (self._max_epoch - self.epoch_num) * average_epoch_time time_left = (self._max_epoch - self.epoch_num) * average_epoch_time
if time_left > 0:
logger.info("Estimated Time Left: " + humanize_time_delta(time_left)) logger.info("Estimated Time Left: " + humanize_time_delta(time_left))
...@@ -114,7 +114,7 @@ def _guess_dir_structure(dir): ...@@ -114,7 +114,7 @@ def _guess_dir_structure(dir):
else: else:
dir_structure = 'original' dir_structure = 'original'
logger.info( logger.info(
"Assuming directory {} has {} structure.".format( "[ILSVRC12] Assuming directory {} has '{}' structure.".format(
dir, dir_structure)) dir, dir_structure))
return dir_structure return dir_structure
......
...@@ -12,7 +12,6 @@ from six.moves import zip, range ...@@ -12,7 +12,6 @@ from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.common import get_tf_version_number
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from .utils import ( from .utils import (
...@@ -39,16 +38,10 @@ class DataParallelBuilder(GraphBuilder): ...@@ -39,16 +38,10 @@ class DataParallelBuilder(GraphBuilder):
towers(list[int]): list of GPU ids. towers(list[int]): list of GPU ids.
""" """
if len(towers) > 1: if len(towers) > 1:
logger.info("Training a model of {} towers".format(len(towers))) logger.info("[DataParallel] Training a model of {} towers.".format(len(towers)))
DataParallelBuilder._check_tf_version()
self.towers = towers self.towers = towers
@staticmethod
def _check_tf_version():
assert get_tf_version_number() >= 1.1, \
"TF version {} is too old to run multi GPU training!".format(tf.VERSION)
@staticmethod @staticmethod
def _check_grad_list(grad_list): def _check_grad_list(grad_list):
""" """
...@@ -103,7 +96,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -103,7 +96,7 @@ class DataParallelBuilder(GraphBuilder):
index=idx, index=idx,
vs_name=tower_names[idx] if usevs else ''): vs_name=tower_names[idx] if usevs else ''):
if len(str(device)) < 10: # a device function doesn't have good string description if len(str(device)) < 10: # a device function doesn't have good string description
logger.info("Building graph for training tower {} on device {}...".format(idx, device)) logger.info("Building graph for training tower {} on device {} ...".format(idx, device))
else: else:
logger.info("Building graph for training tower {} ...".format(idx)) logger.info("Building graph for training tower {} ...".format(idx))
......
...@@ -78,8 +78,8 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -78,8 +78,8 @@ def regularize_cost(regex, func, name='regularize_cost'):
return name[prefixlen:] return name[prefixlen:]
return name return name
names = list(map(f, names)) names = list(map(f, names))
logger.info("regularize_cost() found {} tensors.".format(len(names))) logger.info("regularize_cost() applying regularizers on {} tensors.".format(len(names)))
_log_once("Applying regularizer for {}".format(', '.join(names))) _log_once("The following tensors will be regularized: {}".format(', '.join(names)))
return tf.add_n(costs, name=name) return tf.add_n(costs, name=name)
...@@ -106,7 +106,8 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -106,7 +106,8 @@ def regularize_cost_from_collection(name='regularize_cost'):
else: else:
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0: if len(losses) > 0:
logger.info("regularize_cost_from_collection() found {} tensors in REGULARIZATION_LOSSES.".format(len(losses))) logger.info("regularize_cost_from_collection() applying regularizers on "
"{} tensors in REGULARIZATION_LOSSES.".format(len(losses)))
reg_loss = tf.add_n(losses, name=name) reg_loss = tf.add_n(losses, name=name)
return reg_loss return reg_loss
else: else:
......
...@@ -24,42 +24,38 @@ def humanize_time_delta(sec): ...@@ -24,42 +24,38 @@ def humanize_time_delta(sec):
"""Humanize timedelta given in seconds """Humanize timedelta given in seconds
Args: Args:
sec (float): time difference in seconds. sec (float): time difference in seconds. Must be positive.
Examples: Returns:
str - time difference as a readable string
Several time differences as a human readable string Examples:
.. code-block:: python .. code-block:: python
print humanize_seconds(1) # 1 second print(humanize_time_delta(1)) # 1 second
print humanize_seconds(60 + 1) # 1 minute 1 second print(humanize_time_delta(60 + 1)) # 1 minute 1 second
print humanize_seconds(87.6) # 1 minute 27 seconds print(humanize_time_delta(87.6)) # 1 minute 27 seconds
print humanize_seconds(0.01) # 0.01 seconds print(humanize_time_delta(0.01)) # 0.01 seconds
print humanize_seconds(60 * 60 + 1) # 1 hour 0 minutes 1 second print(humanize_time_delta(60 * 60 + 1)) # 1 hour 1 second
print humanize_seconds(60 * 60 * 24 + 1) # 1 day 0 hours 0 minutes 1 second print(humanize_time_delta(60 * 60 * 24 + 1)) # 1 day 1 second
print humanize_seconds(60 * 60 * 24 + 60 * 2 + 60*60*9+ 3) # 1 day 9 hours 2 minutes 3 seconds print(humanize_time_delta(60 * 60 * 24 + 60 * 2 + 60*60*9 + 3)) # 1 day 9 hours 2 minutes 3 seconds
Returns:
time difference as a readable string
""" """
assert sec >= 0, sec
if sec == 0:
return "0 second"
time = datetime(2000, 1, 1) + timedelta(seconds=int(sec)) time = datetime(2000, 1, 1) + timedelta(seconds=int(sec))
units = ['day', 'hour', 'minute', 'second'] units = ['day', 'hour', 'minute', 'second']
vals = [time.day - 1, time.hour, time.minute, time.second] vals = [int(sec // 86400), time.hour, time.minute, time.second]
if sec < 60: if sec < 60:
vals[-1] = sec vals[-1] = sec
def _format(v, u): def _format(v, u):
return "{:.3g} {}{}".format(v, u, "s" if v > 1 else "") return "{:.3g} {}{}".format(v, u, "s" if v > 1 else "")
required = False
ans = [] ans = []
for v, u in zip(vals, units): for v, u in zip(vals, units):
if not required:
if v > 0: if v > 0:
required = True
ans.append(_format(v, u))
else:
ans.append(_format(v, u)) ans.append(_format(v, u))
return " ".join(ans) return " ".join(ans)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment