Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a51e2de4
Commit
a51e2de4
authored
Jan 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
a missing part of the last commit.
parent
e3045eda
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
61 deletions
+39
-61
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+15
-46
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+18
-13
tensorpack/train/config.py
tensorpack/train/config.py
+6
-2
No files found.
tensorpack/callbacks/steps.py
View file @
a51e2de4
...
...
@@ -6,23 +6,20 @@
""" Some common step callbacks. """
import
tensorflow
as
tf
import
re
from
six.moves
import
zip
import
tqdm
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.naming
import
(
MOVING_SUMMARY_VARS_KEY
,
GLOBAL_STEP_INCR_VAR_NAME
,
from
..utils.naming
import
(
GLOBAL_STEP_INCR_OP_NAME
,
LOCAL_STEP_OP_NAME
)
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
,
get_global_step_value
from
.base
import
Callback
__all__
=
[
'Step
Stat
Printer'
,
'MaintainStepCounter'
,
'
SummaryMovingAverage'
,
'
ProgressBar'
]
__all__
=
[
'Step
Tensor
Printer'
,
'MaintainStepCounter'
,
'ProgressBar'
]
class
Step
Stat
Printer
(
Callback
):
class
Step
Tensor
Printer
(
Callback
):
""" It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use
:func:`symbolic_functions.print_stat` or :func:`tf.Print` instead. """
...
...
@@ -30,10 +27,10 @@ class StepStatPrinter(Callback):
def
__init__
(
self
,
names
):
"""
Args:
names(list): list of string, the names of the tensor to print.
names(list): list of string, the names of the tensor
s
to print.
"""
names
=
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
names
]
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than Step
Stat
Printer!"
)
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than Step
Tensor
Printer!"
)
self
.
_names
=
names
def
_extra_fetches
(
self
):
...
...
@@ -53,8 +50,11 @@ class MaintainStepCounter(Callback):
"""
def
_setup_graph
(
self
):
# ensure it exists
get_global_step_var
()
self
.
gs_incr_var
=
self
.
trainer
.
sess
.
graph
.
get_tensor_by_name
(
GLOBAL_STEP_INCR_VAR_NAME
)
gs_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
self
.
gs_incr_var
=
tf
.
assign_add
(
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
self
.
local_step
=
tf
.
mod
(
self
.
gs_incr_var
,
self
.
trainer
.
config
.
step_per_epoch
,
name
=
LOCAL_STEP_OP_NAME
)
...
...
@@ -68,37 +68,6 @@ class MaintainStepCounter(Callback):
return
[
self
.
gs_incr_var
.
op
]
class
SummaryMovingAverage
(
Callback
):
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
"""
def
__init__
(
self
,
collection
=
MOVING_SUMMARY_VARS_KEY
,
decay
=
0.95
):
"""
Args:
collection(str): the collection of tensors to summarize. The
default would work with :func:`add_moving_summary`.
decay(float): the decay of the moving average.
"""
self
.
_collection
=
collection
self
.
_decay
=
decay
def
_setup_graph
(
self
):
tensors
=
set
(
tf
.
get_collection
(
self
.
_collection
))
# TODO will produce tower0/xxx. not elegant
with
tf
.
name_scope
(
None
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
self
.
_decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
tensors
)
for
idx
,
c
in
enumerate
(
tensors
):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
self
.
ema_op
=
avg_maintain_op
def
_extra_fetches
(
self
):
return
[
self
.
ema_op
]
class
ProgressBar
(
Callback
):
""" A progress bar based on tqdm. Enabled by default. """
def
_before_train
(
self
):
...
...
tensorpack/tfutils/common.py
View file @
a51e2de4
...
...
@@ -5,17 +5,23 @@
import
tensorflow
as
tf
from
..utils.naming
import
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
,
GLOBAL_STEP_INCR_OP_NAME
from
..utils.naming
import
(
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
,
LOCAL_STEP_VAR_NAME
)
from
..utils
import
logger
from
..utils.argtools
import
memoized
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_var'
,
'get_local_step_var'
,
'get_op_tensor_name'
,
'get_tensors_by_names'
,
'get_op_or_tensor_by_name'
,
'get_tf_version'
,
'get_name_scope_name'
'get_name_scope_name'
,
]
...
...
@@ -56,8 +62,6 @@ def get_global_step_var():
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
initializer
=
0
,
trainable
=
False
,
dtype
=
tf
.
int32
)
# also create the incr operation
tf
.
assign_add
(
var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
return
var
...
...
@@ -70,6 +74,15 @@ def get_global_step_value():
get_global_step_var
())
@
memoized
def
get_local_step_var
():
try
:
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
LOCAL_STEP_VAR_NAME
)
except
KeyError
:
logger
.
warn
(
"get_local_step_var() is only available to use in callbacks!"
)
raise
def
get_op_tensor_name
(
name
):
"""
Will automatically determine if ``name`` is a tensor name (ends with ':x')
...
...
@@ -110,14 +123,6 @@ def get_op_or_tensor_by_name(name):
return
G
.
get_operation_by_name
(
name
)
def
get_tf_version
():
"""
Returns:
int:
"""
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
def
get_name_scope_name
():
"""
Returns:
...
...
tensorpack/train/config.py
View file @
a51e2de4
...
...
@@ -6,7 +6,8 @@ import tensorflow as tf
from
..callbacks
import
(
Callbacks
,
SummaryMovingAverage
,
StatPrinter
,
ProgressBar
,
MaintainStepCounter
)
StatPrinter
,
ProgressBar
,
MaintainStepCounter
)
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..utils
import
logger
...
...
@@ -84,7 +85,10 @@ class TrainConfig(object):
callbacks
=
callbacks
.
cbs
[:
-
1
]
# the last one is StatPrinter()
assert_type
(
callbacks
,
list
)
if
extra_callbacks
is
None
:
extra_callbacks
=
[
SummaryMovingAverage
(),
ProgressBar
(),
StatPrinter
()]
extra_callbacks
=
[
SummaryMovingAverage
(),
ProgressBar
(),
StatPrinter
()]
self
.
callbacks
=
[
MaintainStepCounter
()]
+
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
self
.
callbacks
=
Callbacks
(
self
.
callbacks
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment