Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
596d8008
Commit
596d8008
authored
Jan 24, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use a callback to maintain global_step
parent
bbb47815
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
25 deletions
+54
-25
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+24
-2
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+9
-3
tensorpack/train/base.py
tensorpack/train/base.py
+1
-2
tensorpack/train/config.py
tensorpack/train/config.py
+4
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-3
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+3
-6
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+4
-4
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+6
-0
No files found.
examples/GAN/GAN.py
View file @
596d8008
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
import
numpy
as
np
import
time
from
tensorpack
import
(
FeedfreeTrainerBase
,
TowerContext
,
get_global_step_var
,
QueueInput
,
ModelDesc
)
QueueInput
,
ModelDesc
)
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.gradproc
import
apply_grad_processors
,
CheckGradient
from
tensorpack.dataflow
import
DataFlow
...
...
@@ -92,8 +92,7 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor_d
())
self
.
d_min
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
(),
name
=
'd_op'
)
self
.
d_min
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'd_op'
)
self
.
train_op
=
self
.
d_min
...
...
tensorpack/callbacks/steps.py
View file @
596d8008
...
...
@@ -11,11 +11,15 @@ from six.moves import zip
import
tqdm
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.naming
import
MOVING_SUMMARY_VARS_KEY
from
..utils.naming
import
(
MOVING_SUMMARY_VARS_KEY
,
GLOBAL_STEP_INCR_VAR_NAME
,
LOCAL_STEP_OP_NAME
)
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
from
.base
import
Callback
__all__
=
[
'StepStatPrinter'
,
'SummaryMovingAverage'
,
'ProgressBar'
]
__all__
=
[
'StepStatPrinter'
,
'MaintainStepCounter'
,
'SummaryMovingAverage'
,
'ProgressBar'
]
class
StepStatPrinter
(
Callback
):
...
...
@@ -41,6 +45,24 @@ class StepStatPrinter(Callback):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph and also creates the local step tensor.
This callback is always enabled by the trainer, and you wouldn't need to
use it.
"""
def
_setup_graph
(
self
):
# ensure it exists
get_global_step_var
()
self
.
gs_incr_var
=
self
.
trainer
.
sess
.
graph
.
get_tensor_by_name
(
GLOBAL_STEP_INCR_VAR_NAME
)
self
.
local_step
=
tf
.
mod
(
self
.
gs_incr_var
,
self
.
trainer
.
config
.
step_per_epoch
,
name
=
LOCAL_STEP_OP_NAME
)
def
_extra_fetches
(
self
):
return
[
self
.
gs_incr_var
.
op
]
class
SummaryMovingAverage
(
Callback
):
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
...
...
tensorpack/tfutils/common.py
View file @
596d8008
...
...
@@ -3,12 +3,14 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
..utils.naming
import
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
import
tensorflow
as
tf
from
copy
import
copy
import
six
from
contextlib
import
contextmanager
from
..utils.naming
import
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
,
GLOBAL_STEP_INCR_OP_NAME
from
..utils.argtools
import
memoized
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_var'
,
...
...
@@ -43,6 +45,7 @@ def get_default_sess_config(mem_fraction=0.99):
return
conf
@
memoized
def
get_global_step_var
():
"""
Returns:
...
...
@@ -54,11 +57,14 @@ def get_global_step_var():
except
KeyError
:
scope
=
tf
.
get_variable_scope
()
assert
scope
.
name
==
''
,
\
"Creating global_step_var under a variable scope would cause problems!"
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
"The global_step variable should be created under the root variable scope!"
with
tf
.
variable_scope
(
scope
,
reuse
=
False
),
\
tf
.
name_scope
(
None
):
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
initializer
=
0
,
trainable
=
False
,
dtype
=
tf
.
int32
)
# also create the incr operation
tf
.
assign_add
(
var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
return
var
...
...
tensorpack/train/base.py
View file @
596d8008
...
...
@@ -13,7 +13,7 @@ from .config import TrainConfig
from
..utils
import
logger
from
..utils.timer
import
timed_operation
from
..callbacks
import
StatHolder
from
..tfutils
import
get_global_step_va
r
,
get_global_step_va
lue
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.summary
import
create_scalar_summary
...
...
@@ -144,7 +144,6 @@ class Trainer(object):
"""
self
.
_setup
()
describe_model
()
get_global_step_var
()
# ensure such var exists
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
...
...
tensorpack/train/config.py
View file @
596d8008
...
...
@@ -6,7 +6,7 @@ import tensorflow as tf
from
..callbacks
import
(
Callbacks
,
SummaryMovingAverage
,
StatPrinter
,
ProgressBar
)
StatPrinter
,
ProgressBar
,
MaintainStepCounter
)
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..utils
import
logger
...
...
@@ -42,6 +42,8 @@ class TrainConfig(object):
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), ProgressBar(), StatPrinter()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print
stats generated by other callbacks.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
...
...
@@ -83,7 +85,7 @@ class TrainConfig(object):
assert_type
(
callbacks
,
list
)
if
extra_callbacks
is
None
:
extra_callbacks
=
[
SummaryMovingAverage
(),
ProgressBar
(),
StatPrinter
()]
self
.
callbacks
=
callbacks
+
extra_callbacks
self
.
callbacks
=
[
MaintainStepCounter
()]
+
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
self
.
callbacks
=
Callbacks
(
self
.
callbacks
)
...
...
tensorpack/train/feedfree.py
View file @
596d8008
...
...
@@ -6,7 +6,6 @@
import
tensorflow
as
tf
from
..utils
import
logger
from
..tfutils
import
get_global_step_var
from
..tfutils.tower
import
TowerContext
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
QueueInput
,
FeedfreeInput
...
...
@@ -101,8 +100,7 @@ class SimpleFeedfreeTrainer(
cost
,
grads
=
self
.
_get_cost_and_grad
()
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
(),
name
=
'min_op'
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
# skip training
# self.train_op = tf.group(*self.dequed_inputs)
...
...
tensorpack/train/multigpu.py
View file @
596d8008
...
...
@@ -11,8 +11,7 @@ from six.moves import zip, range
from
..utils
import
logger
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..utils.concurrency
import
LoopThread
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
from
..tfutils
import
(
backup_collection
,
restore_collection
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
.base
import
Trainer
...
...
@@ -112,8 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
(),
name
=
'min_op'
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
...
...
@@ -163,8 +161,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list
=
[
apply_grad_processors
(
g
,
gradprocs
)
for
g
in
grad_list
]
# use grad from the first tower for iteration in main thread
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
get_global_step_var
(),
name
=
'min_op'
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
name
=
'min_op'
)
self
.
_start_async_threads
(
grad_list
)
...
...
tensorpack/train/trainer.py
View file @
596d8008
...
...
@@ -7,8 +7,9 @@ import tensorflow as tf
from
.base
import
Trainer
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
TowerContext
)
from
..predict
import
OnlinePredictor
,
build_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
...
...
@@ -88,8 +89,7 @@ class SimpleTrainer(Trainer):
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
(),
name
=
'min_op'
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
...
...
tensorpack/utils/naming.py
View file @
596d8008
...
...
@@ -7,6 +7,12 @@ import tensorflow as tf
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
GLOBAL_STEP_INCR_OP_NAME
=
'global_step_incr'
GLOBAL_STEP_INCR_VAR_NAME
=
'global_step_incr:0'
LOCAL_STEP_OP_NAME
=
'local_step'
LOCAL_STEP_VAR_NAME
=
'local_step:0'
# prefix of predict tower
PREDICT_TOWER
=
'towerp'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment