Commit 035b3ae0 authored by Yuxin Wu's avatar Yuxin Wu

use named arguments in PredictConfig.

parent e21fc267
## Breaking API changes.
tensorpack is still in early development, and API changes can happen.
Usually the backward compatibilty is preserved for several month, with a deprecation warning.
If you are an early bird to try out this library, you might need to occasionally update your code.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
* 2017/01/06. `summary.add_moving_summary` now takes any number of positional arguments instead of a list.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/bbf41d9e58053f843d0471e6d2d87ff714a79a90) to change your code.
* 2017/01/05. The argument `TrainConfig(dataset=)` is renamed to `TrainConfig(dataflow=)`.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/651a5aea8f9aacad7147542021dcf106fc824bc2) to change your code.
* 2016/11/06. The inferencer `ClassificationError` now expects the vector tensor returned by
`prediction_incorrect` instead of the "wrong" tensor. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/740e9d8ca146af5a911f68a369dd7348243a2253)
to make changes.
* 2016/10/17. `Conv2D` and `FullyConnect` use `tf.identity` by default instead of `tf.nn.relu`.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/6eb0bebe60d6f38bcad9ddb3e6091b0b154a09cf).
* 2016/09/01. The method `_build_graph` of `ModelDesc` doesn't takes `is_training` argument anymore.
The `is_training` attribute can be obtained from tower context. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/fc9e45b0208ff09daf454d3bd910c540735b7f83).
* 2016/05/15. The method `_get_cost` of `ModelDesc` is replaced by `_build_graph`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/e69034b5c9b588db9fb52295b1e63c89e8b42654).
......@@ -21,9 +21,13 @@ __all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar']
class StepStatPrinter(Callback):
""" It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use
:func:`print_stat` or :func:`tf.Print` instead. """
:func:`symbolic_functions.print_stat` or :func:`tf.Print` instead. """
def __init__(self, names):
"""
Args:
names(list): list of string, the names of the tensor to print.
"""
names = [get_op_tensor_name(n)[1] for n in names]
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!")
self._names = names
......@@ -79,4 +83,4 @@ class ProgressBar(Callback):
self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update()
if self.step_num == self._total - 1:
self._bar.__exit__()
self._bar.close()
......@@ -17,6 +17,8 @@ from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
# TODO "variable" is not a right name to use across this file.
class InputVar(object):
""" Store metadata about input placeholders. """
def __init__(self, type, shape, name, sparse=False):
......
......@@ -4,7 +4,7 @@
import six
from tensorpack.models import ModelDesc
from ..models import ModelDesc
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
......@@ -12,12 +12,19 @@ __all__ = ['PredictConfig']
class PredictConfig(object):
def __init__(self, **kwargs):
def __init__(self, model, session_init=None,
session_config=get_default_sess_config(0.4),
input_names=None,
output_names=None,
return_input=False):
"""
Args:
session_init (SessionInit): how to initialize variables of the session.
model (ModelDesc): the model to use.
input_names (list): a list of input tensor names.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
session_config]
input_names (list): a list of input tensor names. Defaults to all
inputs of the model.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
......@@ -25,34 +32,29 @@ class PredictConfig(object):
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.model = model
assert_type(self.model, ModelDesc)
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.4))
self.session_init = kwargs.pop('session_init', JustCurrentSession())
self.session_config = session_config
if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
# inputs & outputs
# TODO add deprecated warning later
self.input_names = kwargs.pop('input_names', None)
if self.input_names is None:
self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None:
pass
# logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
self.input_names = input_names
if self.input_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
self.input_names = [k.name for k in raw_vars]
self.output_names = kwargs.pop('output_names', None)
if self.output_names is None:
self.output_names = kwargs.pop('output_var_names')
# logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
self.output_names = output_names
assert_type(self.output_names, list)
assert_type(self.input_names, list)
assert len(self.input_names), self.input_names
for v in self.input_names:
assert_type(v, six.string_types)
assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
self.return_input = bool(return_input)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment