Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
174c3fc9
Commit
174c3fc9
authored
Apr 19, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
single-pass inference
parent
76fe1b6b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
206 additions
and
142 deletions
+206
-142
examples/ResNet/cifar10_resnet.py
examples/ResNet/cifar10_resnet.py
+2
-1
examples/ResNet/svhn_resnet.py
examples/ResNet/svhn_resnet.py
+2
-1
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+2
-2
examples/mnist_convnet.py
examples/mnist_convnet.py
+3
-3
examples/svhn_digit_convnet.py
examples/svhn_digit_convnet.py
+2
-1
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+195
-0
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+0
-133
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+0
-1
No files found.
examples/ResNet/cifar10_resnet.py
View file @
174c3fc9
...
@@ -162,7 +162,8 @@ def get_config():
...
@@ -162,7 +162,8 @@ def get_config():
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
ClassificationError
(
dataset_test
,
prefix
=
'validation'
),
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()]),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
]),
]),
...
...
examples/ResNet/svhn_resnet.py
View file @
174c3fc9
...
@@ -168,7 +168,8 @@ def get_config():
...
@@ -168,7 +168,8 @@ def get_config():
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
ClassificationError
(
dataset_test
,
prefix
=
'validation'
),
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()
]),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
1
,
0.1
),
(
20
,
0.01
),
(
33
,
0.001
),
(
60
,
0.0001
)])
[(
1
,
0.1
),
(
20
,
0.01
),
(
33
,
0.001
),
(
60
,
0.0001
)])
]),
]),
...
...
examples/cifar10_convnet.py
View file @
174c3fc9
...
@@ -124,12 +124,12 @@ def get_config():
...
@@ -124,12 +124,12 @@ def get_config():
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
ClassificationError
(
dataset_test
,
prefix
=
'test'
),
InferenceRunner
(
dataset_test
,
ClassificationError
())
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
2
0
,
max_epoch
=
30
0
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
examples/mnist_convnet.py
View file @
174c3fc9
...
@@ -72,7 +72,7 @@ class Model(ModelDesc):
...
@@ -72,7 +72,7 @@ class Model(ModelDesc):
name
=
'regularize_loss'
)
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
add_param_summary
([(
'.*/W'
,
[
'histogram'
,
'sparsity'
])])
# monitor histogram of all W
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor histogram of all W
return
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
return
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
def
get_config
():
...
@@ -102,8 +102,8 @@ def get_config():
...
@@ -102,8 +102,8 @@ def get_config():
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
ValidationStatPrinter
(
dataset_test
,
[
'cost:0'
])
,
InferenceRunner
(
dataset_test
,
ClassificationError
(
dataset_test
,
prefix
=
'validation'
),
[
ScalarStats
(
'cost'
),
ClassificationError
()
])
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
...
...
examples/svhn_digit_convnet.py
View file @
174c3fc9
...
@@ -109,7 +109,8 @@ def get_config():
...
@@ -109,7 +109,8 @@ def get_config():
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
ClassificationError
(
test
,
prefix
=
'test'
),
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()])
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
...
...
tensorpack/callbacks/inference.py
0 → 100644
View file @
174c3fc9
# -*- coding: UTF-8 -*-
# File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
from
six.moves
import
zip
from
..dataflow
import
DataFlow
from
..utils
import
*
from
..utils.stat
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
*
from
.base
import
Callback
,
TestCallbackType
__all__
=
[
'InferenceRunner'
,
'ClassificationError'
,
'ScalarStats'
,
'Inferencer'
]
class
Inferencer
(
object
):
__metaclass__
=
ABCMeta
def
before_inference
(
self
):
"""
Called before a new round of inference starts.
"""
self
.
_before_inference
()
def
_before_inference
(
self
):
pass
def
datapoint
(
self
,
dp
,
output
):
"""
Called after complete running every data point
"""
self
.
_datapoint
(
dp
,
output
)
@
abstractmethod
def
_datapoint
(
self
,
dp
,
output
):
pass
def
after_inference
(
self
):
"""
Called after a round of inference ends.
"""
self
.
_after_inference
()
def
_after_inference
(
self
):
pass
def
get_output_tensors
(
self
):
"""
Return a list of tensor names needed for this inference
"""
return
self
.
_get_output_vars
()
@
abstractmethod
def
_get_output_tensors
(
self
):
pass
class
InferenceRunner
(
Callback
):
"""
A callback that runs different kinds of inferencer.
"""
type
=
TestCallbackType
()
def
__init__
(
self
,
ds
,
vcs
):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param vcs: a list of `Inferencer` instance.
"""
assert
isinstance
(
ds
,
DataFlow
),
type
(
ds
)
self
.
ds
=
ds
if
not
isinstance
(
vcs
,
list
):
self
.
vcs
=
[
vcs
]
else
:
self
.
vcs
=
vcs
for
v
in
self
.
vcs
:
assert
isinstance
(
v
,
Inferencer
),
str
(
v
)
def
_before_train
(
self
):
self
.
input_vars
=
self
.
trainer
.
model
.
reuse_input_vars
()
self
.
_find_output_tensors
()
for
v
in
self
.
vcs
:
v
.
trainer
=
self
.
trainer
def
_find_output_tensors
(
self
):
self
.
output_tensors
=
[]
self
.
vc_to_vars
=
[]
for
vc
in
self
.
vcs
:
vc_vars
=
vc
.
_get_output_tensors
()
def
find_oid
(
var
):
if
var
in
self
.
output_tensors
:
return
self
.
output_tensors
.
index
(
var
)
else
:
self
.
output_tensors
.
append
(
var
)
return
len
(
self
.
output_tensors
)
-
1
vc_vars
=
[(
var
,
find_oid
(
var
))
for
var
in
vc_vars
]
self
.
vc_to_vars
.
append
(
vc_vars
)
# convert name to tensors
def
get_tensor
(
name
):
_
,
varname
=
get_op_var_name
(
name
)
return
self
.
graph
.
get_tensor_by_name
(
varname
)
self
.
output_tensors
=
map
(
get_tensor
,
self
.
output_tensors
)
def
_trigger_epoch
(
self
):
for
vc
in
self
.
vcs
:
vc
.
before_inference
()
sess
=
tf
.
get_default_session
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
# TODO custom dp mapping?
outputs
=
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
for
vc
,
varsmap
in
zip
(
self
.
vcs
,
self
.
vc_to_vars
):
vc_output
=
[
outputs
[
k
[
1
]]
for
k
in
varsmap
]
vc
.
datapoint
(
dp
,
vc_output
)
pbar
.
update
()
for
vc
in
self
.
vcs
:
vc
.
after_inference
()
class
ScalarStats
(
Inferencer
):
"""
Write stat and summary of some scalar tensor.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the dataset.
"""
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
"""
:param names_to_print: list of names of tensors, or just a name
:param prefix: an optional prefix for logging
"""
if
not
isinstance
(
names_to_print
,
list
):
self
.
names
=
[
names_to_print
]
else
:
self
.
names
=
names_to_print
self
.
prefix
=
prefix
def
_get_output_tensors
(
self
):
return
self
.
names
def
_before_inference
(
self
):
self
.
stats
=
[]
def
_datapoint
(
self
,
dp
,
output
):
self
.
stats
.
append
(
output
)
def
_after_inference
(
self
):
self
.
stats
=
np
.
mean
(
self
.
stats
,
axis
=
0
)
assert
len
(
self
.
stats
)
==
len
(
self
.
names
)
for
stat
,
name
in
zip
(
self
.
stats
,
self
.
names
):
opname
,
_
=
get_op_var_name
(
name
)
name
=
'{}_{}'
.
format
(
self
.
prefix
,
opname
)
if
self
.
prefix
else
opname
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
name
,
stat
),
get_global_step
())
self
.
trainer
.
stat_holder
.
add_stat
(
name
,
stat
)
class
ClassificationError
(
Inferencer
):
"""
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def
__init__
(
self
,
wrong_var_name
=
'wrong:0'
,
prefix
=
'validation'
):
"""
:param wrong_var_name: name of the `wrong` variable
:param prefix: an optional prefix for logging
"""
self
.
wrong_var_name
=
wrong_var_name
self
.
prefix
=
prefix
def
_get_output_tensors
(
self
):
return
[
self
.
wrong_var_name
]
def
_before_inference
(
self
):
self
.
err_stat
=
Accuracy
()
def
_datapoint
(
self
,
dp
,
outputs
):
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
wrong
=
int
(
outputs
[
0
])
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
def
_after_inference
(
self
):
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
self
.
err_stat
.
accuracy
),
get_global_step
())
self
.
trainer
.
stat_holder
.
add_stat
(
"{}_error"
.
format
(
self
.
prefix
),
self
.
err_stat
.
accuracy
)
tensorpack/callbacks/validation_callback.py
deleted
100644 → 0
View file @
76fe1b6b
# -*- coding: UTF-8 -*-
# File: validation_callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
from
six.moves
import
zip
from
..utils
import
*
from
..utils.stat
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
*
from
.base
import
Callback
,
TestCallbackType
__all__
=
[
'ClassificationError'
,
'ValidationCallback'
,
'ValidationStatPrinter'
]
class
ValidationCallback
(
Callback
):
"""
Base class for validation callbacks.
"""
type
=
TestCallbackType
()
def
__init__
(
self
,
ds
,
prefix
):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
"""
self
.
ds
=
ds
self
.
prefix
=
prefix
def
_before_train
(
self
):
self
.
input_vars
=
self
.
trainer
.
model
.
reuse_input_vars
()
self
.
_find_output_vars
()
def
get_tensor
(
self
,
name
):
"""
Get tensor from graph.
"""
return
self
.
graph
.
get_tensor_by_name
(
name
)
@
abstractmethod
def
_find_output_vars
(
self
):
""" prepare output variables. Will be called in before_train"""
@
abstractmethod
def
_get_output_vars
(
self
):
""" return a list of output vars to eval"""
def
_run_validation
(
self
):
"""
Eval the vars, generate inputs and outputs
"""
output_vars
=
self
.
_get_output_vars
()
sess
=
tf
.
get_default_session
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
outputs
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
yield
(
dp
,
outputs
)
pbar
.
update
()
class
ValidationStatPrinter
(
ValidationCallback
):
"""
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def
__init__
(
self
,
ds
,
names_to_print
,
prefix
=
'validation'
):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
"""
super
(
ValidationStatPrinter
,
self
)
.
__init__
(
ds
,
prefix
)
self
.
names
=
names_to_print
def
_find_output_vars
(
self
):
self
.
vars_to_print
=
[
self
.
get_tensor
(
get_op_var_name
(
n
)[
1
])
for
n
in
self
.
names
]
def
_get_output_vars
(
self
):
return
self
.
vars_to_print
def
_trigger_epoch
(
self
):
stats
=
[]
for
dp
,
outputs
in
self
.
_run_validation
():
stats
.
append
(
outputs
)
stats
=
np
.
mean
(
stats
,
axis
=
0
)
assert
len
(
stats
)
==
len
(
self
.
vars_to_print
)
for
stat
,
var
in
zip
(
stats
,
self
.
vars_to_print
):
name
=
var
.
name
.
replace
(
':0'
,
''
)
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
'{}_{}'
.
format
(
self
.
prefix
,
name
),
stat
),
self
.
global_step
)
self
.
trainer
.
stat_holder
.
add_stat
(
"{}_{}"
.
format
(
self
.
prefix
,
name
),
stat
)
class
ClassificationError
(
ValidationCallback
):
"""
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def
__init__
(
self
,
ds
,
prefix
=
'validation'
,
wrong_var_name
=
'wrong:0'
):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super
(
ClassificationError
,
self
)
.
__init__
(
ds
,
prefix
)
self
.
wrong_var_name
=
wrong_var_name
def
_find_output_vars
(
self
):
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
def
_get_output_vars
(
self
):
return
[
self
.
wrong_var
]
def
_trigger_epoch
(
self
):
err_stat
=
Accuracy
()
for
dp
,
outputs
in
self
.
_run_validation
():
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
wrong
=
outputs
[
0
]
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
global_step
)
self
.
trainer
.
stat_holder
.
add_stat
(
"{}_error"
.
format
(
self
.
prefix
),
err_stat
.
accuracy
)
tensorpack/tfutils/common.py
View file @
174c3fc9
...
@@ -36,7 +36,6 @@ def get_global_step():
...
@@ -36,7 +36,6 @@ def get_global_step():
tf
.
get_default_session
(),
tf
.
get_default_session
(),
get_global_step_var
())
get_global_step_var
())
def
get_op_var_name
(
name
):
def
get_op_var_name
(
name
):
"""
"""
Variable name is assumed to be ``op_name + ':0'``
Variable name is assumed to be ``op_name + ':0'``
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment