Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a4d51a2d
You need to sign in or sign up before continuing.
Commit
a4d51a2d
authored
Feb 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
stat holder and summary writer
parent
51c58dfa
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
100 additions
and
56 deletions
+100
-56
example_mnist.py
example_mnist.py
+1
-1
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+3
-37
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+7
-9
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+75
-0
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+4
-5
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+7
-0
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+3
-4
No files found.
example_mnist.py
View file @
a4d51a2d
...
...
@@ -92,7 +92,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
3
#step_per_epoch = 20
# prepare session
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/common.py
View file @
a4d51a2d
...
...
@@ -10,7 +10,7 @@ import re
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
__all__
=
[
'PeriodicSaver'
,
'SummaryWriter'
]
__all__
=
[
'PeriodicSaver'
]
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
...
...
@@ -30,39 +30,5 @@ class PeriodicSaver(PeriodicCallback):
self
.
path
,
global_step
=
self
.
global_step
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
""" if None, print all scalar summary"""
self
.
log_dir
=
logger
.
LOG_DIR
self
.
print_tag
=
print_tag
def
_before_train
(
self
):
self
.
writer
=
tf
.
train
.
SummaryWriter
(
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
self
.
epoch_num
=
0
def
_trigger_epoch
(
self
):
self
.
epoch_num
+=
1
# check if there is any summary to write
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
printed_tag
=
set
()
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
if
self
.
print_tag
is
None
or
val
.
tag
in
self
.
print_tag
:
logger
.
info
(
'{}: {:.4f}'
.
format
(
val
.
tag
,
val
.
simple_value
))
printed_tag
.
add
(
val
.
tag
)
self
.
writer
.
add_summary
(
summary
,
get_global_step
())
if
self
.
print_tag
is
not
None
and
self
.
epoch_num
==
1
:
if
len
(
printed_tag
)
!=
len
(
self
.
print_tag
):
logger
.
warn
(
"Tags to print not found in Summary Writer: {}"
.
format
(
", "
.
join
([
k
for
k
in
self
.
print_tag
if
k
not
in
printed_tag
])))
def
_after_train
(
self
):
self
.
writer
.
close
()
class
MinSaver
(
Callback
):
pass
tensorpack/callbacks/group.py
View file @
a4d51a2d
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
contextlib
import
contextmanager
from
.base
import
Callback
from
.
common
import
*
from
.
summary
import
*
from
..utils
import
*
__all__
=
[
'Callbacks'
]
...
...
@@ -57,18 +57,18 @@ class CallbackTimeLogger(object):
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
# put SummaryWriter to the beginning
if
type
(
cb
)
==
SummaryWriter
:
self
.
cbs
.
insert
(
0
,
self
.
cbs
.
pop
(
idx
))
break
else
:
raise
ValueError
(
"Callbacks must contain a SummaryWriter!"
)
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
self
.
cbs
.
insert
(
0
,
SummaryWriter
())
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
...
...
@@ -84,7 +84,6 @@ class TrainCallbacks(Callback):
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
TestCallbacks
(
Callback
):
...
...
@@ -97,13 +96,11 @@ class TestCallbacks(Callback):
self
.
cbs
=
callbacks
def
_before_train
(
self
):
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
for
cb
in
self
.
cbs
:
cb
.
before_train
()
...
...
@@ -130,7 +127,6 @@ class TestCallbacks(Callback):
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
Callbacks
(
Callback
):
...
...
@@ -161,6 +157,7 @@ class Callbacks(Callback):
self
.
train
.
after_train
()
if
self
.
test
:
self
.
test
.
after_train
()
logger
.
writer
.
close
()
def
trigger_step
(
self
):
self
.
train
.
trigger_step
()
...
...
@@ -168,6 +165,7 @@ class Callbacks(Callback):
def
_trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
# TODO test callbacks can be run async?
if
self
.
test
:
self
.
test
.
trigger_epoch
()
logger
.
writer
.
flush
()
logger
.
stat_holder
.
finalize
()
tensorpack/callbacks/summary.py
0 → 100644
View file @
a4d51a2d
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
re
import
os
import
operator
import
cPickle
as
pickle
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
__all__
=
[
'SummaryWriter'
]
class
StatHolder
(
object
):
def
__init__
(
self
,
log_dir
,
print_tag
=
None
):
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
self
.
stat_now
=
{}
self
.
log_dir
=
log_dir
self
.
filename
=
os
.
path
.
join
(
log_dir
,
'stat.pkl'
)
if
os
.
path
.
isfile
(
self
.
filename
):
logger
.
info
(
"Loading stats from {}..."
.
format
(
self
.
filename
))
with
open
(
self
.
filename
)
as
f
:
self
.
stat_history
=
pickle
.
load
(
f
)
else
:
self
.
stat_history
=
[]
def
add_stat
(
self
,
k
,
v
):
self
.
stat_now
[
k
]
=
v
def
finalize
(
self
):
self
.
_print_stat
()
self
.
stat_history
.
append
(
self
.
stat_now
)
self
.
stat_now
=
{}
self
.
_write_stat
()
def
_print_stat
(
self
):
for
k
,
v
in
sorted
(
self
.
stat_now
.
items
(),
key
=
operator
.
itemgetter
(
0
)):
if
self
.
print_tag
is
None
or
k
in
self
.
print_tag
:
logger
.
info
(
'{}: {:.4f}'
.
format
(
k
,
v
))
def
_write_stat
(
self
):
tmp_filename
=
self
.
filename
+
'.tmp'
with
open
(
tmp_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
self
.
log_dir
=
logger
.
LOG_DIR
logger
.
stat_holder
=
StatHolder
(
self
.
log_dir
,
print_tag
)
def
_before_train
(
self
):
logger
.
writer
=
tf
.
train
.
SummaryWriter
(
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
_trigger_epoch
(
self
):
# check if there is any summary to write
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
logger
.
writer
.
add_summary
(
summary
,
self
.
global_step
)
tensorpack/callbacks/validation_callback.py
View file @
a4d51a2d
...
...
@@ -28,7 +28,6 @@ class ValidationCallback(PeriodicCallback):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
MODEL_KEY
)[
0
]
.
get_input_vars
()
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
self
.
_find_output_vars
()
def
get_tensor
(
self
,
name
):
...
...
@@ -64,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar
.
update
()
cost_avg
=
cost_sum
/
cnt
self
.
writer
.
add_summary
(
create_summary
(
logger
.
writer
.
add_summary
(
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
self
.
global_step
)
logger
.
info
(
"{}_cost: {:.4f}"
.
format
(
self
.
prefix
,
cost_avg
)
)
logger
.
stat_holder
.
add_stat
(
"{}_cost"
.
format
(
self
.
prefix
),
cost_avg
)
def
_trigger_periodic
(
self
):
for
dp
,
outputs
in
self
.
_run_validation
():
...
...
@@ -102,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong
=
outputs
[
0
]
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
writer
.
add_summary
(
create_summary
(
logger
.
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
global_step
)
logger
.
info
(
"{}_error: {:.4f}"
.
format
(
self
.
prefix
,
err_stat
.
accuracy
)
)
logger
.
stat_holder
.
add_stat
(
"{}_error"
.
format
(
self
.
prefix
),
err_stat
.
accuracy
)
tensorpack/utils/logger.py
View file @
a4d51a2d
...
...
@@ -61,3 +61,10 @@ def set_logger_file(filename):
mkdir_p
(
os
.
path
.
dirname
(
LOG_FILE
))
set_file
(
LOG_FILE
)
# global logger:
# a SummaryWriter
writer
=
None
# a StatHolder
stat_holder
=
None
tensorpack/utils/naming.py
View file @
a4d51a2d
...
...
@@ -6,11 +6,10 @@
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY
=
'summary_writer'
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
# extra variables to summarize during training
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
MODEL_KEY
=
'MODEL'
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
__all__
=
[
x
for
x
in
all_local_names
if
x
.
upper
()
==
x
]
__all__
=
[
x
for
x
in
all_local_names
if
x
.
isupper
()
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment