Commit 7e5d3a85 authored by Yuxin Wu's avatar Yuxin Wu

simplify (and fix) some examples

parent d1041a77
......@@ -116,7 +116,6 @@ def get_config():
ds = CharRNNData(param.corpus, 100000)
ds = BatchData(ds, param.batch_size)
steps_per_epoch = ds.size()
return TrainConfig(
dataflow=ds,
......@@ -125,12 +124,9 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
],
model=Model(),
steps_per_epoch=steps_per_epoch,
max_epoch=50,
)
# TODO rewrite using Predictor interface
def sample(path, start, length):
"""
......
......@@ -158,7 +158,6 @@ def get_data():
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
# Image-to-Image translation mode
ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
......
......@@ -151,7 +151,7 @@ class Model(GANModelDesc):
opt = tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6)
# generator learns 5 times faster
return optimizer.apply_grad_processors(
opt, [gradproc.ScaleGradient(('.*', 5), log=False)])
opt, [gradproc.ScaleGradient(('gen/.*', 5), log=True)])
def get_data():
......
......@@ -156,9 +156,7 @@ def get_data(train_or_test):
def get_config():
logger.auto_set_dir()
# prepare dataset
dataset_train = get_data('train')
steps_per_epoch = 5000
dataset_val = get_data('val')
return TrainConfig(
......@@ -175,7 +173,7 @@ def get_config():
],
session_config=get_default_sess_config(0.99),
model=Model(),
steps_per_epoch=steps_per_epoch,
steps_per_epoch=5000,
max_epoch=80,
)
......
......@@ -139,10 +139,7 @@ def get_data(train_or_test):
def get_config():
logger.auto_set_dir()
# prepare dataset
dataset_train = get_data('train')
steps_per_epoch = dataset_train.size()
dataset_test = get_data('test')
return TrainConfig(
......@@ -155,7 +152,6 @@ def get_config():
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
],
model=Model(n=NUM_UNITS),
steps_per_epoch=steps_per_epoch,
max_epoch=400,
)
......
......@@ -137,15 +137,12 @@ class SoftTripletModel(TripletModel):
def get_config(model, algorithm_name):
logger.auto_set_dir()
dataset = model.get_data()
steps_per_epoch = dataset.size()
extra_display = ["cost"]
if not algorithm_name == "cosine":
extra_display = extra_display + ["loss/pos-dist", "loss/neg-dist"]
return TrainConfig(
dataflow=dataset,
dataflow=model.get_data(),
model=model(),
callbacks=[
ModelSaver(),
......@@ -155,7 +152,6 @@ def get_config(model, algorithm_name):
MovingAverageSummary(),
ProgressBar(extra_display),
StatPrinter()],
steps_per_epoch=steps_per_epoch,
max_epoch=20,
)
......
......@@ -154,6 +154,7 @@ def get_config():
steps_per_epoch = dataset_train.size() * 5
return TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
......@@ -162,7 +163,6 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=steps_per_epoch,
max_epoch=500,
)
......
......@@ -110,7 +110,6 @@ def get_config(cifar_classnum):
# prepare dataset
dataset_train = get_data('train', cifar_classnum)
steps_per_epoch = dataset_train.size()
dataset_test = get_data('test', cifar_classnum)
sess_config = get_default_sess_config(0.5)
......@@ -120,6 +119,7 @@ def get_config(cifar_classnum):
raise StopTraining()
return lr * 0.31
return TrainConfig(
model=Model(cifar_classnum),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
......@@ -128,8 +128,6 @@ def get_config(cifar_classnum):
threshold=0.001, last_k=10),
],
session_config=sess_config,
model=Model(cifar_classnum),
steps_per_epoch=steps_per_epoch,
max_epoch=150,
)
......
......@@ -140,6 +140,7 @@ def get_config():
# get the config which contains everything necessary in a training
return TrainConfig(
model=Model(),
dataflow=dataset_train, # the DataFlow instance for training
callbacks=[
ModelSaver(), # save the model after every epoch
......@@ -148,7 +149,6 @@ def get_config():
# Calculate both the cost and the error for this DataFlow
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
],
model=Model(),
steps_per_epoch=steps_per_epoch,
max_epoch=100,
)
......
......@@ -96,17 +96,15 @@ def get_data():
def get_config():
logger.auto_set_dir()
data_train, data_test = get_data()
steps_per_epoch = data_train.size()
return TrainConfig(
model=Model(),
dataflow=data_train,
callbacks=[
ModelSaver(),
InferenceRunner(data_test,
[ScalarStats('cost'), ClassificationError()])
],
model=Model(),
steps_per_epoch=steps_per_epoch,
max_epoch=350,
)
......
......@@ -37,7 +37,7 @@ class ImageFromFile(RNGDataFlow):
for f in self.files:
im = cv2.imread(f, self.imread_mode)
if self.channel == 3:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = im[:, :, ::-1]
if self.resize is not None:
im = cv2.resize(im, self.resize[::-1])
if self.channel == 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment