Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2aa760b1
Commit
2aa760b1
authored
Sep 20, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix shape comparison when reusing placeholders. fix #1329
parent
eb25cd7f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
5 deletions
+9
-5
docs/tutorial/inference.md
docs/tutorial/inference.md
+1
-1
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+4
-1
tensorpack/train/tower.py
tensorpack/train/tower.py
+2
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-2
No files found.
docs/tutorial/inference.md
View file @
2aa760b1
...
...
@@ -6,7 +6,7 @@
There are two ways to do inference during training.
1.
The easiest way is to write a callback, and use
[
self.trainer.get_predictor()
](
../modules/
modules/
train.html#tensorpack.train.TowerTrainer.get_predictor
)
[
self.trainer.get_predictor()
](
../modules/train.html#tensorpack.train.TowerTrainer.get_predictor
)
to get a callable under inference mode.
See
[
Write a Callback
](
extend/callback.html
)
.
...
...
tensorpack/input_source/input_source_base.py
View file @
2aa760b1
...
...
@@ -34,9 +34,12 @@ def build_or_reuse_placeholder(tensor_spec):
assert
"Placeholder"
in
tensor
.
op
.
type
,
"Tensor {} exists but is not a placeholder!"
.
format
(
name
)
assert
tensor_spec
.
is_compatible_with
(
tensor
),
\
"Tensor {} exists but is not compatible with the signature!"
.
format
(
tensor
)
if
tensor
.
shape
==
tensor_spec
.
shape
:
if
tensor
.
shape
.
as_list
()
==
tensor_spec
.
shape
.
as_list
()
:
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
# Comparing `tensor.shape` directly doesn't work, because
# tensorflow thinks `tf.Dimension(None)` and `tf.Dimension(None)` are not equal.
return
tensor
except
KeyError
:
pass
...
...
tensorpack/train/tower.py
View file @
2aa760b1
...
...
@@ -46,7 +46,8 @@ class TowerTrainer(Trainer):
def
tower_func
(
self
):
"""
A :class:`TowerFunc` instance.
See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)
See `tutorial on tower function
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_
for more information.
"""
return
self
.
_tower_func
...
...
tensorpack/train/trainers.py
View file @
2aa760b1
...
...
@@ -388,8 +388,6 @@ class HorovodTrainer(SingleCostTrainer):
certain numerical issues in practice.
"""
BROADCAST_EVERY_EPOCH
=
True
def
__init__
(
self
,
average
=
True
,
compression
=
None
):
"""
Args:
...
...
@@ -415,6 +413,8 @@ class HorovodTrainer(SingleCostTrainer):
logger
.
info
(
"[HorovodTrainer] local rank={}"
.
format
(
self
.
_local_rank
))
super
(
HorovodTrainer
,
self
)
.
__init__
()
self
.
BROADCAST_EVERY_EPOCH
=
True
def
mpi_enabled
(
self
):
"""
Returns:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment