Commit aee9f1bc authored by Yuxin Wu's avatar Yuxin Wu

Fix when return cost not used in inference

parent b1850e0b
...@@ -50,11 +50,13 @@ class ImageFromFile(RNGDataFlow): ...@@ -50,11 +50,13 @@ class ImageFromFile(RNGDataFlow):
Args: Args:
files (list): list of file paths. files (list): list of file paths.
channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3. channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3.
Will produce (h, w, 1) array if channel==1.
resize (tuple): int or (h, w) tuple. If given, resize the image. resize (tuple): int or (h, w) tuple. If given, resize the image.
""" """
assert len(files), "No image files given to ImageFromFile!" assert len(files), "No image files given to ImageFromFile!"
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
assert self.channel in [1, 3], self.channel
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
if resize is not None: if resize is not None:
resize = shape2d(resize) resize = shape2d(resize)
......
...@@ -225,6 +225,8 @@ class ModelDesc(ModelDescBase): ...@@ -225,6 +225,8 @@ class ModelDesc(ModelDescBase):
Used by trainers to get the final cost for optimization. Used by trainers to get the final cost for optimization.
""" """
ret = self.build_graph(*inputs) ret = self.build_graph(*inputs)
if not get_current_tower_context().is_training:
return None # this is the tower function, could be called for inference
if isinstance(ret, tf.Tensor): # the preferred way if isinstance(ret, tf.Tensor): # the preferred way
assert ret.shape.ndims == 0, "Cost must be a scalar, but found a tensor of shape {}!".format(ret.shape) assert ret.shape.ndims == 0, "Cost must be a scalar, but found a tensor of shape {}!".format(ret.shape)
_check_unused_regularization() _check_unused_regularization()
...@@ -243,6 +245,9 @@ class ModelDesc(ModelDescBase): ...@@ -243,6 +245,9 @@ class ModelDesc(ModelDescBase):
ctx = get_current_tower_context() ctx = get_current_tower_context()
cost = self._build_graph_get_cost(*inputs) cost = self._build_graph_get_cost(*inputs)
if not ctx.is_training:
return None # this is the tower function, could be called for inference
if ctx.has_own_variables: if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES) varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else: else:
......
...@@ -190,6 +190,8 @@ class SingleCostTrainer(TowerTrainer): ...@@ -190,6 +190,8 @@ class SingleCostTrainer(TowerTrainer):
def get_grad_fn(): def get_grad_fn():
ctx = get_current_tower_context() ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors()) cost = get_cost_fn(*input.get_input_tensors())
if not ctx.is_training:
return None # this is the tower function, could be called for inference
if ctx.has_own_variables: if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES) varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment