Commit a4d4eafc authored by Yuxin Wu's avatar Yuxin Wu

misc small change

parent ed702d1d
......@@ -252,6 +252,8 @@ class TowerFuncWrapper(object):
each time the function is called.
:class:`TowerTrainer` needs this so that it knows how to build a predictor.
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
"""
def __init__(self, tower_fn, inputs_desc):
......
......@@ -46,7 +46,8 @@ class TrainLoop(object):
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
self.steps_per_epoch = int(steps_per_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0
# Allow empty epoch (no steps), if we want to run the callbacks only.
assert self.steps_per_epoch >= 0 and self.max_epoch >= 0
self._epoch_num = starting_epoch - 1
......
......@@ -375,6 +375,7 @@ class HorovodTrainer(SingleCostTrainer):
hvd.init()
self.is_chief = hvd.rank() == 0
self._local_rank = hvd.local_rank()
self._rank = hvd.rank()
self._average = average
logger.info("[HorovodTrainer] local rank={}".format(self._local_rank))
super(HorovodTrainer, self).__init__()
......@@ -435,7 +436,10 @@ class HorovodTrainer(SingleCostTrainer):
# TODO:
# 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set.
if self.is_chief:
logger.info("Broadcasting initialized variables ...")
else:
logger.info("Rank {} waiting for initialization broadcasting ...".format(self._rank))
self.sess.run(self._broadcast_op)
......
......@@ -179,7 +179,6 @@ def _pick_tqdm_interval(file):
return 15
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
if int(os.environ['OMPI_COMM_WORLD_SIZE']) > 8:
return 60
# If not a tty, don't refresh progress bar that often
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment