Commit 6e080435 authored by Yuxin Wu's avatar Yuxin Wu

small fixes and imports

parent 2634a254
......@@ -5,20 +5,13 @@
import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924
from . import models
from . import train
from . import utils
from . import tfutils
from . import callbacks
from . import dataflow
from .train import *
from .models import *
from .utils import *
from .tfutils import *
from .callbacks import *
from .dataflow import *
from .predict import *
from tensorpack.train import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.predict import *
if int(numpy.__version__.split('.')[1]) < 9:
logger.warn("Numpy < 1.9 could be extremely slow on some tasks.")
......@@ -7,7 +7,7 @@ import os, shutil
import re
from .base import Callback
from ..utils import *
from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......
......@@ -393,7 +393,6 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield v
return
def SelectComponent(ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
......
......@@ -65,7 +65,8 @@ class DataFromList(RNGDataFlow):
for k in self.lst:
yield k
else:
idxs = self.rng.shuffle(np.arange(len(self.lst)))
idxs = np.arange(len(self.lst))
self.rng.shuffle(idxs)
for k in idxs:
yield self.lst[k]
......
......@@ -10,11 +10,11 @@ import tqdm
import tensorflow as tf
from .config import TrainConfig
from ..utils import *
from ..utils.timer import *
from ..utils import logger, get_tqdm_kwargs
from ..utils.timer import timed_operation
from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder
from ..tfutils import *
from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.summary import create_summary
__all__ = ['Trainer']
......@@ -105,11 +105,12 @@ class Trainer(object):
def main_loop(self):
# some final operations that might modify the graph
self._init_summary()
get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...")
callbacks = self.config.callbacks
callbacks.setup_graph(self) # TODO use weakref instead?
self._init_summary()
logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess)
......
......@@ -6,8 +6,9 @@ import tensorflow as tf
from ..callbacks import Callbacks
from ..models import ModelDesc
from ..utils import *
from ..tfutils import *
from ..utils import logger
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..dataflow import DataFlow
__all__ = ['TrainConfig']
......
......@@ -8,11 +8,12 @@ import itertools, re
from six.moves import zip, range
from ..models import TowerContext
from ..utils import *
from ..utils import logger
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import *
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var)
from .trainer import QueueInputTrainer
......
......@@ -12,8 +12,9 @@ from .base import Trainer
from ..dataflow.common import RepeatedData
from ..models import TowerContext
from ..utils import *
from ..tfutils import *
from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment