Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
712ea325
Commit
712ea325
authored
Nov 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use getter and setter for trainer.tower_func, instead of `set_tower_func`.
parent
a988fc18
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
23 deletions
+58
-23
docs/conf.py
docs/conf.py
+1
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+6
-9
tensorpack/train/tower.py
tensorpack/train/tower.py
+22
-10
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+29
-4
No files found.
docs/conf.py
View file @
712ea325
...
...
@@ -372,6 +372,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'DumpTensor'
,
'StagingInputWrapper'
,
'StepTensorPrinter'
,
'set_tower_func'
,
'guided_relu'
,
'saliency_map'
,
'get_scalar_var'
,
'prediction_incorrect'
,
'huber_loss'
,
...
...
examples/GAN/GAN.py
View file @
712ea325
...
...
@@ -72,9 +72,9 @@ class GANTrainer(TowerTrainer):
# we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
inputs_desc
)
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
inputs_desc
)
with
TowerContext
(
''
,
is_training
=
True
):
tower_func
(
*
input
.
get_input_tensors
())
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
# by default, run one d_min after one g_min
...
...
@@ -83,7 +83,6 @@ class GANTrainer(TowerTrainer):
with
tf
.
control_dependencies
([
g_min
]):
d_min
=
opt
.
minimize
(
model
.
d_loss
,
var_list
=
model
.
d_vars
,
name
=
'd_op'
)
self
.
train_op
=
d_min
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
...
...
@@ -103,9 +102,9 @@ class SeparateGANTrainer(TowerTrainer):
assert
min
(
d_period
,
g_period
)
==
1
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
())
with
TowerContext
(
''
,
is_training
=
True
):
tower_func
(
*
input
.
get_input_tensors
())
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
with
tf
.
name_scope
(
'optimize'
):
...
...
@@ -114,7 +113,6 @@ class SeparateGANTrainer(TowerTrainer):
self
.
g_min
=
opt
.
minimize
(
model
.
g_loss
,
var_list
=
model
.
g_vars
,
name
=
'g_min'
)
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
...
...
@@ -142,11 +140,11 @@ class MultiGPUGANTrainer(TowerTrainer):
def
get_cost
(
*
inputs
):
model
.
build_graph
(
*
inputs
)
return
[
model
.
d_loss
,
model
.
g_loss
]
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
nr_gpu
)),
lambda
:
tower_func
(
*
input
.
get_input_tensors
()),
lambda
:
self
.
tower_func
(
*
input
.
get_input_tensors
()),
devices
)
# simply average the cost. It might get faster to average the gradients
with
tf
.
name_scope
(
'optimize'
):
...
...
@@ -161,7 +159,6 @@ class MultiGPUGANTrainer(TowerTrainer):
d_min
=
opt
.
minimize
(
d_loss
,
var_list
=
model
.
d_vars
,
colocate_gradients_with_ops
=
True
,
name
=
'd_op'
)
self
.
train_op
=
d_min
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
...
...
tensorpack/train/tower.py
View file @
712ea325
...
...
@@ -7,6 +7,7 @@ import six
from
abc
import
abstractmethod
,
ABCMeta
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.develop
import
deprecated
from
..graph_builder.predict
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
...
...
@@ -25,22 +26,33 @@ class TowerTrainer(Trainer):
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
To use features of :class:`TowerTrainer`, set `tower_func` and use it to build the graph.
Note that `tower_func` can only be set once per instance.
"""
_tower_func
=
None
@
call_only_once
def
_set_tower_func
(
self
,
tower_func
):
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
_tower_func
=
tower_func
@
deprecated
(
"Just use tower_func = xxx instead!"
)
def
set_tower_func
(
self
,
tower_func
):
self
.
_set_tower_func
(
tower_func
)
@
property
def
tower_func
(
self
):
"""
A
rgs:
tower_func (TowerFuncWrapper)
A
:class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
return
self
.
_tower_func
@
tower_func
.
setter
def
tower_func
(
self
,
val
):
self
.
_set_tower_func
(
val
)
@
property
def
inputs_desc
(
self
):
...
...
@@ -128,7 +140,7 @@ class SingleCostTrainer(TowerTrainer):
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
self
.
tower_func
=
get_cost_fn
# TODO setup may want to register monitor as well??
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
...
...
tensorpack/utils/argtools.py
View file @
712ea325
...
...
@@ -147,19 +147,25 @@ _FUNC_CALLED = set()
def
call_only_once
(
func
):
"""
Decorate a method of a class, so that this method can only
Decorate a method o
r property o
f a class, so that this method can only
be called once for every instance.
Calling it more than once will result in exception.
"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
self
=
args
[
0
]
assert
hasattr
(
self
,
func
.
__name__
),
"call_only_once can only be used on method!"
# cannot use hasattr here, because hasattr tries to getattr, which
# fails if func is a property
assert
func
.
__name__
in
dir
(
self
),
"call_only_once can only be used on method or property!"
cls
=
type
(
self
)
# cannot use ismethod(), because decorated method becomes a function
is_method
=
inspect
.
isfunction
(
getattr
(
cls
,
func
.
__name__
))
key
=
(
self
,
func
)
assert
key
not
in
_FUNC_CALLED
,
\
"Method {}.{} can only be called once per object!"
.
format
(
type
(
self
)
.
__name__
,
func
.
__name__
)
"{} {}.{} can only be called once per object!"
.
format
(
'Method'
if
is_method
else
'Property'
,
cls
.
__name__
,
func
.
__name__
)
_FUNC_CALLED
.
add
(
key
)
return
func
(
*
args
,
**
kwargs
)
...
...
@@ -169,13 +175,32 @@ def call_only_once(func):
if
__name__
==
'__main__'
:
class
A
():
def
__init__
(
self
):
self
.
_p
=
0
@
call_only_once
def
f
(
self
,
x
):
print
(
x
)
@
property
def
p
(
self
):
return
self
.
_p
@
p
.
setter
@
call_only_once
def
p
(
self
,
val
):
self
.
_p
=
val
a
=
A
()
a
.
f
(
1
)
b
=
A
()
b
.
f
(
2
)
b
.
f
(
1
)
print
(
b
.
p
)
print
(
b
.
p
)
b
.
p
=
2
print
(
b
.
p
)
b
.
p
=
3
print
(
b
.
p
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment