Commit 9f1af4c8 authored by Yuxin Wu's avatar Yuxin Wu

fix argscope bug

parent e90acf27
......@@ -49,7 +49,7 @@ class HyperParamSetter(Callback):
"""
ret = self._get_current_value()
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format(
logger.info("{} at epoch {} will change to {}".format(
self.op_name, self.epoch_num, ret))
self.last_value = ret
return ret
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
from functools import wraps
import six
import copy
from ..tfutils import *
from ..tfutils.modelutils import *
......@@ -34,7 +35,7 @@ def layer_register(summary_activation=False, log_shape=True):
inputs = args[0]
# update from current argument scope
actual_args = get_arg_scope()[func.__name__]
actual_args = copy.copy(get_arg_scope()[func.__name__])
actual_args.update(kwargs)
with tf.variable_scope(name) as scope:
......
......@@ -62,12 +62,14 @@ def LeakyReLU(x, alpha, name=None):
return tf.mul(x, 0.5, name=name)
def BNReLU(is_training):
def BNReLU(is_training, **kwargs):
"""
:param is_traning: boolean
:param kwargs: args for BatchNorm
:returns: a activation function that performs BN + ReLU (a too common combination)
"""
def BNReLU(x, name=None):
x = BatchNorm('bn', x, is_training)
x = BatchNorm('bn', x, is_training, **kwargs)
x = tf.nn.relu(x, name=name)
return x
return BNReLU
......@@ -71,7 +71,7 @@ class Trainer(object):
# some final operations that might modify the graph
logger.info("Preparing for training...")
self._init_summary()
get_global_step_var()
get_global_step_var() # ensure there is such var, before finalizing the graph
callbacks = self.config.callbacks
callbacks.before_train(self)
self.config.session_init.init(self.sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment