Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2870347c
Commit
2870347c
authored
Nov 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Register a list of callbacks at a time
parent
712ea325
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
17 deletions
+30
-17
examples/GAN/GAN.py
examples/GAN/GAN.py
+21
-13
tensorpack/train/base.py
tensorpack/train/base.py
+8
-1
tensorpack/train/tower.py
tensorpack/train/tower.py
+1
-3
No files found.
examples/GAN/GAN.py
View file @
2870347c
...
...
@@ -68,15 +68,24 @@ class GANTrainer(TowerTrainer):
super
(
GANTrainer
,
self
)
.
__init__
()
assert
isinstance
(
model
,
GANModelDesc
),
model
inputs_desc
=
model
.
get_inputs_desc
()
# Setup input
cbs
=
input
.
setup
(
inputs_desc
)
self
.
register_callback
(
cbs
)
# we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training.
"""
We need to set tower_func because it's a TowerTrainer,
and only TowerTrainer supports automatic graph creation for inference during training.
If we don't care about inference during training, using tower_func is
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
inputs_desc
)
with
TowerContext
(
''
,
is_training
=
True
):
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
# Define the training iteration
# by default, run one d_min after one g_min
with
tf
.
name_scope
(
'optimize'
):
g_min
=
opt
.
minimize
(
model
.
g_loss
,
var_list
=
model
.
g_vars
,
name
=
'g_op'
)
...
...
@@ -84,9 +93,6 @@ class GANTrainer(TowerTrainer):
d_min
=
opt
.
minimize
(
model
.
d_loss
,
var_list
=
model
.
d_vars
,
name
=
'd_op'
)
self
.
train_op
=
d_min
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
class
SeparateGANTrainer
(
TowerTrainer
):
""" A GAN trainer which runs two optimization ops with a certain ratio."""
...
...
@@ -101,7 +107,11 @@ class SeparateGANTrainer(TowerTrainer):
self
.
_g_period
=
int
(
g_period
)
assert
min
(
d_period
,
g_period
)
==
1
# Setup input
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
self
.
register_callback
(
cbs
)
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
())
with
TowerContext
(
''
,
is_training
=
True
):
self
.
tower_func
(
*
input
.
get_input_tensors
())
...
...
@@ -113,10 +123,8 @@ class SeparateGANTrainer(TowerTrainer):
self
.
g_min
=
opt
.
minimize
(
model
.
g_loss
,
var_list
=
model
.
g_vars
,
name
=
'g_min'
)
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
def
run_step
(
self
):
# Define the training iteration
if
self
.
global_step
%
(
self
.
_d_period
)
==
0
:
self
.
hooked_sess
.
run
(
self
.
d_min
)
if
self
.
global_step
%
(
self
.
_g_period
)
==
0
:
...
...
@@ -132,11 +140,12 @@ class MultiGPUGANTrainer(TowerTrainer):
assert
nr_gpu
>
1
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
range
(
nr_gpu
)]
#
s
etup input
#
S
etup input
input
=
StagingInput
(
input
,
list
(
range
(
nr_gpu
)))
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
self
.
register_callback
(
cbs
)
#
build the graph
#
Build the graph with multi-gpu replication
def
get_cost
(
*
inputs
):
model
.
build_graph
(
*
inputs
)
return
[
model
.
d_loss
,
model
.
g_loss
]
...
...
@@ -146,7 +155,7 @@ class MultiGPUGANTrainer(TowerTrainer):
list
(
range
(
nr_gpu
)),
lambda
:
self
.
tower_func
(
*
input
.
get_input_tensors
()),
devices
)
#
simply average the cost. It might get
faster to average the gradients
#
Simply average the cost here. It might be
faster to average the gradients
with
tf
.
name_scope
(
'optimize'
):
d_loss
=
tf
.
add_n
([
x
[
0
]
for
x
in
cost_list
])
*
(
1.0
/
nr_gpu
)
g_loss
=
tf
.
add_n
([
x
[
1
]
for
x
in
cost_list
])
*
(
1.0
/
nr_gpu
)
...
...
@@ -158,9 +167,8 @@ class MultiGPUGANTrainer(TowerTrainer):
with
tf
.
control_dependencies
([
g_min
]):
d_min
=
opt
.
minimize
(
d_loss
,
var_list
=
model
.
d_vars
,
colocate_gradients_with_ops
=
True
,
name
=
'd_op'
)
# Define the training iteration
self
.
train_op
=
d_min
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
class
RandomZData
(
DataFlow
):
...
...
tensorpack/train/base.py
View file @
2870347c
...
...
@@ -135,9 +135,16 @@ class Trainer(object):
def
_register_callback
(
self
,
cb
):
"""
Register
a callback
to the trainer.
Register
callbacks
to the trainer.
It can only be called before :meth:`Trainer.train()`.
Args:
cb (Callback or [Callback]): a callback or a list of callbacks
"""
if
isinstance
(
cb
,
(
list
,
tuple
)):
for
x
in
cb
:
self
.
_register_callback
(
x
)
return
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
...
...
tensorpack/train/tower.py
View file @
2870347c
...
...
@@ -145,9 +145,7 @@ class SingleCostTrainer(TowerTrainer):
# TODO setup may want to register monitor as well??
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
internal_callbacks
=
input_callbacks
+
train_callbacks
for
cb
in
internal_callbacks
:
self
.
register_callback
(
cb
)
self
.
register_callback
(
input_callbacks
+
train_callbacks
)
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment