Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9e995a8d
Commit
9e995a8d
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use `TrainLoop` to manage the loop, and delegate properties. Hide `trainer.config`
parent
7efe4939
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
106 additions
and
66 deletions
+106
-66
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+2
-2
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+1
-1
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+3
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+2
-2
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+85
-45
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+4
-4
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+6
-6
tensorpack/train/simple.py
tensorpack/train/simple.py
+1
-1
No files found.
examples/FasterRCNN/train.py
View file @
9e995a8d
...
...
@@ -223,9 +223,9 @@ class EvalCallback(Callback):
self
.
df
=
PrefetchDataZMQ
(
get_eval_dataflow
(),
1
)
EVAL_TIMES
=
5
# eval 5 times during training
interval
=
self
.
trainer
.
config
.
max_epoch
//
(
EVAL_TIMES
+
1
)
interval
=
self
.
trainer
.
max_epoch
//
(
EVAL_TIMES
+
1
)
self
.
epochs_to_eval
=
set
([
interval
*
k
for
k
in
range
(
1
,
EVAL_TIMES
)])
self
.
epochs_to_eval
.
add
(
self
.
trainer
.
config
.
max_epoch
)
self
.
epochs_to_eval
.
add
(
self
.
trainer
.
max_epoch
)
get_tf_nms
()
# just to make sure the nms part of graph is created
def
_eval
(
self
):
...
...
tensorpack/callbacks/base.py
View file @
9e995a8d
...
...
@@ -45,7 +45,7 @@ class Callback(object):
_chief_only
=
True
def
setup_graph
(
self
,
trainer
):
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
_steps_per_epoch
=
trainer
.
steps_per_epoch
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
scope_name
=
type
(
self
)
.
__name__
...
...
tensorpack/callbacks/inference_runner.py
View file @
9e995a8d
...
...
@@ -124,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
def
_setup_graph
(
self
):
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
tower_id
=
self
.
trainer
.
_
config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
...
...
tensorpack/callbacks/monitor.py
View file @
9e995a8d
...
...
@@ -256,13 +256,13 @@ class JSONWriter(TrainingMonitor):
pass
else
:
logger
.
info
(
"Found training history from JSON, now starting from epoch number {}."
.
format
(
epoch
))
self
.
trainer
.
config
.
starting_epoch
=
epoch
self
.
trainer
.
starting_epoch
=
epoch
else
:
self
.
_stats
=
[]
self
.
_stat_now
=
{}
self
.
_last_gs
=
-
1
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
steps_per_epoch
def
_trigger_step
(
self
):
# will do this in trigger_epoch
...
...
@@ -327,7 +327,7 @@ class ScalarPrinter(TrainingMonitor):
def
_setup_graph
(
self
):
self
.
_dic
=
{}
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
steps_per_epoch
def
_trigger_step
(
self
):
if
self
.
_enable_step
:
...
...
tensorpack/callbacks/steps.py
View file @
9e995a8d
...
...
@@ -67,7 +67,7 @@ class ProgressBar(Callback):
def
_before_train
(
self
):
self
.
_last_updated
=
self
.
local_step
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
or
None
...
...
@@ -133,4 +133,4 @@ class MaintainStepCounter(Callback):
def
_after_run
(
self
,
_
,
__
):
# Keep python-side global_step in agreement with TF-side
self
.
trainer
.
_global_step
+=
1
self
.
trainer
.
loop
.
_global_step
+=
1
tensorpack/callbacks/summary.py
View file @
9e995a8d
...
...
@@ -70,7 +70,7 @@ class MergeAllSummaries_RunWithOp(Callback):
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
summary_op
)
else
:
self
.
_fetches
=
None
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
steps_per_epoch
def
_need_run
(
self
):
if
self
.
local_step
==
self
.
_total
-
1
:
...
...
tensorpack/train/base.py
View file @
9e995a8d
...
...
@@ -30,6 +30,63 @@ class StopTraining(BaseException):
pass
class
TrainLoop
(
object
):
"""
Manage the double for loop.
"""
def
__init__
(
self
):
self
.
_epoch_num
=
0
self
.
_global_step
=
0
self
.
_local_step
=
-
1
def
config
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Configure the loop given the settings.
"""
self
.
starting_epoch
=
starting_epoch
self
.
max_epoch
=
max_epoch
self
.
steps_per_epoch
=
steps_per_epoch
self
.
_epoch_num
=
starting_epoch
-
1
def
update_global_step
(
self
):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self
.
_global_step
=
get_global_step_value
()
@
property
def
epoch_num
(
self
):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return
self
.
_epoch_num
@
property
def
global_step
(
self
):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return
self
.
_global_step
@
property
def
local_step
(
self
):
"""
The number of (tensorpack) steps that have finished in the current epoch.
"""
return
self
.
_local_step
class
Trainer
(
object
):
""" Base class for a trainer.
...
...
@@ -39,7 +96,6 @@ class Trainer(object):
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
local_step (int): the number of (tensorpack) steps that have finished in the current epoch.
"""
# step attr only available after before_train?
...
...
@@ -51,33 +107,16 @@ class Trainer(object):
config (TrainConfig): the train config.
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
self
.
_
config
=
config
self
.
model
=
config
.
model
self
.
local_step
=
-
1
self
.
_callbacks
=
[]
self
.
monitors
=
[]
self
.
_epoch_num
=
None
self
.
_global_step
=
0
self
.
loop
=
TrainLoop
()
self
.
loop
.
config
(
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
self
.
_setup
()
# subclass will setup the graph and InputSource
@
property
def
epoch_num
(
self
):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
if
self
.
_epoch_num
is
not
None
:
# has started training
return
self
.
_epoch_num
else
:
return
self
.
config
.
starting_epoch
-
1
def
register_callback
(
self
,
cb
):
"""
Register a callback to the trainer.
...
...
@@ -129,9 +168,9 @@ class Trainer(object):
Setup the trainer and be ready for the main loop.
"""
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
self
.
config
.
callbacks
:
for
cb
in
self
.
_
config
.
callbacks
:
self
.
register_callback
(
cb
)
for
m
in
self
.
config
.
monitors
:
for
m
in
self
.
_
config
.
monitors
:
self
.
register_monitor
(
m
)
self
.
monitors
=
Monitors
(
self
.
monitors
)
self
.
register_callback
(
self
.
monitors
)
...
...
@@ -148,9 +187,9 @@ class Trainer(object):
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
_
config
.
session_init
.
init
(
self
.
sess
)
else
:
assert
isinstance
(
self
.
config
.
session_init
,
JustCurrentSession
),
\
assert
isinstance
(
self
.
_
config
.
session_init
,
JustCurrentSession
),
\
"session_init is only valid for chief worker session!"
self
.
sess
.
graph
.
finalize
()
...
...
@@ -162,7 +201,7 @@ class Trainer(object):
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
self
.
config
.
session_creator
.
create_session
()
self
.
sess
=
self
.
_
config
.
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
...
...
@@ -176,41 +215,29 @@ class Trainer(object):
"""
pass
@
property
def
global_step
(
self
):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return
self
.
_global_step
def
main_loop
(
self
):
"""
Run the main training loop.
"""
with
self
.
sess
.
as_default
():
self
.
_global_step
=
get_global_step_value
()
self
.
loop
.
update_global_step
()
try
:
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
self
.
_global_step
=
get_global_step_value
()
for
self
.
_epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
_
epoch_num
))
self
.
loop
.
update_global_step
()
for
self
.
loop
.
_epoch_num
in
range
(
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
loop
.
epoch_num
))
start_time
=
time
.
time
()
self
.
_callbacks
.
before_epoch
()
for
self
.
lo
cal_step
in
range
(
self
.
config
.
steps_per_epoch
):
for
self
.
lo
op
.
_local_step
in
range
(
self
.
loop
.
steps_per_epoch
):
if
self
.
hooked_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
after_epoch
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
_epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
loop
.
epoch_num
,
self
.
loop
.
global_step
,
time
.
time
()
-
start_time
))
# trigger epoch outside the timing region.
self
.
_callbacks
.
trigger_epoch
()
...
...
@@ -256,6 +283,19 @@ class Trainer(object):
return
""
def
_delegate_attr
(
name
):
"""
Delegate property to self.loop
"""
setattr
(
Trainer
,
name
,
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
)))
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
_delegate_attr
(
name
)
def
launch_train
(
run_step
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
...
...
tensorpack/train/distributed.py
View file @
9e995a8d
...
...
@@ -88,7 +88,7 @@ class DistributedTrainerReplicated(Trainer):
# whether something should be global or local. We now assume
# they should be local.
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
self
.
_
config
.
callbacks
.
extend
(
cbs
)
self
.
train_op
,
initial_sync_op
,
model_sync_op
=
self
.
_builder
.
build
(
self
.
_input_source
,
self
.
model
.
build_graph_get_cost
,
self
.
model
.
get_optimizer
)
...
...
@@ -110,14 +110,14 @@ class DistributedTrainerReplicated(Trainer):
self
.
_set_session_creator
()
def
_set_session_creator
(
self
):
old_sess_creator
=
self
.
config
.
session_creator
old_sess_creator
=
self
.
_
config
.
session_creator
if
not
isinstance
(
old_sess_creator
,
NewSessionCreator
)
\
or
self
.
config
.
session_config
is
not
None
:
or
self
.
_
config
.
session_config
is
not
None
:
raise
ValueError
(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server."
)
self
.
config
.
session_creator
=
get_distributed_session_creator
(
self
.
server
)
self
.
_
config
.
session_creator
=
get_distributed_session_creator
(
self
.
server
)
@
property
def
vs_name_for_predictor
(
self
):
...
...
tensorpack/train/multigpu.py
View file @
9e995a8d
...
...
@@ -71,10 +71,10 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
callbacks
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
train_op
=
SyncMultiGPUParameterServerBuilder
(
self
.
config
.
tower
,
self
.
_ps_device
)
.
build
(
self
.
_
config
.
tower
,
self
.
_ps_device
)
.
build
(
self
.
_input_source
,
self
.
model
.
build_graph_get_cost
,
self
.
model
.
get_optimizer
)
self
.
config
.
callbacks
.
extend
(
callbacks
)
self
.
_
config
.
callbacks
.
extend
(
callbacks
)
def
SyncMultiGPUTrainer
(
config
):
...
...
@@ -102,13 +102,13 @@ class SyncMultiGPUTrainerReplicated(Trainer):
def
_setup
(
self
):
callbacks
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
train_op
,
post_init_op
=
SyncMultiGPUReplicatedBuilder
(
self
.
config
.
tower
)
.
build
(
self
.
train_op
,
post_init_op
=
SyncMultiGPUReplicatedBuilder
(
self
.
_
config
.
tower
)
.
build
(
self
.
_input_source
,
self
.
model
.
build_graph_get_cost
,
self
.
model
.
get_optimizer
)
cb
=
RunOp
(
lambda
:
post_init_op
,
run_before
=
True
,
run_as_trigger
=
True
,
verbose
=
True
)
self
.
config
.
callbacks
.
extend
(
callbacks
+
[
cb
])
self
.
_
config
.
callbacks
.
extend
(
callbacks
+
[
cb
])
class
AsyncMultiGPUTrainer
(
Trainer
):
...
...
@@ -130,7 +130,7 @@ class AsyncMultiGPUTrainer(Trainer):
callbacks
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
train_op
=
AsyncMultiGPUBuilder
(
self
.
config
.
tower
,
self
.
_scale_gradient
)
.
build
(
self
.
_
config
.
tower
,
self
.
_scale_gradient
)
.
build
(
self
.
_input_source
,
self
.
model
.
build_graph_get_cost
,
self
.
model
.
get_optimizer
)
self
.
config
.
callbacks
.
extend
(
callbacks
)
self
.
_
config
.
callbacks
.
extend
(
callbacks
)
tensorpack/train/simple.py
View file @
9e995a8d
...
...
@@ -44,7 +44,7 @@ class SimpleTrainer(Trainer):
self
.
train_op
=
SimpleBuilder
()
.
build
(
self
.
_input_source
,
self
.
model
.
build_graph_get_cost
,
self
.
model
.
get_optimizer
)
self
.
config
.
callbacks
.
extend
(
cbs
)
self
.
_
config
.
callbacks
.
extend
(
cbs
)
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment