Commit c100cae2 authored by Yuxin Wu's avatar Yuxin Wu

globvars in char-rnn

parent 6f28c94e
...@@ -26,7 +26,7 @@ To visualize: ...@@ -26,7 +26,7 @@ To visualize:
""" """
SHAPE = 256 SHAPE = 256
BATCH = 16 BATCH = 4
IN_CH = 3 IN_CH = 3
OUT_CH = 3 OUT_CH = 3
LAMBDA = 100 LAMBDA = 100
...@@ -108,7 +108,7 @@ class Model(ModelDesc): ...@@ -108,7 +108,7 @@ class Model(ModelDesc):
self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss') self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss')
add_moving_summary(errL1, self.g_loss) add_moving_summary(errL1, self.g_loss)
# visualization # tensorboard visualization
if IN_CH == 1: if IN_CH == 1:
input = tf.image.grayscale_to_rgb(input) input = tf.image.grayscale_to_rgb(input)
if OUT_CH == 1: if OUT_CH == 1:
...@@ -186,7 +186,7 @@ if __name__ == '__main__': ...@@ -186,7 +186,7 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='A directory of images') parser.add_argument('--data', help='A directory of 512x256 images')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB') parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
global args global args
args = parser.parse_args() args = parser.parse_args()
......
...@@ -15,16 +15,11 @@ from six.moves import map, range ...@@ -15,16 +15,11 @@ from six.moves import map, range
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import * from tensorpack.tfutils.gradproc import *
from tensorpack.utils.lut import LookUpTable from tensorpack.utils.lut import LookUpTable
from tensorpack.utils.globvars import globalns as param
from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
if six.PY2:
class NS: pass # this is a hack
else:
import types
NS = types.SimpleNamespace # this is what I wanted..
param = NS()
# some model hyperparams to set # some model hyperparams to set
param.batch_size = 128 param.batch_size = 128
param.rnn_size = 256 param.rnn_size = 256
...@@ -118,9 +113,7 @@ def get_config(): ...@@ -118,9 +113,7 @@ def get_config():
dataset=ds, dataset=ds,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(), ModelSaver(),
ModelSaver(),
#HumanHyperParamSetter('learning_rate', 'hyper.txt')
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)]) ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
]), ]),
model=Model(), model=Model(),
...@@ -128,6 +121,7 @@ def get_config(): ...@@ -128,6 +121,7 @@ def get_config():
max_epoch=50, max_epoch=50,
) )
# TODO rewrite using Predictor interface
def sample(path, start, length): def sample(path, start, length):
""" """
:param path: path to the model :param path: path to the model
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: globvars.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
__all__ = ['globalns']
if six.PY2:
class NS: pass
else:
import types
NS = types.SimpleNamespace
globalns = NS()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment