Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
818e3faf
Commit
818e3faf
authored
May 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-ups in callbacks
parent
0d20cb3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
29 deletions
+30
-29
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+29
-27
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-2
No files found.
tensorpack/callbacks/group.py
View file @
818e3faf
...
@@ -12,6 +12,7 @@ from ..utils import *
...
@@ -12,6 +12,7 @@ from ..utils import *
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
# --- Test-Callback related stuff seems not very useful.
@
contextmanager
@
contextmanager
def
create_test_graph
(
trainer
):
def
create_test_graph
(
trainer
):
model
=
trainer
.
model
model
=
trainer
.
model
...
@@ -31,33 +32,6 @@ def create_test_session(trainer):
...
@@ -31,33 +32,6 @@ def create_test_session(trainer):
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
yield
sess
yield
sess
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
@
contextmanager
def
timed_callback
(
self
,
name
):
s
=
time
.
time
()
yield
self
.
add
(
name
,
time
.
time
()
-
s
)
def
log
(
self
):
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{:.3f}sec"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
class
TestCallbackContext
(
object
):
class
TestCallbackContext
(
object
):
"""
"""
A class holding the context needed for running TestCallback
A class holding the context needed for running TestCallback
...
@@ -91,6 +65,34 @@ class TestCallbackContext(object):
...
@@ -91,6 +65,34 @@ class TestCallbackContext(object):
def
test_context
(
self
):
def
test_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
yield
# ---
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
@
contextmanager
def
timed_callback
(
self
,
name
):
s
=
time
.
time
()
yield
self
.
add
(
name
,
time
.
time
()
-
s
)
def
log
(
self
):
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{:.3f}sec"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
"""
"""
...
...
tensorpack/callbacks/inference.py
View file @
818e3faf
...
@@ -13,7 +13,7 @@ from ..utils import *
...
@@ -13,7 +13,7 @@ from ..utils import *
from
..utils.stat
import
*
from
..utils.stat
import
*
from
..tfutils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
*
from
..tfutils.summary
import
*
from
.base
import
Callback
,
TestCallbackType
from
.base
import
Callback
__all__
=
[
'InferenceRunner'
,
'ClassificationError'
,
__all__
=
[
'InferenceRunner'
,
'ClassificationError'
,
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
...
@@ -63,7 +63,6 @@ class InferenceRunner(Callback):
...
@@ -63,7 +63,6 @@ class InferenceRunner(Callback):
"""
"""
A callback that runs different kinds of inferencer.
A callback that runs different kinds of inferencer.
"""
"""
#type = TestCallbackType()
def
__init__
(
self
,
ds
,
vcs
):
def
__init__
(
self
,
ds
,
vcs
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment