Commit 3f14f3a7 authored by Yuxin Wu's avatar Yuxin Wu

fix deprecate use of periodiccallback

parent e49d4fd4
...@@ -299,7 +299,7 @@ def get_config(): ...@@ -299,7 +299,7 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=[PeriodicCallback(ModelSaver(), 3)], callbacks=[PeriodicTrigger(ModelSaver(), every_k_epochs=3)],
model=Model(), model=Model(),
steps_per_epoch=dataset.size(), steps_per_epoch=dataset.size(),
max_epoch=100, max_epoch=100,
......
...@@ -220,7 +220,7 @@ def get_config(): ...@@ -220,7 +220,7 @@ def get_config():
[(80, 2), (100, 3), (120, 4), (140, 5)]), [(80, 2), (100, 3), (120, 4), (140, 5)]),
master, master,
StartProcOrThread(master), StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['policy']), 2), PeriodicTrigger(Evaluator(EVAL_EPISODE, ['state'], ['policy']), every_k_epochs=2),
], ],
session_creator=sesscreate.NewSessionCreator( session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)), config=get_default_sess_config(0.5)),
......
...@@ -73,7 +73,7 @@ class Model(ModelDesc): ...@@ -73,7 +73,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
input, nextinput = inputs input, nextinput = inputs
cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(num_units=param.rnn_size) cell = rnn.MultiRNNCell([rnn.LSTMBlockCell(num_units=param.rnn_size)
for _ in range(param.num_rnn_layer)]) for _ in range(param.num_rnn_layer)])
def get_v(n): def get_v(n):
...@@ -91,7 +91,6 @@ class Model(ModelDesc): ...@@ -91,7 +91,6 @@ class Model(ModelDesc):
input_list = tf.unstack(input_feature, axis=1) # seqlen x (Bxrnnsize) input_list = tf.unstack(input_feature, axis=1) # seqlen x (Bxrnnsize)
# seqlen is 1 in inference. don't need loop_function
outputs, last_state = rnn.static_rnn(cell, input_list, initial, scope='rnnlm') outputs, last_state = rnn.static_rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state') self.last_state = tf.identity(last_state, 'last_state')
......
...@@ -84,7 +84,10 @@ except ImportError: ...@@ -84,7 +84,10 @@ except ImportError:
if __name__ == '__main__': if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1) import gym_ple, cv2 # noqa
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
env = GymEnv('FlappyBird-v0', viz=0.1)
num = env.get_action_space().num_actions() num = env.get_action_space().num_actions()
from ..utils import get_rng from ..utils import get_rng
...@@ -93,6 +96,10 @@ if __name__ == '__main__': ...@@ -93,6 +96,10 @@ if __name__ == '__main__':
act = rng.choice(range(num)) act = rng.choice(range(num))
# print act # print act
r, o = env.action(act) r, o = env.action(act)
env.current_state() state = env.current_state()
state = cv2.resize(state[:450], (84, 84))
cv2.imshow("aa", state)
cv2.waitKey(3)
print(state.shape)
if r != 0 or o: if r != 0 or o:
print(r, o) print(r, o)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment