Commit f461ed2e authored by Yuxin Wu's avatar Yuxin Wu

use more memory for alexnet & vgg

parent aaf4cc78
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import argparse import argparse
import cPickle as pkl import cPickle as pkl
from tensorpack.train import TrainConfig, start_train from tensorpack.train import TrainConfig
from tensorpack.predict import PredictConfig, get_predict_func from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
...@@ -73,6 +73,7 @@ def run_test(path, input): ...@@ -73,6 +73,7 @@ def run_test(path, input):
model=Model(), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution output_var_names=['output:0'] # output:0 is the probability distribution
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
......
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import argparse import argparse
import cPickle as pkl import cPickle as pkl
from tensorpack.train import TrainConfig, start_train from tensorpack.train import TrainConfig
from tensorpack.predict import PredictConfig, get_predict_func from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
...@@ -82,6 +82,7 @@ def run_test(path, input): ...@@ -82,6 +82,7 @@ def run_test(path, input):
model=Model(), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution output_var_names=['output:0'] # output:0 is the probability distribution
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
......
...@@ -50,7 +50,8 @@ class PredictConfig(object): ...@@ -50,7 +50,8 @@ class PredictConfig(object):
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.session_config = kwargs.pop('session_config', None) self.session_config = kwargs.pop('session_config',
get_default_sess_config(0.3))
self.session_init = kwargs.pop('session_init') self.session_init = kwargs.pop('session_init')
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_data_mapping = kwargs.pop('input_data_mapping', None)
...@@ -80,7 +81,7 @@ def get_predict_func(config): ...@@ -80,7 +81,7 @@ def get_predict_func(config):
for n in output_var_names] for n in output_var_names]
# XXX does it work? start with minimal memory, but allow growth # XXX does it work? start with minimal memory, but allow growth
sess = tf.Session(config=get_default_sess_config(0.3)) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
......
...@@ -35,7 +35,7 @@ def getlogger(): ...@@ -35,7 +35,7 @@ def getlogger():
logger.propagate = False logger.propagate = False
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S')) handler.setFormatter(MyFormatter(datefmt='%m%d %H:%M:%S'))
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
logger = getlogger() logger = getlogger()
...@@ -52,7 +52,7 @@ def _set_file(path): ...@@ -52,7 +52,7 @@ def _set_file(path):
info("Log file '{}' backuped to '{}'".format(path, backup_name)) info("Log file '{}' backuped to '{}'".format(path, backup_name))
hdl = logging.FileHandler( hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w') filename=path, encoding='utf-8', mode='w')
hdl.setFormatter(MyFormatter(datefmt='%d %H:%M:%S')) hdl.setFormatter(MyFormatter(datefmt='%m%d %H:%M:%S'))
logger.addHandler(hdl) logger.addHandler(hdl)
def set_logger_dir(dirname, action=None): def set_logger_dir(dirname, action=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment