Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
93020942
Commit
93020942
authored
Dec 28, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
queue works with validation! double graph!
parent
86370a76
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
292 additions
and
89 deletions
+292
-89
dataflow/base.py
dataflow/base.py
+1
-0
dataflow/batch.py
dataflow/batch.py
+19
-1
example_mnist.py
example_mnist.py
+19
-13
models/_common.py
models/_common.py
+24
-18
train.py
train.py
+19
-18
utils/__init__.py
utils/__init__.py
+29
-4
utils/callback.py
utils/callback.py
+108
-32
utils/concurrency.py
utils/concurrency.py
+61
-0
utils/logger.py
utils/logger.py
+8
-0
utils/naming.py
utils/naming.py
+1
-0
utils/summary.py
utils/summary.py
+1
-0
utils/validation_callback.py
utils/validation_callback.py
+2
-3
No files found.
dataflow/base.py
View file @
93020942
...
@@ -8,6 +8,7 @@ from abc import abstractmethod
...
@@ -8,6 +8,7 @@ from abc import abstractmethod
__all__
=
[
'DataFlow'
]
__all__
=
[
'DataFlow'
]
class
DataFlow
(
object
):
class
DataFlow
(
object
):
# TODO private impl
@
abstractmethod
@
abstractmethod
def
get_data
(
self
):
def
get_data
(
self
):
"""
"""
...
...
dataflow/batch.py
View file @
93020942
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
numpy
as
np
import
numpy
as
np
from
.base
import
DataFlow
from
.base
import
DataFlow
__all__
=
[
'BatchData'
]
__all__
=
[
'BatchData'
,
'FixedSizeData'
]
class
BatchData
(
DataFlow
):
class
BatchData
(
DataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
@@ -46,3 +46,21 @@ class BatchData(DataFlow):
...
@@ -46,3 +46,21 @@ class BatchData(DataFlow):
np
.
array
([
x
[
k
]
for
x
in
data_holder
],
np
.
array
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
data_holder
[
0
][
k
]
.
dtype
))
dtype
=
data_holder
[
0
][
k
]
.
dtype
))
return
tuple
(
result
)
return
tuple
(
result
)
class
FixedSizeData
(
DataFlow
):
def
__init__
(
self
,
ds
,
size
):
self
.
ds
=
ds
self
.
_size
=
size
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
cnt
=
0
while
True
:
for
dp
in
self
.
ds
.
get_data
():
cnt
+=
1
yield
dp
if
cnt
==
self
.
_size
:
return
example_mnist.py
View file @
93020942
...
@@ -18,11 +18,14 @@ from models import *
...
@@ -18,11 +18,14 @@ from models import *
from
utils
import
*
from
utils
import
*
from
utils.symbolic_functions
import
*
from
utils.symbolic_functions
import
*
from
utils.summary
import
*
from
utils.summary
import
*
from
utils.callback
import
*
from
utils.validation_callback
import
*
from
utils.concurrency
import
*
from
utils.concurrency
import
*
from
dataflow.dataset
import
Mnist
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
from
dataflow
import
*
def
get_model
(
inputs
):
def
get_model
(
inputs
):
# TODO is_training as a python variable
"""
"""
Args:
Args:
inputs: a list of input variable,
inputs: a list of input variable,
...
@@ -41,12 +44,12 @@ def get_model(inputs):
...
@@ -41,12 +44,12 @@ def get_model(inputs):
image
,
label
=
inputs
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
#
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
#
pool0 = MaxPooling('pool0', conv0, 2)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
#
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
#
pool1 = MaxPooling('pool1', conv1, 2)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
fc0
=
FullyConnected
(
'fc0'
,
image
,
1024
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
# fc will have activation summary by default. disable this for the output layer
# fc will have activation summary by default. disable this for the output layer
...
@@ -74,15 +77,18 @@ def get_model(inputs):
...
@@ -74,15 +77,18 @@ def get_model(inputs):
name
=
'regularize_loss'
)
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
return
[
prob
,
nr_wrong
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
def
get_config
():
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
LOG_DIR
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
logger
.
set_logger_dir
(
log_dir
)
BATCH_SIZE
=
128
BATCH_SIZE
=
128
logger
.
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
sess_config
=
tf
.
ConfigProto
()
sess_config
=
tf
.
ConfigProto
()
...
@@ -111,11 +117,11 @@ def get_config():
...
@@ -111,11 +117,11 @@ def get_config():
return
dict
(
return
dict
(
dataset_train
=
dataset_train
,
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callback
s
=
[
callback
=
Callbacks
(
[
SummaryWriter
(
LOG_DIR
),
SummaryWriter
(),
#ValidationError(dataset_test, prefix='test'
),
PeriodicSaver
(
),
PeriodicSaver
(
LOG_DIR
),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
],
]
)
,
session_config
=
sess_config
,
session_config
=
sess_config
,
inputs
=
input_vars
,
inputs
=
input_vars
,
input_queue
=
input_queue
,
input_queue
=
input_queue
,
...
...
models/_common.py
View file @
93020942
...
@@ -7,6 +7,9 @@ import tensorflow as tf
...
@@ -7,6 +7,9 @@ import tensorflow as tf
from
utils.summary
import
*
from
utils.summary
import
*
from
utils
import
logger
from
utils
import
logger
# make sure each layer is only logged once
_layer_logged
=
set
()
def
layer_register
(
summary_activation
=
False
):
def
layer_register
(
summary_activation
=
False
):
"""
"""
summary_activation: default behavior of whether to summary the output of this layer
summary_activation: default behavior of whether to summary the output of this layer
...
@@ -19,26 +22,29 @@ def layer_register(summary_activation=False):
...
@@ -19,26 +22,29 @@ def layer_register(summary_activation=False):
do_summary
=
kwargs
.
pop
(
do_summary
=
kwargs
.
pop
(
'summary_activation'
,
summary_activation
)
'summary_activation'
,
summary_activation
)
inputs
=
args
[
0
]
inputs
=
args
[
0
]
if
isinstance
(
inputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
with
tf
.
variable_scope
(
name
)
as
scope
:
with
tf
.
variable_scope
(
name
)
as
scope
:
outputs
=
func
(
*
args
,
**
kwargs
)
outputs
=
func
(
*
args
,
**
kwargs
)
if
isinstance
(
outputs
,
list
):
if
name
not
in
_layer_logged
:
shape_str
=
","
.
join
(
# log shape info and add activation
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
isinstance
(
inputs
,
list
):
if
do_summary
:
shape_str
=
","
.
join
(
for
x
in
outputs
:
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
add_activation_summary
(
x
,
scope
.
name
)
else
:
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
if
isinstance
(
outputs
,
list
):
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
do_summary
:
for
x
in
outputs
:
add_activation_summary
(
x
,
scope
.
name
)
else
:
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
_layer_logged
.
add
(
name
)
return
outputs
return
outputs
return
inner
return
inner
return
wrapper
return
wrapper
...
...
train.py
View file @
93020942
...
@@ -6,12 +6,14 @@
...
@@ -6,12 +6,14 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
utils
import
*
from
utils
import
*
from
utils.concurrency
import
*
from
utils.concurrency
import
*
from
utils.callback
import
*
from
utils.summary
import
*
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
from
itertools
import
count
from
itertools
import
count
import
argparse
import
argparse
def
prepare
():
def
prepare
():
is_training
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
()
,
name
=
IS_TRAINING_OP_NAME
)
is_training
=
tf
.
constant
(
True
,
name
=
IS_TRAINING_OP_NAME
)
#keep_prob = tf.placeholder(
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var
=
tf
.
Variable
(
global_step_var
=
tf
.
Variable
(
...
@@ -32,7 +34,7 @@ def start_train(config):
...
@@ -32,7 +34,7 @@ def start_train(config):
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
.
__class__
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
.
__class__
# a list of Callback instance
# a list of Callback instance
callbacks
=
Callbacks
(
config
.
get
(
'callbacks'
,
[]))
callbacks
=
config
[
'callback'
]
# a tf.ConfigProto instance
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
sess_config
=
config
.
get
(
'session_config'
,
None
)
...
@@ -53,6 +55,7 @@ def start_train(config):
...
@@ -53,6 +55,7 @@ def start_train(config):
# build graph
# build graph
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
G
.
add_to_collection
(
FORWARD_FUNC_KEY
,
get_model_func
)
for
v
in
input_vars
:
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
for
v
in
output_vars
:
...
@@ -74,30 +77,28 @@ def start_train(config):
...
@@ -74,30 +77,28 @@ def start_train(config):
train_op
=
optimizer
.
apply_gradients
(
grads
,
global_step_var
)
train_op
=
optimizer
.
apply_gradients
(
grads
,
global_step_var
)
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
=
tf
.
Session
(
config
=
sess_config
)
# start training
sess
.
run
(
tf
.
initialize_all_variables
())
with
sess
.
as_default
():
sess
.
run
(
tf
.
initialize_all_variables
())
# start training:
coord
=
tf
.
train
.
Coordinator
()
# a thread that keeps filling the queue
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
sess
.
as_default
(),
\
coordinator_context
(
sess
,
coord
,
th
,
input_queue
):
callbacks
.
before_train
()
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
max_epoch
):
for
epoch
in
xrange
(
1
,
max_epoch
):
coord
=
tf
.
train
.
Coordinator
()
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)),
\
coordinator_context
(
sess
,
coord
,
th
,
input_queue
):
for
step
in
xrange
(
dataset_train
.
size
()):
for
step
in
xrange
(
dataset_train
.
size
()):
# TODO eval dequeue to get dp
# TODO eval dequeue to get dp
fetches
=
[
train_op
,
cost_var
]
+
output_vars
fetches
=
[
train_op
,
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
,
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
feed_dict
=
{
IS_TRAINING_VAR_NAME
:
True
}
)
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
cost
=
results
[
1
]
cost
=
results
[
1
]
outputs
=
results
[
2
:]
outputs
=
results
[
2
:]
print
tf
.
train
.
global_step
(
sess
,
global_step_var
),
cost
# TODO trigger_step
# trigger_step
# note that summary_op will take a data from the queue.
coord
.
request_stop
()
# summary will take a data from the queue
callbacks
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
print
"Finish callback"
sess
.
close
()
sess
.
close
()
def
main
(
get_config_func
):
def
main
(
get_config_func
):
...
...
utils/__init__.py
View file @
93020942
...
@@ -16,11 +16,7 @@ def global_import(name):
...
@@ -16,11 +16,7 @@ def global_import(name):
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
globals
()[
k
]
=
p
.
__dict__
[
k
]
global_import
(
'naming'
)
global_import
(
'naming'
)
global_import
(
'callback'
)
global_import
(
'validation_callback'
)
@
contextmanager
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
def
timed_operation
(
msg
,
log_start
=
False
):
...
@@ -44,3 +40,32 @@ def describe_model():
...
@@ -44,3 +40,32 @@ def describe_model():
msg
.
append
(
"Total dim={}"
.
format
(
total
))
msg
.
append
(
"Total dim={}"
.
format
(
total
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
# TODO disable shape output in get_model
@
contextmanager
def
create_test_graph
():
G
=
tf
.
get_default_graph
()
input_vars_train
=
G
.
get_collection
(
INPUT_VARS_KEY
)
forward_func
=
G
.
get_collection
(
FORWARD_FUNC_KEY
)[
0
]
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
input_vars
=
[]
for
v
in
input_vars_train
:
name
=
v
.
name
assert
name
.
endswith
(
':0'
),
"I think placeholder variable should all ends with ':0'"
name
=
name
[:
-
2
]
input_vars
.
append
(
tf
.
placeholder
(
v
.
dtype
,
shape
=
v
.
get_shape
(),
name
=
name
))
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
is_training
=
tf
.
constant
(
False
,
name
=
IS_TRAINING_OP_NAME
)
output_vars
,
cost
=
forward_func
(
input_vars
)
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
@
contextmanager
def
create_test_session
():
with
create_test_graph
():
with
tf
.
Session
()
as
sess
:
yield
sess
utils/callback.py
View file @
93020942
...
@@ -10,10 +10,16 @@ import os
...
@@ -10,10 +10,16 @@ import os
import
time
import
time
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
.
import
create_test_session
from
.naming
import
*
from
.naming
import
*
import
logger
import
logger
class
Callback
(
object
):
class
Callback
(
object
):
running_graph
=
'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def
before_train
(
self
):
def
before_train
(
self
):
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
self
.
sess
=
tf
.
get_default_session
()
self
.
sess
=
tf
.
get_default_session
()
...
@@ -53,20 +59,20 @@ class PeriodicCallback(Callback):
...
@@ -53,20 +59,20 @@ class PeriodicCallback(Callback):
pass
pass
class
PeriodicSaver
(
PeriodicCallback
):
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
log_dir
,
period
=
1
):
def
__init__
(
self
,
period
=
1
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
log
_dir
,
'model'
)
self
.
path
=
os
.
path
.
join
(
log
ger
.
LOG_DIR
,
'model'
)
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
99999
)
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
99999
)
def
_trigger
(
self
):
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
epoch_num
,
latest_filename
=
'latest'
)
global_step
=
self
.
epoch_num
)
class
SummaryWriter
(
Callback
):
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
log_dir
):
def
__init__
(
self
):
self
.
log_dir
=
log
_dir
self
.
log_dir
=
log
ger
.
LOG_DIR
self
.
epoch_num
=
0
self
.
epoch_num
=
0
def
_before_train
(
self
):
def
_before_train
(
self
):
...
@@ -80,52 +86,122 @@ class SummaryWriter(Callback):
...
@@ -80,52 +86,122 @@ class SummaryWriter(Callback):
if
self
.
summary_op
is
None
:
if
self
.
summary_op
is
None
:
return
return
summary_str
=
self
.
summary_op
.
eval
(
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
feed_dict
=
{
IS_TRAINING_VAR_NAME
:
True
}
)
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
epoch_num
+=
1
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
class
Callbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
for
cb
in
callbacks
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
def
log
(
self
):
"""
log the time of some heavy callbacks
"""
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{}"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
self
.
tot
,
' '
.
join
(
msgs
)))
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
# put SummaryWriter to the first
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
callback
s
):
for
idx
,
cb
in
enumerate
(
self
.
cb
s
):
if
type
(
cb
)
==
SummaryWriter
:
if
type
(
cb
)
==
SummaryWriter
:
callbacks
.
insert
(
0
,
callback
s
.
pop
(
idx
))
self
.
cbs
.
insert
(
0
,
self
.
cb
s
.
pop
(
idx
))
break
break
else
:
else
:
raise
RuntimeError
(
"callbacks must contain a SummaryWriter!"
)
raise
RuntimeError
(
"Callbacks must contain a SummaryWriter!"
)
self
.
callbacks
=
callbacks
def
before_train
(
self
):
def
before_train
(
self
):
for
cb
in
self
.
c
allback
s
:
for
cb
in
self
.
c
b
s
:
cb
.
before_train
()
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
for
cb
in
self
.
c
allback
s
:
for
cb
in
self
.
c
b
s
:
cb
.
trigger_step
(
inputs
,
outputs
,
cost
)
cb
.
trigger_step
(
inputs
,
outputs
,
cost
)
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
start
=
time
.
time
()
tm
=
CallbackTimeLogger
()
times
=
[]
for
cb
in
self
.
cbs
:
for
cb
in
self
.
callbacks
:
s
=
time
.
time
()
s
=
time
.
time
()
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
t
imes
.
append
(
time
.
time
()
-
s
)
t
m
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
self
.
writer
.
flush
()
t
ot
=
time
.
time
()
-
start
t
m
.
log
()
# log the time of some heavy callbacks
class
TestCallbacks
(
Callback
):
if
tot
<
3
:
def
__init__
(
self
,
callbacks
):
return
self
.
cbs
=
callbacks
msgs
=
[]
for
idx
,
t
in
enumerate
(
times
):
def
before_train
(
self
):
if
t
/
tot
>
0.3
and
t
>
1
:
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
msgs
.
append
(
"{}:{}"
.
format
(
with
create_test_session
()
as
sess
:
type
(
self
.
callbacks
[
idx
])
.
__name__
,
t
))
self
.
sess
=
sess
logger
.
info
(
"Callbacks took {} sec. {}"
.
format
(
tot
,
' '
.
join
(
msgs
)))
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
with
self
.
graph
.
as_default
():
with
self
.
sess
.
as_default
():
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
from
IPython
import
embed
;
embed
()
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
tm
.
log
()
class
Callbacks
(
Callback
):
def
__init__
(
self
,
cbs
):
train_cbs
=
[]
test_cbs
=
[]
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
cb
.
running_graph
==
'test'
:
test_cbs
.
append
(
cb
)
elif
cb
.
running_graph
==
'train'
:
train_cbs
.
append
(
cb
)
else
:
raise
RuntimeError
(
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
self
.
train
=
TrainCallbacks
(
train_cbs
)
self
.
test
=
TestCallbacks
(
test_cbs
)
def
before_train
(
self
):
self
.
train
.
before_train
()
self
.
test
.
before_train
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
train
.
trigger_step
()
# test callback don't have trigger_step
def
trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
self
.
test
.
trigger_epoch
()
utils/concurrency.py
0 → 100644
View file @
93020942
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
threading
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
.naming
import
*
import
logger
class
StoppableThread
(
threading
.
Thread
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
self
.
_stop
=
threading
.
Event
()
def
stop
(
self
):
self
.
_stop
.
set
()
def
stopped
(
self
):
return
self
.
_stop
.
isSet
()
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
sess
,
coord
,
enqueue_op
,
dataflow
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
sess
self
.
coord
=
coord
self
.
input_vars
=
sess
.
graph
.
get_collection
(
INPUT_VARS_KEY
)
self
.
dataflow
=
dataflow
self
.
op
=
enqueue_op
def
run
(
self
):
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
@
contextmanager
def
coordinator_context
(
sess
,
coord
,
thread
,
queue
):
"""
Context manager to make sure queue is closed and thread is joined
"""
thread
.
start
()
try
:
yield
except
(
KeyboardInterrupt
,
Exception
)
as
e
:
raise
finally
:
coord
.
request_stop
()
sess
.
run
(
queue
.
close
(
cancel_pending_enqueues
=
True
))
coord
.
join
([
thread
])
utils/logger.py
View file @
93020942
...
@@ -50,3 +50,11 @@ def set_file(path):
...
@@ -50,3 +50,11 @@ def set_file(path):
hdl
=
logging
.
FileHandler
(
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
logger
.
addHandler
(
hdl
)
logger
.
addHandler
(
hdl
)
global
LOG_DIR
LOG_DIR
=
"train_log"
def
set_logger_dir
(
dirname
):
global
LOG_DIR
LOG_DIR
=
dirname
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
utils/naming.py
View file @
93020942
...
@@ -15,6 +15,7 @@ INPUT_VARS_KEY = 'INPUT_VARIABLES'
...
@@ -15,6 +15,7 @@ INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
COST_VARS_KEY
=
'COST_VARIABLES'
# keep track of each individual cost
COST_VARS_KEY
=
'COST_VARIABLES'
# keep track of each individual cost
SUMMARY_VARS_KEY
=
'SUMMARY_VARIABLES'
# extra variables to summarize during training
SUMMARY_VARS_KEY
=
'SUMMARY_VARIABLES'
# extra variables to summarize during training
FORWARD_FUNC_KEY
=
'FORWARD_FUNCTION'
# export all upper case variables
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
all_local_names
=
locals
()
.
keys
()
...
...
utils/summary.py
View file @
93020942
...
@@ -19,6 +19,7 @@ def create_summary(name, v):
...
@@ -19,6 +19,7 @@ def create_summary(name, v):
return
s
return
s
def
add_activation_summary
(
x
,
name
=
None
):
def
add_activation_summary
(
x
,
name
=
None
):
# TODO dedup
"""
"""
Summary for an activation tensor x.
Summary for an activation tensor x.
If name is None, use x.name
If name is None, use x.name
...
...
utils/validation_callback.py
View file @
93020942
...
@@ -11,6 +11,7 @@ from .summary import *
...
@@ -11,6 +11,7 @@ from .summary import *
import
logger
import
logger
class
ValidationError
(
PeriodicCallback
):
class
ValidationError
(
PeriodicCallback
):
running_graph
=
'test'
"""
"""
Validate the accuracy for the given wrong and cost variable
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
Use under the following setup:
...
@@ -33,7 +34,6 @@ class ValidationError(PeriodicCallback):
...
@@ -33,7 +34,6 @@ class ValidationError(PeriodicCallback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
is_training_var
=
self
.
get_tensor
(
IS_TRAINING_VAR_NAME
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
...
@@ -43,8 +43,7 @@ class ValidationError(PeriodicCallback):
...
@@ -43,8 +43,7 @@ class ValidationError(PeriodicCallback):
err_stat
=
Accuracy
()
err_stat
=
Accuracy
()
cost_sum
=
0
cost_sum
=
0
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
feed
=
{
self
.
is_training_var
:
False
}
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
feed
.
update
(
dict
(
zip
(
self
.
input_vars
,
dp
)))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment