Commit d869aec8 authored by Yuxin Wu's avatar Yuxin Wu

replicated trainer. (didn't work for inference

parent 3f05b530
...@@ -38,10 +38,6 @@ class RunOp(Callback): ...@@ -38,10 +38,6 @@ class RunOp(Callback):
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
def _before_run(self, _):
if self.run_step:
return [self._op]
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
self._op.run() self._op.run()
...@@ -50,6 +46,10 @@ class RunOp(Callback): ...@@ -50,6 +46,10 @@ class RunOp(Callback):
if self.run_as_trigger: if self.run_as_trigger:
self._op.run() self._op.run()
def _before_run(self, _):
if self.run_step:
return [self._op]
class RunUpdateOps(RunOp): class RunUpdateOps(RunOp):
""" """
......
...@@ -223,7 +223,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -223,7 +223,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU. # maintain EMA only on one GPU.
if ctx.is_main_training_tower: if ctx.is_main_training_tower or ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
......
...@@ -66,8 +66,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -66,8 +66,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# GATE_NONE faster? # GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables:
# only optimize w.r.t vars in this tower
varlist = [v for v in varlist if v.op.name.startswith(ctx.name + '/')]
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, cost,
var_list=varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True) colocate_gradients_with_ops=True)
return cost, grads return cost, grads
......
...@@ -155,6 +155,7 @@ class QueueInput(FeedfreeInput): ...@@ -155,6 +155,7 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed # TODO use input data mapping. not all placeholders are needed
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with input placeholders!" "QueueInput has to be used with input placeholders!"
...@@ -200,6 +201,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -200,6 +201,7 @@ class BatchQueueInput(FeedfreeInput):
return self.ds.size() // self.batch_size return self.ds.size() // self.batch_size
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with input placeholders!" "BatchQueueInput has to be used with input placeholders!"
...@@ -385,6 +387,7 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -385,6 +387,7 @@ class StagingInputWrapper(FeedfreeInput):
self.get_stage_op(), self.get_unstage_op(), self._nr_stage)) self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self): def setup_staging_areas(self):
logger.info("Setting up the StageAreas for GPU prefetching ...")
for idx, device in enumerate(self._devices): for idx, device in enumerate(self._devices):
with tf.device(device): with tf.device(device):
inputs = self._input.get_input_tensors() inputs = self._input.get_input_tensors()
......
This diff is collapsed.
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# File: trainer.py # File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer from .base import Trainer
from ..utils import logger from ..utils import logger
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment