Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
646a3c6f
Commit
646a3c6f
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
timing for callback
parent
65a3052f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
50 additions
and
23 deletions
+50
-23
example_mnist.py
example_mnist.py
+2
-1
train.py
train.py
+12
-11
utils/__init__.py
utils/__init__.py
+17
-5
utils/callback.py
utils/callback.py
+17
-0
utils/stat.py
utils/stat.py
+0
-2
utils/summary.py
utils/summary.py
+1
-3
utils/validation_callback.py
utils/validation_callback.py
+1
-1
No files found.
example_mnist.py
View file @
646a3c6f
...
...
@@ -12,9 +12,10 @@ import tensorflow as tf
import
numpy
as
np
import
os
from
utils
import
logger
from
layers
import
*
from
utils
import
*
from
utils.symbolic_functions
import
*
from
utils.summary
import
*
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
...
...
train.py
View file @
646a3c6f
...
...
@@ -77,6 +77,7 @@ def start_train(config):
keep_prob_var
=
G
.
get_tensor_by_name
(
DROPOUT_PROB_VAR_NAME
)
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
dp
in
dataset_train
.
get_data
():
feed
=
{
keep_prob_var
:
0.5
}
feed
.
update
(
dict
(
zip
(
input_vars
,
dp
)))
...
...
utils/__init__.py
View file @
646a3c6f
...
...
@@ -5,7 +5,10 @@
from
pkgutil
import
walk_packages
import
os
import
os.path
import
time
import
sys
from
contextlib
import
contextmanager
import
logger
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
())
...
...
@@ -13,7 +16,16 @@ def global_import(name):
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
global_import
(
'naming'
)
global_import
(
'callback'
)
global_import
(
'validation_callback'
)
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
if
log_start
:
logger
.
info
(
'start {} ...'
.
format
(
msg
))
start
=
time
.
time
()
yield
logger
.
info
(
'finished {}, time={:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
utils/callback.py
View file @
646a3c6f
...
...
@@ -7,9 +7,11 @@ import tensorflow as tf
import
sys
import
numpy
as
np
import
os
import
time
from
abc
import
abstractmethod
from
.naming
import
*
import
logger
class
Callback
(
object
):
def
before_train
(
self
):
...
...
@@ -107,7 +109,22 @@ class Callbacks(Callback):
cb
.
trigger_step
(
inputs
,
outputs
,
cost
)
def
trigger_epoch
(
self
):
start
=
time
.
time
()
times
=
[]
for
cb
in
self
.
callbacks
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
times
.
append
(
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tot
=
time
.
time
()
-
start
# log the time of some heavy callbacks
if
tot
<
3
:
return
msgs
=
[]
for
idx
,
t
in
enumerate
(
times
):
if
t
/
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{}"
.
format
(
type
(
self
.
callbacks
[
idx
])
.
__name__
,
t
))
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
tot
,
' '
.
join
(
msgs
)))
utils/stat.py
View file @
646a3c6f
...
...
@@ -4,8 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
__all__
=
[
'StatCounter'
,
'Accuracy'
]
class
StatCounter
(
object
):
def
__init__
(
self
):
self
.
reset
()
...
...
utils/summary.py
View file @
646a3c6f
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File:
utils
.py
# File:
summary
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
__all__
=
[
'create_summary'
,
'add_histogram_summary'
,
'add_activation_summary'
]
def
create_summary
(
name
,
v
):
"""
Return a tf.Summary object with name and simple value v
...
...
utils/validation_callback.py
View file @
646a3c6f
...
...
@@ -65,5 +65,5 @@ class ValidationError(PeriodicCallback):
cost_avg
),
self
.
epoch_num
)
logger
.
info
(
"{} validation after epoch {}: err={
}, cost={
}"
.
format
(
"{} validation after epoch {}: err={
:.4f}, cost={:.3f
}"
.
format
(
self
.
prefix
,
self
.
epoch_num
,
err_stat
.
accuracy
,
cost_avg
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment