Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bcbbc645
Commit
bcbbc645
authored
Dec 29, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
working on alexnet
parent
87f7e7cb
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
147 additions
and
32 deletions
+147
-32
dataflow/common.py
dataflow/common.py
+1
-0
dataflow/dataset/__init__.py
dataflow/dataset/__init__.py
+0
-1
dataflow/dataset/cifar10.py
dataflow/dataset/cifar10.py
+1
-0
infer.py
infer.py
+74
-0
models/_common.py
models/_common.py
+1
-0
models/conv2d.py
models/conv2d.py
+15
-3
train.py
train.py
+11
-4
utils/modelutils.py
utils/modelutils.py
+44
-0
utils/summary.py
utils/summary.py
+0
-24
No files found.
dataflow/common.py
View file @
bcbbc645
...
@@ -17,6 +17,7 @@ class BatchData(DataFlow):
...
@@ -17,6 +17,7 @@ class BatchData(DataFlow):
if set, might return a data point of a different shape
if set, might return a data point of a different shape
"""
"""
self
.
ds
=
ds
self
.
ds
=
ds
assert
batch_size
<=
ds
.
size
()
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
remainder
=
remainder
self
.
remainder
=
remainder
...
...
dataflow/dataset/__init__.py
View file @
bcbbc645
...
@@ -8,7 +8,6 @@ import os
...
@@ -8,7 +8,6 @@ import os
import
os.path
import
os.path
def
global_import
(
name
):
def
global_import
(
name
):
print
name
p
=
__import__
(
name
,
globals
(),
locals
())
p
=
__import__
(
name
,
globals
(),
locals
())
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
for
k
in
lst
:
...
...
dataflow/dataset/cifar10.py
View file @
bcbbc645
...
@@ -63,6 +63,7 @@ class Cifar10(DataFlow):
...
@@ -63,6 +63,7 @@ class Cifar10(DataFlow):
assert
train_or_test
in
[
'train'
,
'test'
]
assert
train_or_test
in
[
'train'
,
'test'
]
if
dir
is
None
:
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'cifar10_data'
)
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'cifar10_data'
)
maybe_download_and_extract
(
dir
)
if
train_or_test
==
'train'
:
if
train_or_test
==
'train'
:
self
.
fs
=
[
os
.
path
.
join
(
self
.
fs
=
[
os
.
path
.
join
(
...
...
infer.py
0 → 100644
View file @
bcbbc645
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: infer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
itertools
import
count
import
argparse
import
numpy
as
np
from
utils
import
*
from
utils.modelutils
import
describe_model
,
restore_params
from
utils
import
logger
from
dataflow
import
DataFlow
def
start_infer
(
config
):
"""
Args:
config: a tensorpack config dictionary
"""
dataset_test
=
config
[
'dataset_test'
]
assert
isinstance
(
dataset_test
,
DataFlow
),
dataset_test
.
__class__
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
# TODO callback should have trigger_step and trigger_end?
callback
=
config
[
'callback'
]
# restore saved params
params
=
config
.
get
(
'restore_params'
,
{})
# input/output variables
input_vars
=
config
[
'inputs'
]
get_model_func
=
config
[
'get_model_func'
]
output_vars
,
cost_var
=
get_model_func
(
input_vars
,
is_training
=
False
)
# build graph
G
=
tf
.
get_default_graph
()
G
.
add_to_collection
(
FORWARD_FUNC_KEY
,
get_model_func
)
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
describe_model
()
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
.
run
(
tf
.
initialize_all_variables
())
restore_params
(
sess
,
params
)
with
sess
.
as_default
():
with
timed_operation
(
'running one batch'
):
for
dp
in
dataset_test
.
get_data
():
feed
=
dict
(
zip
(
input_vars
,
dp
))
fetches
=
[
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
cost
=
results
[
0
]
outputs
=
results
[
1
:]
prob
=
outputs
[
0
]
callback
(
dp
,
outputs
,
cost
)
def
main
(
get_config_func
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
config
=
get_config_func
()
start_infer
(
config
)
models/_common.py
View file @
bcbbc645
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
utils.modelutils
import
*
from
utils.summary
import
*
from
utils.summary
import
*
from
utils
import
logger
from
utils
import
logger
...
...
models/conv2d.py
View file @
bcbbc645
...
@@ -12,18 +12,22 @@ __all__ = ['Conv2D']
...
@@ -12,18 +12,22 @@ __all__ = ['Conv2D']
@
layer_register
(
summary_activation
=
True
)
@
layer_register
(
summary_activation
=
True
)
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'VALID'
,
stride
=
1
,
padding
=
'VALID'
,
stride
=
1
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
):
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
,
split
=
1
):
"""
"""
kernel_shape: (h, w) or a int
kernel_shape: (h, w) or a int
stride: (h, w) or a int
stride: (h, w) or a int
padding: 'valid' or 'same'
padding: 'valid' or 'same'
split: split channels. used in alexnet
"""
"""
in_shape
=
x
.
get_shape
()
.
as_list
()
in_shape
=
x
.
get_shape
()
.
as_list
()
in_channel
=
in_shape
[
-
1
]
in_channel
=
in_shape
[
-
1
]
assert
in_channel
%
split
==
0
assert
out_channel
%
split
==
0
kernel_shape
=
shape2d
(
kernel_shape
)
kernel_shape
=
shape2d
(
kernel_shape
)
padding
=
padding
.
upper
()
padding
=
padding
.
upper
()
filter_shape
=
kernel_shape
+
[
in_channel
,
out_channel
]
filter_shape
=
kernel_shape
+
[
in_channel
/
split
,
out_channel
]
stride
=
shape4d
(
stride
)
stride
=
shape4d
(
stride
)
if
W_init
is
None
:
if
W_init
is
None
:
...
@@ -34,6 +38,14 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -34,6 +38,14 @@ def Conv2D(x, out_channel, kernel_shape,
W
=
tf
.
get_variable
(
'W'
,
filter_shape
,
initializer
=
W_init
)
# TODO collections
W
=
tf
.
get_variable
(
'W'
,
filter_shape
,
initializer
=
W_init
)
# TODO collections
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
b_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
b_init
)
conv
=
tf
.
nn
.
conv2d
(
x
,
W
,
stride
,
padding
)
if
split
==
1
:
conv
=
tf
.
nn
.
conv2d
(
x
,
W
,
stride
,
padding
)
else
:
inputs
=
tf
.
split
(
3
,
split
,
x
)
kernels
=
tf
.
split
(
3
,
split
,
W
)
outputs
=
[
tf
.
nn
.
conv2d
(
i
,
k
,
stride
,
padding
)
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
3
,
outputs
)
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
))
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
))
train.py
View file @
bcbbc645
...
@@ -4,18 +4,20 @@
...
@@ -4,18 +4,20 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
itertools
import
count
import
argparse
from
utils
import
*
from
utils
import
*
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.summary
import
summary_moving_average
,
describe_model
from
utils.summary
import
summary_moving_average
from
utils.modelutils
import
restore_params
,
describe_model
from
utils
import
logger
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
from
itertools
import
count
import
argparse
def
prepare
():
def
prepare
():
global_step_var
=
tf
.
Variable
(
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
def
start_train
(
config
):
def
start_train
(
config
):
"""
"""
Start training with the given config
Start training with the given config
...
@@ -36,6 +38,9 @@ def start_train(config):
...
@@ -36,6 +38,9 @@ def start_train(config):
sess_config
=
config
.
get
(
'session_config'
,
None
)
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
# restore saved params
params
=
config
.
get
(
'restore_params'
,
{})
# input/output variables
# input/output variables
input_vars
=
config
[
'inputs'
]
input_vars
=
config
[
'inputs'
]
input_queue
=
config
[
'input_queue'
]
input_queue
=
config
[
'input_queue'
]
...
@@ -78,6 +83,8 @@ def start_train(config):
...
@@ -78,6 +83,8 @@ def start_train(config):
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
restore_params
(
sess
,
params
)
# start training:
# start training:
coord
=
tf
.
train
.
Coordinator
()
coord
=
tf
.
train
.
Coordinator
()
# a thread that keeps filling the queue
# a thread that keeps filling the queue
...
...
utils/modelutils.py
0 → 100644
View file @
bcbbc645
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: modelutils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
logger
def
restore_params
(
sess
,
params
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
params
.
iteritems
():
try
:
var
=
var_dict
[
name
]
except
(
ValueError
,
KeyError
):
logger
.
warn
(
"Param {} not found in this graph"
.
format
(
name
))
continue
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
sess
.
run
(
var
.
assign
(
value
))
def
describe_model
():
""" describe the current model parameters"""
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
msg
=
[
""
]
total
=
0
for
v
in
train_vars
:
shape
=
v
.
get_shape
()
ele
=
shape
.
num_elements
()
total
+=
ele
msg
.
append
(
"{}: shape={}, dim={}"
.
format
(
v
.
name
,
shape
.
as_list
(),
ele
))
msg
.
append
(
"Total dim={}"
.
format
(
total
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
def
get_shape_str
(
tensors
):
""" return the shape string for a tensor or a list of tensors"""
if
isinstance
(
tensors
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
tensors
))
else
:
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
utils/summary.py
View file @
bcbbc645
...
@@ -60,27 +60,3 @@ def summary_moving_average(cost_var):
...
@@ -60,27 +60,3 @@ def summary_moving_average(cost_var):
tf
.
scalar_summary
(
c
.
op
.
name
,
averager
.
average
(
c
))
tf
.
scalar_summary
(
c
.
op
.
name
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
def
describe_model
():
""" describe the current model parameters"""
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
msg
=
[
""
]
total
=
0
for
v
in
train_vars
:
shape
=
v
.
get_shape
()
ele
=
shape
.
num_elements
()
total
+=
ele
msg
.
append
(
"{}: shape={}, dim={}"
.
format
(
v
.
name
,
shape
.
as_list
(),
ele
))
msg
.
append
(
"Total dim={}"
.
format
(
total
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
def
get_shape_str
(
tensors
):
""" return the shape string for a tensor or a list of tensors"""
if
isinstance
(
tensors
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
tensors
))
else
:
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment