Commit 21f74399 authored by Yuxin Wu's avatar Yuxin Wu

update docs; type check

parent bbac5cf1
...@@ -18,12 +18,6 @@ TensorFlow itself also changes API and those are not listed here. ...@@ -18,12 +18,6 @@ TensorFlow itself also changes API and those are not listed here.
return [tf.TensorSpec((None, 28, 28, 1), tf.float32, 'image'), return [tf.TensorSpec((None, 28, 28, 1), tf.float32, 'image'),
tf.TensorSpec((None,), tf.int32, 'label')] tf.TensorSpec((None,), tf.int32, 'label')]
``` ```
+ [2018/08/27] msgpack is used for "serialization to disk", because pyarrow
has no compatibility between versions. To use pyarrow instead, `export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`.
+ [2018/04/05] <del>msgpack is replaced by pyarrow in favor of its speed. If you want old behavior,
`export TENSORPACK_SERIALIZE=msgpack`.</del>
It's later found that pyarrow is unstable and may lead to crash.
So the default serialization is changed back to msgpack.
+ [2018/03/20] `ModelDesc` starts to use simplified interfaces: + [2018/03/20] `ModelDesc` starts to use simplified interfaces:
+ `_get_inputs()` renamed to `inputs()` and returns `tf.TensorSpec`. + `_get_inputs()` renamed to `inputs()` and returns `tf.TensorSpec`.
+ `build_graph(self, tensor1, tensor2)` returns the cost tensor directly. + `build_graph(self, tensor1, tensor2)` returns the cost tensor directly.
......
...@@ -18,7 +18,7 @@ This is likely the best-performing open source TensorFlow reimplementation of th ...@@ -18,7 +18,7 @@ This is likely the best-performing open source TensorFlow reimplementation of th
## Dependencies ## Dependencies
+ Python 3.3+; OpenCV + Python 3.3+; OpenCV
+ TensorFlow ≥ 1.6 + TensorFlow ≥ 1.6
+ pycocotools: `pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'` + pycocotools: `for i in cython 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'; do pip install $i; done`
+ Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/FasterRCNN/) + Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/FasterRCNN/)
from tensorpack model zoo from tensorpack model zoo
+ [COCO data](http://cocodataset.org/#download). It needs to have the following directory structure: + [COCO data](http://cocodataset.org/#download). It needs to have the following directory structure:
......
...@@ -84,6 +84,8 @@ class ModelDescBase(object): ...@@ -84,6 +84,8 @@ class ModelDescBase(object):
""" """
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs() inputs = self.inputs()
assert isinstance(inputs, (list, tuple)), \
"ModelDesc.inputs() should return a list of tf.TensorSpec objects! Got {} instead.".format(str(inputs))
if isinstance(inputs[0], tf.Tensor): if isinstance(inputs[0], tf.Tensor):
for p in inputs: for p in inputs:
assert "Placeholder" in p.op.type, \ assert "Placeholder" in p.op.type, \
...@@ -157,7 +159,10 @@ class ModelDesc(ModelDescBase): ...@@ -157,7 +159,10 @@ class ModelDesc(ModelDescBase):
Returns: Returns:
a :class:`tf.train.Optimizer` instance. a :class:`tf.train.Optimizer` instance.
""" """
return self.optimizer() ret = self.optimizer()
assert isinstance(ret, tfv1.train.Optimizer), \
"ModelDesc.optimizer() must return a tf.train.Optimizer! Got {} instead.".format(str(ret))
return ret
def optimizer(self): def optimizer(self):
""" """
......
...@@ -47,6 +47,7 @@ def remap_variables(fn): ...@@ -47,6 +47,7 @@ def remap_variables(fn):
Example: Example:
.. code-block:: python .. code-block:: python
from tensorpack.tfutils import varreplace
with varreplace.remap_variables(lambda var: quantize(var)): with varreplace.remap_variables(lambda var: quantize(var)):
x = FullyConnected('fc', x, 1000) # fc/{W,b} will be quantized x = FullyConnected('fc', x, 1000) # fc/{W,b} will be quantized
""" """
...@@ -67,6 +68,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False): ...@@ -67,6 +68,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
Example: Example:
.. code-block:: python .. code-block:: python
from tensorpack.tfutils import varreplace
with varreplace.freeze_variable(stop_gradient=False, skip_collection=True): with varreplace.freeze_variable(stop_gradient=False, skip_collection=True):
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
......
...@@ -253,7 +253,8 @@ class SingleCostTrainer(TowerTrainer): ...@@ -253,7 +253,8 @@ class SingleCostTrainer(TowerTrainer):
def compute_grad_from_inputs(*inputs): def compute_grad_from_inputs(*inputs):
cost = get_cost_fn(*inputs) cost = get_cost_fn(*inputs)
assert isinstance(cost, tf.Tensor), cost assert isinstance(cost, tf.Tensor), \
"Expect the given function to return a cost, but got {} instead".format(str(cost))
assert cost.shape.ndims == 0, "Cost must be a scalar, but found {}!".format(cost) assert cost.shape.ndims == 0, "Cost must be a scalar, but found {}!".format(cost)
if not ctx.is_training: if not ctx.is_training:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment