Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dac78238
Commit
dac78238
authored
Dec 28, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'queue-thread'
parents
a78d02ac
93020942
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
307 additions
and
91 deletions
+307
-91
dataflow/base.py
dataflow/base.py
+1
-0
dataflow/batch.py
dataflow/batch.py
+19
-1
example_mnist.py
example_mnist.py
+22
-13
models/_common.py
models/_common.py
+24
-18
train.py
train.py
+30
-17
utils/__init__.py
utils/__init__.py
+29
-4
utils/callback.py
utils/callback.py
+108
-34
utils/concurrency.py
utils/concurrency.py
+61
-0
utils/logger.py
utils/logger.py
+9
-1
utils/naming.py
utils/naming.py
+1
-0
utils/summary.py
utils/summary.py
+1
-0
utils/validation_callback.py
utils/validation_callback.py
+2
-3
No files found.
dataflow/base.py
View file @
dac78238
...
...
@@ -8,6 +8,7 @@ from abc import abstractmethod
__all__
=
[
'DataFlow'
]
class
DataFlow
(
object
):
# TODO private impl
@
abstractmethod
def
get_data
(
self
):
"""
...
...
dataflow/batch.py
View file @
dac78238
...
...
@@ -6,7 +6,7 @@
import
numpy
as
np
from
.base
import
DataFlow
__all__
=
[
'BatchData'
]
__all__
=
[
'BatchData'
,
'FixedSizeData'
]
class
BatchData
(
DataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
...
@@ -46,3 +46,21 @@ class BatchData(DataFlow):
np
.
array
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
data_holder
[
0
][
k
]
.
dtype
))
return
tuple
(
result
)
class
FixedSizeData
(
DataFlow
):
def
__init__
(
self
,
ds
,
size
):
self
.
ds
=
ds
self
.
_size
=
size
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
cnt
=
0
while
True
:
for
dp
in
self
.
ds
.
get_data
():
cnt
+=
1
yield
dp
if
cnt
==
self
.
_size
:
return
example_mnist.py
View file @
dac78238
...
...
@@ -18,10 +18,14 @@ from models import *
from
utils
import
*
from
utils.symbolic_functions
import
*
from
utils.summary
import
*
from
utils.callback
import
*
from
utils.validation_callback
import
*
from
utils.concurrency
import
*
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
def
get_model
(
inputs
):
# TODO is_training as a python variable
"""
Args:
inputs: a list of input variable,
...
...
@@ -73,15 +77,18 @@ def get_model(inputs):
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
return
[
prob
,
nr_wrong
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
IMAGE_SIZE
=
28
LOG_DIR
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
logger
.
set_logger_dir
(
log_dir
)
BATCH_SIZE
=
128
logger
.
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
sess_config
=
tf
.
ConfigProto
()
...
...
@@ -91,12 +98,14 @@ def get_config():
sess_config
.
allow_soft_placement
=
True
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_vars
=
[
image_var
,
label_var
]
output_vars
,
cost_var
=
get_model
(
input_vars
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
'float32'
,
'int32'
],
name
=
'queue'
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
...
...
@@ -108,15 +117,15 @@ def get_config():
return
dict
(
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
[
SummaryWriter
(
LOG_DIR
),
callback
=
Callbacks
([
SummaryWriter
(),
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
PeriodicSaver
(
LOG_DIR
),
],
]),
session_config
=
sess_config
,
inputs
=
input_vars
,
outputs
=
output_vars
,
cost
=
cost_var
,
input_queue
=
input_queue
,
get_model_func
=
get_model
,
max_epoch
=
100
,
)
...
...
models/_common.py
View file @
dac78238
...
...
@@ -7,6 +7,9 @@ import tensorflow as tf
from
utils.summary
import
*
from
utils
import
logger
# make sure each layer is only logged once
_layer_logged
=
set
()
def
layer_register
(
summary_activation
=
False
):
"""
summary_activation: default behavior of whether to summary the output of this layer
...
...
@@ -19,26 +22,29 @@ def layer_register(summary_activation=False):
do_summary
=
kwargs
.
pop
(
'summary_activation'
,
summary_activation
)
inputs
=
args
[
0
]
if
isinstance
(
inputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
with
tf
.
variable_scope
(
name
)
as
scope
:
outputs
=
func
(
*
args
,
**
kwargs
)
if
isinstance
(
outputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
do_summary
:
for
x
in
outputs
:
add_activation_summary
(
x
,
scope
.
name
)
else
:
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
if
name
not
in
_layer_logged
:
# log shape info and add activation
if
isinstance
(
inputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
if
isinstance
(
outputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
do_summary
:
for
x
in
outputs
:
add_activation_summary
(
x
,
scope
.
name
)
else
:
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
_layer_logged
.
add
(
name
)
return
outputs
return
inner
return
wrapper
...
...
train.py
View file @
dac78238
...
...
@@ -5,12 +5,15 @@
import
tensorflow
as
tf
from
utils
import
*
from
utils.concurrency
import
*
from
utils.callback
import
*
from
utils.summary
import
*
from
dataflow
import
DataFlow
from
itertools
import
count
import
argparse
def
prepare
():
is_training
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
()
,
name
=
IS_TRAINING_OP_NAME
)
is_training
=
tf
.
constant
(
True
,
name
=
IS_TRAINING_OP_NAME
)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var
=
tf
.
Variable
(
...
...
@@ -31,7 +34,7 @@ def start_train(config):
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
.
__class__
# a list of Callback instance
callbacks
=
Callbacks
(
config
.
get
(
'callbacks'
,
[]))
callbacks
=
config
[
'callback'
]
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
...
...
@@ -39,13 +42,20 @@ def start_train(config):
# a list of input/output variables
input_vars
=
config
[
'inputs'
]
output_vars
=
config
[
'outputs
'
]
cost_var
=
config
[
'cost
'
]
input_queue
=
config
[
'input_queue
'
]
get_model_func
=
config
[
'get_model_func
'
]
max_epoch
=
int
(
config
[
'max_epoch'
])
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
model_inputs
=
input_queue
.
dequeue
()
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
output_vars
,
cost_var
=
get_model_func
(
model_inputs
)
# build graph
G
=
tf
.
get_default_graph
()
G
.
add_to_collection
(
FORWARD_FUNC_KEY
,
get_model_func
)
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
...
...
@@ -67,24 +77,27 @@ def start_train(config):
train_op
=
optimizer
.
apply_gradients
(
grads
,
global_step_var
)
sess
=
tf
.
Session
(
config
=
sess_config
)
# start training
with
sess
.
as_default
():
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
# start training:
coord
=
tf
.
train
.
Coordinator
()
# a thread that keeps filling the queue
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
sess
.
as_default
(),
\
coordinator_context
(
sess
,
coord
,
th
,
input_queue
):
callbacks
.
before_train
()
is_training
=
G
.
get_tensor_by_name
(
IS_TRAINING_VAR_NAME
)
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
dp
in
dataset_train
.
get_data
():
feed
=
{
is_training
:
True
}
feed
.
update
(
dict
(
zip
(
input_vars
,
dp
)))
results
=
sess
.
run
(
[
train_op
,
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
for
step
in
xrange
(
dataset_train
.
size
()):
# TODO eval dequeue to get dp
fetches
=
[
train_op
,
cost_var
]
+
output_vars
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
cost
=
results
[
1
]
outputs
=
results
[
2
:]
callbacks
.
trigger_step
(
feed
,
outputs
,
cost
)
# TODO trigger_step
# note that summary_op will take a data from the queue.
callbacks
.
trigger_epoch
()
sess
.
close
()
...
...
utils/__init__.py
View file @
dac78238
...
...
@@ -16,11 +16,7 @@ def global_import(name):
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
global_import
(
'naming'
)
global_import
(
'callback'
)
global_import
(
'validation_callback'
)
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
...
...
@@ -44,3 +40,32 @@ def describe_model():
msg
.
append
(
"Total dim={}"
.
format
(
total
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
# TODO disable shape output in get_model
@
contextmanager
def
create_test_graph
():
G
=
tf
.
get_default_graph
()
input_vars_train
=
G
.
get_collection
(
INPUT_VARS_KEY
)
forward_func
=
G
.
get_collection
(
FORWARD_FUNC_KEY
)[
0
]
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
input_vars
=
[]
for
v
in
input_vars_train
:
name
=
v
.
name
assert
name
.
endswith
(
':0'
),
"I think placeholder variable should all ends with ':0'"
name
=
name
[:
-
2
]
input_vars
.
append
(
tf
.
placeholder
(
v
.
dtype
,
shape
=
v
.
get_shape
(),
name
=
name
))
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
is_training
=
tf
.
constant
(
False
,
name
=
IS_TRAINING_OP_NAME
)
output_vars
,
cost
=
forward_func
(
input_vars
)
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
@
contextmanager
def
create_test_session
():
with
create_test_graph
():
with
tf
.
Session
()
as
sess
:
yield
sess
utils/callback.py
View file @
dac78238
...
...
@@ -10,10 +10,16 @@ import os
import
time
from
abc
import
abstractmethod
from
.
import
create_test_session
from
.naming
import
*
import
logger
class
Callback
(
object
):
running_graph
=
'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def
before_train
(
self
):
self
.
graph
=
tf
.
get_default_graph
()
self
.
sess
=
tf
.
get_default_session
()
...
...
@@ -53,20 +59,20 @@ class PeriodicCallback(Callback):
pass
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
log_dir
,
period
=
1
):
def
__init__
(
self
,
period
=
1
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
log
_dir
,
'model'
)
self
.
path
=
os
.
path
.
join
(
log
ger
.
LOG_DIR
,
'model'
)
def
_before_train
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
99999
)
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
epoch_num
,
latest_filename
=
'latest'
)
global_step
=
self
.
epoch_num
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
log_dir
):
self
.
log_dir
=
log
_dir
def
__init__
(
self
):
self
.
log_dir
=
log
ger
.
LOG_DIR
self
.
epoch_num
=
0
def
_before_train
(
self
):
...
...
@@ -75,59 +81,127 @@ class SummaryWriter(Callback):
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
last_dp
=
inputs
def
trigger_epoch
(
self
):
# check if there is any summary
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
(
self
.
last_dp
)
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
class
Callbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
for
cb
in
callbacks
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
def
log
(
self
):
"""
log the time of some heavy callbacks
"""
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{}"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
self
.
tot
,
' '
.
join
(
msgs
)))
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
callback
s
):
for
idx
,
cb
in
enumerate
(
self
.
cb
s
):
if
type
(
cb
)
==
SummaryWriter
:
callbacks
.
insert
(
0
,
callback
s
.
pop
(
idx
))
self
.
cbs
.
insert
(
0
,
self
.
cb
s
.
pop
(
idx
))
break
else
:
raise
RuntimeError
(
"callbacks must contain a SummaryWriter!"
)
self
.
callbacks
=
callbacks
raise
RuntimeError
(
"Callbacks must contain a SummaryWriter!"
)
def
before_train
(
self
):
for
cb
in
self
.
c
allback
s
:
for
cb
in
self
.
c
b
s
:
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
for
cb
in
self
.
c
allback
s
:
for
cb
in
self
.
c
b
s
:
cb
.
trigger_step
(
inputs
,
outputs
,
cost
)
def
trigger_epoch
(
self
):
start
=
time
.
time
()
times
=
[]
for
cb
in
self
.
callbacks
:
tm
=
CallbackTimeLogger
()
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
t
imes
.
append
(
time
.
time
()
-
s
)
t
m
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
t
ot
=
time
.
time
()
-
start
t
m
.
log
()
# log the time of some heavy callbacks
if
tot
<
3
:
return
msgs
=
[]
for
idx
,
t
in
enumerate
(
times
):
if
t
/
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{}"
.
format
(
type
(
self
.
callbacks
[
idx
])
.
__name__
,
t
))
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
tot
,
' '
.
join
(
msgs
)))
class
TestCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
def
before_train
(
self
):
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
with
self
.
graph
.
as_default
():
with
self
.
sess
.
as_default
():
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
from
IPython
import
embed
;
embed
()
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
Callbacks
(
Callback
):
def
__init__
(
self
,
cbs
):
train_cbs
=
[]
test_cbs
=
[]
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
cb
.
running_graph
==
'test'
:
test_cbs
.
append
(
cb
)
elif
cb
.
running_graph
==
'train'
:
train_cbs
.
append
(
cb
)
else
:
raise
RuntimeError
(
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
self
.
train
=
TrainCallbacks
(
train_cbs
)
self
.
test
=
TestCallbacks
(
test_cbs
)
def
before_train
(
self
):
self
.
train
.
before_train
()
self
.
test
.
before_train
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
train
.
trigger_step
()
# test callback don't have trigger_step
def
trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
self
.
test
.
trigger_epoch
()
utils/concurrency.py
0 → 100644
View file @
dac78238
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
threading
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
.naming
import
*
import
logger
class
StoppableThread
(
threading
.
Thread
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
self
.
_stop
=
threading
.
Event
()
def
stop
(
self
):
self
.
_stop
.
set
()
def
stopped
(
self
):
return
self
.
_stop
.
isSet
()
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
sess
,
coord
,
enqueue_op
,
dataflow
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
sess
self
.
coord
=
coord
self
.
input_vars
=
sess
.
graph
.
get_collection
(
INPUT_VARS_KEY
)
self
.
dataflow
=
dataflow
self
.
op
=
enqueue_op
def
run
(
self
):
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
@
contextmanager
def
coordinator_context
(
sess
,
coord
,
thread
,
queue
):
"""
Context manager to make sure queue is closed and thread is joined
"""
thread
.
start
()
try
:
yield
except
(
KeyboardInterrupt
,
Exception
)
as
e
:
raise
finally
:
coord
.
request_stop
()
sess
.
run
(
queue
.
close
(
cancel_pending_enqueues
=
True
))
coord
.
join
([
thread
])
utils/logger.py
View file @
dac78238
...
...
@@ -34,7 +34,7 @@ def getlogger():
logger
=
getlogger
()
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
]:
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
locals
()[
func
]
=
getattr
(
logger
,
func
)
def
set_file
(
path
):
...
...
@@ -50,3 +50,11 @@ def set_file(path):
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
logger
.
addHandler
(
hdl
)
global
LOG_DIR
LOG_DIR
=
"train_log"
def
set_logger_dir
(
dirname
):
global
LOG_DIR
LOG_DIR
=
dirname
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
utils/naming.py
View file @
dac78238
...
...
@@ -15,6 +15,7 @@ INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
COST_VARS_KEY
=
'COST_VARIABLES'
# keep track of each individual cost
SUMMARY_VARS_KEY
=
'SUMMARY_VARIABLES'
# extra variables to summarize during training
FORWARD_FUNC_KEY
=
'FORWARD_FUNCTION'
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
...
...
utils/summary.py
View file @
dac78238
...
...
@@ -19,6 +19,7 @@ def create_summary(name, v):
return
s
def
add_activation_summary
(
x
,
name
=
None
):
# TODO dedup
"""
Summary for an activation tensor x.
If name is None, use x.name
...
...
utils/validation_callback.py
View file @
dac78238
...
...
@@ -11,6 +11,7 @@ from .summary import *
import
logger
class
ValidationError
(
PeriodicCallback
):
running_graph
=
'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
...
...
@@ -33,7 +34,6 @@ class ValidationError(PeriodicCallback):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
is_training_var
=
self
.
get_tensor
(
IS_TRAINING_VAR_NAME
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
...
...
@@ -43,8 +43,7 @@ class ValidationError(PeriodicCallback):
err_stat
=
Accuracy
()
cost_sum
=
0
for
dp
in
self
.
ds
.
get_data
():
feed
=
{
self
.
is_training_var
:
False
}
feed
.
update
(
dict
(
zip
(
self
.
input_vars
,
dp
)))
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment