Commit 6262f719 authored by Yuxin Wu's avatar Yuxin Wu

docs / api cleanup

parent 7bdaf8ec
......@@ -4,7 +4,7 @@
Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will take care of step 1 (define the graph), with the following arguments:
These trainers will take care help you define the graph, with the following arguments:
1. Some `tf.TensorSpec`, the signature of the input.
2. An `InputSource`, where the input come from. See [Input Pipeline](input-source.html).
......
......@@ -2,7 +2,7 @@
# Performance Tuning
__We do not know why your training is slow__
(and most of the times it's not due to issues in tensorpack),
(and most of the times it's not due to tensorpack),
unless we can reproduce the slowness with your instsructions.
Tensorpack is designed to be high-performance, as can be seen in the [benchmarks](https://github.com/tensorpack/benchmarks).
......
......@@ -65,7 +65,7 @@ See [config.py](config.py) for details about how to correctly set `BACKBONE.WEIG
### Inference:
To predict on an image (needs DISPLAY to show the outputs):
To predict on given images (needs DISPLAY to show the outputs):
```
./predict.py --predict input1.jpg input2.jpg --load /path/to/Trained-Model-Checkpoint --config SAME-AS-TRAINING
```
......
......@@ -5,6 +5,7 @@
from collections import namedtuple
import tensorflow as tf
from ..utils.develop import log_deprecated, HIDE_DOC
from ..utils.argtools import memoized_method
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
......@@ -74,9 +75,9 @@ class ModelDescBase(object):
Base class for a model description.
"""
@memoized_method
@HIDE_DOC
def get_inputs_desc(self):
# TODO mark deprecated
log_deprecated("ModelDesc.get_inputs_desc", "Use get_input_signature instead!", "2020-03-01")
return self.get_input_signature()
@memoized_method
......@@ -100,8 +101,7 @@ class ModelDescBase(object):
@property
def input_names(self):
"""
Returns:
[str]: the names of all the inputs.
list[str]: the names of all the inputs.
"""
return [k.name for k in self.get_input_signature()]
......@@ -111,7 +111,7 @@ class ModelDescBase(object):
A subclass is expected to implement this method.
If returning placeholders,
the placeholders __have to__ be created inside this method.
the placeholders **have to** be created inside this method.
Don't return placeholders created in other places.
Also, you should never call this method by yourself.
......@@ -141,7 +141,6 @@ class ModelDescBase(object):
@property
def training(self):
"""
Returns:
bool: whether the caller is under a training context or not.
"""
return get_current_tower_context().is_training
......
......@@ -116,8 +116,8 @@ class DataParallelBuilder(GraphBuilder):
ret.append(func())
return ret
@HIDE_DOC
@staticmethod
@HIDE_DOC
def build_on_towers(*args, **kwargs):
return DataParallelBuilder.call_for_each_tower(*args, **kwargs)
......
......@@ -47,38 +47,35 @@ class BaseTowerContext(object):
@abstractproperty
def is_main_training_tower(self):
"""
Whether this tower is the main (i.e., the first) training tower.
bool: Whether this tower is the main (i.e., the first) training tower.
"""
pass
@abstractproperty
def has_own_variables(self):
"""
Whether this tower is supposed to have its own trainable variables.
bool: Whether this tower is supposed to have its own trainable variables.
"""
pass
@property
def name(self):
"""
Returns:
str - The name scope of the tower.
str: The name scope of the tower.
"""
return self._name
@property
def vs_name(self):
"""
Returns:
str - The variable scope of the tower.
str: The variable scope of the tower.
"""
return self._vs_name
@property
def ns_name(self):
"""
Returns:
str - The name scope of the tower.
str: The name scope of the tower.
"""
return self._name
......@@ -157,10 +154,15 @@ class BaseTowerContext(object):
return "TowerContext(name={}, is_training={})".format(
self._name, self._is_training)
@property
def is_training(self):
"""
bool: whether the context is training or not
"""
return self._is_training
class TrainTowerContext(BaseTowerContext):
is_training = True
class TrainTowerContext(BaseTowerContext):
def __init__(self, ns_name, vs_name='', index=0, total=1):
"""
......@@ -169,6 +171,7 @@ class TrainTowerContext(BaseTowerContext):
total (int): total number of towers to be built.
"""
super(TrainTowerContext, self).__init__(ns_name, vs_name)
self._is_training = True
self.index = int(index)
self.total = int(total)
......@@ -196,11 +199,9 @@ class TrainTowerContext(BaseTowerContext):
class PredictTowerContext(BaseTowerContext):
is_training = False
def __init__(self, ns_name, vs_name=''):
super(PredictTowerContext, self).__init__(ns_name, vs_name)
self._is_training = False
self._initial_vs_reuse = tf.get_variable_scope().reuse
......@@ -249,7 +250,8 @@ def TowerContext(tower_name, is_training, vs_name=''):
class TowerFunc(object):
"""
A tower function (see
[tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)).
`tutorial on tower function
<http://tensorpack.readthedocs.io/tutorial/extend/trainer.html#tower-trainer>`_)
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
......@@ -296,8 +298,7 @@ class TowerFunc(object):
@property
def towers(self):
"""
Returns:
a :class:`TowerTensorHandles` object, that can
TowerTensorHandles: a :class:`TowerTensorHandles` object, that can
access the tower handles by either indices or names.
"""
return TowerTensorHandles(self._handles)
......@@ -366,7 +367,7 @@ class TowerTensorHandle(object):
"""
@HIDE_DOC
def __init__(self, ctx, input, output, input_signature=None):
def __init__(self, ctx, inputs, outputs, input_signature=None):
self._ctx = ctx
self._extra_tensor_names = {}
......@@ -374,8 +375,12 @@ class TowerTensorHandle(object):
assert len(input_signature) == len(input)
self._extra_tensor_names = {
get_op_tensor_name(x.name)[1]: y for x, y in zip(input_signature, input)}
self._input = input
self._output = output
self._inputs = inputs
self._outputs = outputs
# TODO: deprecated. Remove them later
self.input = inputs
self.output = outputs
@property
def vs_name(self):
......@@ -465,18 +470,18 @@ class TowerTensorHandle(object):
return self._ctx.get_collection_in_tower(key)
@property
def input(self):
def inputs(self):
"""
The list of input tensors used to build the tower.
list[Tensor]: The list of input tensors used to build the tower.
"""
return self._input
return self._inputs
@property
def output(self):
def outputs(self):
"""
The output returned by the tower function.
list[Tensor]: The outputs returned by the tower function.
"""
return self._output
return self._outputs
@property
def is_training(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment