Commit 224025e3 authored by Yuxin Wu's avatar Yuxin Wu

add infogan models, and some small changes.

parent 55f2f5da
...@@ -19,7 +19,7 @@ Claimed performance in the paper can be reproduced, on several games I've tested ...@@ -19,7 +19,7 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
DQN typically took 2 days of training to reach a score of 400 on breakout game (same as the paper). DQN typically took 1.5 days of training to reach a score of 400 on breakout game (same as the paper).
My Batch-A3C implementation only took <2 hours. My Batch-A3C implementation only took <2 hours.
Both were trained on one GPU with an extra GPU for simulation. Both were trained on one GPU with an extra GPU for simulation.
......
...@@ -102,7 +102,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -102,7 +102,7 @@ class ExpReplay(DataFlow, Callback):
# self.mem.append(deepcopy(self.mem[0])) # self.mem.append(deepcopy(self.mem[0]))
# return # return
old_s = self.player.current_state() old_s = self.player.current_state()
if self.rng.rand() <= self.exploration: if self.rng.rand() <= self.exploration or len(self.mem) < 5:
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
else: else:
# build a history state # build a history state
......
...@@ -21,6 +21,8 @@ To train: ...@@ -21,6 +21,8 @@ To train:
To visualize: To visualize:
./ConditionalGAN-mnist.py --sample --load path/to/model ./ConditionalGAN-mnist.py --sample --load path/to/model
A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
""" """
BATCH = 128 BATCH = 128
......
...@@ -23,6 +23,8 @@ from GAN import GANTrainer, RandomZData, GANModelDesc ...@@ -23,6 +23,8 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
You can also train on other images (just use any directory of jpg files in You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing steps in `get_data()`. `--data`). But you may need to change the preprocessing steps in `get_data()`.
A pretrained model on CelebA is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
""" """
SHAPE = 64 SHAPE = 64
......
...@@ -23,6 +23,8 @@ To train: ...@@ -23,6 +23,8 @@ To train:
To visualize: To visualize:
./InfoGAN-mnist.py --sample --load path/to/model ./InfoGAN-mnist.py --sample --load path/to/model
A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
""" """
BATCH = 128 BATCH = 128
......
...@@ -12,14 +12,12 @@ Reproduce the following GAN-related methods: ...@@ -12,14 +12,12 @@ Reproduce the following GAN-related methods:
+ [Wasserstein GAN](https://arxiv.org/abs/1701.07875) + [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
Please see the __docstring__ in each script for detailed usage. Please see the __docstring__ in each script for detailed usage and pretrained models.
## DCGAN-CelebA.py ## DCGAN-CelebA.py
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch). Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTvr2BBkLUF2M0RXU1NYSkE?usp=sharing) on CelebA face dataset:
+ Generated samples + Generated samples
![sample](demo/CelebA-samples.jpg) ![sample](demo/CelebA-samples.jpg)
......
...@@ -73,7 +73,7 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -73,7 +73,7 @@ class ResizeShortestEdge(ImageAugmentor):
keeping the aspect ratio. keeping the aspect ratio.
""" """
def __init__(self, size): def __init__(self, size, interp=cv2.INTER_LINEAR):
""" """
Args: Args:
size (int): the size to resize the shortest edge to. size (int): the size to resize the shortest edge to.
...@@ -85,7 +85,7 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -85,7 +85,7 @@ class ResizeShortestEdge(ImageAugmentor):
h, w = img.shape[:2] h, w = img.shape[:2]
scale = self.size / min(h, w) scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h]) desSize = map(int, [scale * w, scale * h])
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_LINEAR) ret = cv2.resize(img, tuple(desSize), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
......
...@@ -241,11 +241,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -241,11 +241,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True) epsilon=epsilon, is_training=True)
moving_sigma = tf.sqrt(moving_var, 'sigma') inv_sigma = tf.rsqrt(moving_var, 'inv_sigma')
r = tf.stop_gradient(tf.clip_by_value( r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var / moving_var), 1.0 / rmax, rmax)) tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value( d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) / moving_sigma, (batch_mean - moving_mean) * inv_sigma,
-dmax, dmax)) -dmax, dmax))
xn = xn * r + d xn = xn * r + d
else: else:
......
...@@ -77,7 +77,7 @@ def apply_slim_collections(cost): ...@@ -77,7 +77,7 @@ def apply_slim_collections(cost):
if ctx is not None and ctx.is_main_training_tower: if ctx is not None and ctx.is_main_training_tower:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
if non_grad_updates: if non_grad_updates:
logger.info("Applying UPDATE_OPS collection on cost.") logger.info("Applying UPDATE_OPS collection from the first tower on cost.")
with tf.control_dependencies(non_grad_updates): with tf.control_dependencies(non_grad_updates):
cost = tf.identity(cost, name='cost_with_update') cost = tf.identity(cost, name='cost_with_update')
return cost return cost
...@@ -40,6 +40,7 @@ class PredictorFactory(object): ...@@ -40,6 +40,7 @@ class PredictorFactory(object):
def get_name_in_tower(name): def get_name_in_tower(name):
return PREDICT_TOWER + str(tower) + '/' + name return PREDICT_TOWER + str(tower) + '/' + name
def maybe_inside_tower(name): def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0] name = get_op_tensor_name(name)[0]
if name in placeholder_names: if name in placeholder_names:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment