Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f0573ed2
Commit
f0573ed2
authored
Jun 04, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'distributed' (#144)
parents
a3674b47
930481f2
Changes
18
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
467 additions
and
58 deletions
+467
-58
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+15
-0
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+14
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+7
-4
tensorpack/models/common.py
tensorpack/models/common.py
+1
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+3
-1
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+5
-3
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+20
-12
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+3
-2
tensorpack/train/base.py
tensorpack/train/base.py
+41
-15
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+321
-0
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+6
-7
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+4
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+18
-7
tensorpack/train/predict.py
tensorpack/train/predict.py
+4
-1
No files found.
tensorpack/callbacks/base.py
View file @
f0573ed2
...
...
@@ -36,6 +36,8 @@ class Callback(object):
.. automethod:: _after_train
"""
_chief_only
=
True
def
setup_graph
(
self
,
trainer
):
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
trainer
=
trainer
...
...
@@ -162,6 +164,19 @@ class Callback(object):
def
local_step
(
self
):
return
self
.
trainer
.
local_step
@
property
def
chief_only
(
self
):
"""
Only run this callback on chief training process.
Returns: bool
"""
return
self
.
_chief_only
@
chief_only
.
setter
def
chief_only
(
self
,
v
):
self
.
_chief_only
=
v
def
__str__
(
self
):
return
type
(
self
)
.
__name__
...
...
tensorpack/callbacks/graph.py
View file @
f0573ed2
...
...
@@ -17,13 +17,15 @@ class RunOp(Callback):
""" Run an Op. """
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_as_trigger
=
True
,
run_step
=
False
):
run_before
=
True
,
run_as_trigger
=
True
,
run_step
=
False
,
verbose
=
False
):
"""
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
Examples:
The `DQN Example
...
...
@@ -34,27 +36,38 @@ class RunOp(Callback):
self
.
run_before
=
run_before
self
.
run_as_trigger
=
run_as_trigger
self
.
run_step
=
run_step
self
.
verbose
=
verbose
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
def
_before_train
(
self
):
if
self
.
run_before
:
self
.
_print
()
self
.
_op
.
run
()
def
_trigger
(
self
):
if
self
.
run_as_trigger
:
self
.
_print
()
self
.
_op
.
run
()
def
_before_run
(
self
,
_
):
if
self
.
run_step
:
self
.
_print
()
return
[
self
.
_op
]
def
_print
(
self
):
if
self
.
verbose
:
logger
.
info
(
"Running Op {} ..."
.
format
(
self
.
_op
.
name
))
class
RunUpdateOps
(
RunOp
):
"""
Run ops from the collection UPDATE_OPS every step
"""
_chief_only
=
False
def
__init__
(
self
,
collection
=
tf
.
GraphKeys
.
UPDATE_OPS
):
def
f
():
ops
=
tf
.
get_collection
(
collection
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
f0573ed2
...
...
@@ -90,6 +90,7 @@ class InferenceRunnerBase(Callback):
def
fn
(
_
):
in_tensors
=
self
.
_input_source
.
get_input_tensors
()
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
with
tf
.
variable_scope
(
self
.
trainer
.
vs_name_for_predictor
,
reuse
=
True
):
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
...
...
tensorpack/callbacks/param.py
View file @
f0573ed2
...
...
@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam):
self
.
var
=
v
break
else
:
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
raise
ValueError
(
"{} is not a
GLOBAL_
VARIABLE in the graph!"
.
format
(
self
.
var_name
))
def
set_value
(
self
,
v
):
""" Assign the variable a new value. """
...
...
tensorpack/callbacks/saver.py
View file @
f0573ed2
...
...
@@ -43,6 +43,7 @@ class ModelSaver(Callback):
vars
=
[]
for
key
in
self
.
var_collections
:
vars
.
extend
(
tf
.
get_collection
(
key
))
vars
=
list
(
set
(
vars
))
self
.
path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
'model'
)
if
get_tf_version_number
()
<=
1.1
:
self
.
saver
=
tf
.
train
.
Saver
(
...
...
tensorpack/callbacks/steps.py
View file @
f0573ed2
...
...
@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
self
.
gs_incr_var
=
tf
.
assign_add
(
with
tf
.
device
(
gs_var
.
device
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
# tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_
var
)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_
op
)
def
_before_train
(
self
):
gs_val
=
get_global_step_value
()
...
...
@@ -81,6 +82,8 @@ class MaintainStepCounter(Callback):
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
_chief_only
=
False
def
__init__
(
self
,
names
=
[]):
"""
Args:
...
...
tensorpack/models/common.py
View file @
f0573ed2
...
...
@@ -136,7 +136,7 @@ def layer_register(
# log shape info and add activation
logger
.
info
(
"{} output: {}"
.
format
(
scope
.
name
,
get_shape_str
(
outputs
)))
_LAYER_LOGGED
.
add
(
scope
.
name
)
_LAYER_LOGGED
.
add
(
scope
_
name
)
else
:
# run the actual function
outputs
=
func
(
*
args
,
**
actual_args
)
...
...
tensorpack/models/regularize.py
View file @
f0573ed2
...
...
@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
for
p
in
params
:
para_name
=
p
.
name
# in replicated mode, only regularize variables inside this tower
if
ctx
.
has_own_variables
and
(
not
para_name
.
startswith
(
ctx
.
vs_name
)):
if
ctx
.
has_own_variables
and
ctx
.
vs_name
and
(
not
para_name
.
startswith
(
ctx
.
vs_name
)):
continue
if
re
.
search
(
regex
,
para_name
):
costs
.
append
(
func
(
p
))
...
...
tensorpack/tfutils/common.py
View file @
f0573ed2
...
...
@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99):
conf
.
inter_op_parallelism_threads
=
0
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
mem_fraction
if
get_tf_version_number
()
>=
1.2
:
conf
.
gpu_options
.
force_gpu_compatible
=
True
conf
.
gpu_options
.
allocator_type
=
'BFC'
conf
.
gpu_options
.
allow_growth
=
True
# force gpu compatible?
conf
.
graph_options
.
optimizer_options
.
global_jit_level
=
tf
.
OptimizerOptions
.
ON_1
return
conf
...
...
tensorpack/tfutils/summary.py
View file @
f0573ed2
...
...
@@ -154,11 +154,13 @@ def add_moving_summary(v, *args, **kwargs):
for
x
in
v
:
assert
isinstance
(
x
,
tf
.
Tensor
),
x
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
# TODO will produce tower0/xxx?
# TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO use zero_debias
with
tf
.
name_scope
(
None
):
gs
=
get_global_step_var
()
with
tf
.
name_scope
(
None
),
tf
.
device
(
gs
.
device
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
decay
,
num_updates
=
g
et_global_step_var
()
,
name
=
'EMA'
)
decay
,
num_updates
=
g
s
,
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
...
...
tensorpack/tfutils/tower.py
View file @
f0573ed2
...
...
@@ -17,13 +17,16 @@ class TowerContext(object):
def
__init__
(
self
,
tower_name
,
device
=
None
,
is_training
=
None
,
var_strategy
=
'shared'
):
var_strategy
=
'shared'
,
vs_name
=
None
):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
"""
self
.
_name
=
tower_name
if
device
is
None
:
...
...
@@ -38,6 +41,13 @@ class TowerContext(object):
self
.
_var_strategy
=
var_strategy
if
self
.
_var_strategy
==
'replicated'
:
assert
self
.
_name
if
vs_name
is
None
:
self
.
_vs_name
=
self
.
_name
else
:
self
.
_vs_name
=
vs_name
else
:
assert
vs_name
is
None
,
"vs_name is only valid in 'replicated' mode!"
self
.
_vs_name
=
''
@
property
def
is_main_training_tower
(
self
):
...
...
@@ -62,12 +72,7 @@ class TowerContext(object):
# variable_scope name
@
property
def
vs_name
(
self
):
if
self
.
has_own_variables
:
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if
self
.
index
>
0
:
return
self
.
_name
return
""
return
self
.
_vs_name
@
property
def
index
(
self
):
...
...
@@ -113,13 +118,16 @@ class TowerContext(object):
self
.
_ctxs
=
[]
if
len
(
self
.
_name
):
if
self
.
has_own_variables
:
if
self
.
vs_name
:
if
len
(
self
.
vs_name
)
:
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
vs_name
))
else
:
# use existing variable scope
reuse
=
self
.
index
>
0
or
(
not
self
.
is_training
)
if
self
.
is_training
:
reuse
=
self
.
index
>
0
if
reuse
is
True
:
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
reuse
))
tf
.
get_variable_scope
(),
reuse
=
True
))
# if not training, should handle vs outside (TODO not good)
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
self
.
_ctxs
.
append
(
tf
.
device
(
self
.
_device
))
for
c
in
self
.
_ctxs
:
...
...
tensorpack/tfutils/varmanip.py
View file @
f0573ed2
...
...
@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path):
new_path
=
model_path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_path
:
logger
.
warn
(
"
[SaverRestore] {} is corrected to {} when restoring the model
."
.
format
(
model_path
,
new_path
))
"
Checkpoint path {} is auto-corrected to {}
."
.
format
(
model_path
,
new_path
))
model_path
=
new_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
return
model_path
...
...
@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path):
def
is_training_name
(
name
):
"""
This is a hack temporarily used to improve logging. Do not use this function.
Guess if a name belongs to a training-only variables.
Only used internally to avoid too many logging. Do not use it.
Returns:
bool: Guess whether this tensor is something only used in training.
...
...
tensorpack/train/base.py
View file @
f0573ed2
...
...
@@ -9,8 +9,6 @@ import six
from
six.moves
import
range
import
tensorflow
as
tf
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
.predict
import
PredictorFactory
from
.config
import
TrainConfig
...
...
@@ -21,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_model
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
__all__
=
[
'Trainer'
,
'StopTraining'
]
...
...
@@ -46,6 +45,9 @@ class Trainer(object):
local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished.
"""
# step attr only available after before_train?
is_chief
=
True
def
__init__
(
self
,
config
):
"""
...
...
@@ -79,12 +81,18 @@ class Trainer(object):
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
if
not
self
.
is_chief
and
cb
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
cb
)))
else
:
self
.
_callbacks
.
append
(
cb
)
def
register_monitor
(
self
,
mon
):
assert
isinstance
(
mon
,
TrainingMonitor
),
mon
assert
not
isinstance
(
self
.
monitors
,
Monitors
),
\
"Cannot register more monitors after trainer was setup!"
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
self
.
monitors
.
append
(
mon
)
self
.
register_callback
(
mon
)
...
...
@@ -110,6 +118,7 @@ class Trainer(object):
self
.
monitors
=
Monitors
(
self
.
monitors
)
self
.
register_callback
(
self
.
monitors
)
# TODO cache per graph, avoid describing all towers
describe_model
()
# some final operations that might modify the graph
...
...
@@ -117,21 +126,28 @@ class Trainer(object):
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
# create session
logger
.
info
(
"Creating the session ..."
)
self
.
sess
=
self
.
config
.
session_creator
.
create_session
()
self
.
_monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
None
)
self
.
_create_session
()
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
# init session
self
.
config
.
session_init
.
init
(
self
.
sess
)
else
:
assert
isinstance
(
self
.
config
.
session_init
,
JustCurrentSession
),
\
"session_init is only valid for chief worker session!"
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
def
_create_session
(
self
):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
HookedSession
(
self
.
sess
,
hooks
)
self
.
sess
=
self
.
config
.
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
@
abstractmethod
def
_setup
(
self
):
...
...
@@ -154,12 +170,14 @@ class Trainer(object):
self
.
_starting_step
=
get_global_step_value
()
try
:
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
self
.
_starting_step
=
get_global_step_value
()
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
_monitor
ed_sess
.
should_stop
():
if
self
.
hook
ed_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
...
...
@@ -169,6 +187,7 @@ class Trainer(object):
# trigger epoch outside the timing region.
self
.
_trigger_epoch
()
self
.
_callbacks
.
trigger_epoch
()
logger
.
info
(
"Training has finished!"
)
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
logger
.
info
(
"Training was stopped."
)
except
KeyboardInterrupt
:
...
...
@@ -177,7 +196,14 @@ class Trainer(object):
raise
finally
:
self
.
_callbacks
.
after_train
()
self
.
_monitored_sess
.
close
()
self
.
hooked_sess
.
close
()
@
property
def
vs_name_for_predictor
(
self
):
"""
The variable scope name a predictor should be built in.
"""
return
""
# Predictor related methods: TODO
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
...
...
tensorpack/train/distributed.py
0 → 100644
View file @
f0573ed2
This diff is collapsed.
Click to expand it.
tensorpack/train/feedfree.py
View file @
f0573ed2
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
six.moves
import
zip
from
..tfutils.tower
import
TowerContext
,
get_current_tower_context
from
.input_source
import
QueueInput
,
FeedfreeInput
...
...
@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
# assume single cost
# opt may be created under first-tower variable scope (which is '')
opt
=
self
.
model
.
get_optimizer
()
# GATE_NONE faster?
varlist
=
tf
.
trainable_variables
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
# only optimize w.r.t vars in this tower
# TODO
assumption on the first-tower empty variable scope
# TODO
use ctx.vars?
varlist
=
[
v
for
v
in
varlist
if
v
.
op
.
name
.
startswith
(
ctx
.
vs_name
+
'/'
)]
grads
=
opt
.
compute_
gradients
(
grads
=
tf
.
gradients
(
cost
,
var
_list
=
var
list
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
list
(
zip
(
grads
,
varlist
))
return
cost
,
grads
...
...
tensorpack/train/input_source.py
View file @
f0573ed2
...
...
@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
super
(
QueueInput
,
self
)
.
setup_training
(
trainer
)
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
chief_only
=
False
trainer
.
register_callback
(
cb
)
def
get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
):
...
...
@@ -365,6 +367,7 @@ class DummyConstantInput(TensorInput):
def
fn
():
tlist
=
[]
ctx
=
get_current_tower_context
()
assert
ctx
is
not
None
assert
len
(
self
.
shapes
)
==
len
(
self
.
input_placehdrs
)
for
idx
,
p
in
enumerate
(
self
.
input_placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
...
...
tensorpack/train/multigpu.py
View file @
f0573ed2
...
...
@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
class
MultiGPUTrainerBase
(
Trainer
):
""" Base class for multi-gpu training"""
@
staticmethod
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
var_strategy
=
'shared'
):
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
var_strategy
=
'shared'
,
vs_names
=
None
):
"""
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str):
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
Returns:
List of outputs of ``func``, evaluated on each tower.
...
...
@@ -70,15 +74,20 @@ class MultiGPUTrainerBase(Trainer):
keys_to_freeze
=
TOWER_FREEZE_KEYS
[:]
if
var_strategy
==
'replicated'
:
# TODO ugly
logger
.
info
(
"
UPDATE_OPS from all GPUs will be kept in the collectio
n."
)
logger
.
info
(
"
In replicated mode, UPDATE_OPS from all GPUs will be ru
n."
)
keys_to_freeze
.
remove
(
tf
.
GraphKeys
.
UPDATE_OPS
)
else
:
assert
vs_names
is
None
if
vs_names
is
None
:
vs_names
=
[
None
]
*
len
(
towers
)
for
idx
,
t
in
enumerate
(
towers
):
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
with
TowerContext
(
'tower{}'
.
format
(
idx
),
device
=
device
,
is_training
=
True
,
var_strategy
=
var_strategy
):
var_strategy
=
var_strategy
,
vs_name
=
vs_names
[
idx
]):
if
idx
==
t
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
else
:
...
...
@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
var_strategy
=
'replicated'
)
var_strategy
=
'replicated'
,
# use no variable scope for the first tower
vs_names
=
[
''
]
+
[
None
]
*
(
self
.
config
.
nr_tower
-
1
))
grads
=
self
.
_allreduce_grads
(
grad_list
)
train_ops
=
[]
...
...
@@ -261,7 +272,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
self
.
train_op
=
tf
.
group
(
*
train_ops
,
name
=
'train_op'
)
self
.
register_callback
(
RunOp
(
SyncMultiGPUTrainerReplicated
.
get_post_init_ops
,
run_before
=
True
,
run_as_trigger
=
True
))
run_before
=
True
,
run_as_trigger
=
True
,
verbose
=
True
))
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
...
...
@@ -279,7 +290,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
split_name
=
split_name
[
1
:]
copy_from
=
var_by_name
[
'/'
.
join
(
split_name
)]
post_init_ops
.
append
(
v
.
assign
(
copy_from
.
read_value
()))
return
tf
.
group
(
*
post_init_ops
,
name
=
'
init_sync_vars
'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'
sync_variables_from_tower0
'
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainerBase
,
...
...
tensorpack/train/predict.py
View file @
f0573ed2
...
...
@@ -3,6 +3,7 @@
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..predict
import
(
OnlinePredictor
,
PredictorTowerBuilder
)
...
...
@@ -19,6 +20,7 @@ class PredictorFactory(object):
"""
self
.
model
=
trainer
.
model
self
.
towers
=
trainer
.
config
.
predict_tower
self
.
vs_name
=
trainer
.
vs_name_for_predictor
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
...
...
@@ -34,6 +36,7 @@ class PredictorFactory(object):
"""
tower
=
self
.
towers
[
tower
]
# just ensure the tower exists. won't rebuild (memoized)
with
tf
.
variable_scope
(
self
.
vs_name
,
reuse
=
True
):
self
.
_tower_builder
.
build
(
tower
)
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment