Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
89bcdd10
Commit
89bcdd10
authored
Apr 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add local_step as callback property. remove some legacy code.
parent
762e4dcc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
37 deletions
+21
-37
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+5
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+11
-11
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+2
-12
tensorpack/train/config.py
tensorpack/train/config.py
+0
-11
No files found.
tensorpack/callbacks/base.py
View file @
89bcdd10
...
@@ -17,6 +17,7 @@ class Callback(object):
...
@@ -17,6 +17,7 @@ class Callback(object):
Attributes:
Attributes:
epoch_num(int): the number of the current epoch.
epoch_num(int): the number of the current epoch.
global_step(int): the number of global steps that have finished.
global_step(int): the number of global steps that have finished.
local_step(int): the local steps within the current epoch.
trainer(Trainer): the trainer.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
graph(tf.Graph): the graph.
...
@@ -157,6 +158,10 @@ class Callback(object):
...
@@ -157,6 +158,10 @@ class Callback(object):
def
global_step
(
self
):
def
global_step
(
self
):
return
self
.
trainer
.
global_step
return
self
.
trainer
.
global_step
@
property
def
local_step
(
self
):
return
self
.
trainer
.
local_step
def
__str__
(
self
):
def
__str__
(
self
):
return
type
(
self
)
.
__name__
return
type
(
self
)
.
__name__
...
...
tensorpack/callbacks/steps.py
View file @
89bcdd10
...
@@ -49,9 +49,8 @@ class StepTensorPrinter(Callback):
...
@@ -49,9 +49,8 @@ class StepTensorPrinter(Callback):
class
MaintainStepCounter
(
Callback
):
class
MaintainStepCounter
(
Callback
):
"""
"""
It maintains the global step in the graph and also creates the local step tensor.
It maintains the global step in the graph, making sure it's increased by one in every `run_step` call.
This callback is always enabled by the trainer, and you wouldn't need to
This callback is always enabled by the trainer, and you wouldn't need to use it.
use it.
"""
"""
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
# ensure it exists
# ensure it exists
...
@@ -69,12 +68,12 @@ class MaintainStepCounter(Callback):
...
@@ -69,12 +68,12 @@ class MaintainStepCounter(Callback):
gs_val
=
get_global_step_value
()
gs_val
=
get_global_step_value
()
if
gs_val
!=
0
:
if
gs_val
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_last_updated
=
self
.
local_step
def
_before_run
(
self
,
_
):
def
_before_run
(
self
,
_
):
# increase global_step, when trainer.local_step changed
# increase global_step, when trainer.local_step changed
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
if
self
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_last_updated
=
self
.
local_step
return
self
.
_fetches
return
self
.
_fetches
else
:
else
:
return
None
return
None
...
@@ -93,7 +92,7 @@ class ProgressBar(Callback):
...
@@ -93,7 +92,7 @@ class ProgressBar(Callback):
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_last_updated
=
self
.
local_step
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
...
@@ -104,10 +103,11 @@ class ProgressBar(Callback):
...
@@ -104,10 +103,11 @@ class ProgressBar(Callback):
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_before_run
(
self
,
_
):
def
_before_run
(
self
,
_
):
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
# update progress bar when local step changed (one step is finished)
self
.
_last_updated
=
self
.
trainer
.
local_step
if
self
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
local_step
if
self
.
trainer
.
local_step
==
0
:
if
self
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
return
self
.
_fetches
return
self
.
_fetches
...
@@ -121,7 +121,7 @@ class ProgressBar(Callback):
...
@@ -121,7 +121,7 @@ class ProgressBar(Callback):
def
_trigger_step
(
self
):
def
_trigger_step
(
self
):
self
.
_bar
.
update
()
self
.
_bar
.
update
()
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
if
self
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
self
.
_bar
.
close
()
def
_after_train
(
self
):
def
_after_train
(
self
):
...
...
tensorpack/callbacks/summary.py
View file @
89bcdd10
...
@@ -60,7 +60,7 @@ class MergeAllSummaries(Callback):
...
@@ -60,7 +60,7 @@ class MergeAllSummaries(Callback):
def
_before_run
(
self
,
ctx
):
def
_before_run
(
self
,
ctx
):
if
self
.
_run_alone
:
if
self
.
_run_alone
:
return
None
return
None
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
if
self
.
local_step
==
self
.
_total
-
1
:
return
self
.
_fetches
return
self
.
_fetches
return
None
return
None
...
...
tensorpack/callbacks/trigger.py
View file @
89bcdd10
...
@@ -12,7 +12,7 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
...
@@ -12,7 +12,7 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class
PeriodicTrigger
(
ProxyCallback
):
class
PeriodicTrigger
(
ProxyCallback
):
"""
"""
Schedule to trigger a callback every k steps or every k epochs by its ``
_
trigger()`` method.
Schedule to trigger a callback every k steps or every k epochs by its ``trigger()`` method.
"""
"""
def
__init__
(
self
,
triggerable
,
every_k_steps
=
None
,
every_k_epochs
=
None
):
def
__init__
(
self
,
triggerable
,
every_k_steps
=
None
,
every_k_epochs
=
None
):
"""
"""
...
@@ -37,7 +37,7 @@ class PeriodicTrigger(ProxyCallback):
...
@@ -37,7 +37,7 @@ class PeriodicTrigger(ProxyCallback):
return
return
# trigger_step is triggered after run_step, so
# trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished
# local_step + 1 is the number of step that have finished
if
(
self
.
trainer
.
local_step
+
1
)
%
self
.
_step_k
==
0
:
if
(
self
.
local_step
+
1
)
%
self
.
_step_k
==
0
:
self
.
cb
.
trigger
()
self
.
cb
.
trigger
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
...
tensorpack/train/base.py
View file @
89bcdd10
...
@@ -16,13 +16,13 @@ from .predict import PredictorFactory
...
@@ -16,13 +16,13 @@ from .predict import PredictorFactory
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
.monitor
import
Monitors
,
TrainingMonitor
from
.monitor
import
Monitors
,
TrainingMonitor
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
deprecated
,
log_deprecated
from
..utils.develop
import
deprecated
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_model
from
..tfutils.model_utils
import
describe_model
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
__all__
=
[
'Trainer'
,
'StopTraining'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
class
StopTraining
(
BaseException
):
...
@@ -211,13 +211,3 @@ class Trainer(object):
...
@@ -211,13 +211,3 @@ class Trainer(object):
@
deprecated
(
"Use get_predictors instead!"
,
"2017-05-20"
)
@
deprecated
(
"Use get_predictors instead!"
,
"2017-05-20"
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
self
.
get_predictors
(
input_names
,
output_names
,
n
)
return
self
.
get_predictors
(
input_names
,
output_names
,
n
)
@
deprecated
(
"Don't need to call it any more!"
,
"2017-03-20"
)
def
_setup_predictor_factory
(
self
):
pass
# back-compat
class
MultiPredictorTowerTrainer
(
Trainer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
log_deprecated
(
"MultiPredictorTowerTrainer"
,
"Just remove it instead."
,
"2017-03-21"
)
tensorpack/train/config.py
View file @
89bcdd10
...
@@ -162,17 +162,6 @@ class TrainConfig(object):
...
@@ -162,17 +162,6 @@ class TrainConfig(object):
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
log_deprecated
(
"config.set_tower"
,
"Set config.tower or config.nr_tower directly."
,
"2017-03-15"
)
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
if
nr_tower
:
tower
=
list
(
range
(
nr_tower
))
else
:
if
isinstance
(
tower
,
int
):
tower
=
list
(
range
(
tower
))
self
.
tower
=
tower
assert
isinstance
(
self
.
tower
,
list
)
@
property
@
property
def
nr_tower
(
self
):
def
nr_tower
(
self
):
return
len
(
self
.
tower
)
return
len
(
self
.
tower
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment