Commit f8b54d8e authored by Yuxin Wu's avatar Yuxin Wu

add loadcaffe

parent ef1b20f9
......@@ -6,7 +6,7 @@ import tensorflow as tf
import re
from ..utils import logger
from ..utils import *
from ..utils.utils import *
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer']
......
......@@ -4,14 +4,8 @@
from pkgutil import walk_packages
import os
import time
import sys
from contextlib import contextmanager
import tensorflow as tf
import numpy as np
import collections
from . import logger
def global_import(name):
p = __import__(name, globals(), None, level=1)
......@@ -20,16 +14,9 @@ def global_import(name):
globals()[k] = p.__dict__[k]
global_import('naming')
global_import('sessinit')
global_import('utils')
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
# TODO move this utils to another file
def get_default_sess_config(mem_fraction=0.5):
"""
Return a better config to use as default.
......@@ -41,35 +28,6 @@ def get_default_sess_config(mem_fraction=0.5):
conf.allow_soft_placement = True
return conf
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
def get_global_step_var():
""" get global_step variable in the current graph"""
try:
......@@ -84,7 +42,3 @@ def get_global_step():
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
def mkdir_p(dirname):
assert dirname is not None
if dirname == '':
return
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != 17:
raise e
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: loadcaffe.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from collections import namedtuple, defaultdict
from abc import abstractmethod
import os
from six.moves import zip
from .utils import change_env
from . import logger
def get_processor():
ret = {}
def process_conv(layer_name, param):
assert len(param) == 2
# caffe: ch_out, ch_in, h, w
return {layer_name + '/W': param[0].data.transpose(2,3,1,0),
layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv
# XXX caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param):
assert len(param) == 2
return {layer_name + '/W': param[0].data.transpose(),
layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc
return ret
def load_caffe(model_desc, model_file):
"""
return a dict of params
"""
param_dict = {}
param_processors = get_processor()
with change_env('GLOG_minloglevel', '2'):
import caffe
net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names
for layername, layer in zip(layer_names, net.layers):
if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs))
else:
assert len(layer.blobs) == 0, len(layer.blobs)
logger.info("Model loaded from caffe. Params: " + \
" ".join(sorted(param_dict.keys())))
return param_dict
if __name__ == '__main__':
ret = load_caffe('/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers_deploy.prototxt',
'/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers.caffemodel')
......@@ -10,7 +10,7 @@ from datetime import datetime
from six.moves import input
import sys
from .utils import mkdir_p
from .fs import mkdir_p
__all__ = []
......
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
import os, sys
from contextlib import contextmanager
import time
import collections
from . import logger
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized']
#def expand_dim_if_necessary(var, dp):
# """
# Args:
......@@ -17,13 +24,54 @@ import os
# dp = dp.reshape(new_shape)
# return dp
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
@contextmanager
def change_env(name, val):
oldval = os.environ.get(name, None)
os.environ[name] = val
yield
if oldval is None:
del os.environ[name]
else:
os.environ[name] = oldval
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
def mkdir_p(dirname):
assert dirname is not None
if dirname == '':
return
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != 17:
raise e
def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment