Commit bcbbc645 authored by ppwwyyxx's avatar ppwwyyxx

working on alexnet

parent 87f7e7cb
...@@ -17,6 +17,7 @@ class BatchData(DataFlow): ...@@ -17,6 +17,7 @@ class BatchData(DataFlow):
if set, might return a data point of a different shape if set, might return a data point of a different shape
""" """
self.ds = ds self.ds = ds
assert batch_size <= ds.size()
self.batch_size = batch_size self.batch_size = batch_size
self.remainder = remainder self.remainder = remainder
......
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
import os.path import os.path
def global_import(name): def global_import(name):
print name
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
......
...@@ -63,6 +63,7 @@ class Cifar10(DataFlow): ...@@ -63,6 +63,7 @@ class Cifar10(DataFlow):
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'cifar10_data') dir = os.path.join(os.path.dirname(__file__), 'cifar10_data')
maybe_download_and_extract(dir)
if train_or_test == 'train': if train_or_test == 'train':
self.fs = [os.path.join( self.fs = [os.path.join(
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: infer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from itertools import count
import argparse
import numpy as np
from utils import *
from utils.modelutils import describe_model, restore_params
from utils import logger
from dataflow import DataFlow
def start_infer(config):
"""
Args:
config: a tensorpack config dictionary
"""
dataset_test = config['dataset_test']
assert isinstance(dataset_test, DataFlow), dataset_test.__class__
# a tf.ConfigProto instance
sess_config = config.get('session_config', None)
assert isinstance(sess_config, tf.ConfigProto), sess_config.__class__
# TODO callback should have trigger_step and trigger_end?
callback = config['callback']
# restore saved params
params = config.get('restore_params', {})
# input/output variables
input_vars = config['inputs']
get_model_func = config['get_model_func']
output_vars, cost_var = get_model_func(input_vars, is_training=False)
# build graph
G = tf.get_default_graph()
G.add_to_collection(FORWARD_FUNC_KEY, get_model_func)
for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
describe_model()
sess = tf.Session(config=sess_config)
sess.run(tf.initialize_all_variables())
restore_params(sess, params)
with sess.as_default():
with timed_operation('running one batch'):
for dp in dataset_test.get_data():
feed = dict(zip(input_vars, dp))
fetches = [cost_var] + output_vars
results = sess.run(fetches, feed_dict=feed)
cost = results[0]
outputs = results[1:]
prob = outputs[0]
callback(dp, outputs, cost)
def main(get_config_func):
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
config = get_config_func()
start_infer(config)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from utils.modelutils import *
from utils.summary import * from utils.summary import *
from utils import logger from utils import logger
......
...@@ -12,18 +12,22 @@ __all__ = ['Conv2D'] ...@@ -12,18 +12,22 @@ __all__ = ['Conv2D']
@layer_register(summary_activation=True) @layer_register(summary_activation=True)
def Conv2D(x, out_channel, kernel_shape, def Conv2D(x, out_channel, kernel_shape,
padding='VALID', stride=1, padding='VALID', stride=1,
W_init=None, b_init=None, nl=tf.nn.relu): W_init=None, b_init=None,
nl=tf.nn.relu, split=1):
""" """
kernel_shape: (h, w) or a int kernel_shape: (h, w) or a int
stride: (h, w) or a int stride: (h, w) or a int
padding: 'valid' or 'same' padding: 'valid' or 'same'
split: split channels. used in alexnet
""" """
in_shape = x.get_shape().as_list() in_shape = x.get_shape().as_list()
in_channel = in_shape[-1] in_channel = in_shape[-1]
assert in_channel % split == 0
assert out_channel % split == 0
kernel_shape = shape2d(kernel_shape) kernel_shape = shape2d(kernel_shape)
padding = padding.upper() padding = padding.upper()
filter_shape = kernel_shape + [in_channel, out_channel] filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(stride) stride = shape4d(stride)
if W_init is None: if W_init is None:
...@@ -34,6 +38,14 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -34,6 +38,14 @@ def Conv2D(x, out_channel, kernel_shape,
W = tf.get_variable('W', filter_shape, initializer=W_init) # TODO collections W = tf.get_variable('W', filter_shape, initializer=W_init) # TODO collections
b = tf.get_variable('b', [out_channel], initializer=b_init) b = tf.get_variable('b', [out_channel], initializer=b_init)
conv = tf.nn.conv2d(x, W, stride, padding) if split == 1:
conv = tf.nn.conv2d(x, W, stride, padding)
else:
inputs = tf.split(3, split, x)
kernels = tf.split(3, split, W)
outputs = [tf.nn.conv2d(i, k, stride, padding)
for i, k in zip(inputs, kernels)]
conv = tf.concat(3, outputs)
return nl(tf.nn.bias_add(conv, b)) return nl(tf.nn.bias_add(conv, b))
...@@ -4,18 +4,20 @@ ...@@ -4,18 +4,20 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from itertools import count
import argparse
from utils import * from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard from utils.concurrency import EnqueueThread,coordinator_guard
from utils.summary import summary_moving_average, describe_model from utils.summary import summary_moving_average
from utils.modelutils import restore_params, describe_model
from utils import logger
from dataflow import DataFlow from dataflow import DataFlow
from itertools import count
import argparse
def prepare(): def prepare():
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
def start_train(config): def start_train(config):
""" """
Start training with the given config Start training with the given config
...@@ -36,6 +38,9 @@ def start_train(config): ...@@ -36,6 +38,9 @@ def start_train(config):
sess_config = config.get('session_config', None) sess_config = config.get('session_config', None)
assert isinstance(sess_config, tf.ConfigProto), sess_config.__class__ assert isinstance(sess_config, tf.ConfigProto), sess_config.__class__
# restore saved params
params = config.get('restore_params', {})
# input/output variables # input/output variables
input_vars = config['inputs'] input_vars = config['inputs']
input_queue = config['input_queue'] input_queue = config['input_queue']
...@@ -78,6 +83,8 @@ def start_train(config): ...@@ -78,6 +83,8 @@ def start_train(config):
sess = tf.Session(config=sess_config) sess = tf.Session(config=sess_config)
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
restore_params(sess, params)
# start training: # start training:
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
# a thread that keeps filling the queue # a thread that keeps filling the queue
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: modelutils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import logger
def restore_params(sess, params):
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_dict = dict([v.name, v] for v in variables)
for name, value in params.iteritems():
try:
var = var_dict[name]
except (ValueError, KeyError):
logger.warn("Param {} not found in this graph".format(name))
continue
logger.info("Restoring param {}".format(name))
sess.run(var.assign(value))
def describe_model():
""" describe the current model parameters"""
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), tensors))
else:
shape_str = str(tensors.get_shape().as_list())
return shape_str
...@@ -60,27 +60,3 @@ def summary_moving_average(cost_var): ...@@ -60,27 +60,3 @@ def summary_moving_average(cost_var):
tf.scalar_summary(c.op.name, averager.average(c)) tf.scalar_summary(c.op.name, averager.average(c))
return avg_maintain_op return avg_maintain_op
def describe_model():
""" describe the current model parameters"""
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), tensors))
else:
shape_str = str(tensors.get_shape().as_list())
return shape_str
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment