Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
59160bcf
Commit
59160bcf
authored
May 26, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make zmq/dummy input subclass of TensorInput
parent
ddebb23c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
54 deletions
+40
-54
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+40
-54
No files found.
tensorpack/train/input_source.py
View file @
59160bcf
...
@@ -324,46 +324,6 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -324,46 +324,6 @@ class BatchQueueInput(FeedfreeInput):
return
ret
return
ret
class
DummyConstantInput
(
FeedfreeInput
):
""" Input with some random tensor placed on GPU.
Useful for debugging performance issues """
def
__init__
(
self
,
shapes
):
"""
Args:
shapes (list[list]): a list of fully-sepcified shapes.
"""
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
def
setup_training
(
self
,
trainer
):
super
(
DummyConstantInput
,
self
)
.
setup_training
(
trainer
)
nr_tower
=
trainer
.
config
.
nr_tower
placehdrs
=
self
.
input_placehdrs
assert
len
(
self
.
shapes
)
==
len
(
placehdrs
)
self
.
tensors
=
[]
# don't share variables
for
tower
in
range
(
nr_tower
):
tlist
=
[]
with
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)):
for
idx
,
p
in
enumerate
(
placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
'dummy-{}-{}'
.
format
(
p
.
op
.
name
,
tower
),
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
))
self
.
tensors
.
append
(
tlist
)
def
get_input_tensors
(
self
):
ctx
=
get_current_tower_context
()
ret
=
self
.
tensors
[
ctx
.
index
]
return
ret
class
TensorInput
(
FeedfreeInput
):
class
TensorInput
(
FeedfreeInput
):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
...
@@ -391,27 +351,53 @@ class TensorInput(FeedfreeInput):
...
@@ -391,27 +351,53 @@ class TensorInput(FeedfreeInput):
return
self
.
get_tensor_fn
()
return
self
.
get_tensor_fn
()
class
ZMQInput
(
FeedfreeInput
):
class
DummyConstantInput
(
TensorInput
):
""" Input with some random tensor placed on GPU.
Useful for debugging performance issues """
def
__init__
(
self
,
shapes
):
"""
Args:
shapes (list[list]): a list of fully-sepcified shapes.
"""
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
def
fn
():
tlist
=
[]
ctx
=
get_current_tower_context
()
assert
len
(
self
.
shapes
)
==
len
(
self
.
input_placehdrs
)
for
idx
,
p
in
enumerate
(
self
.
input_placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
'dummy-{}-{}'
.
format
(
p
.
op
.
name
,
ctx
.
index
),
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
))
return
tlist
super
(
DummyConstantInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
# TODO doesn't support remapping
class
ZMQInput
(
TensorInput
):
def
__init__
(
self
,
endpoint
):
def
__init__
(
self
,
endpoint
):
self
.
_endpoint
=
endpoint
self
.
_endpoint
=
endpoint
def
size
(
self
):
from
tensorpack.user_ops
import
zmq_recv
raise
NotImplementedError
()
def
fn
():
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
])
if
isinstance
(
ret
,
tf
.
Tensor
):
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"ZMQInput has to be used with input placeholders!"
"ZMQInput has to be used with InputDesc!"
def
get_input_tensors
(
self
):
from
tensorpack.user_ops
import
zmq_recv
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
])
if
isinstance
(
ret
,
tf
.
Tensor
):
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
class
StagingInputWrapper
(
FeedfreeInput
):
class
StagingInputWrapper
(
FeedfreeInput
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment