Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ca0f0bd0
Commit
ca0f0bd0
authored
May 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
put device ctx into TowerContext
parent
b75ed18c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
39 deletions
+50
-39
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+18
-5
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+2
-2
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+23
-27
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-3
No files found.
tensorpack/predict/base.py
View file @
ca0f0bd0
...
...
@@ -186,10 +186,10 @@ class PredictorTowerBuilder(object):
msg
=
"Building predictor graph {} on gpu={} ..."
.
format
(
towername
,
tower
)
logger
.
info
(
msg
)
# No matter where this get called, clear any existing name scope.
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
TowerContext
(
towername
,
device
=
device
,
is_training
=
False
):
self
.
_fn
(
tower
)
# useful only when the placeholders don't have tower prefix
...
...
tensorpack/tfutils/tower.py
View file @
ca0f0bd0
...
...
@@ -15,13 +15,20 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
def
__init__
(
self
,
tower_name
,
device
=
None
,
is_training
=
None
):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
"""
self
.
_name
=
tower_name
if
device
is
None
:
device
=
'/gpu:0'
if
tf
.
test
.
is_gpu_available
()
else
'/cpu:0'
assert
self
.
index
==
int
(
device
[
-
1
]),
\
"Tower name {} and device {} mismatch!"
.
format
(
self
.
_name
,
device
)
self
.
_device
=
device
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
self
.
_is_training
=
is_training
...
...
@@ -48,6 +55,10 @@ class TowerContext(object):
return
0
return
int
(
self
.
_name
[
-
1
])
@
property
def
device
(
self
):
return
self
.
_device
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
...
...
@@ -79,16 +90,18 @@ class TowerContext(object):
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
_CurrentTowerContext
=
self
# TODO enter name_scope(None) first
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
self
.
_scope_ctx
=
tf
.
name_scope
(
self
.
_name
)
self
.
_scope_ctx
.
__enter__
()
self
.
_device_ctx
=
tf
.
device
(
self
.
_device
)
self
.
_device_ctx
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_scope_ctx
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_device_ctx
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
__str__
(
self
):
...
...
tensorpack/train/feedfree.py
View file @
ca0f0bd0
...
...
@@ -27,8 +27,8 @@ class FeedfreeTrainerBase(Trainer):
self
.
_input_tensors
=
self
.
_input_method
.
get_input_tensors
()
self
.
model
.
build_graph
(
self
.
_input_tensors
)
ctx
=
get_current_tower_context
()
if
ctx
is
None
:
with
TowerContext
(
''
):
if
ctx
is
None
:
# call without a context, use a default one
with
TowerContext
(
''
,
is_training
=
True
):
f
()
else
:
assert
ctx
.
is_training
,
ctx
...
...
tensorpack/train/input_data.py
View file @
ca0f0bd0
...
...
@@ -17,6 +17,7 @@ from six.moves import range
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.concurrency
import
ShareSessionThread
...
...
@@ -168,13 +169,14 @@ class QueueInput(FeedfreeInput):
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
with
tf
.
device
(
'/cpu:0'
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
class
BatchQueueInput
(
FeedfreeInput
):
...
...
@@ -232,15 +234,16 @@ class BatchQueueInput(FeedfreeInput):
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
=
v
.
get_shape
()
.
as_list
()
shp
[
0
]
=
self
.
batch_size
qv
.
set_shape
(
shp
)
return
ret
with
tf
.
device
(
'/cpu:0'
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
=
v
.
get_shape
()
.
as_list
()
shp
[
0
]
=
self
.
batch_size
qv
.
set_shape
(
shp
)
return
ret
class
DummyConstantInput
(
FeedfreeInput
):
...
...
@@ -254,7 +257,6 @@ class DummyConstantInput(FeedfreeInput):
"""
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
self
.
_cnt
=
0
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
...
...
@@ -271,7 +273,6 @@ class DummyConstantInput(FeedfreeInput):
# don't share variables
for
tower
in
range
(
nr_tower
):
tlist
=
[]
# TODO. keep device info in tower
with
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)):
for
idx
,
p
in
enumerate
(
placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
...
...
@@ -280,9 +281,8 @@ class DummyConstantInput(FeedfreeInput):
self
.
tensors
.
append
(
tlist
)
def
get_input_tensors
(
self
):
# TODO XXX call with tower index
ret
=
self
.
tensors
[
self
.
_cnt
]
self
.
_cnt
+=
1
ctx
=
get_current_tower_context
()
ret
=
self
.
tensors
[
ctx
.
index
]
return
ret
...
...
@@ -359,8 +359,6 @@ class StagingInputWrapper(FeedfreeInput):
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_cnt_unstage
=
0
def
setup
(
self
,
model
):
self
.
_input
.
setup
(
model
)
self
.
setup_staging_areas
()
...
...
@@ -390,10 +388,8 @@ class StagingInputWrapper(FeedfreeInput):
return
self
.
_input
.
size
()
def
get_input_tensors
(
self
):
assert
self
.
_cnt_unstage
<
len
(
self
.
_areas
)
assert
len
(
self
.
_areas
)
==
len
(
self
.
_devices
)
ret
=
self
.
_unstage_ops
[
self
.
_cnt_unstage
]
self
.
_cnt_unstage
+=
1
ctx
=
get_current_tower_context
()
ret
=
self
.
_unstage_ops
[
ctx
.
index
]
return
ret
@
staticmethod
...
...
tensorpack/train/multigpu.py
View file @
ca0f0bd0
...
...
@@ -40,9 +40,11 @@ class MultiGPUTrainer(Trainer):
ret
=
[]
global_scope
=
tf
.
get_variable_scope
()
for
idx
,
t
in
enumerate
(
towers
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
)):
with
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
),
device
=
'/gpu:{}'
.
format
(
t
),
is_training
=
True
):
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
ret
.
append
(
func
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment