Commit 7e5d3a85 authored by Yuxin Wu's avatar Yuxin Wu

simplify (and fix) some examples

parent d1041a77
...@@ -116,7 +116,6 @@ def get_config(): ...@@ -116,7 +116,6 @@ def get_config():
ds = CharRNNData(param.corpus, 100000) ds = CharRNNData(param.corpus, 100000)
ds = BatchData(ds, param.batch_size) ds = BatchData(ds, param.batch_size)
steps_per_epoch = ds.size()
return TrainConfig( return TrainConfig(
dataflow=ds, dataflow=ds,
...@@ -125,12 +124,9 @@ def get_config(): ...@@ -125,12 +124,9 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)]) ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
], ],
model=Model(), model=Model(),
steps_per_epoch=steps_per_epoch,
max_epoch=50, max_epoch=50,
) )
# TODO rewrite using Predictor interface
def sample(path, start, length): def sample(path, start, length):
""" """
......
...@@ -158,7 +158,6 @@ def get_data(): ...@@ -158,7 +158,6 @@ def get_data():
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
# Image-to-Image translation mode
ds = MapData(ds, lambda dp: split_input(dp[0])) ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)] augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
......
...@@ -151,7 +151,7 @@ class Model(GANModelDesc): ...@@ -151,7 +151,7 @@ class Model(GANModelDesc):
opt = tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6) opt = tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6)
# generator learns 5 times faster # generator learns 5 times faster
return optimizer.apply_grad_processors( return optimizer.apply_grad_processors(
opt, [gradproc.ScaleGradient(('.*', 5), log=False)]) opt, [gradproc.ScaleGradient(('gen/.*', 5), log=True)])
def get_data(): def get_data():
......
...@@ -156,9 +156,7 @@ def get_data(train_or_test): ...@@ -156,9 +156,7 @@ def get_data(train_or_test):
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
# prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
steps_per_epoch = 5000
dataset_val = get_data('val') dataset_val = get_data('val')
return TrainConfig( return TrainConfig(
...@@ -175,7 +173,7 @@ def get_config(): ...@@ -175,7 +173,7 @@ def get_config():
], ],
session_config=get_default_sess_config(0.99), session_config=get_default_sess_config(0.99),
model=Model(), model=Model(),
steps_per_epoch=steps_per_epoch, steps_per_epoch=5000,
max_epoch=80, max_epoch=80,
) )
......
...@@ -139,10 +139,7 @@ def get_data(train_or_test): ...@@ -139,10 +139,7 @@ def get_data(train_or_test):
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
# prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
steps_per_epoch = dataset_train.size()
dataset_test = get_data('test') dataset_test = get_data('test')
return TrainConfig( return TrainConfig(
...@@ -155,7 +152,6 @@ def get_config(): ...@@ -155,7 +152,6 @@ def get_config():
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
], ],
model=Model(n=NUM_UNITS), model=Model(n=NUM_UNITS),
steps_per_epoch=steps_per_epoch,
max_epoch=400, max_epoch=400,
) )
......
...@@ -137,15 +137,12 @@ class SoftTripletModel(TripletModel): ...@@ -137,15 +137,12 @@ class SoftTripletModel(TripletModel):
def get_config(model, algorithm_name): def get_config(model, algorithm_name):
logger.auto_set_dir() logger.auto_set_dir()
dataset = model.get_data()
steps_per_epoch = dataset.size()
extra_display = ["cost"] extra_display = ["cost"]
if not algorithm_name == "cosine": if not algorithm_name == "cosine":
extra_display = extra_display + ["loss/pos-dist", "loss/neg-dist"] extra_display = extra_display + ["loss/pos-dist", "loss/neg-dist"]
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=model.get_data(),
model=model(), model=model(),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -155,7 +152,6 @@ def get_config(model, algorithm_name): ...@@ -155,7 +152,6 @@ def get_config(model, algorithm_name):
MovingAverageSummary(), MovingAverageSummary(),
ProgressBar(extra_display), ProgressBar(extra_display),
StatPrinter()], StatPrinter()],
steps_per_epoch=steps_per_epoch,
max_epoch=20, max_epoch=20,
) )
......
...@@ -154,6 +154,7 @@ def get_config(): ...@@ -154,6 +154,7 @@ def get_config():
steps_per_epoch = dataset_train.size() * 5 steps_per_epoch = dataset_train.size() * 5
return TrainConfig( return TrainConfig(
model=Model(),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -162,7 +163,6 @@ def get_config(): ...@@ -162,7 +163,6 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
], ],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
max_epoch=500, max_epoch=500,
) )
......
...@@ -110,7 +110,6 @@ def get_config(cifar_classnum): ...@@ -110,7 +110,6 @@ def get_config(cifar_classnum):
# prepare dataset # prepare dataset
dataset_train = get_data('train', cifar_classnum) dataset_train = get_data('train', cifar_classnum)
steps_per_epoch = dataset_train.size()
dataset_test = get_data('test', cifar_classnum) dataset_test = get_data('test', cifar_classnum)
sess_config = get_default_sess_config(0.5) sess_config = get_default_sess_config(0.5)
...@@ -120,6 +119,7 @@ def get_config(cifar_classnum): ...@@ -120,6 +119,7 @@ def get_config(cifar_classnum):
raise StopTraining() raise StopTraining()
return lr * 0.31 return lr * 0.31
return TrainConfig( return TrainConfig(
model=Model(cifar_classnum),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -128,8 +128,6 @@ def get_config(cifar_classnum): ...@@ -128,8 +128,6 @@ def get_config(cifar_classnum):
threshold=0.001, last_k=10), threshold=0.001, last_k=10),
], ],
session_config=sess_config, session_config=sess_config,
model=Model(cifar_classnum),
steps_per_epoch=steps_per_epoch,
max_epoch=150, max_epoch=150,
) )
......
...@@ -140,6 +140,7 @@ def get_config(): ...@@ -140,6 +140,7 @@ def get_config():
# get the config which contains everything necessary in a training # get the config which contains everything necessary in a training
return TrainConfig( return TrainConfig(
model=Model(),
dataflow=dataset_train, # the DataFlow instance for training dataflow=dataset_train, # the DataFlow instance for training
callbacks=[ callbacks=[
ModelSaver(), # save the model after every epoch ModelSaver(), # save the model after every epoch
...@@ -148,7 +149,6 @@ def get_config(): ...@@ -148,7 +149,6 @@ def get_config():
# Calculate both the cost and the error for this DataFlow # Calculate both the cost and the error for this DataFlow
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]), [ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
], ],
model=Model(),
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
max_epoch=100, max_epoch=100,
) )
......
...@@ -96,17 +96,15 @@ def get_data(): ...@@ -96,17 +96,15 @@ def get_data():
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
data_train, data_test = get_data() data_train, data_test = get_data()
steps_per_epoch = data_train.size()
return TrainConfig( return TrainConfig(
model=Model(),
dataflow=data_train, dataflow=data_train,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(data_test, InferenceRunner(data_test,
[ScalarStats('cost'), ClassificationError()]) [ScalarStats('cost'), ClassificationError()])
], ],
model=Model(),
steps_per_epoch=steps_per_epoch,
max_epoch=350, max_epoch=350,
) )
......
...@@ -37,7 +37,7 @@ class ImageFromFile(RNGDataFlow): ...@@ -37,7 +37,7 @@ class ImageFromFile(RNGDataFlow):
for f in self.files: for f in self.files:
im = cv2.imread(f, self.imread_mode) im = cv2.imread(f, self.imread_mode)
if self.channel == 3: if self.channel == 3:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = im[:, :, ::-1]
if self.resize is not None: if self.resize is not None:
im = cv2.resize(im, self.resize[::-1]) im = cv2.resize(im, self.resize[::-1])
if self.channel == 1: if self.channel == 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment