Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
32feff4e
Commit
32feff4e
authored
Apr 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix periodic bug
parent
9a4e6d9d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
19 deletions
+8
-19
examples/mnist_convnet.py
examples/mnist_convnet.py
+0
-3
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+8
-16
No files found.
examples/mnist_convnet.py
View file @
32feff4e
...
@@ -4,8 +4,6 @@
...
@@ -4,8 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.ops
import
control_flow_ops
import
numpy
as
np
import
numpy
as
np
import
os
,
sys
import
os
,
sys
import
argparse
import
argparse
...
@@ -18,7 +16,6 @@ from tensorpack.tfutils.summary import *
...
@@ -18,7 +16,6 @@ from tensorpack.tfutils.summary import *
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
from
IPython
import
embed
;
embed
()
"""
"""
MNIST ConvNet example.
MNIST ConvNet example.
...
...
tensorpack/callbacks/validation_callback.py
View file @
32feff4e
...
@@ -11,23 +11,21 @@ from ..utils import *
...
@@ -11,23 +11,21 @@ from ..utils import *
from
..utils.stat
import
*
from
..utils.stat
import
*
from
..tfutils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
*
from
..tfutils.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
,
TestCallbackType
from
.base
import
Callback
,
TestCallbackType
__all__
=
[
'ClassificationError'
,
'ValidationCallback'
,
'ValidationStatPrinter'
]
__all__
=
[
'ClassificationError'
,
'ValidationCallback'
,
'ValidationStatPrinter'
]
class
ValidationCallback
(
Periodic
Callback
):
class
ValidationCallback
(
Callback
):
"""
"""
Base class for validation callbacks.
Base class for validation callbacks.
"""
"""
type
=
TestCallbackType
()
type
=
TestCallbackType
()
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
):
def
__init__
(
self
,
ds
,
prefix
):
"""
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
"""
super
(
ValidationCallback
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
ds
=
ds
self
.
prefix
=
prefix
self
.
prefix
=
prefix
...
@@ -63,23 +61,18 @@ class ValidationCallback(PeriodicCallback):
...
@@ -63,23 +61,18 @@ class ValidationCallback(PeriodicCallback):
yield
(
dp
,
outputs
)
yield
(
dp
,
outputs
)
pbar
.
update
()
pbar
.
update
()
@
abstractmethod
def
_trigger_periodic
(
self
):
""" Implement the actual callback"""
class
ValidationStatPrinter
(
ValidationCallback
):
class
ValidationStatPrinter
(
ValidationCallback
):
"""
"""
Write stat and summary of some Op for a validation dataset.
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
"""
def
__init__
(
self
,
ds
,
names_to_print
,
prefix
=
'validation'
,
period
=
1
):
def
__init__
(
self
,
ds
,
names_to_print
,
prefix
=
'validation'
):
"""
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
"""
super
(
ValidationStatPrinter
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
super
(
ValidationStatPrinter
,
self
)
.
__init__
(
ds
,
prefix
)
self
.
names
=
names_to_print
self
.
names
=
names_to_print
def
_find_output_vars
(
self
):
def
_find_output_vars
(
self
):
...
@@ -89,7 +82,7 @@ class ValidationStatPrinter(ValidationCallback):
...
@@ -89,7 +82,7 @@ class ValidationStatPrinter(ValidationCallback):
def
_get_output_vars
(
self
):
def
_get_output_vars
(
self
):
return
self
.
vars_to_print
return
self
.
vars_to_print
def
_trigger_
periodic
(
self
):
def
_trigger_
epoch
(
self
):
stats
=
[]
stats
=
[]
for
dp
,
outputs
in
self
.
_run_validation
():
for
dp
,
outputs
in
self
.
_run_validation
():
stats
.
append
(
outputs
)
stats
.
append
(
outputs
)
...
@@ -114,13 +107,12 @@ class ClassificationError(ValidationCallback):
...
@@ -114,13 +107,12 @@ class ClassificationError(ValidationCallback):
In theory, the result could be different from what produced by ValidationStatPrinter.
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
"""
def
__init__
(
self
,
ds
,
prefix
=
'validation'
,
def
__init__
(
self
,
ds
,
prefix
=
'validation'
,
period
=
1
,
wrong_var_name
=
'wrong:0'
):
wrong_var_name
=
'wrong:0'
):
"""
"""
:param ds: a batched `DataFlow` instance
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
:param wrong_var_name: name of the `wrong` variable
"""
"""
super
(
ClassificationError
,
self
)
.
__init__
(
ds
,
prefix
,
period
)
super
(
ClassificationError
,
self
)
.
__init__
(
ds
,
prefix
)
self
.
wrong_var_name
=
wrong_var_name
self
.
wrong_var_name
=
wrong_var_name
def
_find_output_vars
(
self
):
def
_find_output_vars
(
self
):
...
@@ -129,7 +121,7 @@ class ClassificationError(ValidationCallback):
...
@@ -129,7 +121,7 @@ class ClassificationError(ValidationCallback):
def
_get_output_vars
(
self
):
def
_get_output_vars
(
self
):
return
[
self
.
wrong_var
]
return
[
self
.
wrong_var
]
def
_trigger_
periodic
(
self
):
def
_trigger_
epoch
(
self
):
err_stat
=
Accuracy
()
err_stat
=
Accuracy
()
for
dp
,
outputs
in
self
.
_run_validation
():
for
dp
,
outputs
in
self
.
_run_validation
():
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment