Commit 55f2f5da authored by Yuxin Wu's avatar Yuxin Wu

trainer.get_predictor support tower-tensor as input

parent 6640f9bb
...@@ -49,12 +49,12 @@ class Model(GANModelDesc): ...@@ -49,12 +49,12 @@ class Model(GANModelDesc):
l = tf.reshape(l, [-1, 7, 7, 128]) l = tf.reshape(l, [-1, 7, 7, 128])
l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU) l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU)
l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity) l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity)
l = tf.tanh(l, name='gen') l = tf.sigmoid(l, name='gen')
return l return l
def discriminator(self, imgs): def discriminator(self, imgs):
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \ with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \
argscope(LeakyReLU, alpha=0.1): argscope(LeakyReLU, alpha=0.2):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv0', 64) .Conv2D('conv0', 64)
.LeakyReLU() .LeakyReLU()
...@@ -72,7 +72,7 @@ class Model(GANModelDesc): ...@@ -72,7 +72,7 @@ class Model(GANModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
real_sample = inputs[0] real_sample = inputs[0]
real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1) real_sample = tf.expand_dims(real_sample, -1)
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM) # latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10), self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
...@@ -93,7 +93,7 @@ class Model(GANModelDesc): ...@@ -93,7 +93,7 @@ class Model(GANModelDesc):
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
fake_sample = self.generator(z) fake_sample = self.generator(z)
fake_sample_viz = tf.cast((fake_sample + 1) * 128.0, tf.uint8, name='viz') fake_sample_viz = tf.cast((fake_sample) * 255.0, tf.uint8, name='viz')
tf.summary.image('gen', fake_sample_viz, max_outputs=30) tf.summary.image('gen', fake_sample_viz, max_outputs=30)
# may need to investigate how bn stats should be updated across two discrim # may need to investigate how bn stats should be updated across two discrim
...@@ -164,7 +164,7 @@ def get_config(): ...@@ -164,7 +164,7 @@ def get_config():
logger.auto_set_dir() logger.auto_set_dir()
return TrainConfig( return TrainConfig(
dataflow=get_data(), dataflow=get_data(),
callbacks=[ModelSaver()], callbacks=[ModelSaver(keep_freq=0.1)],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
steps_per_epoch=500, steps_per_epoch=500,
......
...@@ -41,6 +41,8 @@ class StartProcOrThread(Callback): ...@@ -41,6 +41,8 @@ class StartProcOrThread(Callback):
if not self._stop_at_last: if not self._stop_at_last:
return return
for k in self._procs_threads: for k in self._procs_threads:
if not k.is_alive():
continue
if isinstance(k, mp.Process): if isinstance(k, mp.Process):
logger.info("Stopping {} ...".format(k.name)) logger.info("Stopping {} ...".format(k.name))
k.terminate() k.terminate()
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
from .base import Trainer from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..predict import OnlinePredictor, build_prediction_graph from ..predict import OnlinePredictor, build_prediction_graph
from .input_data import FeedInput from .input_data import FeedInput
...@@ -35,8 +35,22 @@ class PredictorFactory(object): ...@@ -35,8 +35,22 @@ class PredictorFactory(object):
if not self.tower_built: if not self.tower_built:
self._build_predict_tower() self._build_predict_tower()
tower = self.towers[tower % len(self.towers)] tower = self.towers[tower % len(self.towers)]
placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
def get_name_in_tower(name):
return PREDICT_TOWER + str(tower) + '/' + name
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
return get_name_in_tower(name)
input_names = map(maybe_inside_tower, input_names)
raw_input_vars = get_tensors_by_names(input_names) raw_input_vars = get_tensors_by_names(input_names)
output_names = ['{}{}/'.format(PREDICT_TOWER, tower) + n for n in output_names]
output_names = map(get_name_in_tower, output_names)
output_vars = get_tensors_by_names(output_names) output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
......
...@@ -42,7 +42,7 @@ def timed_operation(msg, log_start=False): ...@@ -42,7 +42,7 @@ def timed_operation(msg, log_start=False):
logger.info('Start {} ...'.format(msg)) logger.info('Start {} ...'.format(msg))
start = time.time() start = time.time()
yield yield
logger.info('{} finished, time:{:.2f}sec.'.format( logger.info('{} finished, time:{:.4f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment