Commit 879995d9 authored by Yuxin Wu's avatar Yuxin Wu

fix rnn sample(). fix gym breaking changes.

parent 14b3578a
......@@ -6,18 +6,18 @@ you'll need to subclass `ModelDesc` and implement several methods:
```python
class MyModel(ModelDesc):
def _get_input_vars(self):
def _get_inputs(self):
return [InputVar(...), InputVar(...)]
def _build_graph(self, input_tensors):
def _build_graph(self, inputs):
# build the graph
```
Basically, `_get_input_vars` should define the metainfo of the input
Basically, `_get_inputs` should define the metainfo of the input
of the model. It should match what is produced by the data you're training with.
`_build_graph` should add tensors/operations to the graph, where
the argument `input_tensors` is the list of input tensors matching the return value of
`_get_input_vars`.
`_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer.
......
......@@ -17,6 +17,7 @@ from tensorpack import *
from tensorpack.tfutils.gradproc import *
from tensorpack.utils.lut import LookUpTable
from tensorpack.utils.globvars import globalns as param
rnn = tf.contrib.rnn
# some model hyperparams to set
param.batch_size = 128
......@@ -67,10 +68,18 @@ class Model(ModelDesc):
def _build_graph(self, inputs):
input, nextinput = inputs
cell = tf.contrib.rnn.BasicLSTMCell(num_units=param.rnn_size)
cell = tf.contrib.rnn.MultiRNNCell([cell] * param.num_rnn_layer)
cell = rnn.BasicLSTMCell(num_units=param.rnn_size)
cell = rnn.MultiRNNCell([cell] * param.num_rnn_layer)
self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32)
def get_v(n):
ret = tf.get_variable(n + '_unused', [param.batch_size, param.rnn_size],
trainable=False,
initializer=tf.constant_initializer())
ret = symbolic_functions.shapeless_placeholder(ret, 0, name=n)
return ret
self.initial = initial = \
(rnn.LSTMStateTuple(get_v('c0'), get_v('h0')),
rnn.LSTMStateTuple(get_v('c1'), get_v('h1')))
embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size])
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize
......@@ -78,13 +87,13 @@ class Model(ModelDesc):
input_list = tf.unstack(input_feature, axis=1) # seqlen x (Bxrnnsize)
# seqlen is 1 in inference. don't need loop_function
outputs, last_state = tf.contrib.rnn.static_rnn(cell, input_list, initial, scope='rnnlm')
outputs, last_state = rnn.static_rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(outputs, 1), [-1, param.rnn_size]) # (Bxseqlen) x rnnsize
logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity)
self.prob = tf.nn.softmax(logits / param.softmax_temprature)
self.prob = tf.nn.softmax(logits / param.softmax_temprature, name='prob')
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=symbolic_functions.flatten(nextinput))
......@@ -130,19 +139,17 @@ def sample(path, start, length):
param.seq_len = 1
ds = CharRNNData(param.corpus, 100000)
model = Model()
inputs = model.get_reuse_placehdrs()
model.build_graph(inputs, False)
sess = tf.Session()
tfutils.SaverRestore(path).init(sess)
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=SaverRestore(path),
input_names=['input', 'c0', 'h0', 'c1', 'h1'],
output_names=['prob', 'last_state']))
dummy_input = np.zeros((1, 1), dtype='int32')
with sess.as_default():
# feed the starting sentence
state = model.initial.eval({inputs[0]: dummy_input})
initial = np.zeros((1, param.rnn_size))
for c in start[:-1]:
x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
state = model.last_state.eval({inputs[0]: x, model.initial: state})
_, state = pred(x, initial, initial, initial, initial)
def pick(prob):
t = np.cumsum(prob)
......@@ -154,8 +161,7 @@ def sample(path, start, length):
c = start[-1]
for k in range(length):
x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
[prob, state] = sess.run([model.prob, model.last_state],
{inputs[0]: x, model.initial: state})
prob, state = pred(x, state[0, 0], state[0, 1], state[1, 0], state[1, 1])
c = ds.lut.get_obj(pick(prob[0]))
ret += c
print(ret)
......
......@@ -34,7 +34,7 @@ class GymEnv(RLEnvironment):
self.gymenv = gym.make(name)
if dumpdir:
mkdir_p(dumpdir)
self.gymenv.monitor.start(dumpdir)
self.gymenv = gym.wrappers.Monitor(self.gymenv, dumpdir)
self.use_dir = dumpdir
self.reset_stat()
......@@ -75,6 +75,7 @@ class GymEnv(RLEnvironment):
try:
import gym
import gym.wrappers
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment