Commit aaf4cc78 authored by Yuxin Wu's avatar Yuxin Wu

update atari/common

parent 0485c1de
......@@ -75,14 +75,15 @@ def eval_model_multithread(cfg, nr_eval):
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback):
def __init__(self, nr_eval, output_name):
def __init__(self, nr_eval, input_names, output_names):
self.eval_episode = nr_eval
self.output_name = output_name
self.input_names = input_names
self.output_names = output_names
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], [self.output_name])] * NR_PROC
self.input_names, self.output_names)] * NR_PROC
def _trigger_epoch(self):
t = time.time()
......
......@@ -15,7 +15,7 @@ try:
from scipy.io import loadmat
__all__ = ['BSDS500']
except ImportError:
logger.error("Cannot import scipy. BSDS500 dataset won't be available!")
logger.warn("Cannot import scipy. BSDS500 dataset won't be available!")
__all__ = []
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
......
......@@ -15,7 +15,7 @@ try:
import scipy.io
__all__ = ['SVHNDigit']
except ImportError:
logger.error("Cannot import scipy. SVHNDigit dataset won't be available!")
logger.warn("Cannot import scipy. SVHNDigit dataset won't be available!")
__all__ = []
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
......
......@@ -122,7 +122,9 @@ class Trainer(object):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
smoothing=0.5,
dynamic_ncols=True,
ascii=True):
#bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
if self.coord.should_stop():
return
......
......@@ -16,7 +16,7 @@ __all__ = []
class MyFormatter(logging.Formatter):
def format(self, record):
date = colored('[%(asctime)s %(lineno)d@%(filename)s:%(name)s]', 'green')
date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s'
if record.levelno == logging.WARNING:
fmt = date + ' ' + colored('WRN', 'red', attrs=['blink']) + ' ' + msg
......@@ -27,25 +27,22 @@ class MyFormatter(logging.Formatter):
if hasattr(self, '_style'):
# Python3 compatibilty
self._style._fmt = fmt
self._fmt = fmt
else:
self._fmt = fmt
self._fmt = fmt
return super(MyFormatter, self).format(record)
def getlogger():
logger = logging.getLogger('tensorpack')
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
logger.addHandler(handler)
return logger
logger = getlogger()
def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
logger = getlogger()
# logger file and directory:
global LOG_FILE, LOG_DIR
def _set_file(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment