Commit 901727fc authored by Yuxin Wu's avatar Yuxin Wu

DCGAN family change

parent e1cfbef8
......@@ -131,35 +131,27 @@ class Model(GANModelDesc):
return opt
DCGAN.Model = Model
def get_config():
return TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(G.data),
callbacks=[
ModelSaver(),
StatMonitorParamSetter(
'learning_rate', 'measure', lambda x: x * 0.5, 0, 10)
],
steps_per_epoch=500,
max_epoch=400,
)
if __name__ == '__main__':
args = DCGAN.get_args()
if args.sample:
DCGAN.sample(args.load, 'gen/conv4.3/output')
DCGAN.sample(Model(), args.load, 'gen/conv4.3/output')
else:
assert args.data
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
nr_gpu = get_nr_gpu()
config.nr_tower = max(get_nr_gpu(), 1)
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[
ModelSaver(),
StatMonitorParamSetter(
'learning_rate', 'measure', lambda x: x * 0.5, 0, 10)
],
steps_per_epoch=500,
max_epoch=400,
session_init=SaverRestore(args.load) if args.load else None,
nr_tower=max(get_nr_gpu(), 1)
)
if config.nr_tower == 1:
GANTrainer(config).train()
else:
......
......@@ -115,7 +115,7 @@ def get_data(datadir):
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, opt.BATCH)
ds = PrefetchDataZMQ(ds, 1)
ds = PrefetchDataZMQ(ds, 5)
return ds
......@@ -129,10 +129,10 @@ def get_config():
)
def sample(model_path, output_name='gen/gen'):
def sample(model, model_path, output_name='gen/gen'):
pred = PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
model=model,
input_names=['z'],
output_names=[output_name, 'z'])
pred = SimpleDatasetPredictor(pred, RandomZData((100, opt.Z_DIM)))
......@@ -162,7 +162,7 @@ def get_args():
if __name__ == '__main__':
args = get_args()
if args.sample:
sample(args.load)
sample(Model(), args.load)
else:
assert args.data
logger.auto_set_dir()
......
......@@ -93,7 +93,7 @@ DCGAN.Model = Model
if __name__ == '__main__':
args = DCGAN.get_args()
if args.sample:
DCGAN.sample(args.load)
DCGAN.sample(Model(), args.load)
else:
assert args.data
logger.auto_set_dir()
......
......@@ -54,7 +54,7 @@ if __name__ == '__main__':
args = DCGAN.get_args()
if args.sample:
DCGAN.sample(args.load)
DCGAN.sample(Model(), args.load)
else:
assert args.data
logger.auto_set_dir()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment