Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1cc356ea
Commit
1cc356ea
authored
Jan 04, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
validation with base routine
parent
ed4e5106
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
26 deletions
+60
-26
example_mnist.py
example_mnist.py
+2
-2
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+58
-24
No files found.
example_mnist.py
View file @
1cc356ea
...
...
@@ -88,14 +88,14 @@ def get_model(inputs, is_training):
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
log_dir
=
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)])
logger
.
set_logger_
dir
(
log_dir
)
logger
.
set_logger_
file
(
os
.
path
.
join
(
log_dir
,
'training.log'
)
)
IMAGE_SIZE
=
28
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
#
step_per_epoch = 30
step_per_epoch
=
30
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/validation_callback.py
View file @
1cc356ea
...
...
@@ -12,40 +12,43 @@ from ..utils.stat import *
from
..utils.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
__all__
=
[
'ValidationError'
]
__all__
=
[
'ValidationError'
,
'ValidationCallback'
]
class
Validation
Error
(
PeriodicCallback
):
class
Validation
Callback
(
PeriodicCallback
):
running_graph
=
'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
Basic routine for validation callbacks.
"""
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
,
wrong_var_name
=
'wrong:0'
,
cost_var_name
=
'cost:0'
):
super
(
ValidationError
,
self
)
.
__init__
(
period
)
def
__init__
(
self
,
ds
,
prefix
,
period
,
cost_var_name
=
'cost:0'
):
super
(
ValidationCallback
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
prefix
=
prefix
self
.
wrong_var_name
=
wrong_var_name
self
.
cost_var_name
=
cost_var_name
def
get_tensor
(
self
,
name
):
return
self
.
graph
.
get_tensor_by_name
(
name
)
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
self
.
_find_output_vars
()
def
_trigger
(
self
):
def
get_tensor
(
self
,
name
):
return
self
.
graph
.
get_tensor_by_name
(
name
)
def
_find_output_vars
(
self
):
pass
def
_get_output_vars
(
self
):
return
[]
def
_run_validation
(
self
):
"""
Generator to return inputs and outputs
"""
cnt
=
0
err_stat
=
Accuracy
()
cost_sum
=
0
output_vars
=
self
.
_get_output_vars
()
output_vars
.
append
(
self
.
cost_var
)
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
itertools
.
izip
(
self
.
input_vars
,
dp
))
...
...
@@ -53,17 +56,48 @@ class ValidationError(PeriodicCallback):
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
cnt
+=
batch_size
wrong
,
cost
=
self
.
sess
.
run
(
[
self
.
wrong_var
,
self
.
cost_var
],
feed_dict
=
feed
)
err_stat
.
feed
(
wrong
,
batch_size
)
outputs
=
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
cost
=
outputs
[
-
1
]
# each batch might not have the same size in validation
cost_sum
+=
cost
*
batch_size
yield
(
dp
,
outputs
[:
-
1
])
pbar
.
update
()
cost_avg
=
cost_sum
/
cnt
self
.
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
global_step
)
self
.
writer
.
add_summary
(
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
self
.
global_step
)
logger
.
info
(
"{}_cost: {:.4f}"
.
format
(
self
.
prefix
,
cost_avg
))
class
ValidationError
(
ValidationCallback
):
running_graph
=
'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
"""
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
,
wrong_var_name
=
'wrong:0'
,
cost_var_name
=
'cost:0'
):
super
(
ValidationError
,
self
)
.
__init__
(
ds
,
prefix
,
period
,
cost_var_name
)
self
.
wrong_var_name
=
wrong_var_name
def
_find_output_vars
(
self
):
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
def
_get_output_vars
(
self
):
return
[
self
.
wrong_var
]
def
_trigger
(
self
):
err_stat
=
Accuracy
()
for
dp
,
outputs
in
self
.
_run_validation
():
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
wrong
=
outputs
[
0
]
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
global_step
)
logger
.
info
(
"{}_error: {:.4f}"
.
format
(
self
.
prefix
,
err_stat
.
accuracy
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment