Commit a4d4eafc authored by Yuxin Wu's avatar Yuxin Wu

misc small change

parent ed702d1d
...@@ -252,6 +252,8 @@ class TowerFuncWrapper(object): ...@@ -252,6 +252,8 @@ class TowerFuncWrapper(object):
each time the function is called. each time the function is called.
:class:`TowerTrainer` needs this so that it knows how to build a predictor. :class:`TowerTrainer` needs this so that it knows how to build a predictor.
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
""" """
def __init__(self, tower_fn, inputs_desc): def __init__(self, tower_fn, inputs_desc):
......
...@@ -46,7 +46,8 @@ class TrainLoop(object): ...@@ -46,7 +46,8 @@ class TrainLoop(object):
self.starting_epoch = int(starting_epoch) self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
self.steps_per_epoch = int(steps_per_epoch) self.steps_per_epoch = int(steps_per_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0 # Allow empty epoch (no steps), if we want to run the callbacks only.
assert self.steps_per_epoch >= 0 and self.max_epoch >= 0
self._epoch_num = starting_epoch - 1 self._epoch_num = starting_epoch - 1
......
...@@ -375,6 +375,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -375,6 +375,7 @@ class HorovodTrainer(SingleCostTrainer):
hvd.init() hvd.init()
self.is_chief = hvd.rank() == 0 self.is_chief = hvd.rank() == 0
self._local_rank = hvd.local_rank() self._local_rank = hvd.local_rank()
self._rank = hvd.rank()
self._average = average self._average = average
logger.info("[HorovodTrainer] local rank={}".format(self._local_rank)) logger.info("[HorovodTrainer] local rank={}".format(self._local_rank))
super(HorovodTrainer, self).__init__() super(HorovodTrainer, self).__init__()
...@@ -435,7 +436,10 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -435,7 +436,10 @@ class HorovodTrainer(SingleCostTrainer):
# TODO: # TODO:
# 1. a allgather helper to concat strings # 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set. # 2. check variables on each rank match each other, print warnings, and broadcast the common set.
logger.info("Broadcasting initialized variables ...") if self.is_chief:
logger.info("Broadcasting initialized variables ...")
else:
logger.info("Rank {} waiting for initialization broadcasting ...".format(self._rank))
self.sess.run(self._broadcast_op) self.sess.run(self._broadcast_op)
......
...@@ -179,8 +179,7 @@ def _pick_tqdm_interval(file): ...@@ -179,8 +179,7 @@ def _pick_tqdm_interval(file):
return 15 return 15
if 'OMPI_COMM_WORLD_SIZE' in os.environ: if 'OMPI_COMM_WORLD_SIZE' in os.environ:
if int(os.environ['OMPI_COMM_WORLD_SIZE']) > 8: return 60
return 60
# If not a tty, don't refresh progress bar that often # If not a tty, don't refresh progress bar that often
return 180 return 180
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment