Commit c100cae2 authored by Yuxin Wu's avatar Yuxin Wu

globvars in char-rnn

parent 6f28c94e
......@@ -26,7 +26,7 @@ To visualize:
"""
SHAPE = 256
BATCH = 16
BATCH = 4
IN_CH = 3
OUT_CH = 3
LAMBDA = 100
......@@ -108,7 +108,7 @@ class Model(ModelDesc):
self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss')
add_moving_summary(errL1, self.g_loss)
# visualization
# tensorboard visualization
if IN_CH == 1:
input = tf.image.grayscale_to_rgb(input)
if OUT_CH == 1:
......@@ -186,7 +186,7 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='A directory of images')
parser.add_argument('--data', help='A directory of 512x256 images')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
global args
args = parser.parse_args()
......
......@@ -15,16 +15,11 @@ from six.moves import map, range
from tensorpack import *
from tensorpack.tfutils.gradproc import *
from tensorpack.utils.lut import LookUpTable
from tensorpack.utils.globvars import globalns as param
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn
if six.PY2:
class NS: pass # this is a hack
else:
import types
NS = types.SimpleNamespace # this is what I wanted..
param = NS()
# some model hyperparams to set
param.batch_size = 128
param.rnn_size = 256
......@@ -118,9 +113,7 @@ def get_config():
dataset=ds,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
#HumanHyperParamSetter('learning_rate', 'hyper.txt')
StatPrinter(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
]),
model=Model(),
......@@ -128,6 +121,7 @@ def get_config():
max_epoch=50,
)
# TODO rewrite using Predictor interface
def sample(path, start, length):
"""
:param path: path to the model
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: globvars.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
__all__ = ['globalns']
if six.PY2:
class NS: pass
else:
import types
NS = types.SimpleNamespace
globalns = NS()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment