Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
585f0837
Commit
585f0837
authored
Dec 28, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
is_training as a bool
parent
dac78238
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
58 additions
and
56 deletions
+58
-56
example_mnist.py
example_mnist.py
+19
-16
models/regularize.py
models/regularize.py
+4
-1
train.py
train.py
+10
-12
utils/__init__.py
utils/__init__.py
+1
-3
utils/callback.py
utils/callback.py
+20
-21
utils/concurrency.py
utils/concurrency.py
+4
-2
utils/summary.py
utils/summary.py
+0
-1
No files found.
example_mnist.py
View file @
585f0837
...
...
@@ -24,22 +24,21 @@ from utils.concurrency import *
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
def
get_model
(
inputs
):
# TODO is_training as a python variable
def
get_model
(
inputs
,
is_training
):
"""
Args:
inputs: a list of input variable,
e.g.: [image_var, label_var] with:
image_var: bx28x28
label_var: bx1 integer
is_training: a python bool variable
Returns:
(outputs, cost)
outputs: a list of output variable
cost: scalar variable
cost:
the cost to minimize.
scalar variable
"""
is_training
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
IS_TRAINING_VAR_NAME
)
keep_prob
=
control_flow_ops
.
cond
(
is_training
,
lambda
:
tf
.
constant
(
0.5
),
lambda
:
tf
.
constant
(
1.0
),
name
=
'dropout_prob'
)
is_training
=
bool
(
is_training
)
keep_prob
=
tf
.
constant
(
0.5
if
is_training
else
1.0
)
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
...
...
@@ -77,19 +76,22 @@ def get_model(inputs):
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
IMAGE_SIZE
=
28
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
logger
.
set_logger_dir
(
log_dir
)
IMAGE_SIZE
=
28
BATCH_SIZE
=
128
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_train
=
FixedSizeData
(
dataset_train
,
20
)
dataset_test
=
FixedSizeData
(
dataset_test
,
20
)
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
...
...
@@ -98,14 +100,15 @@ def get_config():
sess_config
.
allow_soft_placement
=
True
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
input_vars
=
[
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
),
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_vars
=
[
image_var
,
label_var
]
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
'float32'
,
'int32'
],
name
=
'queue'
)
]
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
x
.
dtype
for
x
in
input_vars
],
name
=
'queue'
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
...
...
models/regularize.py
View file @
585f0837
...
...
@@ -10,6 +10,9 @@ from utils import logger
__all__
=
[
'regularize_cost'
]
def
regularize_cost
(
regex
,
func
):
"""
Apply a regularizer on every trainable variable matching the regex
"""
G
=
tf
.
get_default_graph
()
params
=
G
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
...
@@ -17,7 +20,7 @@ def regularize_cost(regex, func):
for
p
in
params
:
name
=
p
.
name
if
re
.
search
(
regex
,
name
):
logger
.
info
(
"
Weight decay
for {}"
.
format
(
name
))
logger
.
info
(
"
Apply regularizer
for {}"
.
format
(
name
))
costs
.
append
(
func
(
p
))
return
tf
.
add_n
(
costs
)
train.py
View file @
585f0837
...
...
@@ -13,9 +13,6 @@ from itertools import count
import
argparse
def
prepare
():
is_training
=
tf
.
constant
(
True
,
name
=
IS_TRAINING_OP_NAME
)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
...
...
@@ -40,7 +37,7 @@ def start_train(config):
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
#
a list of
input/output variables
# input/output variables
input_vars
=
config
[
'inputs'
]
input_queue
=
config
[
'input_queue'
]
get_model_func
=
config
[
'get_model_func'
]
...
...
@@ -49,9 +46,10 @@ def start_train(config):
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
model_inputs
=
input_queue
.
dequeue
()
# set dequeue shape
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
output_vars
,
cost_var
=
get_model_func
(
model_inputs
)
output_vars
,
cost_var
=
get_model_func
(
model_inputs
,
is_training
=
True
)
# build graph
G
=
tf
.
get_default_graph
()
...
...
@@ -84,19 +82,19 @@ def start_train(config):
# a thread that keeps filling the queue
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
with
sess
.
as_default
(),
\
coordinator_
context
(
coordinator_
guard
(
sess
,
coord
,
th
,
input_queue
):
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
step
in
xrange
(
dataset_train
.
size
()):
# TODO eval dequeue to get dp
fetches
=
[
train_op
,
cost_var
]
+
output_vars
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
fetches
=
[
train_op
,
cost_var
]
+
output_vars
+
model_inputs
results
=
sess
.
run
(
fetches
)
cost
=
results
[
1
]
outputs
=
results
[
2
:]
# TODO trigger_step
outputs
=
results
[
2
:
2
+
len
(
output_vars
)]
inputs
=
results
[
-
len
(
model_inputs
):]
callbacks
.
trigger_step
(
inputs
,
outputs
,
cost
)
# note that summary_op will take a data from the queue.
callbacks
.
trigger_epoch
()
sess
.
close
()
...
...
utils/__init__.py
View file @
585f0837
...
...
@@ -57,9 +57,7 @@ def create_test_graph():
))
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
is_training
=
tf
.
constant
(
False
,
name
=
IS_TRAINING_OP_NAME
)
output_vars
,
cost
=
forward_func
(
input_vars
)
output_vars
,
cost
=
forward_func
(
input_vars
,
is_training
=
False
)
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
...
...
utils/callback.py
View file @
585f0837
...
...
@@ -34,9 +34,9 @@ class Callback(object):
"""
Callback to be triggered after every step (every backpropagation)
Args:
inputs: the
input dict fed into the graph
outputs: list of output values after running this
dp
cost: the cost value after running this
dp
inputs: the
list of input values
outputs: list of output values after running this
inputs
cost: the cost value after running this
input
"""
def
trigger_epoch
(
self
):
...
...
@@ -85,9 +85,7 @@ class SummaryWriter(Callback):
# check if there is any summary
if
self
.
summary_op
is
None
:
return
feed
=
{
IS_TRAINING_VAR_NAME
:
True
}
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
summary_str
=
self
.
summary_op
.
eval
()
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
...
...
@@ -102,9 +100,7 @@ class CallbackTimeLogger(object):
self
.
times
.
append
((
name
,
time
))
def
log
(
self
):
"""
log the time of some heavy callbacks
"""
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
return
msgs
=
[]
...
...
@@ -162,12 +158,15 @@ class TestCallbacks(Callback):
def
trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
with
self
.
graph
.
as_default
():
with
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
from
IPython
import
embed
;
embed
()
logger
.
error
(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?"
)
return
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
...
...
@@ -198,7 +197,7 @@ class Callbacks(Callback):
self
.
test
.
before_train
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
train
.
trigger_step
()
self
.
train
.
trigger_step
(
inputs
,
outputs
,
cost
)
# test callback don't have trigger_step
def
trigger_epoch
(
self
):
...
...
utils/concurrency.py
View file @
585f0837
...
...
@@ -45,9 +45,11 @@ class EnqueueThread(threading.Thread):
logger
.
exception
(
"Exception in EnqueueThread:"
)
@
contextmanager
def
coordinator_
context
(
sess
,
coord
,
thread
,
queue
):
def
coordinator_
guard
(
sess
,
coord
,
thread
,
queue
):
"""
Context manager to make sure queue is closed and thread is joined
Context manager to make sure that:
queue is closed
thread is joined
"""
thread
.
start
()
try
:
...
...
utils/summary.py
View file @
585f0837
...
...
@@ -19,7 +19,6 @@ def create_summary(name, v):
return
s
def
add_activation_summary
(
x
,
name
=
None
):
# TODO dedup
"""
Summary for an activation tensor x.
If name is None, use x.name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment