Commit 6de6f3a4 authored by Meet Narendra's avatar Meet Narendra 💬

get annot

parent 76b62e21
from logger import Logger
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
LOGGER = Logger().logger()
#Author: @meetdoshi
from Utils import device
from data_utils import PinDataset
class DataLoader():
def __init__(self):
self.pin_dataset = PinDataset(partitions=range(1))
self.pin_dataset.load_meta_data()
self.pidxs = self.pin_dataset.get_pidxs()
return
def get_sample_image(self):
'''
Function to get image,anno
'''
pidx = np.random.choice(self.pidxs)
print(pidx,type(pidx))
image, anno = self.pin_dataset.get_annotation(pin_id=pidx)
return image,anno
def get_all_images(self):
'''
Function to get all images
'''
images = []
annotations = []
for pidx in self.pidxs:
print(pidx)
image,anno = self.pin_dataset.get_annotation(pin_id=pidx)
images.append(image)
annotations.append(anno)
return images,annotations
'''
if __name__ == "__main__":
data_loader = DataLoader()
image,anno = data_loader.get_sample_image()
print(image,type(image))
print(anno)
'''
...@@ -148,7 +148,7 @@ class PinDataset(object): ...@@ -148,7 +148,7 @@ class PinDataset(object):
annos = [] annos = []
pidxs = {} pidxs = {}
for meta_path in self.meta_paths: for meta_path in self.meta_paths:
annos.extend(np.load(meta_path).tolist()) annos.extend(np.load(meta_path,allow_pickle=True,encoding='latin1').tolist())
print('Meta data loaded, contains %d images.' % len(annos)) print('Meta data loaded, contains %d images.' % len(annos))
for (ind_a, anno) in enumerate(annos): for (ind_a, anno) in enumerate(annos):
pidx = self._get_pidx_from_image_name(anno['image_name']) pidx = self._get_pidx_from_image_name(anno['image_name'])
...@@ -161,7 +161,7 @@ class PinDataset(object): ...@@ -161,7 +161,7 @@ class PinDataset(object):
def get_pidxs(self, max_return_num=None): def get_pidxs(self, max_return_num=None):
"""Return the list of pinterest image idx in the meta data.""" """Return the list of pinterest image idx in the meta data."""
assert self.annos is not None, 'Load meta data first, call load_meta_data()' assert self.annos is not None, 'Load meta data first, call load_meta_data()'
pidxs_list = list(self.pidxs.iterkeys()) pidxs_list = list(i for i in self.pidxs)
if max_return_num is not None: if max_return_num is not None:
num = min(len(pidxs_list), max_return_num) num = min(len(pidxs_list), max_return_num)
else: else:
......
...@@ -11,7 +11,8 @@ from __future__ import division ...@@ -11,7 +11,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from data_utils import PinDataset from data_utils import PinDataset
'''
if __name__ == '__main__': if __name__ == '__main__':
pin_dataset = PinDataset() pin_dataset = PinDataset()
pin_dataset.download_images() pin_dataset.download_images()
'''
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from logger import Logger from logger import Logger
from data_loader import DataLoader
LOGGER = Logger().logger() LOGGER = Logger().logger()
class Model(): class Model():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment