Commit d731cf7b authored by Yuxin Wu's avatar Yuxin Wu

small fixes

parent 087af16e
......@@ -18,7 +18,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
CIFAR10 89% test accuracy after 200 epochs.
CIFAR10 89% test accuracy after 60k step (about 150 epochs)
"""
BATCH_SIZE = 128
......
......@@ -17,7 +17,7 @@ class DumpParamAsImage(Callback):
"""
map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c].
scale: a scaling parameter on pixels
scale: a multiplier on pixel values, applied after map_func. default to 255
"""
self.var_name = var_name
self.func = map_func
......@@ -37,16 +37,15 @@ class DumpParamAsImage(Callback):
val = self.func(val)
if isinstance(val, list):
for idx, im in enumerate(val):
assert im.ndim in [2, 3], str(im.ndim)
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx))
imsave(fname, (im * self.scale).astype('uint8'))
self._dump_image(im, idx)
else:
im = val
assert im.ndim in [2, 3]
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{:03d}.png'.format(self.epoch_num))
imsave(fname, (im * self.scale).astype('uint8'))
self._dump_image(val)
def _dump_image(self, im, idx=None):
assert im.ndim in [2, 3], str(im.ndim)
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{:03d}{}.png'.format(
self.epoch_num, '-' + str(idx) if idx else ''))
imsave(fname, (im * self.scale).astype('uint8'))
......@@ -21,7 +21,9 @@ from .utils import logger
from .dataflow import DataFlow
class TrainConfig(object):
""" config for training"""
"""
Config for training a model with a single loss
"""
def __init__(self, **kwargs):
"""
Args:
......@@ -95,7 +97,7 @@ def scale_grads(grads, multiplier):
def start_train(config):
"""
Start training with the given config
Start training with a config
Args:
config: a TrainConfig instance
"""
......@@ -171,7 +173,9 @@ def start_train(config):
tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch):
with timed_operation('Epoch {}'.format(epoch)):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + config.step_per_epoch)):
for step in tqdm.trange(
config.step_per_epoch,
leave=True, mininterval=0.5,
......
......@@ -31,13 +31,13 @@ def timed_operation(msg, log_start=False):
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
def get_default_sess_config():
def get_default_sess_config(mem_fraction=0.5):
"""
Return a better config to use as default.
Tensorflow default session config consume too much resources
"""
conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = 0.6
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allocator_type = 'BFC'
conf.allow_soft_placement = True
return conf
......
......@@ -7,6 +7,7 @@ import logging
import os, shutil
import os.path
from termcolor import colored
from datetime import datetime
import sys
if not sys.version_info >= (3, 0):
input = raw_input # for compatibility
......@@ -48,7 +49,6 @@ logger = getlogger()
global LOG_FILE, LOG_DIR
def _set_file(path):
if os.path.isfile(path):
from datetime import datetime
backup_name = path + datetime.now().strftime('.%d-%H%M%S')
shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name))
......@@ -60,14 +60,15 @@ def set_logger_dir(dirname):
global LOG_FILE, LOG_DIR
LOG_DIR = dirname
if os.path.isdir(dirname):
logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." )
logger.warn("""\
Directory {} exists! Please either backup or delete it \
unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete):")
act = input().lower()
if act == 'b':
from datetime import datetime
backup_name = dirname + datetime.now().strftime('.%d-%H%M%S')
shutil.move(dirname, backup_name)
info("Log directory'{}' backuped to '{}'".format(dirname, backup_name))
info("Directory'{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'k':
......@@ -83,8 +84,8 @@ def set_logger_dir(dirname):
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func)
# a SummaryWriter
# a global SummaryWriter
writer = None
# a StatHolder
# a global StatHolder
stat_holder = None
......@@ -25,9 +25,12 @@ def describe_model():
def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, (list, tuple)):
for v in tensors:
assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v))
shape_str = ",".join(
map(lambda x: str(x.get_shape().as_list()), tensors))
else:
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list())
return shape_str
......@@ -5,6 +5,7 @@
# use user-space protobuf
import sys, os
site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages')
sys.path.insert(0, site)
if not sys.version_info >= (3, 0):
site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages')
sys.path.insert(0, site)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment