Commit a397cebc authored by Yuxin Wu's avatar Yuxin Wu

add back build_train_tower for compatibility

parent 3e1f0e53
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from six.moves import zip from six.moves import zip
from ..utils import logger
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput from .input_source import QueueInput, FeedfreeInput
...@@ -21,6 +22,13 @@ class FeedfreeTrainerBase(Trainer): ...@@ -21,6 +22,13 @@ class FeedfreeTrainerBase(Trainer):
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``self.data`` to be a :class:`FeedfreeInput`.
""" """
# TODO deprecated
def build_train_tower(self):
logger.warn("build_train_tower() was deprecated! Please build the graph "
"yourself, e.g. by self.model.build_graph(self._input_source)")
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
def _setup(self): def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._setup_input_source(self._input_source) self._setup_input_source(self._input_source)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment