Commit b6a775f4 authored by Yuxin Wu's avatar Yuxin Wu

python3 compat

parent 3f743301
...@@ -21,7 +21,7 @@ import os ...@@ -21,7 +21,7 @@ import os
sys.path.insert(0, os.path.abspath('../')) sys.path.insert(0, os.path.abspath('../'))
import mock import mock
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk'] MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk', 'cv2']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock() sys.modules[mod_name] = mock.Mock()
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six.moves import zip from six.moves import zip, map
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import * from ..utils import *
...@@ -102,7 +102,7 @@ class InferenceRunner(Callback): ...@@ -102,7 +102,7 @@ class InferenceRunner(Callback):
def get_tensor(name): def get_tensor(name):
_, varname = get_op_var_name(name) _, varname = get_op_var_name(name)
return self.graph.get_tensor_by_name(varname) return self.graph.get_tensor_by_name(varname)
self.output_tensors = map(get_tensor, self.output_tensors) self.output_tensors = list(map(get_tensor, self.output_tensors))
def _trigger_epoch(self): def _trigger_epoch(self):
for vc in self.vcs: for vc in self.vcs:
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np from __future__ import division
import copy import copy
import numpy as np
from six.moves import range from six.moves import range
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
...@@ -30,7 +31,7 @@ class BatchData(ProxyDataFlow): ...@@ -30,7 +31,7 @@ class BatchData(ProxyDataFlow):
def size(self): def size(self):
ds_size = self.ds.size() ds_size = self.ds.size()
div = ds_size / self.batch_size div = ds_size // self.batch_size
rem = ds_size % self.batch_size rem = ds_size % self.batch_size
if rem == 0: if rem == 0:
return div return div
......
...@@ -73,7 +73,7 @@ if __name__ == '__main__': ...@@ -73,7 +73,7 @@ if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json', vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json') '/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data(): for k in vqa.get_data():
print json.dumps(k) print(json.dumps(k))
break break
# vqa.get_common_question_words(100) # vqa.get_common_question_words(100)
vqa.get_common_answer(100) vqa.get_common_answer(100)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment