Commit 9f1af4c8 authored by Yuxin Wu's avatar Yuxin Wu

fix argscope bug

parent e90acf27
...@@ -49,7 +49,7 @@ class HyperParamSetter(Callback): ...@@ -49,7 +49,7 @@ class HyperParamSetter(Callback):
""" """
ret = self._get_current_value() ret = self._get_current_value()
if ret is not None and ret != self.last_value: if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format( logger.info("{} at epoch {} will change to {}".format(
self.op_name, self.epoch_num, ret)) self.op_name, self.epoch_num, ret))
self.last_value = ret self.last_value = ret
return ret return ret
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from functools import wraps from functools import wraps
import six import six
import copy
from ..tfutils import * from ..tfutils import *
from ..tfutils.modelutils import * from ..tfutils.modelutils import *
...@@ -34,7 +35,7 @@ def layer_register(summary_activation=False, log_shape=True): ...@@ -34,7 +35,7 @@ def layer_register(summary_activation=False, log_shape=True):
inputs = args[0] inputs = args[0]
# update from current argument scope # update from current argument scope
actual_args = get_arg_scope()[func.__name__] actual_args = copy.copy(get_arg_scope()[func.__name__])
actual_args.update(kwargs) actual_args.update(kwargs)
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
......
...@@ -62,12 +62,14 @@ def LeakyReLU(x, alpha, name=None): ...@@ -62,12 +62,14 @@ def LeakyReLU(x, alpha, name=None):
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
def BNReLU(is_training): def BNReLU(is_training, **kwargs):
""" """
:param is_traning: boolean
:param kwargs: args for BatchNorm
:returns: a activation function that performs BN + ReLU (a too common combination) :returns: a activation function that performs BN + ReLU (a too common combination)
""" """
def BNReLU(x, name=None): def BNReLU(x, name=None):
x = BatchNorm('bn', x, is_training) x = BatchNorm('bn', x, is_training, **kwargs)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
return x return x
return BNReLU return BNReLU
...@@ -71,7 +71,7 @@ class Trainer(object): ...@@ -71,7 +71,7 @@ class Trainer(object):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Preparing for training...") logger.info("Preparing for training...")
self._init_summary() self._init_summary()
get_global_step_var() get_global_step_var() # ensure there is such var, before finalizing the graph
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.before_train(self) callbacks.before_train(self)
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment