Commit 5a0fe4e1 authored by Yuxin Wu's avatar Yuxin Wu

remove some deprecations

parent 8a221643
...@@ -24,13 +24,14 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1' ...@@ -24,13 +24,14 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1'
ON_RTD = (os.environ.get('READTHEDOCS') == 'True') ON_RTD = (os.environ.get('READTHEDOCS') == 'True')
MOCK_MODULES = ['scipy', 'tabulate', MOCK_MODULES = ['tabulate', 'h5py',
'sklearn.datasets', 'sklearn', 'cv2', 'zmq', 'subprocess32', 'lmdb',
'scipy.misc', 'h5py', 'nltk', 'sklearn', 'sklearn.datasets',
'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb', 'scipy', 'scipy.misc', 'scipy.io',
'tornado.concurrent', 'tornado', 'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow',
'msgpack', 'msgpack_numpy', 'msgpack', 'msgpack_numpy',
'gym', 'functools32', 'horovod', 'horovod.tensorflow'] 'functools32']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name) sys.modules[mod_name] = mock.Mock(name=mod_name)
sys.modules['cv2'].__version__ = '3.2.1' # fake version sys.modules['cv2'].__version__ = '3.2.1' # fake version
...@@ -364,16 +365,10 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -364,16 +365,10 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# hide deprecated stuff # hide deprecated stuff
if name in [ if name in [
'MultiGPUTrainerBase', 'MultiGPUTrainerBase',
'FeedfreeInferenceRunner',
'replace_get_variable',
'remap_get_variable',
'freeze_get_variable',
'predictor_factory',
'get_predictors', 'get_predictors',
'RandomCropAroundBox', 'RandomCropAroundBox',
'GaussianDeform', 'GaussianDeform',
'dump_chkpt_vars', 'dump_chkpt_vars',
'VisualQA',
'DumpTensor', 'DumpTensor',
'StagingInputWrapper', 'StagingInputWrapper',
'StepTensorPrinter', 'StepTensorPrinter',
......
...@@ -14,7 +14,6 @@ from six.moves import range ...@@ -14,7 +14,6 @@ from six.moves import range
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..input_source import ( from ..input_source import (
...@@ -25,7 +24,7 @@ from .base import Callback ...@@ -25,7 +24,7 @@ from .base import Callback
from .group import Callbacks from .group import Callbacks
from .inference import Inferencer from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner', __all__ = ['InferenceRunner',
'DataParallelInferenceRunner'] 'DataParallelInferenceRunner']
...@@ -165,11 +164,6 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -165,11 +164,6 @@ class InferenceRunner(InferenceRunnerBase):
inf.trigger_epoch() inf.trigger_epoch()
@deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11")
def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs)
class DataParallelInferenceRunner(InferenceRunnerBase): class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
Inference with data-parallel support on multiple GPUs. Inference with data-parallel support on multiple GPUs.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: visualqa.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..base import DataFlow
from ...utils.timer import timed_operation
from ...utils import logger
from six.moves import zip, map
from collections import Counter
import json
__all__ = []
def read_json(fname):
f = open(fname)
ret = json.load(f)
f.close()
return ret
class VisualQA(DataFlow):
"""
`Visual QA <http://visualqa.org/>`_ dataset.
It simply reads q/a json file and produce q/a pairs in their original format.
"""
def __init__(self, question_file, annotation_file):
logger.warn("dataset.VisualQA is deprecated!")
with timed_operation('Reading VQA JSON file'):
qobj, aobj = list(map(read_json, [question_file, annotation_file]))
self.task_type = qobj['task_type']
self.questions = qobj['questions']
self._size = len(self.questions)
self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions))
self._clean()
def _clean(self):
for a in self.anno:
for aa in a['answers']:
del aa['answer_id']
def size(self):
return self._size
def get_data(self):
for q, a in zip(self.questions, self.anno):
assert q['question_id'] == a['question_id']
yield [q, a]
def get_common_answer(self, n):
""" Get the n most common answers (could be phrases)
n=3000 ~= thresh 4
"""
cnt = Counter()
for anno in self.anno:
cnt[anno['multiple_choice_answer'].lower()] += 1
return [k[0] for k in cnt.most_common(n)]
def get_common_question_words(self, n):
""" Get the n most common words in questions
n=4600 ~= thresh 6
"""
from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter()
for q in self.questions:
cnt.update(word_tokenize(q['question'].lower()))
del cnt['?'] # probably don't need this
ret = cnt.most_common(n)
return [k[0] for k in ret]
if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
print(json.dumps(k))
break
vqa.get_common_answer(100)
...@@ -6,11 +6,7 @@ ...@@ -6,11 +6,7 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from ..utils.develop import deprecated __all__ = ['freeze_variables', 'remap_variables']
__all__ = ['replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable',
'remap_variables']
@contextmanager @contextmanager
...@@ -20,19 +16,6 @@ def custom_getter_scope(custom_getter): ...@@ -20,19 +16,6 @@ def custom_getter_scope(custom_getter):
yield yield
@deprecated("Use custom_getter_scope instead.", "2017-11-06")
def replace_get_variable(fn):
"""
Args:
fn: a function compatible with ``tf.get_variable``.
Returns:
a context with a custom getter
"""
def getter(_, *args, **kwargs):
return fn(*args, **kwargs)
return custom_getter_scope(getter)
def remap_variables(fn): def remap_variables(fn):
""" """
Use fn to map the output of any variable getter. Use fn to map the output of any variable getter.
...@@ -74,13 +57,3 @@ def freeze_variables(): ...@@ -74,13 +57,3 @@ def freeze_variables():
v = tf.stop_gradient(v) v = tf.stop_gradient(v)
return v return v
return custom_getter_scope(custom_getter) return custom_getter_scope(custom_getter)
@deprecated("Renamed to remap_variables", "2017-11-06")
def remap_get_variable():
return remap_variables()
@deprecated("Renamed to freeze_variables", "2017-11-06")
def freeze_get_variable():
return freeze_variables()
...@@ -256,6 +256,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -256,6 +256,7 @@ class HorovodTrainer(SingleCostTrainer):
cb.chief_only = False cb.chief_only = False
return [cb] return [cb]
@HIDE_DOC
def initialize(self, session_creator, session_init): def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator): if not isinstance(session_creator, NewSessionCreator):
raise ValueError( raise ValueError(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment