Commit 0a0101d0 authored by Yuxin Wu's avatar Yuxin Wu

move caffe pb outside

parent d66d7761
......@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.timer import timed_operation
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download
from ..base import DataFlow
......@@ -19,7 +19,6 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
# TODO move caffe_pb outside
class ILSVRCMeta(object):
......@@ -31,8 +30,9 @@ class ILSVRCMeta(object):
dir = get_dataset_dir('ilsvrc_metadata')
self.dir = dir
mkdir_p(self.dir)
self.caffe_pb_file = os.path.join(self.dir, 'caffe_pb2.py')
if not os.path.isfile(self.caffe_pb_file):
self.caffepb = get_caffe_pb()
f = os.path.join(self.dir, 'synsets.txt')
if not os.path.isfile(f):
self._download_caffe_meta()
def get_synset_words_1000(self):
......@@ -48,11 +48,6 @@ class ILSVRCMeta(object):
fpath = download(CAFFE_ILSVRC12_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir)
proto_path = download(CAFFE_PROTO_URL, self.dir)
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(self.dir))
assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?"
def get_image_list(self, name):
"""
:param name: 'train' or 'val' or 'test'
......@@ -73,9 +68,7 @@ class ILSVRCMeta(object):
:param size: return image size in [h, w]. default to (256, 256)
:returns: per-pixel mean as an array of shape (h, w, 3) in range [0, 255]
"""
import imp
caffepb = imp.load_source('caffepb', self.caffe_pb_file)
obj = caffepb.BlobProto()
obj = self.caffepb.BlobProto()
mean_file = os.path.join(self.dir, 'imagenet_mean.binaryproto')
with open(mean_file, 'rb') as f:
......
......@@ -2,7 +2,9 @@
# File: format.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger
from ..utils import logger, get_rng
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from .base import DataFlow
import random
......@@ -98,9 +100,7 @@ class CaffeLMDB(LMDBData):
"""
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle)
import imp
meta = ILSVRCMeta()
self.cpb = imp.load_source('cpb', meta.caffe_pb_file)
self.cpb = get_caffe_pb()
def get_data(self):
datum = self.cpb.Datum()
......
......@@ -11,9 +11,14 @@ import os
from six.moves import zip
from .utils import change_env
from .utils import change_env, get_dataset_dir
from .fs import download
from . import logger
__all__ = ['load_caffe']
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
def get_processor():
ret = {}
def process_conv(layer_name, param, input_data_shape):
......@@ -68,6 +73,17 @@ def load_caffe(model_desc, model_file):
" ".join(sorted(param_dict.keys())))
return param_dict
def get_caffe_pb():
dir = get_dataset_dir('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir)
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?"
import imp
return imp.load_source('caffepb', caffe_pb_file)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment