Commit b95ea88f authored by Yuxin Wu's avatar Yuxin Wu

deconv2d and leakyrelu

parent 77bcc8b1
......@@ -91,13 +91,14 @@ def Deconv2D(x, out_shape, kernel_shape,
stride2d = shape2d(stride)
stride4d = shape4d(stride)
padding = padding.upper()
filter_shape = kernel_shape + [in_channel, out_channel]
if isinstance(out_shape, int):
out_shape = tf.pack([tf.shape(x)[0],
stride2d[0] * in_shape[0], stride2d[1] * in_shape[1], out_shape])
out_channel = out_shape
shape3 = [stride2d[0] * in_shape[0], stride2d[1] * in_shape[1], out_shape]
else:
out_shape = tf.pack([tf.shape(x)[0]] + out_shape)
out_channel = out_shape[-1]
shape3 = out_shape
filter_shape = kernel_shape + [out_channel, in_channel]
if W_init is None:
W_init = tf.contrib.layers.xavier_initializer_conv2d()
......@@ -107,5 +108,7 @@ def Deconv2D(x, out_shape, kernel_shape,
if use_bias:
b = tf.get_variable('b', [out_channel], initializer=b_init)
out_shape = tf.pack([tf.shape(x)[0]] + shape3)
conv = tf.nn.conv2d_transpose(x, W, out_shape, stride4d, padding=padding)
conv.set_shape(tf.TensorShape([None] + shape3))
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
......@@ -10,7 +10,7 @@ from collections import namedtuple
import inspect
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_vars_by_names
from ..tfutils.common import get_tensors_by_names
from ..tfutils.gradproc import CheckGradient
from ..tfutils.tower import get_current_tower_context
......@@ -43,7 +43,7 @@ class ModelDesc(object):
def reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()]
return get_vars_by_names(input_var_names)
return get_tensors_by_names(input_var_names)
def get_input_vars_desc(self):
""" return a list of `InputVar` instance"""
......
......@@ -57,11 +57,12 @@ def LeakyReLU(x, alpha, name=None):
:param input: any tensor.
:param alpha: the negative slope.
"""
alpha = float(alpha)
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
if name is None:
name = 'output'
return tf.mul(x, 0.5, name=name)
return tf.maximum(x, alpha * x, name=name)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=None)
......
......@@ -103,6 +103,7 @@ def is_training_name(name):
:returns: guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ?
# TODO or use get_slot_names()
name = get_op_tensor_name(name)[0]
if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True
......
......@@ -143,6 +143,8 @@ class Trainer(object):
self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
except:
raise
finally:
callbacks.after_train()
self.coord.request_stop()
......
......@@ -52,15 +52,15 @@ class EnqueueThread(threading.Thread):
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
self.coord.request_stop()
logger.info("Enqueue Thread Exited.")
class QueueInputTrainerBase(FeedlessTrainer):
def _build_enque_thread(self, input_queue):
def _build_enque_thread(self, input_queue=None):
""" create a thread that keeps filling the queue """
self.input_vars = self.model.get_input_vars()
if input_queue is None:
......
......@@ -11,7 +11,7 @@ from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection,
from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
......@@ -40,9 +40,9 @@ class PredictorFactory(object):
if not self.tower_built:
self._build_predict_tower()
tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names)
raw_input_vars = get_tensors_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment