Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ea72115e
Commit
ea72115e
authored
Feb 19, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
hack shape for batch normalization
parent
9952c6c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
4 deletions
+53
-4
requirements.txt
requirements.txt
+2
-2
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+41
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+10
-0
No files found.
requirements.txt
View file @
ea72115e
termcolor
numpy
protobuf
~=3.0.0a1
pillow
scipy
tqdm
tensorpack/dataflow/common.py
View file @
ea72115e
...
...
@@ -8,6 +8,7 @@ from .base import DataFlow
from
.imgaug
import
AugmentorList
,
Image
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'AugmentImageComponent'
]
class
BatchData
(
DataFlow
):
...
...
@@ -124,6 +125,19 @@ class FakeData(DataFlow):
yield
[
np
.
random
.
random
(
k
)
for
k
in
self
.
shapes
]
class
MapData
(
DataFlow
):
""" Map a function to the datapoint"""
def
__init__
(
self
,
ds
,
func
):
self
.
ds
=
ds
self
.
func
=
func
def
size
(
self
):
return
self
.
ds
.
size
()
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
yield
self
.
func
(
dp
)
class
MapDataComponent
(
DataFlow
):
""" Apply a function to the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
self
.
ds
=
ds
...
...
@@ -138,6 +152,31 @@ class MapData(DataFlow):
dp
[
self
.
index
]
=
self
.
func
(
dp
[
self
.
index
])
yield
dp
class
RandomChooseData
(
DataFlow
):
"""
Randomly choose from several dataflow. Stop producing when any of its dataflow stops.
"""
def
__init__
(
self
,
df_lists
):
"""
df_lists: list of dataflow, or list of (dataflow, probability) tuple
"""
if
isinstance
(
df_lists
[
0
],
(
tuple
,
list
)):
assert
sum
([
v
[
1
]
for
v
in
df_lists
])
==
1.0
self
.
df_lists
=
df_lists
else
:
prob
=
1.0
/
len
(
df_lists
)
self
.
df_lists
=
[(
k
,
prob
)
for
k
in
df_lists
]
def
get_data
(
self
):
itrs
=
[
v
[
0
]
.
get_data
()
for
v
in
self
.
df_lists
]
probs
=
np
.
array
([
v
[
1
]
for
v
in
self
.
df_lists
])
try
:
while
True
:
itr
=
np
.
random
.
choice
(
itrs
,
p
=
probs
)
yield
next
(
itr
)
except
StopIteration
:
return
def
AugmentImageComponent
(
ds
,
augmentors
,
index
=
0
):
"""
Augment the image in each data point
...
...
@@ -146,9 +185,9 @@ def AugmentImageComponent(ds, augmentors, index=0):
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
# TODO reset rng at the beginning of each get_data
# TODO reset rng at the beginning of each get_data
aug
=
AugmentorList
(
augmentors
)
return
MapData
(
return
MapData
Component
(
ds
,
lambda
img
:
aug
.
augment
(
Image
(
img
))
.
arr
,
index
)
tensorpack/models/batch_norm.py
View file @
ea72115e
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
copy
import
copy
from
._common
import
layer_register
...
...
@@ -37,7 +38,16 @@ def BatchNorm(x, is_training, gamma_init=1.0):
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
])
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
gamma_init
))
# XXX hack to clear shape. see tensorflow#1162
if
shape
[
0
]
is
not
None
:
x
=
tf
.
tile
(
x
,
tf
.
pack
([
1
,
1
,
1
,
1
]))
hack_shape
=
copy
(
shape
)
hack_shape
[
0
]
=
None
x
.
set_shape
(
hack_shape
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
)
print
batch_mean
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.999
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment