Commit 745ad4f0 authored by ppwwyyxx's avatar ppwwyyxx

clean-ups

parent 585f0837
...@@ -186,7 +186,7 @@ ...@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright {yyyy} {name of copyright owner} Copyright Yuxin Wu
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
...@@ -3,7 +3,8 @@ Neural Network Toolbox based on TensorFlow ...@@ -3,7 +3,8 @@ Neural Network Toolbox based on TensorFlow
## Features: ## Features:
+ Scoped Abstraction of common models. + Scoped abstraction of common models.
+ Provide callbacks to control training behavior (as in [Keras](http://keras.io)). + Provide callbacks to control training behavior (as in [Keras](http://keras.io)).
+ Use `Dataflow` to fine-grained control data preprocessing. + Use `Dataflow` to own fine-grained control on data preprocessing.
+ Write a config file, tensorpack will do the rest. + Automatically use the Queue operator in tensorflow to speed up input.
+ Training and testing graph are modeled together, automatically.
...@@ -8,7 +8,6 @@ from abc import abstractmethod ...@@ -8,7 +8,6 @@ from abc import abstractmethod
__all__ = ['DataFlow'] __all__ = ['DataFlow']
class DataFlow(object): class DataFlow(object):
# TODO private impl
@abstractmethod @abstractmethod
def get_data(self): def get_data(self):
""" """
......
...@@ -11,9 +11,10 @@ __all__ = ['BatchData', 'FixedSizeData'] ...@@ -11,9 +11,10 @@ __all__ = ['BatchData', 'FixedSizeData']
class BatchData(DataFlow): class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
""" """
Args: Group data in ds into batches
ds: a dataflow ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size remainder: whether to return the remaining data smaller than a batch_size.
if set, might return a data point of a different shape
""" """
self.ds = ds self.ds = ds
self.batch_size = batch_size self.batch_size = batch_size
......
...@@ -90,8 +90,8 @@ def get_config(): ...@@ -90,8 +90,8 @@ def get_config():
dataset_train = BatchData(Mnist('train'), BATCH_SIZE) dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_test = BatchData(Mnist('test'), 256, remainder=True) dataset_test = BatchData(Mnist('test'), 256, remainder=True)
dataset_train = FixedSizeData(dataset_train, 20) #dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1 sess_config.device_count['GPU'] = 1
......
...@@ -12,7 +12,12 @@ _layer_logged = set() ...@@ -12,7 +12,12 @@ _layer_logged = set()
def layer_register(summary_activation=False): def layer_register(summary_activation=False):
""" """
summary_activation: default behavior of whether to summary the output of this layer Register a layer.
Args:
summary_activation:
Define the default behavior of whether to
summary the output(activation) of this layer.
Can be overriden when creating the layer.
""" """
def wrapper(func): def wrapper(func):
def inner(*args, **kwargs): def inner(*args, **kwargs):
...@@ -26,24 +31,17 @@ def layer_register(summary_activation=False): ...@@ -26,24 +31,17 @@ def layer_register(summary_activation=False):
outputs = func(*args, **kwargs) outputs = func(*args, **kwargs)
if name not in _layer_logged: if name not in _layer_logged:
# log shape info and add activation # log shape info and add activation
if isinstance(inputs, list): logger.info("{} input: {}".format(
shape_str = ",".join( name, get_shape_str(inputs)))
map(str(x.get_shape().as_list()), inputs)) logger.info("{} output: {}".format(
else: name, get_shape_str(outputs)))
shape_str = str(inputs.get_shape().as_list())
logger.info("{} input: {}".format(name, shape_str))
if isinstance(outputs, list): if do_summary:
shape_str = ",".join( if isinstance(outputs, list):
map(str(x.get_shape().as_list()), outputs))
if do_summary:
for x in outputs: for x in outputs:
add_activation_summary(x, scope.name) add_activation_summary(x, scope.name)
else: else:
shape_str = str(outputs.get_shape().as_list())
if do_summary:
add_activation_summary(outputs, scope.name) add_activation_summary(outputs, scope.name)
logger.info("{} output: {}".format(name, shape_str))
_layer_logged.add(name) _layer_logged.add(name)
return outputs return outputs
return inner return inner
...@@ -63,4 +61,3 @@ def shape2d(a): ...@@ -63,4 +61,3 @@ def shape2d(a):
def shape4d(a): def shape4d(a):
# for use with tensorflow # for use with tensorflow
return [1] + shape2d(a) + [1] return [1] + shape2d(a) + [1]
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
import tensorflow as tf import tensorflow as tf
from utils import * from utils import *
from utils.concurrency import * from utils.concurrency import EnqueueThread,coordinator_guard
from utils.callback import * from utils.summary import summary_moving_average, describe_model
from utils.summary import *
from dataflow import DataFlow from dataflow import DataFlow
from itertools import count from itertools import count
import argparse import argparse
...@@ -97,7 +96,6 @@ def start_train(config): ...@@ -97,7 +96,6 @@ def start_train(config):
# note that summary_op will take a data from the queue. # note that summary_op will take a data from the queue.
callbacks.trigger_epoch() callbacks.trigger_epoch()
sess.close()
def main(get_config_func): def main(get_config_func):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -27,20 +27,6 @@ def timed_operation(msg, log_start=False): ...@@ -27,20 +27,6 @@ def timed_operation(msg, log_start=False):
logger.info('finished {}, time={:.2f}sec.'.format( logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
def describe_model():
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
# TODO disable shape output in get_model
@contextmanager @contextmanager
def create_test_graph(): def create_test_graph():
G = tf.get_default_graph() G = tf.get_default_graph()
......
...@@ -3,9 +3,6 @@ ...@@ -3,9 +3,6 @@
# File: naming.py # File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
IS_TRAINING_OP_NAME = 'is_training'
IS_TRAINING_VAR_NAME = 'is_training:0'
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import logger
from .naming import * from .naming import *
def create_summary(name, v): def create_summary(name, v):
...@@ -44,6 +45,10 @@ def add_histogram_summary(regex): ...@@ -44,6 +45,10 @@ def add_histogram_summary(regex):
tf.histogram_summary(name, p) tf.histogram_summary(name, p)
def summary_moving_average(cost_var): def summary_moving_average(cost_var):
""" Create a MovingAverage op and summary for all variables in
COST_VARS_KEY, SUMMARY_VARS_KEY, as well as the argument
Return a op to maintain these average
"""
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.9, num_updates=global_step_var, name='avg') 0.9, num_updates=global_step_var, name='avg')
...@@ -54,3 +59,28 @@ def summary_moving_average(cost_var): ...@@ -54,3 +59,28 @@ def summary_moving_average(cost_var):
for c in vars_to_summary: for c in vars_to_summary:
tf.scalar_summary(c.op.name, averager.average(c)) tf.scalar_summary(c.op.name, averager.average(c))
return avg_maintain_op return avg_maintain_op
def describe_model():
""" describe the current model parameters"""
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), tensors))
else:
shape_str = str(tensors.get_shape().as_list())
return shape_str
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment