Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a673974c
Commit
a673974c
authored
Oct 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
initial commit of new trainer interface (#318)
parent
82187086
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
552 additions
and
11 deletions
+552
-11
tensorpack/train/config.py
tensorpack/train/config.py
+16
-11
tensorpack/trainv2/__init__.py
tensorpack/trainv2/__init__.py
+31
-0
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+242
-0
tensorpack/trainv2/interface.py
tensorpack/trainv2/interface.py
+81
-0
tensorpack/trainv2/trainers.py
tensorpack/trainv2/trainers.py
+182
-0
No files found.
tensorpack/train/config.py
View file @
a673974c
...
...
@@ -83,17 +83,10 @@ class TrainConfig(object):
if
callbacks
is
None
:
callbacks
=
[]
assert_type
(
callbacks
,
list
)
if
extra_callbacks
is
None
:
extra_callbacks
=
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
self
.
_callbacks
=
callbacks
+
extra_callbacks
if
monitors
is
None
:
monitors
=
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
self
.
monitors
=
monitors
self
.
_callbacks
=
callbacks
+
\
(
extra_callbacks
or
TrainConfig
.
DEFAULT_EXTRA_CALLBACKS
())
self
.
monitors
=
monitors
or
TrainConfig
.
DEFAULT_MONITORS
()
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
...
...
@@ -155,3 +148,15 @@ class TrainConfig(object):
@
property
def
callbacks
(
self
):
# disable setter
return
self
.
_callbacks
@
staticmethod
def
DEFAULT_EXTRA_CALLBACKS
():
return
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
@
staticmethod
def
DEFAULT_MONITORS
():
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
tensorpack/trainv2/__init__.py
0 → 100644
View file @
a673974c
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
iter_modules
import
os
import
os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
del
globals
()[
name
]
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
__all__
.
append
(
k
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_SKIP
=
[]
for
_
,
module_name
,
_
in
iter_modules
(
[
_CURR_DIR
]):
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
if
not
os
.
path
.
isfile
(
srcpath
):
continue
if
module_name
.
startswith
(
'_'
):
continue
if
module_name
not
in
_SKIP
:
global_import
(
module_name
)
tensorpack/trainv2/base.py
0 → 100644
View file @
a673974c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
import
tensorflow
as
tf
import
weakref
import
time
from
six.moves
import
range
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
..utils
import
logger
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..callbacks.steps
import
MaintainStepCounter
from
..train.base
import
StopTraining
,
TrainLoop
__all__
=
[
'Trainer'
,
'SingleCostTrainer'
]
class
Trainer
(
object
):
""" Base class for a trainer.
"""
is_chief
=
True
def
__init__
(
self
):
self
.
_callbacks
=
[]
self
.
loop
=
TrainLoop
()
self
.
_monitors
=
[]
# Clarify the type. Don't change from list to monitors.
def
_register_callback
(
self
,
cb
):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
if
not
self
.
is_chief
and
cb
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
cb
)))
else
:
self
.
_callbacks
.
append
(
cb
)
def
_register_monitor
(
self
,
mon
):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert
isinstance
(
mon
,
TrainingMonitor
),
mon
assert
not
isinstance
(
self
.
_monitors
,
Monitors
),
\
"Cannot register more monitors after trainer was setup!"
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"Monitor {} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
self
.
_register_callback
(
mon
)
def
run_step
(
self
):
"""
Defines what to do in one iteration. The default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
if
not
hasattr
(
self
,
'train_op'
):
raise
NotImplementedError
(
"Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!"
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
def
setup_callbacks
(
self
,
callbacks
,
monitors
):
"""
Setup callbacks and monitors. Must be called after the main graph is built.
"""
describe_trainable_vars
()
# TODO weird
self
.
_register_callback
(
MaintainStepCounter
())
for
cb
in
callbacks
:
self
.
_register_callback
(
cb
)
for
m
in
monitors
:
self
.
_register_monitor
(
m
)
self
.
monitors
=
Monitors
(
monitors
)
self
.
_register_callback
(
self
.
monitors
)
# monitors is also a callback
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
def
initialize
(
self
,
session_creator
,
session_init
):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
logger
.
info
(
"Creating the session ..."
)
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
session_init
.
init
(
self
.
sess
)
else
:
assert
isinstance
(
session_init
,
JustCurrentSession
),
\
"session_init is only valid for chief worker session!"
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
def
_create_session
(
self
):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
def
main_loop
(
self
,
steps_per_epoch
,
starting_epoch
=
1
,
max_epoch
=
99999
):
"""
Run the main training loop.
"""
with
self
.
sess
.
as_default
():
self
.
loop
.
config
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
self
.
loop
.
update_global_step
()
try
:
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self
.
loop
.
update_global_step
()
for
self
.
loop
.
_epoch_num
in
range
(
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
loop
.
epoch_num
))
start_time
=
time
.
time
()
self
.
_callbacks
.
before_epoch
()
for
self
.
loop
.
_local_step
in
range
(
self
.
loop
.
steps_per_epoch
):
if
self
.
hooked_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
after_epoch
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
loop
.
epoch_num
,
self
.
loop
.
global_step
,
time
.
time
()
-
start_time
))
# trigger epoch outside the timing region.
self
.
_callbacks
.
trigger_epoch
()
logger
.
info
(
"Training has finished!"
)
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
logger
.
info
(
"Training was stopped."
)
except
KeyboardInterrupt
:
logger
.
info
(
"Detected Ctrl-C and exiting main loop."
)
except
:
raise
finally
:
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
def
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self
.
setup_callbacks
(
callbacks
,
monitors
)
self
.
initialize
(
session_creator
,
session_init
)
self
.
main_loop
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
def
_get_property
(
name
):
"""
Delegate property to self.loop
"""
ret
=
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
))
if
six
.
PY3
:
# __doc__ is readonly in Py2
try
:
ret
.
__doc__
=
getattr
(
TrainLoop
,
name
)
.
__doc__
except
AttributeError
:
pass
return
ret
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
Trainer
):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
def
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
callbacks
=
callbacks
+
self
.
_internal_callbacks
Trainer
.
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Build the main training graph. Defaults to do nothing.
You can either override it in subclasses, or build the graph outside
the trainer.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
assert
not
input
.
setup_done
()
input_callbacks
=
input
.
setup
(
inputs_desc
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
return
self
.
_internal_callbacks
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
pass
tensorpack/trainv2/interface.py
0 → 100644
View file @
a673974c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
import
tensorflow
as
tf
from
..input_source
import
(
FeedInput
,
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
)
from
..train.config
import
TrainConfig
from
.base
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
__all__
=
[
'launch_train_with_config'
,
'TrainConfig'
]
def
_maybe_gpu_prefetch
(
input
,
towers
,
gpu_prefetch
):
# seem to only improve on >1 GPUs
if
len
(
towers
)
>
1
and
gpu_prefetch
:
assert
tf
.
test
.
is_gpu_available
()
if
not
isinstance
(
input
,
(
StagingInputWrapper
,
DummyConstantInput
)):
input
=
StagingInputWrapper
(
input
,
towers
)
return
input
def
launch_train_with_config
(
config
,
trainer
):
"""
To mimic the old training interface, with a trainer and a config.
Args:
config (TrainConfig):
trainer (Trainer): an instance of the new trainer
Examples:
.. code-block:: python
# with the old trainer:
SyncMultiGPUTrainerParameterServer(config, ps_device='gpu').train()
# with the new trainer:
launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu'))
"""
assert
isinstance
(
trainer
,
SingleCostTrainer
),
trainer
assert
isinstance
(
config
,
TrainConfig
),
config
assert
config
.
model
is
not
None
assert
config
.
dataflow
is
not
None
or
config
.
data
is
not
None
model
=
config
.
model
inputs_desc
=
model
.
get_inputs_desc
()
input
=
config
.
data
# some check & input wrappers to mimic same behavior of the old trainer interface
if
input
is
None
:
if
type
(
trainer
)
==
SimpleTrainer
:
input
=
FeedInput
(
config
.
dataflow
)
else
:
input
=
QueueInput
(
config
.
dataflow
)
if
config
.
nr_tower
>
1
:
assert
not
isinstance
(
trainer
,
SimpleTrainer
)
input
=
_maybe_gpu_prefetch
(
input
,
config
.
tower
,
True
)
if
isinstance
(
trainer
,
DistributedTrainerReplicated
)
and
\
config
.
session_config
is
not
None
:
raise
ValueError
(
"Cannot set session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
trainer
.
setup_graph
(
inputs_desc
,
input
,
model
.
build_graph_get_cost
,
model
.
get_optimizer
)
trainer
.
train
(
config
.
callbacks
,
config
.
monitors
,
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
tensorpack/trainv2/trainers.py
0 → 100644
View file @
a673974c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: trainers.py
import
os
from
..callbacks.graph
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..graph_builder.training
import
(
SimpleBuilder
,
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
AsyncMultiGPUBuilder
,
DistributedReplicatedBuilder
)
from
..graph_builder.utils
import
override_to_local_variable
from
..utils
import
logger
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..input_source
import
QueueInput
from
.base
import
Trainer
,
SingleCostTrainer
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerParameterServer'
,
'AsyncMultiGPUTrainer'
,
'DistributedTrainerReplicated'
]
class
SimpleTrainer
(
SingleCostTrainer
):
"""
Single-GPU single-cost single-tower trainer.
"""
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
=
SimpleBuilder
()
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
return
[]
# Only works for type check
class
QueueInputTrainer
(
SimpleTrainer
):
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
assert
isinstance
(
input
,
QueueInput
)
return
super
(
QueueInputTrainer
,
self
)
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
class
SyncMultiGPUTrainerParameterServer
(
SingleCostTrainer
):
__doc__
=
SyncMultiGPUParameterServerBuilder
.
__doc__
def
__init__
(
self
,
towers
,
ps_device
=
'gpu'
):
"""
Args:
towers ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
"""
self
.
_builder
=
SyncMultiGPUParameterServerBuilder
(
towers
,
ps_device
)
super
(
SyncMultiGPUTrainerParameterServer
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
=
self
.
_builder
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
return
[]
class
AsyncMultiGPUTrainer
(
SingleCostTrainer
):
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
def
__init__
(
self
,
towers
,
scale_gradient
=
True
):
"""
Args:
towers ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
self
.
_builder
=
AsyncMultiGPUBuilder
(
towers
,
scale_gradient
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
=
self
.
_builder
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
return
[]
class
SyncMultiGPUTrainerReplicated
(
SingleCostTrainer
):
__doc__
=
SyncMultiGPUReplicatedBuilder
.
__doc__
def
__init__
(
self
,
towers
):
"""
Args:
towers ([int]): list of GPU ids.
"""
self
.
_builder
=
SyncMultiGPUReplicatedBuilder
(
towers
)
super
(
SyncMultiGPUTrainerReplicated
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
,
post_init_op
=
self
.
_builder
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
cb
=
RunOp
(
post_init_op
,
run_before
=
True
,
run_as_trigger
=
True
,
verbose
=
True
)
return
[
cb
]
class
DistributedTrainerReplicated
(
SingleCostTrainer
):
__doc__
=
DistributedReplicatedBuilder
.
__doc__
def
__init__
(
self
,
towers
,
server
):
"""
Args:
towers (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
The job_name must be 'worker' because 'ps' job doesn't need to
build any graph.
"""
self
.
server
=
server
self
.
job_name
=
server
.
server_def
.
job_name
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
towers
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
def
train
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
if
self
.
job_name
==
'ps'
:
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this will never return tensorflow#4713
return
with
override_to_local_variable
():
get_global_step_var
()
# gs should be local
# input source may create variable (queue size summary)
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
input_callbacks
=
input
.
setup
(
inputs_desc
)
train_callbacks
=
self
.
setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
Trainer
.
train
(
self
,
callbacks
+
input_callbacks
+
train_callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
,
initial_sync_op
,
model_sync_op
=
self
.
_builder
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
callbacks
=
[]
# initial local_vars syncing
cb
=
RunOp
(
lambda
:
initial_sync_op
,
run_before
=
True
,
run_as_trigger
=
False
,
verbose
=
True
)
cb
.
chief_only
=
False
callbacks
.
append
(
cb
)
# model_variables syncing
if
model_sync_op
:
cb
=
RunOp
(
lambda
:
model_sync_op
,
run_before
=
False
,
run_as_trigger
=
True
,
verbose
=
True
)
logger
.
warn
(
"For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequently than this."
)
callbacks
.
append
(
cb
)
return
callbacks
def
initialize
(
self
,
session_creator
,
session_init
):
if
not
isinstance
(
session_creator
,
NewSessionCreator
):
raise
ValueError
(
"Cannot set session_creator for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
super
(
DistributedTrainerReplicated
,
self
)
.
initialize
(
get_distributed_session_creator
(),
session_init
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment