Commit fd9edc3b authored by Yuxin Wu's avatar Yuxin Wu

Remove unnecessary var_strategy checks in trainers.

parent 9b707d91
...@@ -8,7 +8,7 @@ or write an issue to see if there is a better solution than creating new trainer ...@@ -8,7 +8,7 @@ or write an issue to see if there is a better solution than creating new trainer
For certain tasks, you do need a new trainer. For certain tasks, you do need a new trainer.
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration. Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
The existing common trainers do two things: The existing common trainers all do two things:
1. Setup the graph and input pipeline, from `TrainConfig`. 1. Setup the graph and input pipeline, from `TrainConfig`.
2. Minimize `model.cost` in each iteration. 2. Minimize `model.cost` in each iteration.
...@@ -16,11 +16,12 @@ But you can customize it by using the base `Trainer` class. ...@@ -16,11 +16,12 @@ But you can customize it by using the base `Trainer` class.
* To customize the graph: * To customize the graph:
Create the graph, add any tensors and ops either before creating the trainer or inside `Trainer.__init__`. Add any tensors and ops you like, either before creating the trainer or inside `Trainer.__init__`.
In this case you don't need to set model/data in `TrainConfig` any more.
* Two ways to customize the iteration: * Two ways to customize the iteration:
1. Set `Trainer.train_op`. This op will be run by default. 1. Set `Trainer.train_op`. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method. This way you can run more ops in one iteration. 2. Subclass `Trainer` and override the `run_step()` method. This way you can do something more than running an op.
There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference. There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference.
...@@ -213,11 +213,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -213,11 +213,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope()) xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
if ctx.has_own_variables: if ctx.has_own_variables:
# only apply update in this case # Only apply update in this case.
# Add these EMA to model_variables so that they will be synced
# properly by replicated trainers.
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
add_model_variable(v) add_model_variable(v)
else: else:
# don't need update if we are sharing variables from an old tower # Don't need update if we are sharing variables from an existing tower
restore_collection(coll_bk) restore_collection(coll_bk)
if ndims == 2: if ndims == 2:
......
...@@ -199,7 +199,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -199,7 +199,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
lambda: MultiGPUTrainerBase._build_graph_get_grads( lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), self.model, self._input_source),
devices=self.raw_devices, devices=self.raw_devices,
var_strategy='replicated',
vs_names=[True] * self.config.nr_tower) # open vs at each tower vs_names=[True] * self.config.nr_tower) # open vs at each tower
MultiGPUTrainerBase._check_grad_list(grad_list) MultiGPUTrainerBase._check_grad_list(grad_list)
......
...@@ -50,14 +50,13 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -50,14 +50,13 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
@staticmethod @staticmethod
def build_on_multi_tower( def build_on_multi_tower(
towers, func, towers, func,
devices=None, var_strategy='shared', devices=None,
use_vs=None): use_vs=None):
""" """
Args: Args:
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``. devices: a list of devices to be used. By default will use GPUs in ``towers``.
var_strategy (str): 'shared' or 'replicated'
use_vs (list[bool]): list of use_vs to passed to TowerContext use_vs (list[bool]): list of use_vs to passed to TowerContext
Returns: Returns:
...@@ -73,11 +72,6 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -73,11 +72,6 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
tower_names = ['tower{}'.format(idx) for idx in range(len(towers))] tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
keys_to_freeze = TOWER_FREEZE_KEYS[:] keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
else:
assert use_vs is None
if use_vs is None: if use_vs is None:
use_vs = [False] * len(towers) use_vs = [False] * len(towers)
assert len(use_vs) == len(towers) assert len(use_vs) == len(towers)
...@@ -308,7 +302,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -308,7 +302,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
tower, tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input), lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
var_strategy='replicated',
# use no variable scope for the first tower # use no variable scope for the first tower
use_vs=[False] + [True] * (len(tower) - 1)) use_vs=[False] + [True] * (len(tower) - 1))
grads = SyncMultiGPUTrainerReplicated._allreduce_grads(grad_list) grads = SyncMultiGPUTrainerReplicated._allreduce_grads(grad_list)
......
...@@ -15,7 +15,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS' ...@@ -15,7 +15,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS] TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment