Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
94a445ad
Commit
94a445ad
authored
Jan 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] trigger_step with fetch
parent
ab86361f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
13 deletions
+68
-13
examples/A3C-Gym/README.md
examples/A3C-Gym/README.md
+1
-1
examples/README.md
examples/README.md
+1
-0
examples/SpatialTransformer/README.md
examples/SpatialTransformer/README.md
+3
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+33
-4
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+26
-3
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+4
-4
No files found.
examples/A3C-Gym/README.md
View file @
94a445ad
...
...
@@ -6,7 +6,7 @@ Implemented A3C in [Asynchronous Methods for Deep Reinforcement Learning](http:/
`./train-atari.py --env Breakout-v0 --gpu 0`
It should run at a speed of 6~10 iteration/s on 1 GPU.
It should run at a speed of 6~10 iteration/s on 1 GPU
plus 12+ CPU cores
.
Training with a significant slower speed (e.g. on CPU) will give bad performance,
probably because of async issues.
The pre-trained models are all trained with 4 GPUs for about 2 days.
...
...
examples/README.md
View file @
94a445ad
...
...
@@ -13,6 +13,7 @@ Training examples with __reproducible__ and meaningful performance.
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
SpatialTransformer
)
+
[
Visualize Saliency Maps by Guided ReLU
](
Saliency
)
+
[
Similarity Learning on MNIST
](
SimilarityLearning
)
+
Load a pre-trained
[
AlexNet
](
load-alexnet.py
)
or
[
VGG16
](
load-vgg16.py
)
model.
+
Load a pre-trained
[
Convolutional Pose Machines
](
ConvolutionalPoseMachines/
)
.
...
...
examples/SpatialTransformer/README.md
View file @
94a445ad
...
...
@@ -11,7 +11,9 @@ and warped them separately.
<p
align=
"center"
>
<img
src=
"./demo.jpg"
width=
"400"
>
</p>
Left: input image; Middle: output of the first STN branch (which localizes the second digit); Right: output of the second STN branch.
*
Left: input image.
*
Middle: output of the first STN branch (which localizes the second digit).
*
Right: output of the second STN branch.
To train (takes about 300 epochs to reach 8.8% error):
```
bash
...
...
tensorpack/callbacks/base.py
View file @
94a445ad
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
from
abc
import
ABCMeta
import
six
from
..tfutils.common
import
get_op_or_tensor_by_name
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
,
'CallbackFactory'
]
...
...
@@ -49,12 +50,42 @@ class Callback(object):
def
_before_train
(
self
):
pass
def
trigger_step
(
self
):
def
trigger_step
(
self
,
*
args
):
"""
Callback to be triggered after every step (every backpropagation)
Callback to be triggered after every step (every backpropagation).
Args:
args: a list of values corresponding to :meth:`extra_fetches`.
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
self
.
_trigger_step
(
*
args
)
def
_trigger_step
(
self
,
*
args
):
pass
def
extra_fetches
(
self
):
"""
Returns:
list: a list of elements to be fetched in every step and
passed to :meth:`trigger_step`. Elements can be
Operations/Tensors, or names of Operations/Tensors.
This function will be called only after the graph is finalized.
This function should be a pure function (i.e. no side-effect when called)
"""
fetches
=
self
.
_extra_fetches
()
ret
=
[]
for
f
in
fetches
:
if
isinstance
(
f
,
(
tf
.
Tensor
,
tf
.
Operation
)):
ret
.
append
(
f
)
else
:
ret
.
append
(
get_op_or_tensor_by_name
(
f
))
return
ret
def
_extra_fetches
(
self
):
return
[]
def
trigger_epoch
(
self
):
"""
...
...
@@ -110,8 +141,6 @@ class ProxyCallback(Callback):
class
PeriodicCallback
(
ProxyCallback
):
"""
Wrap a callback so that it is triggered after every ``period`` epochs.
Doesn't work for ``trigger_step``.
"""
def
__init__
(
self
,
cb
,
period
):
...
...
tensorpack/callbacks/group.py
View file @
94a445ad
...
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
time
from
.base
import
Callback
...
...
@@ -67,6 +68,7 @@ class Callbacks(Callback):
raise
ValueError
(
"Callbacks must contain StatPrinter for stat and writer to work properly!"
)
self
.
cbs
=
cbs
self
.
_extra_fetches_cache
=
None
def
_setup_graph
(
self
):
with
tf
.
name_scope
(
None
):
...
...
@@ -81,9 +83,30 @@ class Callbacks(Callback):
for
cb
in
self
.
cbs
:
cb
.
after_train
()
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
cb
.
trigger_step
()
def
_extra_fetches
(
self
):
if
self
.
_extra_fetches_cache
is
not
None
:
return
self
.
_extra_fetches_cache
# TODO use dispatch mechanism to avoid duplication
self
.
_cbid_to_fetchid
=
defaultdict
(
list
)
ret
=
[]
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
fetch
=
cb
.
extra_fetches
()
if
len
(
fetch
)
==
0
:
continue
for
f
in
fetch
:
ret
.
append
(
f
)
self
.
_cbid_to_fetchid
[
idx
]
.
append
(
len
(
ret
)
-
1
)
self
.
_extra_fetches_cache
=
ret
return
ret
def
_trigger_step
(
self
,
*
args
):
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
fid
=
self
.
_cbid_to_fetchid
[
idx
]
if
len
(
fid
)
==
0
:
cb
.
trigger_step
()
else
:
data
=
[
args
[
k
]
for
k
in
fid
]
cb
.
trigger_step
(
*
data
)
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
...
...
tensorpack/tfutils/summary.py
View file @
94a445ad
...
...
@@ -137,8 +137,8 @@ def summary_moving_average(tensors=None):
with
tf
.
name_scope
(
None
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
0.95
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
tensors
)
for
idx
,
c
in
enumerate
(
tensors
):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
avg_maintain_op
=
averager
.
apply
(
tensors
)
for
idx
,
c
in
enumerate
(
tensors
):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
return
avg_maintain_op
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment