Commit d869aec8 authored by Yuxin Wu's avatar Yuxin Wu

replicated trainer. (didn't work for inference

parent 3f05b530
......@@ -38,10 +38,6 @@ class RunOp(Callback):
def _setup_graph(self):
self._op = self.setup_func()
def _before_run(self, _):
if self.run_step:
return [self._op]
def _before_train(self):
if self.run_before:
self._op.run()
......@@ -50,6 +46,10 @@ class RunOp(Callback):
if self.run_as_trigger:
self._op.run()
def _before_run(self, _):
if self.run_step:
return [self._op]
class RunUpdateOps(RunOp):
"""
......
......@@ -223,7 +223,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU.
if ctx.is_main_training_tower:
if ctx.is_main_training_tower or ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
ret = tf.identity(xn, name='output')
......
......@@ -66,8 +66,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
cost = self.model.get_cost() # assume single cost
opt = self.model.get_optimizer()
# GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables:
# only optimize w.r.t vars in this tower
varlist = [v for v in varlist if v.op.name.startswith(ctx.name + '/')]
grads = opt.compute_gradients(
cost,
var_list=varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost, grads
......
......@@ -155,6 +155,7 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with input placeholders!"
......@@ -200,6 +201,7 @@ class BatchQueueInput(FeedfreeInput):
return self.ds.size() // self.batch_size
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with input placeholders!"
......@@ -385,6 +387,7 @@ class StagingInputWrapper(FeedfreeInput):
self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self):
logger.info("Setting up the StageAreas for GPU prefetching ...")
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
......
This diff is collapsed.
......@@ -2,6 +2,9 @@
# File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer
from ..utils import logger
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment