Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
be4759be
Commit
be4759be
authored
Jan 15, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc update
parent
58c2779f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
9 deletions
+23
-9
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+2
-2
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+16
-4
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+3
-1
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+2
-2
No files found.
tensorpack/dataflow/parallel_map.py
View file @
be4759be
...
@@ -20,8 +20,7 @@ from .parallel import (
...
@@ -20,8 +20,7 @@ from .parallel import (
__all__
=
[
'ThreadedMapData'
,
'MultiThreadMapData'
,
__all__
=
[
'ThreadedMapData'
,
'MultiThreadMapData'
,
'MultiProcessMapData'
,
'MultiProcessMapDataZMQ'
,
'MultiProcessMapData'
,
'MultiProcessMapDataZMQ'
]
'MultiProcessMapDataComponentSharedArray'
]
class
_ParallelMapData
(
ProxyDataFlow
):
class
_ParallelMapData
(
ProxyDataFlow
):
...
@@ -302,6 +301,7 @@ def _pool_map(data):
...
@@ -302,6 +301,7 @@ def _pool_map(data):
return
WORKER_ID
return
WORKER_ID
# TODO shutdown pool, improve speed.
class
MultiProcessMapDataComponentSharedArray
(
DataFlow
):
class
MultiProcessMapDataComponentSharedArray
(
DataFlow
):
"""
"""
Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
...
...
tensorpack/graph_builder/utils.py
View file @
be4759be
...
@@ -7,6 +7,8 @@ from contextlib import contextmanager
...
@@ -7,6 +7,8 @@ from contextlib import contextmanager
import
operator
import
operator
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_number
__all__
=
[
'LeastLoadedDeviceSetter'
,
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideCachingDevice'
,
'OverrideCachingDevice'
,
...
@@ -41,12 +43,22 @@ def override_to_local_variable(enable=True):
...
@@ -41,12 +43,22 @@ def override_to_local_variable(enable=True):
return
getter
(
name
,
*
args
,
**
kwargs
)
return
getter
(
name
,
*
args
,
**
kwargs
)
orig_vs
=
tf
.
get_variable_scope
()
orig_vs
=
tf
.
get_variable_scope
()
# TODO TF1.5 has https://github.com/tensorflow/tensorflow/pull/14390
if
get_tf_version_number
()
>=
1.5
:
with
tf
.
variable_scope
(
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
tf
.
get_variable_scope
(),
custom_getter
=
custom_getter
):
custom_getter
=
custom_getter
,
with
tf
.
name_scope
(
orig_vs
.
original_name_scop
e
):
auxiliary_name_scope
=
Fals
e
):
yield
yield
else
:
if
get_tf_version_number
()
>=
1.2
:
ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
else
:
ns
=
tf
.
get_variable_scope
()
.
original_name_scope
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
custom_getter
):
with
tf
.
name_scope
(
ns
+
'/'
):
yield
else
:
else
:
yield
yield
...
...
tensorpack/input_source/input_source.py
View file @
be4759be
...
@@ -118,9 +118,10 @@ class EnqueueThread(ShareSessionThread):
...
@@ -118,9 +118,10 @@ class EnqueueThread(ShareSessionThread):
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
_lock
=
threading
.
Lock
()
self
.
_lock
=
threading
.
Lock
()
# self._size = queue.size()
def
run
(
self
):
def
run
(
self
):
with
self
.
default_sess
():
with
self
.
default_sess
()
as
sess
:
try
:
try
:
self
.
reinitialize_dataflow
()
self
.
reinitialize_dataflow
()
while
True
:
while
True
:
...
@@ -130,6 +131,7 @@ class EnqueueThread(ShareSessionThread):
...
@@ -130,6 +131,7 @@ class EnqueueThread(ShareSessionThread):
dp
=
next
(
self
.
_itr
)
dp
=
next
(
self
.
_itr
)
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
# _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self
.
op
.
run
(
feed_dict
=
feed
)
self
.
op
.
run
(
feed_dict
=
feed
)
except
(
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
,
DataFlowTerminated
):
except
(
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
,
DataFlowTerminated
):
pass
pass
...
...
tensorpack/utils/concurrency.py
View file @
be4759be
...
@@ -124,10 +124,10 @@ class ShareSessionThread(threading.Thread):
...
@@ -124,10 +124,10 @@ class ShareSessionThread(threading.Thread):
def
default_sess
(
self
):
def
default_sess
(
self
):
if
self
.
_sess
:
if
self
.
_sess
:
with
self
.
_sess
.
as_default
():
with
self
.
_sess
.
as_default
():
yield
yield
self
.
_sess
else
:
else
:
logger
.
warn
(
"ShareSessionThread {} wasn't under a default session!"
.
format
(
self
.
name
))
logger
.
warn
(
"ShareSessionThread {} wasn't under a default session!"
.
format
(
self
.
name
))
yield
yield
None
def
start
(
self
):
def
start
(
self
):
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment