Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
09d1e881
Commit
09d1e881
authored
Jan 09, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix test script
parent
50859d25
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
6 additions
and
10 deletions
+6
-10
example_alexnet.py
example_alexnet.py
+2
-2
scripts/dump_model_params.py
scripts/dump_model_params.py
+0
-5
scripts/imgclassify.py
scripts/imgclassify.py
+1
-1
tensorpack/predict.py
tensorpack/predict.py
+2
-2
tensorpack/utils/sessinit.py
tensorpack/utils/sessinit.py
+1
-0
No files found.
example_alexnet.py
View file @
09d1e881
...
@@ -78,7 +78,7 @@ def get_model(inputs, is_training):
...
@@ -78,7 +78,7 @@ def get_model(inputs, is_training):
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
basename
=
os
.
path
.
basename
(
__file__
)
log_dir
=
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)])
log_dir
=
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)])
logger
.
set_logger_
dir
(
log_dir
)
logger
.
set_logger_
file
(
os
.
path
.
join
(
log_dir
,
'training.log'
)
)
dataset_train
=
FakeData
([(
227
,
227
,
3
),
tuple
()],
10
)
dataset_train
=
FakeData
([(
227
,
227
,
3
),
tuple
()],
10
)
dataset_train
=
BatchData
(
dataset_train
,
10
)
dataset_train
=
BatchData
(
dataset_train
,
10
)
...
@@ -158,7 +158,7 @@ if __name__ == '__main__':
...
@@ -158,7 +158,7 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
#
start_train(get_config())
start_train
(
get_config
())
# run alexnet with given model (in npy format)
# run alexnet with given model (in npy format)
run_test
(
'alexnet-tuned.npy'
)
run_test
(
'alexnet-tuned.npy'
)
scripts/dump_model_params.py
View file @
09d1e881
...
@@ -11,8 +11,6 @@ import imp
...
@@ -11,8 +11,6 @@ import imp
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils
import
sessinit
from
tensorpack.utils
import
sessinit
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.predict
import
DatasetPredictor
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
dest
=
'config'
)
parser
.
add_argument
(
dest
=
'config'
)
...
@@ -30,6 +28,3 @@ with tf.Graph().as_default() as G:
...
@@ -30,6 +28,3 @@ with tf.Graph().as_default() as G:
init
.
init
(
sess
)
init
.
init
(
sess
)
with
sess
.
as_default
():
with
sess
.
as_default
():
sessinit
.
dump_session_params
(
args
.
output
)
sessinit
.
dump_session_params
(
args
.
output
)
scripts/imgclassify.py
View file @
09d1e881
...
@@ -39,7 +39,7 @@ with tf.Graph().as_default() as G:
...
@@ -39,7 +39,7 @@ with tf.Graph().as_default() as G:
ds
=
ImageFromFile
(
args
.
images
,
3
,
resize
=
(
227
,
227
))
ds
=
ImageFromFile
(
args
.
images
,
3
,
resize
=
(
227
,
227
))
predictor
=
DatasetPredictor
(
config
,
ds
,
batch
=
128
)
predictor
=
DatasetPredictor
(
config
,
ds
,
batch
=
128
)
res
=
predictor
.
get_all_result
()
res
=
predictor
.
get_all_result
()
res
=
[
k
[
1
]
for
k
in
res
]
res
=
[
k
.
output
for
k
in
res
]
if
args
.
output_type
==
'label'
:
if
args
.
output_type
==
'label'
:
for
r
in
res
:
for
r
in
res
:
...
...
tensorpack/predict.py
View file @
09d1e881
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
itertools
import
count
from
itertools
import
count
,
izip
import
argparse
import
argparse
from
collections
import
namedtuple
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
...
@@ -93,7 +93,7 @@ def get_predict_func(config):
...
@@ -93,7 +93,7 @@ def get_predict_func(config):
assert
len
(
input_map
)
==
len
(
dp
),
\
assert
len
(
input_map
)
==
len
(
dp
),
\
"Graph has {} inputs but dataset only gives {} components!"
.
format
(
"Graph has {} inputs but dataset only gives {} components!"
.
format
(
len
(
input_map
),
len
(
dp
))
len
(
input_map
),
len
(
dp
))
feed
=
dict
(
zip
(
input_map
,
dp
))
feed
=
dict
(
i
zip
(
input_map
,
dp
))
if
output_var_names
is
not
None
:
if
output_var_names
is
not
None
:
results
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
results
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
results
return
results
...
...
tensorpack/utils/sessinit.py
View file @
09d1e881
...
@@ -56,6 +56,7 @@ class ParamRestore(SessionInit):
...
@@ -56,6 +56,7 @@ class ParamRestore(SessionInit):
sess
.
run
(
var
.
assign
(
value
))
sess
.
run
(
var
.
assign
(
value
))
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
""" dump value of all trainable variables to a dict"""
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
result
=
{}
result
=
{}
for
v
in
var
:
for
v
in
var
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment