Commit 5a163e7c authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Add name to logger.autodir for different runs and simplify (#301)

* Add name to logger.autodir for different runs and simplify
similarity learning (incl. removing tf.FLAGS in favor of argparse)

* remove 'required' from mnist-embedding.py

* next fix

* c'mon travis

* Update logger.py

* fix linting..
parent 723b89b8
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
import argparse
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
...@@ -25,13 +25,6 @@ except ImportError: ...@@ -25,13 +25,6 @@ except ImportError:
MATPLOTLIB_AVAIBLABLE = False MATPLOTLIB_AVAIBLABLE = False
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('load', "", 'load model')
tf.flags.DEFINE_integer('gpu', 0, 'used gpu')
tf.flags.DEFINE_string('algorithm', "siamese", 'algorithm')
tf.flags.DEFINE_boolean('visualize', False, 'show embedding')
class EmbeddingModel(ModelDesc): class EmbeddingModel(ModelDesc):
def embed(self, x, nfeatures=2): def embed(self, x, nfeatures=2):
"""Embed all given tensors into an nfeatures-dim space. """ """Embed all given tensors into an nfeatures-dim space. """
...@@ -141,9 +134,6 @@ class SoftTripletModel(TripletModel): ...@@ -141,9 +134,6 @@ class SoftTripletModel(TripletModel):
def get_config(model, algorithm_name): def get_config(model, algorithm_name):
logger.set_logger_dir(
os.path.join('train_log',
'mnist-embeddings-{}'.format(algorithm_name)))
extra_display = ["cost"] extra_display = ["cost"]
if not algorithm_name == "cosine": if not algorithm_name == "cosine":
...@@ -166,7 +156,7 @@ def get_config(model, algorithm_name): ...@@ -166,7 +156,7 @@ def get_config(model, algorithm_name):
) )
def visualize(model_path, model): def visualize(model_path, model, algo_name):
if not MATPLOTLIB_AVAIBLABLE: if not MATPLOTLIB_AVAIBLABLE:
logger.error("visualize requires matplotlib package ...") logger.error("visualize requires matplotlib package ...")
return return
...@@ -213,27 +203,32 @@ def visualize(model_path, model): ...@@ -213,27 +203,32 @@ def visualize(model_path, model):
plt.axis([ax_min[0], ax_max[0], ax_min[1], ax_max[1]]) plt.axis([ax_min[0], ax_max[0], ax_min[1], ax_max[1]])
plt.xticks([]), plt.yticks([]) plt.xticks([]), plt.yticks([])
algo_name = FLAGS.algorithm
plt.title('Embedding using %s-loss' % algo_name) plt.title('Embedding using %s-loss' % algo_name)
plt.savefig('%s.jpg' % algo_name) plt.savefig('%s.jpg' % algo_name)
if __name__ == '__main__': if __name__ == '__main__':
unknown = FLAGS._parse_flags() parser = argparse.ArgumentParser()
assert len(unknown) == 0, "Invalid argument!" parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
assert FLAGS.algorithm in ["siamese", "cosine", "triplet", "softtriplet"] parser.add_argument('--load', help='load model')
parser.add_argument('-a', '--algorithm', help='used algorithm', type=str,
choices=["siamese", "cosine", "triplet", "softtriplet"])
parser.add_argument('--visualize', help='export embeddings into an image', action='store_true')
args = parser.parse_args()
ALGO_CONFIGS = {"siamese": SiameseModel, ALGO_CONFIGS = {"siamese": SiameseModel,
"cosine": CosineModel, "cosine": CosineModel,
"triplet": TripletModel, "triplet": TripletModel,
"softtriplet": SoftTripletModel} "softtriplet": SoftTripletModel}
with change_gpu(FLAGS.gpu): logger.auto_set_dir(name=args.algorithm)
if FLAGS.visualize:
visualize(FLAGS.load, ALGO_CONFIGS[FLAGS.algorithm]) with change_gpu(args.gpu):
if args.visualize:
visualize(args.load, ALGO_CONFIGS[args.algorithm], args.algorithm)
else: else:
config = get_config(ALGO_CONFIGS[FLAGS.algorithm], FLAGS.algorithm) config = get_config(ALGO_CONFIGS[args.algorithm], args.algorithm)
if FLAGS.load: if args.load:
config.session_init = SaverRestore(FLAGS.load) config.session_init = SaverRestore(args.load)
else: else:
SimpleTrainer(config).train() SimpleTrainer(config).train()
...@@ -123,13 +123,13 @@ def disable_logger(): ...@@ -123,13 +123,13 @@ def disable_logger():
globals()[func] = lambda x: None globals()[func] = lambda x: None
def auto_set_dir(action=None): def auto_set_dir(action=None, name=None):
""" """
Set log directory to a subdir inside "train_log", with the name being Use :func:`logger.set_logger_dir` to set log directory to
the main python file currently running""" "./train_log/{scriptname}:{name}". "scriptname" is the name of the main python file currently running"""
mod = sys.modules['__main__'] mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__) basename = os.path.basename(mod.__file__)
set_logger_dir( auto_dirname = os.path.join('train_log', basename[:basename.rfind('.')])
os.path.join('train_log', if name:
basename[:basename.rfind('.')]), auto_dirname += ':%s' % name
action=action) set_logger_dir(auto_dirname, action=action)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment