Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
06ea1c0a
Commit
06ea1c0a
authored
Jan 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
api docs for tfutils/
parent
bbf41d9e
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
291 additions
and
128 deletions
+291
-128
docs/conf.py
docs/conf.py
+1
-2
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+2
-2
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+1
-1
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+1
-2
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+24
-5
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+49
-7
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+39
-17
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+30
-37
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+12
-5
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+49
-28
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+13
-5
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+59
-15
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+11
-2
No files found.
docs/conf.py
View file @
06ea1c0a
...
...
@@ -67,8 +67,7 @@ extensions = [
'sphinx.ext.autodoc'
,
'sphinx.ext.napoleon'
,
#'sphinx.ext.coverage',
#'sphinx.ext.mathjax',
'sphinx.ext.mathbase'
,
'sphinx.ext.mathjax'
,
'sphinx.ext.intersphinx'
,
'sphinx.ext.viewcode'
,
]
...
...
examples/Atari2600/DQN.py
View file @
06ea1c0a
...
...
@@ -132,8 +132,8 @@ class Model(ModelDesc):
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
self
.
cost
=
tf
.
truediv
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
tf
.
cast
(
BATCH_SIZE
,
tf
.
float32
),
name
=
'cost'
)
self
.
cost
=
tf
.
reduce_mean
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
name
=
'cost'
)
summary
.
add_param_summary
((
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
add_moving_summary
(
self
.
cost
)
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
06ea1c0a
...
...
@@ -88,7 +88,7 @@ class Model(ModelDesc):
def
get_gradient_processor
(
self
):
return
[
MapGradient
(
lambda
grad
:
tf
.
clip_by_global_norm
([
grad
],
5
)[
0
][
0
]),
ScaleGradient
(
[(
'STN.*'
,
0.1
)]
),
SummaryGradient
()]
ScaleGradient
(
(
'STN.*'
,
0.1
)
),
SummaryGradient
()]
def
get_data
(
isTrain
):
...
...
tensorpack/tfutils/__init__.py
View file @
06ea1c0a
...
...
@@ -34,5 +34,4 @@ for _, module_name, _ in walk_packages(
continue
if
module_name
in
_TO_IMPORT
:
_global_import
(
module_name
)
if
module_name
!=
'common'
:
__all__
.
append
(
module_name
)
__all__
.
extend
([
'sessinit'
,
'gradproc'
])
tensorpack/tfutils/argscope.py
View file @
06ea1c0a
...
...
@@ -14,13 +14,30 @@ _ArgScopeStack = []
@
contextmanager
def
argscope
(
layers
,
**
param
):
def
argscope
(
layers
,
**
kwargs
):
"""
Args:
layers (list or layer): layer or list of layers to apply the arguments.
Returns:
a context where all appearance of these layer will by default have the
arguments specified by kwargs.
Example:
.. code-block:: python
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
x = Conv2D('conv0', x)
x = Conv2D('conv1', x)
x = Conv2D('conv2', x, out_channel=64) # override argscope
"""
if
not
isinstance
(
layers
,
list
):
layers
=
[
layers
]
def
_check_args_exist
(
l
):
args
=
inspect
.
getargspec
(
l
)
.
args
for
k
,
v
in
six
.
iteritems
(
param
):
for
k
,
v
in
six
.
iteritems
(
kwargs
):
assert
k
in
args
,
"No argument {} in {}"
.
format
(
k
,
l
.
__name__
)
for
l
in
layers
:
...
...
@@ -29,7 +46,7 @@ def argscope(layers, **param):
new_scope
=
copy
.
copy
(
get_arg_scope
())
for
l
in
layers
:
new_scope
[
l
.
__name__
]
.
update
(
param
)
new_scope
[
l
.
__name__
]
.
update
(
kwargs
)
_ArgScopeStack
.
append
(
new_scope
)
yield
del
_ArgScopeStack
[
-
1
]
...
...
@@ -37,8 +54,10 @@ def argscope(layers, **param):
def
get_arg_scope
():
"""
:returns: the current argscope.
An argscope is a dict of dict: dict[layername] = {arg: val}
Returns:
dict: the current argscope.
An argscope is a dict of dict: ``dict[layername] = {arg: val}``
"""
if
len
(
_ArgScopeStack
)
>
0
:
return
_ArgScopeStack
[
-
1
]
...
...
tensorpack/tfutils/common.py
View file @
06ea1c0a
...
...
@@ -28,8 +28,10 @@ def get_default_sess_config(mem_fraction=0.99):
Return a better session config to use as default.
Tensorflow default session config consume too much resources.
:param mem_fraction: fraction of memory to use. default to 0.99
:returns: a `tf.ConfigProto` object.
Args:
mem_fraction(float): fraction of memory to use.
Returns:
tf.ConfigProto: the config to use.
"""
conf
=
tf
.
ConfigProto
()
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
mem_fraction
...
...
@@ -41,7 +43,11 @@ def get_default_sess_config(mem_fraction=0.99):
def
get_global_step_var
():
""" :returns: the global_step variable in the current graph. create if not existed"""
"""
Returns:
tf.Tensor: the global_step variable in the current graph. create if
doesn't exist.
"""
try
:
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
except
KeyError
:
...
...
@@ -56,7 +62,9 @@ def get_global_step_var():
def
get_global_step
():
""" :returns: global_step value in current graph and session"""
"""
Returns:
float: global_step value in current graph and session"""
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
get_global_step_var
())
...
...
@@ -66,8 +74,10 @@ def get_op_tensor_name(name):
"""
Tensor name is assumed to be ``op_name + ':0'``
:param name: an op or a tensor name
:returns: (op_name, tensor_name)
Args:
name(str): name of an op or a tensor
Returns:
tuple: (op_name, tensor_name)
"""
if
name
.
endswith
(
':0'
):
return
name
[:
-
2
],
name
...
...
@@ -80,7 +90,10 @@ get_op_var_name = get_op_tensor_name
def
get_tensors_by_names
(
names
):
"""
Get a list of tensors in the default graph by a list of names
Get a list of tensors in the default graph by a list of names.
Args:
names (list):
"""
ret
=
[]
G
=
tf
.
get_default_graph
()
...
...
@@ -94,6 +107,12 @@ get_vars_by_names = get_tensors_by_names
def
backup_collection
(
keys
):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret
=
{}
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
...
...
@@ -101,22 +120,45 @@ def backup_collection(keys):
def
restore_collection
(
backup
):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
def
freeze_collection
(
keys
):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup
=
backup_collection
(
keys
)
yield
restore_collection
(
backup
)
def
get_tf_version
():
"""
Returns:
int:
"""
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
tensorpack/tfutils/gradproc.py
View file @
06ea1c0a
...
...
@@ -12,16 +12,17 @@ from ..utils import logger
from
.symbolic_functions
import
rms
from
.summary
import
add_moving_summary
__all__
=
[
'GradientProcessor'
,
'SummaryGradient'
,
'CheckGradient'
,
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
,
'GlobalNormClip'
]
__all__
=
[
'GradientProcessor'
,
'GlobalNormClip'
,
'MapGradient'
,
'SummaryGradient'
,
'CheckGradient'
,
'ScaleGradient'
,
'apply_grad_processors'
]
def
apply_grad_processors
(
grads
,
gradprocs
):
"""
:param grads: list of (grad, var).
:param gradprocs: list of `GradientProcessor` instances.
:returns: list of (grad, var) went through the processors
Args:
grads (list): list of (grad, var).
gradprocs (list): list of :class:`GradientProcessor` instances.
Returns:
list: list of (grad, var) went through the processors.
"""
g
=
[]
for
grad
,
var
in
grads
:
...
...
@@ -36,13 +37,18 @@ def apply_grad_processors(grads, gradprocs):
@
six
.
add_metaclass
(
ABCMeta
)
class
GradientProcessor
(
object
):
""" Base class for all gradient processors.
Subclass should override the ``_process()`` method.
"""
def
process
(
self
,
grads
):
"""
Process the symbolic gradients.
:param grads: list of (grad, var)
:returns: symbolic gradients with the same type as input
Args:
grads (list): list of (grad, var).
Returns:
list: processed gradients, with the same type as input.
"""
with
tf
.
name_scope
(
type
(
self
)
.
__name__
):
return
self
.
_process
(
grads
)
...
...
@@ -53,10 +59,16 @@ class GradientProcessor(object):
class
GlobalNormClip
(
GradientProcessor
):
""" Clip by global norm.
The global norm is the sum of norm for **all** gradients.
See :func:`tf.clip_by_global_norm` for more information.
"""
def
__init__
(
self
,
global_norm
):
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
"""
Args:
global_norm(float): the threshold to clip with.
"""
self
.
_norm
=
global_norm
...
...
@@ -75,9 +87,10 @@ class MapGradient(GradientProcessor):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
Args:
func: takes a grad or (grad, var) pair and returns a grad. If return None, the
gradient is discarded (hence no update to the variable will happen).
regex (str): used to match variables. Defaults to match all variables.
"""
args
=
inspect
.
getargspec
(
func
)
.
args
arg_num
=
len
(
args
)
-
inspect
.
ismethod
(
func
)
...
...
@@ -109,7 +122,7 @@ _summaried_gradient = set()
class
SummaryGradient
(
MapGradient
):
"""
Summary histo
ry and RMS for each graident variable
Summary histo
gram and RMS for each graident variable.
"""
def
__init__
(
self
):
...
...
@@ -127,6 +140,7 @@ class SummaryGradient(MapGradient):
class
CheckGradient
(
MapGradient
):
"""
Check for numeric issue.
See :func:`tf.check_numerics` for more information.
"""
def
__init__
(
self
):
...
...
@@ -141,13 +155,21 @@ class CheckGradient(MapGradient):
class
ScaleGradient
(
MapGradient
):
"""
Scale certain gradient by a multiplier
Scale certain gradient by a multiplier
.
"""
def
__init__
(
self
,
multipliers
,
log
=
True
):
"""
:param multipliers: list of (regex, float)
:param log: whether to do logging or not
Args:
multipliers (tuple or list): tuple of (regex, float), or list of tuples.
log (bool): whether to do logging or not
Example:
Use double learning rate for all the bias (as in caffe):
.. code-block:: python
ScaleGradient(('.*/b', 2))
"""
if
not
isinstance
(
multipliers
,
list
):
multipliers
=
[
multipliers
]
...
...
tensorpack/tfutils/sessinit.py
View file @
06ea1c0a
...
...
@@ -11,7 +11,8 @@ import six
from
..utils
import
logger
,
PREDICT_TOWER
from
.common
import
get_op_var_name
from
.varmanip
import
SessionUpdate
,
get_savename_from_varname
,
is_training_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'ChainInit'
,
...
...
@@ -22,12 +23,14 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
@
six
.
add_metaclass
(
ABCMeta
)
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session"""
""" Base class for utilities to initialize a session
.
"""
def
init
(
self
,
sess
):
""" Initialize a session
"""
Initialize a session
:param sess: a `tf.Session`
Args:
sess (tf.Session): the session
"""
self
.
_init
(
sess
)
...
...
@@ -37,7 +40,7 @@ class SessionInit(object):
class
JustCurrentSession
(
SessionInit
):
"""
Just use the current default session.
This is a no-op placeholder"""
""" This is a no-op placeholder"""
def
_init
(
self
,
sess
):
pass
...
...
@@ -45,8 +48,7 @@ class JustCurrentSession(SessionInit):
class
NewSession
(
SessionInit
):
"""
Create a new session. All variables will be initialized by their
initializer.
Initialize global variables by their initializer.
"""
def
_init
(
self
,
sess
):
...
...
@@ -55,32 +57,17 @@ class NewSession(SessionInit):
class
SaverRestore
(
SessionInit
):
"""
Restore an old model saved by `ModelSaver`.
Restore an old model saved by
:class:
`ModelSaver`.
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
Args:
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
"""
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
model_path
=
os
.
path
.
join
(
'.'
,
model_path
)
# avoid #4921 and #6142
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
model_path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
model_path
))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path
=
model_path
if
'00000-of-00001'
in
model_path
:
new_path
=
model_path
.
split
(
'.data'
)[
0
]
elif
model_path
.
endswith
(
'.index'
):
new_path
=
model_path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_path
:
logger
.
warn
(
"[SaverRestore] {} is corrected to {} when restoring the model."
.
format
(
model_path
,
new_path
))
model_path
=
new_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
self
.
set_path
(
model_path
)
model_path
=
get_checkpoint_path
(
model_path
)
self
.
path
=
model_path
self
.
prefix
=
prefix
def
_init
(
self
,
sess
):
...
...
@@ -94,9 +81,6 @@ class SaverRestore(SessionInit):
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)),
write_version
=
2
)
saver
.
restore
(
sess
,
self
.
path
)
def
set_path
(
self
,
model_path
):
self
.
path
=
model_path
@
staticmethod
def
_produce_restore_dict
(
vars_multimap
):
"""
...
...
@@ -161,7 +145,8 @@ class ParamRestore(SessionInit):
def
__init__
(
self
,
param_dict
):
"""
:param param_dict: a dict of {name: value}
Args:
param_dict (dict): a dict of {name: value}
"""
# use varname (with :0) for consistency
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
...
...
@@ -190,12 +175,17 @@ class ParamRestore(SessionInit):
class
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
""" Initialize a session by a list of :class:`SessionInit` instance, executed one by one.
This can be useful for, e.g., loading several models from different files
to form a composition of models.
"""
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
"""
:params sess_inits: list of `SessionInit` instances.
:params new_session: add a `NewSession()` and the beginning, if not there
Args:
sess_inits (list): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
"""
if
new_session
and
not
isinstance
(
sess_inits
[
0
],
NewSession
):
sess_inits
.
insert
(
0
,
NewSession
())
...
...
@@ -208,8 +198,11 @@ class ChainInit(SessionInit):
def
get_model_loader
(
filename
):
"""
Get a corresponding model loader by looking at the file name
:return: either a ParamRestore or SaverRestore
Get a corresponding model loader by looking at the file name.
Returns:
SessInit: either a :class:`ParamRestore` (if name ends with 'npy') or
:class:`SaverRestore` (otherwise).
"""
if
filename
.
endswith
(
'.npy'
):
assert
os
.
path
.
isfile
(
filename
),
filename
...
...
tensorpack/tfutils/summary.py
View file @
06ea1c0a
...
...
@@ -102,8 +102,10 @@ def add_param_summary(*summary_lists):
def
add_moving_summary
(
v
,
*
args
):
"""
:param v: tensor or list of tensor to summary
:param args: tensors to summary
Args:
v (tf.Tensor or list): tensor or list of tensors to summary. Must have
scalar type.
args: tensors to summary (support positional arguments)
"""
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
...
...
@@ -119,9 +121,14 @@ def add_moving_summary(v, *args):
@
memoized
def
summary_moving_average
(
tensors
=
None
):
"""
Create a MovingAverage op and add summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average.
Create a MovingAverage Op and add summary Op for all the moving averages.
This is called by the trainer.
Args:
tensors(list): list of tf.Tensor to summary. hefaults to the
collection ````MOVING_SUMMARY_VARS_KEY``.
Returns:
tf.Operation: an op to maintain these average.
"""
if
tensors
is
None
:
tensors
=
set
(
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
))
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
06ea1c0a
...
...
@@ -8,9 +8,13 @@ import numpy as np
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
"""
:param logits: NxC
:param label: N
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
Args:
logits: (N,C)
label: (N,)
topk(int): topk
Returns:
a float32 vector of length N with 0/1 values. 1 means incorrect
prediction.
"""
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
tf
.
float32
,
name
=
name
)
...
...
@@ -39,9 +43,11 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1].
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced cross entropy loss
Args:
pred: of shape (b, ...). the predictions in [0,1].
label: of the same shape. the ground truth in {0,1}.
Returns:
class-balanced cross entropy loss.
"""
z
=
batch_flatten
(
pred
)
y
=
tf
.
cast
(
batch_flatten
(
label
),
tf
.
float32
)
...
...
@@ -59,14 +65,8 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
"""
The class-balanced cross entropy loss,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
This is more numerically stable than class_balanced_cross_entropy
:param logits: size: the logits.
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
This function accepts logits rather than predictions, and is more numerically stable than
:func:`class_balanced_cross_entropy`.
"""
y
=
tf
.
cast
(
label
,
tf
.
float32
)
...
...
@@ -77,17 +77,12 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
pos_weight
=
beta
/
(
1
-
beta
)
cost
=
tf
.
nn
.
weighted_cross_entropy_with_logits
(
logits
,
y
,
pos_weight
)
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
# logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# cost = tf.sub(loss_pos, loss_neg, name=name)
return
cost
def
print_stat
(
x
,
message
=
None
):
"""
a simple print op
.
Use it like:
x = print_stat(x)
"""
A simple print Op that might be easier to use than :meth:`tf.Print`
.
Use it like:
``x = print_stat(x, message='This is x')``.
"""
if
message
is
None
:
message
=
x
.
op
.
name
...
...
@@ -96,6 +91,10 @@ def print_stat(x, message=None):
def
rms
(
x
,
name
=
None
):
"""
Returns:
root mean square of tensor x.
"""
if
name
is
None
:
name
=
x
.
op
.
name
+
'/rms'
with
tf
.
name_scope
(
None
):
# name already contains the scope
...
...
@@ -104,19 +103,41 @@ def rms(x, name=None):
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
r"""
Huber loss of x.
.. math::
y = \begin{cases} \frac{x^2}{2}, & |x| < \delta \\
\delta |x| - \frac{\delta^2}{2}, & |x| \ge \delta
\end{cases}
Args:
x: the difference vector.
delta (float):
Returns:
a tensor of the same shape of x.
"""
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
return
tf
.
reduce_sum
(
tf
.
select
(
abscost
<
delta
,
sqrcost
*
0.5
,
abscost
*
delta
-
0.5
*
delta
**
2
),
name
=
name
)
return
tf
.
select
(
abscost
<
delta
,
sqrcost
*
0.5
,
abscost
*
delta
-
0.5
*
delta
**
2
,
name
=
name
)
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
"""
get a scalar variable with certain initial value
:param summary: summary this variable
Get a scalar variable with certain initial value
Args:
name (str): name of the variable.
init_value (float): initial value.
summary (bool): whether to summary this variable.
trainable (bool): trainable or not.
Returns:
tf.Variable: the variable
"""
ret
=
tf
.
get_variable
(
name
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
init_value
),
...
...
tensorpack/tfutils/tower.py
View file @
06ea1c0a
...
...
@@ -13,9 +13,14 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
"""
self
.
_name
=
tower_name
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
...
...
@@ -39,12 +44,15 @@ class TowerContext(object):
def
get_variable_on_tower
(
self
,
*
args
,
**
kwargs
):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
Get a variable for this tower specifically, without reusing, even if
it is called under a ``reuse=True`` variable scope.
Tensorflow doesn't allow us to disable reuse under a
``reuse=True`` scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
Args:
args: same as ``tf.get_variable()``.
"""
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
...
...
tensorpack/tfutils/varmanip.py
View file @
06ea1c0a
...
...
@@ -14,17 +14,20 @@ from ..utils.naming import PREDICT_TOWER
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
,
'is_training_name'
]
'get_savename_from_varname'
,
'is_training_name'
,
'get_checkpoint_path'
]
def
get_savename_from_varname
(
varname
,
varname_prefix
=
None
,
savename_prefix
=
None
):
"""
:param varname: a variable name in the graph
:param varname_prefix: an optional prefix that may need to be removed in varname
:param savename_prefix: an optional prefix to append to all savename
:returns: the name used to save the variable
Args:
varname(str): a variable name in the graph
varname_prefix(str): an optional prefix that may need to be removed in varname
savename_prefix(str): an optional prefix to append to all savename
Returns:
str: the name used to save the variable
"""
name
=
varname
if
PREDICT_TOWER
in
name
:
...
...
@@ -46,7 +49,9 @@ class SessionUpdate(object):
def
__init__
(
self
,
sess
,
vars_to_update
):
"""
:param vars_to_update: a collection of variables to update
Args:
sess (tf.Session): a session object
vars_to_update: a collection of variables to update
"""
self
.
sess
=
sess
self
.
assign_ops
=
defaultdict
(
list
)
...
...
@@ -60,8 +65,9 @@ class SessionUpdate(object):
def
update
(
self
,
prms
):
"""
:param prms: dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
Args:
prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
assign_ops
...
...
@@ -77,8 +83,12 @@ class SessionUpdate(object):
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`ParamRestore`).
Args:
path(str): the path to save the parameters.
"""
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
...
...
@@ -96,10 +106,42 @@ the same name".format(v.name))
np
.
save
(
path
,
result
)
def
dump_chkpt_vars
(
model_path
):
""" Dump all variables from a checkpoint to a dict"""
def
get_checkpoint_path
(
model_path
):
"""
Work around TF problems in checkpoint path handling.
Args:
model_path: a user-input path
Returns:
str: the argument that can be passed to NewCheckpointReader
"""
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
model_path
=
os
.
path
.
join
(
'.'
,
model_path
)
# avoid #4921
model_path
=
os
.
path
.
join
(
'.'
,
model_path
)
# avoid #4921 and #6142
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
model_path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
model_path
))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path
=
model_path
if
'00000-of-00001'
in
model_path
:
new_path
=
model_path
.
split
(
'.data'
)[
0
]
elif
model_path
.
endswith
(
'.index'
):
new_path
=
model_path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_path
:
logger
.
warn
(
"[SaverRestore] {} is corrected to {} when restoring the model."
.
format
(
model_path
,
new_path
))
model_path
=
new_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
return
model_path
def
dump_chkpt_vars
(
model_path
):
""" Dump all variables from a checkpoint to a dict.
Args:
model_path(str): path to a checkpoint.
"""
model_path
=
get_checkpoint_path
(
model_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
var_names
=
reader
.
get_variable_to_shape_map
()
.
keys
()
result
=
{}
...
...
@@ -110,8 +152,10 @@ def dump_chkpt_vars(model_path):
def
is_training_name
(
name
):
"""
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
This is a hack temporarily used to improve logging. Do not use this function.
Returns:
bool: Guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names()
...
...
tensorpack/tfutils/varreplace.py
View file @
06ea1c0a
...
...
@@ -14,6 +14,13 @@ _ORIG_GET_VARIABLE = tf.get_variable
@
contextmanager
def
replace_get_variable
(
fn
):
"""
Args:
fn: a function taking the same arguments as ``tf.get_variable``.
Returns:
a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``.
"""
old_getv
=
tf
.
get_variable
old_vars_getv
=
variable_scope
.
get_variable
...
...
@@ -26,8 +33,10 @@ def replace_get_variable(fn):
def
freeze_get_variable
():
"""
Return a contextmanager, where all variables returned by
`get_variable` will have no gradients.
Return a context, where all variables (reused or not) returned by
``get_variable`` will have no gradients (surrounded by ``tf.stop_gradient``).
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get
saved correctly. This is useful to fix certain variables for fine-tuning.
Example:
.. code-block:: python
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment