Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6b6947fc
Commit
6b6947fc
authored
Mar 07, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
gradproc support replicated
parent
99a7d749
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
7 deletions
+19
-7
README.md
README.md
+2
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+10
-3
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+7
-2
No files found.
README.md
View file @
6b6947fc
...
@@ -24,9 +24,9 @@ It's Yet Another TF wrapper, but different in:
...
@@ -24,9 +24,9 @@ It's Yet Another TF wrapper, but different in:
Tensorpack helps you load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization.
Tensorpack helps you load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization.
3.
It's not a model wrapper.
3.
It's not a model wrapper.
+
There are too many symbolic function wrappers.
+
There are too many symbolic function wrappers
in the world
.
Tensorpack includes only a few common models.
Tensorpack includes only a few common models.
You can use any symbolic function library inside tensorpack, including tf
layers/Keras/slim/tflearn/tensorlayer/....
But you can use any symbolic function library inside tensorpack, including tf.
layers/Keras/slim/tflearn/tensorlayer/....
See
[
tutorials
](
http://tensorpack.readthedocs.io/en/latest/tutorial/index.html
)
to know more about these features.
See
[
tutorials
](
http://tensorpack.readthedocs.io/en/latest/tutorial/index.html
)
to know more about these features.
...
...
tensorpack/predict/base.py
View file @
6b6947fc
...
@@ -95,6 +95,8 @@ class OnlinePredictor(PredictorBase):
...
@@ -95,6 +95,8 @@ class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session and given tensors.
""" A predictor which directly use an existing session and given tensors.
"""
"""
ACCEPT_OPTIONS
=
False
def
__init__
(
self
,
input_tensors
,
output_tensors
,
def
__init__
(
self
,
input_tensors
,
output_tensors
,
return_input
=
False
,
sess
=
None
):
return_input
=
False
,
sess
=
None
):
"""
"""
...
@@ -115,7 +117,8 @@ class OnlinePredictor(PredictorBase):
...
@@ -115,7 +117,8 @@ class OnlinePredictor(PredictorBase):
if
sess
is
not
None
:
if
sess
is
not
None
:
self
.
_callable
=
sess
.
make_callable
(
self
.
_callable
=
sess
.
make_callable
(
fetches
=
output_tensors
,
fetches
=
output_tensors
,
feed_list
=
input_tensors
)
feed_list
=
input_tensors
,
accept_options
=
self
.
ACCEPT_OPTIONS
)
else
:
else
:
self
.
_callable
=
None
self
.
_callable
=
None
else
:
else
:
...
@@ -131,8 +134,12 @@ class OnlinePredictor(PredictorBase):
...
@@ -131,8 +134,12 @@ class OnlinePredictor(PredictorBase):
if
self
.
_callable
is
None
:
if
self
.
_callable
is
None
:
self
.
_callable
=
self
.
sess
.
make_callable
(
self
.
_callable
=
self
.
sess
.
make_callable
(
fetches
=
self
.
output_tensors
,
fetches
=
self
.
output_tensors
,
feed_list
=
self
.
input_tensors
)
feed_list
=
self
.
input_tensors
,
return
self
.
_callable
(
*
dp
)
accept_options
=
self
.
ACCEPT_OPTIONS
)
# run_metadata = tf.RunMetadata()
# options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
ret
=
self
.
_callable
(
*
dp
)
return
ret
def
_do_call
(
self
,
dp
):
def
_do_call
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_tensors
),
\
assert
len
(
dp
)
==
len
(
self
.
input_tensors
),
\
...
...
tensorpack/tfutils/gradproc.py
View file @
6b6947fc
...
@@ -152,18 +152,23 @@ class SummaryGradient(MapGradient):
...
@@ -152,18 +152,23 @@ class SummaryGradient(MapGradient):
# TODO this is global. not good.
# TODO this is global. not good.
_summaried_gradient
=
set
()
_summaried_gradient
=
set
()
def
__init__
(
self
,
regex
=
'.*'
):
def
__init__
(
self
,
regex
=
'.*'
,
collections
=
None
):
"""
"""
Args:
Args:
regex(str): same as in :class:`MapGradient`.
regex(str): same as in :class:`MapGradient`.
collections (list[str]): list of collection names
"""
"""
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
,
regex
)
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
,
regex
)
self
.
_coll
=
collections
def
_mapper
(
self
,
grad
,
var
):
def
_mapper
(
self
,
grad
,
var
):
name
=
var
.
op
.
name
name
=
var
.
op
.
name
if
re
.
match
(
'tower[0-9]+/'
,
name
):
# replicated training, var may come from different towers
return
grad
if
name
not
in
SummaryGradient
.
_summaried_gradient
:
if
name
not
in
SummaryGradient
.
_summaried_gradient
:
SummaryGradient
.
_summaried_gradient
.
add
(
name
)
SummaryGradient
.
_summaried_gradient
.
add
(
name
)
tf
.
summary
.
histogram
(
name
+
'-grad'
,
grad
)
tf
.
summary
.
histogram
(
name
+
'-grad'
,
grad
,
collections
=
self
.
_coll
)
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
return
grad
return
grad
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment