Commit fd635774 authored by Yuxin Wu's avatar Yuxin Wu

Python3 compatibility on import statements

parent e02d310c
......@@ -7,7 +7,7 @@ from pkgutil import walk_packages
import os
def global_import(name):
p = __import__(name, globals(), locals())
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -3,9 +3,11 @@
# File: dump.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import Callback
import cv2
import os
import scipy.misc
from scipy.misc import imsave
from .base import Callback
from ..utils import logger
__all__ = ['DumpParamAsImage']
......@@ -39,12 +41,12 @@ class DumpParamAsImage(Callback):
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{:03d}-{}.png'.format(self.epoch_num, idx))
cv2.imwrite(fname, im * self.scale)
imsave(fname, (im * self.scale).astype('uint8'))
else:
im = val
assert im.ndim in [2, 3]
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{:03d}.png'.format(self.epoch_num))
cv2.imwrite(fname, im * self.scale)
imsave(fname, (im * self.scale).astype('uint8'))
......@@ -7,7 +7,7 @@ import tensorflow as tf
import re
import os
import operator
import cPickle as pickle
import pickle
from .base import Callback, PeriodicCallback
from ..utils import *
......
......@@ -6,20 +6,20 @@
from pkgutil import walk_packages
import os
import os.path
import dataset
import imgaug
__SKIP = ['dftools', 'dataset']
from . import dataset
from . import imgaug
def global_import(name):
if name in __SKIP:
return
p = __import__(name, globals(), locals())
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
__SKIP = ['dftools', 'dataset']
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
if not module_name.startswith('_') and \
module_name not in __SKIP:
global_import(module_name)
......@@ -5,7 +5,7 @@
import numpy as np
from .base import DataFlow
from imgaug import AugmentorList, Image
from .imgaug import AugmentorList, Image
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'AugmentImageComponent']
......@@ -94,7 +94,7 @@ class RepeatedData(DataFlow):
def size(self):
if self.nr == -1:
raise RuntimeError(), "size() is unavailable for infinite dataflow"
raise RuntimeError("size() is unavailable for infinite dataflow")
return self.ds.size() * self.nr
def get_data(self):
......
......@@ -8,7 +8,7 @@ import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals())
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -3,7 +3,7 @@
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys
import cPickle
import pickle
import numpy
from six.moves import urllib
import tarfile
......@@ -41,7 +41,7 @@ def maybe_download_and_extract(dest_directory):
def read_cifar10(filenames):
for fname in filenames:
fo = open(fname, 'rb')
dic = cPickle.load(fo)
dic = pickle.load(fo)
data = dic['data']
label = dic['labels']
fo.close()
......
......@@ -9,7 +9,7 @@ from pkgutil import walk_packages
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals())
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -8,7 +8,7 @@ import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals())
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -82,7 +82,7 @@ def ImageSample(inputs):
sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled')
from _test import TestModel
from ._test import TestModel
class TestSample(TestModel):
def test_sample(self):
import numpy as np
......@@ -140,9 +140,9 @@ if __name__ == '__main__':
out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
print out[0].min()
print out[0].max()
print out[0].sum()
print(out[0].min())
print(out[0].max())
print(out[0].sum())
out = sess.run([output])[0]
im = out[0]
......
......@@ -58,7 +58,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
input_shape[3]])
return prod
from _test import TestModel
from ._test import TestModel
class TestPool(TestModel):
def test_fixed_unpooling(self):
h, w = 3, 4
......
......@@ -4,16 +4,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from itertools import count, izip
from itertools import count
import argparse
from collections import namedtuple
import numpy as np
from tqdm import tqdm
from utils import *
from utils.modelutils import describe_model
from utils import logger
from dataflow import DataFlow, BatchData
from .utils import *
from .utils.modelutils import describe_model
from .utils import logger
from .dataflow import DataFlow, BatchData
class PredictConfig(object):
def __init__(self, **kwargs):
......@@ -93,7 +93,7 @@ def get_predict_func(config):
assert len(input_map) == len(dp), \
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(izip(input_map, dp))
feed = dict(zip(input_map, dp))
if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed)
return results
......
......@@ -8,17 +8,17 @@ from itertools import count
import copy
import argparse
import re
import tqdm
from models import ModelDesc
from dataflow.common import RepeatedData
from utils import *
from utils.concurrency import EnqueueThread
from callbacks import *
from utils.summary import summary_moving_average
from utils.modelutils import describe_model
from utils import logger
from dataflow import DataFlow
from .models import ModelDesc
from .dataflow.common import RepeatedData
from .utils import *
from .utils.concurrency import EnqueueThread
from .callbacks import *
from .utils.summary import summary_moving_average
from .utils.modelutils import describe_model
from .utils import logger
from .dataflow import DataFlow
class TrainConfig(object):
""" config for training"""
......
......@@ -12,10 +12,10 @@ import tensorflow as tf
import numpy as np
import collections
import logger
from . import logger
def global_import(name):
p = __import__(name, globals(), locals())
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -5,12 +5,11 @@
import threading
from contextlib import contextmanager
from itertools import izip
import tensorflow as tf
from .utils import expand_dim_if_necessary
from .naming import *
import logger
from . import logger
class StoppableThread(threading.Thread):
def __init__(self):
......@@ -43,7 +42,7 @@ class EnqueueThread(threading.Thread):
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(izip(self.input_vars, dp))
feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed)
#print '\nExauhsted!!!'
except tf.errors.CancelledError as e:
......
......@@ -21,6 +21,7 @@ class MyFormatter(logging.Formatter):
fmt = date + ' ' + colored('ERR', 'red', attrs=['blink', 'underline']) + ' ' + msg
else:
fmt = date + ' ' + msg
# TODO this doesn't work in Python3
self._fmt = fmt
return super(MyFormatter, self).format(record)
......@@ -31,6 +32,7 @@ def getlogger():
handler = logging.StreamHandler()
handler.setFormatter(MyFormatter(datefmt='%d %H:%M:%S'))
logger.addHandler(handler)
logger.warn("hahah")
return logger
logger = getlogger()
......@@ -53,7 +55,7 @@ def set_logger_dir(dirname):
if os.path.isdir(dirname):
logger.info("Directory {} exists. Please either backup or delete it unless you're continue from a paused task." )
logger.info("Select Action: k (keep) / b (backup) / d (delete):")
act = raw_input().lower()
act = input().lower()
if act == 'b':
from datetime import datetime
backup_name = dirname + datetime.now().strftime('.%d-%H%M%S')
......
......@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import logger
from . import logger
def describe_model():
""" describe the current model parameters"""
......
......@@ -8,7 +8,7 @@ from abc import abstractmethod, ABCMeta
import numpy as np
import tensorflow as tf
import logger
from . import logger
class SessionInit(object):
__metaclass__ = ABCMeta
......
......@@ -4,8 +4,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import logger
from naming import *
from . import logger
from .naming import *
def create_summary(name, v):
"""
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# use user-space protobuf
import sys, os
site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages')
sys.path.insert(0, site)
#import sys, os
#site = os.path.join(os.environ['HOME'],
#'.local/lib/python2.7/site-packages')
#sys.path.insert(0, site)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment