Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8644248a
Commit
8644248a
authored
May 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
get_predict_func in simpletrainer. before trying a different inference framework
parent
65fb37b9
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
23 deletions
+40
-23
examples/mnist-convnet.py
examples/mnist-convnet.py
+2
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-1
tensorpack/predict/common.py
tensorpack/predict/common.py
+1
-0
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+13
-16
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+22
-4
No files found.
examples/mnist-convnet.py
View file @
8644248a
...
...
@@ -117,5 +117,6 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
()
#QueueInputTrainer(config).train()
SimpleInputTrainer
(
config
)
.
train
()
tensorpack/models/model_desc.py
View file @
8644248a
...
...
@@ -20,7 +20,7 @@ class ModelDesc(object):
def
get_input_vars
(
self
):
"""
Create or return (if already created)
input TF
vars in the graph.
Create or return (if already created)
raw input TF placeholder
vars in the graph.
:returns: the list of raw input vars in the graph
"""
...
...
tensorpack/predict/common.py
View file @
8644248a
...
...
@@ -46,6 +46,7 @@ class PredictConfig(object):
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
...
tensorpack/predict/concurrency.py
View file @
8644248a
...
...
@@ -47,11 +47,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output)
:param config: a `PredictConfig`
"""
super
(
MultiProcessQueuePredictWorker
,
self
)
.
__init__
(
idx
,
gpuid
,
config
)
self
.
inqueue
=
inqueue
...
...
@@ -67,17 +64,17 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else
:
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
class
MultiThreadPredictWorker
(
threading
.
Thread
):
def
__init__
(
self
,
idx
,
gpuid
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super
(
MultiProcessPredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
self
.
gpuid
=
gpuid
self
.
config
=
config
#class CurrentSessionPredictor(
):
#
def __init__(self, idx, gpuid, config):
#
"""
#
:param idx: index of the worker. the 0th worker will print log.
#
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
#
:param config: a `PredictConfig`
#
"""
#
super(MultiProcessPredictWorker, self).__init__()
#
self.idx = idx
#
self.gpuid = gpuid
#
self.config = config
def
run
(
self
):
pass
#
def run(self):
#
pass
tensorpack/tfutils/sessinit.py
View file @
8644248a
...
...
@@ -37,7 +37,7 @@ class SessionInit(object):
class
JustCurrentSession
(
SessionInit
):
""" Just use the current default session. This is a no-op placeholder"""
def
_init
(
self
,
sess
):
logger
.
info
(
"Using the current running session .."
)
pass
class
NewSession
(
SessionInit
):
"""
...
...
tensorpack/train/trainer.py
View file @
8644248a
...
...
@@ -27,9 +27,8 @@ class SimpleTrainer(Trainer):
def
train
(
self
):
model
=
self
.
model
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
input_vars
model
.
build_graph
(
input_vars
,
True
)
self
.
input_vars
=
model
.
get_input_vars
()
model
.
build_graph
(
self
.
input_vars
,
True
)
cost_var
=
model
.
get_cost
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
...
...
@@ -53,6 +52,26 @@ class SimpleTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
input_vars
=
[]
for
n
in
input_names
:
opn
,
varn
=
get_op_var_name
(
n
)
v
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
varn
)
assert
v
in
self
.
input_vars
input_vars
.
append
(
v
)
output_vars
=
[]
for
n
in
output_names
:
opn
,
varn
=
get_op_var_name
(
n
)
v
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
varn
)
output_vars
.
append
(
v
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
super
(
EnqueueThread
,
self
)
.
__init__
()
...
...
@@ -85,7 +104,6 @@ class EnqueueThread(threading.Thread):
finally
:
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInputTrainer
(
Trainer
):
"""
Trainer which builds a FIFO queue for input.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment