Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0d894877
Commit
0d894877
authored
Feb 26, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
gradproc
parent
80622ae7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
86 additions
and
23 deletions
+86
-23
example_mnist.py
example_mnist.py
+1
-0
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+1
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+5
-5
tensorpack/train/base.py
tensorpack/train/base.py
+5
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-17
tensorpack/utils/gradproc.py
tensorpack/utils/gradproc.py
+69
-0
tensorpack/utils/sessinit.py
tensorpack/utils/sessinit.py
+3
-0
No files found.
example_mnist.py
View file @
0d894877
...
...
@@ -14,6 +14,7 @@ from tensorpack.train import TrainConfig, SimpleTrainer
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.gradproc
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow
import
*
...
...
tensorpack/callbacks/validation_callback.py
View file @
0d894877
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
import
itertools
from
tqdm
import
tqdm
from
abc
import
ABCMeta
from
abc
import
ABCMeta
,
abstractmethod
from
..utils
import
*
from
..utils.stat
import
*
...
...
tensorpack/models/model_desc.py
View file @
0d894877
...
...
@@ -7,6 +7,8 @@ from abc import ABCMeta, abstractmethod
import
tensorflow
as
tf
from
collections
import
namedtuple
from
..utils.gradproc
import
*
__all__
=
[
'ModelDesc'
,
'InputVar'
]
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
...
...
@@ -68,8 +70,6 @@ class ModelDesc(object):
the cost to minimize. scalar variable
"""
def
get_lr_multiplier
(
self
):
"""
Return a list of (variable_regex: multiplier)
"""
return
[]
def
get_gradient_processor
(
self
):
""" Return a list of GradientProcessor. They will be executed in order"""
return
[
SummaryGradient
(),
CheckGradient
()]
tensorpack/train/base.py
View file @
0d894877
...
...
@@ -111,3 +111,8 @@ class Trainer(object):
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
def
process_grads
(
self
,
grads
):
procs
=
self
.
config
.
model
.
get_gradient_processor
()
for
proc
in
procs
:
grads
=
proc
.
process
(
grads
)
return
grads
tensorpack/train/trainer.py
View file @
0d894877
...
...
@@ -15,17 +15,6 @@ from ..utils.summary import summary_moving_average
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
'start_train'
]
def
summary_grads
(
grads
):
for
grad
,
var
in
grads
:
if
grad
:
# TODO also summary RMS and print
tf
.
histogram_summary
(
var
.
op
.
name
+
'/gradients'
,
grad
)
def
check_grads
(
grads
):
for
grad
,
var
in
grads
:
assert
grad
is
not
None
,
"Grad is None for variable {}"
.
format
(
var
.
name
)
tf
.
Assert
(
tf
.
reduce_all
(
tf
.
is_finite
(
var
)),
[
var
])
def
scale_grads
(
grads
,
multiplier
):
ret
=
[]
for
grad
,
var
in
grads
:
...
...
@@ -54,9 +43,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op
=
summary_moving_average
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
check_grads
(
grads
)
grads
=
scale_grads
(
grads
,
model
.
get_lr_multiplier
())
summary_grads
(
grads
)
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
...
...
@@ -133,9 +120,7 @@ class QueueInputTrainer(Trainer):
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
avg_maintain_op
=
summary_moving_average
(
cost_var
)
# TODO(multigpu) average the cost from each device?
check_grads
(
grads
)
grads
=
scale_grads
(
grads
,
model
.
get_lr_multiplier
())
summary_grads
(
grads
)
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
...
...
tensorpack/utils/gradproc.py
0 → 100644
View file @
0d894877
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: gradproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
import
re
from
.
import
logger
__all__
=
[
'GradientProcessor'
,
'SummaryGradient'
,
'CheckGradient'
,
'ScaleGradient'
]
class
GradientProcessor
(
object
):
__metaclass__
=
ABCMeta
@
abstractmethod
def
process
(
self
,
grads
):
"""
Process the symbolic gradients, return symbolic gradients
grads: list of (grad, var)
"""
class
SummaryGradient
(
GradientProcessor
):
"""
Summary history and RMS for each graident variable
"""
def
process
(
self
,
grads
):
for
grad
,
var
in
grads
:
tf
.
histogram_summary
(
var
.
op
.
name
+
'/grad'
,
grad
)
tf
.
scalar_summary
(
var
.
op
.
name
+
'/gradRMS'
,
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
grad
))))
return
grads
class
CheckGradient
(
GradientProcessor
):
"""
Check for numeric issue
"""
def
process
(
self
,
grads
):
for
grad
,
var
in
grads
:
assert
grad
is
not
None
,
"Grad is None for variable {}"
.
format
(
var
.
name
)
# TODO make assert work
tf
.
Assert
(
tf
.
reduce_all
(
tf
.
is_finite
(
var
)),
[
var
])
return
grads
class
ScaleGradient
(
GradientProcessor
):
"""
Scale gradient by a multiplier
"""
def
__init__
(
self
,
multipliers
):
"""
multipliers: list of (regex, float)
"""
self
.
multipliers
=
multipliers
def
process
(
self
,
grads
):
# TODO use None for zero to speed up?
ret
=
[]
for
grad
,
var
in
grads
:
varname
=
var
.
op
.
name
for
regex
,
val
in
self
.
multipliers
:
if
re
.
search
(
regex
,
varname
):
logger
.
info
(
"Apply lr multiplier {} for {}"
.
format
(
val
,
varname
))
ret
.
append
((
grad
*
val
,
var
))
break
else
:
ret
.
append
((
grad
,
var
))
return
ret
tensorpack/utils/sessinit.py
View file @
0d894877
...
...
@@ -39,6 +39,9 @@ class SaverRestore(SessionInit):
self
.
path
=
model_path
class
ParamRestore
(
SessionInit
):
"""
Restore trainable variables from a dictionary
"""
def
__init__
(
self
,
param_dict
):
self
.
prms
=
param_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment