Commit 94a445ad authored by Yuxin Wu's avatar Yuxin Wu

[WIP] trigger_step with fetch

parent ab86361f
......@@ -6,7 +6,7 @@ Implemented A3C in [Asynchronous Methods for Deep Reinforcement Learning](http:/
`./train-atari.py --env Breakout-v0 --gpu 0`
It should run at a speed of 6~10 iteration/s on 1 GPU.
It should run at a speed of 6~10 iteration/s on 1 GPU plus 12+ CPU cores.
Training with a significant slower speed (e.g. on CPU) will give bad performance,
probably because of async issues.
The pre-trained models are all trained with 4 GPUs for about 2 days.
......
......@@ -13,6 +13,7 @@ Training examples with __reproducible__ and meaningful performance.
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](HED)
+ [Spatial Transformer Networks on MNIST addition](SpatialTransformer)
+ [Visualize Saliency Maps by Guided ReLU](Saliency)
+ [Similarity Learning on MNIST](SimilarityLearning)
+ Load a pre-trained [AlexNet](load-alexnet.py) or [VGG16](load-vgg16.py) model.
+ Load a pre-trained [Convolutional Pose Machines](ConvolutionalPoseMachines/).
......
......@@ -11,7 +11,9 @@ and warped them separately.
<p align="center"> <img src="./demo.jpg" width="400"> </p>
Left: input image; Middle: output of the first STN branch (which localizes the second digit); Right: output of the second STN branch.
* Left: input image.
* Middle: output of the first STN branch (which localizes the second digit).
* Right: output of the second STN branch.
To train (takes about 300 epochs to reach 8.8% error):
```bash
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
from abc import ABCMeta
import six
from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
......@@ -49,12 +50,42 @@ class Callback(object):
def _before_train(self):
pass
def trigger_step(self):
def trigger_step(self, *args):
"""
Callback to be triggered after every step (every backpropagation)
Callback to be triggered after every step (every backpropagation).
Args:
args: a list of values corresponding to :meth:`extra_fetches`.
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
self._trigger_step(*args)
def _trigger_step(self, *args):
pass
def extra_fetches(self):
"""
Returns:
list: a list of elements to be fetched in every step and
passed to :meth:`trigger_step`. Elements can be
Operations/Tensors, or names of Operations/Tensors.
This function will be called only after the graph is finalized.
This function should be a pure function (i.e. no side-effect when called)
"""
fetches = self._extra_fetches()
ret = []
for f in fetches:
if isinstance(f, (tf.Tensor, tf.Operation)):
ret.append(f)
else:
ret.append(get_op_or_tensor_by_name(f))
return ret
def _extra_fetches(self):
return []
def trigger_epoch(self):
"""
......@@ -110,8 +141,6 @@ class ProxyCallback(Callback):
class PeriodicCallback(ProxyCallback):
"""
Wrap a callback so that it is triggered after every ``period`` epochs.
Doesn't work for ``trigger_step``.
"""
def __init__(self, cb, period):
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from contextlib import contextmanager
from collections import defaultdict
import time
from .base import Callback
......@@ -67,6 +68,7 @@ class Callbacks(Callback):
raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!")
self.cbs = cbs
self._extra_fetches_cache = None
def _setup_graph(self):
with tf.name_scope(None):
......@@ -81,9 +83,30 @@ class Callbacks(Callback):
for cb in self.cbs:
cb.after_train()
def trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
def _extra_fetches(self):
if self._extra_fetches_cache is not None:
return self._extra_fetches_cache
# TODO use dispatch mechanism to avoid duplication
self._cbid_to_fetchid = defaultdict(list)
ret = []
for idx, cb in enumerate(self.cbs):
fetch = cb.extra_fetches()
if len(fetch) == 0:
continue
for f in fetch:
ret.append(f)
self._cbid_to_fetchid[idx].append(len(ret)-1)
self._extra_fetches_cache = ret
return ret
def _trigger_step(self, *args):
for idx, cb in enumerate(self.cbs):
fid = self._cbid_to_fetchid[idx]
if len(fid) == 0:
cb.trigger_step()
else:
data = [args[k] for k in fid]
cb.trigger_step(*data)
def _trigger_epoch(self):
tm = CallbackTimeLogger()
......
......@@ -137,8 +137,8 @@ def summary_moving_average(tensors=None):
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
0.95, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
return avg_maintain_op
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment