Commit f8b54d8e authored by Yuxin Wu's avatar Yuxin Wu

add loadcaffe

parent ef1b20f9
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
import re import re
from ..utils import logger from ..utils import logger
from ..utils import * from ..utils.utils import *
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer'] __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer']
......
...@@ -4,14 +4,8 @@ ...@@ -4,14 +4,8 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
import time
import sys
from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import collections
from . import logger
def global_import(name): def global_import(name):
p = __import__(name, globals(), None, level=1) p = __import__(name, globals(), None, level=1)
...@@ -20,16 +14,9 @@ def global_import(name): ...@@ -20,16 +14,9 @@ def global_import(name):
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
global_import('naming') global_import('naming')
global_import('sessinit') global_import('sessinit')
global_import('utils')
@contextmanager # TODO move this utils to another file
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
def get_default_sess_config(mem_fraction=0.5): def get_default_sess_config(mem_fraction=0.5):
""" """
Return a better config to use as default. Return a better config to use as default.
...@@ -41,35 +28,6 @@ def get_default_sess_config(mem_fraction=0.5): ...@@ -41,35 +28,6 @@ def get_default_sess_config(mem_fraction=0.5):
conf.allow_soft_placement = True conf.allow_soft_placement = True
return conf return conf
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
def get_global_step_var(): def get_global_step_var():
""" get global_step variable in the current graph""" """ get global_step variable in the current graph"""
try: try:
...@@ -84,7 +42,3 @@ def get_global_step(): ...@@ -84,7 +42,3 @@ def get_global_step():
return tf.train.global_step( return tf.train.global_step(
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
def mkdir_p(dirname):
assert dirname is not None
if dirname == '':
return
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != 17:
raise e
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: loadcaffe.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from collections import namedtuple, defaultdict
from abc import abstractmethod
import os
from six.moves import zip
from .utils import change_env
from . import logger
def get_processor():
ret = {}
def process_conv(layer_name, param):
assert len(param) == 2
# caffe: ch_out, ch_in, h, w
return {layer_name + '/W': param[0].data.transpose(2,3,1,0),
layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv
# XXX caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param):
assert len(param) == 2
return {layer_name + '/W': param[0].data.transpose(),
layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc
return ret
def load_caffe(model_desc, model_file):
"""
return a dict of params
"""
param_dict = {}
param_processors = get_processor()
with change_env('GLOG_minloglevel', '2'):
import caffe
net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names
for layername, layer in zip(layer_names, net.layers):
if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs))
else:
assert len(layer.blobs) == 0, len(layer.blobs)
logger.info("Model loaded from caffe. Params: " + \
" ".join(sorted(param_dict.keys())))
return param_dict
if __name__ == '__main__':
ret = load_caffe('/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers_deploy.prototxt',
'/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers.caffemodel')
...@@ -10,7 +10,7 @@ from datetime import datetime ...@@ -10,7 +10,7 @@ from datetime import datetime
from six.moves import input from six.moves import input
import sys import sys
from .utils import mkdir_p from .fs import mkdir_p
__all__ = [] __all__ = []
......
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: utils.py # File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
import os, sys
from contextlib import contextmanager
import time
import collections
from . import logger
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized']
#def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
# """ # """
# Args: # Args:
...@@ -17,13 +24,54 @@ import os ...@@ -17,13 +24,54 @@ import os
# dp = dp.reshape(new_shape) # dp = dp.reshape(new_shape)
# return dp # return dp
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
@contextmanager
def change_env(name, val):
oldval = os.environ.get(name, None)
os.environ[name] = val
yield
if oldval is None:
del os.environ[name]
else:
os.environ[name] = oldval
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
def mkdir_p(dirname): def get_rng(self):
assert dirname is not None seed = (id(self) + os.getpid()) % 4294967295
if dirname == '': return np.random.RandomState(seed)
return
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != 17:
raise e
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment