Commit f16643ac authored by ppwwyyxx's avatar ppwwyyxx

testing and loading script

parent dd031661
*.gz
*.npy
train_log
# Byte-compiled / optimized / DLL files
......
......@@ -62,7 +62,7 @@ def get_model(inputs, is_training):
# fc will have activation summary by default. disable this for the output layer
logits = FullyConnected('fc1', l, out_dim=10,
summary_activation=False, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
prob = tf.nn.softmax(logits, name='prob')
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump_model_params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
from tensorpack.utils import *
from tensorpack.utils import sessinit
from tensorpack.dataflow import *
from tensorpack.predict import DatasetPredictor
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument(dest='model')
parser.add_argument(dest='output')
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G:
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
config = get_config_func()
config['get_model_func'](config['inputs'], is_training=False)
init = sessinit.SaverRestore(args.model)
sess = tf.Session()
init.init(sess)
with sess.as_default():
sessinit.dump_session_params(args.output)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: imgclassify.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
from tensorpack.utils import *
from tensorpack.utils import sessinit
from tensorpack.dataflow import *
from tensorpack.predict import DatasetPredictor
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument(dest='model')
parser.add_argument(dest='images', nargs='+')
parser.add_argument('--output_type', default='label',
choices=['label', 'label-prob', 'raw'])
parser.add_argument('--top', default=1, type=int)
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G:
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
config = get_config_func()
config['session_init'] = sessinit.SaverRestore(args.model)
config['output_var'] = 'output:0'
ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128)
res = predictor.get_all_result()
if args.output_type == 'label':
for r in res:
print r.argsort()[-top:][::-1]
elif args.output_type == 'label_prob':
raise NotImplementedError
elif args.output_type == 'raw':
print res
......@@ -17,6 +17,7 @@ class BatchData(DataFlow):
if set, might return a data point of a different shape
"""
self.ds = ds
if not remainder:
assert batch_size <= ds.size()
self.batch_size = batch_size
self.remainder = remainder
......@@ -85,7 +86,6 @@ class FakeData(DataFlow):
for _ in xrange(self._size):
yield [np.random.random(k) for k in self.shapes]
class MapData(DataFlow):
""" Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: image.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import cv2
from .base import DataFlow
__all__ = ['ImageFromFile']
class ImageFromFile(DataFlow):
""" generate rgb images from files """
def __init__(self, files, channel, resize=None):
""" files: list of file path
channel: 1 or 3 channel
resize: a (w, h) tuple. If given, will force a resize
"""
self.files = files
self.channel = int(channel)
self.resize = resize
def size(self):
return len(self.files)
def get_data(self):
for f in self.files:
im = cv2.imread(
f, cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR)
if self.channel == 3:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.resize is not None:
im = cv2.resize(im, self.resize)
yield (im,)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: infer.py
# File: predict.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......@@ -11,7 +11,7 @@ import numpy as np
from utils import *
from utils.modelutils import describe_model
from utils import logger
from dataflow import DataFlow
from dataflow import DataFlow, BatchData
def get_predict_func(config):
"""
......@@ -27,6 +27,10 @@ def get_predict_func(config):
sess_init = config['session_init']
# Provide this if only specific output is needed.
# by default will evaluate all outputs as well as cost
output_var_name = config.get('output_var', None)
# input/output variables
input_vars = config['inputs']
get_model_func = config['get_model_func']
......@@ -38,18 +42,30 @@ def get_predict_func(config):
sess_init.init(sess)
def run_input(dp):
# TODO if input and dp not aligned?
feed = dict(zip(input_vars, dp))
results = sess.run(
[cost_var] + output_vars, feed_dict=feed)
if output_var_name is not None:
fetches = tf.get_default_graph().get_tensor_by_name(output_var_name)
results = sess.run(fetches, feed_dict=feed)
return results[0]
else:
fetches = [cost_var] + output_vars
results = sess.run(fetches, feed_dict=feed)
cost = results[0]
outputs = results[1:]
return cost, outputs
return run_input
class DatasetPredictor(object):
def __init__(self, predict_config, dataset):
def __init__(self, predict_config, dataset, batch=0):
"""
A predictor with the given predict_config, run on the given dataset
if batch is larger than zero, the dataset will be batched
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
if batch > 0:
self.ds = BatchData(self.ds, batch, remainder=True)
self.predict_func = get_predict_func(predict_config)
def get_result(self):
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import abstractmethod
import numpy as np
import tensorflow as tf
from . import logger
......@@ -24,7 +25,7 @@ class SaverRestore(SessionInit):
saver = tf.train.Saver()
saver.restore(sess, self.path)
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
"Restore checkpoint from {}".format(self.path))
def set_path(self, model_path):
self.path = model_path
......@@ -44,3 +45,12 @@ class ParamRestore(SessionInit):
continue
logger.info("Restoring param {}".format(name))
sess.run(var.assign(value))
def dump_session_params(path):
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
result = {}
for v in var:
result[v.name] = v.eval()
logger.info("Params to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment