Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a78d02ac
Commit
a78d02ac
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use is_training
parent
9bf42054
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
30 additions
and
22 deletions
+30
-22
example_mnist.py
example_mnist.py
+6
-2
train.py
train.py
+7
-14
utils/__init__.py
utils/__init__.py
+1
-2
utils/naming.py
utils/naming.py
+2
-2
utils/summary.py
utils/summary.py
+12
-0
utils/validation_callback.py
utils/validation_callback.py
+2
-2
No files found.
example_mnist.py
View file @
a78d02ac
...
...
@@ -9,6 +9,8 @@ import os
sys
.
path
.
insert
(
0
,
os
.
path
.
expanduser
(
'~/.local/lib/python2.7/site-packages'
))
import
tensorflow
as
tf
from
tensorflow.python.ops
import
control_flow_ops
import
numpy
as
np
import
os
...
...
@@ -31,8 +33,9 @@ def get_model(inputs):
outputs: a list of output variable
cost: scalar variable
"""
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time
keep_prob
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
DROPOUT_PROB_VAR_NAME
)
is_training
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
IS_TRAINING_VAR_NAME
)
keep_prob
=
control_flow_ops
.
cond
(
is_training
,
lambda
:
tf
.
constant
(
0.5
),
lambda
:
tf
.
constant
(
1.0
),
name
=
'dropout_prob'
)
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
...
...
@@ -83,6 +86,7 @@ def get_config():
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
sess_config
.
gpu_options
.
allocator_type
=
'BFC'
sess_config
.
allow_soft_placement
=
True
...
...
train.py
View file @
a78d02ac
...
...
@@ -10,8 +10,9 @@ from itertools import count
import
argparse
def
prepare
():
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
tuple
(),
name
=
DROPOUT_PROB_OP_NAME
)
is_training
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
(),
name
=
IS_TRAINING_OP_NAME
)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
...
...
@@ -49,19 +50,11 @@ def start_train(config):
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
summary
_model
()
describe
_model
()
global_step_var
=
G
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
# add some summary ops to the graph
averager
=
tf
.
train
.
ExponentialMovingAverage
(
0.9
,
num_updates
=
global_step_var
,
name
=
'avg'
)
vars_to_summary
=
[
cost_var
]
+
\
tf
.
get_collection
(
SUMMARY_VARS_KEY
)
+
\
tf
.
get_collection
(
COST_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
c
in
vars_to_summary
:
tf
.
scalar_summary
(
c
.
op
.
name
,
averager
.
average
(
c
))
avg_maintain_op
=
summary_moving_average
(
cost_var
)
# maintain average in each step
with
tf
.
control_dependencies
([
avg_maintain_op
]):
...
...
@@ -79,11 +72,11 @@ def start_train(config):
sess
.
run
(
tf
.
initialize_all_variables
())
callbacks
.
before_train
()
keep_prob_var
=
G
.
get_tensor_by_name
(
DROPOUT_PROB
_VAR_NAME
)
is_training
=
G
.
get_tensor_by_name
(
IS_TRAINING
_VAR_NAME
)
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
dp
in
dataset_train
.
get_data
():
feed
=
{
keep_prob_var
:
0.5
}
feed
=
{
is_training
:
True
}
feed
.
update
(
dict
(
zip
(
input_vars
,
dp
)))
results
=
sess
.
run
(
...
...
utils/__init__.py
View file @
a78d02ac
...
...
@@ -31,8 +31,7 @@ def timed_operation(msg, log_start=False):
logger
.
info
(
'finished {}, time={:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
def
summary_model
():
def
describe_model
():
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
msg
=
[
""
]
total
=
0
...
...
utils/naming.py
View file @
a78d02ac
...
...
@@ -3,8 +3,8 @@
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
DROPOUT_PROB_OP_NAME
=
'dropout_prob
'
DROPOUT_PROB_VAR_NAME
=
'dropout_prob
:0'
IS_TRAINING_OP_NAME
=
'is_training
'
IS_TRAINING_VAR_NAME
=
'is_training
:0'
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
...
...
utils/summary.py
View file @
a78d02ac
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
.naming
import
*
def
create_summary
(
name
,
v
):
"""
...
...
@@ -42,3 +43,14 @@ def add_histogram_summary(regex):
if
re
.
search
(
regex
,
name
):
tf
.
histogram_summary
(
name
,
p
)
def
summary_moving_average
(
cost_var
):
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
averager
=
tf
.
train
.
ExponentialMovingAverage
(
0.9
,
num_updates
=
global_step_var
,
name
=
'avg'
)
vars_to_summary
=
[
cost_var
]
+
\
tf
.
get_collection
(
SUMMARY_VARS_KEY
)
+
\
tf
.
get_collection
(
COST_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
c
in
vars_to_summary
:
tf
.
scalar_summary
(
c
.
op
.
name
,
averager
.
average
(
c
))
return
avg_maintain_op
utils/validation_callback.py
View file @
a78d02ac
...
...
@@ -33,7 +33,7 @@ class ValidationError(PeriodicCallback):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
dropout_var
=
self
.
get_tensor
(
DROPOUT_PROB
_VAR_NAME
)
self
.
is_training_var
=
self
.
get_tensor
(
IS_TRAINING
_VAR_NAME
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
...
...
@@ -43,7 +43,7 @@ class ValidationError(PeriodicCallback):
err_stat
=
Accuracy
()
cost_sum
=
0
for
dp
in
self
.
ds
.
get_data
():
feed
=
{
self
.
dropout_var
:
1.0
}
feed
=
{
self
.
is_training_var
:
False
}
feed
.
update
(
dict
(
zip
(
self
.
input_vars
,
dp
)))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment