Commit 5a0fe4e1 authored by Yuxin Wu's avatar Yuxin Wu

remove some deprecations

parent 8a221643
......@@ -24,13 +24,14 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1'
ON_RTD = (os.environ.get('READTHEDOCS') == 'True')
MOCK_MODULES = ['scipy', 'tabulate',
'sklearn.datasets', 'sklearn',
'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb',
'tornado.concurrent', 'tornado',
MOCK_MODULES = ['tabulate', 'h5py',
'cv2', 'zmq', 'subprocess32', 'lmdb',
'sklearn', 'sklearn.datasets',
'scipy', 'scipy.misc', 'scipy.io',
'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow',
'msgpack', 'msgpack_numpy',
'gym', 'functools32', 'horovod', 'horovod.tensorflow']
'functools32']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name)
sys.modules['cv2'].__version__ = '3.2.1' # fake version
......@@ -364,16 +365,10 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# hide deprecated stuff
if name in [
'MultiGPUTrainerBase',
'FeedfreeInferenceRunner',
'replace_get_variable',
'remap_get_variable',
'freeze_get_variable',
'predictor_factory',
'get_predictors',
'RandomCropAroundBox',
'GaussianDeform',
'dump_chkpt_vars',
'VisualQA',
'DumpTensor',
'StagingInputWrapper',
'StepTensorPrinter',
......
......@@ -14,7 +14,6 @@ from six.moves import range
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow.base import DataFlow
from ..input_source import (
......@@ -25,7 +24,7 @@ from .base import Callback
from .group import Callbacks
from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
__all__ = ['InferenceRunner',
'DataParallelInferenceRunner']
......@@ -165,11 +164,6 @@ class InferenceRunner(InferenceRunnerBase):
inf.trigger_epoch()
@deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11")
def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs)
class DataParallelInferenceRunner(InferenceRunnerBase):
"""
Inference with data-parallel support on multiple GPUs.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: visualqa.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..base import DataFlow
from ...utils.timer import timed_operation
from ...utils import logger
from six.moves import zip, map
from collections import Counter
import json
__all__ = []
def read_json(fname):
f = open(fname)
ret = json.load(f)
f.close()
return ret
class VisualQA(DataFlow):
"""
`Visual QA <http://visualqa.org/>`_ dataset.
It simply reads q/a json file and produce q/a pairs in their original format.
"""
def __init__(self, question_file, annotation_file):
logger.warn("dataset.VisualQA is deprecated!")
with timed_operation('Reading VQA JSON file'):
qobj, aobj = list(map(read_json, [question_file, annotation_file]))
self.task_type = qobj['task_type']
self.questions = qobj['questions']
self._size = len(self.questions)
self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions))
self._clean()
def _clean(self):
for a in self.anno:
for aa in a['answers']:
del aa['answer_id']
def size(self):
return self._size
def get_data(self):
for q, a in zip(self.questions, self.anno):
assert q['question_id'] == a['question_id']
yield [q, a]
def get_common_answer(self, n):
""" Get the n most common answers (could be phrases)
n=3000 ~= thresh 4
"""
cnt = Counter()
for anno in self.anno:
cnt[anno['multiple_choice_answer'].lower()] += 1
return [k[0] for k in cnt.most_common(n)]
def get_common_question_words(self, n):
""" Get the n most common words in questions
n=4600 ~= thresh 6
"""
from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter()
for q in self.questions:
cnt.update(word_tokenize(q['question'].lower()))
del cnt['?'] # probably don't need this
ret = cnt.most_common(n)
return [k[0] for k in ret]
if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
print(json.dumps(k))
break
vqa.get_common_answer(100)
......@@ -6,11 +6,7 @@
import tensorflow as tf
from contextlib import contextmanager
from ..utils.develop import deprecated
__all__ = ['replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable',
'remap_variables']
__all__ = ['freeze_variables', 'remap_variables']
@contextmanager
......@@ -20,19 +16,6 @@ def custom_getter_scope(custom_getter):
yield
@deprecated("Use custom_getter_scope instead.", "2017-11-06")
def replace_get_variable(fn):
"""
Args:
fn: a function compatible with ``tf.get_variable``.
Returns:
a context with a custom getter
"""
def getter(_, *args, **kwargs):
return fn(*args, **kwargs)
return custom_getter_scope(getter)
def remap_variables(fn):
"""
Use fn to map the output of any variable getter.
......@@ -74,13 +57,3 @@ def freeze_variables():
v = tf.stop_gradient(v)
return v
return custom_getter_scope(custom_getter)
@deprecated("Renamed to remap_variables", "2017-11-06")
def remap_get_variable():
return remap_variables()
@deprecated("Renamed to freeze_variables", "2017-11-06")
def freeze_get_variable():
return freeze_variables()
......@@ -256,6 +256,7 @@ class HorovodTrainer(SingleCostTrainer):
cb.chief_only = False
return [cb]
@HIDE_DOC
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator):
raise ValueError(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment