Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c04c1ef8
Commit
c04c1ef8
authored
Jun 01, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
pass Server to trainer
parent
b0677681
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
20 deletions
+17
-20
tensorpack/train/base.py
tensorpack/train/base.py
+1
-0
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+15
-20
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+1
-0
No files found.
tensorpack/train/base.py
View file @
c04c1ef8
...
...
@@ -177,6 +177,7 @@ class Trainer(object):
# trigger epoch outside the timing region.
self
.
_trigger_epoch
()
self
.
_callbacks
.
trigger_epoch
()
logger
.
info
(
"Training has finished!"
)
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
logger
.
info
(
"Training was stopped."
)
except
KeyboardInterrupt
:
...
...
tensorpack/train/distributed.py
View file @
c04c1ef8
...
...
@@ -49,11 +49,14 @@ class OverrideToLocalVariableIfNotPsVar(object):
class
DistributedReplicatedTrainer
(
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
,
job_name
,
task_index
,
cluster
):
assert
job_name
in
[
'ps'
,
'worker'
],
job_name
self
.
job_name
=
job_name
self
.
task_index
=
task_index
self
.
cluster
=
cluster
def
__init__
(
self
,
config
,
server
):
self
.
server
=
server
server_def
=
server
.
server_def
self
.
cluster
=
tf
.
train
.
ClusterSpec
(
server_def
.
cluster
)
self
.
job_name
=
server_def
.
job_name
self
.
task_index
=
server_def
.
task_index
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
job_name
self
.
_input_source
=
config
.
data
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
super
(
DistributedReplicatedTrainer
,
self
)
.
__init__
(
config
)
...
...
@@ -76,9 +79,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if
self
.
nr_gpu
>
1
:
assert
tf
.
test
.
is_gpu_available
()
# TODO staging doesn't work with dummy (require context)
# seem to only improve on >1 GPUs
if
not
isinstance
(
self
.
_input_source
,
StagingInputWrapper
):
self
.
_input_source
=
StagingInputWrapper
(
self
.
_input_source
,
self
.
raw_devices
)
#
if not isinstance(self._input_source, StagingInputWrapper):
#
self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@
staticmethod
def
_average_grads
(
tower_grads
,
devices
):
...
...
@@ -96,7 +100,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return
tower_grads
[
0
]
new_tower_grads
=
[]
with
tf
.
name_scope
(
'AvgGrad'
):
for
i
,
grad_and_vars
in
enumerate
(
zip
(
*
grad_list
)):
for
i
,
grad_and_vars
in
enumerate
(
zip
(
*
tower_grads
)):
# Ngpu * 2
with
tf
.
device
(
devices
[
i
%
nr_device
]):
v
=
grad_and_vars
[
0
][
1
]
...
...
@@ -150,18 +154,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return
var_update_ops
def
_setup
(
self
):
conf
=
get_default_sess_config
()
self
.
server
=
tf
.
train
.
Server
(
self
.
cluster
,
job_name
=
self
.
job_name
,
task_index
=
self
.
task_index
,
config
=
conf
# TODO sessconfig
)
if
self
.
job_name
==
'ps'
:
logger
.
info
(
"Running ps {}"
.
format
(
self
.
task_index
))
self
.
server
.
join
()
return
opt
=
self
.
model
.
get_optimizer
()
# in global scope, not local
return
# TODO exit and skip mainloop how?
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariableIfNotPsVar
()):
...
...
@@ -185,9 +183,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with
tf
.
device
(
self
.
param_server_device
):
gs
=
get_global_step_var
()
opt
=
self
.
model
.
get_optimizer
()
# in global scope, not local
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
self
.
_input_source
.
setup_training
(
self
)
self
.
_setup
()
self
.
monitors
=
Monitors
(
self
.
monitors
)
...
...
tensorpack/train/input_source.py
View file @
c04c1ef8
...
...
@@ -367,6 +367,7 @@ class DummyConstantInput(TensorInput):
def
fn
():
tlist
=
[]
ctx
=
get_current_tower_context
()
assert
ctx
is
not
None
assert
len
(
self
.
shapes
)
==
len
(
self
.
input_placehdrs
)
for
idx
,
p
in
enumerate
(
self
.
input_placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment