Commit 05352860 authored by Yuxin Wu's avatar Yuxin Wu

doc update and inputsdesc name check (#564)

parent b936df1d
Bug Reports/Feature Requests/Usage Questions Only: Bug Reports/Feature Requests/Usage Questions Only:
Bug reports or other problems with code: PLEASE always include Any unexpected problems with code: PLEASE always include
1. What you did. (command you run and changes you made if using examples; post or describe your code if not) 1. What you did. (command you run and changes you made if using examples; post or describe your code if not)
2. What you observed, e.g. as much as logs possible. 2. What you observed, e.g. as much as logs possible.
3. What you expected, if not obvious. 3. What you expected, if not obvious.
......
...@@ -42,16 +42,21 @@ l = func(l, *args, **kwargs) ...@@ -42,16 +42,21 @@ l = func(l, *args, **kwargs)
l = FullyConnected('fc1', l, 10, nl=tf.identity) l = FullyConnected('fc1', l, 10, nl=tf.identity)
``` ```
### Access Internal Variables: ### Access Relevant Tensors
Access the variables like this: The variables inside the layer will be named `name/W`, `name/b`, etc.
See the API documentation of each layer for details.
When building the graph, you can access the variables like this:
```python ```python
l = Conv2D('conv1', l, 32, 3) l = Conv2D('conv1', l, 32, 3)
print(l.variables.W) print(l.variables.W)
print(l.variables.b) print(l.variables.b)
``` ```
The names are documented in API documentation. But note that this is a hacky way and may not work with future versions of TensorFlow.
Note that this method doesn't work with LinearWrap, and cannot access the variables created by an activation function. Also this method doesn't work with LinearWrap, and cannot access the variables created by an activation function.
The output of a layer is usually named `name/output` unless documented differently in the API.
You can always print a tensor to see its name.
### Use Models outside Tensorpack ### Use Models outside Tensorpack
......
...@@ -35,6 +35,8 @@ class InputDesc( ...@@ -35,6 +35,8 @@ class InputDesc(
""" """
shape = tuple(shape) # has to be tuple for "self" to be hashable shape = tuple(shape) # has to be tuple for "self" to be hashable
assert isinstance(type, tf.DType), type assert isinstance(type, tf.DType), type
if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid InputDesc name: '{}'".format(name))
self = super(InputDesc, cls).__new__(cls, type, shape, name) self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None self._cached_placeholder = None
return self return self
......
...@@ -170,12 +170,11 @@ def get_current_tower_context(): ...@@ -170,12 +170,11 @@ def get_current_tower_context():
class TowerFuncWrapper(object): class TowerFuncWrapper(object):
""" """
A wrapper around a function which builds one tower (one replicate of the model). A wrapper around a tower function (function which builds one tower, i.e. one replicate of the model).
It keeps track of the name scope, variable scope and input/output tensors It keeps track of the name scope, variable scope and input/output tensors
each time the function is called. each time the function is called.
:class:`TowerTrainer` needs this option to be set, so that :class:`TowerTrainer` needs this so that it knows how to build a predictor.
it knows how to build a predictor.
""" """
def __init__(self, tower_fn, inputs_desc): def __init__(self, tower_fn, inputs_desc):
...@@ -186,11 +185,13 @@ class TowerFuncWrapper(object): ...@@ -186,11 +185,13 @@ class TowerFuncWrapper(object):
inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors. inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors.
""" """
assert callable(tower_fn), tower_fn assert callable(tower_fn), tower_fn
if not isinstance(tower_fn, TowerFuncWrapper): inputs_desc_names = [k.name for k in inputs_desc]
self._tower_fn = tower_fn assert len(set(inputs_desc_names)) == len(inputs_desc_names), \
self._inputs_desc = inputs_desc "Duplicated names in inputs_desc! " + str(inputs_desc_names)
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
self._handles = [] self._handles = []
def __new__(cls, tower_fn, inputs_desc): def __new__(cls, tower_fn, inputs_desc):
# to avoid double-wrapping a function # to avoid double-wrapping a function
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment