Commit deaabc90 authored by Yuxin Wu's avatar Yuxin Wu

gym submission

parent 7e91eb48
...@@ -3,29 +3,6 @@ ...@@ -3,29 +3,6 @@
The following guide introduces some core concepts of TensorPack. In contrast to several other libraries TensorPack contains of several modules to build complex deep learning algorithms and train models with high accuracy and high speed. The following guide introduces some core concepts of TensorPack. In contrast to several other libraries TensorPack contains of several modules to build complex deep learning algorithms and train models with high accuracy and high speed.
### Training
Given TensorFlow's optimizers this library provides several training protocols even for efficient multi-GPU environments. There is support for single GPU, training on one machine with multiple GPUs (synchron or asyncron), training of Generative Adversarial networks and reinforcement learning.
You only need to configure your training protocol like
````python
config = TrainConfig(
dataflow=my_dataflow,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ModelSaver(), ...]),
model=Model())
# start training
SimpleTrainer(config).train()
````
Switching between single-GPU and multi-GPU is as easy as replace the last line with
````python
# start multi-GPUtraining
SyncMultiGPUTrainer(config).train()
````
### Callbacks ### Callbacks
The use of callbacks add the flexibility to execute code during training. These callbacks are triggered on several events such as after each step or at the end of one training epoch. The use of callbacks add the flexibility to execute code during training. These callbacks are triggered on several events such as after each step or at the end of one training epoch.
# Trainers
## Trainer
Training is just "running something again and again".
Tensorpack base trainer implements the logic of *running the iteration*,
and other trainers implement what the iteration is.
Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will by default optimizes `ModelDesc.cost`,
therefore you can use these trainers as long as you set `self.cost` in `ModelDesc._build_graph()`.
The existing trainers were implemented with a TensorFlow queue to prefetch and buffer
training data, which is significantly faster than
a naive `sess.run(..., feed_dict={...})` you might use.
There are also multi-GPU trainers which includes the logic of data-parallel multi-GPU training,
with either synchronous update or asynchronous update. You can enable multi-GPU training
by just changing one line.
To use trainers, pass a `TrainConfig` to configure them:
````python
config = TrainConfig(
dataflow=my_dataflow,
optimizer=tf.train.AdamOptimizer(0.01),
callbacks=[...]
model=MyModel()
)
# start training:
# SimpleTrainer(config).train()
# start training with queue prefetch:
# QueueInputTrainer(config).train()
# start multi-GPU training with synchronous update:
SyncMultiGPUTrainer(config).train()
````
Trainers just run some iterations, so there is no limit in where the data come from
or what to do in an iteration.
For example, [GAN trainer](../examples/GAN/GAN.py) minimizes
two cost functions alternatively.
Some trainer takes data from a TensorFlow reading pipeline instead of a Dataflow
([PTB example](../examples/PennTreebank)).
## Develop trainers
The existing trainers should be enough for single-cost optimization tasks. If you
want to do something inside the trainer, considering writing it as a callback, or
submit an issue to see if there is a better solution than creating new trainers.
For other tasks, you might need a new trainer.
The [GAN trainer](../examples/GAN/GAN.py) is one example of how to implement
new trainers.
More details to come.
...@@ -41,7 +41,6 @@ class TransitionExperience(object): ...@@ -41,7 +41,6 @@ class TransitionExperience(object):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process): class SimulatorProcessBase(mp.Process):
def __init__(self, idx): def __init__(self, idx):
super(SimulatorProcessBase, self).__init__() super(SimulatorProcessBase, self).__init__()
self.idx = int(idx) self.idx = int(idx)
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = F403,F401,F405,F841 ignore = F403,F401,F405,F841
exclude = private
...@@ -48,7 +48,6 @@ class GymEnv(RLEnvironment): ...@@ -48,7 +48,6 @@ class GymEnv(RLEnvironment):
self._ob = self.gymenv.reset() self._ob = self.gymenv.reset()
def finish_episode(self): def finish_episode(self):
self.gymenv.close()
self.stats['score'].append(self.rwd_counter.sum) self.stats['score'].append(self.rwd_counter.sum)
def current_state(self): def current_state(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment