Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f16643ac
Commit
f16643ac
authored
Dec 29, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
testing and loading script
parent
dd031661
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
157 additions
and
12 deletions
+157
-12
.gitignore
.gitignore
+1
-0
example_mnist.py
example_mnist.py
+1
-1
scripts/dump_model_params.py
scripts/dump_model_params.py
+37
-0
scripts/imgclassify.py
scripts/imgclassify.py
+46
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+2
-2
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+35
-0
tensorpack/predict.py
tensorpack/predict.py
+24
-8
tensorpack/utils/sessinit.py
tensorpack/utils/sessinit.py
+11
-1
No files found.
.gitignore
View file @
f16643ac
*.gz
*.npy
train_log
# Byte-compiled / optimized / DLL files
...
...
example_mnist.py
View file @
f16643ac
...
...
@@ -62,7 +62,7 @@ def get_model(inputs, is_training):
# fc will have activation summary by default. disable this for the output layer
logits
=
FullyConnected
(
'fc1'
,
l
,
out_dim
=
10
,
summary_activation
=
False
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'
output
'
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'
prob
'
)
y
=
one_hot
(
label
,
10
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
y
)
...
...
scripts/dump_model_params.py
0 → 100755
View file @
f16643ac
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump_model_params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
argparse
import
cv2
import
tensorflow
as
tf
import
imp
from
tensorpack.utils
import
*
from
tensorpack.utils
import
sessinit
from
tensorpack.dataflow
import
*
from
tensorpack.predict
import
DatasetPredictor
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
dest
=
'config'
)
parser
.
add_argument
(
dest
=
'model'
)
parser
.
add_argument
(
dest
=
'output'
)
args
=
parser
.
parse_args
()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
with
tf
.
Graph
()
.
as_default
()
as
G
:
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
config
=
get_config_func
()
config
[
'get_model_func'
](
config
[
'inputs'
],
is_training
=
False
)
init
=
sessinit
.
SaverRestore
(
args
.
model
)
sess
=
tf
.
Session
()
init
.
init
(
sess
)
with
sess
.
as_default
():
sessinit
.
dump_session_params
(
args
.
output
)
scripts/imgclassify.py
0 → 100755
View file @
f16643ac
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: imgclassify.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
argparse
import
cv2
import
tensorflow
as
tf
import
imp
from
tensorpack.utils
import
*
from
tensorpack.utils
import
sessinit
from
tensorpack.dataflow
import
*
from
tensorpack.predict
import
DatasetPredictor
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
dest
=
'config'
)
parser
.
add_argument
(
dest
=
'model'
)
parser
.
add_argument
(
dest
=
'images'
,
nargs
=
'+'
)
parser
.
add_argument
(
'--output_type'
,
default
=
'label'
,
choices
=
[
'label'
,
'label-prob'
,
'raw'
])
parser
.
add_argument
(
'--top'
,
default
=
1
,
type
=
int
)
args
=
parser
.
parse_args
()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
with
tf
.
Graph
()
.
as_default
()
as
G
:
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
config
=
get_config_func
()
config
[
'session_init'
]
=
sessinit
.
SaverRestore
(
args
.
model
)
config
[
'output_var'
]
=
'output:0'
ds
=
ImageFromFile
(
args
.
images
,
3
,
resize
=
(
227
,
227
))
predictor
=
DatasetPredictor
(
config
,
ds
,
batch
=
128
)
res
=
predictor
.
get_all_result
()
if
args
.
output_type
==
'label'
:
for
r
in
res
:
print
r
.
argsort
()[
-
top
:][::
-
1
]
elif
args
.
output_type
==
'label_prob'
:
raise
NotImplementedError
elif
args
.
output_type
==
'raw'
:
print
res
tensorpack/dataflow/common.py
View file @
f16643ac
...
...
@@ -17,6 +17,7 @@ class BatchData(DataFlow):
if set, might return a data point of a different shape
"""
self
.
ds
=
ds
if
not
remainder
:
assert
batch_size
<=
ds
.
size
()
self
.
batch_size
=
batch_size
self
.
remainder
=
remainder
...
...
@@ -85,7 +86,6 @@ class FakeData(DataFlow):
for
_
in
xrange
(
self
.
_size
):
yield
[
np
.
random
.
random
(
k
)
for
k
in
self
.
shapes
]
class
MapData
(
DataFlow
):
""" Apply a function to the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
...
...
tensorpack/dataflow/image.py
0 → 100644
View file @
f16643ac
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: image.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
cv2
from
.base
import
DataFlow
__all__
=
[
'ImageFromFile'
]
class
ImageFromFile
(
DataFlow
):
""" generate rgb images from files """
def
__init__
(
self
,
files
,
channel
,
resize
=
None
):
""" files: list of file path
channel: 1 or 3 channel
resize: a (w, h) tuple. If given, will force a resize
"""
self
.
files
=
files
self
.
channel
=
int
(
channel
)
self
.
resize
=
resize
def
size
(
self
):
return
len
(
self
.
files
)
def
get_data
(
self
):
for
f
in
self
.
files
:
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_GRAYSCALE
if
self
.
channel
==
1
else
cv2
.
IMREAD_COLOR
)
if
self
.
channel
==
3
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
if
self
.
resize
is
not
None
:
im
=
cv2
.
resize
(
im
,
self
.
resize
)
yield
(
im
,)
tensorpack/predict.py
View file @
f16643ac
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File:
infer
.py
# File:
predict
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
...
...
@@ -11,7 +11,7 @@ import numpy as np
from
utils
import
*
from
utils.modelutils
import
describe_model
from
utils
import
logger
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
,
BatchData
def
get_predict_func
(
config
):
"""
...
...
@@ -27,6 +27,10 @@ def get_predict_func(config):
sess_init
=
config
[
'session_init'
]
# Provide this if only specific output is needed.
# by default will evaluate all outputs as well as cost
output_var_name
=
config
.
get
(
'output_var'
,
None
)
# input/output variables
input_vars
=
config
[
'inputs'
]
get_model_func
=
config
[
'get_model_func'
]
...
...
@@ -38,18 +42,30 @@ def get_predict_func(config):
sess_init
.
init
(
sess
)
def
run_input
(
dp
):
# TODO if input and dp not aligned?
feed
=
dict
(
zip
(
input_vars
,
dp
))
results
=
sess
.
run
(
[
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
if
output_var_name
is
not
None
:
fetches
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
output_var_name
)
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
return
results
[
0
]
else
:
fetches
=
[
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
cost
=
results
[
0
]
outputs
=
results
[
1
:]
return
cost
,
outputs
return
run_input
class
DatasetPredictor
(
object
):
def
__init__
(
self
,
predict_config
,
dataset
):
def
__init__
(
self
,
predict_config
,
dataset
,
batch
=
0
):
"""
A predictor with the given predict_config, run on the given dataset
if batch is larger than zero, the dataset will be batched
"""
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
if
batch
>
0
:
self
.
ds
=
BatchData
(
self
.
ds
,
batch
,
remainder
=
True
)
self
.
predict_func
=
get_predict_func
(
predict_config
)
def
get_result
(
self
):
...
...
tensorpack/utils/sessinit.py
View file @
f16643ac
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
abstractmethod
import
numpy
as
np
import
tensorflow
as
tf
from
.
import
logger
...
...
@@ -24,7 +25,7 @@ class SaverRestore(SessionInit):
saver
=
tf
.
train
.
Saver
()
saver
.
restore
(
sess
,
self
.
path
)
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_
path
))
"Restore checkpoint from {}"
.
format
(
self
.
path
))
def
set_path
(
self
,
model_path
):
self
.
path
=
model_path
...
...
@@ -44,3 +45,12 @@ class ParamRestore(SessionInit):
continue
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
sess
.
run
(
var
.
assign
(
value
))
def
dump_session_params
(
path
):
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
result
=
{}
for
v
in
var
:
result
[
v
.
name
]
=
v
.
eval
()
logger
.
info
(
"Params to save to {}:"
.
format
(
path
))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment