Commit 879995d9 authored by Yuxin Wu's avatar Yuxin Wu

fix rnn sample(). fix gym breaking changes.

parent 14b3578a
...@@ -6,18 +6,18 @@ you'll need to subclass `ModelDesc` and implement several methods: ...@@ -6,18 +6,18 @@ you'll need to subclass `ModelDesc` and implement several methods:
```python ```python
class MyModel(ModelDesc): class MyModel(ModelDesc):
def _get_input_vars(self): def _get_inputs(self):
return [InputVar(...), InputVar(...)] return [InputVar(...), InputVar(...)]
def _build_graph(self, input_tensors): def _build_graph(self, inputs):
# build the graph # build the graph
``` ```
Basically, `_get_input_vars` should define the metainfo of the input Basically, `_get_inputs` should define the metainfo of the input
of the model. It should match what is produced by the data you're training with. of the model. It should match what is produced by the data you're training with.
`_build_graph` should add tensors/operations to the graph, where `_build_graph` should add tensors/operations to the graph, where
the argument `input_tensors` is the list of input tensors matching the return value of the argument `input_tensors` is the list of input tensors matching the return value of
`_get_input_vars`. `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer. functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer.
......
...@@ -17,6 +17,7 @@ from tensorpack import * ...@@ -17,6 +17,7 @@ from tensorpack import *
from tensorpack.tfutils.gradproc import * from tensorpack.tfutils.gradproc import *
from tensorpack.utils.lut import LookUpTable from tensorpack.utils.lut import LookUpTable
from tensorpack.utils.globvars import globalns as param from tensorpack.utils.globvars import globalns as param
rnn = tf.contrib.rnn
# some model hyperparams to set # some model hyperparams to set
param.batch_size = 128 param.batch_size = 128
...@@ -67,10 +68,18 @@ class Model(ModelDesc): ...@@ -67,10 +68,18 @@ class Model(ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
input, nextinput = inputs input, nextinput = inputs
cell = tf.contrib.rnn.BasicLSTMCell(num_units=param.rnn_size) cell = rnn.BasicLSTMCell(num_units=param.rnn_size)
cell = tf.contrib.rnn.MultiRNNCell([cell] * param.num_rnn_layer) cell = rnn.MultiRNNCell([cell] * param.num_rnn_layer)
self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32) def get_v(n):
ret = tf.get_variable(n + '_unused', [param.batch_size, param.rnn_size],
trainable=False,
initializer=tf.constant_initializer())
ret = symbolic_functions.shapeless_placeholder(ret, 0, name=n)
return ret
self.initial = initial = \
(rnn.LSTMStateTuple(get_v('c0'), get_v('h0')),
rnn.LSTMStateTuple(get_v('c1'), get_v('h1')))
embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size]) embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size])
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize
...@@ -78,13 +87,13 @@ class Model(ModelDesc): ...@@ -78,13 +87,13 @@ class Model(ModelDesc):
input_list = tf.unstack(input_feature, axis=1) # seqlen x (Bxrnnsize) input_list = tf.unstack(input_feature, axis=1) # seqlen x (Bxrnnsize)
# seqlen is 1 in inference. don't need loop_function # seqlen is 1 in inference. don't need loop_function
outputs, last_state = tf.contrib.rnn.static_rnn(cell, input_list, initial, scope='rnnlm') outputs, last_state = rnn.static_rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state') self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize) # seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(outputs, 1), [-1, param.rnn_size]) # (Bxseqlen) x rnnsize output = tf.reshape(tf.concat(outputs, 1), [-1, param.rnn_size]) # (Bxseqlen) x rnnsize
logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity) logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity)
self.prob = tf.nn.softmax(logits / param.softmax_temprature) self.prob = tf.nn.softmax(logits / param.softmax_temprature, name='prob')
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=symbolic_functions.flatten(nextinput)) logits=logits, labels=symbolic_functions.flatten(nextinput))
...@@ -130,35 +139,32 @@ def sample(path, start, length): ...@@ -130,35 +139,32 @@ def sample(path, start, length):
param.seq_len = 1 param.seq_len = 1
ds = CharRNNData(param.corpus, 100000) ds = CharRNNData(param.corpus, 100000)
model = Model() pred = OfflinePredictor(PredictConfig(
inputs = model.get_reuse_placehdrs() model=Model(),
model.build_graph(inputs, False) session_init=SaverRestore(path),
sess = tf.Session() input_names=['input', 'c0', 'h0', 'c1', 'h1'],
tfutils.SaverRestore(path).init(sess) output_names=['prob', 'last_state']))
dummy_input = np.zeros((1, 1), dtype='int32') # feed the starting sentence
with sess.as_default(): initial = np.zeros((1, param.rnn_size))
# feed the starting sentence for c in start[:-1]:
state = model.initial.eval({inputs[0]: dummy_input}) x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
for c in start[:-1]: _, state = pred(x, initial, initial, initial, initial)
x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
state = model.last_state.eval({inputs[0]: x, model.initial: state}) def pick(prob):
t = np.cumsum(prob)
def pick(prob): s = np.sum(prob)
t = np.cumsum(prob) return(int(np.searchsorted(t, np.random.rand(1) * s)))
s = np.sum(prob)
return(int(np.searchsorted(t, np.random.rand(1) * s))) # generate more
ret = start
# generate more c = start[-1]
ret = start for k in range(length):
c = start[-1] x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
for k in range(length): prob, state = pred(x, state[0, 0], state[0, 1], state[1, 0], state[1, 1])
x = np.array([[ds.lut.get_idx(c)]], dtype='int32') c = ds.lut.get_obj(pick(prob[0]))
[prob, state] = sess.run([model.prob, model.last_state], ret += c
{inputs[0]: x, model.initial: state}) print(ret)
c = ds.lut.get_obj(pick(prob[0]))
ret += c
print(ret)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -34,7 +34,7 @@ class GymEnv(RLEnvironment): ...@@ -34,7 +34,7 @@ class GymEnv(RLEnvironment):
self.gymenv = gym.make(name) self.gymenv = gym.make(name)
if dumpdir: if dumpdir:
mkdir_p(dumpdir) mkdir_p(dumpdir)
self.gymenv.monitor.start(dumpdir) self.gymenv = gym.wrappers.Monitor(self.gymenv, dumpdir)
self.use_dir = dumpdir self.use_dir = dumpdir
self.reset_stat() self.reset_stat()
...@@ -75,6 +75,7 @@ class GymEnv(RLEnvironment): ...@@ -75,6 +75,7 @@ class GymEnv(RLEnvironment):
try: try:
import gym import gym
import gym.wrappers
# TODO # TODO
# gym.undo_logger_setup() # gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199 # https://github.com/openai/gym/pull/199
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment