Commit c8a9e4e5 authored by Yuxin Wu's avatar Yuxin Wu

get_collection in TowerTensorHandle

parent 2f9e2c0e
......@@ -36,7 +36,6 @@ We refuse toy examples.
Instead of showing you 10 arbitrary networks trained on toy datasets,
[tensorpack examples](examples) faithfully replicate papers and care about reproducing numbers,
demonstrating its flexibility for actual research.
Some highlights:
### Vision:
+ [Train ResNet](examples/ResNet) and [other models](examples/ImageNetModels) on ImageNet.
......@@ -57,7 +56,6 @@ Some highlights:
+ [char-rnn for fun](examples/Char-RNN)
+ [LSTM language model on PennTreebank](examples/PennTreebank)
## Install:
Dependencies:
......
......@@ -98,6 +98,7 @@ class MinSaver(Callback):
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
checkpoint_dir (str): the directory containing checkpoints.
Example:
Save the model with minimum validation error to
......@@ -108,9 +109,8 @@ class MinSaver(Callback):
MinSaver('val-error')
Notes:
It assumes that :class:`ModelSaver` is used with
the same ``checkpoint_dir``. And it will save
the model to that directory as well.
It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir``
and appears earlier in the callback list.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
"""
......
......@@ -333,6 +333,15 @@ class TowerTensorHandle(object):
name_with_vs = name
return get_op_or_tensor_by_name(name_with_vs)
def get_collection(self, name):
"""
Get items from a collection that are added in this tower.
Args:
name (str): the name of the collection
"""
return self._ctx.get_collection_in_tower(name)
@property
def input(self):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment