Commit 0a0101d0 authored by Yuxin Wu's avatar Yuxin Wu

move caffe pb outside

parent d66d7761
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir, memoized from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.timer import timed_operation from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ..base import DataFlow from ..base import DataFlow
...@@ -19,7 +19,6 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12'] ...@@ -19,7 +19,6 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
def log_once(s): logger.warn(s) def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
# TODO move caffe_pb outside # TODO move caffe_pb outside
class ILSVRCMeta(object): class ILSVRCMeta(object):
...@@ -31,8 +30,9 @@ class ILSVRCMeta(object): ...@@ -31,8 +30,9 @@ class ILSVRCMeta(object):
dir = get_dataset_dir('ilsvrc_metadata') dir = get_dataset_dir('ilsvrc_metadata')
self.dir = dir self.dir = dir
mkdir_p(self.dir) mkdir_p(self.dir)
self.caffe_pb_file = os.path.join(self.dir, 'caffe_pb2.py') self.caffepb = get_caffe_pb()
if not os.path.isfile(self.caffe_pb_file): f = os.path.join(self.dir, 'synsets.txt')
if not os.path.isfile(f):
self._download_caffe_meta() self._download_caffe_meta()
def get_synset_words_1000(self): def get_synset_words_1000(self):
...@@ -48,11 +48,6 @@ class ILSVRCMeta(object): ...@@ -48,11 +48,6 @@ class ILSVRCMeta(object):
fpath = download(CAFFE_ILSVRC12_URL, self.dir) fpath = download(CAFFE_ILSVRC12_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir) tarfile.open(fpath, 'r:gz').extractall(self.dir)
proto_path = download(CAFFE_PROTO_URL, self.dir)
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(self.dir))
assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?"
def get_image_list(self, name): def get_image_list(self, name):
""" """
:param name: 'train' or 'val' or 'test' :param name: 'train' or 'val' or 'test'
...@@ -73,9 +68,7 @@ class ILSVRCMeta(object): ...@@ -73,9 +68,7 @@ class ILSVRCMeta(object):
:param size: return image size in [h, w]. default to (256, 256) :param size: return image size in [h, w]. default to (256, 256)
:returns: per-pixel mean as an array of shape (h, w, 3) in range [0, 255] :returns: per-pixel mean as an array of shape (h, w, 3) in range [0, 255]
""" """
import imp obj = self.caffepb.BlobProto()
caffepb = imp.load_source('caffepb', self.caffe_pb_file)
obj = caffepb.BlobProto()
mean_file = os.path.join(self.dir, 'imagenet_mean.binaryproto') mean_file = os.path.join(self.dir, 'imagenet_mean.binaryproto')
with open(mean_file, 'rb') as f: with open(mean_file, 'rb') as f:
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# File: format.py # File: format.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger from ..utils import logger, get_rng
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from .base import DataFlow from .base import DataFlow
import random import random
...@@ -98,9 +100,7 @@ class CaffeLMDB(LMDBData): ...@@ -98,9 +100,7 @@ class CaffeLMDB(LMDBData):
""" """
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle) super(CaffeLMDB, self).__init__(lmdb_dir, shuffle)
import imp self.cpb = get_caffe_pb()
meta = ILSVRCMeta()
self.cpb = imp.load_source('cpb', meta.caffe_pb_file)
def get_data(self): def get_data(self):
datum = self.cpb.Datum() datum = self.cpb.Datum()
......
...@@ -11,9 +11,14 @@ import os ...@@ -11,9 +11,14 @@ import os
from six.moves import zip from six.moves import zip
from .utils import change_env from .utils import change_env, get_dataset_dir
from .fs import download
from . import logger from . import logger
__all__ = ['load_caffe']
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
def get_processor(): def get_processor():
ret = {} ret = {}
def process_conv(layer_name, param, input_data_shape): def process_conv(layer_name, param, input_data_shape):
...@@ -68,6 +73,17 @@ def load_caffe(model_desc, model_file): ...@@ -68,6 +73,17 @@ def load_caffe(model_desc, model_file):
" ".join(sorted(param_dict.keys()))) " ".join(sorted(param_dict.keys())))
return param_dict return param_dict
def get_caffe_pb():
dir = get_dataset_dir('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir)
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?"
import imp
return imp.load_source('caffepb', caffe_pb_file)
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment