Commit b01cd6d9 authored by Yuxin Wu's avatar Yuxin Wu

error for wrong param dtype

parent e2eba742
...@@ -152,8 +152,9 @@ with TowerContext('', is_training=False): ...@@ -152,8 +152,9 @@ with TowerContext('', is_training=False):
It's also very common to change the graph for inference. It's also very common to change the graph for inference.
For example, you may need a different data layout for CPU inference, For example, you may need a different data layout for CPU inference,
or you may need placeholders in the inference graph (which may not even exist in you may need placeholders in the inference graph (which may not even exist in
the training graph). However metagraph is not designed to be easily modified at all. the training graph), you may add/remove static shapes for inference, etc.
However metagraph is not designed to be easily modified at all.
Due to the above reasons, to do inference, it's best to recreate a clean graph (and save it if needed) by yourself. Due to the above reasons, to do inference, it's best to recreate a clean graph (and save it if needed) by yourself.
``` ```
......
...@@ -77,6 +77,10 @@ class GraphVarParam(HyperParam): ...@@ -77,6 +77,10 @@ class GraphVarParam(HyperParam):
def set_value(self, v): def set_value(self, v):
""" Assign the variable a new value. """ """ Assign the variable a new value. """
if not self.var.dtype.is_floating and isinstance(v, float):
raise ValueError(
"HyperParam {} has type '{}'. Cannot update it using float values.".format(
self.name, self.var.dtype))
self.var.load(v) self.var.load(v)
def get_value(self): def get_value(self):
......
...@@ -44,7 +44,7 @@ class Grayscale(ColorSpace): ...@@ -44,7 +44,7 @@ class Grayscale(ColorSpace):
class ToUint8(PhotometricAugmentor): class ToUint8(PhotometricAugmentor):
""" Convert image to uint8. Useful to reduce communication overhead. """ """ Clip and convert image to uint8. Useful to reduce communication overhead. """
def _augment(self, img, _): def _augment(self, img, _):
return np.clip(img, 0, 255).astype(np.uint8) return np.clip(img, 0, 255).astype(np.uint8)
......
...@@ -134,7 +134,7 @@ class FeedfreeInput(InputSource): ...@@ -134,7 +134,7 @@ class FeedfreeInput(InputSource):
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, placehdrs): def __init__(self, queue, ds, placehdrs):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread ' + queue.name self.name = 'EnqueueThread: enqueue dataflow to TF queue "{}"'.format(queue.name)
self.daemon = True self.daemon = True
self.dataflow = ds self.dataflow = ds
self.queue = queue self.queue = queue
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment