Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
712ea325
Commit
712ea325
authored
Nov 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use getter and setter for trainer.tower_func, instead of `set_tower_func`.
parent
a988fc18
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
23 deletions
+58
-23
docs/conf.py
docs/conf.py
+1
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+6
-9
tensorpack/train/tower.py
tensorpack/train/tower.py
+22
-10
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+29
-4
No files found.
docs/conf.py
View file @
712ea325
...
@@ -372,6 +372,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -372,6 +372,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'DumpTensor'
,
'DumpTensor'
,
'StagingInputWrapper'
,
'StagingInputWrapper'
,
'StepTensorPrinter'
,
'StepTensorPrinter'
,
'set_tower_func'
,
'guided_relu'
,
'saliency_map'
,
'get_scalar_var'
,
'guided_relu'
,
'saliency_map'
,
'get_scalar_var'
,
'prediction_incorrect'
,
'huber_loss'
,
'prediction_incorrect'
,
'huber_loss'
,
...
...
examples/GAN/GAN.py
View file @
712ea325
...
@@ -72,9 +72,9 @@ class GANTrainer(TowerTrainer):
...
@@ -72,9 +72,9 @@ class GANTrainer(TowerTrainer):
# we need to set towerfunc because it's a TowerTrainer,
# we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training.
# and only TowerTrainer supports automatic graph creation for inference during training.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
inputs_desc
)
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
inputs_desc
)
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
tower_func
(
*
input
.
get_input_tensors
())
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
opt
=
model
.
get_optimizer
()
# by default, run one d_min after one g_min
# by default, run one d_min after one g_min
...
@@ -83,7 +83,6 @@ class GANTrainer(TowerTrainer):
...
@@ -83,7 +83,6 @@ class GANTrainer(TowerTrainer):
with
tf
.
control_dependencies
([
g_min
]):
with
tf
.
control_dependencies
([
g_min
]):
d_min
=
opt
.
minimize
(
model
.
d_loss
,
var_list
=
model
.
d_vars
,
name
=
'd_op'
)
d_min
=
opt
.
minimize
(
model
.
d_loss
,
var_list
=
model
.
d_vars
,
name
=
'd_op'
)
self
.
train_op
=
d_min
self
.
train_op
=
d_min
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
self
.
register_callback
(
cb
)
...
@@ -103,9 +102,9 @@ class SeparateGANTrainer(TowerTrainer):
...
@@ -103,9 +102,9 @@ class SeparateGANTrainer(TowerTrainer):
assert
min
(
d_period
,
g_period
)
==
1
assert
min
(
d_period
,
g_period
)
==
1
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
())
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
tower_func
(
*
input
.
get_input_tensors
())
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
opt
=
model
.
get_optimizer
()
with
tf
.
name_scope
(
'optimize'
):
with
tf
.
name_scope
(
'optimize'
):
...
@@ -114,7 +113,6 @@ class SeparateGANTrainer(TowerTrainer):
...
@@ -114,7 +113,6 @@ class SeparateGANTrainer(TowerTrainer):
self
.
g_min
=
opt
.
minimize
(
self
.
g_min
=
opt
.
minimize
(
model
.
g_loss
,
var_list
=
model
.
g_vars
,
name
=
'g_min'
)
model
.
g_loss
,
var_list
=
model
.
g_vars
,
name
=
'g_min'
)
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
self
.
register_callback
(
cb
)
...
@@ -142,11 +140,11 @@ class MultiGPUGANTrainer(TowerTrainer):
...
@@ -142,11 +140,11 @@ class MultiGPUGANTrainer(TowerTrainer):
def
get_cost
(
*
inputs
):
def
get_cost
(
*
inputs
):
model
.
build_graph
(
*
inputs
)
model
.
build_graph
(
*
inputs
)
return
[
model
.
d_loss
,
model
.
g_loss
]
return
[
model
.
d_loss
,
model
.
g_loss
]
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
nr_gpu
)),
list
(
range
(
nr_gpu
)),
lambda
:
tower_func
(
*
input
.
get_input_tensors
()),
lambda
:
self
.
tower_func
(
*
input
.
get_input_tensors
()),
devices
)
devices
)
# simply average the cost. It might get faster to average the gradients
# simply average the cost. It might get faster to average the gradients
with
tf
.
name_scope
(
'optimize'
):
with
tf
.
name_scope
(
'optimize'
):
...
@@ -161,7 +159,6 @@ class MultiGPUGANTrainer(TowerTrainer):
...
@@ -161,7 +159,6 @@ class MultiGPUGANTrainer(TowerTrainer):
d_min
=
opt
.
minimize
(
d_loss
,
var_list
=
model
.
d_vars
,
d_min
=
opt
.
minimize
(
d_loss
,
var_list
=
model
.
d_vars
,
colocate_gradients_with_ops
=
True
,
name
=
'd_op'
)
colocate_gradients_with_ops
=
True
,
name
=
'd_op'
)
self
.
train_op
=
d_min
self
.
train_op
=
d_min
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
self
.
register_callback
(
cb
)
...
...
tensorpack/train/tower.py
View file @
712ea325
...
@@ -7,6 +7,7 @@ import six
...
@@ -7,6 +7,7 @@ import six
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.develop
import
deprecated
from
..graph_builder.predict
import
SimplePredictBuilder
from
..graph_builder.predict
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..predict.base
import
OnlinePredictor
...
@@ -25,22 +26,33 @@ class TowerTrainer(Trainer):
...
@@ -25,22 +26,33 @@ class TowerTrainer(Trainer):
This is required by some features that replicates the model
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
automatically, e.g. creating a predictor.
"""
tower_func
=
None
To use features of :class:`TowerTrainer`, set `tower_func` and use it to build the graph.
"""
Note that `tower_func` can only be set once per instance.
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
"""
_tower_func
=
None
@
call_only_once
@
call_only_once
def
_set_tower_func
(
self
,
tower_func
):
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
_tower_func
=
tower_func
@
deprecated
(
"Just use tower_func = xxx instead!"
)
def
set_tower_func
(
self
,
tower_func
):
def
set_tower_func
(
self
,
tower_func
):
self
.
_set_tower_func
(
tower_func
)
@
property
def
tower_func
(
self
):
"""
"""
A
rgs:
A
:class:`TowerFuncWrapper` instance.
tower_func (TowerFuncWrapper)
A callable which takes some input tensors and builds one replicate of the model.
"""
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
return
self
.
_tower_func
self
.
tower_func
=
tower_func
@
tower_func
.
setter
def
tower_func
(
self
,
val
):
self
.
_set_tower_func
(
val
)
@
property
@
property
def
inputs_desc
(
self
):
def
inputs_desc
(
self
):
...
@@ -128,7 +140,7 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -128,7 +140,7 @@ class SingleCostTrainer(TowerTrainer):
"""
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
self
.
tower_func
=
get_cost_fn
# TODO setup may want to register monitor as well??
# TODO setup may want to register monitor as well??
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
...
...
tensorpack/utils/argtools.py
View file @
712ea325
...
@@ -147,19 +147,25 @@ _FUNC_CALLED = set()
...
@@ -147,19 +147,25 @@ _FUNC_CALLED = set()
def
call_only_once
(
func
):
def
call_only_once
(
func
):
"""
"""
Decorate a method of a class, so that this method can only
Decorate a method o
r property o
f a class, so that this method can only
be called once for every instance.
be called once for every instance.
Calling it more than once will result in exception.
Calling it more than once will result in exception.
"""
"""
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
self
=
args
[
0
]
self
=
args
[
0
]
assert
hasattr
(
self
,
func
.
__name__
),
"call_only_once can only be used on method!"
# cannot use hasattr here, because hasattr tries to getattr, which
# fails if func is a property
assert
func
.
__name__
in
dir
(
self
),
"call_only_once can only be used on method or property!"
cls
=
type
(
self
)
# cannot use ismethod(), because decorated method becomes a function
is_method
=
inspect
.
isfunction
(
getattr
(
cls
,
func
.
__name__
))
key
=
(
self
,
func
)
key
=
(
self
,
func
)
assert
key
not
in
_FUNC_CALLED
,
\
assert
key
not
in
_FUNC_CALLED
,
\
"Method {}.{} can only be called once per object!"
.
format
(
"{} {}.{} can only be called once per object!"
.
format
(
type
(
self
)
.
__name__
,
func
.
__name__
)
'Method'
if
is_method
else
'Property'
,
cls
.
__name__
,
func
.
__name__
)
_FUNC_CALLED
.
add
(
key
)
_FUNC_CALLED
.
add
(
key
)
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
...
@@ -169,13 +175,32 @@ def call_only_once(func):
...
@@ -169,13 +175,32 @@ def call_only_once(func):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
class
A
():
class
A
():
def
__init__
(
self
):
self
.
_p
=
0
@
call_only_once
@
call_only_once
def
f
(
self
,
x
):
def
f
(
self
,
x
):
print
(
x
)
print
(
x
)
@
property
def
p
(
self
):
return
self
.
_p
@
p
.
setter
@
call_only_once
def
p
(
self
,
val
):
self
.
_p
=
val
a
=
A
()
a
=
A
()
a
.
f
(
1
)
a
.
f
(
1
)
b
=
A
()
b
=
A
()
b
.
f
(
2
)
b
.
f
(
2
)
b
.
f
(
1
)
b
.
f
(
1
)
print
(
b
.
p
)
print
(
b
.
p
)
b
.
p
=
2
print
(
b
.
p
)
b
.
p
=
3
print
(
b
.
p
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment