Commit 65a3052f authored by ppwwyyxx's avatar ppwwyyxx

logger

parent 20887a79
......@@ -9,20 +9,23 @@ import gzip
import numpy
from six.moves import urllib
from utils import logger
__all__ = ['Mnist']
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
logger.info("Downloading mnist data...")
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
logger.info('Successfully downloaded to ' + filename)
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
......@@ -30,7 +33,6 @@ def _read32(bytestream):
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
......@@ -47,7 +49,6 @@ def extract_images(filename):
def extract_labels(filename):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
......
......@@ -3,15 +3,16 @@
# File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# prefer protobuf in user-namespace
# use user-space protobuf
import sys
import os
sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import tensorflow as tf
import numpy as np
import os
from utils import logger
from layers import *
from utils import *
from dataflow.dataset import Mnist
......@@ -77,6 +78,7 @@ def get_config():
IMAGE_SIZE = 28
LOG_DIR = 'train_log'
BATCH_SIZE = 128
logger.set_file(os.path.join(LOG_DIR, 'training.log'))
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
import re
from utils import logger
__all__ = ['regularize_cost']
......@@ -16,7 +17,7 @@ def regularize_cost(regex, func):
for p in params:
name = p.name
if re.search(regex, name):
print("Weight decay for {}".format(name))
logger.info("Weight decay for {}".format(name))
costs.append(func(p))
return tf.add_n(costs)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: logger.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import logging
import os
from termcolor import colored
__all__ = []
class MyFormatter(logging.Formatter):
def format(self, record):
date = colored('[%(asctime)s %(lineno)d@%(filename)s:%(name)s]', 'green')
msg = '%(message)s'
if record.levelno == logging.WARNING:
fmt = date + ' ' + colored('WRN', 'red', attrs=['blink']) + ' ' + msg
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
fmt = date + ' ' + colored('ERR', 'red', attrs=['blink', 'underline']) + ' ' + msg
else:
fmt = date + ' ' + msg
self._fmt = fmt
return super(MyFormatter, self).format(record)
def getlogger():
logger = logging.getLogger('tensorpack')
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
logger.addHandler(handler)
return logger
logger = getlogger()
for func in ['info', 'warning', 'error', 'critical', 'warn']:
locals()[func] = getattr(logger, func)
def set_file(path):
if os.path.isfile(path):
warn("File \"{}\" exists! backup? (y/n)".format(path))
resp = raw_input()
if resp in ['y', 'Y']:
from datetime import datetime
backup_name = path + datetime.now().strftime('.%d-%H%M%S')
import shutil
shutil.move(path, backup_name)
info("Log '{}' moved to '{}'".format(path, backup_name))
hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl)
......@@ -8,6 +8,7 @@ from .stat import *
from .callback import PeriodicCallback, Callback
from .naming import *
from .summary import *
import logger
class ValidationError(PeriodicCallback):
"""
......@@ -63,5 +64,6 @@ class ValidationError(PeriodicCallback):
create_summary('{}_cost'.format(self.prefix),
cost_avg),
self.epoch_num)
print "{} validation after epoch {}: err={}, cost={}".format(
self.prefix, self.epoch_num, err_stat.accuracy, cost_avg)
logger.info(
"{} validation after epoch {}: err={}, cost={}".format(
self.prefix, self.epoch_num, err_stat.accuracy, cost_avg))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment