Commit 0fa11e65 authored by Yuxin Wu's avatar Yuxin Wu

batchdatabyshape

parent c51df958
......@@ -43,8 +43,6 @@ Usage:
../../scripts/plot-point.py --legend 1,2,3,4,5,final --decay 0.8
"""
BATCH_SIZE = 1
class Model(ModelDesc):
def __init__(self, is_training=True):
self.isTrain = is_training
......@@ -163,13 +161,13 @@ def get_data(name):
imgaug.GaussianNoise(),
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
#if isTrain:
#ds = PrefetchDataZMQ(ds, 3)
ds = BatchDataByShape(ds, 8, idx=0)
if isTrain:
ds = PrefetchDataZMQ(ds, 1)
return ds
def view_data():
ds = get_data('train')
ds = RepeatedData(get_data('train'), -1)
ds.reset_state()
for ims, edgemaps in ds.get_data():
for im, edgemap in zip(ims, edgemaps):
......
......@@ -24,6 +24,7 @@ with much much fewer lines of code.
It reaches 74.5% single-crop validation accuracy, slightly better than the official code,
and has the same running speed as well.
The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512.
With 8 TitanX it runs about 0.45 it/s.
"""
BATCH_SIZE = 64
......
......@@ -5,7 +5,7 @@
from __future__ import division
import copy
import numpy as np
from collections import deque
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import *
......@@ -13,7 +13,7 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData', 'TestDataSpeed']
'LocallyShuffleData', 'TestDataSpeed', 'BatchDataByShape']
class TestDataSpeed(ProxyDataFlow):
def __init__(self, ds, size=1000):
......@@ -70,7 +70,7 @@ class BatchData(ProxyDataFlow):
holder.append(data)
if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder)
holder = []
del holder[:]
if self.remainder and len(holder) > 0:
yield BatchData._aggregate_batch(holder)
......@@ -97,6 +97,34 @@ class BatchData(ProxyDataFlow):
IP.embed(config=IP.terminal.ipapp.load_default_config())
return result
class BatchDataByShape(BatchData):
def __init__(self, ds, batch_size, idx):
""" Group datapoint of the same shape together to batches
:param ds: a DataFlow instance. Its component must be either a scalar or a numpy array
:param idx: dp[idx] will be used to group datapoints. Other component
in dp are assumed to have the same shape.
"""
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx
def size(self):
raise NotImplementedError()
def reset_state(self):
super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list)
def get_data(self):
for dp in self.ds.get_data():
shp = dp[self.idx].shape
print(shp, len(self.holder))
holder = self.holder[shp]
holder.append(dp)
if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder)
del holder[:]
class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
......
......@@ -100,7 +100,7 @@ class ILSVRC12(RNGDataFlow):
If is 'original' then keep the original decompressed dir with list
of image files (as below). If equals to 'train', use the `train/` dir
structure with class name as subdirectories.
:param include_bb: Include the bounding box. Useful in training.
:param include_bb: Include the bounding box. Maybe useful in training.
Dir should have the following structure:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment