Commit 2c8cd32f authored by Yuxin Wu's avatar Yuxin Wu

shuffle imagefromfile

parent 6221d17d
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
import numpy as np import numpy as np
import cv2 import cv2
import copy import copy
from .base import DataFlow, ProxyDataFlow from .base import RNGDataFlow, DataFlow, ProxyDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from .imgaug import AugmentorList from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class ImageFromFile(DataFlow): class ImageFromFile(RNGDataFlow):
def __init__(self, files, channel=3, resize=None): def __init__(self, files, channel=3, resize=None, shuffle=False):
""" """
Generate rgb images from list of files Generate rgb images from list of files
:param files: list of file paths :param files: list of file paths
...@@ -23,11 +23,14 @@ class ImageFromFile(DataFlow): ...@@ -23,11 +23,14 @@ class ImageFromFile(DataFlow):
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
self.resize = resize self.resize = resize
self.shuffle = shuffle
def size(self): def size(self):
return len(self.files) return len(self.files)
def get_data(self): def get_data(self):
if self.shuffle:
self.rng.shuffle(self.files)
for f in self.files: for f in self.files:
im = cv2.imread( im = cv2.imread(
f, cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR) f, cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment