Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e61946b2
Commit
e61946b2
authored
Jul 01, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove SimplePredictBuilder
parent
25b31f68
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
64 deletions
+81
-64
docs/conf.py
docs/conf.py
+29
-25
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+29
-24
tensorpack/graph_builder/predict.py
tensorpack/graph_builder/predict.py
+2
-0
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+1
-0
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+14
-9
tensorpack/train/tower.py
tensorpack/train/tower.py
+6
-6
No files found.
docs/conf.py
View file @
e61946b2
...
@@ -353,20 +353,12 @@ def process_signature(app, what, name, obj, options, signature,
...
@@ -353,20 +353,12 @@ def process_signature(app, what, name, obj, options, signature,
# signature: arg list
# signature: arg list
return
signature
,
return_annotation
return
signature
,
return_annotation
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
# we hide something deliberately
_DEPRECATED_NAMES
=
set
([
if
getattr
(
obj
,
'__HIDE_SPHINX_DOC__'
,
False
):
return
True
if
name
==
'__init__'
:
if
obj
.
__doc__
and
skip
:
# include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258
return
False
# Hide some names that are deprecated or not intended to be used
if
name
in
[
# deprecated stuff:
# deprecated stuff:
'TryResumeTraining'
,
'TryResumeTraining'
,
'QueueInputTrainer'
,
'QueueInputTrainer'
,
'SimplePredictBuilder'
,
# renamed stuff:
# renamed stuff:
'DumpTensor'
,
'DumpTensor'
,
...
@@ -387,7 +379,19 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -387,7 +379,19 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'aggregate_grads'
,
'aggregate_grads'
,
'allreduce_grads'
,
'allreduce_grads'
,
'PrefetchOnGPUs'
,
'PrefetchOnGPUs'
,
]:
])
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
# we hide something deliberately
if
getattr
(
obj
,
'__HIDE_SPHINX_DOC__'
,
False
):
return
True
if
name
==
'__init__'
:
if
obj
.
__doc__
and
skip
:
# include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258
return
False
# Hide some names that are deprecated or not intended to be used
if
name
in
_DEPRECATED_NAMES
:
return
True
return
True
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
# skip these methods with empty docstring
# skip these methods with empty docstring
...
...
tensorpack/callbacks/inference_runner.py
View file @
e61946b2
...
@@ -15,10 +15,10 @@ from six.moves import range
...
@@ -15,10 +15,10 @@ from six.moves import range
from
..utils
import
logger
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.utils
import
get_tqdm_kwargs
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..tfutils.tower
import
PredictTowerContext
from
..input_source
import
(
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
)
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
)
from
..graph_builder.predict
import
SimplePredictBuilder
from
.base
import
Callback
from
.base
import
Callback
from
.group
import
Callbacks
from
.group
import
Callbacks
...
@@ -28,6 +28,10 @@ __all__ = ['InferenceRunnerBase', 'InferenceRunner',
...
@@ -28,6 +28,10 @@ __all__ = ['InferenceRunnerBase', 'InferenceRunner',
'DataParallelInferenceRunner'
]
'DataParallelInferenceRunner'
]
def
_device_from_int
(
dev
):
return
'/gpu:{}'
.
format
(
dev
)
if
dev
>=
0
else
'/cpu:0'
class
InferencerToHook
(
tf
.
train
.
SessionRunHook
):
class
InferencerToHook
(
tf
.
train
.
SessionRunHook
):
def
__init__
(
self
,
inf
,
fetches
):
def
__init__
(
self
,
inf
,
fetches
):
self
.
_inf
=
inf
self
.
_inf
=
inf
...
@@ -94,9 +98,9 @@ class InferenceRunnerBase(Callback):
...
@@ -94,9 +98,9 @@ class InferenceRunnerBase(Callback):
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_input_callbacks
.
before_train
()
self
.
_input_callbacks
.
before_train
()
if
self
.
_size
>
0
:
if
self
.
_size
>
0
:
logger
.
info
(
"
InferenceRunner w
ill eval {} iterations"
.
format
(
self
.
_size
))
logger
.
info
(
"
[InferenceRunner] W
ill eval {} iterations"
.
format
(
self
.
_size
))
else
:
else
:
logger
.
warn
(
"
InferenceRunner got an input with unknown size! It w
ill iterate until OutOfRangeError!"
)
logger
.
warn
(
"
[InferenceRunner] Got an InputSource with unknown size! W
ill iterate until OutOfRangeError!"
)
def
_after_train
(
self
):
def
_after_train
(
self
):
self
.
_input_callbacks
.
after_train
()
self
.
_input_callbacks
.
after_train
()
...
@@ -122,7 +126,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -122,7 +126,7 @@ class InferenceRunner(InferenceRunnerBase):
assert
isinstance
(
input
,
InputSource
),
input
assert
isinstance
(
input
,
InputSource
),
input
assert
not
isinstance
(
input
,
StagingInput
),
input
assert
not
isinstance
(
input
,
StagingInput
),
input
self
.
_tower_name
=
tower_name
self
.
_tower_name
=
tower_name
self
.
_device
=
device
self
.
_device
=
_device_from_int
(
device
)
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
...
@@ -131,16 +135,17 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -131,16 +135,17 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
device
=
self
.
_device
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
tower_func
=
self
.
trainer
.
tower_func
input_callbacks
=
self
.
_input_source
.
setup
(
tower_func
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
logger
.
info
(
"[InferenceRunner] Building tower '{}' on device {} ..."
.
format
(
self
.
_tower_name
,
self
.
_device
))
SimplePredictBuilder
(
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
ns_name
=
self
.
_tower_name
,
tf
.
device
(
self
.
_device
),
\
vs_name
=
self
.
trainer
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
PredictTowerContext
(
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_tower_name
,
vs_name
=
self
.
trainer
.
_main_tower_vs_name
):
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
tower_func
(
*
self
.
_input_source
.
get_input_tensors
())
self
.
_tower_handle
=
tower_func
.
towers
[
-
1
]
for
h
in
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]:
for
h
in
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]:
self
.
register_hook
(
h
)
self
.
register_hook
(
h
)
...
@@ -178,7 +183,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -178,7 +183,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
It will run the remainder (when the total size of input is not a multiple of #GPU)
It will run the remainder (when the total size of input is not a multiple of #GPU)
sequentially.
sequentially.
"""
"""
def
__init__
(
self
,
input
,
infs
,
gpus
):
def
__init__
(
self
,
input
,
infs
,
gpus
,
tower_name
=
'InferenceTower'
):
"""
"""
Args:
Args:
input (DataFlow or QueueInput)
input (DataFlow or QueueInput)
...
@@ -186,13 +191,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -186,13 +191,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
"""
"""
if
isinstance
(
gpus
,
int
):
if
isinstance
(
gpus
,
int
):
gpus
=
list
(
range
(
gpus
))
gpus
=
list
(
range
(
gpus
))
self
.
_tower_names
=
[
'InferenceTower{}'
.
format
(
k
)
for
k
in
range
(
len
(
gpus
))]
self
.
_devices
=
[
_device_from_int
(
k
)
for
k
in
gpus
]
self
.
_tower_names
=
[
'{}{}'
.
format
(
tower_name
,
k
)
for
k
in
range
(
len
(
gpus
))]
if
isinstance
(
input
,
DataFlow
):
if
isinstance
(
input
,
DataFlow
):
input
=
QueueInput
(
input
)
input
=
QueueInput
(
input
)
assert
isinstance
(
input
,
QueueInput
),
input
assert
isinstance
(
input
,
QueueInput
),
input
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
assert
self
.
_size
>
0
,
"Input for DataParallelInferenceRunner must have a size!"
assert
self
.
_size
>
0
,
"Input for DataParallelInferenceRunner must have a size!"
self
.
_gpus
=
gpus
self
.
_hooks
=
[]
self
.
_hooks
=
[]
self
.
_hooks_parallel
=
[]
self
.
_hooks_parallel
=
[]
...
@@ -201,15 +207,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -201,15 +207,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self
.
_handles
=
[]
self
.
_handles
=
[]
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
tower_func
=
self
.
trainer
.
tower_func
input_callbacks
=
self
.
_input_source
.
setup
(
tower_func
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
t
in
enumerate
(
self
.
_gpus
):
for
idx
,
dev
in
enumerate
(
self
.
_devices
):
tower_name
=
self
.
_tower_names
[
idx
]
with
tf
.
device
(
dev
),
PredictTowerContext
(
SimplePredictBuilder
(
self
.
_tower_names
[
idx
],
vs_name
=
self
.
trainer
.
_main_tower_vs_name
):
ns_name
=
tower_name
,
tower_func
(
*
self
.
_input_source
.
get_input_tensors
())
vs_name
=
self
.
trainer
.
_main_tower_vs_name
,
device
=
t
)
.
build
(
self
.
_handles
.
append
(
tower_func
.
towers
[
-
1
])
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_handles
.
append
(
self
.
trainer
.
tower_func
.
towers
[
-
1
])
# setup callbacks and hooks
# setup callbacks and hooks
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
...
@@ -267,7 +272,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -267,7 +272,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
inf
.
before_epoch
()
inf
.
before_epoch
()
total
=
self
.
_size
total
=
self
.
_size
nr_tower
=
len
(
self
.
_
gpu
s
)
nr_tower
=
len
(
self
.
_
device
s
)
self
.
_input_source
.
reset_state
()
self
.
_input_source
.
reset_state
()
with
_inference_context
():
with
_inference_context
():
with
tqdm
.
tqdm
(
total
=
total
,
**
get_tqdm_kwargs
())
as
pbar
:
with
tqdm
.
tqdm
(
total
=
total
,
**
get_tqdm_kwargs
())
as
pbar
:
...
...
tensorpack/graph_builder/predict.py
View file @
e61946b2
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.tower
import
PredictTowerContext
from
.training
import
GraphBuilder
from
.training
import
GraphBuilder
...
@@ -14,6 +15,7 @@ class SimplePredictBuilder(GraphBuilder):
...
@@ -14,6 +15,7 @@ class SimplePredictBuilder(GraphBuilder):
"""
"""
Single-tower predictor.
Single-tower predictor.
"""
"""
@
deprecated
(
"Please use TowerContext to build it by yourself!"
,
"2018-12-31"
)
def
__init__
(
self
,
ns_name
=
''
,
vs_name
=
''
,
device
=
0
):
def
__init__
(
self
,
ns_name
=
''
,
vs_name
=
''
,
device
=
0
):
"""
"""
Args:
Args:
...
...
tensorpack/input_source/input_source_base.py
View file @
e61946b2
...
@@ -92,6 +92,7 @@ class InputSource(object):
...
@@ -92,6 +92,7 @@ class InputSource(object):
Returns:
Returns:
list[Callback]: extra callbacks needed by this InputSource.
list[Callback]: extra callbacks needed by this InputSource.
callbacks of InputSource cannot use any `trigger*()` method.
"""
"""
self
.
_setup
(
inputs_desc
)
self
.
_setup
(
inputs_desc
)
self
.
_setup_done
=
True
self
.
_setup_done
=
True
...
...
tensorpack/predict/multigpu.py
View file @
e61946b2
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..graph_builder.predict
import
SimplePredictBuilder
from
..graph_builder.model_desc
import
InputDesc
from
..graph_builder.model_desc
import
InputDesc
from
..input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..tfutils.tower
import
PredictTowerContext
from
.base
import
OnlinePredictor
from
.base
import
OnlinePredictor
__all__
=
[
'MultiTowerOfflinePredictor'
,
__all__
=
[
'MultiTowerOfflinePredictor'
,
...
@@ -14,7 +14,9 @@ __all__ = ['MultiTowerOfflinePredictor',
...
@@ -14,7 +14,9 @@ __all__ = ['MultiTowerOfflinePredictor',
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
""" A multi-tower multi-GPU predictor. """
""" A multi-tower multi-GPU predictor.
It builds one predictor for each tower.
"""
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
"""
"""
...
@@ -35,9 +37,10 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -35,9 +37,10 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
tower_name
=
'tower'
+
str
(
t
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
),
\
builder
=
SimplePredictBuilder
(
ns_name
=
tower_name
,
device
=
t
)
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
builder
.
build
(
input
,
config
.
tower_func
)
PredictTowerContext
(
tower_name
):
config
.
tower_func
(
*
input
.
get_input_tensors
())
handles
.
append
(
config
.
tower_func
.
towers
[
-
1
])
handles
.
append
(
config
.
tower_func
.
towers
[
-
1
])
self
.
sess
=
config
.
session_creator
.
create_session
()
self
.
sess
=
config
.
session_creator
.
create_session
()
...
@@ -73,7 +76,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -73,7 +76,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
"""
"""
A data-parallel predictor.
A data-parallel predictor. It builds one predictor that utilizes all GPUs.
Note that it doesn't split/concat inputs/outputs automatically.
Note that it doesn't split/concat inputs/outputs automatically.
Instead, its inputs are:
Instead, its inputs are:
``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
...
@@ -99,9 +103,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -99,9 +103,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
inputs_desc
)
input
.
setup
(
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
),
\
builder
=
SimplePredictBuilder
(
ns_name
=
tower_name
,
device
=
t
)
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
builder
.
build
(
input
,
config
.
tower_func
)
PredictTowerContext
(
tower_name
):
config
.
tower_func
(
*
input
.
get_input_tensors
())
h
=
config
.
tower_func
.
towers
[
-
1
]
h
=
config
.
tower_func
.
towers
[
-
1
]
input_tensors
.
extend
(
h
.
get_tensors
(
config
.
input_names
))
input_tensors
.
extend
(
h
.
get_tensors
(
config
.
input_names
))
output_tensors
.
extend
(
h
.
get_tensors
(
config
.
output_names
))
output_tensors
.
extend
(
h
.
get_tensors
(
config
.
output_names
))
...
...
tensorpack/train/tower.py
View file @
e61946b2
...
@@ -6,11 +6,10 @@ import six
...
@@ -6,11 +6,10 @@ import six
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.argtools
import
call_only_once
,
memoized
from
..graph_builder.predict
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..predict.base
import
OnlinePredictor
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
,
PredictTowerContext
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.gradproc
import
FilterNoneGrad
from
.base
import
Trainer
from
.base
import
Trainer
...
@@ -94,6 +93,7 @@ class TowerTrainer(Trainer):
...
@@ -94,6 +93,7 @@ class TowerTrainer(Trainer):
"""
"""
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
device
=
'/gpu:{}'
.
format
(
device
)
if
device
>=
0
else
'/cpu:0'
try
:
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
tower
=
self
.
tower_func
.
towers
[
tower_name
]
...
@@ -105,10 +105,10 @@ class TowerTrainer(Trainer):
...
@@ -105,10 +105,10 @@ class TowerTrainer(Trainer):
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
)
:
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
)
,
\
SimplePredictBuilder
(
tf
.
device
(
device
),
PredictTowerContext
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
):
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
self
.
tower_func
(
*
input
.
get_input_tensors
()
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment