Commit f698a04d authored by Yuxin Wu's avatar Yuxin Wu

bug fix: a line is missing in regularize

parent 8902af90
......@@ -78,7 +78,8 @@ def get_model(inputs, is_training):
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir)
dataset_train = FakeData([(227,227,3), tuple()], 10)
......
......@@ -80,7 +80,8 @@ def get_model(inputs, is_training):
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir)
dataset_train = dataset.Cifar10('train')
......
......@@ -89,7 +89,8 @@ def get_model(inputs, is_training):
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir)
IMAGE_SIZE = 28
......
......@@ -25,6 +25,7 @@ def regularize_cost(regex, func):
costs = []
for p in params:
name = p.name
if re.search(regex, name):
costs.append(func(p))
_log_regularizer(name)
return tf.add_n(costs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment