Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1e7fa5f9
Commit
1e7fa5f9
authored
Jul 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make InputSource subclass methods private
parent
ef9e27a6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
54 deletions
+95
-54
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-1
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+93
-53
No files found.
tensorpack/callbacks/inference_runner.py
View file @
1e7fa5f9
...
@@ -270,7 +270,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -270,7 +270,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
total
-=
nr_tower
total
-=
nr_tower
# take care of the rest
# take care of the rest
while
total
>
0
:
while
total
>
0
:
feed
=
self
.
_input_source
.
next_feed
(
cnt
=
1
)
# TODO XXX doesn't support remap
feed
=
self
.
_input_source
.
_next_feed
(
cnt
=
1
)
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
total
-=
1
total
-=
1
...
...
tensorpack/train/input_source.py
View file @
1e7fa5f9
...
@@ -36,7 +36,6 @@ __all__ = ['InputSource',
...
@@ -36,7 +36,6 @@ __all__ = ['InputSource',
class
InputSource
(
object
):
class
InputSource
(
object
):
""" Base class for the abstract InputSource. """
""" Base class for the abstract InputSource. """
@
abstractmethod
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
"""
"""
Returns:
Returns:
...
@@ -44,12 +43,20 @@ class InputSource(object):
...
@@ -44,12 +43,20 @@ class InputSource(object):
used as input of :func:`build_graph`.
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
For non-placeholder tensors, should always create and return new tensors when called.
"""
"""
return
self
.
_get_input_tensors
()
@
abstractmethod
def
_get_input_tensors
(
self
):
pass
def
setup
(
self
,
inputs_desc
):
def
setup
(
self
,
inputs_desc
):
"""
"""
Args:
Args:
inputs_desc (list[InputDesc]): list of input desc
inputs_desc (list[InputDesc]): list of input desc
"""
"""
self
.
_setup
(
inputs_desc
)
def
_setup
(
self
,
inputs_desc
):
pass
pass
def
get_callbacks
(
self
):
def
get_callbacks
(
self
):
...
@@ -57,21 +64,31 @@ class InputSource(object):
...
@@ -57,21 +64,31 @@ class InputSource(object):
Returns:
Returns:
list[Callback]: extra callbacks required by this InputSource.
list[Callback]: extra callbacks required by this InputSource.
"""
"""
return
self
.
_get_callbacks
()
def
_get_callbacks
(
self
):
return
[]
return
[]
@
abstractmethod
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
"""
Semantics of this method has not been well defined.
Semantics of this method has not been well defined.
"""
"""
pass
# TODO
self
.
_reset_state
()
@
abstractmethod
@
abstractmethod
def
_reset_state
(
self
):
pass
def
next_feed
(
self
):
def
next_feed
(
self
):
"""
"""
Returns:
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
a feed_dict of {Tensor: data}, to be used to run the steps
"""
"""
return
self
.
_next_feed
()
@
abstractmethod
def
_next_feed
(
self
):
pass
pass
def
size
(
self
):
def
size
(
self
):
...
@@ -79,7 +96,10 @@ class InputSource(object):
...
@@ -79,7 +96,10 @@ class InputSource(object):
Returns:
Returns:
int: epoch size of the InputSource
int: epoch size of the InputSource
"""
"""
return
NotImplementedError
()
return
self
.
_size
()
def
_size
(
self
):
raise
NotImplementedError
()
class
ProxyInputSource
(
InputSource
):
class
ProxyInputSource
(
InputSource
):
...
@@ -90,22 +110,22 @@ class ProxyInputSource(InputSource):
...
@@ -90,22 +110,22 @@ class ProxyInputSource(InputSource):
assert
isinstance
(
input
,
InputSource
),
input
assert
isinstance
(
input
,
InputSource
),
input
self
.
_input
=
input
self
.
_input
=
input
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
return
self
.
_input
.
get_input_tensors
()
return
self
.
_input
.
get_input_tensors
()
def
setup
(
self
,
inputs_desc
):
def
_
setup
(
self
,
inputs_desc
):
self
.
_input
.
setup
(
inputs_desc
)
self
.
_input
.
setup
(
inputs_desc
)
def
get_callbacks
(
self
):
def
_
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
return
self
.
_input
.
get_callbacks
()
def
size
(
self
):
def
_
size
(
self
):
return
self
.
_input
.
size
()
return
self
.
_input
.
size
()
def
next_feed
(
self
):
def
_
next_feed
(
self
):
return
self
.
_input
.
next_feed
()
return
self
.
_input
.
next_feed
()
def
reset_state
(
self
):
def
_
reset_state
(
self
):
self
.
_input
.
reset_state
()
self
.
_input
.
reset_state
()
...
@@ -119,22 +139,22 @@ class FeedInput(InputSource):
...
@@ -119,22 +139,22 @@ class FeedInput(InputSource):
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
def
size
(
self
):
def
_
size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
reset_state
()
self
.
reset_state
()
def
reset_state
(
self
):
def
_
reset_state
(
self
):
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
.
reset_state
()
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
self
.
data_producer
=
rds
.
get_data
()
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
return
self
.
_all_placehdrs
return
self
.
_all_placehdrs
def
next_feed
(
self
):
def
_
next_feed
(
self
):
dp
=
next
(
self
.
data_producer
)
dp
=
next
(
self
.
data_producer
)
assert
len
(
dp
)
==
len
(
self
.
_all_placehdrs
),
"[FeedInput] datapoints and inputs are of different length!"
assert
len
(
dp
)
==
len
(
self
.
_all_placehdrs
),
"[FeedInput] datapoints and inputs are of different length!"
return
dict
(
zip
(
self
.
_all_placehdrs
,
dp
))
return
dict
(
zip
(
self
.
_all_placehdrs
,
dp
))
...
@@ -149,7 +169,7 @@ class DataParallelFeedInput(FeedInput):
...
@@ -149,7 +169,7 @@ class DataParallelFeedInput(FeedInput):
self
.
_tower_names
=
tower_names
self
.
_tower_names
=
tower_names
self
.
_nr_tower
=
len
(
tower_names
)
self
.
_nr_tower
=
len
(
tower_names
)
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
self
.
_placehdrs_per_tower
=
[]
self
.
_placehdrs_per_tower
=
[]
for
tname
in
self
.
_tower_names
:
for
tname
in
self
.
_tower_names
:
# build a list of placeholders for each tower
# build a list of placeholders for each tower
...
@@ -157,12 +177,12 @@ class DataParallelFeedInput(FeedInput):
...
@@ -157,12 +177,12 @@ class DataParallelFeedInput(FeedInput):
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
self
.
reset_state
()
self
.
reset_state
()
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
# return placeholders for each tower
# return placeholders for each tower
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
def
next_feed
(
self
,
cnt
=
None
):
def
_
next_feed
(
self
,
cnt
=
None
):
"""
"""
Args:
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
cnt: how many towers to feed to. Defaults to the total number of towers
...
@@ -181,10 +201,10 @@ class FeedfreeInput(InputSource):
...
@@ -181,10 +201,10 @@ class FeedfreeInput(InputSource):
""" Abstract base for input without feed,
""" Abstract base for input without feed,
e.g. by queue or other operations. """
e.g. by queue or other operations. """
def
reset_state
(
self
):
def
_
reset_state
(
self
):
pass
pass
def
next_feed
(
self
):
def
_
next_feed
(
self
):
return
{}
return
{}
...
@@ -244,10 +264,10 @@ class QueueInput(FeedfreeInput):
...
@@ -244,10 +264,10 @@ class QueueInput(FeedfreeInput):
self
.
queue
=
queue
self
.
queue
=
queue
self
.
ds
=
ds
self
.
ds
=
ds
def
size
(
self
):
def
_
size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
...
@@ -258,12 +278,12 @@ class QueueInput(FeedfreeInput):
...
@@ -258,12 +278,12 @@ class QueueInput(FeedfreeInput):
name
=
'input_queue'
)
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
def
get_callbacks
(
self
):
def
_
get_callbacks
(
self
):
cb
=
StartProcOrThread
(
self
.
thread
)
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
chief_only
=
False
cb
.
chief_only
=
False
return
[
cb
]
return
[
cb
]
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
):
with
tf
.
device
(
'/cpu:0'
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
...
@@ -291,10 +311,10 @@ class BatchQueueInput(QueueInput):
...
@@ -291,10 +311,10 @@ class BatchQueueInput(QueueInput):
super
(
BatchQueueInput
,
self
)
.
__init__
(
ds
,
queue
)
super
(
BatchQueueInput
,
self
)
.
__init__
(
ds
,
queue
)
self
.
batch_size
=
int
(
batch_size
)
self
.
batch_size
=
int
(
batch_size
)
def
size
(
self
):
def
_
size
(
self
):
return
self
.
ds
.
size
()
//
self
.
batch_size
return
self
.
ds
.
size
()
//
self
.
batch_size
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
...
@@ -325,7 +345,7 @@ class BatchQueueInput(QueueInput):
...
@@ -325,7 +345,7 @@ class BatchQueueInput(QueueInput):
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
):
with
tf
.
device
(
'/cpu:0'
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
...
@@ -338,6 +358,7 @@ class BatchQueueInput(QueueInput):
...
@@ -338,6 +358,7 @@ class BatchQueueInput(QueueInput):
return
ret
return
ret
# TODO tensor inputs can be drained? look at the new dataset API.
class
TensorInput
(
FeedfreeInput
):
class
TensorInput
(
FeedfreeInput
):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
...
@@ -352,15 +373,20 @@ class TensorInput(FeedfreeInput):
...
@@ -352,15 +373,20 @@ class TensorInput(FeedfreeInput):
if
size
is
not
None
:
if
size
is
not
None
:
size
=
int
(
size
)
size
=
int
(
size
)
assert
size
>
0
assert
size
>
0
self
.
_size
=
size
self
.
_
fixed_
size
=
size
def
size
(
self
):
def
_setup
(
self
,
inputs_desc
):
if
self
.
_size
is
None
:
self
.
_desc
=
inputs_desc
def
_size
(
self
):
if
self
.
_fixed_size
is
None
:
raise
NotImplementedError
(
"size of TensorInput is undefined!"
)
raise
NotImplementedError
(
"size of TensorInput is undefined!"
)
return
self
.
_size
return
self
.
_
fixed_
size
def
get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
get_tensor_fn
()
ret
=
self
.
get_tensor_fn
()
assert
len
(
ret
)
==
len
(
self
.
_desc
),
"{} != {}"
.
format
(
len
(
ret
),
len
(
self
.
_desc
))
return
ret
class
DummyConstantInput
(
TensorInput
):
class
DummyConstantInput
(
TensorInput
):
...
@@ -387,7 +413,7 @@ class DummyConstantInput(TensorInput):
...
@@ -387,7 +413,7 @@ class DummyConstantInput(TensorInput):
return
tlist
return
tlist
super
(
DummyConstantInput
,
self
)
.
__init__
(
fn
)
super
(
DummyConstantInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
self
.
inputs_desc
=
inputs
self
.
inputs_desc
=
inputs
...
@@ -410,8 +436,8 @@ class ZMQInput(TensorInput):
...
@@ -410,8 +436,8 @@ class ZMQInput(TensorInput):
return
ret
return
ret
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs_desc
):
self
.
inputs_desc
=
inputs
self
.
inputs_desc
=
inputs
_desc
assert
len
(
self
.
inputs_desc
)
>
0
,
\
assert
len
(
self
.
inputs_desc
)
>
0
,
\
"ZMQInput has to be used with InputDesc!"
"ZMQInput has to be used with InputDesc!"
...
@@ -454,19 +480,19 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -454,19 +480,19 @@ class StagingInputWrapper(FeedfreeInput):
self
.
_stage_ops
=
[]
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_unstage_ops
=
[]
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
self
.
_input
.
setup
(
inputs
)
self
.
_input
.
setup
(
inputs
)
self
.
setup_staging_areas
()
self
.
_
setup_staging_areas
()
def
get_callbacks
(
self
):
def
_
get_callbacks
(
self
):
cbs
=
self
.
_input
.
get_callbacks
()
cbs
=
self
.
_input
.
get_callbacks
()
cbs
.
append
(
cbs
.
append
(
StagingInputWrapper
.
StagingCallback
(
StagingInputWrapper
.
StagingCallback
(
self
.
get_stage_op
(),
self
.
get_unstage_op
(),
self
.
_nr_stage
))
self
.
_get_stage_op
(),
self
.
_
get_unstage_op
(),
self
.
_nr_stage
))
return
cbs
return
cbs
def
setup_staging_areas
(
self
):
def
_
setup_staging_areas
(
self
):
logger
.
info
(
"Setting up StagingArea for GPU prefetching ..."
)
logger
.
info
(
"Setting up StagingArea for GPU prefetching ..."
)
for
idx
,
device
in
enumerate
(
self
.
_devices
):
for
idx
,
device
in
enumerate
(
self
.
_devices
):
with
tf
.
device
(
device
):
with
tf
.
device
(
device
):
...
@@ -482,22 +508,18 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -482,22 +508,18 @@ class StagingInputWrapper(FeedfreeInput):
vout
.
set_shape
(
vin
.
get_shape
())
vout
.
set_shape
(
vin
.
get_shape
())
self
.
_unstage_ops
.
append
(
outputs
)
self
.
_unstage_ops
.
append
(
outputs
)
def
size
(
self
):
def
_
size
(
self
):
return
self
.
_input
.
size
()
return
self
.
_input
.
size
()
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
ret
=
self
.
_unstage_ops
[
ctx
.
index
]
ret
=
self
.
_unstage_ops
[
ctx
.
index
]
return
ret
return
ret
@
staticmethod
def
_get_stage_op
(
self
):
def
get_staging_name
(
idx
):
return
'StagingArea{}'
.
format
(
idx
)
def
get_stage_op
(
self
):
return
tf
.
group
(
*
self
.
_stage_ops
)
return
tf
.
group
(
*
self
.
_stage_ops
)
def
get_unstage_op
(
self
):
def
_
get_unstage_op
(
self
):
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
return
tf
.
group
(
*
all_outputs
)
return
tf
.
group
(
*
all_outputs
)
...
@@ -514,18 +536,36 @@ def remap_input_source(input, names):
...
@@ -514,18 +536,36 @@ def remap_input_source(input, names):
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
"""
def
__init__
(
self
,
input
,
names
):
def
__init__
(
self
,
input
,
names
):
ProxyInputSource
.
__init__
(
self
,
input
)
ProxyInputSource
.
__init__
(
self
,
input
)
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
self
.
_names
=
tuple
(
names
)
self
.
_names
=
tuple
(
names
)
def
setup
(
self
,
inputs
):
def
_
setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
self
.
_input
.
setup
(
inputs_subset
)
self
.
_input
.
setup
(
inputs_subset
)
def
get_input_tensors
(
self
):
def
_
get_input_tensors
(
self
):
ret
=
self
.
_input
.
get_input_tensors
()
ret
=
self
.
_input
.
get_input_tensors
()
assert
len
(
ret
)
==
len
(
self
.
_names
)
assert
len
(
ret
)
==
len
(
self
.
_names
)
return
get_tensors_inputs
(
return
get_tensors_inputs
(
...
@@ -535,6 +575,6 @@ def remap_input_source(input, names):
...
@@ -535,6 +575,6 @@ def remap_input_source(input, names):
# inherit oldcls so that type check in various places would work
# inherit oldcls so that type check in various places would work
cls
=
type
(
'Remapped'
+
oldcls
.
__name__
,
(
ProxyInputSource
,
oldcls
),
{
cls
=
type
(
'Remapped'
+
oldcls
.
__name__
,
(
ProxyInputSource
,
oldcls
),
{
'__init__'
:
__init__
,
'__init__'
:
__init__
,
'
setup'
:
setup
,
'
_setup'
:
_
setup
,
'
get_input_tensors'
:
get_input_tensors
})
'
_get_input_tensors'
:
_
get_input_tensors
})
return
cls
(
input
,
names
)
return
cls
(
input
,
names
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment