Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
61f14083
Commit
61f14083
authored
Feb 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add "register_callback", so custom trainers can have more control over callbacks&hooks
parent
0a508273
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
22 deletions
+44
-22
tensorpack/predict/config.py
tensorpack/predict/config.py
+4
-4
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+3
-3
tensorpack/train/base.py
tensorpack/train/base.py
+33
-9
tensorpack/train/config.py
tensorpack/train/config.py
+2
-4
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+2
-2
No files found.
tensorpack/predict/config.py
View file @
61f14083
...
@@ -9,7 +9,7 @@ from ..models import ModelDesc
...
@@ -9,7 +9,7 @@ from ..models import ModelDesc
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..tfutils
import
get_default_sess_config
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sesscreate
import
NewSession
from
..tfutils.sesscreate
import
NewSession
Creator
__all__
=
[
'PredictConfig'
]
__all__
=
[
'PredictConfig'
]
...
@@ -28,7 +28,7 @@ class PredictConfig(object):
...
@@ -28,7 +28,7 @@ class PredictConfig(object):
Args:
Args:
model (ModelDesc): the model to use.
model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`.
session. Defaults to :class:`sesscreate.NewSession
Creator
()`.
session_init (SessionInit): how to initialize variables of the session.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all
input_names (list): a list of input tensor names. Defaults to all
...
@@ -52,9 +52,9 @@ class PredictConfig(object):
...
@@ -52,9 +52,9 @@ class PredictConfig(object):
if
session_creator
is
None
:
if
session_creator
is
None
:
if
session_config
is
not
None
:
if
session_config
is
not
None
:
log_deprecated
(
"PredictConfig(session_config=)"
,
"Use session_creator instead!"
,
"2017-04-20"
)
log_deprecated
(
"PredictConfig(session_config=)"
,
"Use session_creator instead!"
,
"2017-04-20"
)
self
.
session_creator
=
NewSession
(
config
=
session_config
)
self
.
session_creator
=
NewSession
Creator
(
config
=
session_config
)
else
:
else
:
self
.
session_creator
=
NewSession
(
config
=
get_default_sess_config
(
0.4
))
self
.
session_creator
=
NewSession
Creator
(
config
=
get_default_sess_config
(
0.4
))
else
:
else
:
self
.
session_creator
=
session_creator
self
.
session_creator
=
session_creator
...
...
tensorpack/tfutils/sesscreate.py
View file @
61f14083
...
@@ -5,10 +5,10 @@
...
@@ -5,10 +5,10 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
__all__
=
[
'NewSession
'
,
'ReuseSession
'
]
__all__
=
[
'NewSession
Creator'
,
'ReuseSessionCreator
'
]
class
NewSession
(
tf
.
train
.
SessionCreator
):
class
NewSession
Creator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
graph
=
None
,
config
=
None
):
def
__init__
(
self
,
target
=
''
,
graph
=
None
,
config
=
None
):
"""
"""
Args:
Args:
...
@@ -22,7 +22,7 @@ class NewSession(tf.train.SessionCreator):
...
@@ -22,7 +22,7 @@ class NewSession(tf.train.SessionCreator):
return
tf
.
Session
(
target
=
self
.
target
,
graph
=
self
.
graph
,
config
=
self
.
config
)
return
tf
.
Session
(
target
=
self
.
target
,
graph
=
self
.
graph
,
config
=
self
.
config
)
class
ReuseSession
(
tf
.
train
.
SessionCreator
):
class
ReuseSession
Creator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
sess
):
def
__init__
(
self
,
sess
):
"""
"""
Args:
Args:
...
...
tensorpack/train/base.py
View file @
61f14083
...
@@ -10,11 +10,14 @@ import six
...
@@ -10,11 +10,14 @@ import six
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
.predict
import
PredictorFactory
from
.predict
import
PredictorFactory
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
deprecated
,
log_deprecated
from
..utils.develop
import
deprecated
,
log_deprecated
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
,
Callback
,
Callbacks
,
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
..tfutils.summary
import
create_scalar_summary
from
..tfutils.summary
import
create_scalar_summary
...
@@ -57,6 +60,24 @@ class Trainer(object):
...
@@ -57,6 +60,24 @@ class Trainer(object):
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
local_step
=
-
1
self
.
local_step
=
-
1
self
.
_callbacks
=
[]
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
self
.
config
.
callbacks
:
self
.
register_callback
(
cb
)
def
register_callback
(
self
,
cb
):
"""
Use this method before :meth:`Trainer._setup` finishes,
to register a callback to the trainer.
The hooks of the registered callback will be bind to the
`self.hooked_sess` session.
"""
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
self
.
_callbacks
.
append
(
cb
)
def
train
(
self
):
def
train
(
self
):
""" Start training """
""" Start training """
self
.
setup
()
self
.
setup
()
...
@@ -74,7 +95,7 @@ class Trainer(object):
...
@@ -74,7 +95,7 @@ class Trainer(object):
# trigger subclass
# trigger subclass
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
# trigger callbacks
# trigger callbacks
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
_
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
self
.
summary_writer
.
flush
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
@@ -126,7 +147,9 @@ class Trainer(object):
...
@@ -126,7 +147,9 @@ class Trainer(object):
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Setup callbacks graph ..."
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
session_init
.
_setup_graph
()
self
.
config
.
session_init
.
_setup_graph
()
def
after_init
(
scaffold
,
sess
):
def
after_init
(
scaffold
,
sess
):
...
@@ -140,10 +163,12 @@ class Trainer(object):
...
@@ -140,10 +163,12 @@ class Trainer(object):
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
hooks
=
None
)
self
.
hooked_sess
=
self
.
monitored_sess
# just create an alias
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
HookedSession
(
self
.
sess
,
hooks
)
@
abstractmethod
@
abstractmethod
def
_setup
(
self
):
def
_setup
(
self
):
""" setup Trainer-specific stuff for training"""
""" setup Trainer-specific stuff for training"""
...
@@ -161,11 +186,10 @@ class Trainer(object):
...
@@ -161,11 +186,10 @@ class Trainer(object):
"""
"""
Run the main training loop.
Run the main training loop.
"""
"""
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
self
.
_starting_step
=
get_global_step_value
()
self
.
_starting_step
=
get_global_step_value
()
try
:
try
:
callbacks
.
before_train
()
self
.
_
callbacks
.
before_train
()
for
self
.
epoch_num
in
range
(
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
...
@@ -174,7 +198,7 @@ class Trainer(object):
...
@@ -174,7 +198,7 @@ class Trainer(object):
if
self
.
monitored_sess
.
should_stop
():
if
self
.
monitored_sess
.
should_stop
():
return
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
callbacks
.
trigger_step
()
self
.
_
callbacks
.
trigger_step
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
@@ -186,7 +210,7 @@ class Trainer(object):
...
@@ -186,7 +210,7 @@ class Trainer(object):
except
:
except
:
raise
raise
finally
:
finally
:
callbacks
.
after_train
()
self
.
_
callbacks
.
after_train
()
self
.
summary_writer
.
close
()
self
.
summary_writer
.
close
()
self
.
monitored_sess
.
close
()
self
.
monitored_sess
.
close
()
...
...
tensorpack/train/config.py
View file @
61f14083
...
@@ -6,8 +6,7 @@ import tensorflow as tf
...
@@ -6,8 +6,7 @@ import tensorflow as tf
from
..callbacks
import
(
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
Callbacks
,
MovingAverageSummary
,
StatPrinter
,
ProgressBar
,
MergeAllSummaries
,
StatPrinter
,
ProgressBar
,
MergeAllSummaries
)
MaintainStepCounter
)
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
...
@@ -89,9 +88,8 @@ class TrainConfig(object):
...
@@ -89,9 +88,8 @@ class TrainConfig(object):
ProgressBar
(),
ProgressBar
(),
MergeAllSummaries
(),
MergeAllSummaries
(),
StatPrinter
()]
StatPrinter
()]
self
.
callbacks
=
[
MaintainStepCounter
()]
+
callbacks
+
extra_callbacks
self
.
callbacks
=
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
assert_type
(
self
.
callbacks
,
list
)
self
.
callbacks
=
Callbacks
(
self
.
callbacks
)
self
.
model
=
model
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
...
...
tensorpack/train/input_data.py
View file @
61f14083
...
@@ -155,7 +155,7 @@ class QueueInput(FeedfreeInput):
...
@@ -155,7 +155,7 @@ class QueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
self
.
setup
(
trainer
.
model
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
...
@@ -219,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -219,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
self
.
setup
(
trainer
.
model
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment