Commit 035b3ae0 authored by Yuxin Wu's avatar Yuxin Wu

use named arguments in PredictConfig.

parent e21fc267
## Breaking API changes.
tensorpack is still in early development, and API changes can happen.
Usually the backward compatibilty is preserved for several month, with a deprecation warning.
If you are an early bird to try out this library, you might need to occasionally update your code.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
* 2017/01/06. `summary.add_moving_summary` now takes any number of positional arguments instead of a list.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/bbf41d9e58053f843d0471e6d2d87ff714a79a90) to change your code.
* 2017/01/05. The argument `TrainConfig(dataset=)` is renamed to `TrainConfig(dataflow=)`.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/651a5aea8f9aacad7147542021dcf106fc824bc2) to change your code.
* 2016/11/06. The inferencer `ClassificationError` now expects the vector tensor returned by
`prediction_incorrect` instead of the "wrong" tensor. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/740e9d8ca146af5a911f68a369dd7348243a2253)
to make changes.
* 2016/10/17. `Conv2D` and `FullyConnect` use `tf.identity` by default instead of `tf.nn.relu`.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/6eb0bebe60d6f38bcad9ddb3e6091b0b154a09cf).
* 2016/09/01. The method `_build_graph` of `ModelDesc` doesn't takes `is_training` argument anymore.
The `is_training` attribute can be obtained from tower context. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/fc9e45b0208ff09daf454d3bd910c540735b7f83).
* 2016/05/15. The method `_get_cost` of `ModelDesc` is replaced by `_build_graph`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/e69034b5c9b588db9fb52295b1e63c89e8b42654).
...@@ -21,9 +21,13 @@ __all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar'] ...@@ -21,9 +21,13 @@ __all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar']
class StepStatPrinter(Callback): class StepStatPrinter(Callback):
""" It prints the value of some tensors in each step. """ It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use It's just a demo of how trigger_step works but you should in general use
:func:`print_stat` or :func:`tf.Print` instead. """ :func:`symbolic_functions.print_stat` or :func:`tf.Print` instead. """
def __init__(self, names): def __init__(self, names):
"""
Args:
names(list): list of string, the names of the tensor to print.
"""
names = [get_op_tensor_name(n)[1] for n in names] names = [get_op_tensor_name(n)[1] for n in names]
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!") logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!")
self._names = names self._names = names
...@@ -79,4 +83,4 @@ class ProgressBar(Callback): ...@@ -79,4 +83,4 @@ class ProgressBar(Callback):
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update() self._bar.update()
if self.step_num == self._total - 1: if self.step_num == self._total - 1:
self._bar.__exit__() self._bar.close()
...@@ -17,6 +17,8 @@ from ..tfutils.tower import get_current_tower_context ...@@ -17,6 +17,8 @@ from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph'] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
# TODO "variable" is not a right name to use across this file.
class InputVar(object): class InputVar(object):
""" Store metadata about input placeholders. """ """ Store metadata about input placeholders. """
def __init__(self, type, shape, name, sparse=False): def __init__(self, type, shape, name, sparse=False):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import six import six
from tensorpack.models import ModelDesc from ..models import ModelDesc
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
...@@ -12,12 +12,19 @@ __all__ = ['PredictConfig'] ...@@ -12,12 +12,19 @@ __all__ = ['PredictConfig']
class PredictConfig(object): class PredictConfig(object):
def __init__(self, **kwargs): def __init__(self, model, session_init=None,
session_config=get_default_sess_config(0.4),
input_names=None,
output_names=None,
return_input=False):
""" """
Args: Args:
session_init (SessionInit): how to initialize variables of the session.
model (ModelDesc): the model to use. model (ModelDesc): the model to use.
input_names (list): a list of input tensor names. session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
session_config]
input_names (list): a list of input tensor names. Defaults to all
inputs of the model.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph. tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`. return_input: same as in :attr:`PredictorBase.return_input`.
...@@ -25,34 +32,29 @@ class PredictConfig(object): ...@@ -25,34 +32,29 @@ class PredictConfig(object):
# TODO use the name "tensor" instead of "variable" # TODO use the name "tensor" instead of "variable"
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.model = model
assert_type(self.model, ModelDesc)
# XXX does it work? start with minimal memory, but allow growth. # XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF. # allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.4)) self.session_config = session_config
self.session_init = kwargs.pop('session_init', JustCurrentSession()) if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
# inputs & outputs # inputs & outputs
# TODO add deprecated warning later self.input_names = input_names
self.input_names = kwargs.pop('input_names', None)
if self.input_names is None:
self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None:
pass
# logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if self.input_names is None: if self.input_names is None:
# neither options is set, assume all inputs # neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc() raw_vars = self.model.get_input_vars_desc()
self.input_names = [k.name for k in raw_vars] self.input_names = [k.name for k in raw_vars]
self.output_names = kwargs.pop('output_names', None) self.output_names = output_names
if self.output_names is None: assert_type(self.output_names, list)
self.output_names = kwargs.pop('output_var_names') assert_type(self.input_names, list)
# logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names assert len(self.input_names), self.input_names
for v in self.input_names: for v in self.input_names:
assert_type(v, six.string_types) assert_type(v, six.string_types)
assert len(self.output_names), self.output_names assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False) self.return_input = bool(return_input)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment