Commit b95ea88f authored by Yuxin Wu's avatar Yuxin Wu

deconv2d and leakyrelu

parent 77bcc8b1
...@@ -91,13 +91,14 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -91,13 +91,14 @@ def Deconv2D(x, out_shape, kernel_shape,
stride2d = shape2d(stride) stride2d = shape2d(stride)
stride4d = shape4d(stride) stride4d = shape4d(stride)
padding = padding.upper() padding = padding.upper()
filter_shape = kernel_shape + [in_channel, out_channel]
if isinstance(out_shape, int): if isinstance(out_shape, int):
out_shape = tf.pack([tf.shape(x)[0], out_channel = out_shape
stride2d[0] * in_shape[0], stride2d[1] * in_shape[1], out_shape]) shape3 = [stride2d[0] * in_shape[0], stride2d[1] * in_shape[1], out_shape]
else: else:
out_shape = tf.pack([tf.shape(x)[0]] + out_shape) out_channel = out_shape[-1]
shape3 = out_shape
filter_shape = kernel_shape + [out_channel, in_channel]
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.xavier_initializer_conv2d() W_init = tf.contrib.layers.xavier_initializer_conv2d()
...@@ -107,5 +108,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -107,5 +108,7 @@ def Deconv2D(x, out_shape, kernel_shape,
if use_bias: if use_bias:
b = tf.get_variable('b', [out_channel], initializer=b_init) b = tf.get_variable('b', [out_channel], initializer=b_init)
out_shape = tf.pack([tf.shape(x)[0]] + shape3)
conv = tf.nn.conv2d_transpose(x, W, out_shape, stride4d, padding=padding) conv = tf.nn.conv2d_transpose(x, W, out_shape, stride4d, padding=padding)
conv.set_shape(tf.TensorShape([None] + shape3))
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output') return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
...@@ -10,7 +10,7 @@ from collections import namedtuple ...@@ -10,7 +10,7 @@ from collections import namedtuple
import inspect import inspect
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_vars_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.gradproc import CheckGradient from ..tfutils.gradproc import CheckGradient
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
...@@ -43,7 +43,7 @@ class ModelDesc(object): ...@@ -43,7 +43,7 @@ class ModelDesc(object):
def reuse_input_vars(self): def reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph""" """ Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()] input_var_names = [k.name for k in self._get_input_vars()]
return get_vars_by_names(input_var_names) return get_tensors_by_names(input_var_names)
def get_input_vars_desc(self): def get_input_vars_desc(self):
""" return a list of `InputVar` instance""" """ return a list of `InputVar` instance"""
......
...@@ -57,11 +57,12 @@ def LeakyReLU(x, alpha, name=None): ...@@ -57,11 +57,12 @@ def LeakyReLU(x, alpha, name=None):
:param input: any tensor. :param input: any tensor.
:param alpha: the negative slope. :param alpha: the negative slope.
""" """
alpha = float(alpha)
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
if name is None: if name is None:
name = 'output' name = 'output'
return tf.mul(x, 0.5, name=name) return tf.maximum(x, alpha * x, name=name)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
def BNReLU(x, name=None): def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=None) x = BatchNorm('bn', x, use_local_stat=None)
......
...@@ -103,6 +103,7 @@ def is_training_name(name): ...@@ -103,6 +103,7 @@ def is_training_name(name):
:returns: guess whether this tensor is something only used in training. :returns: guess whether this tensor is something only used in training.
""" """
# TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ? # TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ?
# TODO or use get_slot_names()
name = get_op_tensor_name(name)[0] name = get_op_tensor_name(name)[0]
if name.endswith('/Adam') or name.endswith('/Adam_1'): if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True return True
......
...@@ -143,6 +143,8 @@ class Trainer(object): ...@@ -143,6 +143,8 @@ class Trainer(object):
self.trigger_epoch() self.trigger_epoch()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
except:
raise
finally: finally:
callbacks.after_train() callbacks.after_train()
self.coord.request_stop() self.coord.request_stop()
......
...@@ -52,15 +52,15 @@ class EnqueueThread(threading.Thread): ...@@ -52,15 +52,15 @@ class EnqueueThread(threading.Thread):
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
finally: finally:
self.coord.request_stop()
try: try:
self.sess.run(self.close_op) self.sess.run(self.close_op)
except RuntimeError: # session already closed except RuntimeError: # session already closed
pass pass
self.coord.request_stop()
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInputTrainerBase(FeedlessTrainer): class QueueInputTrainerBase(FeedlessTrainer):
def _build_enque_thread(self, input_queue): def _build_enque_thread(self, input_queue=None):
""" create a thread that keeps filling the queue """ """ create a thread that keeps filling the queue """
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
if input_queue is None: if input_queue is None:
......
...@@ -11,7 +11,7 @@ from .base import Trainer ...@@ -11,7 +11,7 @@ from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import logger, SUMMARY_BACKUP_KEYS from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
...@@ -40,9 +40,9 @@ class PredictorFactory(object): ...@@ -40,9 +40,9 @@ class PredictorFactory(object):
if not self.tower_built: if not self.tower_built:
self._build_predict_tower() self._build_predict_tower()
tower = self.towers[tower % len(self.towers)] tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_tensors_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names] output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names) output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self): def _build_predict_tower(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment