Commit 6e080435 authored by Yuxin Wu's avatar Yuxin Wu

small fixes and imports

parent 2634a254
...@@ -5,20 +5,13 @@ ...@@ -5,20 +5,13 @@
import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034 import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924 import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924
from . import models from tensorpack.train import *
from . import train from tensorpack.models import *
from . import utils from tensorpack.utils import *
from . import tfutils from tensorpack.tfutils import *
from . import callbacks from tensorpack.callbacks import *
from . import dataflow from tensorpack.dataflow import *
from tensorpack.predict import *
from .train import *
from .models import *
from .utils import *
from .tfutils import *
from .callbacks import *
from .dataflow import *
from .predict import *
if int(numpy.__version__.split('.')[1]) < 9: if int(numpy.__version__.split('.')[1]) < 9:
logger.warn("Numpy < 1.9 could be extremely slow on some tasks.") logger.warn("Numpy < 1.9 could be extremely slow on some tasks.")
...@@ -7,7 +7,7 @@ import os, shutil ...@@ -7,7 +7,7 @@ import os, shutil
import re import re
from .base import Callback from .base import Callback
from ..utils import * from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......
...@@ -393,7 +393,6 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -393,7 +393,6 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield v yield v
return return
def SelectComponent(ds, idxs): def SelectComponent(ds, idxs):
""" """
:param ds: a :mod:`DataFlow` instance :param ds: a :mod:`DataFlow` instance
......
...@@ -65,7 +65,8 @@ class DataFromList(RNGDataFlow): ...@@ -65,7 +65,8 @@ class DataFromList(RNGDataFlow):
for k in self.lst: for k in self.lst:
yield k yield k
else: else:
idxs = self.rng.shuffle(np.arange(len(self.lst))) idxs = np.arange(len(self.lst))
self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
yield self.lst[k] yield self.lst[k]
......
...@@ -10,11 +10,11 @@ import tqdm ...@@ -10,11 +10,11 @@ import tqdm
import tensorflow as tf import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import * from ..utils import logger, get_tqdm_kwargs
from ..utils.timer import * from ..utils.timer import timed_operation
from ..utils.concurrency import start_proc_mask_signal from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import * from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
__all__ = ['Trainer'] __all__ = ['Trainer']
...@@ -105,11 +105,12 @@ class Trainer(object): ...@@ -105,11 +105,12 @@ class Trainer(object):
def main_loop(self): def main_loop(self):
# some final operations that might modify the graph # some final operations that might modify the graph
self._init_summary()
get_global_step_var() # ensure there is such var, before finalizing the graph get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.setup_graph(self) # TODO use weakref instead? callbacks.setup_graph(self) # TODO use weakref instead?
self._init_summary()
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables()) self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
......
...@@ -6,8 +6,9 @@ import tensorflow as tf ...@@ -6,8 +6,9 @@ import tensorflow as tf
from ..callbacks import Callbacks from ..callbacks import Callbacks
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import * from ..utils import logger
from ..tfutils import * from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..dataflow import DataFlow from ..dataflow import DataFlow
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
......
...@@ -8,11 +8,12 @@ import itertools, re ...@@ -8,11 +8,12 @@ import itertools, re
from six.moves import zip, range from six.moves import zip, range
from ..models import TowerContext from ..models import TowerContext
from ..utils import * from ..utils import logger
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils import * from ..tfutils import (backup_collection, restore_collection,
get_global_step_var)
from .trainer import QueueInputTrainer from .trainer import QueueInputTrainer
......
...@@ -12,8 +12,9 @@ from .base import Trainer ...@@ -12,8 +12,9 @@ from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..models import TowerContext from ..models import TowerContext
from ..utils import * from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import * from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment