Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
98eb3db5
Commit
98eb3db5
authored
Dec 03, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
More data types support in BatchData (fix #983)
parent
135d17e2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+1
-1
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+1
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+13
-9
No files found.
examples/FasterRCNN/model_frcnn.py
View file @
98eb3db5
...
@@ -214,7 +214,7 @@ def fastrcnn_predictions(boxes, scores):
...
@@ -214,7 +214,7 @@ def fastrcnn_predictions(boxes, scores):
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection
=
-
tf
.
nn
.
top_k
(
-
selection
,
k
=
tf
.
size
(
selection
))[
0
]
sorted_selection
=
-
tf
.
nn
.
top_k
(
-
selection
,
k
=
tf
.
size
(
selection
))[
0
]
if
get_tf_version_tuple
()
>=
(
1
,
1
2
):
if
get_tf_version_tuple
()
>=
(
1
,
1
3
):
mask
=
tf
.
sparse
.
SparseTensor
(
indices
=
tf
.
expand_dims
(
sorted_selection
,
1
),
mask
=
tf
.
sparse
.
SparseTensor
(
indices
=
tf
.
expand_dims
(
sorted_selection
,
1
),
values
=
tf
.
ones_like
(
sorted_selection
,
dtype
=
tf
.
bool
),
values
=
tf
.
ones_like
(
sorted_selection
,
dtype
=
tf
.
bool
),
dense_shape
=
output_shape
)
dense_shape
=
output_shape
)
...
...
examples/FasterRCNN/train.py
View file @
98eb3db5
...
@@ -522,6 +522,7 @@ if __name__ == '__main__':
...
@@ -522,6 +522,7 @@ if __name__ == '__main__':
MODEL
=
ResNetFPNModel
()
if
cfg
.
MODE_FPN
else
ResNetC4Model
()
MODEL
=
ResNetFPNModel
()
if
cfg
.
MODE_FPN
else
ResNetC4Model
()
if
args
.
visualize
or
args
.
evaluate
or
args
.
predict
:
if
args
.
visualize
or
args
.
evaluate
or
args
.
predict
:
assert
tf
.
test
.
is_gpu_available
()
assert
args
.
load
assert
args
.
load
finalize_configs
(
is_training
=
False
)
finalize_configs
(
is_training
=
False
)
...
...
tensorpack/dataflow/common.py
View file @
98eb3db5
...
@@ -128,22 +128,26 @@ class BatchData(ProxyDataFlow):
...
@@ -128,22 +128,26 @@ class BatchData(ProxyDataFlow):
result
.
append
(
result
.
append
(
[
x
[
k
]
for
x
in
data_holder
])
[
x
[
k
]
for
x
in
data_holder
])
else
:
else
:
dt
=
data_holder
[
0
][
k
]
data
=
data_holder
[
0
][
k
]
if
type
(
dt
)
in
list
(
six
.
integer_types
)
+
[
bool
]:
if
isinstance
(
data
,
six
.
integer_types
):
tp
=
'int32'
dtype
=
'int32'
elif
type
(
dt
)
==
float
:
elif
type
(
data
)
==
bool
:
tp
=
'float32'
dtype
=
'bool'
elif
type
(
data
)
==
float
:
dtype
=
'float32'
elif
isinstance
(
data
,
(
six
.
binary_type
,
six
.
text_type
)):
dtype
=
'str'
else
:
else
:
try
:
try
:
tp
=
dt
.
dtype
dtype
=
data
.
dtype
except
AttributeError
:
except
AttributeError
:
raise
TypeError
(
"Unsupported type to batch: {}"
.
format
(
type
(
d
t
)))
raise
TypeError
(
"Unsupported type to batch: {}"
.
format
(
type
(
d
ata
)))
try
:
try
:
result
.
append
(
result
.
append
(
np
.
asarray
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
tp
))
np
.
asarray
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
dtype
))
except
Exception
as
e
:
# noqa
except
Exception
as
e
:
# noqa
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
if
isinstance
(
d
t
,
np
.
ndarray
):
if
isinstance
(
d
ata
,
np
.
ndarray
):
s
=
pprint
.
pformat
([
x
[
k
]
.
shape
for
x
in
data_holder
])
s
=
pprint
.
pformat
([
x
[
k
]
.
shape
for
x
in
data_holder
])
logger
.
error
(
"Shape of all arrays to be batched: "
+
s
)
logger
.
error
(
"Shape of all arrays to be batched: "
+
s
)
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment