Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
de9025d6
Commit
de9025d6
authored
Mar 23, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs about MapAndBatch
parent
22582cc7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
10 deletions
+14
-10
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+13
-8
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-2
No files found.
tensorpack/dataflow/parallel_map.py
View file @
de9025d6
...
@@ -354,13 +354,15 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
...
@@ -354,13 +354,15 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
enable_death_signal
(
_warn
=
self
.
identity
==
b
'0'
)
enable_death_signal
(
_warn
=
self
.
identity
==
b
'0'
)
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
# recv jobs
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
set_hwm
(
self
.
hwm
*
self
.
batch_size
)
socket
.
connect
(
self
.
input_pipe
)
socket
.
connect
(
self
.
input_pipe
)
# send results
out_socket
=
ctx
.
socket
(
zmq
.
PUSH
)
out_socket
=
ctx
.
socket
(
zmq
.
PUSH
)
out_socket
.
set_hwm
(
max
(
self
.
hwm
//
self
.
batch_size
,
5
))
out_socket
.
set_hwm
(
max
(
self
.
hwm
,
5
))
out_socket
.
connect
(
self
.
result_pipe
)
out_socket
.
connect
(
self
.
result_pipe
)
batch
=
[]
batch
=
[]
...
@@ -374,7 +376,7 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
...
@@ -374,7 +376,7 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
out_socket
.
send
(
dumps
(
dp
),
copy
=
False
)
out_socket
.
send
(
dumps
(
dp
),
copy
=
False
)
del
batch
[:]
del
batch
[:]
def
__init__
(
self
,
ds
,
num_proc
,
map_func
,
batch_size
,
buffer_size
=
1024
):
def
__init__
(
self
,
ds
,
num_proc
,
map_func
,
batch_size
,
buffer_size
=
None
):
"""
"""
Args:
Args:
ds (DataFlow): the dataflow to map
ds (DataFlow): the dataflow to map
...
@@ -382,15 +384,18 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
...
@@ -382,15 +384,18 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
map_func (callable): datapoint -> datapoint | None. Return None to
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
discard/skip the datapoint.
batch_size (int): batch size
batch_size (int): batch size
buffer_size (int): number of datapoints in the buffer
buffer_size (int): number of datapoints (not batched) in the buffer.
Defaults to batch_size * 10
"""
"""
super
(
MultiProcessMapAndBatchDataZMQ
,
self
)
.
__init__
()
super
(
MultiProcessMapAndBatchDataZMQ
,
self
)
.
__init__
()
assert
batch_size
<
buffer_size
self
.
ds
=
ds
self
.
ds
=
ds
self
.
num_proc
=
num_proc
self
.
num_proc
=
num_proc
self
.
map_func
=
map_func
self
.
map_func
=
map_func
self
.
buffer_size
=
buffer_size
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
assert
self
.
batch_size
<
buffer_size
if
buffer_size
is
None
:
buffer_size
=
batch_size
*
10
self
.
buffer_size
=
buffer_size
def
reset_state
(
self
):
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
...
@@ -401,13 +406,13 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
...
@@ -401,13 +406,13 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
.
set_hwm
(
max
(
5
,
self
.
buffer_size
*
2
//
self
.
batch_size
))
self
.
socket
.
set_hwm
(
max
(
5
,
self
.
buffer_size
//
self
.
batch_size
))
_bind_guard
(
self
.
socket
,
result_pipe
)
_bind_guard
(
self
.
socket
,
result_pipe
)
dispatcher
=
MultiProcessMapAndBatchDataZMQ
.
_Dispatcher
(
self
.
ds
,
job_pipe
,
self
.
buffer_size
)
dispatcher
=
MultiProcessMapAndBatchDataZMQ
.
_Dispatcher
(
self
.
ds
,
job_pipe
,
self
.
buffer_size
)
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
num_proc
)]
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
num_proc
)]
worker_hwm
=
max
(
3
,
self
.
buffer_size
*
2
//
self
.
num_proc
//
self
.
batch_size
)
worker_hwm
=
max
(
3
,
self
.
buffer_size
//
self
.
num_proc
//
self
.
batch_size
)
self
.
_procs
=
[
MultiProcessMapAndBatchDataZMQ
.
_Worker
(
self
.
_procs
=
[
MultiProcessMapAndBatchDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
job_pipe
,
result_pipe
,
worker_hwm
,
self
.
batch_size
)
self
.
_proc_ids
[
k
],
self
.
map_func
,
job_pipe
,
result_pipe
,
worker_hwm
,
self
.
batch_size
)
for
k
in
range
(
self
.
num_proc
)]
for
k
in
range
(
self
.
num_proc
)]
...
...
tensorpack/models/batch_norm.py
View file @
de9025d6
...
@@ -287,8 +287,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -287,8 +287,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
center
=
center
,
scale
=
scale
,
center
=
center
,
scale
=
scale
,
beta_initializer
=
beta_initializer
,
beta_initializer
=
beta_initializer
,
gamma_initializer
=
gamma_initializer
,
gamma_initializer
=
gamma_initializer
,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
fused
=
(
ndims
==
4
and
axis
in
[
1
,
3
]),
fused
=
(
ndims
==
4
and
axis
in
[
1
,
3
]
and
not
freeze_bn_backward
),
_reuse
=
tf
.
get_variable_scope
()
.
reuse
)
_reuse
=
tf
.
get_variable_scope
()
.
reuse
)
use_fp16
=
inputs
.
dtype
==
tf
.
float16
use_fp16
=
inputs
.
dtype
==
tf
.
float16
if
use_fp16
:
if
use_fp16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment