Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
118c2a26
Commit
118c2a26
authored
May 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-up some deprecations
parent
01486c39
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
29 additions
and
58 deletions
+29
-58
examples/GAN/README.md
examples/GAN/README.md
+1
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+12
-20
tensorpack/train/config.py
tensorpack/train/config.py
+3
-28
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+3
-2
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+4
-3
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
tests/dev/git-hooks/pre-commit
tests/dev/git-hooks/pre-commit
+4
-2
No files found.
examples/GAN/README.md
View file @
118c2a26
...
@@ -18,7 +18,7 @@ Reproduce the following GAN-related methods:
...
@@ -18,7 +18,7 @@ Reproduce the following GAN-related methods:
+
BEGAN (
[
BEGAN: Boundary Equilibrium Generative Adversarial Networks
](
https://arxiv.org/abs/1703.10717
)
)
+
BEGAN (
[
BEGAN: Boundary Equilibrium Generative Adversarial Networks
](
https://arxiv.org/abs/1703.10717
)
)
Please see the __docstring__ in each script for detailed usage and pretrained models.
Please see the __docstring__ in each script for detailed usage and pretrained models.
MultiGPU training is supported.
## DCGAN.py
## DCGAN.py
...
...
tensorpack/models/model_desc.py
View file @
118c2a26
...
@@ -10,14 +10,11 @@ import six
...
@@ -10,14 +10,11 @@ import six
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
INPUTS_KEY
from
..utils.naming
import
INPUTS_KEY
from
..utils.develop
import
deprecated
,
log_deprecated
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..tfutils.model_utils
import
apply_slim_collections
from
..tfutils.model_utils
import
apply_slim_collections
__all__
=
[
'InputDesc'
,
'InputVar'
,
'ModelDesc'
,
'ModelFromMetaGraph'
]
__all__
=
[
'InputDesc'
,
'InputVar'
,
'ModelDesc'
,
'ModelFromMetaGraph'
]
# TODO "variable" is not the right name to use for input here.
class
InputDesc
(
object
):
class
InputDesc
(
object
):
""" Store metadata about input placeholders. """
""" Store metadata about input placeholders. """
...
@@ -50,7 +47,8 @@ class InputVar(InputDesc):
...
@@ -50,7 +47,8 @@ class InputVar(InputDesc):
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
class
ModelDesc
(
object
):
""" Base class for a model description """
""" Base class for a model description.
"""
# inputs:
# inputs:
@
memoized
@
memoized
...
@@ -63,11 +61,6 @@ class ModelDesc(object):
...
@@ -63,11 +61,6 @@ class ModelDesc(object):
"""
"""
return
self
.
build_placeholders
()
return
self
.
build_placeholders
()
@
deprecated
(
"Use get_reused_placehdrs() instead."
,
"2017-04-11"
)
def
get_input_vars
(
self
):
# this wasn't a public API anyway
return
self
.
get_reused_placehdrs
()
def
build_placeholders
(
self
,
prefix
=
''
):
def
build_placeholders
(
self
,
prefix
=
''
):
"""
"""
For each InputDesc, create new placeholders with optional prefix and
For each InputDesc, create new placeholders with optional prefix and
...
@@ -76,12 +69,12 @@ class ModelDesc(object):
...
@@ -76,12 +69,12 @@ class ModelDesc(object):
Returns:
Returns:
list[tf.Tensor]: the list of built placeholders.
list[tf.Tensor]: the list of built placeholders.
"""
"""
input
_var
s
=
self
.
_get_inputs
()
inputs
=
self
.
_get_inputs
()
for
v
in
input
_var
s
:
for
v
in
inputs
:
tf
.
add_to_collection
(
INPUTS_KEY
,
v
.
dumps
())
tf
.
add_to_collection
(
INPUTS_KEY
,
v
.
dumps
())
ret
=
[]
ret
=
[]
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
for
v
in
input
_var
s
:
for
v
in
inputs
:
placehdr_f
=
tf
.
placeholder
if
not
v
.
sparse
else
tf
.
sparse_placeholder
placehdr_f
=
tf
.
placeholder
if
not
v
.
sparse
else
tf
.
sparse_placeholder
ret
.
append
(
placehdr_f
(
ret
.
append
(
placehdr_f
(
v
.
type
,
shape
=
v
.
shape
,
v
.
type
,
shape
=
v
.
shape
,
...
@@ -95,15 +88,11 @@ class ModelDesc(object):
...
@@ -95,15 +88,11 @@ class ModelDesc(object):
"""
"""
return
self
.
_get_inputs
()
return
self
.
_get_inputs
()
def
_get_inputs
(
self
):
# this is a better name than _get_input_vars
@
abstractmethod
def
_get_inputs
(
self
):
"""
"""
:returns: a list of InputDesc
:returns: a list of InputDesc
"""
"""
log_deprecated
(
""
,
"_get_input_vars() was renamed to _get_inputs()."
,
"2017-04-11"
)
return
self
.
_get_input_vars
()
def
_get_input_vars
(
self
):
# keep backward compatibility
raise
NotImplementedError
()
def
build_graph
(
self
,
model_inputs
):
def
build_graph
(
self
,
model_inputs
):
"""
"""
...
@@ -142,8 +131,8 @@ class ModelDesc(object):
...
@@ -142,8 +131,8 @@ class ModelDesc(object):
def
get_optimizer
(
self
):
def
get_optimizer
(
self
):
"""
"""
Return the optimizer used in the task.
Return the optimizer used in the task.
Used by some of the tensorpack :class:`Trainer` which
only uses a
single optimizer.
Used by some of the tensorpack :class:`Trainer` which
assume
single optimizer.
You can
ignore this method if you use your own
trainer with more than one optimizers.
You can
(and should) ignore this method if you use a custom
trainer with more than one optimizers.
Users of :class:`ModelDesc` will need to implement `_get_optimizer()`,
Users of :class:`ModelDesc` will need to implement `_get_optimizer()`,
which will only be called once per each model.
which will only be called once per each model.
...
@@ -157,6 +146,9 @@ class ModelDesc(object):
...
@@ -157,6 +146,9 @@ class ModelDesc(object):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_gradient_processor
(
self
):
def
get_gradient_processor
(
self
):
return
self
.
_get_gradient_processor
()
def
_get_gradient_processor
(
self
):
return
[]
return
[]
...
...
tensorpack/train/config.py
View file @
118c2a26
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
# File: config.py
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..callbacks
import
(
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
Callbacks
,
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
ProgressBar
,
MergeAllSummaries
,
...
@@ -15,7 +13,6 @@ from ..utils.develop import log_deprecated
...
@@ -15,7 +13,6 @@ from ..utils.develop import log_deprecated
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.optimizer
import
apply_grad_processors
from
.input_source
import
InputSource
from
.input_source
import
InputSource
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
...
@@ -154,15 +151,9 @@ class TrainConfig(object):
...
@@ -154,15 +151,9 @@ class TrainConfig(object):
assert
len
(
set
(
self
.
predict_tower
))
==
len
(
self
.
predict_tower
),
\
assert
len
(
set
(
self
.
predict_tower
))
==
len
(
self
.
predict_tower
),
\
"Cannot have duplicated predict_tower!"
"Cannot have duplicated predict_tower!"
if
'optimizer'
in
kwargs
:
assert
'optimizer'
not
in
kwargs
,
\
log_deprecated
(
"TrainConfig(optimizer=...)"
,
"TrainConfig(optimizer=...) was already deprecated! "
\
"Use ModelDesc._get_optimizer() instead."
,
"Use ModelDesc._get_optimizer() instead."
"2017-04-12"
)
self
.
_optimizer
=
kwargs
.
pop
(
'optimizer'
)
assert_type
(
self
.
_optimizer
,
tf
.
train
.
Optimizer
)
else
:
self
.
_optimizer
=
None
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
@
property
@
property
...
@@ -176,19 +167,3 @@ class TrainConfig(object):
...
@@ -176,19 +167,3 @@ class TrainConfig(object):
@
property
@
property
def
callbacks
(
self
):
# disable setter
def
callbacks
(
self
):
# disable setter
return
self
.
_callbacks
return
self
.
_callbacks
@
property
def
optimizer
(
self
):
""" for back-compatibilty only. will remove in the future"""
if
self
.
_optimizer
:
opt
=
self
.
_optimizer
else
:
opt
=
self
.
model
.
get_optimizer
()
gradproc
=
self
.
model
.
get_gradient_processor
()
if
gradproc
:
log_deprecated
(
"ModelDesc.get_gradient_processor()"
,
"Use gradient processor to build an optimizer instead."
,
"2017-04-12"
)
opt
=
apply_grad_processors
(
opt
,
gradproc
)
if
not
self
.
_optimizer
:
self
.
_optimizer
=
opt
return
opt
tensorpack/train/feedfree.py
View file @
118c2a26
...
@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
""" get the cost and gradient"""
self
.
build_train_tower
()
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
# assume single cost
cost
=
self
.
model
.
get_cost
()
# assume single cost
opt
=
self
.
config
.
optimizer
# TODO XXX
opt
=
self
.
model
.
get_optimizer
()
# GATE_NONE faster?
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
grads
=
opt
.
compute_gradients
(
cost
,
cost
,
...
@@ -96,7 +96,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
...
@@ -96,7 +96,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
cost
,
grads
=
self
.
_get_cost_and_grad
()
cost
,
grads
=
self
.
_get_cost_and_grad
()
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
# skip training
# skip training
# self.train_op = tf.group(*self._input_tensors)
# self.train_op = tf.group(*self._input_tensors)
...
...
tensorpack/train/input_source.py
View file @
118c2a26
...
@@ -364,8 +364,8 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -364,8 +364,8 @@ class StagingInputWrapper(FeedfreeInput):
devices: list of devices to be used for each training tower
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
nr_stage: number of elements to prefetch
"""
"""
assert
isinstance
(
input
,
FeedfreeInput
),
input
self
.
_input
=
input
self
.
_input
=
input
assert
isinstance
(
input
,
FeedfreeInput
)
self
.
_devices
=
devices
self
.
_devices
=
devices
self
.
_nr_stage
=
nr_stage
self
.
_nr_stage
=
nr_stage
self
.
_areas
=
[]
self
.
_areas
=
[]
...
...
tensorpack/train/multigpu.py
View file @
118c2a26
...
@@ -167,7 +167,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
...
@@ -167,7 +167,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
grads
=
SyncMultiGPUTrainerParameterServer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainerParameterServer
.
_average_grads
(
grad_list
)
# grads = grad_list[0]
# grads = grad_list[0]
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
self
.
train_op
=
self
.
model
.
get_optimizer
()
.
apply_gradients
(
grads
,
name
=
'min_op'
)
def
SyncMultiGPUTrainer
(
config
):
def
SyncMultiGPUTrainer
(
config
):
...
@@ -217,7 +217,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
...
@@ -217,7 +217,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
grad_list
=
[
gradproc
.
process
(
gv
)
for
gv
in
grad_list
]
grad_list
=
[
gradproc
.
process
(
gv
)
for
gv
in
grad_list
]
# use grad from the first tower for iteration in main thread
# use grad from the first tower for iteration in main thread
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
name
=
'min_op'
)
self
.
_opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
self
.
_opt
.
apply_gradients
(
grad_list
[
0
],
name
=
'min_op'
)
self
.
_start_async_threads
(
grad_list
)
self
.
_start_async_threads
(
grad_list
)
...
@@ -227,7 +228,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
...
@@ -227,7 +228,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
self
.
async_step_counter
=
itertools
.
count
()
self
.
async_step_counter
=
itertools
.
count
()
self
.
training_threads
=
[]
self
.
training_threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
train_op
=
self
.
_opt
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
# TODO this won't work with StageInput
self
.
sess
.
run
([
op
])
# TODO this won't work with StageInput
...
...
tensorpack/train/trainer.py
View file @
118c2a26
...
@@ -42,5 +42,5 @@ class SimpleTrainer(Trainer):
...
@@ -42,5 +42,5 @@ class SimpleTrainer(Trainer):
model
.
build_graph
(
self
.
inputs
)
model
.
build_graph
(
self
.
inputs
)
cost_var
=
model
.
get_cost
()
cost_var
=
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
minimize
(
cost_var
,
name
=
'min_op'
)
self
.
train_op
=
opt
.
minimize
(
cost_var
,
name
=
'min_op'
)
tests/dev/git-hooks/pre-commit
View file @
118c2a26
...
@@ -4,7 +4,9 @@ flake8 .
...
@@ -4,7 +4,9 @@ flake8 .
cd
examples
cd
examples
GIT_ARG
=
"--git-dir ../.git --work-tree .."
GIT_ARG
=
"--git-dir ../.git --work-tree .."
# find out modified python files
# find out modified python files
, so that we ignored unstaged files
MOD
=
$(
git
$GIT_ARG
status
-s
|
grep
-E
'\.py$'
|
grep
-E
'^ *M|^ *A '
|
cut
-c
4-
)
MOD
=
$(
git
$GIT_ARG
status
-s
|
grep
-E
'\.py$'
|
grep
-E
'^ *M|^ *A '
|
cut
-c
4-
)
# git $GIT_ARG status -s | grep -E '\.py$'
# git $GIT_ARG status -s | grep -E '\.py$'
flake8
$MOD
if
[[
-n
$MOD
]]
;
then
flake8
$MOD
fi
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment