Commit 7d9582a1 authored by Yuxin Wu's avatar Yuxin Wu

update for ilsvrc

parent 17d8feb5
...@@ -10,7 +10,7 @@ import imp ...@@ -10,7 +10,7 @@ import imp
import tqdm import tqdm
import os import os
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.utils import mkdir_p from tensorpack.utils.fs import mkdir_p
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import tarfile import tarfile
import cv2
import numpy as np import numpy as np
import scipy.ndimage as scimg
from ...utils import logger, get_rng from ...utils import logger, get_rng
from ..base import DataFlow from ..base import DataFlow
...@@ -61,9 +61,10 @@ class ILSVRCMeta(object): ...@@ -61,9 +61,10 @@ class ILSVRCMeta(object):
ret.append((name, int(cls))) ret.append((name, int(cls)))
return ret return ret
def get_per_pixel_mean(self): def get_per_pixel_mean(self, size=None):
""" """
:returns per-pixel mean as an array of shape (3, 256, 256) in range [0, 255] :param size: return image size in [h, w]. default to (256, 256)
:returns per-pixel mean as an array of shape (h, w, 3) in range [0, 255]
""" """
import imp import imp
caffepb = imp.load_source('caffepb', self.caffe_pb_file) caffepb = imp.load_source('caffepb', self.caffe_pb_file)
...@@ -73,6 +74,9 @@ class ILSVRCMeta(object): ...@@ -73,6 +74,9 @@ class ILSVRCMeta(object):
with open(mean_file) as f: with open(mean_file) as f:
obj.ParseFromString(f.read()) obj.ParseFromString(f.read())
arr = np.array(obj.data).reshape((3, 256, 256)) arr = np.array(obj.data).reshape((3, 256, 256))
arr = np.transpose(arr, [1,2,0])
if size is not None:
arr = cv2.resize(arr, size[::-1])
return arr return arr
class ILSVRC12(DataFlow): class ILSVRC12(DataFlow):
...@@ -106,9 +110,10 @@ class ILSVRC12(DataFlow): ...@@ -106,9 +110,10 @@ class ILSVRC12(DataFlow):
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
tp = self.imglist[k] tp = self.imglist[k]
fname = os.path.join(self.dir, self.name, tp[0]) fname = os.path.join(self.dir, self.name, tp[0]).strip()
im = scimg.imread(fname) im = cv2.imread(fname)
if len(im.shape) == 2: assert im is not None, fname
if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2) im = np.expand_dims(im, 2).repeat(3,2)
yield [im, tp[1]] yield [im, tp[1]]
......
...@@ -69,7 +69,7 @@ class AugmentorList(ImageAugmentor): ...@@ -69,7 +69,7 @@ class AugmentorList(ImageAugmentor):
self.augs = augmentors self.augs = augmentors
def _augment(self, img): def _augment(self, img):
assert img.arr.ndim in [2, 3] assert img.arr.ndim in [2, 3], img.arr.ndim
img.arr = img.arr.astype('float32') img.arr = img.arr.astype('float32')
for aug in self.augs: for aug in self.augs:
aug.augment(img) aug.augment(img)
......
...@@ -86,11 +86,12 @@ class QueueInputTrainer(Trainer): ...@@ -86,11 +86,12 @@ class QueueInputTrainer(Trainer):
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
ret = [] ret = []
for grad_and_vars in zip(*tower_grads): with tf.device('/gpu:0'):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads)) for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1] grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
ret.append((grad, v)) v = grad_and_vars[0][1]
return ret ret.append((grad, v))
return ret
def train(self): def train(self):
model = self.model model = self.model
...@@ -121,7 +122,8 @@ class QueueInputTrainer(Trainer): ...@@ -121,7 +122,8 @@ class QueueInputTrainer(Trainer):
if i == 0: if i == 0:
cost_var_t0 = cost_var cost_var_t0 = cost_var
grad_list.append( grad_list.append(
self.config.optimizer.compute_gradients(cost_var)) self.config.optimizer.compute_gradients(cost_var,
gate_gradients=0))
if i == 0: if i == 0:
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment