Commit a3de7ec7 authored by Yuxin Wu's avatar Yuxin Wu

use modeldesc

parent ba2e7ff0
......@@ -22,58 +22,67 @@ BATCH_SIZE = 10
MIN_AFTER_DEQUEUE = 500
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training):
# img: 227x227x3
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID')
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = MaxPooling('pool1', l, 3, stride=2, padding='VALID')
l = Conv2D('conv2', l, out_channel=256, kernel_shape=5,
padding='SAME', split=2)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='VALID')
l = Conv2D('conv3', l, out_channel=384, kernel_shape=3,
padding='SAME')
l = Conv2D('conv4', l, out_channel=384, kernel_shape=3,
padding='SAME', split=2)
l = Conv2D('conv5', l, out_channel=256, kernel_shape=3,
padding='SAME', split=2)
l = MaxPooling('pool3', l, 3, stride=2, padding='VALID')
l = FullyConnected('fc6', l, 4096)
l = FullyConnected('fc7', l, out_dim=4096)
# fc will have activation summary by default. disable this for the output layer
logits = FullyConnected('fc8', l, out_dim=1000, summary_activation=False, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 1000)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')
class Model(ModelDesc):
def _get_input_vars(self):
return [
tf.placeholder(
tf.float32, shape=(None, 227, 227, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
def _get_cost(self, inputs, is_training):
# img: 227x227x3
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID')
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = MaxPooling('pool1', l, 3, stride=2, padding='VALID')
l = Conv2D('conv2', l, out_channel=256, kernel_shape=5,
padding='SAME', split=2)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='VALID')
l = Conv2D('conv3', l, out_channel=384, kernel_shape=3,
padding='SAME')
l = Conv2D('conv4', l, out_channel=384, kernel_shape=3,
padding='SAME', split=2)
l = Conv2D('conv5', l, out_channel=256, kernel_shape=3,
padding='SAME', split=2)
l = MaxPooling('pool3', l, 3, stride=2, padding='VALID')
l = FullyConnected('fc6', l, 4096)
l = FullyConnected('fc7', l, out_dim=4096)
# fc will have activation summary by default. disable this for the output layer
logits = FullyConnected('fc8', l, out_dim=1000, summary_activation=False, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 1000)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
......@@ -87,16 +96,6 @@ def get_config():
sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
# prepare model
input_vars = [
tf.placeholder(
tf.float32, shape=(None, 227, 227, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.RandomShuffleQueue(
10, 3, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay(
learning_rate=1e-8,
global_step=get_global_step_var(),
......@@ -115,27 +114,18 @@ def get_config():
#ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
inputs=input_vars,
input_queue=input_queue,
get_model_func=get_model,
model=Model(),
step_per_epoch=step_per_epoch,
session_init=ParamRestore(param_dict),
max_epoch=100,
)
def run_test(path):
input_vars = [
tf.placeholder(
tf.float32, shape=(None, 227, 227, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
param_dict = np.load(path).item()
pred_config = PredictConfig(
inputs=input_vars,
input_dataset_mapping=[input_vars[0]],
get_model_func=get_model,
model=Models(),
input_data_mapping=[0],
session_init=ParamRestore(param_dict),
output_var_names=['output:0'] # output:0 is the probability distribution
)
......
......@@ -27,69 +27,77 @@ BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = int(50000 * 0.4)
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training):
#keep_prob = tf.constant(0.5 if is_training else 0.0)
image, label = inputs
if is_training:
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=False)
tf.image_summary("train_image", image, 10)
l = Conv2D('conv1', image, out_channel=64, kernel_shape=5, padding='SAME',
W_init=tf.truncated_normal_initializer(stddev=1e-4))
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
l = Conv2D('conv2', l, out_channel=64, kernel_shape=5, padding='SAME',
W_init=tf.truncated_normal_initializer(stddev=1e-4),
b_init=tf.constant_initializer(0.1))
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = FullyConnected('fc0', l, 384,
W_init=tf.truncated_normal_initializer(stddev=0.04),
b_init=tf.constant_initializer(0.1))
l = FullyConnected('fc1', l, out_dim=192,
W_init=tf.truncated_normal_initializer(stddev=0.04),
b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, summary_activation=False,
nl=tf.identity,
W_init=tf.truncated_normal_initializer(stddev=1.0/192))
prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(0.004,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*') # monitor all variables
return tf.add_n([cost, wd_cost], name='cost')
class Model(ModelDesc):
def _get_input_vars(self):
return [
tf.placeholder(
tf.float32, shape=[None, 24, 24, 3], name='input'),
tf.placeholder(
tf.int32, shape=[None], name='label')
]
def _get_cost(self, input_vars, is_training):
image, label = input_vars
if is_training:
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10)
l = Conv2D('conv1', image, out_channel=64, kernel_shape=5, padding='SAME',
W_init=tf.truncated_normal_initializer(stddev=1e-4))
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
l = Conv2D('conv2', l, out_channel=64, kernel_shape=5, padding='SAME',
W_init=tf.truncated_normal_initializer(stddev=1e-4),
b_init=tf.constant_initializer(0.1))
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = FullyConnected('fc0', l, 384,
W_init=tf.truncated_normal_initializer(stddev=0.04),
b_init=tf.constant_initializer(0.1))
l = FullyConnected('fc1', l, out_dim=192,
W_init=tf.truncated_normal_initializer(stddev=0.04),
b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, summary_activation=False,
nl=tf.identity,
W_init=tf.truncated_normal_initializer(stddev=1.0/192))
prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(0.004,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*') # monitor all variables
return tf.add_n([cost, wd_cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
# prepare dataset
dataset_train = dataset.Cifar10('train')
augmentors = [
imgaug.RandomCrop((24, 24)),
......@@ -100,6 +108,7 @@ def get_config():
]
dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128)
step_per_epoch = dataset_train.size()
augmentors = [
imgaug.CenterCrop((24, 24)),
......@@ -107,22 +116,11 @@ def get_config():
]
dataset_test = dataset.Cifar10('test')
dataset_test = AugmentImageComponent(dataset_test, augmentors)
dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size()
dataset_test = BatchData(dataset_test, 128, remainder=True)
sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
# prepare model
input_vars = [
tf.placeholder(
tf.float32, shape=(None, 24, 24, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.FIFOQueue(
50, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay(
learning_rate=1e-1,
global_step=get_global_step_var(),
......@@ -139,10 +137,7 @@ def get_config():
ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
inputs=input_vars,
input_queue=input_queue,
get_model_func=get_model,
batched_model_input=False,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=500,
)
......
......@@ -19,92 +19,79 @@ from tensorpack.callbacks import *
from tensorpack.dataflow import *
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training):
"""
Args:
inputs: a list of input variable for training
e.g.: [image_var, label_var] with:
image_var: bx28x28
label_var: bx1 integer
is_training: a python bool variable
Returns:
the cost to minimize. scalar variable
"""
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
l = MaxPooling('pool0', l, 2)
l = Conv2D('conv1', l, out_channel=40, kernel_shape=3)
l = MaxPooling('pool1', l, 2)
l = FullyConnected('fc0', l, 1024)
l = tf.nn.dropout(l, keep_prob)
# fc will have activation summary by default. disable this for the output layer
logits = FullyConnected('fc1', l, out_dim=10,
summary_activation=False, nl=tf.identity)
prob = tf.nn.softmax(logits, name='prob')
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')
IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_input_vars(self):
return [
tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
def _get_cost(self, input_vars, is_training):
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = input_vars
image = tf.expand_dims(image, 3) # add a single channel
l = Conv2D('conv0', image, out_channel=32, kernel_shape=3)
l = Conv2D('conv1', image, out_channel=32, kernel_shape=3)
l = MaxPooling('pool0', l, 2)
l = Conv2D('conv2', l, out_channel=40, kernel_shape=3)
l = MaxPooling('pool1', l, 2)
l = FullyConnected('fc0', l, 1024)
l = tf.nn.dropout(l, keep_prob)
# fc will have activation summary by default. disable this for the output layer
logits = FullyConnected('fc1', l, out_dim=10,
summary_activation=False, nl=tf.identity)
prob = tf.nn.softmax(logits, name='prob')
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
IMAGE_SIZE = 28
# prepare dataset
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
#step_per_epoch = 3
#dataset_test = FixedSizeData(dataset_test, 20)
# prepare session
sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
# prepare model
input_vars = [
tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.FIFOQueue(
100, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay(
learning_rate=1e-4,
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 50,
decay_steps=dataset_train.size() * 10,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
......@@ -117,9 +104,7 @@ def get_config():
ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
inputs=input_vars,
input_queue=input_queue,
get_model_func=get_model,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=100,
)
......
......@@ -28,10 +28,8 @@ class DumpParamAsImage(Callback):
def _before_train(self):
self.var = self.graph.get_tensor_by_name(self.var_name)
self.epoch_num = 0
def _trigger_epoch(self):
self.epoch_num += 1
val = self.sess.run(self.var)
if self.func is not None:
val = self.func(val)
......@@ -40,13 +38,13 @@ class DumpParamAsImage(Callback):
assert im.ndim in [2, 3], str(im.ndim)
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{}-{}.png'.format(self.epoch_num, idx))
self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx))
cv2.imwrite(fname, im * self.scale)
else:
im = val
assert im.ndim in [2, 3]
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{}.png'.format(self.epoch_num))
self.prefix + '-ep{:03d}.png'.format(self.epoch_num))
cv2.imwrite(fname, im * self.scale)
......@@ -15,23 +15,15 @@ __all__ = ['Callbacks']
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
model = G.get_collection(MODEL_KEY)[0]
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = []
for v in input_vars_train:
name = v.name
assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'"
name = name[:-2]
input_vars.append(tf.placeholder(
v.dtype, shape=v.get_shape(), name=name
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
cost = forward_func(input_vars, is_training=False)
new_model = model.__class__()
input_vars = new_model.get_input_vars()
cost = new_model.get_cost(input_vars, is_training=False)
Gtest.add_to_collection(MODEL_KEY, new_model)
yield Gtest
@contextmanager
......@@ -155,15 +147,20 @@ class Callbacks(Callback):
raise ValueError(
"Unknown callback running graph {}!".format(cb.running_graph))
self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs)
if test_cbs:
self.test = TestCallbacks(test_cbs)
else:
self.test = None
def _before_train(self):
self.train.before_train()
self.test.before_train()
if self.test:
self.test.before_train()
def _after_train(self):
self.train.after_train()
self.test.after_train()
if self.test:
self.test.after_train()
def trigger_step(self):
self.train.trigger_step()
......@@ -172,4 +169,5 @@ class Callbacks(Callback):
def _trigger_epoch(self):
self.train.trigger_epoch()
# TODO test callbacks can be run async?
self.test.trigger_epoch()
if self.test:
self.test.trigger_epoch()
......@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self.cost_var_name = cost_var_name
def _before_train(self):
self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name)
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
self._find_output_vars()
......
......@@ -16,7 +16,7 @@ class BatchData(DataFlow):
Group data in ds into batches
ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size.
if set, might return a data point of a different shape
if set True, will possibly return a data point of a smaller 1st dimension
"""
self.ds = ds
if not remainder:
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: model_desc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod
import tensorflow as tf
__all__ = ['ModelDesc']
class ModelDesc(object):
__metaclass__ = ABCMeta
def __init__(self):
self.input_vars = None
def get_input_vars(self):
"""
return the list of input vars in the graph
results will be cached, to avoid creating the same variable
"""
if self.input_vars is None:
self.input_vars = self._get_input_vars()
for i in self.input_vars:
assert isinstance(i, tf.Tensor), tf.Tensor.__class__
return self.input_vars
@abstractmethod
def _get_input_vars(self):
"""
return the list of input vars in the graph
"""
def get_input_queue(self):
"""
return the queue for input. the dequeued elements will be fed to self.get_cost
"""
assert self.input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue')
def get_cost(self, input_vars, is_training):
assert len(input_vars) == len(self.input_vars)
assert type(is_training) == bool
return self._get_cost(input_vars, is_training)
@abstractmethod
def _get_cost(self, input_vars, is_training):
"""
Args:
input_vars: a list of input variable in the graph
e.g.: [image_var, label_var] with:
image_var: bx28x28
label_var: bx1 integer
is_training: a python bool variable
Returns:
the cost to minimize. scalar variable
input_vars might be different from self.input_vars
(inputs might go through the queue for faster input),
but must have the same length
"""
def get_lr_multipler(self):
"""
Return a dict of {variable_regex: multiplier}
"""
return {}
......@@ -24,24 +24,23 @@ class PredictConfig(object):
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session.
inputs: input variables of the graph.
input_dataset_mapping: Decide the mapping from each component in data
input_data_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
distribution). It should be a list with size=len(one_data_point),
where each element is a tensor which each component of the
data point should be fed into.
If not given, defaults to `inputs`.
distribution).
It should be a list with size=len(one_data_point),
where each element is an index of the input variables each
component of the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, with image classification task, the testing
dataset only provides datapoints of images (no labels). The
arguments should look like:
inputs: [image_var, label_var]
input_dataset_mapping: [image_var]
input_data_mapping: [0]
If this argument is not set, the inputs and the data points won't be aligned.
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
model: a ModelDesc instance
output_var_names: a list of names of the output variable to predict, the
variables can be any computable tensor in the graph.
if None, will only calculate the cost returned by `get_model_func`.
......@@ -53,9 +52,8 @@ class PredictConfig(object):
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init')
self.inputs = kwargs.pop('inputs')
self.input_dataset_mapping = kwargs.pop('input_dataset_mapping', None)
self.get_model_func = kwargs.pop('get_model_func')
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_dataset_mapping', None)
self.output_var_names = kwargs.pop('output_var_names', None)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......@@ -72,9 +70,9 @@ def get_predict_func(config):
output_var_names = config.output_var_names
# input/output variables
input_vars = config.inputs
cost_var = config.get_model_func(input_vars, is_training=False)
input_map = config.input_dataset_mapping
input_vars = config.model.get_input_vars()
cost_var = config.model.get_cost(input_vars, is_training=False)
input_map = [input_vars[k] for k in config.input_data_mapping]
if input_map is None:
input_map = input_vars
......
......@@ -9,6 +9,7 @@ import copy
import argparse
import tqdm
from models import ModelDesc
from utils import *
from utils.concurrency import EnqueueThread
from callbacks import *
......@@ -32,17 +33,7 @@ class TrainConfig(object):
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
inputs: a list of input variables. must match what is returned by
the dataset
input_queue: the queue used for input. default to a FIFO queue
with capacity 5
get_model_func: a function taking `inputs` and `is_training`, and
return the cost to minimize
batched_model_input: boolean. If yes, `get_model_func` expected batched
input in training. Otherwise, expect single data point in
training, so that you may do pre-processing and batch them
later with batch ops. It's suggested that you do all
preprocessing in dataset as that is usually faster.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
......@@ -60,14 +51,8 @@ class TrainConfig(object):
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.inputs = kwargs.pop('inputs')
[assert_type(i, tf.Tensor) for i in self.inputs]
self.input_queue = kwargs.pop(
'input_queue', tf.FIFOQueue(5, [x.dtype for x in self.inputs], name='input_queue'))
assert_type(self.input_queue, tf.QueueBase)
assert self.input_queue.dtypes == [x.dtype for x in self.inputs]
self.get_model_func = kwargs.pop('get_model_func')
self.batched_model_input = kwargs.pop('batched_model_input', True)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
......@@ -97,29 +82,22 @@ def start_train(config):
Args:
config: a TrainConfig instance
"""
input_vars = config.inputs
input_queue = config.input_queue
model = config.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
callbacks = config.callbacks
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func)
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
tf.add_to_collection(MODEL_KEY, model)
def get_model_inputs():
model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor):
model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars):
if config.batched_model_input:
qv.set_shape(v.get_shape())
else:
qv.set_shape(v.get_shape().as_list()[1:])
qv.set_shape(v.get_shape())
return model_inputs
if config.batched_model_input:
enqueue_op = input_queue.enqueue(input_vars)
else:
enqueue_op = input_queue.enqueue_many(input_vars)
enqueue_op = input_queue.enqueue(input_vars)
# get gradients to update:
logger.info("Training a model of {} tower".format(config.nr_tower))
......@@ -131,7 +109,7 @@ def start_train(config):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
cost_var = config.get_model_func(model_inputs, is_training=True)
cost_var = model.get_cost(model_inputs, is_training=True)
grads.append(
config.optimizer.compute_gradients(cost_var))
......@@ -145,7 +123,7 @@ def start_train(config):
grads = average_gradients(grads)
else:
model_inputs = get_model_inputs()
cost_var = config.get_model_func(model_inputs, is_training=True)
cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads)
check_grads(grads)
......
......@@ -29,7 +29,7 @@ class EnqueueThread(threading.Thread):
super(EnqueueThread, self).__init__()
self.sess = sess
self.coord = coord
self.input_vars = sess.graph.get_collection(INPUT_VARS_KEY)
self.input_vars = sess.graph.get_collection(MODEL_KEY)[0].get_input_vars()
self.dataflow = dataflow
self.op = enqueue_op
self.queue = queue
......
......@@ -8,9 +8,8 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training
FORWARD_FUNC_KEY = 'FORWARD_FUNCTION'
MODEL_KEY = 'MODEL'
# export all upper case variables
all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment