Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
96255c9a
Commit
96255c9a
authored
Feb 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
model multiple inputs as a list
parent
b506eb0a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
17 deletions
+22
-17
example_cifar10.py
example_cifar10.py
+7
-7
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+3
-1
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+1
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/models/fc.py
tensorpack/models/fc.py
+1
-1
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+7
-4
tensorpack/utils/modelutils.py
tensorpack/utils/modelutils.py
+2
-2
No files found.
example_cifar10.py
View file @
96255c9a
...
...
@@ -92,18 +92,18 @@ def get_config():
dataset_train
=
dataset
.
Cifar10
(
'train'
)
augmentors
=
[
RandomCrop
((
24
,
24
)),
Flip
(
horiz
=
True
),
BrightnessAdd
(
63
),
Contrast
((
0.2
,
1.8
)),
MeanVarianceNormalize
(
all_channel
=
True
)
imgaug
.
RandomCrop
((
24
,
24
)),
imgaug
.
Flip
(
horiz
=
True
),
imgaug
.
BrightnessAdd
(
63
),
imgaug
.
Contrast
((
0.2
,
1.8
)),
imgaug
.
MeanVarianceNormalize
(
all_channel
=
True
)
]
dataset_train
=
AugmentImageComponent
(
dataset_train
,
augmentors
)
dataset_train
=
BatchData
(
dataset_train
,
128
)
augmentors
=
[
CenterCrop
((
24
,
24
)),
MeanVarianceNormalize
(
all_channel
=
True
)
imgaug
.
CenterCrop
((
24
,
24
)),
imgaug
.
MeanVarianceNormalize
(
all_channel
=
True
)
]
dataset_test
=
dataset
.
Cifar10
(
'test'
)
dataset_test
=
AugmentImageComponent
(
dataset_test
,
augmentors
)
...
...
tensorpack/dataflow/dataset/cifar10.py
View file @
96255c9a
...
...
@@ -7,6 +7,7 @@ import cPickle
import
numpy
from
six.moves
import
urllib
import
tarfile
import
logging
from
...utils
import
logger
from
..base
import
DataFlow
...
...
@@ -24,10 +25,11 @@ def maybe_download_and_extract(dest_directory):
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
if
os
.
path
.
isdir
(
os
.
path
.
join
(
dest_directory
,
'cifar-10-batches-py'
)):
logger
.
info
(
"Found cifar10 data in {}."
.
format
(
dest_directory
))
return
else
:
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
file
name
,
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
file
path
,
float
(
count
*
block_size
)
/
float
(
total_size
)
*
100.0
))
sys
.
stdout
.
flush
()
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
DATA_URL
,
filepath
,
reporthook
=
_progress
)
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
96255c9a
...
...
@@ -24,7 +24,7 @@ def maybe_download(filename, work_directory):
os
.
mkdir
(
work_directory
)
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
logger
.
info
(
"Downloading mnist data
..."
)
logger
.
info
(
"Downloading mnist data
to {}..."
.
format
(
filepath
)
)
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
SOURCE_URL
+
filename
,
filepath
)
statinfo
=
os
.
stat
(
filepath
)
logger
.
info
(
'Successfully downloaded to '
+
filename
)
...
...
tensorpack/models/conv2d.py
View file @
96255c9a
...
...
@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride
=
shape4d
(
stride
)
if
W_init
is
None
:
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
4
e-2
)
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
1
e-2
)
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
...
...
tensorpack/models/fc.py
View file @
96255c9a
...
...
@@ -17,7 +17,7 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
if
W_init
is
None
:
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
1
.0
/
math
.
sqrt
(
float
(
in_dim
)))
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
1
/
math
.
sqrt
(
float
(
in_dim
)))
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
(
0.0
)
...
...
tensorpack/models/image_sample.py
View file @
96255c9a
...
...
@@ -39,21 +39,24 @@ def sample(img, coords):
return
sampled
@
layer_register
()
def
ImageSample
(
template
,
mapping
):
def
ImageSample
(
inputs
):
"""
Sample the template image, using the given coordinate, by bilinear interpolation.
inputs: list of [template, mapping]
template: bxhxwxc
mapping: bxh2xw2x2 (y, x) real-value coordinates
Return: bxh2xw2xc
"""
template
,
mapping
=
inputs
assert
template
.
get_shape
()
.
ndims
==
4
and
mapping
.
get_shape
()
.
ndims
==
4
mapping
=
tf
.
maximum
(
mapping
,
0.0
)
tf
.
check_numerics
(
mapping
,
"mapping"
)
lcoor
=
tf
.
cast
(
mapping
,
tf
.
int32
)
# floor
ucoor
=
lcoor
+
1
# has to cast to int32 and then cast back
# XXX tf.floor have gradient 1 w.r.t input, bug or feature?
# tf.floor have gradient 1 w.r.t input
# TODO bug fixed in #951
diff
=
mapping
-
tf
.
cast
(
lcoor
,
tf
.
float32
)
neg_diff
=
1.0
-
diff
#bxh2xw2x2
...
...
@@ -128,7 +131,7 @@ if __name__ == '__main__':
mapping
[
0
,
y
,
x
,:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapv
=
tf
.
Variable
(
mapping
)
output
=
ImageSample
(
'sample'
,
imv
,
mapv
)
output
=
ImageSample
(
'sample'
,
[
imv
,
mapv
]
)
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
initialize_all_variables
())
...
...
tensorpack/utils/modelutils.py
View file @
96255c9a
...
...
@@ -23,9 +23,9 @@ def describe_model():
def
get_shape_str
(
tensors
):
""" return the shape string for a tensor or a list of tensors"""
if
isinstance
(
tensors
,
list
):
if
isinstance
(
tensors
,
(
list
,
tuple
)
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
tensors
))
map
(
lambda
x
:
str
(
x
.
get_shape
()
.
as_list
()),
tensors
))
else
:
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment