Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
10cc1962
Commit
10cc1962
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
extract common utilities out of train/
parent
a64d25cf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
90 additions
and
68 deletions
+90
-68
docs/conf.py
docs/conf.py
+2
-1
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+41
-3
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+2
-1
tensorpack/tfutils/distributed.py
tensorpack/tfutils/distributed.py
+41
-0
tensorpack/train/base.py
tensorpack/train/base.py
+2
-36
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+2
-27
No files found.
docs/conf.py
View file @
10cc1962
...
@@ -366,7 +366,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -366,7 +366,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'dump_chkpt_vars'
,
'dump_chkpt_vars'
,
'VisualQA'
,
'VisualQA'
,
'huber_loss'
,
'huber_loss'
,
'DumpTensor'
'DumpTensor'
,
'StepTensorPrinter'
]:
]:
return
True
return
True
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
...
...
tensorpack/callbacks/steps.py
View file @
10cc1962
...
@@ -10,14 +10,15 @@ import tqdm
...
@@ -10,14 +10,15 @@ import tqdm
from
..utils
import
logger
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..tfutils.common
import
(
from
..tfutils.common
import
(
get_op_tensor_name
,
get_op_or_tensor_by_name
)
get_op_tensor_name
,
get_op_or_tensor_by_name
,
get_global_step_var
)
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'StepTensorPrinter'
,
'ProgressBar'
]
__all__
=
[
'
TensorPrinter'
,
'
StepTensorPrinter'
,
'ProgressBar'
]
class
Step
TensorPrinter
(
Callback
):
class
TensorPrinter
(
Callback
):
""" Prints the value of some tensors in each step.
""" Prints the value of some tensors in each step.
It's an example of how ``before_run/after_run`` works.
It's an example of how ``before_run/after_run`` works.
"""
"""
...
@@ -44,6 +45,9 @@ class StepTensorPrinter(Callback):
...
@@ -44,6 +45,9 @@ class StepTensorPrinter(Callback):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
StepTensorPrinter
=
TensorPrinter
class
ProgressBar
(
Callback
):
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
""" A progress bar based on tqdm. Enabled by default. """
...
@@ -96,3 +100,37 @@ class ProgressBar(Callback):
...
@@ -96,3 +100,37 @@ class ProgressBar(Callback):
def
_after_train
(
self
):
def
_after_train
(
self
):
if
self
.
_bar
:
# training may get killed before the first step
if
self
.
_bar
:
# training may get killed before the first step
self
.
_bar
.
close
()
self
.
_bar
.
close
()
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it.
"""
_chief_only
=
False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def
_setup_graph
(
self
):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
with
self
.
graph
.
colocate_with
(
gs_var
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_op
)
def
_before_train
(
self
):
if
self
.
global_step
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
def
_before_run
(
self
,
_
):
# always increase global_step when hooked_sess.run is called
return
self
.
_fetches
def
_after_run
(
self
,
_
,
__
):
# Keep python-side global_step in agreement with TF-side
self
.
trainer
.
_global_step
+=
1
tensorpack/tfutils/__init__.py
View file @
10cc1962
...
@@ -35,4 +35,5 @@ for _, module_name, _ in iter_modules(
...
@@ -35,4 +35,5 @@ for _, module_name, _ in iter_modules(
if
module_name
in
_TO_IMPORT
:
if
module_name
in
_TO_IMPORT
:
_global_import
(
module_name
)
# import the content to tfutils.*
_global_import
(
module_name
)
# import the content to tfutils.*
__all__
.
extend
([
'sessinit'
,
'summary'
,
'optimizer'
,
__all__
.
extend
([
'sessinit'
,
'summary'
,
'optimizer'
,
'sesscreate'
,
'gradproc'
,
'varreplace'
,
'symbolic_functions'
])
'sesscreate'
,
'gradproc'
,
'varreplace'
,
'symbolic_functions'
,
'distributed'
])
tensorpack/tfutils/distributed.py
0 → 100644
View file @
10cc1962
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import
tensorflow
as
tf
def
get_distributed_session_creator
(
server
):
"""
Args:
server (tf.train.Server):
Returns:
tf.train.SessionCreator
"""
server_def
=
server
.
server_def
is_chief
=
(
server_def
.
job_name
==
'worker'
)
and
(
server_def
.
task_index
==
0
)
init_op
=
tf
.
global_variables_initializer
()
local_init_op
=
tf
.
local_variables_initializer
()
ready_op
=
tf
.
report_uninitialized_variables
()
sm
=
tf
.
train
.
SessionManager
(
local_init_op
=
local_init_op
,
ready_op
=
ready_op
,
graph
=
tf
.
get_default_graph
())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
class
_Creator
(
tf
.
train
.
SessionCreator
):
def
create_session
(
self
):
if
is_chief
:
return
sm
.
prepare_session
(
master
=
server
.
target
,
init_op
=
init_op
)
else
:
return
sm
.
wait_for_session
(
master
=
server
.
target
)
return
_Creator
()
tensorpack/train/base.py
View file @
10cc1962
...
@@ -10,15 +10,15 @@ import tensorflow as tf
...
@@ -10,15 +10,15 @@ import tensorflow as tf
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils
import
get_global_step_value
,
get_global_step_var
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sessinit
import
JustCurrentSession
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..callbacks.steps
import
MaintainStepCounter
__all__
=
[
'Trainer'
,
'StopTraining'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
...
@@ -30,40 +30,6 @@ class StopTraining(BaseException):
...
@@ -30,40 +30,6 @@ class StopTraining(BaseException):
pass
pass
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, you don't need to worry about it.
"""
chief_only
=
False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def
_setup_graph
(
self
):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
with
self
.
graph
.
colocate_with
(
gs_var
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_op
)
def
_before_train
(
self
):
if
self
.
global_step
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
def
_before_run
(
self
,
_
):
# always increase global_step when hooked_sess.run is called
return
self
.
_fetches
def
_after_run
(
self
,
_
,
__
):
# Keep python-side global_step in agreement with TF-side
self
.
trainer
.
_global_step
+=
1
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer.
""" Base class for a trainer.
...
...
tensorpack/train/distributed.py
View file @
10cc1962
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: distributed.py
# File: distributed.py
import
tensorflow
as
tf
import
os
import
os
from
..utils
import
logger
from
..utils
import
logger
from
..callbacks
import
RunOp
from
..callbacks
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils
import
get_global_step_var
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.utils
import
override_to_local_variable
from
..graph_builder.utils
import
override_to_local_variable
...
@@ -63,9 +63,6 @@ class DistributedTrainerReplicated(Trainer):
...
@@ -63,9 +63,6 @@ class DistributedTrainerReplicated(Trainer):
if
self
.
job_name
==
'worker'
:
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
config
.
tower
,
server
)
self
.
_builder
=
DistributedReplicatedBuilder
(
config
.
tower
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
self
.
_input_source
=
config
.
data
self
.
_input_source
=
config
.
data
...
@@ -117,29 +114,7 @@ class DistributedTrainerReplicated(Trainer):
...
@@ -117,29 +114,7 @@ class DistributedTrainerReplicated(Trainer):
"Cannot set session_creator or session_config for distributed training! "
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server."
)
"To use a custom session config, pass it with tf.train.Server."
)
init_op
=
tf
.
global_variables_initializer
()
self
.
config
.
session_creator
=
get_distributed_session_creator
(
self
.
server
)
local_init_op
=
tf
.
local_variables_initializer
()
ready_op
=
tf
.
report_uninitialized_variables
()
sm
=
tf
.
train
.
SessionManager
(
local_init_op
=
local_init_op
,
ready_op
=
ready_op
,
graph
=
tf
.
get_default_graph
())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def
_create_session
():
if
self
.
is_chief
:
return
sm
.
prepare_session
(
master
=
self
.
server
.
target
,
init_op
=
init_op
)
else
:
return
sm
.
wait_for_session
(
master
=
self
.
server
.
target
)
class
_Creator
(
tf
.
train
.
SessionCreator
):
def
create_session
(
self
):
return
_create_session
()
self
.
config
.
session_creator
=
_Creator
()
@
property
@
property
def
vs_name_for_predictor
(
self
):
def
vs_name_for_predictor
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment