Commit aaf4cc78 authored by Yuxin Wu's avatar Yuxin Wu

update atari/common

parent 0485c1de
...@@ -75,14 +75,15 @@ def eval_model_multithread(cfg, nr_eval): ...@@ -75,14 +75,15 @@ def eval_model_multithread(cfg, nr_eval):
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback): class Evaluator(Callback):
def __init__(self, nr_eval, output_name): def __init__(self, nr_eval, input_names, output_names):
self.eval_episode = nr_eval self.eval_episode = nr_eval
self.output_name = output_name self.input_names = input_names
self.output_names = output_names
def _before_train(self): def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func( self.pred_funcs = [self.trainer.get_predict_func(
['state'], [self.output_name])] * NR_PROC self.input_names, self.output_names)] * NR_PROC
def _trigger_epoch(self): def _trigger_epoch(self):
t = time.time() t = time.time()
......
...@@ -15,7 +15,7 @@ try: ...@@ -15,7 +15,7 @@ try:
from scipy.io import loadmat from scipy.io import loadmat
__all__ = ['BSDS500'] __all__ = ['BSDS500']
except ImportError: except ImportError:
logger.error("Cannot import scipy. BSDS500 dataset won't be available!") logger.warn("Cannot import scipy. BSDS500 dataset won't be available!")
__all__ = [] __all__ = []
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
......
...@@ -15,7 +15,7 @@ try: ...@@ -15,7 +15,7 @@ try:
import scipy.io import scipy.io
__all__ = ['SVHNDigit'] __all__ = ['SVHNDigit']
except ImportError: except ImportError:
logger.error("Cannot import scipy. SVHNDigit dataset won't be available!") logger.warn("Cannot import scipy. SVHNDigit dataset won't be available!")
__all__ = [] __all__ = []
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
......
...@@ -122,7 +122,9 @@ class Trainer(object): ...@@ -122,7 +122,9 @@ class Trainer(object):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
leave=True, mininterval=0.5, leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True): smoothing=0.5,
dynamic_ncols=True,
ascii=True):
#bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'): #bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
if self.coord.should_stop(): if self.coord.should_stop():
return return
......
...@@ -16,7 +16,7 @@ __all__ = [] ...@@ -16,7 +16,7 @@ __all__ = []
class MyFormatter(logging.Formatter): class MyFormatter(logging.Formatter):
def format(self, record): def format(self, record):
date = colored('[%(asctime)s %(lineno)d@%(filename)s:%(name)s]', 'green') date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s' msg = '%(message)s'
if record.levelno == logging.WARNING: if record.levelno == logging.WARNING:
fmt = date + ' ' + colored('WRN', 'red', attrs=['blink']) + ' ' + msg fmt = date + ' ' + colored('WRN', 'red', attrs=['blink']) + ' ' + msg
...@@ -28,24 +28,21 @@ class MyFormatter(logging.Formatter): ...@@ -28,24 +28,21 @@ class MyFormatter(logging.Formatter):
# Python3 compatibilty # Python3 compatibilty
self._style._fmt = fmt self._style._fmt = fmt
self._fmt = fmt self._fmt = fmt
else:
self._fmt = fmt
return super(MyFormatter, self).format(record) return super(MyFormatter, self).format(record)
def getlogger(): def getlogger():
logger = logging.getLogger('tensorpack') logger = logging.getLogger('tensorpack')
logger.propagate = False logger.propagate = False
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler() handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S')) handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
logger = getlogger()
def get_time_str(): def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S') return datetime.now().strftime('%m%d-%H%M%S')
logger = getlogger()
# logger file and directory: # logger file and directory:
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
def _set_file(path): def _set_file(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment