Commit 12f78b94 authored by Yuxin Wu's avatar Yuxin Wu

support int in add_moving_summary.

parent c6a07ecd
...@@ -94,6 +94,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -94,6 +94,7 @@ class AugmentImageComponent(MapDataComponent):
class AugmentImageCoordinates(MapData): class AugmentImageCoordinates(MapData):
""" """
Apply image augmentors on an image and a list of coordinates. Apply image augmentors on an image and a list of coordinates.
Coordinates must be a Nx2 floating point array, each row is (x, y).
""" """
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True): def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True):
""" """
...@@ -116,6 +117,9 @@ class AugmentImageCoordinates(MapData): ...@@ -116,6 +117,9 @@ class AugmentImageCoordinates(MapData):
def func(dp): def func(dp):
try: try:
img, coords = dp[img_index], dp[coords_index] img, coords = dp[img_index], dp[coords_index]
assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
if copy: if copy:
img, coords = copy_mod.deepcopy((img, coords)) img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img) img, prms = self.augs._augment_return_params(img)
......
...@@ -30,7 +30,7 @@ class Flip(ImageAugmentor): ...@@ -30,7 +30,7 @@ class Flip(ImageAugmentor):
elif vert: elif vert:
self.code = 0 self.code = 0
else: else:
raise ValueError("Are you kidding?") raise ValueError("At least one of horiz or vert has to be True!")
self.prob = prob self.prob = prob
self._init() self._init()
......
...@@ -197,6 +197,8 @@ def add_moving_summary(*args, **kwargs): ...@@ -197,6 +197,8 @@ def add_moving_summary(*args, **kwargs):
for c in v: for c in v:
name = re.sub('tower[0-9]+/', '', c.op.name) name = re.sub('tower[0-9]+/', '', c.op.name)
with G.colocate_with(c), tf.name_scope(None): with G.colocate_with(c), tf.name_scope(None):
if not c.dtype.is_floating:
c = tf.cast(c, tf.float32)
# assign_moving_average creates variables with op names, therefore clear ns first. # assign_moving_average creates variables with op names, therefore clear ns first.
with _enter_vs_reuse_ns('EMA') as vs: with _enter_vs_reuse_ns('EMA') as vs:
ema_var = tf.get_variable(name, shape=c.shape, dtype=c.dtype, ema_var = tf.get_variable(name, shape=c.shape, dtype=c.dtype,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment