Commit 09e99778 authored by ppwwyyxx's avatar ppwwyyxx

sessinit paraemter

parent 8c674cc4
......@@ -13,9 +13,7 @@ from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.utils.sessinit import *
from tensorpack.utils.validation_callback import *
from tensorpack.dataflow.dataset import Cifar10
from tensorpack.dataflow import *
BATCH_SIZE = 128
......@@ -81,10 +79,10 @@ def get_config():
logger.set_logger_dir(log_dir)
import cv2
dataset_train = Cifar10('train')
dataset_train = dataset.Cifar10('train')
dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24)))
dataset_train = BatchData(dataset_train, 128)
dataset_test = Cifar10('test')
dataset_test = dataset.Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size()
......@@ -133,7 +131,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
global args
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......@@ -141,8 +138,7 @@ if __name__ == '__main__':
with tf.Graph().as_default():
train.prepare()
config = get_config()
if args.load:
config['session_init'] = SaverRestore(args.load)
sess_init = NewSession()
train.start_train(config)
......@@ -16,7 +16,6 @@ from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.utils.validation_callback import *
from tensorpack.dataflow.dataset import Mnist
from tensorpack.dataflow import *
BATCH_SIZE = 128
......@@ -95,8 +94,8 @@ def get_config():
IMAGE_SIZE = 28
dataset_train = BatchData(Mnist('train'), 128)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
......@@ -144,6 +143,7 @@ if __name__ == '__main__':
from tensorpack import train
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......@@ -151,4 +151,6 @@ if __name__ == '__main__':
with tf.Graph().as_default():
train.prepare()
config = get_config()
if args.load:
config['session_init'] = SaverRestore(args.load)
train.start_train(config)
......@@ -6,6 +6,7 @@
from pkgutil import walk_packages
import os
import os.path
import dataset
__SKIP = ['dftools', 'dataset']
def global_import(name):
......
......@@ -4,6 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from functools import wraps
from ..utils.modelutils import *
from ..utils.summary import *
from ..utils import logger
......@@ -21,6 +23,7 @@ def layer_register(summary_activation=False):
Can be overriden when creating the layer.
"""
def wrapper(func):
@wraps(func)
def inner(*args, **kwargs):
name = args[0]
assert isinstance(name, basestring)
......
......@@ -7,9 +7,14 @@ import tensorflow as tf
import re
from ..utils import logger
from ..utils import *
__all__ = ['regularize_cost']
@memoized
def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name))
def regularize_cost(regex, func):
"""
Apply a regularizer on every trainable variable matching the regex
......@@ -20,8 +25,7 @@ def regularize_cost(regex, func):
costs = []
for p in params:
name = p.name
if re.search(regex, name):
logger.info("Apply regularizer for {}".format(name))
costs.append(func(p))
costs.append(func(p))
_log_regularizer(name)
return tf.add_n(costs)
......@@ -11,7 +11,6 @@ from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard
from utils.summary import summary_moving_average
from utils.modelutils import describe_model
from utils.sessinit import NewSession
from utils import logger
from dataflow import DataFlow
......
......@@ -8,8 +8,11 @@ import os
import time
import sys
from contextlib import contextmanager
import logger
import tensorflow as tf
import collections
import logger
def global_import(name):
p = __import__(name, globals(), locals())
......@@ -17,6 +20,7 @@ def global_import(name):
for k in lst:
globals()[k] = p.__dict__[k]
global_import('naming')
global_import('sessinit')
@contextmanager
def timed_operation(msg, log_start=False):
......@@ -66,3 +70,31 @@ def get_default_sess_config():
conf.allow_soft_placement = True
return conf
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
......@@ -7,7 +7,7 @@ from abc import abstractmethod
import numpy as np
import tensorflow as tf
from . import logger
import logger
class SessionInit(object):
@abstractmethod
def init(self, sess):
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
import logger
from .naming import *
from naming import *
def create_summary(name, v):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment