Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
86370a76
Commit
86370a76
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
checkpoint
parent
a78d02ac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
30 deletions
+43
-30
example_mnist.py
example_mnist.py
+15
-12
train.py
train.py
+25
-13
utils/callback.py
utils/callback.py
+2
-4
utils/logger.py
utils/logger.py
+1
-1
No files found.
example_mnist.py
View file @
86370a76
...
@@ -18,6 +18,7 @@ from models import *
...
@@ -18,6 +18,7 @@ from models import *
from
utils
import
*
from
utils
import
*
from
utils.symbolic_functions
import
*
from
utils.symbolic_functions
import
*
from
utils.summary
import
*
from
utils.summary
import
*
from
utils.concurrency
import
*
from
dataflow.dataset
import
Mnist
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
from
dataflow
import
*
...
@@ -40,12 +41,12 @@ def get_model(inputs):
...
@@ -40,12 +41,12 @@ def get_model(inputs):
image
,
label
=
inputs
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
#
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
#
pool0 = MaxPooling('pool0', conv0, 2)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
#
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
#
pool1 = MaxPooling('pool1', conv1, 2)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
fc0
=
FullyConnected
(
'fc0'
,
image
,
1024
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
# fc will have activation summary by default. disable this for the output layer
# fc will have activation summary by default. disable this for the output layer
...
@@ -91,12 +92,14 @@ def get_config():
...
@@ -91,12 +92,14 @@ def get_config():
sess_config
.
allow_soft_placement
=
True
sess_config
.
allow_soft_placement
=
True
# prepare model
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
image_var
=
tf
.
placeholder
(
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_vars
=
[
image_var
,
label_var
]
input_vars
=
[
image_var
,
label_var
]
output_vars
,
cost_var
=
get_model
(
input_vars
)
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
'float32'
,
'int32'
],
name
=
'queue'
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
learning_rate
=
1e-4
,
...
@@ -110,13 +113,13 @@ def get_config():
...
@@ -110,13 +113,13 @@ def get_config():
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
[
callbacks
=
[
SummaryWriter
(
LOG_DIR
),
SummaryWriter
(
LOG_DIR
),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
#
ValidationError(dataset_test, prefix='test'),
PeriodicSaver
(
LOG_DIR
),
PeriodicSaver
(
LOG_DIR
),
],
],
session_config
=
sess_config
,
session_config
=
sess_config
,
inputs
=
input_vars
,
inputs
=
input_vars
,
outputs
=
output_vars
,
input_queue
=
input_queue
,
cost
=
cost_var
,
get_model_func
=
get_model
,
max_epoch
=
100
,
max_epoch
=
100
,
)
)
...
...
train.py
View file @
86370a76
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
utils
import
*
from
utils
import
*
from
utils.concurrency
import
*
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
from
itertools
import
count
from
itertools
import
count
import
argparse
import
argparse
...
@@ -39,11 +40,17 @@ def start_train(config):
...
@@ -39,11 +40,17 @@ def start_train(config):
# a list of input/output variables
# a list of input/output variables
input_vars
=
config
[
'inputs'
]
input_vars
=
config
[
'inputs'
]
output_vars
=
config
[
'outputs
'
]
input_queue
=
config
[
'input_queue
'
]
cost_var
=
config
[
'cost
'
]
get_model_func
=
config
[
'get_model_func
'
]
max_epoch
=
int
(
config
[
'max_epoch'
])
max_epoch
=
int
(
config
[
'max_epoch'
])
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
model_inputs
=
input_queue
.
dequeue
()
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
output_vars
,
cost_var
=
get_model_func
(
model_inputs
)
# build graph
# build graph
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
for
v
in
input_vars
:
for
v
in
input_vars
:
...
@@ -71,21 +78,26 @@ def start_train(config):
...
@@ -71,21 +78,26 @@ def start_train(config):
with
sess
.
as_default
():
with
sess
.
as_default
():
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
callbacks
.
before_train
()
callbacks
.
before_train
()
is_training
=
G
.
get_tensor_by_name
(
IS_TRAINING_VAR_NAME
)
for
epoch
in
xrange
(
1
,
max_epoch
):
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
coord
=
tf
.
train
.
Coordinator
()
for
dp
in
dataset_train
.
get_data
():
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
feed
=
{
is_training
:
True
}
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)),
\
feed
.
update
(
dict
(
zip
(
input_vars
,
dp
)))
coordinator_context
(
sess
,
coord
,
th
,
input_queue
):
results
=
sess
.
run
(
for
step
in
xrange
(
dataset_train
.
size
()):
[
train_op
,
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
# TODO eval dequeue to get dp
fetches
=
[
train_op
,
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
,
feed_dict
=
{
IS_TRAINING_VAR_NAME
:
True
})
cost
=
results
[
1
]
cost
=
results
[
1
]
outputs
=
results
[
2
:]
outputs
=
results
[
2
:]
callbacks
.
trigger_step
(
feed
,
outputs
,
cost
)
print
tf
.
train
.
global_step
(
sess
,
global_step_var
),
cost
# trigger_step
coord
.
request_stop
()
# summary will take a data from the queue
callbacks
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
print
"Finish callback"
sess
.
close
()
sess
.
close
()
def
main
(
get_config_func
):
def
main
(
get_config_func
):
...
...
utils/callback.py
View file @
86370a76
...
@@ -75,15 +75,13 @@ class SummaryWriter(Callback):
...
@@ -75,15 +75,13 @@ class SummaryWriter(Callback):
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
self
.
last_dp
=
inputs
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
# check if there is any summary
# check if there is any summary
if
self
.
summary_op
is
None
:
if
self
.
summary_op
is
None
:
return
return
summary_str
=
self
.
summary_op
.
eval
(
self
.
last_dp
)
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
{
IS_TRAINING_VAR_NAME
:
True
})
self
.
epoch_num
+=
1
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
...
...
utils/logger.py
View file @
86370a76
...
@@ -34,7 +34,7 @@ def getlogger():
...
@@ -34,7 +34,7 @@ def getlogger():
logger
=
getlogger
()
logger
=
getlogger
()
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
]:
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
locals
()[
func
]
=
getattr
(
logger
,
func
)
locals
()[
func
]
=
getattr
(
logger
,
func
)
def
set_file
(
path
):
def
set_file
(
path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment