Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e21fc267
Commit
e21fc267
authored
Jan 24, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use get_extra_fetches() to allow trainer to fetch something more at certain steps.
parent
589a8a35
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
17 additions
and
6 deletions
+17
-6
README.md
README.md
+1
-1
docs/index.rst
docs/index.rst
+2
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-0
tensorpack/train/base.py
tensorpack/train/base.py
+10
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
README.md
View file @
e21fc267
...
...
@@ -19,7 +19,7 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t
+
[
Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym
](
examples/A3C-Gym
)
### Unsupervised Learning:
+
[
Several
Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN
](
examples/GAN
)
+
[
Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN
](
examples/GAN
)
### Speech / NLP:
+
[
LSTM-CTC for speech recognition
](
examples/CTC-TIMIT
)
...
...
docs/index.rst
View file @
e21fc267
...
...
@@ -2,6 +2,8 @@ Welcome to tensorpack!
======================================
tensorpack is in early development.
All tutorials are drafts for now. You can get an idea from them but the details
might not be correct.
.. toctree::
:maxdepth: 2
...
...
examples/GAN/GAN.py
View file @
e21fc267
...
...
@@ -98,7 +98,7 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
train_op
=
self
.
d_min
def
run_step
(
self
):
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
extra_fetches
)
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
()
)
return
ret
[
1
:]
...
...
tensorpack/models/batch_norm.py
View file @
e21fc267
...
...
@@ -160,6 +160,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# maintain EMA only in the main training tower
if
ctx
.
is_main_training_tower
:
# TODO a way to use debias in multitower.
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
...
...
tensorpack/train/base.py
View file @
e21fc267
...
...
@@ -41,7 +41,6 @@ class Trainer(object):
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
epoch_num (int): the current epoch number.
step_num (int): the current step number (in an epoch).
"""
...
...
@@ -130,6 +129,15 @@ class Trainer(object):
"""
self
.
add_summary
(
create_scalar_summary
(
name
,
val
))
def
get_extra_fetches
(
self
):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
return
self
.
_extra_fetches
def
setup
(
self
):
"""
Setup the trainer and be ready for the main loop.
...
...
@@ -140,7 +148,7 @@ class Trainer(object):
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
self
.
_
extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
...
...
tensorpack/train/feedfree.py
View file @
e21fc267
...
...
@@ -54,7 +54,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
extra_fetches
)
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
()
)
return
ret
[
1
:]
# if not hasattr(self, 'cnt'):
# self.cnt = 0
...
...
tensorpack/train/trainer.py
View file @
e21fc267
...
...
@@ -72,7 +72,7 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
extra_fetches
,
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
()
,
feed_dict
=
feed
)
return
ret
[
1
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment