Commit b31dd61a authored by Yuxin Wu's avatar Yuxin Wu

fix BGR for CAM

parent 57a81ade
......@@ -9,6 +9,7 @@ import numpy as np
import argparse
from tensorpack import *
from tensorpack.utils import viz
from tensorpack.utils.argtools import memoized
"""
......
......@@ -12,9 +12,11 @@ import multiprocessing
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils import viz
from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_group,
......@@ -33,7 +35,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=False)
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
cfg = {
......@@ -85,13 +87,14 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain:
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
def get_config():
nr_gpu = get_nr_gpu()
global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
dataset_train = get_data('train')
......@@ -102,14 +105,15 @@ def get_config():
dataflow=dataset_train,
callbacks=[
ModelSaver(),
InferenceRunner(dataset_val, [
PeriodicTrigger(InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]),
every_k_epochs=2),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
[(30, 1e-2), (55, 1e-3), (75, 1e-4), (95, 1e-5)]),
],
steps_per_epoch=5000,
max_epoch=110,
max_epoch=105,
nr_tower=nr_gpu
)
......@@ -151,7 +155,7 @@ def viz_cam(model_file, data_dir):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True)
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--depth', type=int, default=18)
parser.add_argument('--load', help='load model')
......@@ -159,7 +163,8 @@ if __name__ == '__main__':
args = parser.parse_args()
DEPTH = args.depth
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.cam:
BATCH_SIZE = 128 # something that can run on one gpu
......@@ -170,4 +175,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
SyncMultiGPUTrainer(config).train()
SyncMultiGPUTrainerParameterServer(config).train()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from .base import ProxyDataFlow
""" This file was deprecated """
__all__ = []
class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
"""
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
:param apply_symbf_on_dp: apply the above function to datapoint
"""
super(TFFuncMapper, self).__init__(ds)
self.get_placeholders = get_placeholders
self.symbf = symbf
self.apply_symbf_on_dp = apply_symbf_on_dp
self.device = device
def reset_state(self):
super(TFFuncMapper, self).reset_state()
self.graph = tf.Graph()
with self.graph.as_default(), \
tf.device(self.device):
self.placeholders = self.get_placeholders()
self.output_vars = self.symbf(self.placeholders)
self.sess = tf.Session()
def run_func(vals):
return self.sess.run(self.output_vars,
feed_dict=dict(zip(self.placeholders, vals)))
self.run_func = run_func
def get_data(self):
for dp in self.ds.get_data():
dp = self.apply_symbf_on_dp(dp, self.run_func)
if dp:
yield dp
if __name__ == '__main__':
from .raw import FakeData
ds = FakeData([[224, 224, 3]], 100000, random=False)
def tf_aug(v):
v = v[0]
v = tf.image.random_brightness(v, 0.1)
v = tf.image.random_contrast(v, 0.8, 1.2)
v = tf.image.random_flip_left_right(v)
return v
ds = TFFuncMapper(ds,
lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')],
tf_aug,
lambda dp, f: [f([dp[0]])[0]]
)
# from .prefetch import PrefetchDataZMQ
# from .image import AugmentImageComponent
# from . import imgaug
# ds = AugmentImageComponent(ds,
# [imgaug.Brightness(0.1, clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False),
# imgaug.Flip(horiz=True)
# ])
# ds = PrefetchDataZMQ(ds, 4)
ds.reset_state()
import tqdm
itr = ds.get_data()
for k in tqdm.trange(100000):
next(itr)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment