Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d5fe531d
Commit
d5fe531d
authored
Apr 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
lut & update vqa
parent
b315a1a7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
37 additions
and
17 deletions
+37
-17
scripts/dump_train_config.py
scripts/dump_train_config.py
+1
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+1
-1
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+19
-14
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+1
-0
tensorpack/utils/lut.py
tensorpack/utils/lut.py
+13
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+2
-2
No files found.
scripts/dump_train_config.py
View file @
d5fe531d
...
...
@@ -49,6 +49,7 @@ with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
if
idx
>
NR_DP_TEST
:
break
pbar
.
update
()
from
IPython
import
embed
;
embed
()
tensorpack/dataflow/common.py
View file @
d5fe531d
...
...
@@ -16,7 +16,7 @@ class BatchData(ProxyDataFlow):
"""
Group data in `ds` into batches.
:param ds: a DataFlow instance
:param ds: a DataFlow instance
. Its component must be either a scalar or a numpy array
:param remainder: whether to return the remaining data smaller than a batch_size.
If set True, will possibly return a data point of a smaller 1st dimension.
Otherwise, all generated data are guranteed to have the same size.
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
d5fe531d
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..base
import
DataFlow
from
...utils
import
*
from
six.moves
import
zip
,
map
from
collections
import
Counter
import
json
...
...
@@ -17,16 +18,17 @@ class VisualQA(DataFlow):
Simply read q/a json file and produce q/a pairs in their original format.
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
qobj
=
json
.
load
(
open
(
question_file
))
self
.
task_type
=
qobj
[
'task_type'
]
self
.
questions
=
qobj
[
'questions'
]
self
.
_size
=
len
(
self
.
questions
)
with
timed_operation
(
'Reading VQA JSON file'
):
qobj
=
json
.
load
(
open
(
question_file
))
self
.
task_type
=
qobj
[
'task_type'
]
self
.
questions
=
qobj
[
'questions'
]
self
.
_size
=
len
(
self
.
questions
)
aobj
=
json
.
load
(
open
(
annotation_file
))
self
.
anno
=
aobj
[
'annotations'
]
assert
len
(
self
.
anno
)
==
len
(
self
.
questions
),
\
"{}!={}"
.
format
(
len
(
self
.
anno
),
len
(
self
.
questions
))
self
.
_clean
()
aobj
=
json
.
load
(
open
(
annotation_file
))
self
.
anno
=
aobj
[
'annotations'
]
assert
len
(
self
.
anno
)
==
len
(
self
.
questions
),
\
"{}!={}"
.
format
(
len
(
self
.
anno
),
len
(
self
.
questions
))
self
.
_clean
()
def
_clean
(
self
):
for
a
in
self
.
anno
:
...
...
@@ -42,15 +44,17 @@ class VisualQA(DataFlow):
yield
[
q
,
a
]
def
get_common_answer
(
self
,
n
):
""" Get the n most common answers (could be phrases) """
""" Get the n most common answers (could be phrases)
n=3000 ~= thresh 4
"""
cnt
=
Counter
()
for
anno
in
self
.
anno
:
cnt
[
anno
[
'multiple_choice_answer'
]]
+=
1
return
[
k
[
0
]
for
k
in
cnt
.
most_common
(
n
)]
def
get_common_question_words
(
self
,
n
):
"""
Get the n most common words in questions
"""
Get the n most common words in questions
n=4600 ~= thresh 6
"""
from
nltk.tokenize
import
word_tokenize
# will need to download 'punckt'
cnt
=
Counter
()
...
...
@@ -64,7 +68,8 @@ if __name__ == '__main__':
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
for
k
in
vqa
.
get_data
():
#
print json.dumps(k)
print
json
.
dumps
(
k
)
break
vqa
.
get_common_question_words
(
100
)
# vqa.get_common_question_words(100)
vqa
.
get_common_answer
(
100
)
#from IPython import embed; embed()
tensorpack/dataflow/prefetch.py
View file @
d5fe531d
...
...
@@ -7,6 +7,7 @@ import multiprocessing
from
six.moves
import
range
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
ensure_procs_terminate
from
..utils
import
logger
__all__
=
[
'PrefetchData'
]
...
...
tensorpack/utils/lut.py
0 → 100644
View file @
d5fe531d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: lut.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
__all__
=
[
'LookUpTable'
]
class
LookUpTable
(
object
):
def
__init__
(
self
,
objlist
):
self
.
idx2obj
=
dict
(
enumerate
(
objlist
))
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
tensorpack/utils/utils.py
View file @
d5fe531d
...
...
@@ -11,8 +11,8 @@ import numpy as np
from
.
import
logger
__all__
=
[
'timed_operation'
,
'change_env'
,
'get_rng'
,
'memoized'
,
'get_nr_gpu'
]
__all__
=
[
'timed_operation'
,
'change_env'
,
'get_rng'
,
'memoized'
,
'get_nr_gpu'
]
#def expand_dim_if_necessary(var, dp):
# """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment