Commit 50d95344 authored by Yuxin Wu's avatar Yuxin Wu

setup.py with git_version

parent 3356b8de
......@@ -201,7 +201,6 @@ def fastrcnn_predictions(boxes, scores):
Returns: n boolean, the selection
"""
prob, box = X
output_shape = tf.shape(prob)
# filter by score threshold
ids = tf.reshape(tf.where(prob > cfg.TEST.RESULT_SCORE_THRESH), [-1])
prob = tf.gather(prob, ids)
......@@ -213,11 +212,19 @@ def fastrcnn_predictions(boxes, scores):
# sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
mask = tf.sparse_to_dense(
sparse_indices=sorted_selection,
output_shape=output_shape,
sparse_values=True,
default_value=False)
if get_tf_version_tuple() >= (1, 12):
mask = tf.sparse.SparseTensor(indices=sorted_selection,
values=tf.ones_like(sorted_selection, dtype=tf.bool),
dense_shape=tf.shape(prob))
mask = tf.sparse.to_dense(mask, default_value=False)
else:
# deprecated by TF
mask = tf.sparse_to_dense(
sparse_indices=sorted_selection,
output_shape=tf.shape(prob),
sparse_values=True,
default_value=False)
return mask
# TF bug in version 1.11, 1.12: https://github.com/tensorflow/tensorflow/issues/22750
......
......@@ -10,12 +10,32 @@ this_directory = path.abspath(path.dirname(__file__))
# setup metainfo
libinfo_py = path.join(this_directory, 'tensorpack', 'libinfo.py')
last_line = open(libinfo_py, "rb").readlines()[-1].strip()
exec(last_line)
libinfo_content = open(libinfo_py, "r").readlines()
version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][0]
exec(version_line) # produce __version__
with open(path.join(this_directory, 'README.md'), 'rb') as f:
long_description = f.read().decode('utf-8')
def add_git_version():
def get_git_version():
from subprocess import check_output
try:
return check_output("git describe --tags --long --dirty".split()).decode('utf-8').strip()
except:
return __version__ # noqa
newlibinfo_content = [l for l in libinfo_content if not l.startswith('__git_version__')]
newlibinfo_content.append('__git_version__ = "{}"'.format(get_git_version()))
with open(libinfo_py, "w") as f:
f.write("".join(newlibinfo_content))
add_git_version()
setup(
name='tensorpack',
version=__version__, # noqa
......
......@@ -14,7 +14,7 @@ from six.moves import range, zip
import threading
from .input_source_base import InputSource
from ..dataflow import DataFlow, MapData, RepeatedData, DataFlowTerminated
from ..dataflow import DataFlow, MapData, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
......@@ -159,8 +159,9 @@ class EnqueueThread(ShareSessionThread):
feed = _make_feeds(self.placehdrs, dp)
# _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
except (tf.errors.CancelledError, tf.errors.OutOfRangeError):
pass
# logger.exception("Exception in {}:".format(self.name))
except Exception as e:
if isinstance(e, RuntimeError) and 'closed Session' in str(e):
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment