Commit dfacc168 authored by Yuxin Wu's avatar Yuxin Wu

clean imports

parent bcf8dbfe
...@@ -11,8 +11,10 @@ try: ...@@ -11,8 +11,10 @@ try:
gym.undo_logger_setup() gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199 # https://github.com/openai/gym/pull/199
# not sure does it cause other problems # not sure does it cause other problems
__all__ = ['GymEnv']
except ImportError: except ImportError:
logger.warn("Cannot import gym. GymEnv won't be available.") logger.warn("Cannot import gym. GymEnv won't be available.")
__all__ = []
import threading import threading
...@@ -20,7 +22,6 @@ from ..utils.fs import * ...@@ -20,7 +22,6 @@ from ..utils.fs import *
from ..utils.stat import * from ..utils.stat import *
from .envbase import RLEnvironment, DiscreteActionSpace from .envbase import RLEnvironment, DiscreteActionSpace
__all__ = ['GymEnv']
_ALE_LOCK = threading.Lock() _ALE_LOCK = threading.Lock()
......
...@@ -8,8 +8,6 @@ import os ...@@ -8,8 +8,6 @@ import os
import time import time
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils import *
__all__ = ['Callback', 'PeriodicCallback'] __all__ = ['Callback', 'PeriodicCallback']
class Callback(object): class Callback(object):
......
...@@ -7,8 +7,8 @@ from contextlib import contextmanager ...@@ -7,8 +7,8 @@ from contextlib import contextmanager
import time import time
from .base import Callback from .base import Callback
from .stat import * from .stat import StatPrinter
from ..utils import * from ..utils import logger
__all__ = ['Callbacks'] __all__ = ['Callbacks']
......
...@@ -11,10 +11,9 @@ import six ...@@ -11,10 +11,9 @@ import six
from six.moves import zip, map from six.moves import zip, map
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import * from ..utils import get_tqdm_kwargs, logger
from ..utils.stat import * from ..utils.stat import RatioCounter, BinaryStatistics
from ..tfutils import * from ..tfutils import get_op_tensor_name
from ..tfutils.summary import *
from .base import Callback from .base import Callback
__all__ = ['InferenceRunner', 'ClassificationError', __all__ = ['InferenceRunner', 'ClassificationError',
......
...@@ -8,7 +8,7 @@ import operator ...@@ -8,7 +8,7 @@ import operator
import json import json
from .base import Callback from .base import Callback
from ..utils import * from ..utils import logger
__all__ = ['StatHolder', 'StatPrinter', 'SendStat'] __all__ = ['StatHolder', 'StatPrinter', 'SendStat']
......
...@@ -7,9 +7,9 @@ from functools import wraps ...@@ -7,9 +7,9 @@ from functools import wraps
import six import six
import copy, os import copy, os
from ..tfutils import * from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import * from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import * from ..tfutils.summary import add_activation_summary
from ..utils import logger from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
from copy import copy from copy import copy
import re import re
from .model_desc import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger, EXTRA_SAVE_VARS_KEY from ..utils import logger, EXTRA_SAVE_VARS_KEY
from ._common import layer_register from ._common import layer_register
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import math import math
from ._common import * from ._common import layer_register, shape2d, shape4d
from ..utils import map_arg, logger from ..utils import map_arg, logger
__all__ = ['Conv2D'] __all__ = ['Conv2D']
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
import math import math
from ._common import layer_register from ._common import layer_register
from ..tfutils.symbolic_functions import * from ..tfutils import symbolic_functions as symbf
__all__ = ['FullyConnected'] __all__ = ['FullyConnected']
...@@ -26,7 +26,7 @@ def FullyConnected(x, out_dim, ...@@ -26,7 +26,7 @@ def FullyConnected(x, out_dim,
:param use_bias: whether to use bias. a boolean default to True :param use_bias: whether to use bias. a boolean default to True
:returns: a 2D tensor :returns: a 2D tensor
""" """
x = batch_flatten(x) x = symbf.batch_flatten(x)
in_dim = x.get_shape().as_list()[1] in_dim = x.get_shape().as_list()[1]
if W_init is None: if W_init is None:
......
...@@ -10,85 +10,13 @@ from collections import namedtuple ...@@ -10,85 +10,13 @@ from collections import namedtuple
import inspect import inspect
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUT_VARS_KEY
from ..tfutils import * from ..tfutils.common import get_vars_by_names
from ..tfutils.gradproc import CheckGradient
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph', __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
'get_current_tower_context', 'TowerContext']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name']) InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0')
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
@property
def is_training(self):
return self._is_training
@property
def name(self):
return self._name
def get_variable_on_tower(self, *args, **kwargs):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False):
scope = tf.get_variable_scope()
assert scope.reuse == False
return tf.get_variable(*args, **kwargs)
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
if name.startswith('towerp'):
newname = re.sub('towerp[0-9]+/', '', name)
try:
return graph.get_tensor_by_name(newname)
except KeyError:
newname = re.sub('towerp[0-9]+/', 'tower0/', name)
return graph.get_tensor_by_name(newname)
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from copy import copy from copy import copy
from ._common import * from ._common import layer_register
from .batch_norm import BatchNorm from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
import tensorflow as tf import tensorflow as tf
import numpy import numpy
from ._common import * from ._common import layer_register, shape2d, shape4d
from ..tfutils.symbolic_functions import * from ..tfutils import symbolic_functions as symbf
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample'] 'BilinearUpSample']
...@@ -105,9 +105,9 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -105,9 +105,9 @@ def FixedUnPooling(x, shape, unpool_mat=None):
assert unpool_mat.get_shape().as_list() == list(shape) assert unpool_mat.get_shape().as_list() == list(shape)
# perform a tensor-matrix kronecker product # perform a tensor-matrix kronecker product
fx = flatten(tf.transpose(x, [0, 3, 1, 2])) fx = symbf.flatten(tf.transpose(x, [0, 3, 1, 2]))
fx = tf.expand_dims(fx, -1) # (bchw)x1 fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(flatten(unpool_mat), 0) #1x(shxsw) mat = tf.expand_dims(symbf.flatten(unpool_mat), 0) #1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw) prod = tf.matmul(fx, mat) #(bchw) x(shxsw)
prod = tf.reshape(prod, tf.pack( prod = tf.reshape(prod, tf.pack(
[-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]])) [-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
......
...@@ -6,8 +6,8 @@ import tensorflow as tf ...@@ -6,8 +6,8 @@ import tensorflow as tf
import re import re
from ..utils import logger from ..utils import logger
from ..utils.utils import * from ..utils.utils import memoized
from .model_desc import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ._common import layer_register from ._common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout'] __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
......
...@@ -7,9 +7,8 @@ from abc import abstractmethod, ABCMeta, abstractproperty ...@@ -7,9 +7,8 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf import tensorflow as tf
import six import six
from ..models import TowerContext
from ..utils import logger from ..utils import logger
from ..tfutils import get_vars_by_names from ..tfutils import get_vars_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor', __all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase', 'AsyncPredictorBase',
......
...@@ -9,7 +9,8 @@ from six.moves import zip ...@@ -9,7 +9,8 @@ from six.moves import zip
from tensorpack.models import ModelDesc from tensorpack.models import ModelDesc
from ..utils import logger from ..utils import logger
from ..tfutils import * from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor from .base import OfflinePredictor
import multiprocessing import multiprocessing
......
...@@ -9,12 +9,9 @@ import time ...@@ -9,12 +9,9 @@ import time
import six import six
from six.moves import queue, range, zip from six.moves import queue, range, zip
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import logger from ..utils import logger
from ..utils.timer import *
from ..tfutils import *
from .base import * from .base import *
......
...@@ -17,4 +17,5 @@ _global_import('sessinit') ...@@ -17,4 +17,5 @@ _global_import('sessinit')
_global_import('common') _global_import('common')
_global_import('gradproc') _global_import('gradproc')
_global_import('argscope') _global_import('argscope')
_global_import('tower')
...@@ -6,6 +6,8 @@ import tensorflow as tf ...@@ -6,6 +6,8 @@ import tensorflow as tf
from ..utils import logger from ..utils import logger
__all__ = ['describe_model', 'get_shape_str']
def describe_model(): def describe_model():
""" print a description of the current model parameters """ """ print a description of the current model parameters """
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
import re import re
from ..utils import * from ..utils import *
from .tower import get_current_tower_context
from . import get_global_step_var from . import get_global_step_var
from .symbolic_functions import rms from .symbolic_functions import rms
...@@ -28,6 +29,8 @@ def add_activation_summary(x, name=None): ...@@ -28,6 +29,8 @@ def add_activation_summary(x, name=None):
Add summary to graph for an activation tensor x. Add summary to graph for an activation tensor x.
If name is None, use x.name. If name is None, use x.name.
""" """
if not get_current_tower_context().is_main_training_tower:
return
ndim = x.get_shape().ndims ndim = x.get_shape().ndims
assert ndim >= 2, \ assert ndim >= 2, \
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!" "Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
...@@ -46,6 +49,8 @@ def add_param_summary(summary_lists): ...@@ -46,6 +49,8 @@ def add_param_summary(summary_lists):
:param summary_lists: list of (regex, [list of summary type to perform]). :param summary_lists: list of (regex, [list of summary type to perform]).
Type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms' Type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
""" """
if not get_current_tower_context().is_main_training_tower:
return
def perform(var, action): def perform(var, action):
ndim = var.get_shape().ndims ndim = var.get_shape().ndims
name = var.name.replace(':0', '') name = var.name.replace(':0', '')
...@@ -84,6 +89,8 @@ def add_moving_summary(v, *args): ...@@ -84,6 +89,8 @@ def add_moving_summary(v, *args):
:param v: tensor or list of tensor to summary :param v: tensor or list of tensor to summary
:param args: tensors to summary :param args: tensors to summary
""" """
if not get_current_tower_context().is_main_training_tower:
return
if not isinstance(v, list): if not isinstance(v, list):
v = [v] v = [v]
v.extend(args) v.extend(args)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tower.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
__all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0')
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
@property
def is_training(self):
return self._is_training
@property
def name(self):
return self._name
def get_variable_on_tower(self, *args, **kwargs):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False):
scope = tf.get_variable_scope()
assert scope.reuse == False
return tf.get_variable(*args, **kwargs)
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
if name.startswith('towerp'):
newname = re.sub('towerp[0-9]+/', '', name)
try:
return graph.get_tensor_by_name(newname)
except KeyError:
newname = re.sub('towerp[0-9]+/', 'tower0/', name)
return graph.get_tensor_by_name(newname)
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
...@@ -7,14 +7,13 @@ import tensorflow as tf ...@@ -7,14 +7,13 @@ import tensorflow as tf
import itertools, re import itertools, re
from six.moves import zip, range from six.moves import zip, range
from ..models import TowerContext
from ..utils import logger from ..utils import logger
from ..utils.naming import * from ..utils.naming import *
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection,
get_global_step_var) get_global_step_var, TowerContext)
from .trainer import QueueInputTrainer from .trainer import QueueInputTrainer
......
...@@ -11,10 +11,9 @@ from .base import Trainer ...@@ -11,10 +11,9 @@ from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..models import TowerContext
from ..utils import logger, SUMMARY_BACKUP_KEYS from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection, from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
...@@ -67,7 +66,7 @@ class SimpleTrainer(Trainer): ...@@ -67,7 +66,7 @@ class SimpleTrainer(Trainer):
with TowerContext(''): with TowerContext(''):
model.build_graph(self.input_vars) model.build_graph(self.input_vars)
cost_var = model.get_cost() # TODO assert scalar cost_var = model.get_cost() # TODO assert scalar
add_moving_summary(cost_var) add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads) grads = self.process_grads(grads)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment