Commit 91f3a441 authored by Yuxin Wu's avatar Yuxin Wu

remove prefix from InputDesc.build_placeholder

parent ebf2332b
......@@ -39,21 +39,17 @@ class InputDesc(
self._cached_placeholder = None
return self
def build_placeholder(self, prefix=''):
def build_placeholder(self):
"""
Build a tf.placeholder from the metadata, with an optional prefix.
Args:
prefix(str): the name of the placeholder will be ``prefix + self.name``
Build a tf.placeholder from the metadata.
Returns:
tf.Tensor:
"""
with tf.name_scope(None): # clear any name scope it might get called in
ret = tf.placeholder(
self.type, shape=self.shape,
name=prefix + self.name)
if prefix == '' and self._cached_placeholder is None:
self.type, shape=self.shape, name=self.name)
if self._cached_placeholder is None:
self._cached_placeholder = ret # cached_placeholder only caches the prefix='' case
return ret
......@@ -97,8 +93,7 @@ class ModelDescBase(object):
"""
Build the whole symbolic graph.
This is supposed to be the "tower function" when used with :class:`TowerTrainer`.
By default it will call :meth:`_build_graph`
with a list of input tensors.
By default it will call :meth:`_build_graph` with a list of input tensors, for backward-compatibility.
Args:
args ([tf.Tensor]): tensors that matches the list of
......@@ -108,10 +103,11 @@ class ModelDescBase(object):
arg = args[0]
if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future?
log_deprecated("build_graph(InputSource)", "Call with tensors in positional args instead.")
log_deprecated("build_graph(InputSource)",
"Call with tensors in positional args instead.", "2018-03-31")
elif isinstance(arg, (list, tuple)):
inputs = arg
log_deprecated("build_graph([Tensor])", "Call with positional args instead.")
log_deprecated("build_graph([Tensor])", "Call with positional args instead.", "2018-03-31")
else:
inputs = [arg]
else:
......
......@@ -39,15 +39,8 @@ class PlaceholderInput(InputSource):
"""
Just produce placeholders as input tensors.
"""
def __init__(self, prefix=''):
"""
Args:
prefix(str): an optional prefix to add to the placeholder.
"""
self._prefix = prefix
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder(prefix=self._prefix) for v in inputs]
self._all_placehdrs = [v.build_placeholder() for v in inputs]
def _get_input_tensors(self):
return self._all_placehdrs
......
......@@ -6,6 +6,7 @@
import tensorflow as tf
from ..utils import logger
from ..graph_builder.predict import SimplePredictBuilder
from ..graph_builder.model_desc import InputDesc
from ..input_source import PlaceholderInput
from .base import OnlinePredictor
......@@ -93,8 +94,11 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for idx, t in enumerate(towers):
tower_name = 'tower' + str(t)
input = PlaceholderInput(tower_name + '/')
input.setup(config.inputs_desc)
inputs_desc = [InputDesc(desc.type, desc.shape, tower_name + '/' + desc.name)
for desc in config.inputs_desc]
input = PlaceholderInput()
input.setup(inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
builder = SimplePredictBuilder(ns_name=tower_name, device=t)
......
......@@ -93,9 +93,11 @@ class Trainer(object):
""" Base class for a trainer.
"""
_API_VERSION = 2
is_chief = True
"""
Whether this process is the chief worker in distributed training.
Certain callbacks will only be run by chief worker.
"""
def __init__(self, config=None):
"""
......
......@@ -40,8 +40,6 @@ class Trainer(object):
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
_API_VERSION = 1
is_chief = True
"""
Whether this process is the chief worker in distributed training.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment