Commit a3d6c93d authored by Yuxin Wu's avatar Yuxin Wu

better handle params

parent 0ca9fef5
...@@ -9,6 +9,8 @@ import os, sys ...@@ -9,6 +9,8 @@ import os, sys
import argparse import argparse
from collections import Counter from collections import Counter
import operator import operator
import six
from six.moves import map, range
from tensorpack import * from tensorpack import *
from tensorpack.models import * from tensorpack.models import *
...@@ -20,16 +22,26 @@ from tensorpack.callbacks import * ...@@ -20,16 +22,26 @@ from tensorpack.callbacks import *
from tensorflow.models.rnn import rnn_cell from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq from tensorflow.models.rnn import seq2seq
BATCH_SIZE = 128 if six.PY2:
RNN_SIZE = 128 # hidden state size class NS: pass # this is a hack
NUM_RNN_LAYER = 2 else:
SEQ_LEN = 50 import types
VOCAB_SIZE = None # will be initialized by CharRNNData NS = types.SimpleNamespace # this is what I wanted..
CORPUS = 'input.txt' param = NS()
# some model hyperparams to set
param.batch_size = 128
param.rnn_size = 128
param.num_rnn_layer = 2
param.seq_len = 50
param.grad_clip = 5.
param.vocab_size = None
param.softmax_temprature = 1
param.corpus = 'input.txt'
# Get corpus to play with at: http://cs.stanford.edu/people/karpathy/char-rnn/
class CharRNNData(DataFlow): class CharRNNData(DataFlow):
def __init__(self, input_file, size): def __init__(self, input_file, size):
self.seq_length = SEQ_LEN self.seq_length = param.seq_len
self._size = size self._size = size
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -40,8 +52,7 @@ class CharRNNData(DataFlow): ...@@ -40,8 +52,7 @@ class CharRNNData(DataFlow):
char_cnt = sorted(counter.items(), key=operator.itemgetter(1), reverse=True) char_cnt = sorted(counter.items(), key=operator.itemgetter(1), reverse=True)
self.chars = [x[0] for x in char_cnt] self.chars = [x[0] for x in char_cnt]
self.vocab_size = len(self.chars) self.vocab_size = len(self.chars)
global VOCAB_SIZE param.vocab_size = self.vocab_size
VOCAB_SIZE = self.vocab_size
self.lut = LookUpTable(self.chars) self.lut = LookUpTable(self.chars)
self.whole_seq = np.array(list(map(self.lut.get_idx, data)), dtype='int32') self.whole_seq = np.array(list(map(self.lut.get_idx, data)), dtype='int32')
...@@ -61,50 +72,48 @@ class CharRNNData(DataFlow): ...@@ -61,50 +72,48 @@ class CharRNNData(DataFlow):
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.int32, (None, SEQ_LEN), 'input'), return [InputVar(tf.int32, (None, param.seq_len), 'input'),
InputVar(tf.int32, (None, SEQ_LEN), 'nextinput') InputVar(tf.int32, (None, param.seq_len), 'nextinput') ]
]
def _get_cost(self, input_vars, is_training): def _get_cost(self, input_vars, is_training):
input, nextinput = input_vars input, nextinput = input_vars
cell = rnn_cell.BasicLSTMCell(RNN_SIZE) cell = rnn_cell.BasicLSTMCell(num_units=param.rnn_size)
cell = rnn_cell.MultiRNNCell([cell] * NUM_RNN_LAYER) cell = rnn_cell.MultiRNNCell([cell] * param.num_rnn_layer)
self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32) self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32)
embeddingW = tf.get_variable('embedding', [VOCAB_SIZE, RNN_SIZE]) embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size])
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize
input_list = tf.split(1, SEQ_LEN, input_feature) #seqlen x (Bx1xrnnsize) input_list = tf.split(1, param.seq_len, input_feature) #seqlen x (Bx1xrnnsize)
input_list = [tf.squeeze(x, [1]) for x in input_list] input_list = [tf.squeeze(x, [1]) for x in input_list]
# seqlen is 1 in inference. don't need loop_function # seqlen is 1 in inference. don't need loop_function
outputs, last_state = seq2seq.rnn_decoder(input_list, initial, cell, scope='rnnlm') outputs, last_state = seq2seq.rnn_decoder(input_list, initial, cell, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state') self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize) # seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(1, outputs), [-1, RNN_SIZE]) # (seqlenxB) x rnnsize output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (seqlenxB) x rnnsize
logits = FullyConnected('fc', output, VOCAB_SIZE, nl=tf.identity) logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity)
self.prob = tf.nn.softmax(logits) self.prob = tf.nn.softmax(logits / param.softmax_temprature)
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, symbolic_functions.flatten(nextinput)) logits, symbolic_functions.flatten(nextinput))
xent_loss = tf.reduce_mean(xent_loss, name='xent_loss') xent_loss = tf.reduce_mean(xent_loss, name='xent_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, xent_loss)
summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([xent_loss], name='cost') return tf.add_n([xent_loss], name='cost')
def get_gradient_processor(self): def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5.)[0][0])] return [MapGradient(lambda grad: tf.clip_by_global_norm(
[grad], param.grad_clip)[0][0])]
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')])) os.path.join('train_log', basename[:basename.rfind('.')]))
ds = CharRNNData(CORPUS, 100000) ds = CharRNNData(param.corpus, 100000)
ds = BatchData(ds, 128) ds = BatchData(ds, param.batch_size)
step_per_epoch = ds.size() step_per_epoch = ds.size()
lr = tf.Variable(2e-3, trainable=False, name='learning_rate') lr = tf.Variable(2e-3, trainable=False, name='learning_rate')
...@@ -130,9 +139,8 @@ def sample(path, start, length): ...@@ -130,9 +139,8 @@ def sample(path, start, length):
:param length: a `int`. the length of text to generate :param length: a `int`. the length of text to generate
""" """
# initialize vocabulary and sequence length # initialize vocabulary and sequence length
global SEQ_LEN param.seq_len = 1
SEQ_LEN = 1 ds = CharRNNData(param.corpus, 100000)
ds = CharRNNData(CORPUS, 100000)
model = Model() model = Model()
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
...@@ -170,10 +178,12 @@ if __name__ == '__main__': ...@@ -170,10 +178,12 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
subparsers = parser.add_subparsers(title='command', dest='command') subparsers = parser.add_subparsers(title='command', dest='command')
parser_sample = subparsers.add_parser('sample', help='sample a trained model') parser_sample = subparsers.add_parser('sample', help='sample a trained model')
parser_sample.add_argument('-n', '--num', type=int, default=300, parser_sample.add_argument('-n', '--num', type=int,
help='length of text to generate') default=300, help='length of text to generate')
parser_sample.add_argument('-s', '--start', required=True, default='The ', parser_sample.add_argument('-s', '--start',
help='initial text sequence') default='The ', help='initial text sequence')
parser_sample.add_argument('-t', '--temperature', type=float,
default=1, help='softmax temperature')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -181,6 +191,7 @@ if __name__ == '__main__': ...@@ -181,6 +191,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if args.command == 'sample': if args.command == 'sample':
param.softmax_temprature = args.temperature
sample(args.load, args.start, args.num) sample(args.load, args.start, args.num)
sys.exit() sys.exit()
else: else:
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import os, sys import os, sys
import argparse import argparse
import tensorpack as tp from tensorpack import *
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
...@@ -56,7 +56,7 @@ class Model(ModelDesc): ...@@ -56,7 +56,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ClassificationError to use at test time # compute the number of failed samples, for ClassificationError to use at test time
wrong = tp.symbolic_functions.prediction_incorrect(logits, label) wrong = symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
...@@ -68,7 +68,7 @@ class Model(ModelDesc): ...@@ -68,7 +68,7 @@ class Model(ModelDesc):
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
tp.summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost') return tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
...@@ -77,22 +77,18 @@ def get_config(): ...@@ -77,22 +77,18 @@ def get_config():
os.path.join('train_log', basename[:basename.rfind('.')])) os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = tp.BatchData(tp.dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = tp.BatchData(tp.dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
# prepare session
sess_config = tp.get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-3, learning_rate=1e-3,
global_step=tp.get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 10, decay_steps=dataset_train.size() * 10,
decay_rate=0.3, staircase=True, name='learning_rate') decay_rate=0.3, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return tp.TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
...@@ -101,7 +97,7 @@ def get_config(): ...@@ -101,7 +97,7 @@ def get_config():
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ]) [ScalarStats('cost'), ClassificationError() ])
]), ]),
session_config=sess_config, session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=100,
...@@ -121,6 +117,5 @@ if __name__ == '__main__': ...@@ -121,6 +117,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
#tp.SimpleTrainer(config).train() QueueInputTrainer(config).train()
tp.QueueInputTrainer(config).train()
...@@ -96,7 +96,6 @@ class MapGradient(GradientProcessor): ...@@ -96,7 +96,6 @@ class MapGradient(GradientProcessor):
ret = [] ret = []
for grad, var in grads: for grad, var in grads:
if re.match(self.regex, var.op.name): if re.match(self.regex, var.op.name):
logger.info("DEBUG {}".format(var.op.name))
ret.append((self.func(grad), var)) ret.append((self.func(grad), var))
else: else:
ret.append((grad, var)) ret.append((grad, var))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment