Commit d10cb1af authored by Yuxin Wu's avatar Yuxin Wu

small updates

parent 34a5a809
...@@ -117,8 +117,10 @@ def sample(model_path): ...@@ -117,8 +117,10 @@ def sample(model_path):
input_names=['zc'], input_names=['zc'],
output_names=['gen/gen'])) output_names=['gen/gen']))
eye = [k for k in np.eye(10)] eye = []
inputs = np.asarray(eye * 10) for k in np.eye(10):
eye = eye + [k] * 10
inputs = np.asarray(eye)
while True: while True:
o = pred([inputs]) o = pred([inputs])
o = (o[0] + 1) * 128.0 o = (o[0] + 1) * 128.0
......
...@@ -33,4 +33,4 @@ It requires the datasets released by the original authors. ...@@ -33,4 +33,4 @@ It requires the datasets released by the original authors.
Reproduce a mnist experiement in InfoGAN. Reproduce a mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information, By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
the GAN learns to map the 10 variables to 10 digits in an unsupervised fashion. the network unsupervisedly learns to map the 10 variables to 10 digits.
### code and models for my Gym submissions on Atari games ### Code and models for my Gym submissions on Atari games
Use A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). Implemented A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783).
### To train on an Atari game: ### To train on an Atari game:
......
...@@ -146,6 +146,7 @@ class SimulatorMaster(threading.Thread): ...@@ -146,6 +146,7 @@ class SimulatorMaster(threading.Thread):
while True: while True:
msg = loads(self.c2s_socket.recv(copy=False).bytes) msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg ident, state, reward, isOver = msg
# TODO check history and warn about dead client
client = self.clients[ident] client = self.clients[ident]
# check if reward&isOver is valid # check if reward&isOver is valid
......
...@@ -207,7 +207,7 @@ class HyperParamSetterWithFunc(HyperParamSetter): ...@@ -207,7 +207,7 @@ class HyperParamSetterWithFunc(HyperParamSetter):
"""Set hyperparameter by a func """Set hyperparameter by a func
new_value = f(epoch_num, old_value) new_value = f(epoch_num, old_value)
""" """
super(StatMonitorParamSetter, self).__init__(param) super(HyperParamSetterWithFunc, self).__init__(param)
self.f = func self.f = func
def _get_value_to_set(self): def _get_value_to_set(self):
......
...@@ -10,7 +10,7 @@ from six.moves import range ...@@ -10,7 +10,7 @@ from six.moves import range
import numpy as np import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', __all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop',
'RandomCropRandomShape', 'perturb_BB'] 'RandomCropRandomShape', 'perturb_BB', 'RandomCropAroundBox']
class RandomCrop(ImageAugmentor): class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """ """ Randomly crop the image into a smaller one """
...@@ -109,7 +109,7 @@ def perturb_BB(image_shape, bb, max_pertub_pixel, ...@@ -109,7 +109,7 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
return bb return bb
class RandomCropRandomShape(ImageAugmentor): class RandomCropAroundBox(ImageAugmentor):
""" """
Crop a box around a bounding box Crop a box around a bounding box
""" """
...@@ -118,7 +118,7 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -118,7 +118,7 @@ class RandomCropRandomShape(ImageAugmentor):
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)] :param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param max_aspect_ratio_diff: keep aspect ratio within the range :param max_aspect_ratio_diff: keep aspect ratio within the range
""" """
super(RandomCropRandomShape, self).__init__() super(RandomCropAroundBox, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -135,5 +135,33 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -135,5 +135,33 @@ class RandomCropRandomShape(ImageAugmentor):
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class RandomCropRandomShape(ImageAugmentor):
def __init__(self, wmin, hmin,
wmax=None, hmax=None,
max_aspect_ratio=None):
"""
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
If max is None, will use the input image shape.
max_aspect_ratio is the upper bound of max(w,h)/min(w,h)
"""
if max_aspect_ratio is None:
max_aspect_ratio = 9999999
self._init(locals())
def _get_augment_params(self, img):
hmax = self.hmax or img.shape[0]
wmax = self.wmax or img.shape[1]
h = self.rng.randint(self.hmin, hmax+1)
w = self.rng.randint(self.wmin, wmax+1)
diffh = img.shape[0] - h
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = img.shape[1] - w
x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0,x0,h,w)
def _augment(self, img, param):
y0, x0, h, w = param
return img[y0:y0+h,x0:x0+w]
if __name__ == '__main__': if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50)) print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
...@@ -91,18 +91,22 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase, ...@@ -91,18 +91,22 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
MultiGPUTrainer, MultiGPUTrainer,
SingleCostFeedlessTrainer, SingleCostFeedlessTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config,
input_queue=None,
predict_tower=None,
average_gradient=True):
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue) self._build_enque_thread(input_queue)
self.average_gradient = average_gradient
def _setup(self): def _setup(self):
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and if self.average_gradient and self.config.nr_tower > 1:
# sync have consistent effective learning rate # pretend to average the grads, in order to make async and
if self.config.nr_tower > 1: # sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)) gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list] grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: debug.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import sys
__all__ = ['enable_call_trace']
def enable_call_trace():
def tracer(frame, event, arg):
if event == 'call':
co = frame.f_code
func_name = co.co_name
if func_name == 'write' or func_name == 'print':
# ignore write() calls from print statements
return
func_line_no = frame.f_lineno
func_filename = co.co_filename
caller = frame.f_back
if caller:
caller_line_no = caller.f_lineno
caller_filename = caller.f_code.co_filename
print 'Call to `%s` on line %s:%s from %s:%s' % \
(func_name, func_filename, func_line_no,
caller_filename, caller_line_no)
return
sys.settrace(tracer)
if __name__ == '__main__':
enable_call_trace()
def b(a):
print 2
def a():
print 1
b(1)
a()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment