Commit 50d95344 authored by Yuxin Wu's avatar Yuxin Wu

setup.py with git_version

parent 3356b8de
...@@ -201,7 +201,6 @@ def fastrcnn_predictions(boxes, scores): ...@@ -201,7 +201,6 @@ def fastrcnn_predictions(boxes, scores):
Returns: n boolean, the selection Returns: n boolean, the selection
""" """
prob, box = X prob, box = X
output_shape = tf.shape(prob)
# filter by score threshold # filter by score threshold
ids = tf.reshape(tf.where(prob > cfg.TEST.RESULT_SCORE_THRESH), [-1]) ids = tf.reshape(tf.where(prob > cfg.TEST.RESULT_SCORE_THRESH), [-1])
prob = tf.gather(prob, ids) prob = tf.gather(prob, ids)
...@@ -213,9 +212,17 @@ def fastrcnn_predictions(boxes, scores): ...@@ -213,9 +212,17 @@ def fastrcnn_predictions(boxes, scores):
# sort available in TF>1.4.0 # sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING') # sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0] sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
if get_tf_version_tuple() >= (1, 12):
mask = tf.sparse.SparseTensor(indices=sorted_selection,
values=tf.ones_like(sorted_selection, dtype=tf.bool),
dense_shape=tf.shape(prob))
mask = tf.sparse.to_dense(mask, default_value=False)
else:
# deprecated by TF
mask = tf.sparse_to_dense( mask = tf.sparse_to_dense(
sparse_indices=sorted_selection, sparse_indices=sorted_selection,
output_shape=output_shape, output_shape=tf.shape(prob),
sparse_values=True, sparse_values=True,
default_value=False) default_value=False)
return mask return mask
......
...@@ -10,12 +10,32 @@ this_directory = path.abspath(path.dirname(__file__)) ...@@ -10,12 +10,32 @@ this_directory = path.abspath(path.dirname(__file__))
# setup metainfo # setup metainfo
libinfo_py = path.join(this_directory, 'tensorpack', 'libinfo.py') libinfo_py = path.join(this_directory, 'tensorpack', 'libinfo.py')
last_line = open(libinfo_py, "rb").readlines()[-1].strip() libinfo_content = open(libinfo_py, "r").readlines()
exec(last_line) version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][0]
exec(version_line) # produce __version__
with open(path.join(this_directory, 'README.md'), 'rb') as f: with open(path.join(this_directory, 'README.md'), 'rb') as f:
long_description = f.read().decode('utf-8') long_description = f.read().decode('utf-8')
def add_git_version():
def get_git_version():
from subprocess import check_output
try:
return check_output("git describe --tags --long --dirty".split()).decode('utf-8').strip()
except:
return __version__ # noqa
newlibinfo_content = [l for l in libinfo_content if not l.startswith('__git_version__')]
newlibinfo_content.append('__git_version__ = "{}"'.format(get_git_version()))
with open(libinfo_py, "w") as f:
f.write("".join(newlibinfo_content))
add_git_version()
setup( setup(
name='tensorpack', name='tensorpack',
version=__version__, # noqa version=__version__, # noqa
......
...@@ -14,7 +14,7 @@ from six.moves import range, zip ...@@ -14,7 +14,7 @@ from six.moves import range, zip
import threading import threading
from .input_source_base import InputSource from .input_source_base import InputSource
from ..dataflow import DataFlow, MapData, RepeatedData, DataFlowTerminated from ..dataflow import DataFlow, MapData, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
...@@ -159,8 +159,9 @@ class EnqueueThread(ShareSessionThread): ...@@ -159,8 +159,9 @@ class EnqueueThread(ShareSessionThread):
feed = _make_feeds(self.placehdrs, dp) feed = _make_feeds(self.placehdrs, dp)
# _, sz = sess.run([self.op, self._sz], feed_dict=feed) # _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): except (tf.errors.CancelledError, tf.errors.OutOfRangeError):
pass pass
# logger.exception("Exception in {}:".format(self.name))
except Exception as e: except Exception as e:
if isinstance(e, RuntimeError) and 'closed Session' in str(e): if isinstance(e, RuntimeError) and 'closed Session' in str(e):
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment