Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
25b31f68
Commit
25b31f68
authored
Jul 01, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-up trainv1
parent
0b2f3c11
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
9 additions
and
668 deletions
+9
-668
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+5
-1
tensorpack/train/base.py
tensorpack/train/base.py
+4
-18
tensorpack/trainv1/__init__.py
tensorpack/trainv1/__init__.py
+0
-31
tensorpack/trainv1/base.py
tensorpack/trainv1/base.py
+0
-292
tensorpack/trainv1/config.py
tensorpack/trainv1/config.py
+0
-7
tensorpack/trainv1/distributed.py
tensorpack/trainv1/distributed.py
+0
-98
tensorpack/trainv1/interface.py
tensorpack/trainv1/interface.py
+0
-7
tensorpack/trainv1/multigpu.py
tensorpack/trainv1/multigpu.py
+0
-140
tensorpack/trainv1/simple.py
tensorpack/trainv1/simple.py
+0
-68
tensorpack/trainv1/utility.py
tensorpack/trainv1/utility.py
+0
-6
No files found.
examples/FasterRCNN/README.md
View file @
25b31f68
# Faster-RCNN / Mask-RCNN on COCO
# Faster-RCNN / Mask-RCNN on COCO
This example provides a minimal (
only 1.6
k lines) and faithful implementation of the following papers:
This example provides a minimal (
<2
k lines) and faithful implementation of the following papers:
+
[
Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks
](
https://arxiv.org/abs/1506.01497
)
+
[
Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks
](
https://arxiv.org/abs/1506.01497
)
+
[
Feature Pyramid Networks for Object Detection
](
https://arxiv.org/abs/1612.03144
)
+
[
Feature Pyramid Networks for Object Detection
](
https://arxiv.org/abs/1612.03144
)
+
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
+
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
with the support of:
+
Multi-GPU / distributed training
+
[
Cross-GPU BatchNorm
](
https://arxiv.org/abs/1711.07240
)
## Dependencies
## Dependencies
+
Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+
Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
...
...
tensorpack/train/base.py
View file @
25b31f68
...
@@ -101,11 +101,6 @@ class Trainer(object):
...
@@ -101,11 +101,6 @@ class Trainer(object):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
config is only for compatibility reasons in case you're
using custom trainers with old-style API.
You should never use config.
"""
self
.
_callbacks
=
[]
self
.
_callbacks
=
[]
self
.
loop
=
TrainLoop
()
self
.
loop
=
TrainLoop
()
...
@@ -310,22 +305,13 @@ class Trainer(object):
...
@@ -310,22 +305,13 @@ class Trainer(object):
session_creator
,
session_init
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
steps_per_epoch
,
starting_epoch
,
max_epoch
)
# create the old trainer when called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
or
'config'
in
kwargs
:
or
'config'
in
kwargs
:
name
=
cls
.
__name__
logger
.
error
(
"You're calling new trainers with old trainer API!"
)
try
:
logger
.
error
(
"See https://github.com/tensorpack/tensorpack/issues/458 for more information."
)
import
tensorpack.trainv1
as
old_train_mod
# noqa
import
sys
old_trainer
=
getattr
(
old_train_mod
,
name
)
sys
.
exit
(
1
)
except
AttributeError
:
# custom trainer. has to live with it
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
else
:
logger
.
warn
(
"You're calling new trainers with old trainer API!"
)
logger
.
warn
(
"Now it returns the old trainer for you, please switch to use new trainers soon!"
)
logger
.
warn
(
"See https://github.com/tensorpack/tensorpack/issues/458 for more information."
)
return
old_trainer
(
*
args
,
**
kwargs
)
else
:
else
:
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
...
...
tensorpack/trainv1/__init__.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: __init__.py
from
pkgutil
import
iter_modules
import
os
import
os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
del
globals
()[
name
]
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
__all__
.
append
(
k
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_SKIP
=
[
'utility'
]
for
_
,
module_name
,
_
in
iter_modules
(
[
_CURR_DIR
]):
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
if
not
os
.
path
.
isfile
(
srcpath
):
continue
if
module_name
.
startswith
(
'_'
):
continue
if
module_name
not
in
_SKIP
:
global_import
(
module_name
)
tensorpack/trainv1/base.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: base.py
import
time
import
weakref
import
six
from
six.moves
import
range
import
tensorflow
as
tf
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.tower
import
TowerFuncWrapper
from
..input_source
import
PlaceholderInput
from
..graph_builder.predict
import
SimplePredictBuilder
from
..predict.base
import
OnlinePredictor
from
..callbacks.steps
import
MaintainStepCounter
from
..train.base
import
StopTraining
,
TrainLoop
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
Trainer
(
object
):
""" Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
is_chief
=
True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def
__init__
(
self
,
config
):
"""
Args:
config (TrainConfig): the train config.
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
config
.
_deprecated_parsing
()
self
.
_config
=
config
self
.
model
=
config
.
model
if
self
.
model
is
not
None
:
def
f
(
*
inputs
):
self
.
model
.
build_graph
(
*
inputs
)
"""
Only to mimic new trainer interafce on inference.
"""
self
.
inputs_desc
=
self
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
f
,
self
.
inputs_desc
)
self
.
_callbacks
=
[]
self
.
_monitors
=
[]
self
.
loop
=
TrainLoop
()
self
.
loop
.
config
(
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
self
.
_setup
()
# subclass will setup the graph and InputSource
def
register_callback
(
self
,
cb
):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
if
not
self
.
is_chief
and
cb
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
cb
)))
else
:
self
.
_callbacks
.
append
(
cb
)
def
register_monitor
(
self
,
mon
):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert
isinstance
(
mon
,
TrainingMonitor
),
mon
assert
not
isinstance
(
self
.
_monitors
,
Monitors
),
\
"Cannot register more monitors after trainer was setup!"
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"Monitor {} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
self
.
_monitors
.
append
(
mon
)
self
.
register_callback
(
mon
)
@
property
def
monitors
(
self
):
assert
isinstance
(
self
.
_monitors
,
Monitors
),
"Monitors haven't been setup!"
return
self
.
_monitors
def
train
(
self
):
""" Start training """
self
.
setup
()
self
.
main_loop
()
def
run_step
(
self
):
"""
Defines what to do in one iteration. The default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
if
not
hasattr
(
self
,
'train_op'
):
raise
NotImplementedError
(
"Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!"
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
def
setup
(
self
):
"""
Setup the trainer and be ready for the main loop.
"""
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
self
.
_config
.
callbacks
:
self
.
register_callback
(
cb
)
for
m
in
self
.
_config
.
monitors
:
self
.
register_monitor
(
m
)
self
.
_monitors
=
Monitors
(
self
.
_monitors
)
self
.
register_callback
(
self
.
_monitors
)
describe_trainable_vars
()
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_config
.
session_init
.
_setup_graph
()
logger
.
info
(
"Creating the session ..."
)
self
.
_create_session
()
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
self
.
_config
.
session_init
.
_run_init
(
self
.
sess
)
else
:
if
not
isinstance
(
self
.
_config
.
session_init
,
JustCurrentSession
):
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
def
_create_session
(
self
):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
self
.
_config
.
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
def
_setup
(
self
):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def
main_loop
(
self
):
"""
Run the main training loop.
"""
with
self
.
sess
.
as_default
():
self
.
loop
.
update_global_step
()
try
:
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
self
.
loop
.
update_global_step
()
for
self
.
loop
.
_epoch_num
in
range
(
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
loop
.
epoch_num
))
start_time
=
time
.
time
()
self
.
_callbacks
.
before_epoch
()
for
self
.
loop
.
_local_step
in
range
(
self
.
loop
.
steps_per_epoch
):
if
self
.
hooked_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
after_epoch
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
loop
.
epoch_num
,
self
.
loop
.
global_step
,
time
.
time
()
-
start_time
))
# trigger epoch outside the timing region.
self
.
_callbacks
.
trigger_epoch
()
logger
.
info
(
"Training has finished!"
)
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
logger
.
info
(
"Training was stopped."
)
except
KeyboardInterrupt
:
logger
.
info
(
"Detected Ctrl-C and exiting main loop."
)
raise
finally
:
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
device
=
tower
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
_main_tower_vs_name
(
self
):
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
return
""
@
property
def
config
(
self
):
log_deprecated
(
"Trainer.config"
,
"It is supposed to be private! Most of its attributes can be accessed by other means."
,
"2017-12-31"
)
return
self
.
_config
# create new trainer when not called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
or
'config'
in
kwargs
:
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
else
:
import
tensorpack.train
as
new_train
name
=
cls
.
__name__
new_trainer
=
getattr
(
new_train
,
name
)
logger
.
warn
(
"You're calling old trainers with new trainer API!"
)
logger
.
warn
(
"Now it returns the new trainer for you, please `export TENSORPACK_TRAIN_API=v2`"
" to import new trainers automatically."
)
logger
.
warn
(
"You can also ignore this warning and wait for new API to become the default."
)
return
new_trainer
(
*
args
,
**
kwargs
)
def
_get_property
(
name
):
"""
Delegate property to self.loop
"""
ret
=
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
))
if
six
.
PY3
:
# __doc__ is readonly in Py2
try
:
ret
.
__doc__
=
getattr
(
TrainLoop
,
name
)
.
__doc__
except
AttributeError
:
pass
return
ret
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
tensorpack/trainv1/config.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: config.py
__all__
=
[
'TrainConfig'
]
from
..train.config
import
TrainConfig
tensorpack/trainv1/distributed.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: distributed.py
import
os
from
..utils
import
logger
from
..callbacks
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.utils
import
override_to_local_variable
from
.base
import
Trainer
__all__
=
[
'DistributedTrainerReplicated'
]
class
DistributedTrainerReplicated
(
Trainer
):
__doc__
=
DistributedReplicatedBuilder
.
__doc__
def
__init__
(
self
,
config
,
server
):
"""
Args:
config(TrainConfig): Must contain 'model' and 'data'.
server(tf.train.Server): the server object with ps and workers
"""
assert
config
.
data
is
not
None
and
config
.
model
is
not
None
self
.
server
=
server
self
.
job_name
=
server
.
server_def
.
job_name
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
config
.
tower
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
self
.
_input_source
=
config
.
data
super
(
DistributedTrainerReplicated
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
if
self
.
job_name
==
'ps'
:
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this will never return tensorflow#4713
return
with
override_to_local_variable
():
get_global_step_var
()
# gs should be local
# input source may create variable (queue size summary)
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
_config
.
callbacks
.
extend
(
cbs
)
self
.
train_op
,
initial_sync_op
,
model_sync_op
=
self
.
_builder
.
build
(
lambda
:
self
.
model
.
_build_graph_get_grads
(
*
self
.
_input_source
.
get_input_tensors
()),
self
.
model
.
get_optimizer
)
# initial local_vars syncing
cb
=
RunOp
(
lambda
:
initial_sync_op
,
run_before
=
True
,
run_as_trigger
=
False
,
verbose
=
True
)
cb
.
chief_only
=
False
self
.
register_callback
(
cb
)
# model_variables syncing
if
model_sync_op
:
cb
=
RunOp
(
lambda
:
model_sync_op
,
run_before
=
False
,
run_as_trigger
=
True
,
verbose
=
True
)
logger
.
warn
(
"For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequently than this."
)
self
.
register_callback
(
cb
)
self
.
_set_session_creator
()
def
_set_session_creator
(
self
):
old_sess_creator
=
self
.
_config
.
session_creator
if
not
isinstance
(
old_sess_creator
,
NewSessionCreator
)
\
or
old_sess_creator
.
user_provided_config
:
raise
ValueError
(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
self
.
_config
.
session_creator
=
get_distributed_session_creator
(
self
.
server
)
@
property
def
_main_tower_vs_name
(
self
):
return
"tower0"
tensorpack/trainv1/interface.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: interface.py
__all__
=
[
'launch_train_with_config'
]
from
..train.interface
import
launch_train_with_config
tensorpack/trainv1/multigpu.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: multigpu.py
import
tensorflow
as
tf
from
..callbacks.graph
import
RunOp
from
..utils.develop
import
log_deprecated
from
..input_source
import
QueueInput
,
StagingInput
,
DummyConstantInput
from
..graph_builder.training
import
(
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
AsyncMultiGPUBuilder
,
DataParallelBuilder
)
from
.base
import
Trainer
__all__
=
[
'MultiGPUTrainerBase'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerParameterServer'
,
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainerBase
(
Trainer
):
"""
For backward compatibility only
"""
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
use_vs
=
None
):
log_deprecated
(
"MultiGPUTrainerBase.build_on_multitower"
,
"Please use DataParallelBuilder.build_on_towers"
,
"2018-01-31"
)
return
DataParallelBuilder
.
build_on_towers
(
towers
,
func
,
devices
,
use_vs
)
def
apply_prefetch_policy
(
config
,
gpu_prefetch
=
True
):
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
None
and
config
.
dataflow
is
not
None
:
# always use Queue prefetch
config
.
data
=
QueueInput
(
config
.
dataflow
)
config
.
dataflow
=
None
if
len
(
config
.
tower
)
>
1
and
gpu_prefetch
:
assert
tf
.
test
.
is_gpu_available
()
# seem to only improve on >1 GPUs
if
not
isinstance
(
config
.
data
,
(
StagingInput
,
DummyConstantInput
)):
config
.
data
=
StagingInput
(
config
.
data
)
class
SyncMultiGPUTrainerParameterServer
(
Trainer
):
__doc__
=
SyncMultiGPUParameterServerBuilder
.
__doc__
def
__init__
(
self
,
config
,
ps_device
=
'gpu'
,
gpu_prefetch
=
True
):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
"""
apply_prefetch_policy
(
config
,
gpu_prefetch
)
self
.
_input_source
=
config
.
data
assert
ps_device
in
[
'gpu'
,
'cpu'
],
ps_device
self
.
_ps_device
=
ps_device
super
(
SyncMultiGPUTrainerParameterServer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
callbacks
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
train_op
=
SyncMultiGPUParameterServerBuilder
(
self
.
_config
.
tower
,
self
.
_ps_device
)
.
build
(
lambda
:
self
.
model
.
_build_graph_get_grads
(
*
self
.
_input_source
.
get_input_tensors
()),
self
.
model
.
get_optimizer
)
self
.
_config
.
callbacks
.
extend
(
callbacks
)
def
SyncMultiGPUTrainer
(
config
):
"""
Alias for ``SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')``,
as this is the most commonly used synchronous multigpu trainer (but may
not be more efficient than the other).
"""
return
SyncMultiGPUTrainerParameterServer
(
config
,
ps_device
=
'gpu'
)
class
SyncMultiGPUTrainerReplicated
(
Trainer
):
__doc__
=
SyncMultiGPUReplicatedBuilder
.
__doc__
def
__init__
(
self
,
config
,
gpu_prefetch
=
True
):
"""
Args:
config, gpu_prefetch: same as in :class:`SyncMultiGPUTrainerParameterServer`
"""
apply_prefetch_policy
(
config
,
gpu_prefetch
)
self
.
_input_source
=
config
.
data
super
(
SyncMultiGPUTrainerReplicated
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
callbacks
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
train_op
,
post_init_op
=
SyncMultiGPUReplicatedBuilder
(
self
.
_config
.
tower
)
.
build
(
lambda
:
self
.
model
.
_build_graph_get_grads
(
*
self
.
_input_source
.
get_input_tensors
()),
self
.
model
.
get_optimizer
)
cb
=
RunOp
(
lambda
:
post_init_op
,
run_before
=
True
,
run_as_trigger
=
True
,
verbose
=
True
)
self
.
_config
.
callbacks
.
extend
(
callbacks
+
[
cb
])
class
AsyncMultiGPUTrainer
(
Trainer
):
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
def
__init__
(
self
,
config
,
scale_gradient
=
True
):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
apply_prefetch_policy
(
config
)
self
.
_input_source
=
config
.
data
self
.
_scale_gradient
=
scale_gradient
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
callbacks
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
train_op
=
AsyncMultiGPUBuilder
(
self
.
_config
.
tower
,
self
.
_scale_gradient
)
.
build
(
lambda
:
self
.
model
.
_build_graph_get_grads
(
*
self
.
_input_source
.
get_input_tensors
()),
self
.
model
.
get_optimizer
)
self
.
_config
.
callbacks
.
extend
(
callbacks
)
tensorpack/trainv1/simple.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: simple.py
from
.base
import
Trainer
from
..tfutils.tower
import
TowerContext
from
..utils
import
logger
from
..input_source
import
FeedInput
,
QueueInput
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
class
SimpleTrainer
(
Trainer
):
""" A naive single-tower single-cost demo trainer.
It simply builds one tower and minimize `model.cost`.
It supports both InputSource and DataFlow.
When DataFlow is given instead of InputSource, the InputSource to be
used will be ``FeedInput(df)`` (no prefetch).
"""
def
__init__
(
self
,
config
):
"""
Args:
config (TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
"""
assert
len
(
config
.
tower
)
==
1
,
\
"Got nr_tower={}, but doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
config
.
tower
))
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
dataflow
is
None
:
self
.
_input_source
=
config
.
data
else
:
self
.
_input_source
=
FeedInput
(
config
.
dataflow
)
logger
.
warn
(
"FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead."
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
with
TowerContext
(
''
,
is_training
=
True
):
grads
=
self
.
model
.
_build_graph_get_grads
(
*
self
.
_input_source
.
get_input_tensors
())
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
self
.
_config
.
callbacks
.
extend
(
cbs
)
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
not
None
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
else
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
dataflow
=
None
return
SimpleTrainer
(
config
)
tensorpack/trainv1/utility.py
deleted
100644 → 0
View file @
0b2f3c11
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from
..graph_builder.utils
import
(
# noqa
override_to_local_variable
,
LeastLoadedDeviceSetter
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment