Commit fe1e88d3 authored by Yuxin Wu's avatar Yuxin Wu

JIT hurt performance

parent 9ca10d35
...@@ -40,6 +40,8 @@ class RunOp(Callback): ...@@ -40,6 +40,8 @@ class RunOp(Callback):
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
if self.run_step:
self._fetch = tf.train.SessionRunArgs(fetches=self._op)
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
...@@ -54,7 +56,7 @@ class RunOp(Callback): ...@@ -54,7 +56,7 @@ class RunOp(Callback):
def _before_run(self, _): def _before_run(self, _):
if self.run_step: if self.run_step:
self._print() self._print()
return [self._op] return self._fetch # faster than return [self._op]
def _print(self): def _print(self):
if self.verbose: if self.verbose:
......
...@@ -44,7 +44,8 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -44,7 +44,8 @@ def get_default_sess_config(mem_fraction=0.99):
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True conf.gpu_options.allow_growth = True
conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 # Hurt performance in 8xP100 training
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf return conf
......
...@@ -379,9 +379,10 @@ class DummyConstantInput(TensorInput): ...@@ -379,9 +379,10 @@ class DummyConstantInput(TensorInput):
assert ctx is not None assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs) assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs): for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable( tlist.append(tf.constant(
'dummy-{}-{}'.format(p.op.name, ctx.index), shape=self.shapes[idx], 0, dtype=p.dtype,
dtype=p.dtype, trainable=False)) name='dummy-{}-{}'.format(p.op.name, ctx.index),
shape=self.shapes[idx]))
return tlist return tlist
super(DummyConstantInput, self).__init__(fn) super(DummyConstantInput, self).__init__(fn)
......
...@@ -353,7 +353,7 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False): ...@@ -353,7 +353,7 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
from ..utils.develop import create_dummy_func # noqa from ..utils.develop import create_dummy_func # noqa
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: except (ImportError, RuntimeError):
pyplot2img = create_dummy_func('pyplot2img', 'matplotlib') # noqa pyplot2img = create_dummy_func('pyplot2img', 'matplotlib') # noqa
intensity_to_rgb = create_dummy_func('intensity_to_rgb', 'matplotlib') # noqa intensity_to_rgb = create_dummy_func('intensity_to_rgb', 'matplotlib') # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment