Commit 0fa11e65 authored by Yuxin Wu's avatar Yuxin Wu

batchdatabyshape

parent c51df958
...@@ -43,8 +43,6 @@ Usage: ...@@ -43,8 +43,6 @@ Usage:
../../scripts/plot-point.py --legend 1,2,3,4,5,final --decay 0.8 ../../scripts/plot-point.py --legend 1,2,3,4,5,final --decay 0.8
""" """
BATCH_SIZE = 1
class Model(ModelDesc): class Model(ModelDesc):
def __init__(self, is_training=True): def __init__(self, is_training=True):
self.isTrain = is_training self.isTrain = is_training
...@@ -163,13 +161,13 @@ def get_data(name): ...@@ -163,13 +161,13 @@ def get_data(name):
imgaug.GaussianNoise(), imgaug.GaussianNoise(),
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchDataByShape(ds, 8, idx=0)
#if isTrain: if isTrain:
#ds = PrefetchDataZMQ(ds, 3) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
def view_data(): def view_data():
ds = get_data('train') ds = RepeatedData(get_data('train'), -1)
ds.reset_state() ds.reset_state()
for ims, edgemaps in ds.get_data(): for ims, edgemaps in ds.get_data():
for im, edgemap in zip(ims, edgemaps): for im, edgemap in zip(ims, edgemaps):
......
...@@ -24,6 +24,7 @@ with much much fewer lines of code. ...@@ -24,6 +24,7 @@ with much much fewer lines of code.
It reaches 74.5% single-crop validation accuracy, slightly better than the official code, It reaches 74.5% single-crop validation accuracy, slightly better than the official code,
and has the same running speed as well. and has the same running speed as well.
The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512. The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512.
With 8 TitanX it runs about 0.45 it/s.
""" """
BATCH_SIZE = 64 BATCH_SIZE = 64
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from __future__ import division from __future__ import division
import copy import copy
import numpy as np import numpy as np
from collections import deque from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import * from ..utils import *
...@@ -13,7 +13,7 @@ from ..utils import * ...@@ -13,7 +13,7 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData', 'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent', 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData', 'TestDataSpeed'] 'LocallyShuffleData', 'TestDataSpeed', 'BatchDataByShape']
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
def __init__(self, ds, size=1000): def __init__(self, ds, size=1000):
...@@ -70,7 +70,7 @@ class BatchData(ProxyDataFlow): ...@@ -70,7 +70,7 @@ class BatchData(ProxyDataFlow):
holder.append(data) holder.append(data)
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder) yield BatchData._aggregate_batch(holder)
holder = [] del holder[:]
if self.remainder and len(holder) > 0: if self.remainder and len(holder) > 0:
yield BatchData._aggregate_batch(holder) yield BatchData._aggregate_batch(holder)
...@@ -97,6 +97,34 @@ class BatchData(ProxyDataFlow): ...@@ -97,6 +97,34 @@ class BatchData(ProxyDataFlow):
IP.embed(config=IP.terminal.ipapp.load_default_config()) IP.embed(config=IP.terminal.ipapp.load_default_config())
return result return result
class BatchDataByShape(BatchData):
def __init__(self, ds, batch_size, idx):
""" Group datapoint of the same shape together to batches
:param ds: a DataFlow instance. Its component must be either a scalar or a numpy array
:param idx: dp[idx] will be used to group datapoints. Other component
in dp are assumed to have the same shape.
"""
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx
def size(self):
raise NotImplementedError()
def reset_state(self):
super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list)
def get_data(self):
for dp in self.ds.get_data():
shp = dp[self.idx].shape
print(shp, len(self.holder))
holder = self.holder[shp]
holder.append(dp)
if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder)
del holder[:]
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed epoch size. """ Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch. The state of the underlying DataFlow is maintained among each epoch.
......
...@@ -100,7 +100,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -100,7 +100,7 @@ class ILSVRC12(RNGDataFlow):
If is 'original' then keep the original decompressed dir with list If is 'original' then keep the original decompressed dir with list
of image files (as below). If equals to 'train', use the `train/` dir of image files (as below). If equals to 'train', use the `train/` dir
structure with class name as subdirectories. structure with class name as subdirectories.
:param include_bb: Include the bounding box. Useful in training. :param include_bb: Include the bounding box. Maybe useful in training.
Dir should have the following structure: Dir should have the following structure:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment