Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dd1ac6b0
Commit
dd1ac6b0
authored
Dec 28, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
step_per_epoch & compatible feeding
parent
745ad4f0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
11 deletions
+58
-11
example_mnist.py
example_mnist.py
+17
-3
train.py
train.py
+10
-3
utils/concurrency.py
utils/concurrency.py
+12
-5
utils/utils.py
utils/utils.py
+19
-0
No files found.
example_mnist.py
View file @
dd1ac6b0
...
@@ -24,6 +24,10 @@ from utils.concurrency import *
...
@@ -24,6 +24,10 @@ from utils.concurrency import *
from
dataflow.dataset
import
Mnist
from
dataflow.dataset
import
Mnist
from
dataflow
import
*
from
dataflow
import
*
BATCH_SIZE
=
128
MIN_AFTER_DEQUEUE
=
500
CAPACITY
=
MIN_AFTER_DEQUEUE
+
3
*
BATCH_SIZE
def
get_model
(
inputs
,
is_training
):
def
get_model
(
inputs
,
is_training
):
"""
"""
Args:
Args:
...
@@ -43,6 +47,15 @@ def get_model(inputs, is_training):
...
@@ -43,6 +47,15 @@ def get_model(inputs, is_training):
image
,
label
=
inputs
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
if
is_training
:
# augmentations
image
,
label
=
tf
.
train
.
slice_input_producer
(
[
image
,
label
],
name
=
'slice_queue'
)
image
=
tf
.
image
.
random_brightness
(
image
,
0.1
)
image
,
label
=
tf
.
train
.
shuffle_batch
(
[
image
,
label
],
BATCH_SIZE
,
CAPACITY
,
MIN_AFTER_DEQUEUE
,
num_threads
=
2
,
enqueue_many
=
False
)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
...
@@ -86,11 +99,11 @@ def get_config():
...
@@ -86,11 +99,11 @@ def get_config():
logger
.
set_logger_dir
(
log_dir
)
logger
.
set_logger_dir
(
log_dir
)
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
BATCH_SIZE
=
128
dataset_train
=
BatchData
(
Mnist
(
'train'
),
BATCH_SIZE
)
dataset_train
=
Mnist
(
'train'
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
#dataset_train = FixedSizeData(dataset_train, 20)
step_per_epoch
=
dataset_train
.
size
()
/
BATCH_SIZE
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
tf
.
ConfigProto
()
sess_config
=
tf
.
ConfigProto
()
...
@@ -129,6 +142,7 @@ def get_config():
...
@@ -129,6 +142,7 @@ def get_config():
inputs
=
input_vars
,
inputs
=
input_vars
,
input_queue
=
input_queue
,
input_queue
=
input_queue
,
get_model_func
=
get_model
,
get_model_func
=
get_model
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
100
,
max_epoch
=
100
,
)
)
...
...
train.py
View file @
dd1ac6b0
...
@@ -41,7 +41,9 @@ def start_train(config):
...
@@ -41,7 +41,9 @@ def start_train(config):
input_queue
=
config
[
'input_queue'
]
input_queue
=
config
[
'input_queue'
]
get_model_func
=
config
[
'get_model_func'
]
get_model_func
=
config
[
'get_model_func'
]
step_per_epoch
=
int
(
config
[
'step_per_epoch'
])
max_epoch
=
int
(
config
[
'max_epoch'
])
max_epoch
=
int
(
config
[
'max_epoch'
])
assert
step_per_epoch
>
0
and
max_epoch
>
0
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
model_inputs
=
input_queue
.
dequeue
()
model_inputs
=
input_queue
.
dequeue
()
...
@@ -79,14 +81,19 @@ def start_train(config):
...
@@ -79,14 +81,19 @@ def start_train(config):
# start training:
# start training:
coord
=
tf
.
train
.
Coordinator
()
coord
=
tf
.
train
.
Coordinator
()
# a thread that keeps filling the queue
# a thread that keeps filling the queue
th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
input_th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset_train
)
model_th
=
tf
.
train
.
start_queue_runners
(
sess
=
sess
,
coord
=
coord
,
daemon
=
True
,
start
=
False
)
with
sess
.
as_default
(),
\
with
sess
.
as_default
(),
\
coordinator_guard
(
coordinator_guard
(
sess
,
coord
,
th
,
input_queue
):
sess
,
coord
,
[
input_th
]
+
model_
th
,
input_queue
):
callbacks
.
before_train
()
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
max_epoch
):
for
epoch
in
xrange
(
1
,
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
step
in
xrange
(
dataset_train
.
size
()):
for
step
in
xrange
(
step_per_epoch
):
if
coord
.
should_stop
():
return
fetches
=
[
train_op
,
cost_var
]
+
output_vars
+
model_inputs
fetches
=
[
train_op
,
cost_var
]
+
output_vars
+
model_inputs
results
=
sess
.
run
(
fetches
)
results
=
sess
.
run
(
fetches
)
cost
=
results
[
1
]
cost
=
results
[
1
]
...
...
utils/concurrency.py
View file @
dd1ac6b0
...
@@ -5,8 +5,10 @@
...
@@ -5,8 +5,10 @@
import
threading
import
threading
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
izip
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.utils
import
expand_dim_if_necessary
from
.naming
import
*
from
.naming
import
*
import
logger
import
logger
...
@@ -37,21 +39,26 @@ class EnqueueThread(threading.Thread):
...
@@ -37,21 +39,26 @@ class EnqueueThread(threading.Thread):
for
dp
in
self
.
dataflow
.
get_data
():
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
feed
=
{}
for
var
,
data
in
izip
(
self
.
input_vars
,
dp
):
data
=
expand_dim_if_necessary
(
var
,
data
)
feed
[
var
]
=
data
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
as
e
:
pass
pass
except
Exception
:
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
logger
.
exception
(
"Exception in EnqueueThread:"
)
self
.
coord
.
request_stop
()
@
contextmanager
@
contextmanager
def
coordinator_guard
(
sess
,
coord
,
thread
,
queue
):
def
coordinator_guard
(
sess
,
coord
,
thread
s
,
queue
):
"""
"""
Context manager to make sure that:
Context manager to make sure that:
queue is closed
queue is closed
thread
is
joined
thread
s are
joined
"""
"""
thread
.
start
()
for
th
in
threads
:
th
.
start
()
try
:
try
:
yield
yield
except
(
KeyboardInterrupt
,
Exception
)
as
e
:
except
(
KeyboardInterrupt
,
Exception
)
as
e
:
...
@@ -60,4 +67,4 @@ def coordinator_guard(sess, coord, thread, queue):
...
@@ -60,4 +67,4 @@ def coordinator_guard(sess, coord, thread, queue):
coord
.
request_stop
()
coord
.
request_stop
()
sess
.
run
(
sess
.
run
(
queue
.
close
(
cancel_pending_enqueues
=
True
))
queue
.
close
(
cancel_pending_enqueues
=
True
))
coord
.
join
(
[
thread
]
)
coord
.
join
(
threads
)
utils/utils.py
0 → 100644
View file @
dd1ac6b0
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
def
expand_dim_if_necessary
(
var
,
dp
):
"""
Args:
var: a tensor
dp: a numpy array
Return a reshaped version of dp, if that makes it match the valid dimension of var
"""
shape
=
var
.
get_shape
()
.
as_list
()
valid_shape
=
[
k
for
k
in
shape
if
k
]
if
dp
.
shape
==
tuple
(
valid_shape
):
new_shape
=
[
k
if
k
else
1
for
k
in
shape
]
dp
=
dp
.
reshape
(
new_shape
)
return
dp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment