Commit 3fa1a499 authored by Yuxin Wu's avatar Yuxin Wu

support sparse placeholder

parent e3e21c61
...@@ -115,7 +115,7 @@ class Model(ModelDesc): ...@@ -115,7 +115,7 @@ class Model(ModelDesc):
output = tf.image.grayscale_to_rgb(output) output = tf.image.grayscale_to_rgb(output)
fake_output = tf.image.grayscale_to_rgb(fake_output) fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0 viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0
viz = tf.cast(viz, tf.uint8, name='viz') viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.image_summary('gen', viz, max_images=max(30, BATCH)) tf.image_summary('gen', viz, max_images=max(30, BATCH))
all_vars = tf.trainable_variables() all_vars = tf.trainable_variables()
......
...@@ -17,8 +17,12 @@ from ..tfutils.tower import get_current_tower_context ...@@ -17,8 +17,12 @@ from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
_InputVar = namedtuple('InputVar', ['type', 'shape', 'name']) _InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(_InputVar): class InputVar(_InputVar):
def __init__(self, type, shape, name, sparse=False):
super(InputVar, self).__init__(type, shape, name, sparse)
def __new__(cls, type, shape, name, sparse=False):
return super(InputVar, cls).__new__(cls, type, shape, name, sparse)
def dumps(self): def dumps(self):
return pickle.dumps(self) return pickle.dumps(self)
@staticmethod @staticmethod
...@@ -35,11 +39,11 @@ class ModelDesc(object): ...@@ -35,11 +39,11 @@ class ModelDesc(object):
:returns: the list of raw input vars in the graph :returns: the list of raw input vars in the graph
""" """
try: if hasattr(self, 'reuse_input_vars'):
return self._reuse_input_vars() return self.reuse_input_vars
except KeyError: ret = self.get_placeholders()
pass self.reuse_input_vars = ret
return self.get_placeholders() return ret
def get_placeholders(self, prefix=''): def get_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar """ build placeholders with optional prefix, for each InputVar
...@@ -49,16 +53,12 @@ class ModelDesc(object): ...@@ -49,16 +53,12 @@ class ModelDesc(object):
tf.add_to_collection(INPUT_VARS_KEY, v.dumps()) tf.add_to_collection(INPUT_VARS_KEY, v.dumps())
ret = [] ret = []
for v in input_vars: for v in input_vars:
ret.append(tf.placeholder( placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
ret.append(placehdr_f(
v.type, shape=v.shape, v.type, shape=v.shape,
name=prefix + v.name)) name=prefix + v.name))
return ret return ret
def _reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()]
return get_tensors_by_names(input_var_names)
def get_input_vars_desc(self): def get_input_vars_desc(self):
""" return a list of `InputVar` instance""" """ return a list of `InputVar` instance"""
return self._get_input_vars() return self._get_input_vars()
......
...@@ -99,7 +99,7 @@ def add_moving_summary(v, *args): ...@@ -99,7 +99,7 @@ def add_moving_summary(v, *args):
v = [v] v = [v]
v.extend(args) v.extend(args)
for x in v: for x in v:
assert x.get_shape().ndims == 0 assert x.get_shape().ndims == 0, x.get_shape()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
@memoized @memoized
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment