Commit fdad5c4f authored by Yuxin Wu's avatar Yuxin Wu

update readme

parent 35c5a4e8
...@@ -3,7 +3,11 @@ Neural Network Toolbox on TensorFlow ...@@ -3,7 +3,11 @@ Neural Network Toolbox on TensorFlow
In development but usable. API might change a bit. In development but usable. API might change a bit.
See some interesting [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn. See some interesting [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn about the framework:
+ [Double-DQN for playing Atari games](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/Atari2600)
+ [ResNet for Cifar10 classification](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet)
+ [char-rnn language model](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/char-rnn)
## Features: ## Features:
......
...@@ -13,11 +13,8 @@ import six ...@@ -13,11 +13,8 @@ import six
from six.moves import map, range from six.moves import map, range
from tensorpack import * from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.tfutils.gradproc import * from tensorpack.tfutils.gradproc import *
from tensorpack.utils.lut import LookUpTable from tensorpack.utils.lut import LookUpTable
from tensorpack.callbacks import *
from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
...@@ -190,8 +187,6 @@ if __name__ == '__main__': ...@@ -190,8 +187,6 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if args.command == 'sample': if args.command == 'sample':
param.softmax_temprature = args.temperature param.softmax_temprature = args.temperature
...@@ -199,9 +194,8 @@ if __name__ == '__main__': ...@@ -199,9 +194,8 @@ if __name__ == '__main__':
sample(args.load, args.start, args.num) sample(args.load, args.start, args.num)
sys.exit() sys.exit()
else: else:
with tf.Graph().as_default(): config = get_config()
config = get_config() if args.load:
if args.load: config.session_init = SaverRestore(args.load)
config.session_init = SaverRestore(args.load) QueueInputTrainer(config).train()
QueueInputTrainer(config).train()
...@@ -70,8 +70,9 @@ class memoized(object): ...@@ -70,8 +70,9 @@ class memoized(object):
'''Support instance methods.''' '''Support instance methods.'''
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
def get_rng(self): def get_rng(obj=None):
seed = (id(self) + os.getpid() + """ obj: some object to use to generate random seed"""
seed = (id(obj) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment