Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
93020942
Commit
93020942
authored
Dec 28, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
queue works with validation! double graph!
parent
86370a76
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
292 additions
and
89 deletions
+292
-89
dataflow/base.py
dataflow/base.py
+1
-0
dataflow/batch.py
dataflow/batch.py
+19
-1
example_mnist.py
example_mnist.py
+19
-13
models/_common.py
models/_common.py
+24
-18
train.py
train.py
+19
-18
utils/__init__.py
utils/__init__.py
+29
-4
utils/callback.py
utils/callback.py
+108
-32
utils/concurrency.py
utils/concurrency.py
+61
-0
utils/logger.py
utils/logger.py
+8
-0
utils/naming.py
utils/naming.py
+1
-0
utils/summary.py
utils/summary.py
+1
-0
utils/validation_callback.py
utils/validation_callback.py
+2
-3
No files found.
dataflow/base.py
View file @
93020942
...
...
@@ -8,6 +8,7 @@ from abc import abstractmethod
__all__
=
[
'DataFlow'
]
class
DataFlow
(
object
):
# TODO private impl
@
abstractmethod
def
get_data
(
self
):
"""
...
...
dataflow/batch.py
View file @
93020942
...
...
@@ -6,7 +6,7 @@
import
numpy
as
np
from
.base
import
DataFlow
__all__
=
[
'BatchData'
]
__all__
=
[
'BatchData'
,
'FixedSizeData'
]
class
BatchData
(
DataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
...
@@ -46,3 +46,21 @@ class BatchData(DataFlow):
np
.
array
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
data_holder
[
0
][
k
]
.
dtype
))
return
tuple
(
result
)
class
FixedSizeData
(
DataFlow
):
def
__init__
(
self
,
ds
,
size
):
self
.
ds
=
ds
self
.
_size
=
size
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
cnt
=
0
while
True
:
for
dp
in
self
.
ds
.
get_data
():
cnt
+=
1
yield
dp
if
cnt
==
self
.
_size
:
return
example_mnist.py
View file @
93020942
...
...
@@ -18,11 +18,14 @@ from models import *
from
utils
import
*
from
utils.symbolic_functions
import
*
from
utils.summary
import
*
from
utils.callback
import
*
from
utils.validation_callback
import
*
from
utils.concurrency
import
*
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
def
get_model
(
inputs
):
# TODO is_training as a python variable
"""
Args:
inputs: a list of input variable,
...
...
@@ -41,12 +44,12 @@ def get_model(inputs):
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
#
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
#
pool0 = MaxPooling('pool0', conv0, 2)
#
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
#
pool1 = MaxPooling('pool1', conv1, 2)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
fc0
=
FullyConnected
(
'fc0'
,
image
,
1024
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
# fc will have activation summary by default. disable this for the output layer
...
...
@@ -74,15 +77,18 @@ def get_model(inputs):
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
return
[
prob
,
nr_wrong
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
IMAGE_SIZE
=
28
LOG_DIR
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
logger
.
set_logger_dir
(
log_dir
)
BATCH_SIZE
=
128
logger
.
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
sess_config
=
tf
.
ConfigProto
()
...
...
@@ -111,11 +117,11 @@ def get_config():
return
dict
(
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callback
s
=
[
SummaryWriter
(
LOG_DIR
),
#ValidationError(dataset_test, prefix='test'
),
PeriodicSaver
(
LOG_DIR
),
],
callback
=
Callbacks
(
[
SummaryWriter
(),
PeriodicSaver
(
),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
]
)
,
session_config
=
sess_config
,
inputs
=
input_vars
,
input_queue
=
input_queue
,
...
...
models/_common.py
View file @
93020942
...
...
@@ -7,6 +7,9 @@ import tensorflow as tf
from
utils.summary
import
*
from
utils
import
logger
# make sure each layer is only logged once
_layer_logged
=
set
()
def
layer_register
(
summary_activation
=
False
):
"""
summary_activation: default behavior of whether to summary the output of this layer
...
...
@@ -19,26 +22,29 @@ def layer_register(summary_activation=False):
do_summary
=
kwargs
.
pop
(
'summary_activation'
,
summary_activation
)
inputs
=
args
[
0
]
if
isinstance
(
inputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
with
tf
.
variable_scope
(
name
)
as
scope
:
outputs
=
func
(
*
args
,
**
kwargs
)
if
isinstance
(
outputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
do_summary
:
for
x
in
outputs
:
add_activation_summary
(
x
,
scope
.
name
)
else
:
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
if
name
not
in
_layer_logged
:
# log shape info and add activation
if
isinstance
(
inputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
if
isinstance
(
outputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
do_summary
:
for
x
in
outputs
:
add_activation_summary
(
x
,
scope
.
name
)
else
:
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
_layer_logged
.
add
(
name
)
return
outputs
return
inner
return
wrapper
...
...
train.py
View file @
93020942
...
...
@@ -6,12 +6,14 @@
import
tensorflow
as
tf
from
utils
import
*
from
utils.concurrency
import
*
from
utils.callback
import
*
from
utils.summary
import
*
from
dataflow
import
DataFlow
from
itertools
import
count
import
argparse
def
prepare
():
is_training
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
()
,
name
=
IS_TRAINING_OP_NAME
)
is_training
=
tf
.
constant
(
True
,
name
=
IS_TRAINING_OP_NAME
)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var
=
tf
.
Variable
(
...
...
@@ -32,7 +34,7 @@ def start_train(config):
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
.
__class__
# a list of Callback instance
callbacks
=
Callbacks
(
config
.
get
(
'callbacks'
,
[]))
callbacks
=
config
[
'callback'
]
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
...
...
@@ -53,6 +55,7 @@ def start_train(config):
# build graph
G
=
tf
.
get_default_graph
()
G
.
add_to_collection
(
FORWARD_FUNC_KEY
,
get_model_func
)
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
...
...
@@ -74,30 +77,28 @@ def start_train(config):
train_op
=
optimizer
.
apply_gradients
(
grads
,
global_step_var
)
sess
=
tf
.
Session
(
config
=
sess_config
)
# start training
with
sess
.
as_default
():
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
# start training:
coord
=
tf
.
train
.
Coordinator
()
# a thread that keeps filling the queue
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
sess
.
as_default
(),
\
coordinator_context
(
sess
,
coord
,
th
,
input_queue
):
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
max_epoch
):
coord
=
tf
.
train
.
Coordinator
()
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)),
\
coordinator_context
(
sess
,
coord
,
th
,
input_queue
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
step
in
xrange
(
dataset_train
.
size
()):
# TODO eval dequeue to get dp
fetches
=
[
train_op
,
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
,
feed_dict
=
{
IS_TRAINING_VAR_NAME
:
True
}
)
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
cost
=
results
[
1
]
outputs
=
results
[
2
:]
print
tf
.
train
.
global_step
(
sess
,
global_step_var
),
cost
# trigger_step
coord
.
request_stop
()
# summary will take a data from the queue
# TODO trigger_step
# note that summary_op will take a data from the queue.
callbacks
.
trigger_epoch
()
print
"Finish callback"
sess
.
close
()
def
main
(
get_config_func
):
...
...
utils/__init__.py
View file @
93020942
...
...
@@ -16,11 +16,7 @@ def global_import(name):
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
global_import
(
'naming'
)
global_import
(
'callback'
)
global_import
(
'validation_callback'
)
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
...
...
@@ -44,3 +40,32 @@ def describe_model():
msg
.
append
(
"Total dim={}"
.
format
(
total
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
# TODO disable shape output in get_model
@
contextmanager
def
create_test_graph
():
G
=
tf
.
get_default_graph
()
input_vars_train
=
G
.
get_collection
(
INPUT_VARS_KEY
)
forward_func
=
G
.
get_collection
(
FORWARD_FUNC_KEY
)[
0
]
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
input_vars
=
[]
for
v
in
input_vars_train
:
name
=
v
.
name
assert
name
.
endswith
(
':0'
),
"I think placeholder variable should all ends with ':0'"
name
=
name
[:
-
2
]
input_vars
.
append
(
tf
.
placeholder
(
v
.
dtype
,
shape
=
v
.
get_shape
(),
name
=
name
))
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
is_training
=
tf
.
constant
(
False
,
name
=
IS_TRAINING_OP_NAME
)
output_vars
,
cost
=
forward_func
(
input_vars
)
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
@
contextmanager
def
create_test_session
():
with
create_test_graph
():
with
tf
.
Session
()
as
sess
:
yield
sess
utils/callback.py
View file @
93020942
...
...
@@ -10,10 +10,16 @@ import os
import
time
from
abc
import
abstractmethod
from
.
import
create_test_session
from
.naming
import
*
import
logger
class
Callback
(
object
):
running_graph
=
'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def
before_train
(
self
):
self
.
graph
=
tf
.
get_default_graph
()
self
.
sess
=
tf
.
get_default_session
()
...
...
@@ -53,20 +59,20 @@ class PeriodicCallback(Callback):
pass
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
log_dir
,
period
=
1
):
def
__init__
(
self
,
period
=
1
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
log
_dir
,
'model'
)
self
.
path
=
os
.
path
.
join
(
log
ger
.
LOG_DIR
,
'model'
)
def
_before_train
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
99999
)
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
epoch_num
,
latest_filename
=
'latest'
)
global_step
=
self
.
epoch_num
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
log_dir
):
self
.
log_dir
=
log
_dir
def
__init__
(
self
):
self
.
log_dir
=
log
ger
.
LOG_DIR
self
.
epoch_num
=
0
def
_before_train
(
self
):
...
...
@@ -80,52 +86,122 @@ class SummaryWriter(Callback):
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
{
IS_TRAINING_VAR_NAME
:
True
}
)
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
class
Callbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
for
cb
in
callbacks
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
def
log
(
self
):
"""
log the time of some heavy callbacks
"""
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{}"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
self
.
tot
,
' '
.
join
(
msgs
)))
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
callback
s
):
for
idx
,
cb
in
enumerate
(
self
.
cb
s
):
if
type
(
cb
)
==
SummaryWriter
:
callbacks
.
insert
(
0
,
callback
s
.
pop
(
idx
))
self
.
cbs
.
insert
(
0
,
self
.
cb
s
.
pop
(
idx
))
break
else
:
raise
RuntimeError
(
"callbacks must contain a SummaryWriter!"
)
self
.
callbacks
=
callbacks
raise
RuntimeError
(
"Callbacks must contain a SummaryWriter!"
)
def
before_train
(
self
):
for
cb
in
self
.
c
allback
s
:
for
cb
in
self
.
c
b
s
:
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
for
cb
in
self
.
c
allback
s
:
for
cb
in
self
.
c
b
s
:
cb
.
trigger_step
(
inputs
,
outputs
,
cost
)
def
trigger_epoch
(
self
):
start
=
time
.
time
()
times
=
[]
for
cb
in
self
.
callbacks
:
tm
=
CallbackTimeLogger
()
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
t
imes
.
append
(
time
.
time
()
-
s
)
t
m
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
t
ot
=
time
.
time
()
-
start
t
m
.
log
()
# log the time of some heavy callbacks
if
tot
<
3
:
return
msgs
=
[]
for
idx
,
t
in
enumerate
(
times
):
if
t
/
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{}"
.
format
(
type
(
self
.
callbacks
[
idx
])
.
__name__
,
t
))
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
tot
,
' '
.
join
(
msgs
)))
class
TestCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
def
before_train
(
self
):
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
with
self
.
graph
.
as_default
():
with
self
.
sess
.
as_default
():
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
from
IPython
import
embed
;
embed
()
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
Callbacks
(
Callback
):
def
__init__
(
self
,
cbs
):
train_cbs
=
[]
test_cbs
=
[]
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
cb
.
running_graph
==
'test'
:
test_cbs
.
append
(
cb
)
elif
cb
.
running_graph
==
'train'
:
train_cbs
.
append
(
cb
)
else
:
raise
RuntimeError
(
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
self
.
train
=
TrainCallbacks
(
train_cbs
)
self
.
test
=
TestCallbacks
(
test_cbs
)
def
before_train
(
self
):
self
.
train
.
before_train
()
self
.
test
.
before_train
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
train
.
trigger_step
()
# test callback don't have trigger_step
def
trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
self
.
test
.
trigger_epoch
()
utils/concurrency.py
0 → 100644
View file @
93020942
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
threading
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
.naming
import
*
import
logger
class
StoppableThread
(
threading
.
Thread
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
self
.
_stop
=
threading
.
Event
()
def
stop
(
self
):
self
.
_stop
.
set
()
def
stopped
(
self
):
return
self
.
_stop
.
isSet
()
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
sess
,
coord
,
enqueue_op
,
dataflow
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
sess
self
.
coord
=
coord
self
.
input_vars
=
sess
.
graph
.
get_collection
(
INPUT_VARS_KEY
)
self
.
dataflow
=
dataflow
self
.
op
=
enqueue_op
def
run
(
self
):
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
@
contextmanager
def
coordinator_context
(
sess
,
coord
,
thread
,
queue
):
"""
Context manager to make sure queue is closed and thread is joined
"""
thread
.
start
()
try
:
yield
except
(
KeyboardInterrupt
,
Exception
)
as
e
:
raise
finally
:
coord
.
request_stop
()
sess
.
run
(
queue
.
close
(
cancel_pending_enqueues
=
True
))
coord
.
join
([
thread
])
utils/logger.py
View file @
93020942
...
...
@@ -50,3 +50,11 @@ def set_file(path):
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
logger
.
addHandler
(
hdl
)
global
LOG_DIR
LOG_DIR
=
"train_log"
def
set_logger_dir
(
dirname
):
global
LOG_DIR
LOG_DIR
=
dirname
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
utils/naming.py
View file @
93020942
...
...
@@ -15,6 +15,7 @@ INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
COST_VARS_KEY
=
'COST_VARIABLES'
# keep track of each individual cost
SUMMARY_VARS_KEY
=
'SUMMARY_VARIABLES'
# extra variables to summarize during training
FORWARD_FUNC_KEY
=
'FORWARD_FUNCTION'
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
...
...
utils/summary.py
View file @
93020942
...
...
@@ -19,6 +19,7 @@ def create_summary(name, v):
return
s
def
add_activation_summary
(
x
,
name
=
None
):
# TODO dedup
"""
Summary for an activation tensor x.
If name is None, use x.name
...
...
utils/validation_callback.py
View file @
93020942
...
...
@@ -11,6 +11,7 @@ from .summary import *
import
logger
class
ValidationError
(
PeriodicCallback
):
running_graph
=
'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
...
...
@@ -33,7 +34,6 @@ class ValidationError(PeriodicCallback):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
is_training_var
=
self
.
get_tensor
(
IS_TRAINING_VAR_NAME
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
...
...
@@ -43,8 +43,7 @@ class ValidationError(PeriodicCallback):
err_stat
=
Accuracy
()
cost_sum
=
0
for
dp
in
self
.
ds
.
get_data
():
feed
=
{
self
.
is_training_var
:
False
}
feed
.
update
(
dict
(
zip
(
self
.
input_vars
,
dp
)))
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment