Commit 27746043 authored by Yuxin Wu's avatar Yuxin Wu

simplify code & loader

parent 425d3a30
......@@ -25,15 +25,9 @@ class RandomCrop(ImageAugmentor):
assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape
diffh = orig_shape[0] - self.crop_shape[0]
if diffh == 0:
h0 = 0
else:
h0 = self.rng.randint(diffh)
h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1]
if diffw == 0:
w0 = 0
else:
w0 = self.rng.randint(diffw)
w0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (h0, w0)
def _augment(self, img, param):
......
......@@ -6,6 +6,7 @@ import os
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import re
import numpy as np
import tensorflow as tf
import six
......@@ -15,7 +16,7 @@ from .varmanip import SessionUpdate, get_savename_from_varname
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
'JustCurrentSession']
'JustCurrentSession', 'get_model_loader']
# TODO they initialize_all at the beginning by default.
......@@ -179,3 +180,10 @@ class ChainInit(SessionInit):
def _init(self, sess):
for i in self.inits:
i.init(sess)
def get_model_loader(filename):
if filename.endswith('.npy'):
return ParamRestore(np.load(filename, encoding='latin1').item())
else:
return SaverRestore(filename)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment