Commit 5a163e7c authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Add name to logger.autodir for different runs and simplify (#301)

* Add name to logger.autodir for different runs and simplify
similarity learning (incl. removing tf.FLAGS in favor of argparse)

* remove 'required' from mnist-embedding.py

* next fix

* c'mon travis

* Update logger.py

* fix linting..
parent 723b89b8
......@@ -9,7 +9,7 @@ import os
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
import argparse
import tensorflow as tf
import tensorflow.contrib.slim as slim
......@@ -25,13 +25,6 @@ except ImportError:
MATPLOTLIB_AVAIBLABLE = False
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('load', "", 'load model')
tf.flags.DEFINE_integer('gpu', 0, 'used gpu')
tf.flags.DEFINE_string('algorithm', "siamese", 'algorithm')
tf.flags.DEFINE_boolean('visualize', False, 'show embedding')
class EmbeddingModel(ModelDesc):
def embed(self, x, nfeatures=2):
"""Embed all given tensors into an nfeatures-dim space. """
......@@ -141,9 +134,6 @@ class SoftTripletModel(TripletModel):
def get_config(model, algorithm_name):
logger.set_logger_dir(
os.path.join('train_log',
'mnist-embeddings-{}'.format(algorithm_name)))
extra_display = ["cost"]
if not algorithm_name == "cosine":
......@@ -166,7 +156,7 @@ def get_config(model, algorithm_name):
)
def visualize(model_path, model):
def visualize(model_path, model, algo_name):
if not MATPLOTLIB_AVAIBLABLE:
logger.error("visualize requires matplotlib package ...")
return
......@@ -213,27 +203,32 @@ def visualize(model_path, model):
plt.axis([ax_min[0], ax_max[0], ax_min[1], ax_max[1]])
plt.xticks([]), plt.yticks([])
algo_name = FLAGS.algorithm
plt.title('Embedding using %s-loss' % algo_name)
plt.savefig('%s.jpg' % algo_name)
if __name__ == '__main__':
unknown = FLAGS._parse_flags()
assert len(unknown) == 0, "Invalid argument!"
assert FLAGS.algorithm in ["siamese", "cosine", "triplet", "softtriplet"]
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('-a', '--algorithm', help='used algorithm', type=str,
choices=["siamese", "cosine", "triplet", "softtriplet"])
parser.add_argument('--visualize', help='export embeddings into an image', action='store_true')
args = parser.parse_args()
ALGO_CONFIGS = {"siamese": SiameseModel,
"cosine": CosineModel,
"triplet": TripletModel,
"softtriplet": SoftTripletModel}
with change_gpu(FLAGS.gpu):
if FLAGS.visualize:
visualize(FLAGS.load, ALGO_CONFIGS[FLAGS.algorithm])
logger.auto_set_dir(name=args.algorithm)
with change_gpu(args.gpu):
if args.visualize:
visualize(args.load, ALGO_CONFIGS[args.algorithm], args.algorithm)
else:
config = get_config(ALGO_CONFIGS[FLAGS.algorithm], FLAGS.algorithm)
if FLAGS.load:
config.session_init = SaverRestore(FLAGS.load)
config = get_config(ALGO_CONFIGS[args.algorithm], args.algorithm)
if args.load:
config.session_init = SaverRestore(args.load)
else:
SimpleTrainer(config).train()
......@@ -123,13 +123,13 @@ def disable_logger():
globals()[func] = lambda x: None
def auto_set_dir(action=None):
def auto_set_dir(action=None, name=None):
"""
Set log directory to a subdir inside "train_log", with the name being
the main python file currently running"""
Use :func:`logger.set_logger_dir` to set log directory to
"./train_log/{scriptname}:{name}". "scriptname" is the name of the main python file currently running"""
mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__)
set_logger_dir(
os.path.join('train_log',
basename[:basename.rfind('.')]),
action=action)
auto_dirname = os.path.join('train_log', basename[:basename.rfind('.')])
if name:
auto_dirname += ':%s' % name
set_logger_dir(auto_dirname, action=action)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment