Commit 27746043 authored by Yuxin Wu's avatar Yuxin Wu

simplify code & loader

parent 425d3a30
...@@ -25,15 +25,9 @@ class RandomCrop(ImageAugmentor): ...@@ -25,15 +25,9 @@ class RandomCrop(ImageAugmentor):
assert orig_shape[0] >= self.crop_shape[0] \ assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape and orig_shape[1] >= self.crop_shape[1], orig_shape
diffh = orig_shape[0] - self.crop_shape[0] diffh = orig_shape[0] - self.crop_shape[0]
if diffh == 0: h0 = 0 if diffh == 0 else self.rng.randint(diffh)
h0 = 0
else:
h0 = self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1] diffw = orig_shape[1] - self.crop_shape[1]
if diffw == 0: w0 = 0 if diffw == 0 else self.rng.randint(diffw)
w0 = 0
else:
w0 = self.rng.randint(diffw)
return (h0, w0) return (h0, w0)
def _augment(self, img, param): def _augment(self, img, param):
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict from collections import defaultdict
import re import re
import numpy as np
import tensorflow as tf import tensorflow as tf
import six import six
...@@ -15,7 +16,7 @@ from .varmanip import SessionUpdate, get_savename_from_varname ...@@ -15,7 +16,7 @@ from .varmanip import SessionUpdate, get_savename_from_varname
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
'JustCurrentSession'] 'JustCurrentSession', 'get_model_loader']
# TODO they initialize_all at the beginning by default. # TODO they initialize_all at the beginning by default.
...@@ -179,3 +180,10 @@ class ChainInit(SessionInit): ...@@ -179,3 +180,10 @@ class ChainInit(SessionInit):
def _init(self, sess): def _init(self, sess):
for i in self.inits: for i in self.inits:
i.init(sess) i.init(sess)
def get_model_loader(filename):
if filename.endswith('.npy'):
return ParamRestore(np.load(filename, encoding='latin1').item())
else:
return SaverRestore(filename)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment