Commit e02d310c authored by Yuxin Wu's avatar Yuxin Wu

ask about existing log_dir

parent a4d51a2d
...@@ -86,8 +86,8 @@ class Model(ModelDesc): ...@@ -86,8 +86,8 @@ class Model(ModelDesc):
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) logger.set_logger_dir(
logger.set_logger_file(os.path.join(log_dir, 'training.log')) os.path.join('train_log', basename[:basename.rfind('.')]))
dataset_train = FakeData([(227,227,3), tuple()], 10) dataset_train = FakeData([(227,227,3), tuple()], 10)
dataset_train = BatchData(dataset_train, 10) dataset_train = BatchData(dataset_train, 10)
......
...@@ -96,8 +96,8 @@ class Model(ModelDesc): ...@@ -96,8 +96,8 @@ class Model(ModelDesc):
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) logger.set_logger_dir(
logger.set_logger_file(os.path.join(log_dir, 'training.log')) os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
......
...@@ -85,8 +85,8 @@ class Model(ModelDesc): ...@@ -85,8 +85,8 @@ class Model(ModelDesc):
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) logger.set_logger_dir(
logger.set_logger_file(os.path.join(log_dir, 'training.log')) os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
......
...@@ -53,6 +53,8 @@ class SummaryWriter(Callback): ...@@ -53,6 +53,8 @@ class SummaryWriter(Callback):
""" print_tag : a list of regex to match scalar summary to print """ print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags if None, will print all scalar tags
""" """
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
logger.stat_holder = StatHolder(self.log_dir, print_tag) logger.stat_holder = StatHolder(self.log_dir, print_tag)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import logging import logging
import os import os, shutil
import os.path import os.path
from termcolor import colored from termcolor import colored
from .utils import mkdir_p from .utils import mkdir_p
...@@ -35,33 +35,44 @@ def getlogger(): ...@@ -35,33 +35,44 @@ def getlogger():
logger = getlogger() logger = getlogger()
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']: # logger file and directory:
locals()[func] = getattr(logger, func) global LOG_FILE, LOG_DIR
def _set_file(path):
def set_file(path):
if os.path.isfile(path): if os.path.isfile(path):
from datetime import datetime from datetime import datetime
backup_name = path + datetime.now().strftime('.%d-%H%M%S') backup_name = path + datetime.now().strftime('.%d-%H%M%S')
import shutil
shutil.move(path, backup_name) shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name)) info("Log file '{}' backuped to '{}'".format(path, backup_name))
dirname = os.path.dirname(path)
if not os.path.isdir(dirname):
os.makedirs(dirname)
hdl = logging.FileHandler( hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w') filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl) logger.addHandler(hdl)
global LOG_FILE def set_logger_dir(dirname):
LOG_FILE = "train_log/log.log"
def set_logger_file(filename):
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
LOG_FILE = filename LOG_DIR = dirname
LOG_DIR = os.path.dirname(LOG_FILE) if os.path.isdir(dirname):
mkdir_p(os.path.dirname(LOG_FILE)) logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." )
set_file(LOG_FILE) logger.info("Select Action: k (keep) / b (backup) / d (delete):")
act = raw_input().lower()
if act == 'b':
from datetime import datetime
backup_name = dirname + datetime.now().strftime('.%d-%H%M%S')
shutil.move(dirname, backup_name)
info("Log directory'{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'k':
pass
else:
raise ValueError("Unknown action: {}".format(act))
mkdir_p(dirname)
LOG_FILE = os.path.join(dirname, 'log.log')
_set_file(LOG_FILE)
# global logger:
# export logger functions
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func)
# a SummaryWriter # a SummaryWriter
writer = None writer = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment