Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3ab6d2b0
Commit
3ab6d2b0
authored
Jun 03, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
variable scope issues for model saving / predictor
parent
95cb6ba2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
48 deletions
+88
-48
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+1
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+18
-12
tensorpack/train/base.py
tensorpack/train/base.py
+21
-2
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+26
-26
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+15
-4
tensorpack/train/predict.py
tensorpack/train/predict.py
+4
-1
No files found.
tensorpack/callbacks/inference_runner.py
View file @
3ab6d2b0
...
@@ -90,7 +90,8 @@ class InferenceRunnerBase(Callback):
...
@@ -90,7 +90,8 @@ class InferenceRunnerBase(Callback):
def
fn
(
_
):
def
fn
(
_
):
in_tensors
=
self
.
_input_source
.
get_input_tensors
()
in_tensors
=
self
.
_input_source
.
get_input_tensors
()
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
with
tf
.
variable_scope
(
self
.
trainer
.
vs_name_for_predictor
,
reuse
=
True
):
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
...
...
tensorpack/models/regularize.py
View file @
3ab6d2b0
...
@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
...
@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
for
p
in
params
:
for
p
in
params
:
para_name
=
p
.
name
para_name
=
p
.
name
# in replicated mode, only regularize variables inside this tower
# in replicated mode, only regularize variables inside this tower
if
ctx
.
has_own_variables
and
(
not
para_name
.
startswith
(
ctx
.
vs_name
)):
if
ctx
.
has_own_variables
and
ctx
.
vs_name
and
(
not
para_name
.
startswith
(
ctx
.
vs_name
)):
continue
continue
if
re
.
search
(
regex
,
para_name
):
if
re
.
search
(
regex
,
para_name
):
costs
.
append
(
func
(
p
))
costs
.
append
(
func
(
p
))
...
...
tensorpack/tfutils/tower.py
View file @
3ab6d2b0
...
@@ -17,13 +17,16 @@ class TowerContext(object):
...
@@ -17,13 +17,16 @@ class TowerContext(object):
def
__init__
(
self
,
tower_name
,
def
__init__
(
self
,
tower_name
,
device
=
None
,
is_training
=
None
,
device
=
None
,
is_training
=
None
,
var_strategy
=
'shared'
):
var_strategy
=
'shared'
,
vs_name
=
None
):
"""
"""
Args:
Args:
tower_name (str): 'tower0', 'towerp0', or ''
tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'.
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
"""
"""
self
.
_name
=
tower_name
self
.
_name
=
tower_name
if
device
is
None
:
if
device
is
None
:
...
@@ -38,6 +41,11 @@ class TowerContext(object):
...
@@ -38,6 +41,11 @@ class TowerContext(object):
self
.
_var_strategy
=
var_strategy
self
.
_var_strategy
=
var_strategy
if
self
.
_var_strategy
==
'replicated'
:
if
self
.
_var_strategy
==
'replicated'
:
assert
self
.
_name
assert
self
.
_name
if
vs_name
is
None
:
self
.
_vs_name
=
self
.
_name
else
:
assert
vs_name
is
None
,
"vs_name is only valid in 'replicated' mode!"
self
.
_vs_name
=
''
@
property
@
property
def
is_main_training_tower
(
self
):
def
is_main_training_tower
(
self
):
...
@@ -62,12 +70,7 @@ class TowerContext(object):
...
@@ -62,12 +70,7 @@ class TowerContext(object):
# variable_scope name
# variable_scope name
@
property
@
property
def
vs_name
(
self
):
def
vs_name
(
self
):
if
self
.
has_own_variables
:
return
self
.
_vs_name
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if
self
.
index
>
0
:
return
self
.
_name
return
""
@
property
@
property
def
index
(
self
):
def
index
(
self
):
...
@@ -113,13 +116,16 @@ class TowerContext(object):
...
@@ -113,13 +116,16 @@ class TowerContext(object):
self
.
_ctxs
=
[]
self
.
_ctxs
=
[]
if
len
(
self
.
_name
):
if
len
(
self
.
_name
):
if
self
.
has_own_variables
:
if
self
.
has_own_variables
:
if
self
.
vs_name
:
if
len
(
self
.
vs_name
)
:
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
vs_name
))
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
vs_name
))
else
:
else
:
# use existing variable scope
if
self
.
is_training
:
reuse
=
self
.
index
>
0
or
(
not
self
.
is_training
)
reuse
=
self
.
index
>
0
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
if
reuse
is
True
:
tf
.
get_variable_scope
(),
reuse
=
reuse
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
))
# if not training, should handle vs outside (TODO not good)
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
self
.
_ctxs
.
append
(
tf
.
device
(
self
.
_device
))
self
.
_ctxs
.
append
(
tf
.
device
(
self
.
_device
))
for
c
in
self
.
_ctxs
:
for
c
in
self
.
_ctxs
:
...
...
tensorpack/train/base.py
View file @
3ab6d2b0
...
@@ -19,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
...
@@ -19,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_model
from
..tfutils.model_utils
import
describe_model
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
__all__
=
[
'Trainer'
,
'StopTraining'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
...
@@ -44,6 +45,7 @@ class Trainer(object):
...
@@ -44,6 +45,7 @@ class Trainer(object):
local_step (int): the number of steps that have finished in the current epoch.
local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished.
global_step (int): the number of steps that have finished.
"""
"""
# step attr only available after before_train?
is_chief
=
True
is_chief
=
True
...
@@ -124,11 +126,19 @@ class Trainer(object):
...
@@ -124,11 +126,19 @@ class Trainer(object):
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
if
self
.
is_chief
:
self
.
config
.
session_init
.
_setup_graph
()
# This might finalize the graph (in distributed)
logger
.
info
(
"Creating the session ..."
)
logger
.
info
(
"Creating the session ..."
)
self
.
_create_session
()
self
.
_create_session
()
logger
.
info
(
"Initializing the session ..."
)
if
self
.
is_chief
:
self
.
config
.
session_init
.
init
(
self
.
sess
)
logger
.
info
(
"Initializing the session ..."
)
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
else
:
assert
isinstance
(
self
.
config
.
session_init
,
JustCurrentSession
),
\
"session_init is only valid for chief worker session!"
self
.
sess
.
graph
.
finalize
()
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
logger
.
info
(
"Graph Finalized."
)
...
@@ -164,6 +174,8 @@ class Trainer(object):
...
@@ -164,6 +174,8 @@ class Trainer(object):
self
.
_starting_step
=
get_global_step_value
()
self
.
_starting_step
=
get_global_step_value
()
try
:
try
:
self
.
_callbacks
.
before_train
()
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
self
.
_starting_step
=
get_global_step_value
()
for
self
.
epoch_num
in
range
(
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
...
@@ -190,6 +202,13 @@ class Trainer(object):
...
@@ -190,6 +202,13 @@ class Trainer(object):
self
.
_callbacks
.
after_train
()
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
self
.
hooked_sess
.
close
()
@
property
def
vs_name_for_predictor
(
self
):
"""
The variable scope name a predictor should be built in.
"""
return
""
# Predictor related methods: TODO
# Predictor related methods: TODO
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
"""
...
...
tensorpack/train/distributed.py
View file @
3ab6d2b0
...
@@ -14,18 +14,15 @@ from ..tfutils.common import get_global_step_var, get_op_tensor_name
...
@@ -14,18 +14,15 @@ from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__
=
[
'DistributedReplicatedTrainer'
]
__all__
=
[
'DistributedReplicatedTrainer'
]
# Note that only trainable vars are shadowed
# TODO only trainable model vars are saved
PS_SHADOW_VAR_PREFIX
=
'ps_var'
class
OverrideToLocalVariable
IfNotPsVar
(
object
):
class
OverrideToLocalVariable
(
object
):
"""
"""
Ensures the created variable
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
name
.
startswith
(
PS_SHADOW_VAR_PREFIX
):
return
getter
(
*
args
,
**
kwargs
)
if
'collections'
in
kwargs
:
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
collections
=
kwargs
[
'collections'
]
if
not
collections
:
if
not
collections
:
...
@@ -103,7 +100,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -103,7 +100,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"""
"""
ps_var_grads
=
[]
ps_var_grads
=
[]
for
grad
,
var
in
avg_grads
:
for
grad
,
var
in
avg_grads
:
my_name
=
PS_SHADOW_VAR_PREFIX
+
'/'
+
var
.
name
assert
var
.
name
.
startswith
(
'tower'
),
var
.
name
my_name
=
'/'
.
join
(
var
.
name
.
split
(
'/'
)[
1
:])
my_name
=
get_op_tensor_name
(
my_name
)[
0
]
my_name
=
get_op_tensor_name
(
my_name
)[
0
]
new_v
=
tf
.
get_variable
(
my_name
,
dtype
=
var
.
dtype
.
base_dtype
,
new_v
=
tf
.
get_variable
(
my_name
,
dtype
=
var
.
dtype
.
base_dtype
,
initializer
=
var
.
initial_value
,
initializer
=
var
.
initial_value
,
...
@@ -141,26 +139,29 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -141,26 +139,29 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
logger
.
info
(
"Running ps {}"
.
format
(
self
.
task_index
))
logger
.
info
(
"Running ps {}"
.
format
(
self
.
task_index
))
self
.
server
.
join
()
self
.
server
.
join
()
return
# TODO exit and skip mainloop how?
return
# TODO exit and skip mainloop how?
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
with
tf
.
device
(
self
.
param_server_device
):
with
tf
.
device
(
self
.
param_server_device
):
get_global_step_var
()
gs
=
get_global_step_var
()
assert
gs
.
device
,
gs
.
device
self
.
model
.
get_optimizer
()
# TODO in global scope, not local
self
.
model
.
get_optimizer
()
# TODO in global scope, not local
# do this before super.setup because input_source my need global step
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
with
tf
.
variable_scope
(
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariable
IfNotPsVar
()):
custom_getter
=
OverrideToLocalVariable
()):
# Ngpu * Nvar * 2
# Ngpu * Nvar * 2
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
lambda
:
self
.
_get_cost_and_grad
()[
1
],
devices
=
self
.
raw_devices
,
devices
=
self
.
raw_devices
,
var_strategy
=
'replicated'
)
var_strategy
=
'replicated'
,
vs_names
=
None
)
# use the default vs names
avg_grads
=
DistributedReplicatedTrainer
.
_average_grads
(
grad_list
,
self
.
raw_devices
)
avg_grads
=
DistributedReplicatedTrainer
.
_average_grads
(
grad_list
,
self
.
raw_devices
)
with
tf
.
device
(
self
.
param_server_device
):
with
tf
.
device
(
self
.
param_server_device
):
ps_var_grads
=
DistributedReplicatedTrainer
.
_apply_shadow_vars
(
avg_grads
)
ps_var_grads
=
DistributedReplicatedTrainer
.
_apply_shadow_vars
(
avg_grads
)
var_update_ops
=
self
.
_apply_gradients_and_copy
(
grad_list
,
ps_var_grads
)
var_update_ops
=
self
.
_apply_gradients_and_copy
(
grad_list
,
ps_var_grads
)
self
.
_shadow_vars
=
[
v
for
(
_
,
v
)
in
ps_var_grads
]
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
...
@@ -180,7 +181,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -180,7 +181,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"Cannot set session_creator or session_config for distributed training! "
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor."
)
"To use a custom session config, pass it to the tf.train.Server constructor."
)
# TODO use scaffold
# TODO use scaffold
+ monitored session
class
SupervisedSessionCreator
(
tf
.
train
.
SessionCreator
):
class
SupervisedSessionCreator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
is_chief
,
target
):
def
__init__
(
self
,
is_chief
,
target
):
self
.
is_chief
=
is_chief
self
.
is_chief
=
is_chief
...
@@ -239,18 +240,17 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -239,18 +240,17 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
local_vars
=
tf
.
local_variables
()
local_vars
=
tf
.
local_variables
()
local_var_by_name
=
dict
([(
strip_port
(
v
.
name
),
v
)
for
v
in
local_vars
])
local_var_by_name
=
dict
([(
strip_port
(
v
.
name
),
v
)
for
v
in
local_vars
])
post_init_ops
=
[]
post_init_ops
=
[]
for
v
in
tf
.
global_variables
():
for
v
in
self
.
_shadow_vars
:
if
v
.
name
.
startswith
(
PS_SHADOW_VAR_PREFIX
+
'/'
):
vname
=
strip_port
(
v
.
name
)
prefix
=
strip_port
(
for
i
in
range
(
self
.
nr_gpu
):
v
.
name
[
len
(
PS_SHADOW_VAR_PREFIX
+
'/'
):])
name
=
'tower
%
s/
%
s'
%
(
i
,
vname
)
for
i
in
range
(
self
.
nr_gpu
):
if
name
in
local_var_by_name
:
if
i
==
0
:
copy_to
=
local_var_by_name
[
name
]
name
=
prefix
# no prefix for tower0
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
else
:
else
:
name
=
'tower
%
s/
%
s'
%
(
i
,
prefix
)
logger
.
warn
(
"Global variable {} doesn't match a corresponding local var"
.
format
(
v
.
name
))
if
name
in
local_var_by_name
:
copy_to
=
local_var_by_name
[
name
]
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
else
:
logger
.
warn
(
"Global varable {} doesn't match a corresponding local var"
.
format
(
v
.
name
))
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_ps'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_ps'
)
@
property
def
vs_name_for_predictor
(
self
):
return
"tower0"
tensorpack/train/feedfree.py
View file @
3ab6d2b0
...
@@ -71,7 +71,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -71,7 +71,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
# only optimize w.r.t vars in this tower
# only optimize w.r.t vars in this tower
# TODO
assumption on the first-tower empty variable scope
# TODO
use ctx.vars?
varlist
=
[
v
for
v
in
varlist
if
v
.
op
.
name
.
startswith
(
ctx
.
vs_name
+
'/'
)]
varlist
=
[
v
for
v
in
varlist
if
v
.
op
.
name
.
startswith
(
ctx
.
vs_name
+
'/'
)]
grads
=
opt
.
compute_gradients
(
grads
=
opt
.
compute_gradients
(
cost
,
cost
,
...
...
tensorpack/train/multigpu.py
View file @
3ab6d2b0
...
@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
...
@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
class
MultiGPUTrainerBase
(
Trainer
):
class
MultiGPUTrainerBase
(
Trainer
):
""" Base class for multi-gpu training"""
""" Base class for multi-gpu training"""
@
staticmethod
@
staticmethod
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
var_strategy
=
'shared'
):
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
var_strategy
=
'shared'
,
vs_names
=
None
):
"""
"""
Args:
Args:
towers: list of gpu relative ids
towers: list of gpu relative ids
func: a lambda to be called inside each tower
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str):
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
Returns:
Returns:
List of outputs of ``func``, evaluated on each tower.
List of outputs of ``func``, evaluated on each tower.
...
@@ -72,13 +76,18 @@ class MultiGPUTrainerBase(Trainer):
...
@@ -72,13 +76,18 @@ class MultiGPUTrainerBase(Trainer):
if
var_strategy
==
'replicated'
:
# TODO ugly
if
var_strategy
==
'replicated'
:
# TODO ugly
logger
.
info
(
"In replicated mode, UPDATE_OPS from all GPUs will be run."
)
logger
.
info
(
"In replicated mode, UPDATE_OPS from all GPUs will be run."
)
keys_to_freeze
.
remove
(
tf
.
GraphKeys
.
UPDATE_OPS
)
keys_to_freeze
.
remove
(
tf
.
GraphKeys
.
UPDATE_OPS
)
else
:
assert
vs_names
is
None
if
vs_names
is
None
:
vs_names
=
[
None
]
*
len
(
towers
)
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
with
TowerContext
(
with
TowerContext
(
'tower{}'
.
format
(
idx
),
'tower{}'
.
format
(
idx
),
device
=
device
,
is_training
=
True
,
device
=
device
,
is_training
=
True
,
var_strategy
=
var_strategy
):
var_strategy
=
var_strategy
,
vs_name
=
vs_names
[
idx
]):
if
idx
==
t
:
if
idx
==
t
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
else
:
else
:
...
@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
...
@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
lambda
:
self
.
_get_cost_and_grad
()[
1
],
var_strategy
=
'replicated'
)
var_strategy
=
'replicated'
,
# use no variable scope for the first tower
vs_names
=
[
''
]
+
[
None
]
*
self
.
config
.
nr_tower
-
1
)
grads
=
self
.
_allreduce_grads
(
grad_list
)
grads
=
self
.
_allreduce_grads
(
grad_list
)
train_ops
=
[]
train_ops
=
[]
...
...
tensorpack/train/predict.py
View file @
3ab6d2b0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
# File: predict.py
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..predict
import
(
OnlinePredictor
,
from
..predict
import
(
OnlinePredictor
,
PredictorTowerBuilder
)
PredictorTowerBuilder
)
...
@@ -19,6 +20,7 @@ class PredictorFactory(object):
...
@@ -19,6 +20,7 @@ class PredictorFactory(object):
"""
"""
self
.
model
=
trainer
.
model
self
.
model
=
trainer
.
model
self
.
towers
=
trainer
.
config
.
predict_tower
self
.
towers
=
trainer
.
config
.
predict_tower
self
.
vs_name
=
trainer
.
vs_name_for_predictor
def
fn
(
_
):
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
...
@@ -34,7 +36,8 @@ class PredictorFactory(object):
...
@@ -34,7 +36,8 @@ class PredictorFactory(object):
"""
"""
tower
=
self
.
towers
[
tower
]
tower
=
self
.
towers
[
tower
]
# just ensure the tower exists. won't rebuild (memoized)
# just ensure the tower exists. won't rebuild (memoized)
self
.
_tower_builder
.
build
(
tower
)
with
tf
.
variable_scope
(
self
.
vs_name
,
reuse
=
True
):
self
.
_tower_builder
.
build
(
tower
)
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment