Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6607d856
Commit
6607d856
authored
Jul 12, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix unpool unknown shape problem
parent
e51855c5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
1 deletion
+5
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-0
tensorpack/models/pool.py
tensorpack/models/pool.py
+3
-1
tensorpack/predict/common.py
tensorpack/predict/common.py
+1
-0
No files found.
tensorpack/models/conv2d.py
View file @
6607d856
...
...
@@ -32,6 +32,7 @@ def Conv2D(x, out_channel, kernel_shape,
"""
in_shape
=
x
.
get_shape
()
.
as_list
()
in_channel
=
in_shape
[
-
1
]
assert
in_channel
is
not
None
,
"Input to Conv2D cannot have unknown channel!"
assert
in_channel
%
split
==
0
assert
out_channel
%
split
==
0
...
...
tensorpack/models/pool.py
View file @
6607d856
...
...
@@ -74,7 +74,9 @@ def UnPooling2x2ZeroFilled(x):
return
tf
.
reshape
(
out
,
out_size
)
else
:
sh
=
tf
.
shape
(
x
)
return
tf
.
reshape
(
out
,
[
-
1
,
sh
[
1
]
*
2
,
sh
[
2
]
*
2
,
sh
[
3
]])
ret
=
tf
.
reshape
(
out
,
tf
.
pack
([
-
1
,
sh
[
1
]
*
2
,
sh
[
2
]
*
2
,
sh
[
3
]]))
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
return
ret
@
layer_register
()
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
):
...
...
tensorpack/predict/common.py
View file @
6607d856
...
...
@@ -90,5 +90,6 @@ def get_predict_func(config):
def
run_input
(
dp
):
feed
=
dict
(
zip
(
input_map
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
# XXX hack. so the caller can get access to the session.
run_input
.
session
=
sess
return
run_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment