Commit b7766fc1 authored by Yuxin Wu's avatar Yuxin Wu

fix swig name. add some df

parent 3efce3ae
......@@ -9,7 +9,8 @@ from .base import DataFlow, ProxyDataFlow
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'RandomMixData', 'JoinData']
'MapDataComponent', 'RandomChooseData', 'RandomMixData',
'JoinData', 'ConcatData', 'SelectComponent']
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -249,7 +250,7 @@ class RandomMixData(DataFlow):
for k in idxs:
yield next(itrs[k])
class JoinData(DataFlow):
class ConcatData(DataFlow):
"""
Concatenate several dataflows.
"""
......@@ -271,21 +272,48 @@ class JoinData(DataFlow):
for dp in d.get_data():
yield dp
class SelectComponent(ProxyDataFlow):
class JoinData(DataFlow):
"""
Select component from a datapoint.
Join the components from each DataFlow.
e.g.: df1: [dp1, dp2]
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
def __init__(self, ds, idxs):
def __init__(self, df_lists):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
:param df_lists: list of :mod:`DataFlow` instances
"""
super(SelectComponent, self).__init__(ds)
self.idxs = idxs
self.df_lists = df_lists
self._size = self.df_lists[0].size()
for d in self.df_lists:
assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
def reset_state(self):
for d in self.df_lists:
d.reset_state()
def size(self):
return self._size
def get_data(self):
for dp in self.ds.get_data():
newdp = []
for idx in self.idxs:
newdp.append(dp[idx])
yield newdp
itrs = [k.get_data() for k in self.df_lists]
try:
while True:
dp = []
for itr in itrs:
dp.extend(next(itr))
yield dp
except StopIteration:
pass
finally:
for itr in itrs:
del itr
def SelectComponent(ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
return MapData(ds, lambda dp: [dp[i] for i in idxs])
......@@ -79,7 +79,8 @@ def get_predict_func(config):
# check output_var_names against output_vars
if output_var_names is not None:
output_vars = [tf.get_default_graph().get_tensor_by_name(n) for n in output_var_names]
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
else:
output_vars = []
......
......@@ -6,7 +6,7 @@
from ..utils.naming import *
import tensorflow as tf
def get_default_sess_config(mem_fraction=0.5):
def get_default_sess_config(mem_fraction=0.99):
"""
Return a better session config to use as default.
Tensorflow default session config consume too much resources.
......
......@@ -85,7 +85,7 @@ class SaverRestore(SessionInit):
@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
return set(reader.GetVariableToShapeMap().keys())
return set(reader.get_variable_to_shape_map().keys())
@staticmethod
def _get_vars_to_restore_multimap(vars_available):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment