Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a6aa8bce
Commit
a6aa8bce
authored
Aug 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Define global_step to be the number of hooked_sess.run calls
parent
b747c068
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
70 deletions
+59
-70
examples/GAN/WGAN.py
examples/GAN/WGAN.py
+2
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+3
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+2
-38
tensorpack/train/base.py
tensorpack/train/base.py
+50
-25
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+2
-1
No files found.
examples/GAN/WGAN.py
View file @
a6aa8bce
...
@@ -83,7 +83,6 @@ if __name__ == '__main__':
...
@@ -83,7 +83,6 @@ if __name__ == '__main__':
max_epoch
=
200
,
max_epoch
=
200
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
)
"""
# The original code uses a different schedule, but this seems to work well.
The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
"""
SeparateGANTrainer
(
config
,
d_period
=
3
)
.
train
()
SeparateGANTrainer
(
config
,
d_period
=
3
)
.
train
()
tensorpack/callbacks/base.py
View file @
a6aa8bce
...
@@ -19,9 +19,9 @@ class Callback(object):
...
@@ -19,9 +19,9 @@ class Callback(object):
for more detailed explanation of the callback methods.
for more detailed explanation of the callback methods.
Attributes:
Attributes:
epoch_num(int): t
he number of the current epoch.
epoch_num(int): t
rainer.epoch_num
global_step(int): t
he number of global steps that have finished or is currently running.
global_step(int): t
rainer.global_step
local_step(int): t
he local steps within the current epoch.
local_step(int): t
rainer.local_step
trainer(Trainer): the trainer.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
graph(tf.Graph): the graph.
...
...
tensorpack/callbacks/steps.py
View file @
a6aa8bce
...
@@ -10,14 +10,11 @@ import tqdm
...
@@ -10,14 +10,11 @@ import tqdm
from
..utils
import
logger
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..tfutils.common
import
(
from
..tfutils.common
import
(
get_op_tensor_name
,
get_global_step_var
,
get_op_tensor_name
,
get_op_or_tensor_by_name
)
get_global_step_value
,
get_op_or_tensor_by_name
)
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'StepTensorPrinter'
,
'MaintainStepCounter'
,
__all__
=
[
'StepTensorPrinter'
,
'ProgressBar'
]
'ProgressBar'
]
class
StepTensorPrinter
(
Callback
):
class
StepTensorPrinter
(
Callback
):
...
@@ -47,39 +44,6 @@ class StepTensorPrinter(Callback):
...
@@ -47,39 +44,6 @@ class StepTensorPrinter(Callback):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one in every `run_step` call.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def
_setup_graph
(
self
):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
with
tf
.
device
(
gs_var
.
device
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
# tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_op
)
def
_before_train
(
self
):
gs_val
=
get_global_step_value
()
if
gs_val
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
self
.
_last_updated
=
self
.
local_step
def
_before_run
(
self
,
_
):
# increase global_step, when trainer.local_step changed
if
self
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
local_step
return
self
.
_fetches
else
:
return
None
class
ProgressBar
(
Callback
):
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
""" A progress bar based on tqdm. Enabled by default. """
...
...
tensorpack/train/base.py
View file @
a6aa8bce
...
@@ -8,17 +8,19 @@ from six.moves import range
...
@@ -8,17 +8,19 @@ from six.moves import range
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..graph_builder.predictor_factory
import
PredictorFactory
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..utils.develop
import
deprecated
from
..utils.develop
import
deprecated
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
,
get_global_step_var
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sessinit
import
JustCurrentSession
from
..graph_builder.predictor_factory
import
PredictorFactory
__all__
=
[
'Trainer'
,
'StopTraining'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
...
@@ -29,6 +31,34 @@ class StopTraining(BaseException):
...
@@ -29,6 +31,34 @@ class StopTraining(BaseException):
pass
pass
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def
_setup_graph
(
self
):
# ensure it exists
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
with
tf
.
device
(
gs_var
.
device
):
self
.
gs_incr_op
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
.
op
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_op
)
def
_before_train
(
self
):
if
self
.
global_step
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
def
_before_run
(
self
,
_
):
# always increase global_step when hooked_sess.run is called
return
self
.
_fetches
def
_after_run
(
self
,
_
,
__
):
# Keep python-side global_step in agreement with TF-side
self
.
trainer
.
_global_step
+=
1
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer.
""" Base class for a trainer.
...
@@ -38,7 +68,7 @@ class Trainer(object):
...
@@ -38,7 +68,7 @@ class Trainer(object):
sess (tf.Session): the current session in use.
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
monitors (Monitors): the monitors. Callbacks can use it for logging.
local_step (int): the number of steps that have finished in the current epoch.
local_step (int): the number of
(tensorpack)
steps that have finished in the current epoch.
"""
"""
# step attr only available after before_train?
# step attr only available after before_train?
...
@@ -58,6 +88,7 @@ class Trainer(object):
...
@@ -58,6 +88,7 @@ class Trainer(object):
self
.
_callbacks
=
[]
self
.
_callbacks
=
[]
self
.
monitors
=
[]
self
.
monitors
=
[]
self
.
_epoch_num
=
None
self
.
_epoch_num
=
None
self
.
_global_step
=
0
self
.
_setup
()
# subclass will setup the graph and InputSource
self
.
_setup
()
# subclass will setup the graph and InputSource
...
@@ -102,24 +133,18 @@ class Trainer(object):
...
@@ -102,24 +133,18 @@ class Trainer(object):
def
run_step
(
self
):
def
run_step
(
self
):
"""
"""
Defines what to do in one iteration
, by
default is:
Defines what to do in one iteration
. The
default is:
``self.hooked_sess.run(self.train_op)``.
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
or overriding this method.
"""
"""
assert
hasattr
(
self
,
'train_op'
),
\
if
not
hasattr
(
self
,
'train_op'
):
"Please either set `Trainer.train_op` or provide an implementation "
\
raise
NotImplementedError
(
"of Trainer.run_step()!"
"Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!"
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
def
_setup_input_source
(
self
,
input_source
):
"""
Setup InputSource on this trainer.
"""
cbs
=
input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
def
setup
(
self
):
def
setup
(
self
):
"""
"""
Setup the trainer and be ready for the main loop.
Setup the trainer and be ready for the main loop.
...
@@ -175,25 +200,25 @@ class Trainer(object):
...
@@ -175,25 +200,25 @@ class Trainer(object):
@
property
@
property
def
global_step
(
self
):
def
global_step
(
self
):
"""
"""
The number of steps that have finished or is currently running.
The tensorflow global_step, i.e. how many times `hooked_sess.run` has been called.
Note:
1. global_step is incremented **after** each `hooked_sess.run` returns from TF runtime.
2. If you make zero or more than one calls to `hooked_sess.run` in one
`run_step`, local_step and global_step may increment at different speed.
"""
"""
try
:
return
self
.
_global_step
return
self
.
_starting_step
+
\
self
.
config
.
steps_per_epoch
*
(
self
.
epoch_num
-
self
.
config
.
starting_epoch
)
+
\
self
.
local_step
+
1
# +1: the ongoing step
except
AttributeError
:
return
get_global_step_value
()
def
main_loop
(
self
):
def
main_loop
(
self
):
"""
"""
Run the main training loop.
Run the main training loop.
"""
"""
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
self
.
_
starting
_step
=
get_global_step_value
()
self
.
_
global
_step
=
get_global_step_value
()
try
:
try
:
self
.
_callbacks
.
before_train
()
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
# refresh global step (might have changed by callbacks) TODO ugly
self
.
_
starting
_step
=
get_global_step_value
()
self
.
_
global
_step
=
get_global_step_value
()
for
self
.
_epoch_num
in
range
(
for
self
.
_epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
_epoch_num
))
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
_epoch_num
))
...
@@ -221,7 +246,7 @@ class Trainer(object):
...
@@ -221,7 +246,7 @@ class Trainer(object):
self
.
_callbacks
.
after_train
()
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
self
.
hooked_sess
.
close
()
# Predictor related methods:
# Predictor related methods
. They actually should not be part of a trainer
:
@
property
@
property
def
vs_name_for_predictor
(
self
):
def
vs_name_for_predictor
(
self
):
"""
"""
...
...
tensorpack/train/feedfree.py
View file @
a6aa8bce
...
@@ -28,7 +28,8 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -28,7 +28,8 @@ class FeedfreeTrainerBase(Trainer):
def
_setup
(
self
):
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
self
.
_setup_input_source
(
self
.
_input_source
)
cbs
=
self
.
_setup_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
# deprecated
# deprecated
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment