Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a4d51a2d
Commit
a4d51a2d
authored
Feb 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
stat holder and summary writer
parent
51c58dfa
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
100 additions
and
56 deletions
+100
-56
example_mnist.py
example_mnist.py
+1
-1
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+3
-37
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+7
-9
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+75
-0
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+4
-5
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+7
-0
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+3
-4
No files found.
example_mnist.py
View file @
a4d51a2d
...
...
@@ -92,7 +92,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
3
#step_per_epoch = 20
# prepare session
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/common.py
View file @
a4d51a2d
...
...
@@ -10,7 +10,7 @@ import re
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
__all__
=
[
'PeriodicSaver'
,
'SummaryWriter'
]
__all__
=
[
'PeriodicSaver'
]
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
...
...
@@ -30,39 +30,5 @@ class PeriodicSaver(PeriodicCallback):
self
.
path
,
global_step
=
self
.
global_step
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
""" if None, print all scalar summary"""
self
.
log_dir
=
logger
.
LOG_DIR
self
.
print_tag
=
print_tag
def
_before_train
(
self
):
self
.
writer
=
tf
.
train
.
SummaryWriter
(
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
self
.
epoch_num
=
0
def
_trigger_epoch
(
self
):
self
.
epoch_num
+=
1
# check if there is any summary to write
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
printed_tag
=
set
()
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
if
self
.
print_tag
is
None
or
val
.
tag
in
self
.
print_tag
:
logger
.
info
(
'{}: {:.4f}'
.
format
(
val
.
tag
,
val
.
simple_value
))
printed_tag
.
add
(
val
.
tag
)
self
.
writer
.
add_summary
(
summary
,
get_global_step
())
if
self
.
print_tag
is
not
None
and
self
.
epoch_num
==
1
:
if
len
(
printed_tag
)
!=
len
(
self
.
print_tag
):
logger
.
warn
(
"Tags to print not found in Summary Writer: {}"
.
format
(
", "
.
join
([
k
for
k
in
self
.
print_tag
if
k
not
in
printed_tag
])))
def
_after_train
(
self
):
self
.
writer
.
close
()
class
MinSaver
(
Callback
):
pass
tensorpack/callbacks/group.py
View file @
a4d51a2d
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
contextlib
import
contextmanager
from
.base
import
Callback
from
.
common
import
*
from
.
summary
import
*
from
..utils
import
*
__all__
=
[
'Callbacks'
]
...
...
@@ -57,18 +57,18 @@ class CallbackTimeLogger(object):
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
# put SummaryWriter to the beginning
if
type
(
cb
)
==
SummaryWriter
:
self
.
cbs
.
insert
(
0
,
self
.
cbs
.
pop
(
idx
))
break
else
:
raise
ValueError
(
"Callbacks must contain a SummaryWriter!"
)
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
self
.
cbs
.
insert
(
0
,
SummaryWriter
())
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
...
...
@@ -84,7 +84,6 @@ class TrainCallbacks(Callback):
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
TestCallbacks
(
Callback
):
...
...
@@ -97,13 +96,11 @@ class TestCallbacks(Callback):
self
.
cbs
=
callbacks
def
_before_train
(
self
):
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
for
cb
in
self
.
cbs
:
cb
.
before_train
()
...
...
@@ -130,7 +127,6 @@ class TestCallbacks(Callback):
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
Callbacks
(
Callback
):
...
...
@@ -161,6 +157,7 @@ class Callbacks(Callback):
self
.
train
.
after_train
()
if
self
.
test
:
self
.
test
.
after_train
()
logger
.
writer
.
close
()
def
trigger_step
(
self
):
self
.
train
.
trigger_step
()
...
...
@@ -168,6 +165,7 @@ class Callbacks(Callback):
def
_trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
# TODO test callbacks can be run async?
if
self
.
test
:
self
.
test
.
trigger_epoch
()
logger
.
writer
.
flush
()
logger
.
stat_holder
.
finalize
()
tensorpack/callbacks/summary.py
0 → 100644
View file @
a4d51a2d
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
re
import
os
import
operator
import
cPickle
as
pickle
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
__all__
=
[
'SummaryWriter'
]
class
StatHolder
(
object
):
def
__init__
(
self
,
log_dir
,
print_tag
=
None
):
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
self
.
stat_now
=
{}
self
.
log_dir
=
log_dir
self
.
filename
=
os
.
path
.
join
(
log_dir
,
'stat.pkl'
)
if
os
.
path
.
isfile
(
self
.
filename
):
logger
.
info
(
"Loading stats from {}..."
.
format
(
self
.
filename
))
with
open
(
self
.
filename
)
as
f
:
self
.
stat_history
=
pickle
.
load
(
f
)
else
:
self
.
stat_history
=
[]
def
add_stat
(
self
,
k
,
v
):
self
.
stat_now
[
k
]
=
v
def
finalize
(
self
):
self
.
_print_stat
()
self
.
stat_history
.
append
(
self
.
stat_now
)
self
.
stat_now
=
{}
self
.
_write_stat
()
def
_print_stat
(
self
):
for
k
,
v
in
sorted
(
self
.
stat_now
.
items
(),
key
=
operator
.
itemgetter
(
0
)):
if
self
.
print_tag
is
None
or
k
in
self
.
print_tag
:
logger
.
info
(
'{}: {:.4f}'
.
format
(
k
,
v
))
def
_write_stat
(
self
):
tmp_filename
=
self
.
filename
+
'.tmp'
with
open
(
tmp_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
self
.
log_dir
=
logger
.
LOG_DIR
logger
.
stat_holder
=
StatHolder
(
self
.
log_dir
,
print_tag
)
def
_before_train
(
self
):
logger
.
writer
=
tf
.
train
.
SummaryWriter
(
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
_trigger_epoch
(
self
):
# check if there is any summary to write
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
logger
.
writer
.
add_summary
(
summary
,
self
.
global_step
)
tensorpack/callbacks/validation_callback.py
View file @
a4d51a2d
...
...
@@ -28,7 +28,6 @@ class ValidationCallback(PeriodicCallback):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
MODEL_KEY
)[
0
]
.
get_input_vars
()
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
self
.
_find_output_vars
()
def
get_tensor
(
self
,
name
):
...
...
@@ -64,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar
.
update
()
cost_avg
=
cost_sum
/
cnt
self
.
writer
.
add_summary
(
create_summary
(
logger
.
writer
.
add_summary
(
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
self
.
global_step
)
logger
.
info
(
"{}_cost: {:.4f}"
.
format
(
self
.
prefix
,
cost_avg
)
)
logger
.
stat_holder
.
add_stat
(
"{}_cost"
.
format
(
self
.
prefix
),
cost_avg
)
def
_trigger_periodic
(
self
):
for
dp
,
outputs
in
self
.
_run_validation
():
...
...
@@ -102,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong
=
outputs
[
0
]
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
writer
.
add_summary
(
create_summary
(
logger
.
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
global_step
)
logger
.
info
(
"{}_error: {:.4f}"
.
format
(
self
.
prefix
,
err_stat
.
accuracy
)
)
logger
.
stat_holder
.
add_stat
(
"{}_error"
.
format
(
self
.
prefix
),
err_stat
.
accuracy
)
tensorpack/utils/logger.py
View file @
a4d51a2d
...
...
@@ -61,3 +61,10 @@ def set_logger_file(filename):
mkdir_p
(
os
.
path
.
dirname
(
LOG_FILE
))
set_file
(
LOG_FILE
)
# global logger:
# a SummaryWriter
writer
=
None
# a StatHolder
stat_holder
=
None
tensorpack/utils/naming.py
View file @
a4d51a2d
...
...
@@ -6,11 +6,10 @@
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY
=
'summary_writer'
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
# extra variables to summarize during training
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
MODEL_KEY
=
'MODEL'
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
__all__
=
[
x
for
x
in
all_local_names
if
x
.
upper
()
==
x
]
__all__
=
[
x
for
x
in
all_local_names
if
x
.
isupper
()
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment