Commit 890df78f authored by Yuxin Wu's avatar Yuxin Wu

small fix in RandomApplyAug & A3C

parent a0247332
......@@ -80,12 +80,12 @@ class MySimulatorWorker(SimulatorProcess):
class Model(ModelDesc):
def _get_inputs(self):
assert NUM_ACTIONS is not None
return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
return [InputDesc(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'),
InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'futurereward')]
def _get_NN_prediction(self, image):
image = image / 255.0
image = tf.cast(image, tf.float32) / 255.0
with argscope(Conv2D, nl=tf.nn.relu):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
l = MaxPooling('pool0', l, 2)
......@@ -220,7 +220,7 @@ def get_config():
dataflow=dataflow,
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(80, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
ScheduledHyperParamSetter('explore_factor',
[(80, 2), (100, 3), (120, 4), (140, 5)]),
......@@ -230,7 +230,7 @@ def get_config():
StartProcOrThread(master),
PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['policy'], get_player),
every_k_epochs=2),
every_k_epochs=3),
],
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
......@@ -264,7 +264,7 @@ if __name__ == '__main__':
if args.task != 'train':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
session_init=get_model_loader(args.load),
input_names=['state'],
output_names=['policy'])
if args.task == 'play':
......@@ -296,7 +296,7 @@ if __name__ == '__main__':
trainer = QueueInputTrainer
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = get_model_loader(args.load)
config.tower = train_tower
config.predict_tower = predict_tower
trainer(config).train()
......@@ -120,4 +120,4 @@ def play_n_episodes(player, predfunc, nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Score:", score)
print("{}/{}, score=", k, nr, score)
......@@ -109,7 +109,10 @@ class AugmentImageComponents(MapData):
to keep the original images not modified.
Turn it off to save time when you know it's OK.
"""
self.augs = AugmentorList(augmentors)
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self.ds = ds
self._nr_error = 0
......
......@@ -38,6 +38,14 @@ class RandomApplyAug(ImageAugmentor):
else:
return (False, None)
def _augment_return_params(self, img):
p = self.rng.rand()
if p < self.prob:
img, prms = self.aug._augment_return_params(img)
return img, (True, prms)
else:
return img, (False, None)
def reset_state(self):
super(RandomApplyAug, self).reset_state()
self.aug.reset_state()
......
......@@ -31,7 +31,7 @@ def FullyConnected(x, out_dim,
Variable Names:
* ``W``: weights
* ``W``: weights of shape [in_dim, out_dim]
* ``b``: bias
"""
x = symbf.batch_flatten(x)
......
......@@ -99,9 +99,9 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
# self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# if not hasattr(self, 'xxx'):
# self.xxx = outputs = self.func(batched)
# self.xxx = outputs = self.func(batched)
# else:
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment