Commit 890df78f authored by Yuxin Wu's avatar Yuxin Wu

small fix in RandomApplyAug & A3C

parent a0247332
...@@ -80,12 +80,12 @@ class MySimulatorWorker(SimulatorProcess): ...@@ -80,12 +80,12 @@ class MySimulatorWorker(SimulatorProcess):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'),
InputDesc(tf.int64, (None,), 'action'), InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'futurereward')] InputDesc(tf.float32, (None,), 'futurereward')]
def _get_NN_prediction(self, image): def _get_NN_prediction(self, image):
image = image / 255.0 image = tf.cast(image, tf.float32) / 255.0
with argscope(Conv2D, nl=tf.nn.relu): with argscope(Conv2D, nl=tf.nn.relu):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5) l = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
l = MaxPooling('pool0', l, 2) l = MaxPooling('pool0', l, 2)
...@@ -220,7 +220,7 @@ def get_config(): ...@@ -220,7 +220,7 @@ def get_config():
dataflow=dataflow, dataflow=dataflow,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(80, 0.0003), (120, 0.0001)]), ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]), ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
ScheduledHyperParamSetter('explore_factor', ScheduledHyperParamSetter('explore_factor',
[(80, 2), (100, 3), (120, 4), (140, 5)]), [(80, 2), (100, 3), (120, 4), (140, 5)]),
...@@ -230,7 +230,7 @@ def get_config(): ...@@ -230,7 +230,7 @@ def get_config():
StartProcOrThread(master), StartProcOrThread(master),
PeriodicTrigger(Evaluator( PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['policy'], get_player), EVAL_EPISODE, ['state'], ['policy'], get_player),
every_k_epochs=2), every_k_epochs=3),
], ],
session_creator=sesscreate.NewSessionCreator( session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)), config=get_default_sess_config(0.5)),
...@@ -264,7 +264,7 @@ if __name__ == '__main__': ...@@ -264,7 +264,7 @@ if __name__ == '__main__':
if args.task != 'train': if args.task != 'train':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=get_model_loader(args.load),
input_names=['state'], input_names=['state'],
output_names=['policy']) output_names=['policy'])
if args.task == 'play': if args.task == 'play':
...@@ -296,7 +296,7 @@ if __name__ == '__main__': ...@@ -296,7 +296,7 @@ if __name__ == '__main__':
trainer = QueueInputTrainer trainer = QueueInputTrainer
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = get_model_loader(args.load)
config.tower = train_tower config.tower = train_tower
config.predict_tower = predict_tower config.predict_tower = predict_tower
trainer(config).train() trainer(config).train()
...@@ -120,4 +120,4 @@ def play_n_episodes(player, predfunc, nr): ...@@ -120,4 +120,4 @@ def play_n_episodes(player, predfunc, nr):
if k != 0: if k != 0:
player.restart_episode() player.restart_episode()
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
print("Score:", score) print("{}/{}, score=", k, nr, score)
...@@ -109,6 +109,9 @@ class AugmentImageComponents(MapData): ...@@ -109,6 +109,9 @@ class AugmentImageComponents(MapData):
to keep the original images not modified. to keep the original images not modified.
Turn it off to save time when you know it's OK. Turn it off to save time when you know it's OK.
""" """
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.ds = ds self.ds = ds
self._nr_error = 0 self._nr_error = 0
......
...@@ -38,6 +38,14 @@ class RandomApplyAug(ImageAugmentor): ...@@ -38,6 +38,14 @@ class RandomApplyAug(ImageAugmentor):
else: else:
return (False, None) return (False, None)
def _augment_return_params(self, img):
p = self.rng.rand()
if p < self.prob:
img, prms = self.aug._augment_return_params(img)
return img, (True, prms)
else:
return img, (False, None)
def reset_state(self): def reset_state(self):
super(RandomApplyAug, self).reset_state() super(RandomApplyAug, self).reset_state()
self.aug.reset_state() self.aug.reset_state()
......
...@@ -31,7 +31,7 @@ def FullyConnected(x, out_dim, ...@@ -31,7 +31,7 @@ def FullyConnected(x, out_dim,
Variable Names: Variable Names:
* ``W``: weights * ``W``: weights of shape [in_dim, out_dim]
* ``b``: bias * ``b``: bias
""" """
x = symbf.batch_flatten(x) x = symbf.batch_flatten(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment