Commit 5ccaea83 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'new-infer'

parents 8644248a 818e3faf
...@@ -8,14 +8,9 @@ import argparse ...@@ -8,14 +8,9 @@ import argparse
import numpy as np import numpy as np
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack import *
from tensorpack.models import * import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
""" """
A small convnet model for cifar 10 or cifar100 dataset. A small convnet model for cifar 10 or cifar100 dataset.
...@@ -65,7 +60,7 @@ class Model(ModelDesc): ...@@ -65,7 +60,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ClassificationError to use at test time # compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label) wrong = symbf.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
...@@ -161,4 +156,5 @@ if __name__ == '__main__': ...@@ -161,4 +156,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() #QueueInputTrainer(config).train()
SimpleTrainer(config).train()
...@@ -9,9 +9,6 @@ import os, sys ...@@ -9,9 +9,6 @@ import os, sys
import argparse import argparse
from tensorpack import * from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.callbacks import *
""" """
MNIST ConvNet example. MNIST ConvNet example.
...@@ -117,6 +114,5 @@ if __name__ == '__main__': ...@@ -117,6 +114,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
#QueueInputTrainer(config).train() QueueInputTrainer(config).train()
SimpleInputTrainer(config).train()
...@@ -12,6 +12,7 @@ from ..utils import * ...@@ -12,6 +12,7 @@ from ..utils import *
__all__ = ['Callbacks'] __all__ = ['Callbacks']
# --- Test-Callback related stuff seems not very useful.
@contextmanager @contextmanager
def create_test_graph(trainer): def create_test_graph(trainer):
model = trainer.model model = trainer.model
...@@ -31,33 +32,6 @@ def create_test_session(trainer): ...@@ -31,33 +32,6 @@ def create_test_session(trainer):
with tf.Session() as sess: with tf.Session() as sess:
yield sess yield sess
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
self.tot = 0
def add(self, name, time):
self.tot += time
self.times.append((name, time))
@contextmanager
def timed_callback(self, name):
s = time.time()
yield
self.add(name, time.time() - s)
def log(self):
""" log the time of some heavy callbacks """
if self.tot < 3:
return
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{:.3f}sec".format(name, t))
logger.info(
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs)))
class TestCallbackContext(object): class TestCallbackContext(object):
""" """
A class holding the context needed for running TestCallback A class holding the context needed for running TestCallback
...@@ -91,6 +65,34 @@ class TestCallbackContext(object): ...@@ -91,6 +65,34 @@ class TestCallbackContext(object):
def test_context(self): def test_context(self):
with self.graph.as_default(), self.sess.as_default(): with self.graph.as_default(), self.sess.as_default():
yield yield
# ---
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
self.tot = 0
def add(self, name, time):
self.tot += time
self.times.append((name, time))
@contextmanager
def timed_callback(self, name):
s = time.time()
yield
self.add(name, time.time() - s)
def log(self):
""" log the time of some heavy callbacks """
if self.tot < 3:
return
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{:.3f}sec".format(name, t))
logger.info(
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs)))
class Callbacks(Callback): class Callbacks(Callback):
""" """
......
...@@ -13,7 +13,7 @@ from ..utils import * ...@@ -13,7 +13,7 @@ from ..utils import *
from ..utils.stat import * from ..utils.stat import *
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import * from ..tfutils.summary import *
from .base import Callback, TestCallbackType from .base import Callback
__all__ = ['InferenceRunner', 'ClassificationError', __all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats'] 'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
...@@ -63,7 +63,6 @@ class InferenceRunner(Callback): ...@@ -63,7 +63,6 @@ class InferenceRunner(Callback):
""" """
A callback that runs different kinds of inferencer. A callback that runs different kinds of inferencer.
""" """
type = TestCallbackType()
def __init__(self, ds, vcs): def __init__(self, ds, vcs):
""" """
...@@ -82,12 +81,15 @@ class InferenceRunner(Callback): ...@@ -82,12 +81,15 @@ class InferenceRunner(Callback):
def _before_train(self): def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars() self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_tensors() self._find_output_tensors()
input_names = [x.name for x in self.input_vars]
self.pred_func = self.trainer.get_predict_func(
input_names, self.output_tensors)
for v in self.vcs: for v in self.vcs:
v.trainer = self.trainer v.trainer = self.trainer
def _find_output_tensors(self): def _find_output_tensors(self):
self.output_tensors = [] self.output_tensors = [] # list of names
self.vc_to_vars = [] self.vc_to_vars = [] # list of list of (var_name: output_idx)
for vc in self.vcs: for vc in self.vcs:
vc_vars = vc._get_output_tensors() vc_vars = vc._get_output_tensors()
def find_oid(var): def find_oid(var):
...@@ -99,12 +101,6 @@ class InferenceRunner(Callback): ...@@ -99,12 +101,6 @@ class InferenceRunner(Callback):
vc_vars = [(var, find_oid(var)) for var in vc_vars] vc_vars = [(var, find_oid(var)) for var in vc_vars]
self.vc_to_vars.append(vc_vars) self.vc_to_vars.append(vc_vars)
# convert name to tensors
def get_tensor(name):
_, varname = get_op_var_name(name)
return self.graph.get_tensor_by_name(varname)
self.output_tensors = list(map(get_tensor, self.output_tensors))
def _trigger_epoch(self): def _trigger_epoch(self):
for vc in self.vcs: for vc in self.vcs:
vc.before_inference() vc.before_inference()
...@@ -112,8 +108,9 @@ class InferenceRunner(Callback): ...@@ -112,8 +108,9 @@ class InferenceRunner(Callback):
sess = tf.get_default_session() sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar: with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping? #feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
outputs = sess.run(self.output_tensors, feed_dict=feed) #outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs = self.pred_func(dp)
for vc, varsmap in zip(self.vcs, self.vc_to_vars): for vc, varsmap in zip(self.vcs, self.vc_to_vars):
vc_output = [outputs[k[1]] for k in varsmap] vc_output = [outputs[k[1]] for k in varsmap]
vc.datapoint(dp, vc_output) vc.datapoint(dp, vc_output)
......
...@@ -9,7 +9,9 @@ import tensorflow as tf ...@@ -9,7 +9,9 @@ import tensorflow as tf
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step',
'get_global_step_var', 'get_global_step_var',
'get_op_var_name'] 'get_op_var_name',
'get_vars_by_names'
]
def get_default_sess_config(mem_fraction=0.9): def get_default_sess_config(mem_fraction=0.9):
""" """
...@@ -53,3 +55,14 @@ def get_op_var_name(name): ...@@ -53,3 +55,14 @@ def get_op_var_name(name):
return name[:-2], name return name[:-2], name
else: else:
return name, name + ':0' return name, name + ':0'
def get_vars_by_names(names):
"""
Get a list of variables in the default graph by a list of names
"""
ret = []
G = tf.get_default_graph()
for n in names:
opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn))
return ret
...@@ -50,6 +50,11 @@ class Trainer(object): ...@@ -50,6 +50,11 @@ class Trainer(object):
""" run an iteration""" """ run an iteration"""
pass pass
@abstractmethod
def get_predict_func(self, input_names, output_names):
""" return a predict function"""
pass
def trigger_epoch(self): def trigger_epoch(self):
self._trigger_epoch() self._trigger_epoch()
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
......
...@@ -53,25 +53,16 @@ class SimpleTrainer(Trainer): ...@@ -53,25 +53,16 @@ class SimpleTrainer(Trainer):
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
input_vars = [] input_vars = get_vars_by_names(input_names)
for n in input_names: for v in input_vars:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
assert v in self.input_vars assert v in self.input_vars
input_vars.append(v) output_vars = get_vars_by_names(output_names)
output_vars = []
for n in output_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
output_vars.append(v)
def func(inputs): def func(inputs):
assert len(inputs) == len(input_vars) assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs)) feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed) return self.sess.run(output_vars, feed_dict=feed)
return func return func
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var): def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
...@@ -126,6 +117,7 @@ class QueueInputTrainer(Trainer): ...@@ -126,6 +117,7 @@ class QueueInputTrainer(Trainer):
self.async = async self.async = async
if self.async: if self.async:
assert self.config.nr_tower > 1 assert self.config.nr_tower > 1
self._dequed_inputs = []
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -148,6 +140,7 @@ class QueueInputTrainer(Trainer): ...@@ -148,6 +140,7 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars) assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars): for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
self._dequed_inputs.append(ret)
return ret return ret
def _single_tower_grad(self): def _single_tower_grad(self):
...@@ -248,6 +241,27 @@ class QueueInputTrainer(Trainer): ...@@ -248,6 +241,27 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
if self.config.nr_tower == 1:
dequed = self._dequed_inputs[0]
input_vars = [dequed[k] for k in input_var_idxs]
output_vars = get_vars_by_names(output_names)
else:
# TODO naive impl: use the first tower only
dequed = self._dequed_inputs[0]
input_vars = [dequed[k] for k in input_var_idxs]
output_names = ['tower0/' + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def start_train(config): def start_train(config):
tr = QueueInputTrainer(config) tr = QueueInputTrainer(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment