Commit 7240f877 authored by Yuxin Wu's avatar Yuxin Wu

fix import for 'git pull' users

parent a8cb7e33
......@@ -11,13 +11,14 @@ import pprint
from tensorpack.tfutils.varmanip import get_checkpoint_path
fpath = sys.argv[1]
if __name__ == '__main__':
fpath = sys.argv[1]
if fpath.endswith('.npy'):
params = np.load(fpath, encoding='latin1').item()
dic = {k: v.shape for k, v in six.iteritems(params)}
else:
path = get_checkpoint_path(sys.argv[1])
reader = tf.train.NewCheckpointReader(path)
dic = reader.get_variable_to_shape_map()
pprint.pprint(dic)
if fpath.endswith('.npy'):
params = np.load(fpath, encoding='latin1').item()
dic = {k: v.shape for k, v in six.iteritems(params)}
else:
path = get_checkpoint_path(sys.argv[1])
reader = tf.train.NewCheckpointReader(path)
dic = reader.get_variable_to_shape_map()
pprint.pprint(dic)
......@@ -20,7 +20,11 @@ def _global_import(name):
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in iter_modules(
[os.path.dirname(__file__)]):
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if not module_name.startswith('_'):
_global_import(module_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment