Commit fe1e88d3 authored by Yuxin Wu's avatar Yuxin Wu

JIT hurt performance

parent 9ca10d35
......@@ -40,6 +40,8 @@ class RunOp(Callback):
def _setup_graph(self):
self._op = self.setup_func()
if self.run_step:
self._fetch = tf.train.SessionRunArgs(fetches=self._op)
def _before_train(self):
if self.run_before:
......@@ -54,7 +56,7 @@ class RunOp(Callback):
def _before_run(self, _):
if self.run_step:
self._print()
return [self._op]
return self._fetch # faster than return [self._op]
def _print(self):
if self.verbose:
......
......@@ -44,7 +44,8 @@ def get_default_sess_config(mem_fraction=0.99):
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
# Hurt performance in 8xP100 training
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf
......
......@@ -379,9 +379,10 @@ class DummyConstantInput(TensorInput):
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable(
'dummy-{}-{}'.format(p.op.name, ctx.index), shape=self.shapes[idx],
dtype=p.dtype, trainable=False))
tlist.append(tf.constant(
0, dtype=p.dtype,
name='dummy-{}-{}'.format(p.op.name, ctx.index),
shape=self.shapes[idx]))
return tlist
super(DummyConstantInput, self).__init__(fn)
......
......@@ -353,7 +353,7 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
from ..utils.develop import create_dummy_func # noqa
try:
import matplotlib.pyplot as plt
except ImportError:
except (ImportError, RuntimeError):
pyplot2img = create_dummy_func('pyplot2img', 'matplotlib') # noqa
intensity_to_rgb = create_dummy_func('intensity_to_rgb', 'matplotlib') # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment