Commit f698a04d authored by Yuxin Wu's avatar Yuxin Wu

bug fix: a line is missing in regularize

parent 8902af90
...@@ -78,7 +78,8 @@ def get_model(inputs, is_training): ...@@ -78,7 +78,8 @@ def get_model(inputs, is_training):
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3]) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
dataset_train = FakeData([(227,227,3), tuple()], 10) dataset_train = FakeData([(227,227,3), tuple()], 10)
......
...@@ -80,7 +80,8 @@ def get_model(inputs, is_training): ...@@ -80,7 +80,8 @@ def get_model(inputs, is_training):
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3]) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
......
...@@ -89,7 +89,8 @@ def get_model(inputs, is_training): ...@@ -89,7 +89,8 @@ def get_model(inputs, is_training):
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3]) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
IMAGE_SIZE = 28 IMAGE_SIZE = 28
......
...@@ -25,7 +25,8 @@ def regularize_cost(regex, func): ...@@ -25,7 +25,8 @@ def regularize_cost(regex, func):
costs = [] costs = []
for p in params: for p in params:
name = p.name name = p.name
costs.append(func(p)) if re.search(regex, name):
_log_regularizer(name) costs.append(func(p))
_log_regularizer(name)
return tf.add_n(costs) return tf.add_n(costs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment