Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ca0f0bd0
You need to sign in or sign up before continuing.
Commit
ca0f0bd0
authored
May 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
put device ctx into TowerContext
parent
b75ed18c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
39 deletions
+50
-39
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+18
-5
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+2
-2
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+23
-27
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-3
No files found.
tensorpack/predict/base.py
View file @
ca0f0bd0
...
@@ -186,10 +186,10 @@ class PredictorTowerBuilder(object):
...
@@ -186,10 +186,10 @@ class PredictorTowerBuilder(object):
msg
=
"Building predictor graph {} on gpu={} ..."
.
format
(
towername
,
tower
)
msg
=
"Building predictor graph {} on gpu={} ..."
.
format
(
towername
,
tower
)
logger
.
info
(
msg
)
logger
.
info
(
msg
)
# No matter where this get called, clear any existing name scope.
# No matter where this get called, clear any existing name scope.
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
device
=
device
,
is_training
=
False
):
TowerContext
(
towername
,
is_training
=
False
):
self
.
_fn
(
tower
)
self
.
_fn
(
tower
)
# useful only when the placeholders don't have tower prefix
# useful only when the placeholders don't have tower prefix
...
...
tensorpack/tfutils/tower.py
View file @
ca0f0bd0
...
@@ -15,13 +15,20 @@ _CurrentTowerContext = None
...
@@ -15,13 +15,20 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
def
__init__
(
self
,
tower_name
,
device
=
None
,
is_training
=
None
):
"""
"""
Args:
Args:
tower_name (str): 'tower0', 'towerp0', or ''
tower_name (str): 'tower0', 'towerp0', or ''
device (str): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
is_training (bool): if None, automatically determine from tower_name.
"""
"""
self
.
_name
=
tower_name
self
.
_name
=
tower_name
if
device
is
None
:
device
=
'/gpu:0'
if
tf
.
test
.
is_gpu_available
()
else
'/cpu:0'
assert
self
.
index
==
int
(
device
[
-
1
]),
\
"Tower name {} and device {} mismatch!"
.
format
(
self
.
_name
,
device
)
self
.
_device
=
device
if
is_training
is
None
:
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
self
.
_is_training
=
is_training
self
.
_is_training
=
is_training
...
@@ -48,6 +55,10 @@ class TowerContext(object):
...
@@ -48,6 +55,10 @@ class TowerContext(object):
return
0
return
0
return
int
(
self
.
_name
[
-
1
])
return
int
(
self
.
_name
[
-
1
])
@
property
def
device
(
self
):
return
self
.
_device
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
return
graph
.
get_tensor_by_name
(
name
)
...
@@ -79,16 +90,18 @@ class TowerContext(object):
...
@@ -79,16 +90,18 @@ class TowerContext(object):
assert
_CurrentTowerContext
is
None
,
\
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
"Nesting TowerContext!"
_CurrentTowerContext
=
self
_CurrentTowerContext
=
self
# TODO enter name_scope(None) first
if
len
(
self
.
_name
):
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
self
.
_scope_ctx
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
self
.
_scope_ctx
.
__enter__
()
self
.
_device_ctx
=
tf
.
device
(
self
.
_device
)
self
.
_device_ctx
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
global
_CurrentTowerContext
_CurrentTowerContext
=
None
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_scope_ctx
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_device_ctx
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
return
False
def
__str__
(
self
):
def
__str__
(
self
):
...
...
tensorpack/train/feedfree.py
View file @
ca0f0bd0
...
@@ -27,8 +27,8 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -27,8 +27,8 @@ class FeedfreeTrainerBase(Trainer):
self
.
_input_tensors
=
self
.
_input_method
.
get_input_tensors
()
self
.
_input_tensors
=
self
.
_input_method
.
get_input_tensors
()
self
.
model
.
build_graph
(
self
.
_input_tensors
)
self
.
model
.
build_graph
(
self
.
_input_tensors
)
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
None
:
if
ctx
is
None
:
# call without a context, use a default one
with
TowerContext
(
''
):
with
TowerContext
(
''
,
is_training
=
True
):
f
()
f
()
else
:
else
:
assert
ctx
.
is_training
,
ctx
assert
ctx
.
is_training
,
ctx
...
...
tensorpack/train/input_data.py
View file @
ca0f0bd0
...
@@ -17,6 +17,7 @@ from six.moves import range
...
@@ -17,6 +17,7 @@ from six.moves import range
from
..dataflow
import
DataFlow
,
RepeatedData
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..utils.concurrency
import
ShareSessionThread
from
..utils.concurrency
import
ShareSessionThread
...
@@ -168,13 +169,14 @@ class QueueInput(FeedfreeInput):
...
@@ -168,13 +169,14 @@ class QueueInput(FeedfreeInput):
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
with
tf
.
device
(
'/cpu:0'
):
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
[
ret
]
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
ret
=
[
ret
]
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
qv
.
set_shape
(
v
.
get_shape
())
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
return
ret
qv
.
set_shape
(
v
.
get_shape
())
return
ret
class
BatchQueueInput
(
FeedfreeInput
):
class
BatchQueueInput
(
FeedfreeInput
):
...
@@ -232,15 +234,16 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -232,15 +234,16 @@ class BatchQueueInput(FeedfreeInput):
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
with
tf
.
device
(
'/cpu:0'
):
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
ret
=
[
ret
]
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
ret
=
[
ret
]
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
shp
=
v
.
get_shape
()
.
as_list
()
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
[
0
]
=
self
.
batch_size
shp
=
v
.
get_shape
()
.
as_list
()
qv
.
set_shape
(
shp
)
shp
[
0
]
=
self
.
batch_size
return
ret
qv
.
set_shape
(
shp
)
return
ret
class
DummyConstantInput
(
FeedfreeInput
):
class
DummyConstantInput
(
FeedfreeInput
):
...
@@ -254,7 +257,6 @@ class DummyConstantInput(FeedfreeInput):
...
@@ -254,7 +257,6 @@ class DummyConstantInput(FeedfreeInput):
"""
"""
self
.
shapes
=
shapes
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
logger
.
warn
(
"Using dummy input for debug!"
)
self
.
_cnt
=
0
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
...
@@ -271,7 +273,6 @@ class DummyConstantInput(FeedfreeInput):
...
@@ -271,7 +273,6 @@ class DummyConstantInput(FeedfreeInput):
# don't share variables
# don't share variables
for
tower
in
range
(
nr_tower
):
for
tower
in
range
(
nr_tower
):
tlist
=
[]
tlist
=
[]
# TODO. keep device info in tower
with
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)):
with
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)):
for
idx
,
p
in
enumerate
(
placehdrs
):
for
idx
,
p
in
enumerate
(
placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
tlist
.
append
(
tf
.
get_variable
(
...
@@ -280,9 +281,8 @@ class DummyConstantInput(FeedfreeInput):
...
@@ -280,9 +281,8 @@ class DummyConstantInput(FeedfreeInput):
self
.
tensors
.
append
(
tlist
)
self
.
tensors
.
append
(
tlist
)
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
# TODO XXX call with tower index
ctx
=
get_current_tower_context
()
ret
=
self
.
tensors
[
self
.
_cnt
]
ret
=
self
.
tensors
[
ctx
.
index
]
self
.
_cnt
+=
1
return
ret
return
ret
...
@@ -359,8 +359,6 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -359,8 +359,6 @@ class StagingInputWrapper(FeedfreeInput):
self
.
_stage_ops
=
[]
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_cnt_unstage
=
0
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
_input
.
setup
(
model
)
self
.
_input
.
setup
(
model
)
self
.
setup_staging_areas
()
self
.
setup_staging_areas
()
...
@@ -390,10 +388,8 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -390,10 +388,8 @@ class StagingInputWrapper(FeedfreeInput):
return
self
.
_input
.
size
()
return
self
.
_input
.
size
()
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
assert
self
.
_cnt_unstage
<
len
(
self
.
_areas
)
ctx
=
get_current_tower_context
()
assert
len
(
self
.
_areas
)
==
len
(
self
.
_devices
)
ret
=
self
.
_unstage_ops
[
ctx
.
index
]
ret
=
self
.
_unstage_ops
[
self
.
_cnt_unstage
]
self
.
_cnt_unstage
+=
1
return
ret
return
ret
@
staticmethod
@
staticmethod
...
...
tensorpack/train/multigpu.py
View file @
ca0f0bd0
...
@@ -40,9 +40,11 @@ class MultiGPUTrainer(Trainer):
...
@@ -40,9 +40,11 @@ class MultiGPUTrainer(Trainer):
ret
=
[]
ret
=
[]
global_scope
=
tf
.
get_variable_scope
()
global_scope
=
tf
.
get_variable_scope
()
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
with
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
TowerContext
(
'tower{}'
.
format
(
idx
)):
'tower{}'
.
format
(
idx
),
device
=
'/gpu:{}'
.
format
(
t
),
is_training
=
True
):
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
ret
.
append
(
func
())
ret
.
append
(
func
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment