Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0d20cb3d
Commit
0d20cb3d
authored
May 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
new inference framework based on trainer.get_predict_func
parent
8644248a
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
62 additions
and
40 deletions
+62
-40
examples/cifar-convnet.py
examples/cifar-convnet.py
+5
-9
examples/mnist-convnet.py
examples/mnist-convnet.py
+1
-5
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+9
-11
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+16
-3
tensorpack/train/base.py
tensorpack/train/base.py
+5
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+26
-12
No files found.
examples/cifar-convnet.py
View file @
0d20cb3d
...
...
@@ -8,14 +8,9 @@ import argparse
import
numpy
as
np
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
"""
A small convnet model for cifar 10 or cifar100 dataset.
...
...
@@ -65,7 +60,7 @@ class Model(ModelDesc):
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
# compute the number of failed samples, for ClassificationError to use at test time
wrong
=
prediction_incorrect
(
logits
,
label
)
wrong
=
symbf
.
prediction_incorrect
(
logits
,
label
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
tf
.
add_to_collection
(
...
...
@@ -161,4 +156,5 @@ if __name__ == '__main__':
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInputTrainer
(
config
)
.
train
()
#QueueInputTrainer(config).train()
SimpleTrainer
(
config
)
.
train
()
examples/mnist-convnet.py
View file @
0d20cb3d
...
...
@@ -9,9 +9,6 @@ import os, sys
import
argparse
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.callbacks
import
*
"""
MNIST ConvNet example.
...
...
@@ -117,6 +114,5 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
#QueueInputTrainer(config).train()
SimpleInputTrainer
(
config
)
.
train
()
QueueInputTrainer
(
config
)
.
train
()
tensorpack/callbacks/inference.py
View file @
0d20cb3d
...
...
@@ -63,7 +63,7 @@ class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
type
=
TestCallbackType
()
#
type = TestCallbackType()
def
__init__
(
self
,
ds
,
vcs
):
"""
...
...
@@ -82,12 +82,15 @@ class InferenceRunner(Callback):
def
_before_train
(
self
):
self
.
input_vars
=
self
.
trainer
.
model
.
reuse_input_vars
()
self
.
_find_output_tensors
()
input_names
=
[
x
.
name
for
x
in
self
.
input_vars
]
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
input_names
,
self
.
output_tensors
)
for
v
in
self
.
vcs
:
v
.
trainer
=
self
.
trainer
def
_find_output_tensors
(
self
):
self
.
output_tensors
=
[]
self
.
vc_to_vars
=
[]
self
.
output_tensors
=
[]
# list of names
self
.
vc_to_vars
=
[]
# list of list of (var_name: output_idx)
for
vc
in
self
.
vcs
:
vc_vars
=
vc
.
_get_output_tensors
()
def
find_oid
(
var
):
...
...
@@ -99,12 +102,6 @@ class InferenceRunner(Callback):
vc_vars
=
[(
var
,
find_oid
(
var
))
for
var
in
vc_vars
]
self
.
vc_to_vars
.
append
(
vc_vars
)
# convert name to tensors
def
get_tensor
(
name
):
_
,
varname
=
get_op_var_name
(
name
)
return
self
.
graph
.
get_tensor_by_name
(
varname
)
self
.
output_tensors
=
list
(
map
(
get_tensor
,
self
.
output_tensors
))
def
_trigger_epoch
(
self
):
for
vc
in
self
.
vcs
:
vc
.
before_inference
()
...
...
@@ -112,8 +109,9 @@ class InferenceRunner(Callback):
sess
=
tf
.
get_default_session
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
# TODO custom dp mapping?
outputs
=
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs
=
self
.
pred_func
(
dp
)
for
vc
,
varsmap
in
zip
(
self
.
vcs
,
self
.
vc_to_vars
):
vc_output
=
[
outputs
[
k
[
1
]]
for
k
in
varsmap
]
vc
.
datapoint
(
dp
,
vc_output
)
...
...
tensorpack/tfutils/common.py
View file @
0d20cb3d
...
...
@@ -9,7 +9,9 @@ import tensorflow as tf
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step_var'
,
'get_op_var_name'
]
'get_op_var_name'
,
'get_vars_by_names'
]
def
get_default_sess_config
(
mem_fraction
=
0.9
):
"""
...
...
@@ -53,3 +55,14 @@ def get_op_var_name(name):
return
name
[:
-
2
],
name
else
:
return
name
,
name
+
':0'
def
get_vars_by_names
(
names
):
"""
Get a list of variables in the default graph by a list of names
"""
ret
=
[]
G
=
tf
.
get_default_graph
()
for
n
in
names
:
opn
,
varn
=
get_op_var_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
tensorpack/train/base.py
View file @
0d20cb3d
...
...
@@ -50,6 +50,11 @@ class Trainer(object):
""" run an iteration"""
pass
@
abstractmethod
def
get_predict_func
(
self
,
input_names
,
output_names
):
""" return a predict function"""
pass
def
trigger_epoch
(
self
):
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
...
...
tensorpack/train/trainer.py
View file @
0d20cb3d
...
...
@@ -53,25 +53,16 @@ class SimpleTrainer(Trainer):
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
input_vars
=
[]
for
n
in
input_names
:
opn
,
varn
=
get_op_var_name
(
n
)
v
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
varn
)
input_vars
=
get_vars_by_names
(
input_names
)
for
v
in
input_vars
:
assert
v
in
self
.
input_vars
input_vars
.
append
(
v
)
output_vars
=
[]
for
n
in
output_names
:
opn
,
varn
=
get_op_var_name
(
n
)
v
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
varn
)
output_vars
.
append
(
v
)
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
super
(
EnqueueThread
,
self
)
.
__init__
()
...
...
@@ -126,6 +117,7 @@ class QueueInputTrainer(Trainer):
self
.
async
=
async
if
self
.
async
:
assert
self
.
config
.
nr_tower
>
1
self
.
_dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
...
...
@@ -148,6 +140,7 @@ class QueueInputTrainer(Trainer):
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
self
.
_dequed_inputs
.
append
(
ret
)
return
ret
def
_single_tower_grad
(
self
):
...
...
@@ -248,6 +241,27 @@ class QueueInputTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
raw_input_vars
=
get_vars_by_names
(
input_names
)
input_var_idxs
=
[
self
.
input_vars
.
index
(
v
)
for
v
in
raw_input_vars
]
if
self
.
config
.
nr_tower
==
1
:
dequed
=
self
.
_dequed_inputs
[
0
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
output_vars
=
get_vars_by_names
(
output_names
)
else
:
# TODO naive impl: use the first tower only
dequed
=
self
.
_dequed_inputs
[
0
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
output_names
=
[
'tower0/'
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
def
start_train
(
config
):
tr
=
QueueInputTrainer
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment