Commit 1be49dc9 authored by Yuxin Wu's avatar Yuxin Wu

fix bug of last commit

parent d2939cf8
......@@ -5,9 +5,9 @@
from .base import Trainer
from ..tfutils.tower import TowerContext
from ..utils import logger
from ..input_source import FeedInput, QueueInput
from ..graph_builder.training import SimpleBuilder
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
......@@ -42,8 +42,12 @@ class SimpleTrainer(Trainer):
def _setup(self):
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = SimpleBuilder().build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
with TowerContext('', is_training=True):
grads = self.model.build_graph_get_grads(
*self._input_source.get_input_tensors())
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
self._config.callbacks.extend(cbs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment