Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
eee05770
Commit
eee05770
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use hooks to run step triggers. examples unfixed. (#147)
parent
136174c9
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
39 additions
and
43 deletions
+39
-43
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-0
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+16
-26
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+16
-3
tensorpack/train/base.py
tensorpack/train/base.py
+4
-9
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-3
No files found.
tensorpack/callbacks/base.py
View file @
eee05770
...
@@ -85,6 +85,7 @@ class Callback(object):
...
@@ -85,6 +85,7 @@ class Callback(object):
if
isinstance
(
f
,
(
tf
.
Tensor
,
tf
.
Operation
)):
if
isinstance
(
f
,
(
tf
.
Tensor
,
tf
.
Operation
)):
ret
.
append
(
f
)
ret
.
append
(
f
)
else
:
else
:
# warn about speed
ret
.
append
(
get_op_or_tensor_by_name
(
f
))
ret
.
append
(
get_op_or_tensor_by_name
(
f
))
return
ret
return
ret
...
...
tensorpack/callbacks/group.py
View file @
eee05770
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
time
import
time
import
traceback
import
traceback
...
@@ -15,8 +14,20 @@ from ..utils import logger
...
@@ -15,8 +14,20 @@ from ..utils import logger
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
class
CallbackTimeLogger
(
object
):
class
CallbackHook
(
tf
.
train
.
SessionRunHook
):
def
__init__
(
self
,
cb
):
self
.
cb
=
cb
def
before_run
(
self
,
_
):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
self
.
cb
.
extra_fetches
())
def
after_run
(
self
,
_
,
vals
):
res
=
vals
.
results
self
.
cb
.
trigger_step
(
*
res
)
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
times
=
[]
self
.
tot
=
0
self
.
tot
=
0
...
@@ -90,30 +101,9 @@ class Callbacks(Callback):
...
@@ -90,30 +101,9 @@ class Callbacks(Callback):
except
Exception
:
except
Exception
:
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_extra_fetches
(
self
):
def
get_hooks
(
self
):
if
self
.
_extra_fetches_cache
is
not
None
:
# TODO skip
return
self
.
_extra_fetches_cache
return
[
CallbackHook
(
cb
)
for
cb
in
self
.
cbs
]
# TODO use dispatch mechanism to avoid duplication
self
.
_cbid_to_fetchid
=
defaultdict
(
list
)
ret
=
[]
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
fetch
=
cb
.
extra_fetches
()
if
len
(
fetch
)
==
0
:
continue
for
f
in
fetch
:
ret
.
append
(
f
)
self
.
_cbid_to_fetchid
[
idx
]
.
append
(
len
(
ret
)
-
1
)
self
.
_extra_fetches_cache
=
ret
return
ret
def
_trigger_step
(
self
,
*
args
):
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
fid
=
self
.
_cbid_to_fetchid
[
idx
]
if
len
(
fid
)
==
0
:
cb
.
trigger_step
()
else
:
data
=
[
args
[
k
]
for
k
in
fid
]
cb
.
trigger_step
(
*
data
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
...
...
tensorpack/tfutils/common.py
View file @
eee05770
...
@@ -116,11 +116,24 @@ def get_tensors_by_names(names):
...
@@ -116,11 +116,24 @@ def get_tensors_by_names(names):
def
get_op_or_tensor_by_name
(
name
):
def
get_op_or_tensor_by_name
(
name
):
"""
Get either tf.Operation of tf.Tensor from names.
Args:
name (list[str] or str): names of operations or tensors.
"""
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
if
len
(
name
)
>=
3
and
name
[
-
2
]
==
':'
:
return
G
.
get_tensor_by_name
(
name
)
def
f
(
n
):
if
len
(
n
)
>=
3
and
n
[
-
2
]
==
':'
:
return
G
.
get_tensor_by_name
(
n
)
else
:
return
G
.
get_operation_by_name
(
n
)
if
not
isinstance
(
name
,
list
):
return
f
(
name
)
else
:
else
:
return
G
.
get_operation_by_name
(
name
)
return
map
(
f
,
name
)
def
get_name_scope_name
():
def
get_name_scope_name
():
...
...
tensorpack/train/base.py
View file @
eee05770
...
@@ -72,7 +72,8 @@ class Trainer(object):
...
@@ -72,7 +72,8 @@ class Trainer(object):
This function should only get called after :meth:`setup()` has finished.
This function should only get called after :meth:`setup()` has finished.
"""
"""
return
self
.
_extra_fetches
# TODO remove this func
return
[]
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
...
@@ -130,7 +131,6 @@ class Trainer(object):
...
@@ -130,7 +131,6 @@ class Trainer(object):
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
logger
.
info
(
"Setup summaries ..."
)
logger
.
info
(
"Setup summaries ..."
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
...
@@ -149,7 +149,7 @@ class Trainer(object):
...
@@ -149,7 +149,7 @@ class Trainer(object):
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
None
)
hooks
=
self
.
config
.
callbacks
.
get_hooks
()
)
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
...
@@ -182,12 +182,7 @@ class Trainer(object):
...
@@ -182,12 +182,7 @@ class Trainer(object):
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
monitored_sess
.
should_stop
():
if
self
.
monitored_sess
.
should_stop
():
return
return
fetch_data
=
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
if
fetch_data
is
None
:
# old trainer doesn't return fetch data
callbacks
.
trigger_step
()
else
:
callbacks
.
trigger_step
(
*
fetch_data
)
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
...
tensorpack/train/feedfree.py
View file @
eee05770
...
@@ -63,8 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -63,8 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
""" Simply run ``self.train_op``, which minimizes the cost."""
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
())
self
.
monitored_sess
.
run
(
self
.
train_op
)
return
ret
[
1
:]
# if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# self.cnt = 0
# else:
# else:
...
...
tensorpack/train/trainer.py
View file @
eee05770
...
@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer):
...
@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
feed
=
self
.
_input_method
.
next_feed
()
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
(),
self
.
monitored_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
feed_dict
=
feed
)
return
ret
[
1
:]
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment