Commit 224025e3 authored by Yuxin Wu's avatar Yuxin Wu

add infogan models, and some small changes.

parent 55f2f5da
......@@ -19,7 +19,7 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png)
DQN typically took 2 days of training to reach a score of 400 on breakout game (same as the paper).
DQN typically took 1.5 days of training to reach a score of 400 on breakout game (same as the paper).
My Batch-A3C implementation only took <2 hours.
Both were trained on one GPU with an extra GPU for simulation.
......
......@@ -102,7 +102,7 @@ class ExpReplay(DataFlow, Callback):
# self.mem.append(deepcopy(self.mem[0]))
# return
old_s = self.player.current_state()
if self.rng.rand() <= self.exploration:
if self.rng.rand() <= self.exploration or len(self.mem) < 5:
act = self.rng.choice(range(self.num_actions))
else:
# build a history state
......
......@@ -21,6 +21,8 @@ To train:
To visualize:
./ConditionalGAN-mnist.py --sample --load path/to/model
A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
"""
BATCH = 128
......
......@@ -23,6 +23,8 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing steps in `get_data()`.
A pretrained model on CelebA is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
"""
SHAPE = 64
......
......@@ -23,6 +23,8 @@ To train:
To visualize:
./InfoGAN-mnist.py --sample --load path/to/model
A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
"""
BATCH = 128
......
......@@ -12,14 +12,12 @@ Reproduce the following GAN-related methods:
+ [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
Please see the __docstring__ in each script for detailed usage.
Please see the __docstring__ in each script for detailed usage and pretrained models.
## DCGAN-CelebA.py
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTvr2BBkLUF2M0RXU1NYSkE?usp=sharing) on CelebA face dataset:
+ Generated samples
![sample](demo/CelebA-samples.jpg)
......
......@@ -73,7 +73,7 @@ class ResizeShortestEdge(ImageAugmentor):
keeping the aspect ratio.
"""
def __init__(self, size):
def __init__(self, size, interp=cv2.INTER_LINEAR):
"""
Args:
size (int): the size to resize the shortest edge to.
......@@ -85,7 +85,7 @@ class ResizeShortestEdge(ImageAugmentor):
h, w = img.shape[:2]
scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h])
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_LINEAR)
ret = cv2.resize(img, tuple(desSize), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
......
......@@ -241,11 +241,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
moving_sigma = tf.sqrt(moving_var, 'sigma')
inv_sigma = tf.rsqrt(moving_var, 'inv_sigma')
r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var / moving_var), 1.0 / rmax, rmax))
tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) / moving_sigma,
(batch_mean - moving_mean) * inv_sigma,
-dmax, dmax))
xn = xn * r + d
else:
......
......@@ -77,7 +77,7 @@ def apply_slim_collections(cost):
if ctx is not None and ctx.is_main_training_tower:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
if non_grad_updates:
logger.info("Applying UPDATE_OPS collection on cost.")
logger.info("Applying UPDATE_OPS collection from the first tower on cost.")
with tf.control_dependencies(non_grad_updates):
cost = tf.identity(cost, name='cost_with_update')
return cost
......@@ -40,6 +40,7 @@ class PredictorFactory(object):
def get_name_in_tower(name):
return PREDICT_TOWER + str(tower) + '/' + name
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment