Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
61f14083
Commit
61f14083
authored
Feb 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add "register_callback", so custom trainers can have more control over callbacks&hooks
parent
0a508273
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
22 deletions
+44
-22
tensorpack/predict/config.py
tensorpack/predict/config.py
+4
-4
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+3
-3
tensorpack/train/base.py
tensorpack/train/base.py
+33
-9
tensorpack/train/config.py
tensorpack/train/config.py
+2
-4
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+2
-2
No files found.
tensorpack/predict/config.py
View file @
61f14083
...
...
@@ -9,7 +9,7 @@ from ..models import ModelDesc
from
..utils.develop
import
log_deprecated
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sesscreate
import
NewSession
from
..tfutils.sesscreate
import
NewSession
Creator
__all__
=
[
'PredictConfig'
]
...
...
@@ -28,7 +28,7 @@ class PredictConfig(object):
Args:
model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`.
session. Defaults to :class:`sesscreate.NewSession
Creator
()`.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all
...
...
@@ -52,9 +52,9 @@ class PredictConfig(object):
if
session_creator
is
None
:
if
session_config
is
not
None
:
log_deprecated
(
"PredictConfig(session_config=)"
,
"Use session_creator instead!"
,
"2017-04-20"
)
self
.
session_creator
=
NewSession
(
config
=
session_config
)
self
.
session_creator
=
NewSession
Creator
(
config
=
session_config
)
else
:
self
.
session_creator
=
NewSession
(
config
=
get_default_sess_config
(
0.4
))
self
.
session_creator
=
NewSession
Creator
(
config
=
get_default_sess_config
(
0.4
))
else
:
self
.
session_creator
=
session_creator
...
...
tensorpack/tfutils/sesscreate.py
View file @
61f14083
...
...
@@ -5,10 +5,10 @@
import
tensorflow
as
tf
__all__
=
[
'NewSession
'
,
'ReuseSession
'
]
__all__
=
[
'NewSession
Creator'
,
'ReuseSessionCreator
'
]
class
NewSession
(
tf
.
train
.
SessionCreator
):
class
NewSession
Creator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
graph
=
None
,
config
=
None
):
"""
Args:
...
...
@@ -22,7 +22,7 @@ class NewSession(tf.train.SessionCreator):
return
tf
.
Session
(
target
=
self
.
target
,
graph
=
self
.
graph
,
config
=
self
.
config
)
class
ReuseSession
(
tf
.
train
.
SessionCreator
):
class
ReuseSession
Creator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
sess
):
"""
Args:
...
...
tensorpack/train/base.py
View file @
61f14083
...
...
@@ -10,11 +10,14 @@ import six
from
six.moves
import
range
import
tensorflow
as
tf
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
.predict
import
PredictorFactory
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils.develop
import
deprecated
,
log_deprecated
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
,
Callback
,
Callbacks
,
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.summary
import
create_scalar_summary
...
...
@@ -57,6 +60,24 @@ class Trainer(object):
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
local_step
=
-
1
self
.
_callbacks
=
[]
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
self
.
config
.
callbacks
:
self
.
register_callback
(
cb
)
def
register_callback
(
self
,
cb
):
"""
Use this method before :meth:`Trainer._setup` finishes,
to register a callback to the trainer.
The hooks of the registered callback will be bind to the
`self.hooked_sess` session.
"""
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
self
.
_callbacks
.
append
(
cb
)
def
train
(
self
):
""" Start training """
self
.
setup
()
...
...
@@ -74,7 +95,7 @@ class Trainer(object):
# trigger subclass
self
.
_trigger_epoch
()
# trigger callbacks
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
_
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
def
_trigger_epoch
(
self
):
...
...
@@ -126,7 +147,9 @@ class Trainer(object):
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
session_init
.
_setup_graph
()
def
after_init
(
scaffold
,
sess
):
...
...
@@ -140,10 +163,12 @@ class Trainer(object):
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
self
.
hooked_sess
=
self
.
monitored_sess
# just create an alias
hooks
=
None
)
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
HookedSession
(
self
.
sess
,
hooks
)
@
abstractmethod
def
_setup
(
self
):
""" setup Trainer-specific stuff for training"""
...
...
@@ -161,11 +186,10 @@ class Trainer(object):
"""
Run the main training loop.
"""
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
self
.
_starting_step
=
get_global_step_value
()
try
:
callbacks
.
before_train
()
self
.
_
callbacks
.
before_train
()
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
...
...
@@ -174,7 +198,7 @@ class Trainer(object):
if
self
.
monitored_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
callbacks
.
trigger_step
()
self
.
_
callbacks
.
trigger_step
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
...
@@ -186,7 +210,7 @@ class Trainer(object):
except
:
raise
finally
:
callbacks
.
after_train
()
self
.
_
callbacks
.
after_train
()
self
.
summary_writer
.
close
()
self
.
monitored_sess
.
close
()
...
...
tensorpack/train/config.py
View file @
61f14083
...
...
@@ -6,8 +6,7 @@ import tensorflow as tf
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
StatPrinter
,
ProgressBar
,
MergeAllSummaries
,
MaintainStepCounter
)
StatPrinter
,
ProgressBar
,
MergeAllSummaries
)
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..utils
import
logger
...
...
@@ -89,9 +88,8 @@ class TrainConfig(object):
ProgressBar
(),
MergeAllSummaries
(),
StatPrinter
()]
self
.
callbacks
=
[
MaintainStepCounter
()]
+
callbacks
+
extra_callbacks
self
.
callbacks
=
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
self
.
callbacks
=
Callbacks
(
self
.
callbacks
)
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
...
...
tensorpack/train/input_data.py
View file @
61f14083
...
...
@@ -155,7 +155,7 @@ class QueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
...
...
@@ -219,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment