Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ca0f0bd0
You need to sign in or sign up before continuing.
Commit
ca0f0bd0
authored
May 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
put device ctx into TowerContext
parent
b75ed18c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
39 deletions
+50
-39
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+18
-5
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+2
-2
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+23
-27
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-3
No files found.
tensorpack/predict/base.py
View file @
ca0f0bd0
...
...
@@ -186,10 +186,10 @@ class PredictorTowerBuilder(object):
msg
=
"Building predictor graph {} on gpu={} ..."
.
format
(
towername
,
tower
)
logger
.
info
(
msg
)
# No matter where this get called, clear any existing name scope.
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
TowerContext
(
towername
,
device
=
device
,
is_training
=
False
):
self
.
_fn
(
tower
)
# useful only when the placeholders don't have tower prefix
...
...
tensorpack/tfutils/tower.py
View file @
ca0f0bd0
...
...
@@ -15,13 +15,20 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
def
__init__
(
self
,
tower_name
,
device
=
None
,
is_training
=
None
):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
"""
self
.
_name
=
tower_name
if
device
is
None
:
device
=
'/gpu:0'
if
tf
.
test
.
is_gpu_available
()
else
'/cpu:0'
assert
self
.
index
==
int
(
device
[
-
1
]),
\
"Tower name {} and device {} mismatch!"
.
format
(
self
.
_name
,
device
)
self
.
_device
=
device
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
self
.
_is_training
=
is_training
...
...
@@ -48,6 +55,10 @@ class TowerContext(object):
return
0
return
int
(
self
.
_name
[
-
1
])
@
property
def
device
(
self
):
return
self
.
_device
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
...
...
@@ -79,16 +90,18 @@ class TowerContext(object):
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
_CurrentTowerContext
=
self
# TODO enter name_scope(None) first
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
self
.
_scope_ctx
=
tf
.
name_scope
(
self
.
_name
)
self
.
_scope_ctx
.
__enter__
()
self
.
_device_ctx
=
tf
.
device
(
self
.
_device
)
self
.
_device_ctx
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_scope_ctx
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_device_ctx
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
__str__
(
self
):
...
...
tensorpack/train/feedfree.py
View file @
ca0f0bd0
...
...
@@ -27,8 +27,8 @@ class FeedfreeTrainerBase(Trainer):
self
.
_input_tensors
=
self
.
_input_method
.
get_input_tensors
()
self
.
model
.
build_graph
(
self
.
_input_tensors
)
ctx
=
get_current_tower_context
()
if
ctx
is
None
:
with
TowerContext
(
''
):
if
ctx
is
None
:
# call without a context, use a default one
with
TowerContext
(
''
,
is_training
=
True
):
f
()
else
:
assert
ctx
.
is_training
,
ctx
...
...
tensorpack/train/input_data.py
View file @
ca0f0bd0
...
...
@@ -17,6 +17,7 @@ from six.moves import range
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.concurrency
import
ShareSessionThread
...
...
@@ -168,13 +169,14 @@ class QueueInput(FeedfreeInput):
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
with
tf
.
device
(
'/cpu:0'
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
class
BatchQueueInput
(
FeedfreeInput
):
...
...
@@ -232,15 +234,16 @@ class BatchQueueInput(FeedfreeInput):
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
=
v
.
get_shape
()
.
as_list
()
shp
[
0
]
=
self
.
batch_size
qv
.
set_shape
(
shp
)
return
ret
with
tf
.
device
(
'/cpu:0'
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
=
v
.
get_shape
()
.
as_list
()
shp
[
0
]
=
self
.
batch_size
qv
.
set_shape
(
shp
)
return
ret
class
DummyConstantInput
(
FeedfreeInput
):
...
...
@@ -254,7 +257,6 @@ class DummyConstantInput(FeedfreeInput):
"""
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
self
.
_cnt
=
0
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
...
...
@@ -271,7 +273,6 @@ class DummyConstantInput(FeedfreeInput):
# don't share variables
for
tower
in
range
(
nr_tower
):
tlist
=
[]
# TODO. keep device info in tower
with
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)):
for
idx
,
p
in
enumerate
(
placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
...
...
@@ -280,9 +281,8 @@ class DummyConstantInput(FeedfreeInput):
self
.
tensors
.
append
(
tlist
)
def
get_input_tensors
(
self
):
# TODO XXX call with tower index
ret
=
self
.
tensors
[
self
.
_cnt
]
self
.
_cnt
+=
1
ctx
=
get_current_tower_context
()
ret
=
self
.
tensors
[
ctx
.
index
]
return
ret
...
...
@@ -359,8 +359,6 @@ class StagingInputWrapper(FeedfreeInput):
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_cnt_unstage
=
0
def
setup
(
self
,
model
):
self
.
_input
.
setup
(
model
)
self
.
setup_staging_areas
()
...
...
@@ -390,10 +388,8 @@ class StagingInputWrapper(FeedfreeInput):
return
self
.
_input
.
size
()
def
get_input_tensors
(
self
):
assert
self
.
_cnt_unstage
<
len
(
self
.
_areas
)
assert
len
(
self
.
_areas
)
==
len
(
self
.
_devices
)
ret
=
self
.
_unstage_ops
[
self
.
_cnt_unstage
]
self
.
_cnt_unstage
+=
1
ctx
=
get_current_tower_context
()
ret
=
self
.
_unstage_ops
[
ctx
.
index
]
return
ret
@
staticmethod
...
...
tensorpack/train/multigpu.py
View file @
ca0f0bd0
...
...
@@ -40,9 +40,11 @@ class MultiGPUTrainer(Trainer):
ret
=
[]
global_scope
=
tf
.
get_variable_scope
()
for
idx
,
t
in
enumerate
(
towers
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
)):
with
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
),
device
=
'/gpu:{}'
.
format
(
t
),
is_training
=
True
):
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
ret
.
append
(
func
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment