Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
39afd64d
Commit
39afd64d
authored
Nov 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Reset dataflow before before_train, to avoid forking session (#494)
parent
c0a81d51
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
9 deletions
+14
-9
examples/FasterRCNN/utils/box_ops.py
examples/FasterRCNN/utils/box_ops.py
+2
-0
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+12
-9
No files found.
examples/FasterRCNN/utils/box_ops.py
View file @
39afd64d
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: box_ops.py
# File: box_ops.py
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils
import
get_default_sess_config
from
tensorpack.tfutils
import
get_default_sess_config
...
@@ -74,6 +75,7 @@ def get_iou_callable():
...
@@ -74,6 +75,7 @@ def get_iou_callable():
"""
"""
Get a pairwise box iou callable.
Get a pairwise box iou callable.
"""
"""
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
with
tf
.
Graph
()
.
as_default
(),
tf
.
device
(
'/cpu:0'
):
with
tf
.
Graph
()
.
as_default
(),
tf
.
device
(
'/cpu:0'
):
A
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
A
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
B
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
B
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
...
...
tensorpack/input_source/input_source.py
View file @
39afd64d
...
@@ -21,7 +21,7 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -21,7 +21,7 @@ from ..tfutils.tower import get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
..utils.concurrency
import
ShareSessionThread
from
..utils.concurrency
import
ShareSessionThread
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
,
CallbackFactory
from
..callbacks.graph
import
RunOp
from
..callbacks.graph
import
RunOp
__all__
=
[
'PlaceholderInput'
,
'FeedInput'
,
__all__
=
[
'PlaceholderInput'
,
'FeedInput'
,
...
@@ -33,6 +33,10 @@ __all__ = ['PlaceholderInput', 'FeedInput',
...
@@ -33,6 +33,10 @@ __all__ = ['PlaceholderInput', 'FeedInput',
'StagingInput'
]
'StagingInput'
]
def
_get_reset_callback
(
df
):
return
CallbackFactory
(
setup_graph
=
lambda
_
:
df
.
reset_state
())
class
PlaceholderInput
(
InputSource
):
class
PlaceholderInput
(
InputSource
):
"""
"""
Just produce placeholders as input tensors.
Just produce placeholders as input tensors.
...
@@ -99,7 +103,7 @@ class FeedInput(InputSource):
...
@@ -99,7 +103,7 @@ class FeedInput(InputSource):
self
.
_cb
.
_reset
()
self
.
_cb
.
_reset
()
def
_get_callbacks
(
self
):
def
_get_callbacks
(
self
):
return
[
self
.
_cb
]
return
[
self
.
_cb
,
_get_reset_callback
(
self
.
_iter_ds
)
]
class
FeedfreeInput
(
InputSource
):
class
FeedfreeInput
(
InputSource
):
...
@@ -116,10 +120,8 @@ class EnqueueThread(ShareSessionThread):
...
@@ -116,10 +120,8 @@ class EnqueueThread(ShareSessionThread):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread '
+
queue
.
name
self
.
name
=
'EnqueueThread '
+
queue
.
name
self
.
daemon
=
True
self
.
daemon
=
True
self
.
dataflow
=
ds
self
.
dataflow
=
RepeatedData
(
ds
,
-
1
)
self
.
queue
=
queue
self
.
queue
=
queue
self
.
placehdrs
=
placehdrs
self
.
placehdrs
=
placehdrs
self
.
op
=
self
.
queue
.
enqueue
(
self
.
placehdrs
)
self
.
op
=
self
.
queue
.
enqueue
(
self
.
placehdrs
)
...
@@ -130,7 +132,7 @@ class EnqueueThread(ShareSessionThread):
...
@@ -130,7 +132,7 @@ class EnqueueThread(ShareSessionThread):
def
run
(
self
):
def
run
(
self
):
with
self
.
default_sess
():
with
self
.
default_sess
():
try
:
try
:
self
.
reset_dataflow
()
self
.
_itr
=
self
.
dataflow
.
get_data
()
while
True
:
while
True
:
# pausable loop
# pausable loop
self
.
_lock
.
acquire
()
self
.
_lock
.
acquire
()
...
@@ -182,6 +184,7 @@ class QueueInput(FeedfreeInput):
...
@@ -182,6 +184,7 @@ class QueueInput(FeedfreeInput):
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
queue
=
queue
self
.
ds
=
ds
self
.
ds
=
ds
self
.
_inf_ds
=
RepeatedData
(
ds
,
-
1
)
def
_size
(
self
):
def
_size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
...
@@ -196,7 +199,7 @@ class QueueInput(FeedfreeInput):
...
@@ -196,7 +199,7 @@ class QueueInput(FeedfreeInput):
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
name
=
'input_queue'
)
name
=
'input_queue'
)
logger
.
info
(
"Setting up the queue '{}' for CPU prefetching ..."
.
format
(
self
.
queue
.
name
))
logger
.
info
(
"Setting up the queue '{}' for CPU prefetching ..."
.
format
(
self
.
queue
.
name
))
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
_inf_
ds
,
self
.
_input_placehdrs
)
self
.
_dequeue_op
=
self
.
queue
.
dequeue
(
name
=
'dequeue_for_reset'
)
self
.
_dequeue_op
=
self
.
queue
.
dequeue
(
name
=
'dequeue_for_reset'
)
...
@@ -236,7 +239,7 @@ class QueueInput(FeedfreeInput):
...
@@ -236,7 +239,7 @@ class QueueInput(FeedfreeInput):
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
cb
=
StartProcOrThread
(
self
.
thread
)
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
chief_only
=
False
cb
.
chief_only
=
False
return
[
cb
,
self
.
_create_ema_callback
()]
return
[
cb
,
self
.
_create_ema_callback
()
,
_get_reset_callback
(
self
.
_inf_ds
)
]
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
),
self
.
cached_name_scope
():
with
tf
.
device
(
'/cpu:0'
),
self
.
cached_name_scope
():
...
@@ -299,7 +302,7 @@ class BatchQueueInput(QueueInput):
...
@@ -299,7 +302,7 @@ class BatchQueueInput(QueueInput):
for
shp
in
self
.
queue
.
shapes
:
for
shp
in
self
.
queue
.
shapes
:
assert
shp
.
is_fully_defined
(),
shape_err
assert
shp
.
is_fully_defined
(),
shape_err
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
_inf_
ds
,
placehdrs_nobatch
)
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
),
self
.
cached_name_scope
():
with
tf
.
device
(
'/cpu:0'
),
self
.
cached_name_scope
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment