Commit a3de7ec7 authored by Yuxin Wu's avatar Yuxin Wu

use modeldesc

parent ba2e7ff0
...@@ -22,58 +22,67 @@ BATCH_SIZE = 10 ...@@ -22,58 +22,67 @@ BATCH_SIZE = 10
MIN_AFTER_DEQUEUE = 500 MIN_AFTER_DEQUEUE = 500
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training): class Model(ModelDesc):
# img: 227x227x3 def _get_input_vars(self):
is_training = bool(is_training) return [
keep_prob = tf.constant(0.5 if is_training else 1.0) tf.placeholder(
tf.float32, shape=(None, 227, 227, 3), name='input'),
image, label = inputs tf.placeholder(
tf.int32, shape=(None,), name='label')
l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID') ]
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = MaxPooling('pool1', l, 3, stride=2, padding='VALID') def _get_cost(self, inputs, is_training):
# img: 227x227x3
l = Conv2D('conv2', l, out_channel=256, kernel_shape=5, is_training = bool(is_training)
padding='SAME', split=2) keep_prob = tf.constant(0.5 if is_training else 1.0)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='VALID') image, label = inputs
l = Conv2D('conv3', l, out_channel=384, kernel_shape=3, l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID')
padding='SAME') l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = Conv2D('conv4', l, out_channel=384, kernel_shape=3, l = MaxPooling('pool1', l, 3, stride=2, padding='VALID')
padding='SAME', split=2)
l = Conv2D('conv5', l, out_channel=256, kernel_shape=3, l = Conv2D('conv2', l, out_channel=256, kernel_shape=5,
padding='SAME', split=2) padding='SAME', split=2)
l = MaxPooling('pool3', l, 3, stride=2, padding='VALID') l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='VALID')
l = FullyConnected('fc6', l, 4096)
l = FullyConnected('fc7', l, out_dim=4096) l = Conv2D('conv3', l, out_channel=384, kernel_shape=3,
# fc will have activation summary by default. disable this for the output layer padding='SAME')
logits = FullyConnected('fc8', l, out_dim=1000, summary_activation=False, nl=tf.identity) l = Conv2D('conv4', l, out_channel=384, kernel_shape=3,
prob = tf.nn.softmax(logits, name='output') padding='SAME', split=2)
l = Conv2D('conv5', l, out_channel=256, kernel_shape=3,
y = one_hot(label, 1000) padding='SAME', split=2)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) l = MaxPooling('pool3', l, 3, stride=2, padding='VALID')
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) l = FullyConnected('fc6', l, 4096)
l = FullyConnected('fc7', l, out_dim=4096)
# compute the number of failed samples, for ValidationError to use at test time # fc will have activation summary by default. disable this for the output layer
wrong = tf.not_equal( logits = FullyConnected('fc8', l, out_dim=1000, summary_activation=False, nl=tf.identity)
tf.cast(tf.argmax(prob, 1), tf.int32), label) prob = tf.nn.softmax(logits, name='output')
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong') y = one_hot(label, 1000)
# monitor training error cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
tf.add_to_collection( cost = tf.reduce_mean(cost, name='cross_entropy_loss')
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# weight decay on all W of fc layers # compute the number of failed samples, for ValidationError to use at test time
wd_cost = tf.mul(1e-4, wrong = tf.not_equal(
regularize_cost('fc.*/W', tf.nn.l2_loss), tf.cast(tf.argmax(prob, 1), tf.int32), label)
name='regularize_loss') wrong = tf.cast(wrong, tf.float32)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
add_param_summary('.*/W') # monitor histogram of all W tf.add_to_collection(
return tf.add_n([wd_cost, cost], name='cost') MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -87,16 +96,6 @@ def get_config(): ...@@ -87,16 +96,6 @@ def get_config():
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
# prepare model
input_vars = [
tf.placeholder(
tf.float32, shape=(None, 227, 227, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.RandomShuffleQueue(
10, 3, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-8, learning_rate=1e-8,
global_step=get_global_step_var(), global_step=get_global_step_var(),
...@@ -115,27 +114,18 @@ def get_config(): ...@@ -115,27 +114,18 @@ def get_config():
#ValidationError(dataset_test, prefix='test'), #ValidationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
inputs=input_vars, model=Model(),
input_queue=input_queue,
get_model_func=get_model,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
max_epoch=100, max_epoch=100,
) )
def run_test(path): def run_test(path):
input_vars = [
tf.placeholder(
tf.float32, shape=(None, 227, 227, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
param_dict = np.load(path).item() param_dict = np.load(path).item()
pred_config = PredictConfig( pred_config = PredictConfig(
inputs=input_vars, model=Models(),
input_dataset_mapping=[input_vars[0]], input_data_mapping=[0],
get_model_func=get_model,
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
output_var_names=['output:0'] # output:0 is the probability distribution output_var_names=['output:0'] # output:0 is the probability distribution
) )
......
...@@ -27,69 +27,77 @@ BATCH_SIZE = 128 ...@@ -27,69 +27,77 @@ BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = int(50000 * 0.4) MIN_AFTER_DEQUEUE = int(50000 * 0.4)
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training): class Model(ModelDesc):
#keep_prob = tf.constant(0.5 if is_training else 0.0) def _get_input_vars(self):
return [
image, label = inputs tf.placeholder(
tf.float32, shape=[None, 24, 24, 3], name='input'),
if is_training: tf.placeholder(
image, label = tf.train.shuffle_batch( tf.int32, shape=[None], name='label')
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE, ]
num_threads=6, enqueue_many=False)
tf.image_summary("train_image", image, 10) def _get_cost(self, input_vars, is_training):
image, label = input_vars
l = Conv2D('conv1', image, out_channel=64, kernel_shape=5, padding='SAME',
W_init=tf.truncated_normal_initializer(stddev=1e-4)) if is_training:
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME') image, label = tf.train.shuffle_batch(
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1') [image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=True)
l = Conv2D('conv2', l, out_channel=64, kernel_shape=5, padding='SAME', tf.image_summary("train_image", image, 10)
W_init=tf.truncated_normal_initializer(stddev=1e-4),
b_init=tf.constant_initializer(0.1)) l = Conv2D('conv1', image, out_channel=64, kernel_shape=5, padding='SAME',
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2') W_init=tf.truncated_normal_initializer(stddev=1e-4))
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME') l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
l = FullyConnected('fc0', l, 384,
W_init=tf.truncated_normal_initializer(stddev=0.04), l = Conv2D('conv2', l, out_channel=64, kernel_shape=5, padding='SAME',
b_init=tf.constant_initializer(0.1)) W_init=tf.truncated_normal_initializer(stddev=1e-4),
l = FullyConnected('fc1', l, out_dim=192, b_init=tf.constant_initializer(0.1))
W_init=tf.truncated_normal_initializer(stddev=0.04), l = tf.nn.lrn(l, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
b_init=tf.constant_initializer(0.1)) l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, summary_activation=False, l = FullyConnected('fc0', l, 384,
nl=tf.identity, W_init=tf.truncated_normal_initializer(stddev=0.04),
W_init=tf.truncated_normal_initializer(stddev=1.0/192)) b_init=tf.constant_initializer(0.1))
l = FullyConnected('fc1', l, out_dim=192,
prob = tf.nn.softmax(logits, name='output') W_init=tf.truncated_normal_initializer(stddev=0.04),
b_init=tf.constant_initializer(0.1))
y = one_hot(label, 10) # fc will have activation summary by default. disable for the output layer
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) logits = FullyConnected('linear', l, out_dim=10, summary_activation=False,
cost = tf.reduce_mean(cost, name='cross_entropy_loss') nl=tf.identity,
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) W_init=tf.truncated_normal_initializer(stddev=1.0/192))
# compute the number of failed samples, for ValidationError to use at test time prob = tf.nn.softmax(logits, name='output')
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label) y = one_hot(label, 10)
wrong = tf.cast(wrong, tf.float32) cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
nr_wrong = tf.reduce_sum(wrong, name='wrong') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
# monitor training error tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
# weight decay on all W of fc layers tf.cast(tf.argmax(prob, 1), tf.int32), label)
wd_cost = tf.mul(0.004, wrong = tf.cast(wrong, tf.float32)
regularize_cost('fc.*/W', tf.nn.l2_loss), nr_wrong = tf.reduce_sum(wrong, name='wrong')
name='regularize_loss') # monitor training error
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
add_param_summary('.*') # monitor all variables
return tf.add_n([cost, wd_cost], name='cost') # weight decay on all W of fc layers
wd_cost = tf.mul(0.004,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*') # monitor all variables
return tf.add_n([cost, wd_cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log')) logger.set_logger_file(os.path.join(log_dir, 'training.log'))
# prepare dataset
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
augmentors = [ augmentors = [
imgaug.RandomCrop((24, 24)), imgaug.RandomCrop((24, 24)),
...@@ -100,6 +108,7 @@ def get_config(): ...@@ -100,6 +108,7 @@ def get_config():
] ]
dataset_train = AugmentImageComponent(dataset_train, augmentors) dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
step_per_epoch = dataset_train.size()
augmentors = [ augmentors = [
imgaug.CenterCrop((24, 24)), imgaug.CenterCrop((24, 24)),
...@@ -107,22 +116,11 @@ def get_config(): ...@@ -107,22 +116,11 @@ def get_config():
] ]
dataset_test = dataset.Cifar10('test') dataset_test = dataset.Cifar10('test')
dataset_test = AugmentImageComponent(dataset_test, augmentors) dataset_test = AugmentImageComponent(dataset_test, augmentors)
dataset_test = BatchData(dataset_test, 128) dataset_test = BatchData(dataset_test, 128, remainder=True)
step_per_epoch = dataset_train.size()
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
# prepare model
input_vars = [
tf.placeholder(
tf.float32, shape=(None, 24, 24, 3), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.FIFOQueue(
50, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-1, learning_rate=1e-1,
global_step=get_global_step_var(), global_step=get_global_step_var(),
...@@ -139,10 +137,7 @@ def get_config(): ...@@ -139,10 +137,7 @@ def get_config():
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
inputs=input_vars, model=Model(),
input_queue=input_queue,
get_model_func=get_model,
batched_model_input=False,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=500, max_epoch=500,
) )
......
...@@ -19,92 +19,79 @@ from tensorpack.callbacks import * ...@@ -19,92 +19,79 @@ from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
BATCH_SIZE = 128 BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500 IMAGE_SIZE = 28
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
class Model(ModelDesc):
def get_model(inputs, is_training): def _get_input_vars(self):
""" return [
Args: tf.placeholder(
inputs: a list of input variable for training tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
e.g.: [image_var, label_var] with: tf.placeholder(
image_var: bx28x28 tf.int32, shape=(None,), name='label')
label_var: bx1 integer ]
is_training: a python bool variable
Returns: def _get_cost(self, input_vars, is_training):
the cost to minimize. scalar variable is_training = bool(is_training)
""" keep_prob = tf.constant(0.5 if is_training else 1.0)
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0) image, label = input_vars
image = tf.expand_dims(image, 3) # add a single channel
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel l = Conv2D('conv0', image, out_channel=32, kernel_shape=3)
l = Conv2D('conv1', image, out_channel=32, kernel_shape=3)
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5) l = MaxPooling('pool0', l, 2)
l = MaxPooling('pool0', l, 2) l = Conv2D('conv2', l, out_channel=40, kernel_shape=3)
l = Conv2D('conv1', l, out_channel=40, kernel_shape=3) l = MaxPooling('pool1', l, 2)
l = MaxPooling('pool1', l, 2)
l = FullyConnected('fc0', l, 1024)
l = FullyConnected('fc0', l, 1024) l = tf.nn.dropout(l, keep_prob)
l = tf.nn.dropout(l, keep_prob)
# fc will have activation summary by default. disable this for the output layer
# fc will have activation summary by default. disable this for the output layer logits = FullyConnected('fc1', l, out_dim=10,
logits = FullyConnected('fc1', l, out_dim=10, summary_activation=False, nl=tf.identity)
summary_activation=False, nl=tf.identity) prob = tf.nn.softmax(logits, name='prob')
prob = tf.nn.softmax(logits, name='prob')
y = one_hot(label, 10)
y = one_hot(label, 10) cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) cost = tf.reduce_mean(cost, name='cross_entropy_loss')
cost = tf.reduce_mean(cost, name='cross_entropy_loss') tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ValidationError to use at test time wrong = tf.not_equal(
wrong = tf.not_equal( tf.cast(tf.argmax(prob, 1), tf.int32), label)
tf.cast(tf.argmax(prob, 1), tf.int32), label) wrong = tf.cast(wrong, tf.float32)
wrong = tf.cast(wrong, tf.float32) nr_wrong = tf.reduce_sum(wrong, name='wrong')
nr_wrong = tf.reduce_sum(wrong, name='wrong') # monitor training error
# monitor training error tf.add_to_collection(
tf.add_to_collection( MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
# weight decay on all W of fc layers wd_cost = tf.mul(1e-4,
wd_cost = tf.mul(1e-4, regularize_cost('fc.*/W', tf.nn.l2_loss),
regularize_cost('fc.*/W', tf.nn.l2_loss), name='regularize_loss')
name='regularize_loss') tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
add_param_summary('.*/W') # monitor histogram of all W return tf.add_n([wd_cost, cost], name='cost')
return tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log')) logger.set_logger_file(os.path.join(log_dir, 'training.log'))
IMAGE_SIZE = 28 # prepare dataset
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 3
#dataset_test = FixedSizeData(dataset_test, 20)
# prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
# prepare model
input_vars = [
tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.FIFOQueue(
100, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-4, learning_rate=1e-3,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 50, decay_steps=dataset_train.size() * 10,
decay_rate=0.1, staircase=True, name='learning_rate') decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -117,9 +104,7 @@ def get_config(): ...@@ -117,9 +104,7 @@ def get_config():
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
inputs=input_vars, model=Model(),
input_queue=input_queue,
get_model_func=get_model,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=100,
) )
......
...@@ -28,10 +28,8 @@ class DumpParamAsImage(Callback): ...@@ -28,10 +28,8 @@ class DumpParamAsImage(Callback):
def _before_train(self): def _before_train(self):
self.var = self.graph.get_tensor_by_name(self.var_name) self.var = self.graph.get_tensor_by_name(self.var_name)
self.epoch_num = 0
def _trigger_epoch(self): def _trigger_epoch(self):
self.epoch_num += 1
val = self.sess.run(self.var) val = self.sess.run(self.var)
if self.func is not None: if self.func is not None:
val = self.func(val) val = self.func(val)
...@@ -40,13 +38,13 @@ class DumpParamAsImage(Callback): ...@@ -40,13 +38,13 @@ class DumpParamAsImage(Callback):
assert im.ndim in [2, 3], str(im.ndim) assert im.ndim in [2, 3], str(im.ndim)
fname = os.path.join( fname = os.path.join(
self.log_dir, self.log_dir,
self.prefix + '-ep{}-{}.png'.format(self.epoch_num, idx)) self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx))
cv2.imwrite(fname, im * self.scale) cv2.imwrite(fname, im * self.scale)
else: else:
im = val im = val
assert im.ndim in [2, 3] assert im.ndim in [2, 3]
fname = os.path.join( fname = os.path.join(
self.log_dir, self.log_dir,
self.prefix + '-ep{}.png'.format(self.epoch_num)) self.prefix + '-ep{:03d}.png'.format(self.epoch_num))
cv2.imwrite(fname, im * self.scale) cv2.imwrite(fname, im * self.scale)
...@@ -15,23 +15,15 @@ __all__ = ['Callbacks'] ...@@ -15,23 +15,15 @@ __all__ = ['Callbacks']
@contextmanager @contextmanager
def create_test_graph(): def create_test_graph():
G = tf.get_default_graph() G = tf.get_default_graph()
input_vars_train = G.get_collection(INPUT_VARS_KEY) model = G.get_collection(MODEL_KEY)[0]
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest: with tf.Graph().as_default() as Gtest:
# create a global step var in test graph # create a global step var in test graph
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = [] new_model = model.__class__()
for v in input_vars_train: input_vars = new_model.get_input_vars()
name = v.name cost = new_model.get_cost(input_vars, is_training=False)
assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'" Gtest.add_to_collection(MODEL_KEY, new_model)
name = name[:-2]
input_vars.append(tf.placeholder(
v.dtype, shape=v.get_shape(), name=name
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
cost = forward_func(input_vars, is_training=False)
yield Gtest yield Gtest
@contextmanager @contextmanager
...@@ -155,15 +147,20 @@ class Callbacks(Callback): ...@@ -155,15 +147,20 @@ class Callbacks(Callback):
raise ValueError( raise ValueError(
"Unknown callback running graph {}!".format(cb.running_graph)) "Unknown callback running graph {}!".format(cb.running_graph))
self.train = TrainCallbacks(train_cbs) self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs) if test_cbs:
self.test = TestCallbacks(test_cbs)
else:
self.test = None
def _before_train(self): def _before_train(self):
self.train.before_train() self.train.before_train()
self.test.before_train() if self.test:
self.test.before_train()
def _after_train(self): def _after_train(self):
self.train.after_train() self.train.after_train()
self.test.after_train() if self.test:
self.test.after_train()
def trigger_step(self): def trigger_step(self):
self.train.trigger_step() self.train.trigger_step()
...@@ -172,4 +169,5 @@ class Callbacks(Callback): ...@@ -172,4 +169,5 @@ class Callbacks(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self.train.trigger_epoch() self.train.trigger_epoch()
# TODO test callbacks can be run async? # TODO test callbacks can be run async?
self.test.trigger_epoch() if self.test:
self.test.trigger_epoch()
...@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self.cost_var_name = cost_var_name self.cost_var_name = cost_var_name
def _before_train(self): def _before_train(self):
self.input_vars = tf.get_collection(INPUT_VARS_KEY) self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name) self.cost_var = self.get_tensor(self.cost_var_name)
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
self._find_output_vars() self._find_output_vars()
......
...@@ -16,7 +16,7 @@ class BatchData(DataFlow): ...@@ -16,7 +16,7 @@ class BatchData(DataFlow):
Group data in ds into batches Group data in ds into batches
ds: a DataFlow instance ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size. remainder: whether to return the remaining data smaller than a batch_size.
if set, might return a data point of a different shape if set True, will possibly return a data point of a smaller 1st dimension
""" """
self.ds = ds self.ds = ds
if not remainder: if not remainder:
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: model_desc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod
import tensorflow as tf
__all__ = ['ModelDesc']
class ModelDesc(object):
__metaclass__ = ABCMeta
def __init__(self):
self.input_vars = None
def get_input_vars(self):
"""
return the list of input vars in the graph
results will be cached, to avoid creating the same variable
"""
if self.input_vars is None:
self.input_vars = self._get_input_vars()
for i in self.input_vars:
assert isinstance(i, tf.Tensor), tf.Tensor.__class__
return self.input_vars
@abstractmethod
def _get_input_vars(self):
"""
return the list of input vars in the graph
"""
def get_input_queue(self):
"""
return the queue for input. the dequeued elements will be fed to self.get_cost
"""
assert self.input_vars is not None
return tf.FIFOQueue(50, [x.dtype for x in self.input_vars], name='input_queue')
def get_cost(self, input_vars, is_training):
assert len(input_vars) == len(self.input_vars)
assert type(is_training) == bool
return self._get_cost(input_vars, is_training)
@abstractmethod
def _get_cost(self, input_vars, is_training):
"""
Args:
input_vars: a list of input variable in the graph
e.g.: [image_var, label_var] with:
image_var: bx28x28
label_var: bx1 integer
is_training: a python bool variable
Returns:
the cost to minimize. scalar variable
input_vars might be different from self.input_vars
(inputs might go through the queue for faster input),
but must have the same length
"""
def get_lr_multipler(self):
"""
Return a dict of {variable_regex: multiplier}
"""
return {}
...@@ -24,24 +24,23 @@ class PredictConfig(object): ...@@ -24,24 +24,23 @@ class PredictConfig(object):
session. default to a session running 1 GPU. session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. initialize variables of a session.
inputs: input variables of the graph. input_data_mapping: Decide the mapping from each component in data
input_dataset_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability the `label` input is not used if you only need probability
distribution). It should be a list with size=len(one_data_point), distribution).
where each element is a tensor which each component of the It should be a list with size=len(one_data_point),
data point should be fed into. where each element is an index of the input variables each
If not given, defaults to `inputs`. component of the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, with image classification task, the testing For example, with image classification task, the testing
dataset only provides datapoints of images (no labels). The dataset only provides datapoints of images (no labels). The
arguments should look like: arguments should look like:
inputs: [image_var, label_var] inputs: [image_var, label_var]
input_dataset_mapping: [image_var] input_data_mapping: [0]
If this argument is not set, the inputs and the data points won't be aligned. If this argument is not set, the inputs and the data points won't be aligned.
get_model_func: a function taking `inputs` and `is_training` and model: a ModelDesc instance
return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the output_var_names: a list of names of the output variable to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
if None, will only calculate the cost returned by `get_model_func`. if None, will only calculate the cost returned by `get_model_func`.
...@@ -53,9 +52,8 @@ class PredictConfig(object): ...@@ -53,9 +52,8 @@ class PredictConfig(object):
self.session_config = kwargs.pop('session_config', get_default_sess_config()) self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init') self.session_init = kwargs.pop('session_init')
self.inputs = kwargs.pop('inputs') self.model = kwargs.pop('model')
self.input_dataset_mapping = kwargs.pop('input_dataset_mapping', None) self.input_data_mapping = kwargs.pop('input_dataset_mapping', None)
self.get_model_func = kwargs.pop('get_model_func')
self.output_var_names = kwargs.pop('output_var_names', None) self.output_var_names = kwargs.pop('output_var_names', None)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -72,9 +70,9 @@ def get_predict_func(config): ...@@ -72,9 +70,9 @@ def get_predict_func(config):
output_var_names = config.output_var_names output_var_names = config.output_var_names
# input/output variables # input/output variables
input_vars = config.inputs input_vars = config.model.get_input_vars()
cost_var = config.get_model_func(input_vars, is_training=False) cost_var = config.model.get_cost(input_vars, is_training=False)
input_map = config.input_dataset_mapping input_map = [input_vars[k] for k in config.input_data_mapping]
if input_map is None: if input_map is None:
input_map = input_vars input_map = input_vars
......
...@@ -9,6 +9,7 @@ import copy ...@@ -9,6 +9,7 @@ import copy
import argparse import argparse
import tqdm import tqdm
from models import ModelDesc
from utils import * from utils import *
from utils.concurrency import EnqueueThread from utils.concurrency import EnqueueThread
from callbacks import * from callbacks import *
...@@ -32,17 +33,7 @@ class TrainConfig(object): ...@@ -32,17 +33,7 @@ class TrainConfig(object):
session. default to a session running 1 GPU. session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session. initialize variables of a session. default to a new session.
inputs: a list of input variables. must match what is returned by model: a ModelDesc instance
the dataset
input_queue: the queue used for input. default to a FIFO queue
with capacity 5
get_model_func: a function taking `inputs` and `is_training`, and
return the cost to minimize
batched_model_input: boolean. If yes, `get_model_func` expected batched
input in training. Otherwise, expect single data point in
training, so that you may do pre-processing and batch them
later with batch ops. It's suggested that you do all
preprocessing in dataset as that is usually faster.
step_per_epoch: the number of steps (parameter updates) to perform step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size() in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100 max_epoch: maximum number of epoch to run training. default to 100
...@@ -60,14 +51,8 @@ class TrainConfig(object): ...@@ -60,14 +51,8 @@ class TrainConfig(object):
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession()) self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.inputs = kwargs.pop('inputs') self.model = kwargs.pop('model')
[assert_type(i, tf.Tensor) for i in self.inputs] assert_type(self.model, ModelDesc)
self.input_queue = kwargs.pop(
'input_queue', tf.FIFOQueue(5, [x.dtype for x in self.inputs], name='input_queue'))
assert_type(self.input_queue, tf.QueueBase)
assert self.input_queue.dtypes == [x.dtype for x in self.inputs]
self.get_model_func = kwargs.pop('get_model_func')
self.batched_model_input = kwargs.pop('batched_model_input', True)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size())) self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
...@@ -97,29 +82,22 @@ def start_train(config): ...@@ -97,29 +82,22 @@ def start_train(config):
Args: Args:
config: a TrainConfig instance config: a TrainConfig instance
""" """
input_vars = config.inputs model = config.model
input_queue = config.input_queue input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
callbacks = config.callbacks callbacks = config.callbacks
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func) tf.add_to_collection(MODEL_KEY, model)
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
def get_model_inputs(): def get_model_inputs():
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor): if isinstance(model_inputs, tf.Tensor):
model_inputs = [model_inputs] model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
if config.batched_model_input: qv.set_shape(v.get_shape())
qv.set_shape(v.get_shape())
else:
qv.set_shape(v.get_shape().as_list()[1:])
return model_inputs return model_inputs
if config.batched_model_input: enqueue_op = input_queue.enqueue(input_vars)
enqueue_op = input_queue.enqueue(input_vars)
else:
enqueue_op = input_queue.enqueue_many(input_vars)
# get gradients to update: # get gradients to update:
logger.info("Training a model of {} tower".format(config.nr_tower)) logger.info("Training a model of {} tower".format(config.nr_tower))
...@@ -131,7 +109,7 @@ def start_train(config): ...@@ -131,7 +109,7 @@ def start_train(config):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = config.get_model_func(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grads.append( grads.append(
config.optimizer.compute_gradients(cost_var)) config.optimizer.compute_gradients(cost_var))
...@@ -145,7 +123,7 @@ def start_train(config): ...@@ -145,7 +123,7 @@ def start_train(config):
grads = average_gradients(grads) grads = average_gradients(grads)
else: else:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = config.get_model_func(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var) grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads) summary_grads(grads)
check_grads(grads) check_grads(grads)
......
...@@ -29,7 +29,7 @@ class EnqueueThread(threading.Thread): ...@@ -29,7 +29,7 @@ class EnqueueThread(threading.Thread):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.sess = sess self.sess = sess
self.coord = coord self.coord = coord
self.input_vars = sess.graph.get_collection(INPUT_VARS_KEY) self.input_vars = sess.graph.get_collection(MODEL_KEY)[0].get_input_vars()
self.dataflow = dataflow self.dataflow = dataflow
self.op = enqueue_op self.op = enqueue_op
self.queue = queue self.queue = queue
......
...@@ -8,9 +8,8 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -8,9 +8,8 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer' SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training
FORWARD_FUNC_KEY = 'FORWARD_FUNCTION' MODEL_KEY = 'MODEL'
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment