Commit a2f4f439 authored by Yuxin Wu's avatar Yuxin Wu

misc fix

parent 0d20032a
......@@ -6,3 +6,4 @@ pyzmq
tornado; python_version < '3.0'
lmdb
matplotlib
scikit-learn
......@@ -24,6 +24,7 @@ with tf.Graph().as_default() as G:
if args.config:
MODEL = imp.load_source('config_script', args.config).Model
M = MODEL()
with TowerContext('', is_training=False):
M.build_graph(M.get_input_vars())
else:
M = ModelFromMetaGraph(args.meta)
......
......@@ -298,12 +298,11 @@ Line: {}""".format(repr(args.delimeter), line)
length_ys = [len(t) for t in data_ys]
print("Length of each column:", length_ys)
max_ysize = max(length_ys)
print("Size of the longest y column: ", max_ysize)
if nr_x_column:
data_xs = [data[k] for k in args.x_column_idx]
else:
data_xs = [list(range(max_ysize))]
data_xs = [list(range(1, max_ysize+1))]
for idx, data_y in enumerate(data_ys):
data_ys[idx] = np.asarray(data_y)
......
......@@ -2,15 +2,15 @@
# File: format.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger, get_rng
import numpy as np
from tqdm import tqdm
from six.moves import range
from ..utils import logger, get_rng, get_tqdm_kwargs
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from .base import RNGDataFlow
import random
from tqdm import tqdm
from six.moves import range
try:
import h5py
except ImportError:
......@@ -24,13 +24,21 @@ try:
except ImportError:
logger.warn("Error in 'import lmdb'. LMDBData won't be available.")
else:
__all__.extend(['LMDBData', 'CaffeLMDB'])
__all__.extend(['LMDBData', 'CaffeLMDB', 'LMDBDataDecoder'])
try:
import sklearn.datasets
except ImportError:
logger.warn("Error in 'import sklearn'. SVMLightData won't be available.")
else:
__all__.extend(['SVMLightData'])
"""
Adapters for different data format.
"""
# TODO lazy load
class HDF5Data(RNGDataFlow):
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
......@@ -69,11 +77,12 @@ class LMDBData(RNGDataFlow):
self._shuffle = shuffle
self._size = self._txn.stat()['entries']
if shuffle:
# get the list of keys either from __keys__ or by iterating
self.keys = self._txn.get('__keys__')
if not self.keys:
self.keys = []
with timed_operation("Loading LMDB keys ...", log_start=True), \
tqdm(total=self._size, ascii=True) as pbar:
tqdm(get_tqdm_kwargs(total=self._size)) as pbar:
for k in self._txn.cursor():
if k != '__keys__':
self.keys.append(k)
......@@ -131,3 +140,20 @@ class CaffeLMDB(LMDBDataDecoder):
super(CaffeLMDB, self).__init__(
lmdb_dir, decoder=decoder, shuffle=shuffle)
class SVMLightData(RNGDataFlow):
""" Read X,y from a svmlight file """
def __init__(self, filename, shuffle=True):
self.X, self.y = sklearn.datasets.load_svmlight_file(filename)
self.X = np.asarray(self.X.todense())
self.shuffle = shuffle
def size(self):
return len(self.y)
def get_data(self):
idxs = np.arange(self.size())
if self.shuffle:
self.rng.shuffle(idxs)
for id in idxs:
yield [self.X[id,:], self.y[id]]
......@@ -45,8 +45,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = OfflinePredictor(self.config)
self.predictor = OfflinePredictor(self.config)
if self.idx == 0:
with self.predictor.graph.as_default():
describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
......@@ -70,7 +71,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((DIE, None))
return
else:
self.outqueue.put((tid, self.func(dp)))
self.outqueue.put((tid, self.predictor(dp)))
class PredictorWorkerThread(threading.Thread):
......
......@@ -66,6 +66,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
yield res
pbar.update()
# TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True):
"""
......
......@@ -64,15 +64,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
"""
z = batch_flatten(logits)
y = tf.cast(batch_flatten(label), tf.float32)
y = tf.cast(label, tf.float32)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
beta = count_neg / (count_neg + count_pos)
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(z, y, pos_weight)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment