Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
585f0837
Commit
585f0837
authored
Dec 28, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
is_training as a bool
parent
dac78238
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
58 additions
and
56 deletions
+58
-56
example_mnist.py
example_mnist.py
+19
-16
models/regularize.py
models/regularize.py
+4
-1
train.py
train.py
+10
-12
utils/__init__.py
utils/__init__.py
+1
-3
utils/callback.py
utils/callback.py
+20
-21
utils/concurrency.py
utils/concurrency.py
+4
-2
utils/summary.py
utils/summary.py
+0
-1
No files found.
example_mnist.py
View file @
585f0837
...
@@ -24,22 +24,21 @@ from utils.concurrency import *
...
@@ -24,22 +24,21 @@ from utils.concurrency import *
from
dataflow.dataset
import
Mnist
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
from
dataflow
import
*
def
get_model
(
inputs
):
def
get_model
(
inputs
,
is_training
):
# TODO is_training as a python variable
"""
"""
Args:
Args:
inputs: a list of input variable,
inputs: a list of input variable,
e.g.: [image_var, label_var] with:
e.g.: [image_var, label_var] with:
image_var: bx28x28
image_var: bx28x28
label_var: bx1 integer
label_var: bx1 integer
is_training: a python bool variable
Returns:
Returns:
(outputs, cost)
(outputs, cost)
outputs: a list of output variable
outputs: a list of output variable
cost: scalar variable
cost:
the cost to minimize.
scalar variable
"""
"""
is_training
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
IS_TRAINING_VAR_NAME
)
is_training
=
bool
(
is_training
)
keep_prob
=
control_flow_ops
.
cond
(
keep_prob
=
tf
.
constant
(
0.5
if
is_training
else
1.0
)
is_training
,
lambda
:
tf
.
constant
(
0.5
),
lambda
:
tf
.
constant
(
1.0
),
name
=
'dropout_prob'
)
image
,
label
=
inputs
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
...
@@ -77,19 +76,22 @@ def get_model(inputs):
...
@@ -77,19 +76,22 @@ def get_model(inputs):
name
=
'regularize_loss'
)
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
# this won't work with multigpu
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
def
get_config
():
IMAGE_SIZE
=
28
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
logger
.
set_logger_dir
(
log_dir
)
logger
.
set_logger_dir
(
log_dir
)
IMAGE_SIZE
=
28
BATCH_SIZE
=
128
BATCH_SIZE
=
128
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_train
=
FixedSizeData
(
dataset_train
,
20
)
dataset_test
=
FixedSizeData
(
dataset_test
,
20
)
sess_config
=
tf
.
ConfigProto
()
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
device_count
[
'GPU'
]
=
1
...
@@ -98,14 +100,15 @@ def get_config():
...
@@ -98,14 +100,15 @@ def get_config():
sess_config
.
allow_soft_placement
=
True
sess_config
.
allow_soft_placement
=
True
# prepare model
# prepare model
image_var
=
tf
.
placeholder
(
input_vars
=
[
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
tf
.
placeholder
(
label_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
),
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
tf
.
placeholder
(
input_vars
=
[
image_var
,
label_var
]
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
'float32'
,
'int32'
],
name
=
'queue'
)
]
input_queue
=
tf
.
RandomShuffleQueue
(
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
100
,
50
,
[
x
.
dtype
for
x
in
input_vars
],
name
=
'queue'
)
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
learning_rate
=
1e-4
,
...
...
models/regularize.py
View file @
585f0837
...
@@ -10,6 +10,9 @@ from utils import logger
...
@@ -10,6 +10,9 @@ from utils import logger
__all__
=
[
'regularize_cost'
]
__all__
=
[
'regularize_cost'
]
def
regularize_cost
(
regex
,
func
):
def
regularize_cost
(
regex
,
func
):
"""
Apply a regularizer on every trainable variable matching the regex
"""
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
params
=
G
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
params
=
G
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
@@ -17,7 +20,7 @@ def regularize_cost(regex, func):
...
@@ -17,7 +20,7 @@ def regularize_cost(regex, func):
for
p
in
params
:
for
p
in
params
:
name
=
p
.
name
name
=
p
.
name
if
re
.
search
(
regex
,
name
):
if
re
.
search
(
regex
,
name
):
logger
.
info
(
"
Weight decay
for {}"
.
format
(
name
))
logger
.
info
(
"
Apply regularizer
for {}"
.
format
(
name
))
costs
.
append
(
func
(
p
))
costs
.
append
(
func
(
p
))
return
tf
.
add_n
(
costs
)
return
tf
.
add_n
(
costs
)
train.py
View file @
585f0837
...
@@ -13,9 +13,6 @@ from itertools import count
...
@@ -13,9 +13,6 @@ from itertools import count
import
argparse
import
argparse
def
prepare
():
def
prepare
():
is_training
=
tf
.
constant
(
True
,
name
=
IS_TRAINING_OP_NAME
)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var
=
tf
.
Variable
(
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
...
@@ -40,7 +37,7 @@ def start_train(config):
...
@@ -40,7 +37,7 @@ def start_train(config):
sess_config
=
config
.
get
(
'session_config'
,
None
)
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
#
a list of
input/output variables
# input/output variables
input_vars
=
config
[
'inputs'
]
input_vars
=
config
[
'inputs'
]
input_queue
=
config
[
'input_queue'
]
input_queue
=
config
[
'input_queue'
]
get_model_func
=
config
[
'get_model_func'
]
get_model_func
=
config
[
'get_model_func'
]
...
@@ -49,9 +46,10 @@ def start_train(config):
...
@@ -49,9 +46,10 @@ def start_train(config):
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
model_inputs
=
input_queue
.
dequeue
()
model_inputs
=
input_queue
.
dequeue
()
# set dequeue shape
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
qv
.
set_shape
(
v
.
get_shape
())
output_vars
,
cost_var
=
get_model_func
(
model_inputs
)
output_vars
,
cost_var
=
get_model_func
(
model_inputs
,
is_training
=
True
)
# build graph
# build graph
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
...
@@ -84,19 +82,19 @@ def start_train(config):
...
@@ -84,19 +82,19 @@ def start_train(config):
# a thread that keeps filling the queue
# a thread that keeps filling the queue
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
sess
.
as_default
(),
\
with
sess
.
as_default
(),
\
coordinator_
context
(
coordinator_
guard
(
sess
,
coord
,
th
,
input_queue
):
sess
,
coord
,
th
,
input_queue
):
callbacks
.
before_train
()
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
max_epoch
):
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
step
in
xrange
(
dataset_train
.
size
()):
for
step
in
xrange
(
dataset_train
.
size
()):
# TODO eval dequeue to get dp
fetches
=
[
train_op
,
cost_var
]
+
output_vars
+
model_inputs
fetches
=
[
train_op
,
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
)
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
cost
=
results
[
1
]
cost
=
results
[
1
]
outputs
=
results
[
2
:]
outputs
=
results
[
2
:
2
+
len
(
output_vars
)]
# TODO trigger_step
inputs
=
results
[
-
len
(
model_inputs
):]
callbacks
.
trigger_step
(
inputs
,
outputs
,
cost
)
# note that summary_op will take a data from the queue.
# note that summary_op will take a data from the queue.
callbacks
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
sess
.
close
()
sess
.
close
()
...
...
utils/__init__.py
View file @
585f0837
...
@@ -57,9 +57,7 @@ def create_test_graph():
...
@@ -57,9 +57,7 @@ def create_test_graph():
))
))
for
v
in
input_vars
:
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
is_training
=
tf
.
constant
(
False
,
name
=
IS_TRAINING_OP_NAME
)
output_vars
,
cost
=
forward_func
(
input_vars
,
is_training
=
False
)
output_vars
,
cost
=
forward_func
(
input_vars
)
for
v
in
output_vars
:
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
yield
Gtest
...
...
utils/callback.py
View file @
585f0837
...
@@ -34,9 +34,9 @@ class Callback(object):
...
@@ -34,9 +34,9 @@ class Callback(object):
"""
"""
Callback to be triggered after every step (every backpropagation)
Callback to be triggered after every step (every backpropagation)
Args:
Args:
inputs: the
input dict fed into the graph
inputs: the
list of input values
outputs: list of output values after running this
dp
outputs: list of output values after running this
inputs
cost: the cost value after running this
dp
cost: the cost value after running this
input
"""
"""
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
...
@@ -85,9 +85,7 @@ class SummaryWriter(Callback):
...
@@ -85,9 +85,7 @@ class SummaryWriter(Callback):
# check if there is any summary
# check if there is any summary
if
self
.
summary_op
is
None
:
if
self
.
summary_op
is
None
:
return
return
summary_str
=
self
.
summary_op
.
eval
()
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
epoch_num
+=
1
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
...
@@ -102,9 +100,7 @@ class CallbackTimeLogger(object):
...
@@ -102,9 +100,7 @@ class CallbackTimeLogger(object):
self
.
times
.
append
((
name
,
time
))
self
.
times
.
append
((
name
,
time
))
def
log
(
self
):
def
log
(
self
):
"""
""" log the time of some heavy callbacks """
log the time of some heavy callbacks
"""
if
self
.
tot
<
3
:
if
self
.
tot
<
3
:
return
return
msgs
=
[]
msgs
=
[]
...
@@ -162,18 +158,21 @@ class TestCallbacks(Callback):
...
@@ -162,18 +158,21 @@ class TestCallbacks(Callback):
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
logger
.
error
(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?"
)
return
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
cb
.
trigger_epoch
()
if
ckpt
is
None
:
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
from
IPython
import
embed
;
embed
()
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
self
.
writer
.
flush
()
self
.
writer
.
flush
()
tm
.
log
()
tm
.
log
()
...
@@ -198,7 +197,7 @@ class Callbacks(Callback):
...
@@ -198,7 +197,7 @@ class Callbacks(Callback):
self
.
test
.
before_train
()
self
.
test
.
before_train
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
train
.
trigger_step
()
self
.
train
.
trigger_step
(
inputs
,
outputs
,
cost
)
# test callback don't have trigger_step
# test callback don't have trigger_step
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
...
...
utils/concurrency.py
View file @
585f0837
...
@@ -45,9 +45,11 @@ class EnqueueThread(threading.Thread):
...
@@ -45,9 +45,11 @@ class EnqueueThread(threading.Thread):
logger
.
exception
(
"Exception in EnqueueThread:"
)
logger
.
exception
(
"Exception in EnqueueThread:"
)
@
contextmanager
@
contextmanager
def
coordinator_
context
(
sess
,
coord
,
thread
,
queue
):
def
coordinator_
guard
(
sess
,
coord
,
thread
,
queue
):
"""
"""
Context manager to make sure queue is closed and thread is joined
Context manager to make sure that:
queue is closed
thread is joined
"""
"""
thread
.
start
()
thread
.
start
()
try
:
try
:
...
...
utils/summary.py
View file @
585f0837
...
@@ -19,7 +19,6 @@ def create_summary(name, v):
...
@@ -19,7 +19,6 @@ def create_summary(name, v):
return
s
return
s
def
add_activation_summary
(
x
,
name
=
None
):
def
add_activation_summary
(
x
,
name
=
None
):
# TODO dedup
"""
"""
Summary for an activation tensor x.
Summary for an activation tensor x.
If name is None, use x.name
If name is None, use x.name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment