Commit f9356946 authored by Yuxin Wu's avatar Yuxin Wu

command line tools to profile a graph

parent 54e391c0
...@@ -9,20 +9,26 @@ from tensorpack.tfutils.varmanip import dump_chkpt_vars ...@@ -9,20 +9,26 @@ from tensorpack.tfutils.varmanip import dump_chkpt_vars
from tensorpack.utils import logger from tensorpack.utils import logger
import argparse import argparse
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint')
parser.add_argument('--dump', help='dump to an npy file')
parser.add_argument('--shell', action='store_true', help='start a shell with the params')
args = parser.parse_args()
if args.checkpoint.endswith('.npy'): if __name__ == '__main__':
params = np.load(args.checkpoint).item() parser = argparse.ArgumentParser()
else: parser.add_argument('model')
params = dump_chkpt_vars(args.checkpoint) parser.add_argument('--dump', help='dump to an npy file')
logger.info("Variables in the checkpoint:") parser.add_argument('--shell', action='store_true', help='start a shell with the params')
logger.info(str(params.keys())) args = parser.parse_args()
if args.dump:
np.save(args.dump, params) if args.model.endswith('.npy'):
if args.shell: params = np.load(args.model).item()
import IPython as IP else:
IP.embed(config=IP.terminal.ipapp.load_default_config()) params = dump_chkpt_vars(args.model)
logger.info("Variables in the model:")
logger.info(str(params.keys()))
if args.dump:
assert args.dump.endswith('.npy'), args.dump
np.save(args.dump, params)
if args.shell:
# params is a dict. play with it
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: checkpoint-prof.py
import tensorflow as tf
import numpy as np
from tensorpack import get_default_sess_config, get_op_tensor_name
from tensorpack.utils import logger
from tensorpack.tfutils.sessinit import get_model_loader
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='model file')
parser.add_argument('--meta', help='metagraph proto file. Will be used to load the graph', required=True)
parser.add_argument('-i', '--input', nargs='+', help='list of input tensors with their shapes.')
parser.add_argument('-o', '--output', nargs='+', help='list of output tensors')
parser.add_argument('--warmup', help='warmup iterations', type=int, default=5)
parser.add_argument('--print-flops', action='store_true')
parser.add_argument('--print-params', action='store_true')
parser.add_argument('--print-timing', action='store_true')
args = parser.parse_args()
tf.train.import_meta_graph(args.meta)
G = tf.get_default_graph()
with tf.Session(config=get_default_sess_config()) as sess:
init = get_model_loader(args.model)
init.init(sess)
feed = {}
for inp in args.input:
inp = inp.split('=')
name = get_op_tensor_name(inp[0].strip())[1]
shape = map(int, inp[1].strip().split(','))
tensor = G.get_tensor_by_name(name)
logger.info("Feeding shape ({}) to tensor {}".format(','.join(map(str, shape)), name))
feed[tensor] = np.random.rand(*shape)
fetches = []
for name in args.output:
name = get_op_tensor_name(name)[1]
fetches.append(G.get_tensor_by_name(name))
logger.info("Fetching tensors: {}".format(', '.join([k.name for k in fetches])))
for _ in range(args.warmup):
sess.run(fetches, feed_dict=feed)
opt = tf.RunOptions()
opt.trace_level = tf.RunOptions.FULL_TRACE
meta = tf.RunMetadata()
sess.run(fetches, feed_dict=feed, options=opt, run_metadata=meta)
if args.print_flops:
tf.contrib.tfprof.model_analyzer.print_model_analysis(
G, run_meta=meta,
tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
if args.print_params:
tf.contrib.tfprof.model_analyzer.print_model_analysis(
G, run_meta=meta,
tfprof_options=tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
if args.print_timing:
tf.contrib.tfprof.model_analyzer.print_model_analysis(
G, run_meta=meta,
tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)
...@@ -8,7 +8,7 @@ import argparse ...@@ -8,7 +8,7 @@ import argparse
import tensorflow as tf import tensorflow as tf
import imp import imp
from tensorpack import TowerContext, logger, ModelFromMetaGraph from tensorpack import TowerContext, logger
from tensorpack.tfutils import sessinit, varmanip from tensorpack.tfutils import sessinit, varmanip
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -28,7 +28,7 @@ with tf.Graph().as_default() as G: ...@@ -28,7 +28,7 @@ with tf.Graph().as_default() as G:
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
M.build_graph(M.get_reused_placehdrs()) M.build_graph(M.get_reused_placehdrs())
else: else:
M = ModelFromMetaGraph(args.meta) tf.train.import_meta_graph(args.meta)
# loading... # loading...
if args.model.endswith('.npy'): if args.model.endswith('.npy'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment