Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2cfefc90
Commit
2cfefc90
authored
May 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
callbacks chief_only
parent
d713bcd2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
5 deletions
+31
-5
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+11
-0
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+3
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+2
-0
tensorpack/train/base.py
tensorpack/train/base.py
+11
-3
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+1
-1
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+3
-1
No files found.
tensorpack/callbacks/base.py
View file @
2cfefc90
...
...
@@ -36,6 +36,8 @@ class Callback(object):
.. automethod:: _after_train
"""
_chief_only
=
True
def
setup_graph
(
self
,
trainer
):
self
.
_steps_per_epoch
=
trainer
.
config
.
steps_per_epoch
self
.
trainer
=
trainer
...
...
@@ -162,6 +164,15 @@ class Callback(object):
def
local_step
(
self
):
return
self
.
trainer
.
local_step
@
property
def
chief_only
(
self
):
"""
Only run this callback on chief training process.
Returns: bool
"""
return
self
.
_chief_only
def
__str__
(
self
):
return
type
(
self
)
.
__name__
...
...
tensorpack/callbacks/graph.py
View file @
2cfefc90
...
...
@@ -55,6 +55,9 @@ class RunUpdateOps(RunOp):
"""
Run ops from the collection UPDATE_OPS every step
"""
_chief_only
=
False
def
__init__
(
self
,
collection
=
tf
.
GraphKeys
.
UPDATE_OPS
):
def
f
():
ops
=
tf
.
get_collection
(
collection
)
...
...
tensorpack/callbacks/steps.py
View file @
2cfefc90
...
...
@@ -81,6 +81,8 @@ class MaintainStepCounter(Callback):
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
_chief_only
=
False
def
__init__
(
self
,
names
=
[]):
"""
Args:
...
...
tensorpack/train/base.py
View file @
2cfefc90
...
...
@@ -47,6 +47,8 @@ class Trainer(object):
global_step (int): the number of steps that have finished.
"""
is_chief
=
True
def
__init__
(
self
,
config
):
"""
Args:
...
...
@@ -79,14 +81,20 @@ class Trainer(object):
assert
isinstance
(
cb
,
Callback
),
cb
assert
not
isinstance
(
self
.
_callbacks
,
Callbacks
),
\
"Cannot register more callbacks after trainer was setup!"
self
.
_callbacks
.
append
(
cb
)
if
not
self
.
is_chief
and
cb
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
cb
)))
else
:
self
.
_callbacks
.
append
(
cb
)
def
register_monitor
(
self
,
mon
):
assert
isinstance
(
mon
,
TrainingMonitor
),
mon
assert
not
isinstance
(
self
.
monitors
,
Monitors
),
\
"Cannot register more monitors after trainer was setup!"
self
.
monitors
.
append
(
mon
)
self
.
register_callback
(
mon
)
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
self
.
monitors
.
append
(
mon
)
self
.
register_callback
(
mon
)
def
train
(
self
):
""" Start training """
...
...
tensorpack/train/distributed.py
View file @
2cfefc90
...
...
@@ -55,6 +55,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self
.
task_index
=
task_index
self
.
cluster
=
cluster
self
.
_input_source
=
config
.
data
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
super
(
DistributedReplicatedTrainer
,
self
)
.
__init__
(
config
)
worker_prefix
=
'/job:worker/task:
%
s'
%
self
.
task_index
...
...
@@ -144,7 +145,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def
setup
(
self
):
with
tf
.
device
(
self
.
param_server_device
):
gs
=
get_global_step_var
()
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
self
.
_input_source
.
setup_training
(
self
)
...
...
tensorpack/train/input_source.py
View file @
2cfefc90
...
...
@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput):
def
setup_training
(
self
,
trainer
):
super
(
QueueInput
,
self
)
.
setup_training
(
trainer
)
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
cb
=
StartProcOrThread
(
self
.
thread
)
cb
.
_chief_only
=
False
trainer
.
register_callback
(
cb
)
def
get_input_tensors
(
self
):
with
tf
.
device
(
'/cpu:0'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment