Commit 6de6f3a4 authored by Meet Narendra's avatar Meet Narendra 💬

get annot

parent 76b62e21
from logger import Logger
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
LOGGER = Logger().logger()
#Author: @meetdoshi
from Utils import device
from data_utils import PinDataset
class DataLoader():
def __init__(self):
self.pin_dataset = PinDataset(partitions=range(1))
self.pin_dataset.load_meta_data()
self.pidxs = self.pin_dataset.get_pidxs()
return
def get_sample_image(self):
'''
Function to get image,anno
'''
pidx = np.random.choice(self.pidxs)
print(pidx,type(pidx))
image, anno = self.pin_dataset.get_annotation(pin_id=pidx)
return image,anno
def get_all_images(self):
'''
Function to get all images
'''
images = []
annotations = []
for pidx in self.pidxs:
print(pidx)
image,anno = self.pin_dataset.get_annotation(pin_id=pidx)
images.append(image)
annotations.append(anno)
return images,annotations
'''
if __name__ == "__main__":
data_loader = DataLoader()
image,anno = data_loader.get_sample_image()
print(image,type(image))
print(anno)
'''
......@@ -148,7 +148,7 @@ class PinDataset(object):
annos = []
pidxs = {}
for meta_path in self.meta_paths:
annos.extend(np.load(meta_path).tolist())
annos.extend(np.load(meta_path,allow_pickle=True,encoding='latin1').tolist())
print('Meta data loaded, contains %d images.' % len(annos))
for (ind_a, anno) in enumerate(annos):
pidx = self._get_pidx_from_image_name(anno['image_name'])
......@@ -161,7 +161,7 @@ class PinDataset(object):
def get_pidxs(self, max_return_num=None):
"""Return the list of pinterest image idx in the meta data."""
assert self.annos is not None, 'Load meta data first, call load_meta_data()'
pidxs_list = list(self.pidxs.iterkeys())
pidxs_list = list(i for i in self.pidxs)
if max_return_num is not None:
num = min(len(pidxs_list), max_return_num)
else:
......
......@@ -11,7 +11,8 @@ from __future__ import division
from __future__ import print_function
from data_utils import PinDataset
'''
if __name__ == '__main__':
pin_dataset = PinDataset()
pin_dataset.download_images()
'''
......@@ -2,6 +2,7 @@ import numpy as np
import torch
import torch.nn as nn
from logger import Logger
from data_loader import DataLoader
LOGGER = Logger().logger()
class Model():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment