Commit 6a0bba68 authored by Yuxin Wu's avatar Yuxin Wu

Move ModelDesc inside train/

parent 71c879bc
...@@ -418,6 +418,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -418,6 +418,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# Hide some names that are deprecated or not intended to be used # Hide some names that are deprecated or not intended to be used
if name in _DEPRECATED_NAMES: if name in _DEPRECATED_NAMES:
return True return True
if name in ['__iter__', '__len__', 'reset_state', 'get_data', 'size']: if name in ['__iter__', '__len__', 'reset_state', 'get_data', 'size']:
# skip these methods with empty docstring # skip these methods with empty docstring
if not obj.__doc__ and inspect.isfunction(obj): if not obj.__doc__ and inspect.isfunction(obj):
......
...@@ -2,7 +2,7 @@ tensorpack.graph_builder package ...@@ -2,7 +2,7 @@ tensorpack.graph_builder package
================================ ================================
These are some useful functions if you need to write your own trainers. These are some useful functions if you need to write your own trainers.
Note that they may not be well maintained. Otherwise you probably don't need to use them.
.. automodule:: tensorpack.graph_builder .. automodule:: tensorpack.graph_builder
:members: :members:
......
...@@ -32,7 +32,9 @@ The `official TensorFlow benchmark <https://github.com/tensorflow/benchmarks/tre ...@@ -32,7 +32,9 @@ The `official TensorFlow benchmark <https://github.com/tensorflow/benchmarks/tre
which seems to suggest that you cannot have **performance and ease-of-use together**. which seems to suggest that you cannot have **performance and ease-of-use together**.
However you can have them both in tensorpack. However you can have them both in tensorpack.
Tensorpack uses TensorFlow efficiently, and hides performance details under its APIs. Tensorpack
`uses TensorFlow efficiently <https://github.com/tensorpack/benchmarks/>`_,
and hides performance details under its APIs.
You no longer need to write You no longer need to write
data prefetch, multi-GPU replication, device placement, variables synchronization -- anything that's unrelated to the model itself. data prefetch, multi-GPU replication, device placement, variables synchronization -- anything that's unrelated to the model itself.
You still need to understand graph and learn to write models with TF, but performance is all taken care of by tensorpack. You still need to understand graph and learn to write models with TF, but performance is all taken care of by tensorpack.
...@@ -48,11 +50,11 @@ A High Level Glance ...@@ -48,11 +50,11 @@ A High Level Glance
They will eventually be wrapped under the same ``InputSource`` interface and go through prefetching. They will eventually be wrapped under the same ``InputSource`` interface and go through prefetching.
* You can use any TF-based symbolic function library to define a model, including * You can use any TF-based symbolic function library to define a model, including
a small set of functions within tensorpack. ``ModelDesc`` is an interface to connect the model with the a small set of functions within tensorpack. ``ModelDesc`` is an interface to connect
``InputSource`` interface. the model with the trainers, but you can also use trainers without ``ModelDesc``.
* Tensorpack trainers manage the training loops for you. * Tensorpack trainers manage the training loops for you.
They also include data parallel logic for multi-GPU or distributed training. They also include data parallel logic for multi-GPU and distributed training.
At the same time, you have the power of customization through callbacks. At the same time, you have the power of customization through callbacks.
* Callbacks are like ``tf.train.SessionRunHook``, or plugins. During training, * Callbacks are like ``tf.train.SessionRunHook``, or plugins. During training,
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
There are many other data loading solutions for deep learning. There are many other data loading solutions for deep learning.
Here we explain why you may want to use Tensorpack DataFlow for your own good: Here we explain why you may want to use Tensorpack DataFlow for your own good:
it's easy, and fast (enough). **it's easy, and fast (enough)**.
Note that this article may contain subjective opinions and we're happy to hear different voices. Note that this article may contain subjective opinions and we're happy to hear different voices.
### How Fast Do You Actually Need? ### How Fast Do You Actually Need?
Your data pipeline **only has to be fast enough**. Your data pipeline **only needs to be fast enough**.
In practice, you should always first make sure your data pipeline runs In practice, you should always first make sure your data pipeline runs
asynchronously with your training. asynchronously with your training.
...@@ -20,7 +20,7 @@ interface. ...@@ -20,7 +20,7 @@ interface.
Once you make sure the data pipeline runs async with your training, Once you make sure the data pipeline runs async with your training,
the data pipeline only needs to be as fast as the training. the data pipeline only needs to be as fast as the training.
**Getting faster brings no gains** to overall throughput. **Getting faster brings no gains** to overall throughput.
It only has to be fast enough. It only needs to be fast enough.
If you have used other data loading libraries, you may doubt If you have used other data loading libraries, you may doubt
how easy it is to make data pipeline fast enough with pure Python. how easy it is to make data pipeline fast enough with pure Python.
...@@ -86,11 +86,10 @@ On the other hand, DataFlow is: ...@@ -86,11 +86,10 @@ On the other hand, DataFlow is:
1. **Easy**: Any Python function that produces data can be made a DataFlow and 1. **Easy**: Any Python function that produces data can be made a DataFlow and
used for training. No need for intermediate format when you don't. used for training. No need for intermediate format when you don't.
1. **Flexible**: Since it is in pure Python, you still have the choice to use 1. **Flexible**: Since it is in pure Python, you can use any data format.
a different data format when you need. When you need, you can still easily serialize your dataflow to a single-file
And we have provided tools to easily format with
[serialize a DataFlow](../../modules/dataflow.html#tensorpack.dataflow.LMDBSerializer) [a few lines of code](../../modules/dataflow.html#tensorpack.dataflow.LMDBSerializer).
to a single-file binary format when you need.
### Alternative Data Loading Solutions: ### Alternative Data Loading Solutions:
......
...@@ -19,6 +19,6 @@ if STATICA_HACK: ...@@ -19,6 +19,6 @@ if STATICA_HACK:
from tensorpack.tfutils import * from tensorpack.tfutils import *
from tensorpack.train import * from tensorpack.train import *
from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase from tensorpack.graph_builder import InputDesc # kept for BC
from tensorpack.input_source import * from tensorpack.input_source import *
from tensorpack.predict import * from tensorpack.predict import *
...@@ -224,7 +224,7 @@ def setup_keras_trainer( ...@@ -224,7 +224,7 @@ def setup_keras_trainer(
class KerasModel(object): class KerasModel(object):
def __init__(self, get_model, input_signature=None, target_signature=None, def __init__(self, get_model, input_signature=None, target_signature=None,
input=None, trainer=None, inputs_desc=None, targets_desc=None): input=None, trainer=None):
""" """
Args: Args:
get_model (input1, input2, ... -> keras.Model): get_model (input1, input2, ... -> keras.Model):
...@@ -234,12 +234,7 @@ class KerasModel(object): ...@@ -234,12 +234,7 @@ class KerasModel(object):
target_signature ([tf.TensorSpec]): required. The signature for the targets tensors. target_signature ([tf.TensorSpec]): required. The signature for the targets tensors.
input (InputSource | DataFlow): the InputSource or DataFlow where the input data comes from. input (InputSource | DataFlow): the InputSource or DataFlow where the input data comes from.
trainer (Trainer): the default will check the number of available GPUs and use them all. trainer (Trainer): the default will check the number of available GPUs and use them all.
inputs_desc, targets_desc: deprecated names for `input_signature` and `target_signature`
""" """
if inputs_desc is not None:
input_signature = inputs_desc
if targets_desc is not None:
target_signature = targets_desc
self.get_model = get_model self.get_model = get_model
assert callable(get_model), get_model assert callable(get_model), get_model
self.input_signature = input_signature self.input_signature = input_signature
......
...@@ -33,6 +33,14 @@ class LMDBSerializer(): ...@@ -33,6 +33,14 @@ class LMDBSerializer():
are serialized datapoints. are serialized datapoints.
You will need to ``pip install lmdb`` to use it. You will need to ``pip install lmdb`` to use it.
Example:
.. code-block:: python
LMDBSerializer.save(my_df, "output.lmdb")
new_df = LMDBSerializer.load("output.lmdb", shuffle=True)
""" """
@staticmethod @staticmethod
def save(df, path, write_frequency=5000): def save(df, path, write_frequency=5000):
......
...@@ -10,6 +10,8 @@ if STATICA_HACK: ...@@ -10,6 +10,8 @@ if STATICA_HACK:
from .distributed import * from .distributed import *
from .utils import * from .utils import *
from .model_desc import InputDesc, ModelDesc, ModelDescBase # kept for BC # noqa
from pkgutil import iter_modules from pkgutil import iter_modules
import os import os
import os.path import os.path
......
...@@ -5,16 +5,11 @@ ...@@ -5,16 +5,11 @@
from collections import namedtuple from collections import namedtuple
import tensorflow as tf import tensorflow as tf
from ..utils.develop import log_deprecated, HIDE_DOC from ..utils.develop import log_deprecated
from ..utils.argtools import memoized_method from ..train.model_desc import ModelDesc, ModelDescBase # kept for BC # noqa
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..compat import backport_tensor_spec, tfv1
TensorSpec = backport_tensor_spec()
__all__ = ['InputDesc']
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
class InputDesc( class InputDesc(
...@@ -39,116 +34,3 @@ class InputDesc( ...@@ -39,116 +34,3 @@ class InputDesc(
log_deprecated("InputDesc", "Use tf.TensorSpec instead!", "2020-03-01") log_deprecated("InputDesc", "Use tf.TensorSpec instead!", "2020-03-01")
assert isinstance(type, tf.DType), type assert isinstance(type, tf.DType), type
return tf.TensorSpec(shape=shape, dtype=type, name=name) return tf.TensorSpec(shape=shape, dtype=type, name=name)
class ModelDescBase(object):
"""
Base class for a model description.
"""
@HIDE_DOC
def get_inputs_desc(self):
log_deprecated("ModelDesc.get_inputs_desc", "Use get_input_signature instead!", "2020-03-01")
return self.get_input_signature()
@memoized_method
def get_input_signature(self):
"""
Returns:
A list of :class:`tf.TensorSpec`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
"""
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
assert isinstance(inputs, (list, tuple)), \
"ModelDesc.inputs() should return a list of tf.TensorSpec objects! Got {} instead.".format(str(inputs))
if isinstance(inputs[0], tf.Tensor):
for p in inputs:
assert "Placeholder" in p.op.type, \
"inputs() have to return TensorSpec or placeholders! Found {} instead.".format(p)
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [TensorSpec(shape=p.shape, dtype=p.dtype, name=get_op_tensor_name(p.name)[0]) for p in inputs]
@property
def input_names(self):
"""
list[str]: the names of all the inputs.
"""
return [k.name for k in self.get_input_signature()]
def inputs(self):
"""
Returns a list of :class:`tf.TensorSpec` or placeholders.
A subclass is expected to implement this method.
If returning placeholders,
the placeholders **have to** be created inside this method.
Don't return placeholders created in other places.
Also, you should never call this method by yourself.
Returns:
list[tf.TensorSpec or tf.placeholder]. To be converted to :class:`tf.TensorSpec`.
"""
raise NotImplementedError()
def build_graph(self, *args):
"""
Build the whole symbolic graph.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
A subclass is expected to implement this method.
Args:
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
Returns:
In general it returns nothing, but a subclass
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
raise NotImplementedError()
@property
def training(self):
"""
bool: whether the caller is under a training context or not.
"""
return get_current_tower_context().is_training
class ModelDesc(ModelDescBase):
"""
A ModelDesc with **single cost** and **single optimizer**.
It has the following constraints in addition to :class:`ModelDescBase`:
1. :meth:`build_graph(...)` method should return a cost when called under a training context.
The cost will be the final cost to be optimized by the optimizer.
Therefore it should include necessary regularization.
2. Subclass is expected to implement :meth:`optimizer()` method.
"""
@memoized_method
def get_optimizer(self):
"""
Return the memoized optimizer returned by `optimizer()`.
Users of :class:`ModelDesc` will need to implement `optimizer()`,
which will only be called once per each model.
Returns:
a :class:`tf.train.Optimizer` instance.
"""
ret = self.optimizer()
assert isinstance(ret, tfv1.train.Optimizer), \
"ModelDesc.optimizer() must return a tf.train.Optimizer! Got {} instead.".format(str(ret))
return ret
def optimizer(self):
"""
Returns a `tf.train.Optimizer` instance.
A subclass is expected to implement this method.
"""
raise NotImplementedError()
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
import six import six
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..graph_builder import ModelDescBase from ..train.model_desc import ModelDescBase
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import JustCurrentSession, SessionInit from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TowerFunc from ..tfutils.tower import TowerFunc
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
...@@ -77,7 +78,7 @@ class PredictConfig(object): ...@@ -77,7 +78,7 @@ class PredictConfig(object):
name, tp.__name__, v.__class__.__name__) name, tp.__name__, v.__class__.__name__)
if inputs_desc is not None: if inputs_desc is not None:
# TODO warn deprecated or not? log_deprecated("PredictConfig(inputs_desc)", "Use input_signature instead!", "2020-03-01")
assert input_signature is None, "Cannot set both inputs_desc and input_signature!" assert input_signature is None, "Cannot set both inputs_desc and input_signature!"
input_signature = inputs_desc input_signature = inputs_desc
......
...@@ -10,7 +10,7 @@ from six.moves import zip ...@@ -10,7 +10,7 @@ from six.moves import zip
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC, log_deprecated
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name from .common import get_op_or_tensor_by_name, get_op_tensor_name
...@@ -309,7 +309,7 @@ class TowerFunc(object): ...@@ -309,7 +309,7 @@ class TowerFunc(object):
@property @property
def inputs_desc(self): def inputs_desc(self):
# TODO mark deprecated log_deprecated("TowerFunc.inputs_desc", "Use .input_signature instead", "2020-03-01")
return self._input_signature return self._input_signature
......
...@@ -7,12 +7,13 @@ import tensorflow as tf ...@@ -7,12 +7,13 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter) JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..input_source import InputSource from ..input_source import InputSource
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.sessinit import SaverRestore, SessionInit from ..tfutils.sessinit import SaverRestore, SessionInit
from ..utils import logger from ..utils import logger
from .model_desc import ModelDescBase
__all__ = ['TrainConfig', 'AutoResumeTrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS'] __all__ = ['TrainConfig', 'AutoResumeTrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS']
......
# -*- coding: utf-8 -*-
# File: model_desc.py
import tensorflow as tf
from ..utils.develop import log_deprecated, HIDE_DOC
from ..utils.argtools import memoized_method
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..compat import backport_tensor_spec, tfv1
TensorSpec = backport_tensor_spec()
__all__ = ['ModelDesc', 'ModelDescBase']
class ModelDescBase(object):
"""
Base class for a model description.
It is used for the simple training interface described in
`Training Interface Tutorial <https://tensorpack.readthedocs.io/tutorial/training-interface.html>`_.
Subclass is expected to implement :meth:`inputs` and :meth:`build_graph`, as they
together define a tower function.
"""
@HIDE_DOC
def get_inputs_desc(self):
log_deprecated("ModelDesc.get_inputs_desc", "Use get_input_signature instead!", "2020-03-01")
return self.get_input_signature()
@memoized_method
def get_input_signature(self):
"""
Returns:
A list of :class:`tf.TensorSpec`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
"""
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
assert isinstance(inputs, (list, tuple)), \
"ModelDesc.inputs() should return a list of tf.TensorSpec objects! Got {} instead.".format(str(inputs))
if isinstance(inputs[0], tf.Tensor):
for p in inputs:
assert "Placeholder" in p.op.type, \
"inputs() have to return TensorSpec or placeholders! Found {} instead.".format(p)
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [TensorSpec(shape=p.shape, dtype=p.dtype, name=get_op_tensor_name(p.name)[0]) for p in inputs]
@property
def input_names(self):
"""
list[str]: the names of all the inputs.
"""
return [k.name for k in self.get_input_signature()]
def inputs(self):
"""
A subclass is expected to implement this method.
If returning placeholders,
the placeholders **have to** be created inside this method.
Don't return placeholders created in other places.
Also, users should never call this method by yourself.
Returns:
list[tf.TensorSpec or tf.placeholder].
"""
raise NotImplementedError()
def build_graph(self, *args):
"""
A subclass is expected to implement this method.
Build the whole symbolic graph.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
Args:
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
Returns:
In general it returns nothing, but a subclass
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
raise NotImplementedError()
@property
def training(self):
"""
bool: whether the caller is under a training context or not.
"""
return get_current_tower_context().is_training
class ModelDesc(ModelDescBase):
"""
One subclass of :class:`ModelDescBase` with the assupmtion of
**single cost** and **single optimizer** training.
It has the following constraints in addition to :class:`ModelDescBase`:
1. `build_graph(...)` method should return a cost tensor when called under a training context.
The cost will be the final cost to be optimized by the optimizer.
Therefore it should include necessary regularization.
2. Subclass is expected to implement :meth:`optimizer()` method.
"""
@memoized_method
def get_optimizer(self):
"""
Return the memoized optimizer returned by `optimizer()`.
Users of :class:`ModelDesc` will need to implement `optimizer()`,
which will only be called once per each model.
Returns:
a :class:`tf.train.Optimizer` instance.
"""
ret = self.optimizer()
assert isinstance(ret, tfv1.train.Optimizer), \
"ModelDesc.optimizer() must return a tf.train.Optimizer! Got {} instead.".format(str(ret))
return ret
def optimizer(self):
"""
A subclass is expected to implement this method.
Returns:
a `tf.train.Optimizer` instance.
"""
raise NotImplementedError()
...@@ -12,7 +12,7 @@ from ..tfutils.gradproc import FilterNoneGrad ...@@ -12,7 +12,7 @@ from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import PredictTowerContext, TowerFunc, get_current_tower_context from ..tfutils.tower import PredictTowerContext, TowerFunc, get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC, log_deprecated
from .base import Trainer from .base import Trainer
__all__ = ['SingleCostTrainer', 'TowerTrainer'] __all__ = ['SingleCostTrainer', 'TowerTrainer']
...@@ -22,11 +22,12 @@ class TowerTrainer(Trainer): ...@@ -22,11 +22,12 @@ class TowerTrainer(Trainer):
""" """
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`. Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model The assumption of tower function is required by some features that replicates the model
automatically, e.g. creating a predictor. automatically. For example, TowerTrainer can create a predictor for you automatically,
by calling the tower function.
To use features of :class:`TowerTrainer`, set `tower_func` and use it to build the graph. To use :class:`TowerTrainer`, set `tower_func` and use it to build the graph.
Note that `tower_func` can only be set once per instance. Note that `tower_func` can only be set once per instance of `TowerTrainer`.
""" """
_tower_func = None _tower_func = None
...@@ -56,25 +57,22 @@ class TowerTrainer(Trainer): ...@@ -56,25 +57,22 @@ class TowerTrainer(Trainer):
@property @property
def inputs_desc(self): def inputs_desc(self):
# TODO mark deprecated log_deprecated("TowerTrainer.inputs_desc", "Use .input_signature instead!", "2020-03-01")
return self.input_signature return self.input_signature
@property @property
def input_signature(self): def input_signature(self):
""" """
Returns: list[tf.TensorSpec]: metainfo about the inputs to the tower.
list[tf.TensorSpec]: metainfo about the inputs to the tower.
""" """
return self.tower_func.input_signature return self.tower_func.input_signature
@property @property
def towers(self): def towers(self):
""" """
Returns: TowerTensorHandles: used to access the tower handles by either indices or names.
a :class:`TowerTensorHandles` object, to
access the tower handles by either indices or names.
It is accessbile only after the graph is set up. This property is accessbile only after the graph is set up.
With :meth:`towers`, you can then access many attributes of each tower: With :meth:`towers`, you can then access many attributes of each tower:
Example: Example:
...@@ -91,7 +89,8 @@ class TowerTrainer(Trainer): ...@@ -91,7 +89,8 @@ class TowerTrainer(Trainer):
This method will build the trainer's tower function under ``TowerContext(is_training=False)``, This method will build the trainer's tower function under ``TowerContext(is_training=False)``,
and returns a callable predictor with input placeholders & output tensors in this tower. and returns a callable predictor with input placeholders & output tensors in this tower.
This method handles the common case of inference with the same tower function. This method handles the common case where you inference with the same tower function
you provide to the trainer.
If you want to do inference with a different tower function, you can always build the tower by yourself, If you want to do inference with a different tower function, you can always build the tower by yourself,
under a "reuse" variable scope and a `TowerContext(is_training=False)`. under a "reuse" variable scope and a `TowerContext(is_training=False)`.
...@@ -205,7 +204,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -205,7 +204,7 @@ class SingleCostTrainer(TowerTrainer):
Args: Args:
input_signature ([TensorSpec]): list of TensorSpec that describe the inputs input_signature ([TensorSpec]): list of TensorSpec that describe the inputs
input (InputSource): input (InputSource): an InputSource which has to match the input signature
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor. get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once. optimizer. Will only be called once.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment