Commit 92ccfe0a authored by Yuxin Wu's avatar Yuxin Wu

Fix tensor name for batch norm (#663); plus some docs update.

parent 65a9fcc7
...@@ -26,14 +26,6 @@ tensorpack.utils.fs module ...@@ -26,14 +26,6 @@ tensorpack.utils.fs module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.utils.globvars module
--------------------------------
.. automodule:: tensorpack.utils.globvars
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.loadcaffe module tensorpack.utils.loadcaffe module
--------------------------------- ---------------------------------
......
...@@ -16,11 +16,15 @@ from six.moves import range ...@@ -16,11 +16,15 @@ from six.moves import range
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import summary, optimizer from tensorpack.tfutils import summary, optimizer
from tensorpack.tfutils.gradproc import GlobalNormClip from tensorpack.tfutils.gradproc import GlobalNormClip
from tensorpack.utils.globvars import globalns as param
import tensorflow as tf import tensorflow as tf
rnn = tf.contrib.rnn rnn = tf.contrib.rnn
class _NS: pass # noqa
param = _NS()
# some model hyperparams to set # some model hyperparams to set
param.batch_size = 128 param.batch_size = 128
param.rnn_size = 256 param.rnn_size = 256
......
...@@ -54,7 +54,7 @@ def update_bn_ema(xn, batch_mean, batch_var, ...@@ -54,7 +54,7 @@ def update_bn_ema(xn, batch_mean, batch_var,
else: else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return xn return tf.identity(xn, name='output')
def reshape_for_bn(param, ndims, chan, data_format): def reshape_for_bn(param, ndims, chan, data_format):
......
...@@ -47,7 +47,6 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -47,7 +47,6 @@ def get_default_sess_config(mem_fraction=0.99):
# conf.graph_options.rewrite_options.memory_optimization = \ # conf.graph_options.rewrite_options.memory_optimization = \
# rwc.RewriterConfig.HEURISTICS # rwc.RewriterConfig.HEURISTICS
# May hurt performance # May hurt performance
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 # conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
# conf.graph_options.place_pruned_graph = True # conf.graph_options.place_pruned_graph = True
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: interface.py # File: interface.py
import tensorflow as tf
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput) InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput)
from ..utils import logger from ..utils import logger
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment