Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
20887a79
Commit
20887a79
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add prepare
parent
28599036
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
50 deletions
+54
-50
example_mnist.py
example_mnist.py
+46
-45
train.py
train.py
+8
-5
No files found.
example_mnist.py
View file @
20887a79
...
@@ -17,9 +17,6 @@ from utils import *
...
@@ -17,9 +17,6 @@ from utils import *
from
dataflow.dataset
import
Mnist
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
from
dataflow
import
*
IMAGE_SIZE
=
28
LOG_DIR
=
'train_log'
def
get_model
(
inputs
):
def
get_model
(
inputs
):
"""
"""
Args:
Args:
...
@@ -33,11 +30,11 @@ def get_model(inputs):
...
@@ -33,11 +30,11 @@ def get_model(inputs):
cost: scalar variable
cost: scalar variable
"""
"""
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
tuple
(),
name
=
DROPOUT_PROB_OP
_NAME
)
keep_prob
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
DROPOUT_PROB_VAR
_NAME
)
image
,
label
=
inputs
image
,
label
=
inputs
image
=
tf
.
reshape
(
image
,
[
-
1
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
]
)
image
=
tf
.
expand_dims
(
image
,
3
)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
padding
=
'valid'
)
padding
=
'valid'
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
...
@@ -76,48 +73,52 @@ def get_model(inputs):
...
@@ -76,48 +73,52 @@ def get_model(inputs):
return
[
prob
,
nr_wrong
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
return
[
prob
,
nr_wrong
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
def
main
(
argv
=
None
):
def
get_config
():
IMAGE_SIZE
=
28
LOG_DIR
=
'train_log'
BATCH_SIZE
=
128
BATCH_SIZE
=
128
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_vars
=
[
image_var
,
label_var
]
output_vars
,
cost_var
=
get_model
(
input_vars
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
global_step
=
global_step_var
,
decay_steps
=
dataset_train
.
size
()
*
50
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
dict
(
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
[
SummaryWriter
(
LOG_DIR
),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
PeriodicSaver
(
LOG_DIR
),
],
session_config
=
sess_config
,
inputs
=
input_vars
,
outputs
=
output_vars
,
cost
=
cost_var
,
max_epoch
=
100
,
)
def
main
(
argv
=
None
):
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
from
train
import
prepare
,
start_train
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
prepare
()
config
=
get_config
()
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_vars
=
[
image_var
,
label_var
]
output_vars
,
cost_var
=
get_model
(
input_vars
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
global_step
=
global_step_var
,
decay_steps
=
dataset_train
.
size
()
*
50
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
config
=
dict
(
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
[
ValidationError
(
dataset_test
,
prefix
=
'test'
),
PeriodicSaver
(
LOG_DIR
,
period
=
1
),
SummaryWriter
(
LOG_DIR
),
],
session_config
=
sess_config
,
inputs
=
input_vars
,
outputs
=
output_vars
,
cost
=
cost_var
,
max_epoch
=
100
,
)
from
train
import
start_train
start_train
(
config
)
start_train
(
config
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
train.py
View file @
20887a79
...
@@ -7,6 +7,13 @@ import tensorflow as tf
...
@@ -7,6 +7,13 @@ import tensorflow as tf
from
utils
import
*
from
utils
import
*
from
itertools
import
count
from
itertools
import
count
def
prepare
():
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
tuple
(),
name
=
DROPOUT_PROB_OP_NAME
)
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
def
start_train
(
config
):
def
start_train
(
config
):
"""
"""
Start training with the given config
Start training with the given config
...
@@ -40,11 +47,7 @@ def start_train(config):
...
@@ -40,11 +47,7 @@ def start_train(config):
for
v
in
output_vars
:
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
try
:
global_step_var
=
G
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
global_step_var
=
G
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
except
KeyError
:
# not created
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
# add some summary ops to the graph
# add some summary ops to the graph
averager
=
tf
.
train
.
ExponentialMovingAverage
(
averager
=
tf
.
train
.
ExponentialMovingAverage
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment