Commit 09e99778 authored by ppwwyyxx's avatar ppwwyyxx

sessinit paraemter

parent 8c674cc4
...@@ -13,9 +13,7 @@ from tensorpack.utils import * ...@@ -13,9 +13,7 @@ from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.utils.callback import *
from tensorpack.utils.sessinit import *
from tensorpack.utils.validation_callback import * from tensorpack.utils.validation_callback import *
from tensorpack.dataflow.dataset import Cifar10
from tensorpack.dataflow import * from tensorpack.dataflow import *
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -81,10 +79,10 @@ def get_config(): ...@@ -81,10 +79,10 @@ def get_config():
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
import cv2 import cv2
dataset_train = Cifar10('train') dataset_train = dataset.Cifar10('train')
dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24))) dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24)))
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
dataset_test = Cifar10('test') dataset_test = dataset.Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24))) dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
dataset_test = BatchData(dataset_test, 128) dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
...@@ -133,7 +131,6 @@ if __name__ == '__main__': ...@@ -133,7 +131,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
global args
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -141,8 +138,7 @@ if __name__ == '__main__': ...@@ -141,8 +138,7 @@ if __name__ == '__main__':
with tf.Graph().as_default(): with tf.Graph().as_default():
train.prepare() train.prepare()
config = get_config() config = get_config()
if args.load: if args.load:
config['session_init'] = SaverRestore(args.load) config['session_init'] = SaverRestore(args.load)
sess_init = NewSession()
train.start_train(config) train.start_train(config)
...@@ -16,7 +16,6 @@ from tensorpack.utils.symbolic_functions import * ...@@ -16,7 +16,6 @@ from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.utils.callback import *
from tensorpack.utils.validation_callback import * from tensorpack.utils.validation_callback import *
from tensorpack.dataflow.dataset import Mnist
from tensorpack.dataflow import * from tensorpack.dataflow import *
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -95,8 +94,8 @@ def get_config(): ...@@ -95,8 +94,8 @@ def get_config():
IMAGE_SIZE = 28 IMAGE_SIZE = 28
dataset_train = BatchData(Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 20 #step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
...@@ -144,6 +143,7 @@ if __name__ == '__main__': ...@@ -144,6 +143,7 @@ if __name__ == '__main__':
from tensorpack import train from tensorpack import train
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -151,4 +151,6 @@ if __name__ == '__main__': ...@@ -151,4 +151,6 @@ if __name__ == '__main__':
with tf.Graph().as_default(): with tf.Graph().as_default():
train.prepare() train.prepare()
config = get_config() config = get_config()
if args.load:
config['session_init'] = SaverRestore(args.load)
train.start_train(config) train.start_train(config)
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
import os.path import os.path
import dataset
__SKIP = ['dftools', 'dataset'] __SKIP = ['dftools', 'dataset']
def global_import(name): def global_import(name):
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from functools import wraps
from ..utils.modelutils import * from ..utils.modelutils import *
from ..utils.summary import * from ..utils.summary import *
from ..utils import logger from ..utils import logger
...@@ -21,6 +23,7 @@ def layer_register(summary_activation=False): ...@@ -21,6 +23,7 @@ def layer_register(summary_activation=False):
Can be overriden when creating the layer. Can be overriden when creating the layer.
""" """
def wrapper(func): def wrapper(func):
@wraps(func)
def inner(*args, **kwargs): def inner(*args, **kwargs):
name = args[0] name = args[0]
assert isinstance(name, basestring) assert isinstance(name, basestring)
......
...@@ -7,9 +7,14 @@ import tensorflow as tf ...@@ -7,9 +7,14 @@ import tensorflow as tf
import re import re
from ..utils import logger from ..utils import logger
from ..utils import *
__all__ = ['regularize_cost'] __all__ = ['regularize_cost']
@memoized
def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name))
def regularize_cost(regex, func): def regularize_cost(regex, func):
""" """
Apply a regularizer on every trainable variable matching the regex Apply a regularizer on every trainable variable matching the regex
...@@ -20,8 +25,7 @@ def regularize_cost(regex, func): ...@@ -20,8 +25,7 @@ def regularize_cost(regex, func):
costs = [] costs = []
for p in params: for p in params:
name = p.name name = p.name
if re.search(regex, name): costs.append(func(p))
logger.info("Apply regularizer for {}".format(name)) _log_regularizer(name)
costs.append(func(p))
return tf.add_n(costs) return tf.add_n(costs)
...@@ -11,7 +11,6 @@ from utils import * ...@@ -11,7 +11,6 @@ from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard from utils.concurrency import EnqueueThread,coordinator_guard
from utils.summary import summary_moving_average from utils.summary import summary_moving_average
from utils.modelutils import describe_model from utils.modelutils import describe_model
from utils.sessinit import NewSession
from utils import logger from utils import logger
from dataflow import DataFlow from dataflow import DataFlow
......
...@@ -8,8 +8,11 @@ import os ...@@ -8,8 +8,11 @@ import os
import time import time
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import logger
import tensorflow as tf import tensorflow as tf
import collections
import logger
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals())
...@@ -17,6 +20,7 @@ def global_import(name): ...@@ -17,6 +20,7 @@ def global_import(name):
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
global_import('naming') global_import('naming')
global_import('sessinit')
@contextmanager @contextmanager
def timed_operation(msg, log_start=False): def timed_operation(msg, log_start=False):
...@@ -66,3 +70,31 @@ def get_default_sess_config(): ...@@ -66,3 +70,31 @@ def get_default_sess_config():
conf.allow_soft_placement = True conf.allow_soft_placement = True
return conf return conf
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
...@@ -7,7 +7,7 @@ from abc import abstractmethod ...@@ -7,7 +7,7 @@ from abc import abstractmethod
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from . import logger import logger
class SessionInit(object): class SessionInit(object):
@abstractmethod @abstractmethod
def init(self, sess): def init(self, sess):
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
import logger import logger
from .naming import * from naming import *
def create_summary(name, v): def create_summary(name, v):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment