Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fe2b4f97
Commit
fe2b4f97
authored
Feb 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
code clean-up in predict/
parent
8df83a93
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
104 additions
and
73 deletions
+104
-73
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+0
-1
examples/DoReFa-Net/resnet-dorefa.py
examples/DoReFa-Net/resnet-dorefa.py
+0
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+2
-3
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+1
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+8
-5
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+7
-6
tensorpack/predict/config.py
tensorpack/predict/config.py
+19
-9
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+25
-26
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+34
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+3
-18
tensorpack/train/predict.py
tensorpack/train/predict.py
+5
-4
No files found.
examples/DoReFa-Net/alexnet-dorefa.py
View file @
fe2b4f97
...
@@ -255,7 +255,6 @@ def run_image(model, sess_init, inputs):
...
@@ -255,7 +255,6 @@ def run_image(model, sess_init, inputs):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
model
,
model
=
model
,
session_init
=
sess_init
,
session_init
=
sess_init
,
session_config
=
get_default_sess_config
(
0.9
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'output'
]
output_names
=
[
'output'
]
)
)
...
...
examples/DoReFa-Net/resnet-dorefa.py
View file @
fe2b4f97
...
@@ -125,7 +125,6 @@ def run_image(model, sess_init, inputs):
...
@@ -125,7 +125,6 @@ def run_image(model, sess_init, inputs):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
model
,
model
=
model
,
session_init
=
sess_init
,
session_init
=
sess_init
,
session_config
=
get_default_sess_config
(
0.9
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'output'
]
output_names
=
[
'output'
]
)
)
...
...
tensorpack/callbacks/base.py
View file @
fe2b4f97
...
@@ -5,14 +5,14 @@
...
@@ -5,14 +5,14 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
six
import
six
from
..tfutils.common
import
get_op_or_tensor_by_name
,
get_global_step_value
from
..tfutils.common
import
get_op_or_tensor_by_name
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
,
'Triggerable'
]
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
,
'Triggerable'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks
""" Base class for all callbacks
.
Attributes:
Attributes:
epoch_num(int): the number of the current epoch.
epoch_num(int): the number of the current epoch.
...
@@ -50,7 +50,6 @@ class Callback(object):
...
@@ -50,7 +50,6 @@ class Callback(object):
pass
pass
def
before_train
(
self
):
def
before_train
(
self
):
self
.
_starting_step
=
get_global_step_value
()
self
.
_before_train
()
self
.
_before_train
()
def
_before_train
(
self
):
def
_before_train
(
self
):
...
...
tensorpack/dataflow/prefetch.py
View file @
fe2b4f97
...
@@ -154,6 +154,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -154,6 +154,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
yield
dp
yield
dp
except
zmq
.
ContextTerminated
:
except
zmq
.
ContextTerminated
:
logger
.
info
(
"ContextTerminated in Master Prefetch Process"
)
logger
.
info
(
"ContextTerminated in Master Prefetch Process"
)
return
except
:
except
:
raise
raise
...
...
tensorpack/predict/base.py
View file @
fe2b4f97
...
@@ -88,7 +88,8 @@ class AsyncPredictorBase(PredictorBase):
...
@@ -88,7 +88,8 @@ class AsyncPredictorBase(PredictorBase):
class
OnlinePredictor
(
PredictorBase
):
class
OnlinePredictor
(
PredictorBase
):
""" A predictor which directly use an existing session. """
""" A predictor which directly use an existing session and given tensors.
"""
def
__init__
(
self
,
input_tensors
,
output_tensors
,
def
__init__
(
self
,
input_tensors
,
output_tensors
,
return_input
=
False
,
sess
=
None
):
return_input
=
False
,
sess
=
None
):
...
@@ -131,13 +132,13 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -131,13 +132,13 @@ class OfflinePredictor(OnlinePredictor):
with
TowerContext
(
''
,
False
):
with
TowerContext
(
''
,
False
):
config
.
model
.
build_graph
(
input_placehdrs
)
config
.
model
.
build_graph
(
input_placehdrs
)
input_
va
rs
=
get_tensors_by_names
(
config
.
input_names
)
input_
tenso
rs
=
get_tensors_by_names
(
config
.
input_names
)
output_
va
rs
=
get_tensors_by_names
(
config
.
output_names
)
output_
tenso
rs
=
get_tensors_by_names
(
config
.
output_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
config
.
session_creator
.
create_session
(
)
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
super
(
OfflinePredictor
,
self
)
.
__init__
(
super
(
OfflinePredictor
,
self
)
.
__init__
(
input_
vars
,
output_va
rs
,
config
.
return_input
,
sess
)
input_
tensors
,
output_tenso
rs
,
config
.
return_input
,
sess
)
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
...
@@ -149,6 +150,8 @@ def get_predict_func(config):
...
@@ -149,6 +150,8 @@ def get_predict_func(config):
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
"""
"""
Build graph on each tower.
Args:
Args:
build_tower_fn: a function that will be called inside each tower,
build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
taking tower id as the argument.
...
...
tensorpack/predict/concurrency.py
View file @
fe2b4f97
...
@@ -8,7 +8,7 @@ import six
...
@@ -8,7 +8,7 @@ import six
from
six.moves
import
queue
,
range
from
six.moves
import
queue
,
range
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils
import
logger
,
deprecated
from
..utils.concurrency
import
DIE
,
StoppableThread
,
ShareSessionThread
from
..utils.concurrency
import
DIE
,
StoppableThread
,
ShareSessionThread
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
.base
import
OnlinePredictor
,
OfflinePredictor
,
AsyncPredictorBase
from
.base
import
OnlinePredictor
,
OfflinePredictor
,
AsyncPredictorBase
...
@@ -27,6 +27,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -27,6 +27,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
config (PredictConfig): the config to use.
config (PredictConfig): the config to use.
"""
"""
super
(
MultiProcessPredictWorker
,
self
)
.
__init__
()
super
(
MultiProcessPredictWorker
,
self
)
.
__init__
()
self
.
name
=
"MultiProcessPredictWorker-{}"
.
format
(
idx
)
self
.
idx
=
idx
self
.
idx
=
idx
self
.
config
=
config
self
.
config
=
config
...
@@ -76,6 +77,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
...
@@ -76,6 +77,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class
PredictorWorkerThread
(
StoppableThread
,
ShareSessionThread
):
class
PredictorWorkerThread
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
name
=
"PredictorWorkerThread-{}"
.
format
(
id
)
self
.
queue
=
queue
self
.
queue
=
queue
self
.
func
=
pred_func
self
.
func
=
pred_func
self
.
daemon
=
True
self
.
daemon
=
True
...
@@ -112,22 +114,20 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
...
@@ -112,22 +114,20 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
for
k
in
range
(
nr_input_var
):
for
k
in
range
(
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
cnt
=
1
while
len
(
futures
)
<
self
.
batch_size
:
while
cnt
<
self
.
batch_size
:
try
:
try
:
inp
,
f
=
self
.
queue
.
get_nowait
()
inp
,
f
=
self
.
queue
.
get_nowait
()
for
k
in
range
(
nr_input_var
):
for
k
in
range
(
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
# do not wait
cnt
+=
1
return
batched
,
futures
return
batched
,
futures
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
"""
An multithread online async predictor which runs a list of
PredictorBase
.
An multithread online async predictor which runs a list of
OnlinePredictor
.
It would do an extra batching internally.
It would do an extra batching internally.
"""
"""
...
@@ -164,6 +164,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -164,6 +164,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
for
t
in
self
.
threads
:
for
t
in
self
.
threads
:
t
.
start
()
t
.
start
()
@
deprecated
(
"Use 'start()' instead!"
,
"2017-03-11"
)
def
run
(
self
):
# temporarily for back-compatibility
def
run
(
self
):
# temporarily for back-compatibility
self
.
start
()
self
.
start
()
...
...
tensorpack/predict/config.py
View file @
fe2b4f97
...
@@ -5,50 +5,60 @@
...
@@ -5,50 +5,60 @@
import
six
import
six
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
log_deprecated
from
..tfutils
import
get_default_sess_config
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sesscreate
import
NewSession
__all__
=
[
'PredictConfig'
]
__all__
=
[
'PredictConfig'
]
class
PredictConfig
(
object
):
class
PredictConfig
(
object
):
def
__init__
(
self
,
model
,
session_init
=
None
,
def
__init__
(
self
,
model
,
session_config
=
get_default_sess_config
(
0.4
),
session_creator
=
None
,
session_init
=
None
,
session_config
=
None
,
input_names
=
None
,
input_names
=
None
,
output_names
=
None
,
output_names
=
None
,
return_input
=
False
):
return_input
=
False
):
"""
"""
Args:
Args:
model (ModelDesc): the model to use.
model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`.
session_init (SessionInit): how to initialize variables of the session.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
Defaults to do nothing.
session_config]
input_names (list): a list of input tensor names. Defaults to all
input_names (list): a list of input tensor names. Defaults to all
inputs of the model.
inputs of the model.
output_names (list): a list of names of the output tensors to predict, the
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
return_input: same as in :attr:`PredictorBase.return_input`.
"""
"""
# TODO use the name "tensor" instead of "variable"
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
model
=
model
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self
.
session_config
=
session_config
if
session_init
is
None
:
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
if
session_creator
is
None
:
if
session_config
is
not
None
:
log_deprecated
(
"PredictConfig(session_config=)"
,
"Use session_creator instead!"
,
"2017-04-20"
)
self
.
session_creator
=
NewSession
(
config
=
session_config
)
else
:
self
.
session_creator
=
NewSession
(
config
=
get_default_sess_config
(
0.4
))
else
:
self
.
session_creator
=
session_creator
# inputs & outputs
# inputs & outputs
self
.
input_names
=
input_names
self
.
input_names
=
input_names
if
self
.
input_names
is
None
:
if
self
.
input_names
is
None
:
# neither options is set, assume all inputs
# neither options is set, assume all inputs
raw_
va
rs
=
self
.
model
.
get_inputs_desc
()
raw_
tenso
rs
=
self
.
model
.
get_inputs_desc
()
self
.
input_names
=
[
k
.
name
for
k
in
raw_
va
rs
]
self
.
input_names
=
[
k
.
name
for
k
in
raw_
tenso
rs
]
self
.
output_names
=
output_names
self
.
output_names
=
output_names
assert_type
(
self
.
output_names
,
list
)
assert_type
(
self
.
output_names
,
list
)
assert_type
(
self
.
input_names
,
list
)
assert_type
(
self
.
input_names
,
list
)
...
...
tensorpack/predict/multigpu.py
View file @
fe2b4f97
...
@@ -5,8 +5,6 @@
...
@@ -5,8 +5,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils.naming
import
PREDICT_TOWER
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
.base
import
OnlinePredictor
,
build_prediction_graph
from
.base
import
OnlinePredictor
,
build_prediction_graph
...
@@ -31,17 +29,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -31,17 +29,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
config
.
model
.
build_graph
(
config
.
model
.
get_reused_placehdrs
())
config
.
model
.
build_graph
(
config
.
model
.
get_reused_placehdrs
())
build_prediction_graph
(
fn
,
towers
)
build_prediction_graph
(
fn
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
self
.
sess
=
config
.
session_creator
.
create_session
(
)
config
.
session_init
.
init
(
self
.
sess
)
config
.
session_init
.
init
(
self
.
sess
)
input_
va
rs
=
get_tensors_by_names
(
config
.
input_names
)
input_
tenso
rs
=
get_tensors_by_names
(
config
.
input_names
)
for
k
in
towers
:
for
k
in
towers
:
output_
va
rs
=
get_tensors_by_names
(
output_
tenso
rs
=
get_tensors_by_names
(
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
[
TowerContext
.
get_predict_towre_name
(
''
,
k
)
+
'/'
+
n
for
n
in
config
.
output_names
])
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
predictors
.
append
(
OnlinePredictor
(
input_
vars
,
output_va
rs
,
config
.
return_input
,
self
.
sess
))
input_
tensors
,
output_tenso
rs
,
config
.
return_input
,
self
.
sess
))
def
_do_call
(
self
,
dp
):
def
_do_call
(
self
,
dp
):
# use the first tower for compatible PredictorBase interface
# use the first tower for compatible PredictorBase interface
...
@@ -57,7 +55,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -57,7 +55,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
""" A data-parallel predictor.
""" A data-parallel predictor.
It runs different towers in parallel.
Its input is: [input[0] in tower[0], input[1] in tower[0], ...,
input[0] in tower[1], input[1] in tower[1], ...]
And same for the output.
"""
"""
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
...
@@ -68,26 +68,25 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -68,26 +68,25 @@ class DataParallelOfflinePredictor(OnlinePredictor):
"""
"""
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
input_names
=
[]
input_var_names
=
[]
output_tensors
=
[]
output_vars
=
[]
for
idx
,
k
in
enumerate
(
towers
):
def
build_tower
(
k
):
towername
=
PREDICT_TOWER
+
str
(
k
)
towername
=
TowerContext
.
get_predict_tower_name
(
k
)
input_vars
=
config
.
model
.
build_placeholders
(
# inputs (placeholders) for this tower only
prefix
=
towername
+
'-'
)
input_tensors
=
config
.
model
.
build_placeholders
(
prefix
=
towername
+
'/'
)
logger
.
info
(
config
.
model
.
build_graph
(
input_tensors
)
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
input_names
.
extend
([
t
.
name
for
t
in
input_tensors
])
TowerContext
(
towername
,
is_training
=
False
),
\
output_tensors
.
extend
(
get_tensors_by_names
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
if
idx
>
0
else
None
):
config
.
model
.
build_graph
(
input_vars
)
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
output_vars
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
[
towername
+
'/'
+
n
for
n
in
config
.
output_names
]))
for
n
in
config
.
output_names
]))
input_vars
=
get_tensors_by_names
(
input_var_names
)
build_prediction_graph
(
build_tower
,
towers
)
input_tensors
=
get_tensors_by_names
(
input_names
)
sess
=
config
.
session_creator
.
create_session
()
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
input_
vars
,
output_va
rs
,
config
.
return_input
,
sess
)
input_
tensors
,
output_tenso
rs
,
config
.
return_input
,
sess
)
tensorpack/tfutils/sesscreate.py
0 → 100644
View file @
fe2b4f97
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sesscreate.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
__all__
=
[
'NewSession'
,
'ReuseSession'
]
class
NewSession
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
graph
=
None
,
config
=
None
):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
"""
self
.
target
=
target
self
.
config
=
config
self
.
graph
=
graph
def
create_session
(
self
):
return
tf
.
Session
(
target
=
self
.
target
,
graph
=
self
.
graph
,
config
=
self
.
config
)
class
ReuseSession
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
sess
):
"""
Args:
sess (tf.Session): the session to reuse
"""
self
.
sess
=
sess
def
create_session
(
self
):
return
self
.
sess
tensorpack/tfutils/sessinit.py
View file @
fe2b4f97
...
@@ -12,13 +12,13 @@ from .common import get_op_tensor_name
...
@@ -12,13 +12,13 @@ from .common import get_op_tensor_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
is_training_name
,
get_checkpoint_path
)
__all__
=
[
'SessionInit'
,
'
NewSession'
,
'
SaverRestore'
,
'SaverRestoreRelaxed'
,
__all__
=
[
'SessionInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'ParamRestore'
,
'ChainInit'
,
'ParamRestore'
,
'ChainInit'
,
'JustCurrentSession'
,
'get_model_loader'
]
'JustCurrentSession'
,
'get_model_loader'
]
class
SessionInit
(
object
):
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session. """
""" Base class for utilities to initialize a
(existing)
session. """
def
init
(
self
,
sess
):
def
init
(
self
,
sess
):
"""
"""
Initialize a session
Initialize a session
...
@@ -44,17 +44,6 @@ class JustCurrentSession(SessionInit):
...
@@ -44,17 +44,6 @@ class JustCurrentSession(SessionInit):
pass
pass
class
NewSession
(
SessionInit
):
"""
Initialize global variables by their initializer.
"""
def
_setup_graph
(
self
):
self
.
op
=
tf
.
global_variables_initializer
()
def
_run_init
(
self
,
sess
):
sess
.
run
(
self
.
op
)
class
CheckpointReaderAdapter
(
object
):
class
CheckpointReaderAdapter
(
object
):
"""
"""
An adapter to work around old checkpoint format, where the keys are op
An adapter to work around old checkpoint format, where the keys are op
...
@@ -207,15 +196,11 @@ class ChainInit(SessionInit):
...
@@ -207,15 +196,11 @@ class ChainInit(SessionInit):
to form a composition of models.
to form a composition of models.
"""
"""
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
def
__init__
(
self
,
sess_inits
):
"""
"""
Args:
Args:
sess_inits (list[SessionInit]): list of :class:`SessionInit` instances.
sess_inits (list[SessionInit]): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
"""
"""
if
new_session
and
not
isinstance
(
sess_inits
[
0
],
NewSession
):
sess_inits
.
insert
(
0
,
NewSession
())
self
.
inits
=
sess_inits
self
.
inits
=
sess_inits
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
...
...
tensorpack/train/predict.py
View file @
fe2b4f97
...
@@ -14,7 +14,7 @@ __all__ = ['PredictorFactory']
...
@@ -14,7 +14,7 @@ __all__ = ['PredictorFactory']
class
PredictorFactory
(
object
):
class
PredictorFactory
(
object
):
""" Make predictors f
or a trainer
"""
""" Make predictors f
rom a trainer.
"""
def
__init__
(
self
,
trainer
):
def
__init__
(
self
,
trainer
):
"""
"""
...
@@ -25,6 +25,7 @@ class PredictorFactory(object):
...
@@ -25,6 +25,7 @@ class PredictorFactory(object):
self
.
towers
=
trainer
.
config
.
predict_tower
self
.
towers
=
trainer
.
config
.
predict_tower
assert
isinstance
(
self
.
towers
,
list
)
assert
isinstance
(
self
.
towers
,
list
)
# TODO sess option
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
"""
Args:
Args:
...
@@ -48,11 +49,11 @@ class PredictorFactory(object):
...
@@ -48,11 +49,11 @@ class PredictorFactory(object):
return
get_name_in_tower
(
name
)
return
get_name_in_tower
(
name
)
input_names
=
map
(
maybe_inside_tower
,
input_names
)
input_names
=
map
(
maybe_inside_tower
,
input_names
)
raw_input_
va
rs
=
get_tensors_by_names
(
input_names
)
raw_input_
tenso
rs
=
get_tensors_by_names
(
input_names
)
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_
va
rs
=
get_tensors_by_names
(
output_names
)
output_
tenso
rs
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
raw_input_
vars
,
output_va
rs
)
return
OnlinePredictor
(
raw_input_
tensors
,
output_tenso
rs
)
@
memoized
@
memoized
def
_build_predict_tower
(
self
):
def
_build_predict_tower
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment