Commit 57624ce2 authored by Matthew Hausknecht's avatar Matthew Hausknecht

Added example sarsa agent.

parent 618437ff
#Directories
FA_DIR = ./funcapprox
POLICY_DIR = ./policy
HFO_SRC_DIR = ../../src
HFO_LIB_DIR = ../../lib
#Includes
INCLUDES = -I$(FA_DIR) -I$(POLICY_DIR) -I$(HFO_SRC_DIR)
#Libs
FA_LIB = funcapprox
POLICY_LIB = policyagent
#Flags
CXXFLAGS = -g -Wall -std=c++11 -pthread
LDFLAGS = -l$(FA_LIB) -l$(POLICY_LIB) -lhfo -pthread
LDLIBS = -L$(FA_DIR) -L$(POLICY_DIR) -L$(HFO_LIB_DIR)
LINKEROPTIONS = -Wl,-rpath,$(HFO_LIB_DIR)
#Compiler
CXX = g++
#Sources
SRC = high_level_sarsa_agent.cpp
#Objects
OBJ = $(SRC:.cpp=.o)
#Target
TARGET = high_level_sarsa_agent
#Rules
.PHONY: $(FA_LIB)
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(FA_LIB):
$(MAKE) -C $(FA_DIR)
$(POLICY_LIB):
$(MAKE) -C $(POLICY_DIR)
$(TARGET): $(FA_LIB) $(POLICY_LIB) $(OBJ)
$(CXX) $(OBJ) $(CXXFLAGS) $(LDLIBS) $(LDFLAGS) -o $(TARGET) $(LINKEROPTIONS)
cleanfa:
$(MAKE) clean -C $(FA_DIR)
cleanpolicy:
$(MAKE) clean -C $(POLICY_DIR)
clean: cleanfa cleanpolicy
rm -f $(TARGET) $(OBJ) *~
#include "CMAC.h"
#define TILINGS_PER_GROUP 32
CMAC::CMAC(int numF, int numA, double r[], double m[], double res[]):FunctionApproximator(numF,numA){
for(int i = 0; i < numF; i++){
ranges[i] = r[i];
minValues[i] = m[i];
resolutions[i] = res[i];
}
minimumTrace = 0.01;
numNonzeroTraces = 0;
for(int i = 0; i < RL_MEMORY_SIZE; i++){
weights[i] = 0;
traces[i] = 0;
}
srand((unsigned int)0);
int tmp[2];
float tmpf[2];
colTab = new collision_table( RL_MEMORY_SIZE, 1 );
GetTiles(tmp, 1, 1, tmpf, 0);// A dummy call to set the hashing table
}
double CMAC::getRange(int i){
return ranges[i];
}
double CMAC::getMinValue(int i){
return minValues[i];
}
double CMAC::getResolution(int i){
return resolutions[i];
}
void CMAC::setState(double s[]){
FunctionApproximator::setState(s);
loadTiles();
}
void CMAC::updateWeights(double delta, double alpha){
double tmp = delta * alpha / numTilings;
for(int i = 0; i < numNonzeroTraces; i++){
int f = nonzeroTraces[i];
if(f > RL_MEMORY_SIZE || f < 0){
std::cerr << "f is too big or too small!!" << f << "\n";
}
weights[f] += tmp * traces[f];
}
}
// Decays all the (nonzero) traces by decay_rate, removing those below minimum_trace
void CMAC::decayTraces(double decayRate){
int f;
for(int loc = numNonzeroTraces - 1; loc >= 0; loc--){
f = nonzeroTraces[loc];
if(f > RL_MEMORY_SIZE || f < 0){
std::cerr << "DecayTraces: f out of range " << f << "\n";
}
traces[f] *= decayRate;
if(traces[f] < minimumTrace){
clearExistentTrace(f, loc);
}
}
}
// Clear any trace for feature f
void CMAC::clearTrace(int f){
if(f > RL_MEMORY_SIZE || f < 0){
std::cerr << "ClearTrace: f out of range " << f << "\n";
}
if(traces[f] != 0){
clearExistentTrace(f, nonzeroTracesInverse[f]);
}
}
// Clear the trace for feature f at location loc in the list of nonzero traces
void CMAC::clearExistentTrace(int f, int loc){
if(f > RL_MEMORY_SIZE || f < 0){
std::cerr << "ClearExistentTrace: f out of range " << f << "\n";
}
traces[f] = 0.0;
numNonzeroTraces--;
nonzeroTraces[loc] = nonzeroTraces[numNonzeroTraces];
nonzeroTracesInverse[nonzeroTraces[loc]] = loc;
}
// Set the trace for feature f to the given value, which must be positive
void CMAC::setTrace(int f, double newTraceValue){
if(f > RL_MEMORY_SIZE || f < 0){
std::cerr << "SetTraces: f out of range " << f << "\n";
}
if(traces[f] >= minimumTrace){
traces[f] = newTraceValue;// trace already exists
}
else{
while(numNonzeroTraces >= RL_MAX_NONZERO_TRACES){
increaseMinTrace();// ensure room for new trace
}
traces[f] = newTraceValue;
nonzeroTraces[numNonzeroTraces] = f;
nonzeroTracesInverse[f] = numNonzeroTraces;
numNonzeroTraces++;
}
}
// Set the trace for feature f to the given value, which must be positive
void CMAC::updateTrace(int f, double deltaTraceValue){
setTrace(f, traces[f] + deltaTraceValue);
}
// Try to make room for more traces by incrementing minimum_trace by 10%,
// culling any traces that fall below the new minimum
void CMAC::increaseMinTrace(){
minimumTrace *= 1.1;
std::cerr << "Changing minimum_trace to " << minimumTrace << std::endl;
for (int loc = numNonzeroTraces - 1; loc >= 0; loc--){ // necessary to loop downwards
int f = nonzeroTraces[loc];
if(traces[f] < minimumTrace){
clearExistentTrace(f, loc);
}
}
}
void CMAC::read(char *fileName){
std::fstream file;
file.open(fileName, std::ios::in | std::ios::binary);
file.read((char *) weights, RL_MEMORY_SIZE * sizeof(double));
unsigned long pos = file.tellg();
file.close();
colTab->restore(fileName, pos);
}
void CMAC::write(char *fileName){
std::fstream file;
file.open(fileName, std::ios::out | std::ios::binary);
file.write((char *) weights, RL_MEMORY_SIZE * sizeof(double));
unsigned long pos = file.tellp();
file.close();
colTab->save(fileName, pos);
}
void CMAC::reset(){
for (int i = 0; i < RL_MEMORY_SIZE; i++){
weights[i] = 0;
traces[i] = 0;
}
}
void CMAC::loadTiles(){
int tilingsPerGroup = TILINGS_PER_GROUP; /* num tilings per tiling group */
numTilings = 0;
/* These are the 'tiling groups' -- play here with representations */
/* One tiling for each state variable */
for(int v = 0; v < getNumFeatures(); v++){
for(int a = 0; a < getNumActions(); a++){
GetTiles1(&(tiles[a][numTilings]), tilingsPerGroup, colTab, state[v] / getResolution(v), a , v);
}
numTilings += tilingsPerGroup;
}
if(numTilings > RL_MAX_NUM_TILINGS){
std::cerr << "TOO MANY TILINGS! " << numTilings << "\n";
}
}
double CMAC::computeQ(int action){
double q = 0;
for(int j = 0; j < numTilings; j++){
q += weights[tiles[action][j]];
}
return q;
}
void CMAC::clearTraces(int action){
for(int j = 0; j < numTilings; j++){
clearTrace(tiles[action][j]);
}
}
void CMAC::updateTraces(int action){
for(int j = 0; j < numTilings; j++)//replace/set traces F[a]
setTrace(tiles[action][j], 1.0);
}
//Not implemented by CMAC
int CMAC::getNumWeights(){
return 0;
}
//Not implemented by CMAC
void CMAC::getWeights(double w[]){
}
//Not implemented by CMAC
void CMAC::setWeights(double w[]){
}
#ifndef CMAC_H
#define CMAC_H
#include <cmath>
#include "FuncApprox.h"
#include "tiles2.h"
#define RL_MEMORY_SIZE 1048576
#define RL_MAX_NONZERO_TRACES 100000
#define RL_MAX_NUM_TILINGS 6000
class CMAC: public FunctionApproximator{
protected:
int tiles[MAX_ACTIONS][RL_MAX_NUM_TILINGS];
double minimumTrace;
int nonzeroTraces[RL_MAX_NONZERO_TRACES];
int numNonzeroTraces;
int nonzeroTracesInverse[RL_MEMORY_SIZE];
double ranges[MAX_STATE_VARS];
double minValues[MAX_STATE_VARS];
double resolutions[MAX_STATE_VARS];
double weights[RL_MEMORY_SIZE];
double traces [RL_MEMORY_SIZE];
int numTilings;
collision_table *colTab;
void clearTrace(int f);
void clearExistentTrace(int f, int loc);
void setTrace(int f, double newTraceValue);
void updateTrace(int f, double deltaTraceValue);
void increaseMinTrace();
void reset();
void loadTiles();
double getRange(int i);
double getMinValue(int i);
double getResolution(int i);
public:
CMAC(int numF, int numA, double r[], double m[], double res[]);
void setState(double s[]);
void updateWeights(double delta, double alpha);
void decayTraces(double decayRate);
void read (char *fileName);
void write(char *fileName);
//Not implemented by CMAC
int getNumWeights();
void getWeights(double w[]);
void setWeights(double w[]);
double computeQ(int action);
void clearTraces(int action);
void updateTraces(int action);
};
#endif
#include "FuncApprox.h"
FunctionApproximator::FunctionApproximator(int numF, int numA){
numFeatures = numF;
numActions = numA;
}
void FunctionApproximator::setState(double s[]){
for(int i = 0; i < numFeatures; i++){
state[i] = s[i];
}
}
int FunctionApproximator::getNumFeatures(){
return numFeatures;
}
int FunctionApproximator::getNumActions(){
return numActions;
}
int FunctionApproximator::argMaxQ(){
int bestAction = 0;
double bestValue = computeQ(bestAction);
int numTies = 0;
double EPS = 1.0e-4;
for(int a = 1; a < getNumActions(); a++){
double q = computeQ(a);
if(fabs(q - bestValue) < EPS){
numTies++;
if(drand48() < (1.0 / (numTies + 1))){
bestAction = a;
bestValue = q;
}
}
else if(q > bestValue){
bestAction = a;
bestValue = q;
numTies = 0;
}
}
return bestAction;
}
double FunctionApproximator::bestQ(){
int bestAction = 0;
double bestValue = computeQ(bestAction);
int numTies = 0;
double EPS = 1.0e-4;
for(int a = 1; a < getNumActions(); a++){
double q = computeQ(a);
if(fabs(q - bestValue) < EPS){
numTies++;
if(drand48() < (1.0 / (numTies + 1))){
bestAction = a;
bestValue = q;
}
}
else if(q > bestValue){
bestAction = a;
bestValue = q;
numTies = 0;
}
}
return bestValue;
}
#ifndef FUNC_APPROX
#define FUNC_APPROX
#include <stdlib.h>
#include <math.h>
#define MAX_STATE_VARS 100
#define MAX_ACTIONS 10
class FunctionApproximator{
protected:
int numFeatures, numActions;
double state[MAX_STATE_VARS];
int getNumFeatures();
int getNumActions();
public:
FunctionApproximator(int numF, int numA);
virtual ~FunctionApproximator(){}
virtual void setState(double s[]);
virtual double computeQ(int action) = 0;
virtual int argMaxQ();
virtual double bestQ();
virtual void updateWeights(double delta, double alpha) = 0;
virtual void clearTraces(int action) = 0;
virtual void decayTraces(double decayRate) = 0;
virtual void updateTraces(int action) = 0;
virtual void read (char *fileName) = 0;
virtual void write(char *fileName) = 0;
virtual int getNumWeights() = 0;
virtual void getWeights(double w[]) = 0;
virtual void setWeights(double w[]) = 0;
virtual void reset() = 0;
};
#endif
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = FuncApprox.cpp tiles2.cpp CMAC.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libfuncapprox.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
This diff is collapsed.
/*
This is Version 2.0 of Rich Sutton's Tile Coding Software
available from his website at:
http://www.richsutton.com
*/
#ifndef _TILES2_H_
#define _TILES2_H_
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <fcntl.h>
#include <unistd.h>
#define MAX_NUM_VARS 20 // Maximum number of variables in a grid-tiling
#define MAX_NUM_COORDS 100 // Maximum number of hashing coordinates
#define MaxLONGINT 2147483647
void GetTiles(
int tiles[], // provided array contains returned tiles (tile indices)
int num_tilings, // number of tile indices to be returned in tiles
int memory_size, // total number of possible tiles
float floats[], // array of floating point variables
int num_floats, // number of floating point variables
int ints[], // array of integer variables
int num_ints); // number of integer variables
class collision_table {
public:
collision_table(int,int);
~collision_table();
long m;
long *data;
int safe;
long calls;
long clearhits;
long collisions;
void reset();
int usage();
void save(char*, unsigned long);
void restore(char*, unsigned long);
};
void GetTiles(
int tiles[], // provided array contains returned tiles (tile indices)
int num_tilings, // number of tile indices to be returned in tiles
collision_table *ctable, // total number of possible tiles
float floats[], // array of floating point variables
int num_floats, // number of floating point variables
int ints[], // array of integer variables
int num_ints); // number of integer variables
int hash_UNH(int *ints, int num_ints, long m, int increment);
int hash(int *ints, int num_ints, collision_table *ctable);
// no ints
void GetTiles(int tiles[],int nt,int memory,float floats[],int nf);
void GetTiles(int tiles[],int nt,collision_table *ct,float floats[],int nf);
// one int
void GetTiles(int tiles[],int nt,int memory,float floats[],int nf,int h1);
void GetTiles(int tiles[],int nt,collision_table *ct,float floats[],int nf,int h1);
// two ints
void GetTiles(int tiles[],int nt,int memory,float floats[],int nf,int h1,int h2);
void GetTiles(int tiles[],int nt,collision_table *ct,float floats[],int nf,int h1,int h2);
// three ints
void GetTiles(int tiles[],int nt,int memory,float floats[],int nf,int h1,int h2,int h3);
void GetTiles(int tiles[],int nt,collision_table *ct,float floats[],int nf,int h1,int h2,int h3);
// one float, no ints
void GetTiles1(int tiles[],int nt,int memory,float f1);
void GetTiles1(int tiles[],int nt,collision_table *ct,float f1);
// one float, one int
void GetTiles1(int tiles[],int nt,int memory,float f1,int h1);
void GetTiles1(int tiles[],int nt,collision_table *ct,float f1,int h1);
// one float, two ints
void GetTiles1(int tiles[],int nt,int memory,float f1,int h1,int h2);
void GetTiles1(int tiles[],int nt,collision_table *ct,float f1,int h1,int h2);
// one float, three ints
void GetTiles1(int tiles[],int nt,int memory,float f1,int h1,int h2,int h3);
void GetTiles1(int tiles[],int nt,collision_table *ct,float f1,int h1,int h2,int h3);
// two floats, no ints
void GetTiles2(int tiles[],int nt,int memory,float f1,float f2);
void GetTiles2(int tiles[],int nt,collision_table *ct,float f1,float f2);
// two floats, one int
void GetTiles2(int tiles[],int nt,int memory,float f1,float f2,int h1);
void GetTiles2(int tiles[],int nt,collision_table *ct,float f1,float f2,int h1);
// two floats, two ints
void GetTiles2(int tiles[],int nt,int memory,float f1,float f2,int h1,int h2);
void GetTiles2(int tiles[],int nt,collision_table *ct,float f1,float f2,int h1,int h2);
// two floats, three ints
void GetTiles2(int tiles[],int nt,int memory,float f1,float f2,int h1,int h2,int h3);
void GetTiles2(int tiles[],int nt,collision_table *ct,float f1,float f2,int h1,int h2,int h3);
#endif
#include <iostream>
#include <vector>
#include <HFO.hpp>
#include <cstdlib>
#include <thread>
#include "SarsaAgent.h"
#include "CMAC.h"
#include <unistd.h>
// Before running this program, first Start HFO server:
// $./bin/HFO --offense-agents numAgents
void printUsage() {
std::cout<<"Usage: ./high_level_sarsa_agent [Options]"<<std::endl;
std::cout<<"Options:"<<std::endl;
std::cout<<" --numAgents <int> Number of SARSA agents"<<std::endl;
std::cout<<" Default: 1"<<std::endl;
std::cout<<" --numEpisodes <int> Number of episodes to run"<<std::endl;
std::cout<<" Default: 10"<<std::endl;
std::cout<<" --basePort <int> SARSA agent base port"<<std::endl;
std::cout<<" Default: 6000"<<std::endl;
std::cout<<" --learnRate <float> Learning rate of SARSA agents"<<std::endl;
std::cout<<" Range: [0.0, 1.0]"<<std::endl;
std::cout<<" Default: 0.1"<<std::endl;
std::cout<<" --suffix <int> Suffix for weights files"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --noOpponent Sets opponent present flag to false"<<std::endl;
std::cout<<" --help Displays this help and exit"<<std::endl;
}
// Returns the reward for SARSA based on current state
double getReward(int status) {
double reward;
if (status==hfo::GOAL) reward = 1;
else if (status==hfo::CAPTURED_BY_DEFENSE) reward = -1;
else if (status==hfo::OUT_OF_BOUNDS) reward = -1;
else reward = 0;
return reward;
}
// Fill state with only the required features from state_vec
void purgeFeatures(double *state, const std::vector<float>& state_vec,
int numTMates, bool oppPres) {
int stateIndex = 0;
// If no opponents ignore features Distance to Opponent
// and Distance from Teammate i to Opponent are absent
int tmpIndex = 9 + 3 * numTMates;
for(int i = 0; i < state_vec.size(); i++) {
// Ignore first six features and teammate proximity to opponent(when opponent is absent)and opponent features
if(i < 6||(!oppPres && ((i>9+numTMates && i<=9+2*numTMates)||i==9))||i>9+6*numTMates) continue;
// Ignore Angle and Uniform Number of Teammates
int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue;
state[stateIndex] = state_vec[i];
stateIndex++;
}
//std::cout<<stateIndex<<"yo";
}
// Convert int to hfo::Action
hfo::action_t toAction(int action, const std::vector<float>& state_vec) {
hfo::action_t a;
switch(action) {
case 0: a = hfo::SHOOT;
break;
case 1: a = hfo::DRIBBLE;
break;
default:int size = state_vec.size();
a = hfo::PASS;/*,
state_vec[(size - 1) - (action - 2) * 3],
0.0};*/
}
return a;
}
void offenseAgent(int port, int numTMates, int numEpi, double learnR,
int suffix, bool oppPres, double eps) {
// Number of features
int numF = oppPres ? (4 + 4 * numTMates) : (3 + 3 * numTMates);
// Number of actions
int numA = 2 + numTMates;
double discFac = 1;
// Tile coding parameter
double resolution = 0.1;
double range[numF];
double min[numF];
double res[numF];
for(int i = 0; i < numF; i++) {
min[i] = -1;
range[i] = 2;
res[i] = resolution;
}
// Weights file
char *wtFile;
std::string s = "weights_" + std::to_string(port) +
"_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix);
wtFile = &s[0u];
CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo;
hfo::status_t status;
hfo::action_t a;
double state[numF];
int action = -1;
double reward;
hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",6000,"localhost","base_left",false,"");
for (int episode=0; episode < numEpi; episode++) {
int count = 0;
status = hfo::IN_GAME;
action = -1;
while (status == hfo::IN_GAME) {
const std::vector<float>& state_vec = hfo.getState();
// If has ball
if(state_vec[5] == 1) {
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
}
// Fill up state array
purgeFeatures(state, state_vec, numTMates, oppPres);
// Get raw action
action = sa->selectAction(state);
// Get hfo::Action
a = toAction(action, state_vec);
} else {
a = hfo::MOVE;
}
if (a== hfo::PASS) {
hfo.act(a,state_vec[(9+6*numTMates) - (action-2)*3]);
//std::cout<<(9+6*numTMates) - (action-2)*3;
} else {
hfo.act(a);
}
status = hfo.step();
}
// End of episode
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
sa->endEpisode();
}
}
delete sa;
delete fa;
}
int main(int argc, char **argv) {
int numAgents = 1;
int numEpisodes = 10;
int basePort = 6000;
double learnR = 0.1;
int suffix = 0;
bool opponentPresent = true;
double eps = 0.01;
for(int i = 1; i < argc; i++) {
std::string param = std::string(argv[i]);
if(param == "--numAgents") {
numAgents = atoi(argv[++i]);
}else if(param == "--numEpisodes") {
numEpisodes = atoi(argv[++i]);
}else if(param == "--basePort") {
basePort = atoi(argv[++i]);
}else if(param == "--learnRate") {
learnR = atof(argv[++i]);
if(learnR < 0 || learnR > 1) {
printUsage();
return 0;
}
}else if(param == "--suffix") {
suffix = atoi(argv[++i]);
}else if(param == "--noOpponent") {
opponentPresent = false;
}else if(param=="--eps"){
eps=atoi(argv[++i]);
}else {
printUsage();
return 0;
}
}
int numTeammates = numAgents - 1;
std::thread agentThreads[numAgents];
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent] = std::thread(offenseAgent, basePort + agent,
numTeammates, numEpisodes, learnR,
suffix, opponentPresent, eps);
usleep(500000L);
}
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent].join();
}
return 0;
}
#Directories
FA_DIR = ../funcapprox
#Includes
INCLUDES = -I$(FA_DIR)
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = PolicyAgent.cpp SarsaAgent.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libpolicyagent.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#include "PolicyAgent.h"
PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile){
this->numFeatures = numFeatures;
this->numActions = numActions;
this->learningRate = learningRate;
this->epsilon = epsilon;
this->FA = FA;
toLoadWeights = strlen(loadWeightsFile) > 0;
if(toLoadWeights){
strcpy(this->loadWeightsFile, loadWeightsFile);
loadWeights(loadWeightsFile);
}
toSaveWeights = strlen(saveWeightsFile) > 0;
if(toSaveWeights){
strcpy(this->saveWeightsFile, saveWeightsFile);
}
}
PolicyAgent::~PolicyAgent(){
}
int PolicyAgent::getNumFeatures(){
return numFeatures;
}
int PolicyAgent::getNumActions(){
return numActions;
}
void PolicyAgent::loadWeights(char *fileName){
std::cout << "Loading Weights from " << fileName << std::endl;
FA->read(fileName);
}
void PolicyAgent::saveWeights(char *fileName){
FA->write(fileName);
}
int PolicyAgent::argmaxQ(double state[]){
return ((int)(drand48() * getNumActions()) % getNumActions());
}
double PolicyAgent::computeQ(double state[], int action){
return 0;
}
#ifndef POLICY_AGENT
#define POLICY_AGENT
#include <cstring>
#include <fstream>
#include <iostream>
#include "FuncApprox.h"
#define MAX_STATE_VARS 100
#define MAX_ACTIONS 10
class PolicyAgent{
private:
int numFeatures;
int numActions;
protected:
double learningRate;
double epsilon;
bool toLoadWeights;
char loadWeightsFile[256];
bool toSaveWeights;
char saveWeightsFile[256];
FunctionApproximator *FA;
int getNumFeatures();
int getNumActions();
public:
PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
~PolicyAgent();
virtual int argmaxQ(double state[]);
virtual double computeQ(double state[], int action);
virtual int selectAction(double state[]) = 0;
virtual void update(double state[], int action, double reward, double discountFactor) = 0;
virtual void endEpisode() = 0;
virtual void reset() = 0;
virtual void loadWeights(char *filename);
virtual void saveWeights(char *filename);
};
#endif
#include "SarsaAgent.h"
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
episodeNumber = 0;
lastAction = -1;
}
void SarsaAgent::update(double state[], int action, double reward, double discountFactor){
if(lastAction == -1){
for(int i = 0; i < getNumFeatures(); i++){
lastState[i] = state[i];
}
lastAction = action;
lastReward = reward;
}
else{
FA->setState(lastState);
double oldQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
double delta = lastReward - oldQ;
FA->setState(state);
//Sarsa update
double newQ = FA->computeQ(action);
delta += discountFactor * newQ;
FA->updateWeights(delta, learningRate);
//Assume gamma, lambda are 0.
FA->decayTraces(0);
for(int i = 0; i < getNumFeatures(); i++){
lastState[i] = state[i];
}
lastAction = action;
lastReward = reward;
}
}
void SarsaAgent::endEpisode(){
episodeNumber++;
//This will not happen usually, but is a safety.
if(lastAction == -1){
return;
}
else{
FA->setState(lastState);
double oldQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
double delta = lastReward - oldQ;
FA->updateWeights(delta, learningRate);
//Assume lambda is 0.
FA->decayTraces(0);
}
if(toSaveWeights && (episodeNumber + 1) % 5 == 0){
saveWeights(saveWeightsFile);
std::cout << "Saving weights to " << saveWeightsFile << std::endl;
}
lastAction = -1;
}
void SarsaAgent::reset(){
lastAction = -1;
}
int SarsaAgent::selectAction(double state[]){
int action;
if(drand48() < epsilon){
action = (int)(drand48() * getNumActions()) % getNumActions();
}
else{
action = argmaxQ(state);
}
return action;
}
int SarsaAgent::argmaxQ(double state[]){
double Q[getNumActions()];
FA->setState(state);
for(int i = 0; i < getNumActions(); i++){
Q[i] = FA->computeQ(i);
}
int bestAction = 0;
double bestValue = Q[bestAction];
int numTies = 0;
double EPS=1.0e-4;
for (int a = 1; a < getNumActions(); a++){
double value = Q[a];
if(fabs(value - bestValue) < EPS){
numTies++;
if(drand48() < (1.0 / (numTies + 1))){
bestValue = value;
bestAction = a;
}
}
else if (value > bestValue){
bestValue = value;
bestAction = a;
numTies = 0;
}
}
return bestAction;
}
//Be careful. This resets FA->state.
double SarsaAgent::computeQ(double state[], int action){
FA->setState(state);
double QValue = FA->computeQ(action);
return QValue;
}
#ifndef SARSA_AGENT
#define SARSA_AGENT
#include "PolicyAgent.h"
#include "FuncApprox.h"
class SarsaAgent:public PolicyAgent{
private:
int episodeNumber;
double lastState[MAX_STATE_VARS];
int lastAction;
double lastReward;
public:
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
int argmaxQ(double state[]);
double computeQ(double state[], int action);
int selectAction(double state[]);
void update(double state[], int action, double reward, double discountFactor);
void endEpisode();
void reset();
};
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment