Commit fd635774 authored by Yuxin Wu's avatar Yuxin Wu

Python3 compatibility on import statements

parent e02d310c
...@@ -7,7 +7,7 @@ from pkgutil import walk_packages ...@@ -7,7 +7,7 @@ from pkgutil import walk_packages
import os import os
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
# File: dump.py # File: dump.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import Callback
import cv2
import os import os
import scipy.misc
from scipy.misc import imsave
from .base import Callback
from ..utils import logger from ..utils import logger
__all__ = ['DumpParamAsImage'] __all__ = ['DumpParamAsImage']
...@@ -39,12 +41,12 @@ class DumpParamAsImage(Callback): ...@@ -39,12 +41,12 @@ class DumpParamAsImage(Callback):
fname = os.path.join( fname = os.path.join(
self.log_dir, self.log_dir,
self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx)) self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx))
cv2.imwrite(fname, im * self.scale) imsave(fname, (im * self.scale).astype('uint8'))
else: else:
im = val im = val
assert im.ndim in [2, 3] assert im.ndim in [2, 3]
fname = os.path.join( fname = os.path.join(
self.log_dir, self.log_dir,
self.prefix + '-ep{:03d}.png'.format(self.epoch_num)) self.prefix + '-ep{:03d}.png'.format(self.epoch_num))
cv2.imwrite(fname, im * self.scale) imsave(fname, (im * self.scale).astype('uint8'))
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
import re import re
import os import os
import operator import operator
import cPickle as pickle import pickle
from .base import Callback, PeriodicCallback from .base import Callback, PeriodicCallback
from ..utils import * from ..utils import *
......
...@@ -6,20 +6,20 @@ ...@@ -6,20 +6,20 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
import os.path import os.path
import dataset
import imgaug
__SKIP = ['dftools', 'dataset'] from . import dataset
from . import imgaug
def global_import(name): def global_import(name):
if name in __SKIP: p = __import__(name, globals(), locals(), level=1)
return
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
__SKIP = ['dftools', 'dataset']
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_'): if not module_name.startswith('_') and \
module_name not in __SKIP:
global_import(module_name) global_import(module_name)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import numpy as np import numpy as np
from .base import DataFlow from .base import DataFlow
from imgaug import AugmentorList, Image from .imgaug import AugmentorList, Image
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'AugmentImageComponent'] 'AugmentImageComponent']
...@@ -94,7 +94,7 @@ class RepeatedData(DataFlow): ...@@ -94,7 +94,7 @@ class RepeatedData(DataFlow):
def size(self): def size(self):
if self.nr == -1: if self.nr == -1:
raise RuntimeError(), "size() is unavailable for infinite dataflow" raise RuntimeError("size() is unavailable for infinite dataflow")
return self.ds.size() * self.nr return self.ds.size() * self.nr
def get_data(self): def get_data(self):
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import os.path import os.path
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: cifar10.py # File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys import os, sys
import cPickle import pickle
import numpy import numpy
from six.moves import urllib from six.moves import urllib
import tarfile import tarfile
...@@ -41,7 +41,7 @@ def maybe_download_and_extract(dest_directory): ...@@ -41,7 +41,7 @@ def maybe_download_and_extract(dest_directory):
def read_cifar10(filenames): def read_cifar10(filenames):
for fname in filenames: for fname in filenames:
fo = open(fname, 'rb') fo = open(fname, 'rb')
dic = cPickle.load(fo) dic = pickle.load(fo)
data = dic['data'] data = dic['data']
label = dic['labels'] label = dic['labels']
fo.close() fo.close()
......
...@@ -9,7 +9,7 @@ from pkgutil import walk_packages ...@@ -9,7 +9,7 @@ from pkgutil import walk_packages
__all__ = [] __all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import os.path import os.path
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -82,7 +82,7 @@ def ImageSample(inputs): ...@@ -82,7 +82,7 @@ def ImageSample(inputs):
sample(template, lyux) * neg_diffy * diffx, sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled') sample(template, uylx) * diffy * neg_diffx], name='sampled')
from _test import TestModel from ._test import TestModel
class TestSample(TestModel): class TestSample(TestModel):
def test_sample(self): def test_sample(self):
import numpy as np import numpy as np
...@@ -140,9 +140,9 @@ if __name__ == '__main__': ...@@ -140,9 +140,9 @@ if __name__ == '__main__':
out = sess.run(tf.gradients(tf.reduce_sum(output), mapv)) out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output) #out = sess.run(output)
print out[0].min() print(out[0].min())
print out[0].max() print(out[0].max())
print out[0].sum() print(out[0].sum())
out = sess.run([output])[0] out = sess.run([output])[0]
im = out[0] im = out[0]
......
...@@ -58,7 +58,7 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -58,7 +58,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
input_shape[3]]) input_shape[3]])
return prod return prod
from _test import TestModel from ._test import TestModel
class TestPool(TestModel): class TestPool(TestModel):
def test_fixed_unpooling(self): def test_fixed_unpooling(self):
h, w = 3, 4 h, w = 3, 4
......
...@@ -4,16 +4,16 @@ ...@@ -4,16 +4,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from itertools import count, izip from itertools import count
import argparse import argparse
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from utils import * from .utils import *
from utils.modelutils import describe_model from .utils.modelutils import describe_model
from utils import logger from .utils import logger
from dataflow import DataFlow, BatchData from .dataflow import DataFlow, BatchData
class PredictConfig(object): class PredictConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -93,7 +93,7 @@ def get_predict_func(config): ...@@ -93,7 +93,7 @@ def get_predict_func(config):
assert len(input_map) == len(dp), \ assert len(input_map) == len(dp), \
"Graph has {} inputs but dataset only gives {} components!".format( "Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp)) len(input_map), len(dp))
feed = dict(izip(input_map, dp)) feed = dict(zip(input_map, dp))
if output_var_names is not None: if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed) results = sess.run(output_vars, feed_dict=feed)
return results return results
......
...@@ -8,17 +8,17 @@ from itertools import count ...@@ -8,17 +8,17 @@ from itertools import count
import copy import copy
import argparse import argparse
import re import re
import tqdm import tqdm
from models import ModelDesc
from dataflow.common import RepeatedData from .models import ModelDesc
from utils import * from .dataflow.common import RepeatedData
from utils.concurrency import EnqueueThread from .utils import *
from callbacks import * from .utils.concurrency import EnqueueThread
from utils.summary import summary_moving_average from .callbacks import *
from utils.modelutils import describe_model from .utils.summary import summary_moving_average
from utils import logger from .utils.modelutils import describe_model
from dataflow import DataFlow from .utils import logger
from .dataflow import DataFlow
class TrainConfig(object): class TrainConfig(object):
""" config for training""" """ config for training"""
......
...@@ -12,10 +12,10 @@ import tensorflow as tf ...@@ -12,10 +12,10 @@ import tensorflow as tf
import numpy as np import numpy as np
import collections import collections
import logger from . import logger
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -5,12 +5,11 @@ ...@@ -5,12 +5,11 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from itertools import izip
import tensorflow as tf import tensorflow as tf
from .utils import expand_dim_if_necessary from .utils import expand_dim_if_necessary
from .naming import * from .naming import *
import logger from . import logger
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
def __init__(self): def __init__(self):
...@@ -43,7 +42,7 @@ class EnqueueThread(threading.Thread): ...@@ -43,7 +42,7 @@ class EnqueueThread(threading.Thread):
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(izip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed) self.sess.run([self.op], feed_dict=feed)
#print '\nExauhsted!!!' #print '\nExauhsted!!!'
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
......
...@@ -21,6 +21,7 @@ class MyFormatter(logging.Formatter): ...@@ -21,6 +21,7 @@ class MyFormatter(logging.Formatter):
fmt = date + ' ' + colored('ERR', 'red', attrs=['blink', 'underline']) + ' ' + msg fmt = date + ' ' + colored('ERR', 'red', attrs=['blink', 'underline']) + ' ' + msg
else: else:
fmt = date + ' ' + msg fmt = date + ' ' + msg
# TODO this doesn't work in Python3
self._fmt = fmt self._fmt = fmt
return super(MyFormatter, self).format(record) return super(MyFormatter, self).format(record)
...@@ -31,6 +32,7 @@ def getlogger(): ...@@ -31,6 +32,7 @@ def getlogger():
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S')) handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
logger.addHandler(handler) logger.addHandler(handler)
logger.warn("hahah")
return logger return logger
logger = getlogger() logger = getlogger()
...@@ -53,7 +55,7 @@ def set_logger_dir(dirname): ...@@ -53,7 +55,7 @@ def set_logger_dir(dirname):
if os.path.isdir(dirname): if os.path.isdir(dirname):
logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." ) logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." )
logger.info("Select Action: k (keep) / b (backup) / d (delete):") logger.info("Select Action: k (keep) / b (backup) / d (delete):")
act = raw_input().lower() act = input().lower()
if act == 'b': if act == 'b':
from datetime import datetime from datetime import datetime
backup_name = dirname + datetime.now().strftime('.%d-%H%M%S') backup_name = dirname + datetime.now().strftime('.%d-%H%M%S')
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import logger
from . import logger
def describe_model(): def describe_model():
""" describe the current model parameters""" """ describe the current model parameters"""
......
...@@ -8,7 +8,7 @@ from abc import abstractmethod, ABCMeta ...@@ -8,7 +8,7 @@ from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import logger from . import logger
class SessionInit(object): class SessionInit(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import logger
from naming import * from . import logger
from .naming import *
def create_summary(name, v): def create_summary(name, v):
""" """
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
# use user-space protobuf # use user-space protobuf
import sys, os #import sys, os
site = os.path.join(os.environ['HOME'], #site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages') #'.local/lib/python2.7/site-packages')
sys.path.insert(0, site) #sys.path.insert(0, site)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment