Commit dfacc168 authored by Yuxin Wu's avatar Yuxin Wu

clean imports

parent bcf8dbfe
......@@ -11,8 +11,10 @@ try:
gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__ = ['GymEnv']
except ImportError:
logger.warn("Cannot import gym. GymEnv won't be available.")
__all__ = []
import threading
......@@ -20,7 +22,6 @@ from ..utils.fs import *
from ..utils.stat import *
from .envbase import RLEnvironment, DiscreteActionSpace
__all__ = ['GymEnv']
_ALE_LOCK = threading.Lock()
......
......@@ -8,8 +8,6 @@ import os
import time
from abc import abstractmethod, ABCMeta
from ..utils import *
__all__ = ['Callback', 'PeriodicCallback']
class Callback(object):
......
......@@ -7,8 +7,8 @@ from contextlib import contextmanager
import time
from .base import Callback
from .stat import *
from ..utils import *
from .stat import StatPrinter
from ..utils import logger
__all__ = ['Callbacks']
......
......@@ -11,10 +11,9 @@ import six
from six.moves import zip, map
from ..dataflow import DataFlow
from ..utils import *
from ..utils.stat import *
from ..tfutils import *
from ..tfutils.summary import *
from ..utils import get_tqdm_kwargs, logger
from ..utils.stat import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name
from .base import Callback
__all__ = ['InferenceRunner', 'ClassificationError',
......
......@@ -8,7 +8,7 @@ import operator
import json
from .base import Callback
from ..utils import *
from ..utils import logger
__all__ = ['StatHolder', 'StatPrinter', 'SendStat']
......
......@@ -7,9 +7,9 @@ from functools import wraps
import six
import copy, os
from ..tfutils import *
from ..tfutils.modelutils import *
from ..tfutils.summary import *
from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary
from ..utils import logger
# make sure each layer is only logged once
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
from copy import copy
import re
from .model_desc import get_current_tower_context
from ..tfutils.tower import get_current_tower_context
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from ._common import layer_register
......
......@@ -6,7 +6,7 @@
import numpy as np
import tensorflow as tf
import math
from ._common import *
from ._common import layer_register, shape2d, shape4d
from ..utils import map_arg, logger
__all__ = ['Conv2D']
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
import math
from ._common import layer_register
from ..tfutils.symbolic_functions import *
from ..tfutils import symbolic_functions as symbf
__all__ = ['FullyConnected']
......@@ -26,7 +26,7 @@ def FullyConnected(x, out_dim,
:param use_bias: whether to use bias. a boolean default to True
:returns: a 2D tensor
"""
x = batch_flatten(x)
x = symbf.batch_flatten(x)
in_dim = x.get_shape().as_list()[1]
if W_init is None:
......
......@@ -10,85 +10,13 @@ from collections import namedtuple
import inspect
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils import *
from ..tfutils.common import get_vars_by_names
from ..tfutils.gradproc import CheckGradient
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph',
'get_current_tower_context', 'TowerContext']
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0')
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
@property
def is_training(self):
return self._is_training
@property
def name(self):
return self._name
def get_variable_on_tower(self, *args, **kwargs):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False):
scope = tf.get_variable_scope()
assert scope.reuse == False
return tf.get_variable(*args, **kwargs)
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
if name.startswith('towerp'):
newname = re.sub('towerp[0-9]+/', '', name)
try:
return graph.get_tensor_by_name(newname)
except KeyError:
newname = re.sub('towerp[0-9]+/', 'tower0/', name)
return graph.get_tensor_by_name(newname)
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
class ModelDesc(object):
""" Base class for a model description """
__metaclass__ = ABCMeta
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
from copy import copy
from ._common import *
from ._common import layer_register
from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
......
......@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy
from ._common import *
from ..tfutils.symbolic_functions import *
from ._common import layer_register, shape2d, shape4d
from ..tfutils import symbolic_functions as symbf
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
......@@ -105,9 +105,9 @@ def FixedUnPooling(x, shape, unpool_mat=None):
assert unpool_mat.get_shape().as_list() == list(shape)
# perform a tensor-matrix kronecker product
fx = flatten(tf.transpose(x, [0, 3, 1, 2]))
fx = symbf.flatten(tf.transpose(x, [0, 3, 1, 2]))
fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(flatten(unpool_mat), 0) #1x(shxsw)
mat = tf.expand_dims(symbf.flatten(unpool_mat), 0) #1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw)
prod = tf.reshape(prod, tf.pack(
[-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
......
......@@ -6,8 +6,8 @@ import tensorflow as tf
import re
from ..utils import logger
from ..utils.utils import *
from .model_desc import get_current_tower_context
from ..utils.utils import memoized
from ..tfutils.tower import get_current_tower_context
from ._common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
......
......@@ -7,9 +7,8 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
import six
from ..models import TowerContext
from ..utils import logger
from ..tfutils import get_vars_by_names
from ..tfutils import get_vars_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
......
......@@ -9,7 +9,8 @@ from six.moves import zip
from tensorpack.models import ModelDesc
from ..utils import logger
from ..tfutils import *
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor
import multiprocessing
......
......@@ -9,12 +9,9 @@ import time
import six
from six.moves import queue, range, zip
from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
from ..utils import logger
from ..utils.timer import *
from ..tfutils import *
from .base import *
......
......@@ -17,4 +17,5 @@ _global_import('sessinit')
_global_import('common')
_global_import('gradproc')
_global_import('argscope')
_global_import('tower')
......@@ -6,6 +6,8 @@ import tensorflow as tf
from ..utils import logger
__all__ = ['describe_model', 'get_shape_str']
def describe_model():
""" print a description of the current model parameters """
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
import re
from ..utils import *
from .tower import get_current_tower_context
from . import get_global_step_var
from .symbolic_functions import rms
......@@ -28,6 +29,8 @@ def add_activation_summary(x, name=None):
Add summary to graph for an activation tensor x.
If name is None, use x.name.
"""
if not get_current_tower_context().is_main_training_tower:
return
ndim = x.get_shape().ndims
assert ndim >= 2, \
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
......@@ -46,6 +49,8 @@ def add_param_summary(summary_lists):
:param summary_lists: list of (regex, [list of summary type to perform]).
Type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
"""
if not get_current_tower_context().is_main_training_tower:
return
def perform(var, action):
ndim = var.get_shape().ndims
name = var.name.replace(':0', '')
......@@ -84,6 +89,8 @@ def add_moving_summary(v, *args):
:param v: tensor or list of tensor to summary
:param args: tensors to summary
"""
if not get_current_tower_context().is_main_training_tower:
return
if not isinstance(v, list):
v = [v]
v.extend(args)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tower.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
__all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0')
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
@property
def is_training(self):
return self._is_training
@property
def name(self):
return self._name
def get_variable_on_tower(self, *args, **kwargs):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False):
scope = tf.get_variable_scope()
assert scope.reuse == False
return tf.get_variable(*args, **kwargs)
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
if name.startswith('towerp'):
newname = re.sub('towerp[0-9]+/', '', name)
try:
return graph.get_tensor_by_name(newname)
except KeyError:
newname = re.sub('towerp[0-9]+/', 'tower0/', name)
return graph.get_tensor_by_name(newname)
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
......@@ -7,14 +7,13 @@ import tensorflow as tf
import itertools, re
from six.moves import zip, range
from ..models import TowerContext
from ..utils import logger
from ..utils.naming import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var)
get_global_step_var, TowerContext)
from .trainer import QueueInputTrainer
......
......@@ -11,10 +11,9 @@ from .base import Trainer
from ..dataflow.common import RepeatedData
from ..models import TowerContext
from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var)
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
......@@ -67,7 +66,7 @@ class SimpleTrainer(Trainer):
with TowerContext(''):
model.build_graph(self.input_vars)
cost_var = model.get_cost() # TODO assert scalar
add_moving_summary(cost_var)
add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment