Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1f0670e5
Commit
1f0670e5
authored
Dec 31, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix scripts on new config
parent
c5df0501
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
14 deletions
+16
-14
example_alexnet.py
example_alexnet.py
+4
-4
scripts/dump_model_params.py
scripts/dump_model_params.py
+2
-3
scripts/imgclassify.py
scripts/imgclassify.py
+10
-7
No files found.
example_alexnet.py
View file @
1f0670e5
...
...
@@ -83,7 +83,7 @@ def get_config():
dataset_train
=
FakeData
([(
227
,
227
,
3
),
tuple
()],
10
)
dataset_train
=
BatchData
(
dataset_train
,
10
)
step_per_epoch
=
3
step_per_epoch
=
1
sess_config
=
get_default_sess_config
()
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
...
...
@@ -105,12 +105,12 @@ def get_config():
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
param_dict
=
np
.
load
(
'alexnet
1
.npy'
)
.
item
()
param_dict
=
np
.
load
(
'alexnet.npy'
)
.
item
()
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callback
=
Callbacks
([
callback
s
=
Callbacks
([
SummaryWriter
(),
PeriodicSaver
(),
#ValidationError(dataset_test, prefix='test'),
...
...
@@ -162,4 +162,4 @@ if __name__ == '__main__':
#start_train(get_config())
# run alexnet with given model (in npy format)
run_test
(
'alexnet.npy'
)
run_test
(
'alexnet
-tuned
.npy'
)
scripts/dump_model_params.py
View file @
1f0670e5
...
...
@@ -23,10 +23,9 @@ args = parser.parse_args()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
with
tf
.
Graph
()
.
as_default
()
as
G
:
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
global_step_var
=
get_global_step_var
()
config
=
get_config_func
()
config
[
'get_model_func'
](
config
[
'inputs'
]
,
is_training
=
False
)
config
.
get_model_func
(
config
.
inputs
,
is_training
=
False
)
init
=
sessinit
.
SaverRestore
(
args
.
model
)
sess
=
tf
.
Session
()
init
.
init
(
sess
)
...
...
scripts/imgclassify.py
View file @
1f0670e5
...
...
@@ -12,7 +12,7 @@ import imp
from
tensorpack.utils
import
*
from
tensorpack.utils
import
sessinit
from
tensorpack.dataflow
import
*
from
tensorpack.predict
import
DatasetPredictor
from
tensorpack.predict
import
PredictConfig
,
DatasetPredictor
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -27,11 +27,14 @@ args = parser.parse_args()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
with
tf
.
Graph
()
.
as_default
()
as
G
:
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
config
=
get_config_func
()
config
[
'session_init'
]
=
sessinit
.
SaverRestore
(
args
.
model
)
config
[
'output_var'
]
=
'output:0'
train_config
=
get_config_func
()
config
=
PredictConfig
(
inputs
=
train_config
.
inputs
,
input_dataset_mapping
=
[
train_config
.
inputs
[
0
]],
# assume first component is image
get_model_func
=
train_config
.
get_model_func
,
session_init
=
sessinit
.
SaverRestore
(
args
.
model
),
output_var_names
=
[
'output:0'
]
)
ds
=
ImageFromFile
(
args
.
images
,
3
,
resize
=
(
227
,
227
))
predictor
=
DatasetPredictor
(
config
,
ds
,
batch
=
128
)
...
...
@@ -39,7 +42,7 @@ with tf.Graph().as_default() as G:
if
args
.
output_type
==
'label'
:
for
r
in
res
:
print
r
.
argsort
()[
-
top
:][
::
-
1
]
print
r
[
0
]
.
argsort
(
axis
=
1
)[:,
-
args
.
top
:][:,
::
-
1
]
elif
args
.
output_type
==
'label_prob'
:
raise
NotImplementedError
elif
args
.
output_type
==
'raw'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment