Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
10cc1962
Commit
10cc1962
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
extract common utilities out of train/
parent
a64d25cf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
90 additions
and
68 deletions
+90
-68
docs/conf.py
docs/conf.py
+2
-1
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+41
-3
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+2
-1
tensorpack/tfutils/distributed.py
tensorpack/tfutils/distributed.py
+41
-0
tensorpack/train/base.py
tensorpack/train/base.py
+2
-36
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+2
-27
No files found.
docs/conf.py
View file @
10cc1962
...
...
@@ -366,7 +366,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'dump_chkpt_vars'
,
'VisualQA'
,
'huber_loss'
,
'DumpTensor'
'DumpTensor'
,
'StepTensorPrinter'
]:
return
True
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
...
...
tensorpack/callbacks/steps.py
View file @
10cc1962
...
...
@@ -10,14 +10,15 @@ import tqdm
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..tfutils.common
import
(
get_op_tensor_name
,
get_op_or_tensor_by_name
)
get_op_tensor_name
,
get_op_or_tensor_by_name
,
get_global_step_var
)
from
.base
import
Callback
__all__
=
[
'StepTensorPrinter'
,
'ProgressBar'
]
__all__
=
[
'
TensorPrinter'
,
'
StepTensorPrinter'
,
'ProgressBar'
]
class
Step
TensorPrinter
(
Callback
):
class
TensorPrinter
(
Callback
):
""" Prints the value of some tensors in each step.
It's an example of how ``before_run/after_run`` works.
"""
...
...
@@ -44,6 +45,9 @@ class StepTensorPrinter(Callback):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
StepTensorPrinter
=
TensorPrinter
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
...
...
@@ -96,3 +100,37 @@ class ProgressBar(Callback):
def
_after_train
(
self
):
if
self
.
_bar
:
# training may get killed before the first step
self
.
_bar
.
close
()
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it.
"""
_chief_only
=
False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def
_setup_graph
(
self
):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
with
self
.
graph
.
colocate_with
(
gs_var
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_op
)
def
_before_train
(
self
):
if
self
.
global_step
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
def
_before_run
(
self
,
_
):
# always increase global_step when hooked_sess.run is called
return
self
.
_fetches
def
_after_run
(
self
,
_
,
__
):
# Keep python-side global_step in agreement with TF-side
self
.
trainer
.
_global_step
+=
1
tensorpack/tfutils/__init__.py
View file @
10cc1962
...
...
@@ -35,4 +35,5 @@ for _, module_name, _ in iter_modules(
if
module_name
in
_TO_IMPORT
:
_global_import
(
module_name
)
# import the content to tfutils.*
__all__
.
extend
([
'sessinit'
,
'summary'
,
'optimizer'
,
'sesscreate'
,
'gradproc'
,
'varreplace'
,
'symbolic_functions'
])
'sesscreate'
,
'gradproc'
,
'varreplace'
,
'symbolic_functions'
,
'distributed'
])
tensorpack/tfutils/distributed.py
0 → 100644
View file @
10cc1962
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import
tensorflow
as
tf
def
get_distributed_session_creator
(
server
):
"""
Args:
server (tf.train.Server):
Returns:
tf.train.SessionCreator
"""
server_def
=
server
.
server_def
is_chief
=
(
server_def
.
job_name
==
'worker'
)
and
(
server_def
.
task_index
==
0
)
init_op
=
tf
.
global_variables_initializer
()
local_init_op
=
tf
.
local_variables_initializer
()
ready_op
=
tf
.
report_uninitialized_variables
()
sm
=
tf
.
train
.
SessionManager
(
local_init_op
=
local_init_op
,
ready_op
=
ready_op
,
graph
=
tf
.
get_default_graph
())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
class
_Creator
(
tf
.
train
.
SessionCreator
):
def
create_session
(
self
):
if
is_chief
:
return
sm
.
prepare_session
(
master
=
server
.
target
,
init_op
=
init_op
)
else
:
return
sm
.
wait_for_session
(
master
=
server
.
target
)
return
_Creator
()
tensorpack/train/base.py
View file @
10cc1962
...
...
@@ -10,15 +10,15 @@ import tensorflow as tf
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils
import
get_global_step_value
,
get_global_step_var
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..callbacks.steps
import
MaintainStepCounter
__all__
=
[
'Trainer'
,
'StopTraining'
]
...
...
@@ -30,40 +30,6 @@ class StopTraining(BaseException):
pass
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, you don't need to worry about it.
"""
chief_only
=
False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def
_setup_graph
(
self
):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
with
self
.
graph
.
colocate_with
(
gs_var
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_op
)
def
_before_train
(
self
):
if
self
.
global_step
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
def
_before_run
(
self
,
_
):
# always increase global_step when hooked_sess.run is called
return
self
.
_fetches
def
_after_run
(
self
,
_
,
__
):
# Keep python-side global_step in agreement with TF-side
self
.
trainer
.
_global_step
+=
1
class
Trainer
(
object
):
""" Base class for a trainer.
...
...
tensorpack/train/distributed.py
View file @
10cc1962
...
...
@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*-
# File: distributed.py
import
tensorflow
as
tf
import
os
from
..utils
import
logger
from
..callbacks
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.utils
import
override_to_local_variable
...
...
@@ -63,9 +63,6 @@ class DistributedTrainerReplicated(Trainer):
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
config
.
tower
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
self
.
_input_source
=
config
.
data
...
...
@@ -117,29 +114,7 @@ class DistributedTrainerReplicated(Trainer):
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server."
)
init_op
=
tf
.
global_variables_initializer
()
local_init_op
=
tf
.
local_variables_initializer
()
ready_op
=
tf
.
report_uninitialized_variables
()
sm
=
tf
.
train
.
SessionManager
(
local_init_op
=
local_init_op
,
ready_op
=
ready_op
,
graph
=
tf
.
get_default_graph
())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def
_create_session
():
if
self
.
is_chief
:
return
sm
.
prepare_session
(
master
=
self
.
server
.
target
,
init_op
=
init_op
)
else
:
return
sm
.
wait_for_session
(
master
=
self
.
server
.
target
)
class
_Creator
(
tf
.
train
.
SessionCreator
):
def
create_session
(
self
):
return
_create_session
()
self
.
config
.
session_creator
=
_Creator
()
self
.
config
.
session_creator
=
get_distributed_session_creator
(
self
.
server
)
@
property
def
vs_name_for_predictor
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment