Commit b315a1a7 authored by Yuxin Wu's avatar Yuxin Wu

update resnet readme, add basic vqa

parent f67dc2a0
......@@ -59,3 +59,4 @@ docs/_build/
# PyBuilder
target/
*.dat
......@@ -4,4 +4,7 @@
Implements the paper "Deep Residual Learning for Image Recognition", [http://arxiv.org/abs/1512.03385](http://arxiv.org/abs/1512.03385)
with the variants proposed in "Identity Mappings in Deep Residual Networks", [https://arxiv.org/abs/1603.05027](https://arxiv.org/abs/1603.05027).
The train error shown here is a moving average of the error rate of each batch in training.
The validation error here is computed on test set.
![cifar10](https://github.com/ppwwyyxx/tensorpack/raw/master/examples/ResNet/cifar10-resnet.png)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: visualqa.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..base import DataFlow
from six.moves import zip, map
from collections import Counter
import json
__all__ = ['VisualQA']
# TODO shuffle
class VisualQA(DataFlow):
"""
Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format.
"""
def __init__(self, question_file, annotation_file):
qobj = json.load(open(question_file))
self.task_type = qobj['task_type']
self.questions = qobj['questions']
self._size = len(self.questions)
aobj = json.load(open(annotation_file))
self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions))
self._clean()
def _clean(self):
for a in self.anno:
for aa in a['answers']:
del aa['answer_id']
def size(self):
return self._size
def get_data(self):
for q, a in zip(self.questions, self.anno):
assert q['question_id'] == a['question_id']
yield [q, a]
def get_common_answer(self, n):
""" Get the n most common answers (could be phrases) """
cnt = Counter()
for anno in self.anno:
cnt[anno['multiple_choice_answer']] += 1
return [k[0] for k in cnt.most_common(n)]
def get_common_question_words(self, n):
"""
Get the n most common words in questions
"""
from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter()
for q in self.questions:
cnt.update(word_tokenize(q['question'].lower()))
del cnt['?'] # probably don't need this
ret = cnt.most_common(n)
return [k[0] for k in ret]
if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
#print json.dumps(k)
break
vqa.get_common_question_words(100)
#from IPython import embed; embed()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment