Commit e02d310c authored by Yuxin Wu's avatar Yuxin Wu

ask about existing log_dir

parent a4d51a2d
......@@ -86,8 +86,8 @@ class Model(ModelDesc):
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
dataset_train = FakeData([(227,227,3), tuple()], 10)
dataset_train = BatchData(dataset_train, 10)
......
......@@ -96,8 +96,8 @@ class Model(ModelDesc):
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset
dataset_train = dataset.Cifar10('train')
......
......@@ -85,8 +85,8 @@ class Model(ModelDesc):
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset
dataset_train = BatchData(dataset.Mnist('train'), 128)
......
......@@ -53,6 +53,8 @@ class SummaryWriter(Callback):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.log_dir = logger.LOG_DIR
logger.stat_holder = StatHolder(self.log_dir, print_tag)
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import logging
import os
import os, shutil
import os.path
from termcolor import colored
from .utils import mkdir_p
......@@ -35,33 +35,44 @@ def getlogger():
logger = getlogger()
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func)
def set_file(path):
# logger file and directory:
global LOG_FILE, LOG_DIR
def _set_file(path):
if os.path.isfile(path):
from datetime import datetime
backup_name = path + datetime.now().strftime('.%d-%H%M%S')
import shutil
shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name))
dirname = os.path.dirname(path)
if not os.path.isdir(dirname):
os.makedirs(dirname)
hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl)
global LOG_FILE
LOG_FILE = "train_log/log.log"
def set_logger_file(filename):
def set_logger_dir(dirname):
global LOG_FILE, LOG_DIR
LOG_FILE = filename
LOG_DIR = os.path.dirname(LOG_FILE)
mkdir_p(os.path.dirname(LOG_FILE))
set_file(LOG_FILE)
LOG_DIR = dirname
if os.path.isdir(dirname):
logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." )
logger.info("Select Action: k (keep) / b (backup) / d (delete):")
act = raw_input().lower()
if act == 'b':
from datetime import datetime
backup_name = dirname + datetime.now().strftime('.%d-%H%M%S')
shutil.move(dirname, backup_name)
info("Log directory'{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'k':
pass
else:
raise ValueError("Unknown action: {}".format(act))
mkdir_p(dirname)
LOG_FILE = os.path.join(dirname, 'log.log')
_set_file(LOG_FILE)
# global logger:
# export logger functions
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func)
# a SummaryWriter
writer = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment