Commit d731cf7b authored by Yuxin Wu's avatar Yuxin Wu

small fixes

parent 087af16e
...@@ -18,7 +18,7 @@ from tensorpack.dataflow import * ...@@ -18,7 +18,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
""" """
CIFAR10 89% test accuracy after 200 epochs. CIFAR10 89% test accuracy after 60k step (about 150 epochs)
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
......
...@@ -17,7 +17,7 @@ class DumpParamAsImage(Callback): ...@@ -17,7 +17,7 @@ class DumpParamAsImage(Callback):
""" """
map_func: map the value of the variable to an image or list of images, default to identity map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c]. images should have shape [h, w] or [h, w, c].
scale: a scaling parameter on pixels scale: a multiplier on pixel values, applied after map_func. default to 255
""" """
self.var_name = var_name self.var_name = var_name
self.func = map_func self.func = map_func
...@@ -37,16 +37,15 @@ class DumpParamAsImage(Callback): ...@@ -37,16 +37,15 @@ class DumpParamAsImage(Callback):
val = self.func(val) val = self.func(val)
if isinstance(val, list): if isinstance(val, list):
for idx, im in enumerate(val): for idx, im in enumerate(val):
assert im.ndim in [2, 3], str(im.ndim) self._dump_image(im, idx)
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx))
imsave(fname, (im * self.scale).astype('uint8'))
else: else:
im = val self._dump_image(val)
assert im.ndim in [2, 3]
def _dump_image(self, im, idx=None):
assert im.ndim in [2, 3], str(im.ndim)
fname = os.path.join( fname = os.path.join(
self.log_dir, self.log_dir,
self.prefix + '-ep{:03d}.png'.format(self.epoch_num)) self.prefix + '-ep{:03d}{}.png'.format(
self.epoch_num, '-' + str(idx) if idx else ''))
imsave(fname, (im * self.scale).astype('uint8')) imsave(fname, (im * self.scale).astype('uint8'))
...@@ -21,7 +21,9 @@ from .utils import logger ...@@ -21,7 +21,9 @@ from .utils import logger
from .dataflow import DataFlow from .dataflow import DataFlow
class TrainConfig(object): class TrainConfig(object):
""" config for training""" """
Config for training a model with a single loss
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
Args: Args:
...@@ -95,7 +97,7 @@ def scale_grads(grads, multiplier): ...@@ -95,7 +97,7 @@ def scale_grads(grads, multiplier):
def start_train(config): def start_train(config):
""" """
Start training with the given config Start training with a config
Args: Args:
config: a TrainConfig instance config: a TrainConfig instance
""" """
...@@ -171,7 +173,9 @@ def start_train(config): ...@@ -171,7 +173,9 @@ def start_train(config):
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('Epoch {}'.format(epoch)): with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
config.step_per_epoch, config.step_per_epoch,
leave=True, mininterval=0.5, leave=True, mininterval=0.5,
......
...@@ -31,13 +31,13 @@ def timed_operation(msg, log_start=False): ...@@ -31,13 +31,13 @@ def timed_operation(msg, log_start=False):
logger.info('{} finished, time={:.2f}sec.'.format( logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
def get_default_sess_config(): def get_default_sess_config(mem_fraction=0.5):
""" """
Return a better config to use as default. Return a better config to use as default.
Tensorflow default session config consume too much resources Tensorflow default session config consume too much resources
""" """
conf = tf.ConfigProto() conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = 0.6 conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.allow_soft_placement = True conf.allow_soft_placement = True
return conf return conf
......
...@@ -7,6 +7,7 @@ import logging ...@@ -7,6 +7,7 @@ import logging
import os, shutil import os, shutil
import os.path import os.path
from termcolor import colored from termcolor import colored
from datetime import datetime
import sys import sys
if not sys.version_info >= (3, 0): if not sys.version_info >= (3, 0):
input = raw_input # for compatibility input = raw_input # for compatibility
...@@ -48,7 +49,6 @@ logger = getlogger() ...@@ -48,7 +49,6 @@ logger = getlogger()
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
def _set_file(path): def _set_file(path):
if os.path.isfile(path): if os.path.isfile(path):
from datetime import datetime
backup_name = path + datetime.now().strftime('.%d-%H%M%S') backup_name = path + datetime.now().strftime('.%d-%H%M%S')
shutil.move(path, backup_name) shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name)) info("Log file '{}' backuped to '{}'".format(path, backup_name))
...@@ -60,14 +60,15 @@ def set_logger_dir(dirname): ...@@ -60,14 +60,15 @@ def set_logger_dir(dirname):
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
LOG_DIR = dirname LOG_DIR = dirname
if os.path.isdir(dirname): if os.path.isdir(dirname):
logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." ) logger.warn("""\
Directory {} exists! Please either backup or delete it \
unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete):") logger.info("Select Action: k (keep) / b (backup) / d (delete):")
act = input().lower() act = input().lower()
if act == 'b': if act == 'b':
from datetime import datetime
backup_name = dirname + datetime.now().strftime('.%d-%H%M%S') backup_name = dirname + datetime.now().strftime('.%d-%H%M%S')
shutil.move(dirname, backup_name) shutil.move(dirname, backup_name)
info("Log directory'{}' backuped to '{}'".format(dirname, backup_name)) info("Directory'{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd': elif act == 'd':
shutil.rmtree(dirname) shutil.rmtree(dirname)
elif act == 'k': elif act == 'k':
...@@ -83,8 +84,8 @@ def set_logger_dir(dirname): ...@@ -83,8 +84,8 @@ def set_logger_dir(dirname):
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']: for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func) locals()[func] = getattr(logger, func)
# a SummaryWriter # a global SummaryWriter
writer = None writer = None
# a StatHolder # a global StatHolder
stat_holder = None stat_holder = None
...@@ -25,9 +25,12 @@ def describe_model(): ...@@ -25,9 +25,12 @@ def describe_model():
def get_shape_str(tensors): def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors""" """ return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
for v in tensors:
assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v))
shape_str = ",".join( shape_str = ",".join(
map(lambda x: str(x.get_shape().as_list()), tensors)) map(lambda x: str(x.get_shape().as_list()), tensors))
else: else:
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list()) shape_str = str(tensors.get_shape().as_list())
return shape_str return shape_str
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# use user-space protobuf # use user-space protobuf
import sys, os import sys, os
site = os.path.join(os.environ['HOME'], if not sys.version_info >= (3, 0):
site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages') '.local/lib/python2.7/site-packages')
sys.path.insert(0, site) sys.path.insert(0, site)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment