Commit b7766fc1 authored by Yuxin Wu's avatar Yuxin Wu

fix swig name. add some df

parent 3efce3ae
...@@ -9,7 +9,8 @@ from .base import DataFlow, ProxyDataFlow ...@@ -9,7 +9,8 @@ from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'RandomMixData', 'JoinData'] 'MapDataComponent', 'RandomChooseData', 'RandomMixData',
'JoinData', 'ConcatData', 'SelectComponent']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -249,7 +250,7 @@ class RandomMixData(DataFlow): ...@@ -249,7 +250,7 @@ class RandomMixData(DataFlow):
for k in idxs: for k in idxs:
yield next(itrs[k]) yield next(itrs[k])
class JoinData(DataFlow): class ConcatData(DataFlow):
""" """
Concatenate several dataflows. Concatenate several dataflows.
""" """
...@@ -271,21 +272,48 @@ class JoinData(DataFlow): ...@@ -271,21 +272,48 @@ class JoinData(DataFlow):
for dp in d.get_data(): for dp in d.get_data():
yield dp yield dp
class SelectComponent(ProxyDataFlow): class JoinData(DataFlow):
""" """
Select component from a datapoint. Join the components from each DataFlow.
e.g.: df1: [dp1, dp2]
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
""" """
def __init__(self, ds, idxs): def __init__(self, df_lists):
""" """
:param ds: a :mod:`DataFlow` instance :param df_lists: list of :mod:`DataFlow` instances
:param idxs: a list of datapoint component index of the original dataflow
""" """
super(SelectComponent, self).__init__(ds) self.df_lists = df_lists
self.idxs = idxs self._size = self.df_lists[0].size()
for d in self.df_lists:
assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
def reset_state(self):
for d in self.df_lists:
d.reset_state()
def size(self):
return self._size
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): itrs = [k.get_data() for k in self.df_lists]
newdp = [] try:
for idx in self.idxs: while True:
newdp.append(dp[idx]) dp = []
yield newdp for itr in itrs:
dp.extend(next(itr))
yield dp
except StopIteration:
pass
finally:
for itr in itrs:
del itr
def SelectComponent(ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
return MapData(ds, lambda dp: [dp[i] for i in idxs])
...@@ -79,7 +79,8 @@ def get_predict_func(config): ...@@ -79,7 +79,8 @@ def get_predict_func(config):
# check output_var_names against output_vars # check output_var_names against output_vars
if output_var_names is not None: if output_var_names is not None:
output_vars = [tf.get_default_graph().get_tensor_by_name(n) for n in output_var_names] output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
else: else:
output_vars = [] output_vars = []
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from ..utils.naming import * from ..utils.naming import *
import tensorflow as tf import tensorflow as tf
def get_default_sess_config(mem_fraction=0.5): def get_default_sess_config(mem_fraction=0.99):
""" """
Return a better session config to use as default. Return a better session config to use as default.
Tensorflow default session config consume too much resources. Tensorflow default session config consume too much resources.
......
...@@ -85,7 +85,7 @@ class SaverRestore(SessionInit): ...@@ -85,7 +85,7 @@ class SaverRestore(SessionInit):
@staticmethod @staticmethod
def _read_checkpoint_vars(model_path): def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
return set(reader.GetVariableToShapeMap().keys()) return set(reader.get_variable_to_shape_map().keys())
@staticmethod @staticmethod
def _get_vars_to_restore_multimap(vars_available): def _get_vars_to_restore_multimap(vars_available):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment