Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
14a28c01
Commit
14a28c01
authored
Jul 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
online predictor in trainer
parent
86ec2d15
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
37 deletions
+35
-37
tensorpack/predict/base.py
tensorpack/predict/base.py
+27
-24
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+6
-11
tensorpack/utils/rect.py
tensorpack/utils/rect.py
+2
-2
No files found.
tensorpack/predict/base.py
View file @
14a28c01
...
...
@@ -7,18 +7,18 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import
tensorflow
as
tf
from
..tfutils
import
get_vars_by_names
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
]
class
PredictorBase
(
object
):
__metaclass__
=
ABCMeta
@
abstractproperty
def
session
(
self
):
""" return the session the predictor is running on"""
pass
"""
Property:
session
return_input
"""
def
__call__
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_var_names
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_var_names
))
output
=
self
.
_do_call
(
dp
)
if
self
.
return_input
:
return
(
dp
,
output
)
...
...
@@ -33,8 +33,23 @@ class PredictorBase(object):
"""
pass
class
OnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
input_vars
,
output_vars
,
return_input
=
False
):
self
.
session
=
sess
self
.
return_input
=
return_input
self
.
input_vars
=
input_vars
self
.
output_vars
=
output_vars
def
_do_call
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_vars
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_vars
))
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_vars
,
feed_dict
=
feed
)
return
output
class
OfflinePredictor
(
PredictorBase
):
class
OfflinePredictor
(
OnlinePredictor
):
""" Build a predictor from a given config, in an independent graph"""
def
__init__
(
self
,
config
):
self
.
graph
=
tf
.
Graph
()
...
...
@@ -42,22 +57,10 @@ class OfflinePredictor(PredictorBase):
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
self
.
input_var_names
=
config
.
input_var_names
self
.
output_var_names
=
config
.
output_var_names
self
.
return_input
=
config
.
return_input
self
.
input_vars
=
get_vars_by_names
(
self
.
input_var_names
)
self
.
output_vars
=
get_vars_by_names
(
self
.
output_var_names
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
self
.
_session
=
sess
@
property
def
session
(
self
):
return
self
.
_session
def
_do_call
(
self
,
dp
):
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_vars
,
feed_dict
=
feed
)
return
output
super
(
OfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
tensorpack/train/trainer.py
View file @
14a28c01
...
...
@@ -8,11 +8,14 @@ import time
from
six.moves
import
zip
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..utils
import
*
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..utils
import
*
from
..tfutils
import
*
from
..predict
import
OnlinePredictor
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
...
...
@@ -56,11 +59,7 @@ class SimpleTrainer(Trainer):
for
v
in
input_vars
:
assert
v
in
self
.
input_vars
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
return
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
)
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
):
...
...
@@ -218,11 +217,7 @@ class QueueInputTrainer(Trainer):
raw_input_vars
=
get_vars_by_names
(
input_names
)
output_names
=
[
'towerp{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
raw_input_vars
)
feed
=
dict
(
zip
(
raw_input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
""" return n predicts functions evenly on each predict_tower"""
...
...
tensorpack/utils/rect.py
View file @
14a28c01
...
...
@@ -102,5 +102,5 @@ if __name__ == '__main__':
x
=
Rect
(
2
,
1
,
3
,
3
,
allow_neg
=
True
)
img
=
np
.
random
.
rand
(
3
,
3
)
print
img
print
x
.
roi_zeropad
(
img
)
print
(
img
)
print
(
x
.
roi_zeropad
(
img
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment