Commit f461ed2e authored by Yuxin Wu's avatar Yuxin Wu

use more memory for alexnet & vgg

parent aaf4cc78
......@@ -10,7 +10,7 @@ import os
import argparse
import cPickle as pkl
from tensorpack.train import TrainConfig, start_train
from tensorpack.train import TrainConfig
from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import *
from tensorpack.utils import *
......@@ -73,6 +73,7 @@ def run_test(path, input):
model=Model(),
input_data_mapping=[0],
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution
)
predict_func = get_predict_func(pred_config)
......
......@@ -10,7 +10,7 @@ import os
import argparse
import cPickle as pkl
from tensorpack.train import TrainConfig, start_train
from tensorpack.train import TrainConfig
from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import *
from tensorpack.utils import *
......@@ -82,6 +82,7 @@ def run_test(path, input):
model=Model(),
input_data_mapping=[0],
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution
)
predict_func = get_predict_func(pred_config)
......
......@@ -50,7 +50,8 @@ class PredictConfig(object):
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.session_config = kwargs.pop('session_config', None)
self.session_config = kwargs.pop('session_config',
get_default_sess_config(0.3))
self.session_init = kwargs.pop('session_init')
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
......@@ -80,7 +81,7 @@ def get_predict_func(config):
for n in output_var_names]
# XXX does it work? start with minimal memory, but allow growth
sess = tf.Session(config=get_default_sess_config(0.3))
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
def run_input(dp):
......
......@@ -35,7 +35,7 @@ def getlogger():
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
handler.setFormatter(MyFormatter(datefmt='%m%d %H:%M:%S'))
logger.addHandler(handler)
return logger
logger = getlogger()
......@@ -52,7 +52,7 @@ def _set_file(path):
info("Log file '{}' backuped to '{}'".format(path, backup_name))
hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w')
hdl.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
hdl.setFormatter(MyFormatter(datefmt='%m%d %H:%M:%S'))
logger.addHandler(hdl)
def set_logger_dir(dirname, action=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment