Commit 08411cb1 authored by Siddharth Aravindan's avatar Siddharth Aravindan Committed by asiddharth

Refactored SARSA libraries, added hand coded agent

parent 384fb648
This diff is collapsed.
# Run this file to create an executable of hand_coded_defense_agent.cpp
g++ -c hand_coded_defense_agent.cpp -I ../src/ -std=c++0x -pthread
g++ -L ../lib/ hand_coded_defense_agent.o -lhfo -pthread -o hand_coded_defense_agent -Wl,-rpath,../lib
#Directories
FA_DIR = ./funcapprox
POLICY_DIR = ./policy
HFO_SRC_DIR = ../../src
HFO_LIB_DIR = ../../lib
#Includes
INCLUDES = -I$(FA_DIR) -I$(POLICY_DIR) -I$(HFO_SRC_DIR)
#Libs
FA_LIB = funcapprox
POLICY_LIB = policyagent
#Flags
CXXFLAGS = -g -Wall -std=c++11 -pthread
LDFLAGS = -l$(FA_LIB) -l$(POLICY_LIB) -lhfo -pthread
LDLIBS = -L$(FA_DIR) -L$(POLICY_DIR) -L$(HFO_LIB_DIR)
LINKEROPTIONS = -Wl,-rpath,$(HFO_LIB_DIR)
#Compiler
CXX = g++
#Sources
SRC = high_level_sarsa_agent.cpp
#Objects
OBJ = $(SRC:.cpp=.o)
#Target
TARGET = high_level_sarsa_agent
#Rules
.PHONY: $(FA_LIB)
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(FA_LIB):
$(MAKE) -C $(FA_DIR)
$(POLICY_LIB):
$(MAKE) -C $(POLICY_DIR)
$(TARGET): $(FA_LIB) $(POLICY_LIB) $(OBJ)
$(CXX) $(OBJ) $(CXXFLAGS) $(LDLIBS) $(LDFLAGS) -o $(TARGET) $(LINKEROPTIONS)
cleanfa:
$(MAKE) clean -C $(FA_DIR)
cleanpolicy:
$(MAKE) clean -C $(POLICY_DIR)
clean: cleanfa cleanpolicy
rm -f $(TARGET) $(OBJ) *~
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = FuncApprox.cpp tiles2.cpp CMAC.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libfuncapprox.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#Directories
FA_DIR = ../funcapprox
#Includes
INCLUDES = -I$(FA_DIR)
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = PolicyAgent.cpp SarsaAgent.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libpolicyagent.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#include <iostream>
#include <vector>
#include <HFO.hpp>
#include <cstdlib>
#include <thread>
#include "SarsaAgent.h"
#include "CMAC.h"
#include <unistd.h>
// Before running this program, first Start HFO server:
// $./bin/HFO --offense-agents numAgents
void printUsage() {
std::cout<<"Usage:123 ./high_level_sarsa_agent [Options]"<<std::endl;
std::cout<<"Options:"<<std::endl;
std::cout<<" --numAgents <int> Number of SARSA agents"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --numEpisodes <int> Number of episodes to run"<<std::endl;
std::cout<<" Default: 10"<<std::endl;
std::cout<<" --basePort <int> SARSA agent base port"<<std::endl;
std::cout<<" Default: 6000"<<std::endl;
std::cout<<" --learnRate <float> Learning rate of SARSA agents"<<std::endl;
std::cout<<" Range: [0.0, 1.0]"<<std::endl;
std::cout<<" Default: 0.1"<<std::endl;
std::cout<<" --suffix <int> Suffix for weights files"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --noOpponent Sets opponent present flag to false"<<std::endl;
std::cout<<" --step Sets the persistent step size"<<std::endl;
std::cout<<" --eps Sets the exploration rate"<<std::endl;
std::cout<<" --numOpponents Sets the number of opponents"<<std::endl;
std::cout<<" --weightId Sets the given Id for weight File"<<std::endl;
std::cout<<" --help Displays this help and exit"<<std::endl;
}
// Returns the reward for SARSA based on current state
double getReward(hfo::status_t status) {
double reward;
if (status==hfo::GOAL) reward = -1;
else if (status==hfo::CAPTURED_BY_DEFENSE) reward = 1;
else if (status==hfo::OUT_OF_BOUNDS) reward = 1;
else reward = 0;
return reward;
}
// Fill state with only the required features from state_vec
void purgeFeatures(double *state, const std::vector<float>& state_vec,
int numTMates, int numOpponents, bool oppPres) {
int stateIndex = 0;
// If no opponents ignore features Distance to Opponent
// and Distance from Teammate i to Opponent are absent
int tmpIndex = oppPres ? (9 + 3 * numTMates) : (9 + 2 * numTMates);
for(int i = 0; i < state_vec.size(); i++) {
// Ignore first six featues
if(i == 5 || i==8) continue;
if(i>9 && i<= 9+numTMates) continue; // Ignore Goal Opening angles, as invalid
if(i<= 9+3*numTMates && i > 9+2*numTMates) continue; // Ignore Pass Opening angles, as invalid
// Ignore Uniform Number of Teammates and opponents
int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 0) )continue;
//if (i > 9+6*numTMates) continue;
state[stateIndex] = state_vec[i];
stateIndex++;
}
}
// Convert int to hfo::Action
hfo::action_t toAction(int action, const std::vector<float>& state_vec) {
hfo::action_t a;
switch (action) {
case 0: a = hfo::MOVE; break;
case 1: a = hfo::REDUCE_ANGLE_TO_GOAL; break;
case 2: a = hfo::GO_TO_BALL; break;
case 3: a = hfo::NOOP; break;
case 4: a = hfo::DEFEND_GOAL; break;
default : a = hfo::MARK_PLAYER; break;
}
return a;
}
void offenseAgent(int port, int numTMates, int numOpponents, int numEpi, double learnR,
int suffix, bool oppPres, double eps, int step, std::string weightid) {
// Number of features
int numF = oppPres ? (8 + 3 * numTMates + 2*numOpponents) : (3 + 3 * numTMates);
// Number of actions
int numA = 5 + numOpponents; //DEF_GOAL+MOVE+GTB+NOOP+RATG+MP(unum)
// Other SARSA parameters
eps = 0.01;
double discFac = 1;
double lambda=0.9375;
// Tile coding parameter
double resolution = 0.1;
double range[numF];
double min[numF];
double res[numF];
for(int i = 0; i < numF; i++) {
min[i] = -1;
range[i] = 2;
res[i] = resolution;
}
// Weights file
char *wtFile;
std::string s = "weights_" + std::to_string(port) +
"_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix) +
"_" + std::to_string(step) +
"_" + weightid;
wtFile = &s[0u];
CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo;
hfo::status_t status;
hfo::action_t a;
double state[numF];
int action = -1;
double reward;
int no_of_offense = numTMates + 1;
hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",port,"localhost","base_right",false,"");
for (int episode=0; episode < numEpi; episode++) {
int count = 0;
status = hfo::IN_GAME;
action = -1;
int count_steps = 0;
double unum = -1;
const std::vector<float>& state_vec = hfo.getState();
int num_steps_per_epi = 0;
while (status == hfo::IN_GAME) {
num_steps_per_epi++;
//std::cout << "::::::"<< hfo::ActionToString(a) <<" "<<count_steps <<std::endl;
if (count_steps != step && action >=0 && (a != hfo :: MARK_PLAYER || unum>0)) {
count_steps ++;
if (a == hfo::MARK_PLAYER) {
hfo.act(a,unum);
//std::cout << "MARKING" << unum <<"\n";
} else {
hfo.act(a);
}
status = hfo.step();
continue;
} else {
count_steps = 0;
}
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
}
// Fill up state array
purgeFeatures(state, state_vec, numTMates, numOpponents, oppPres);
// Get raw action
action = sa->selectAction(state);
// Get hfo::Action
a = toAction(action, state_vec);
if (a== hfo::MARK_PLAYER) {
unum = state_vec[(state_vec.size()-1 - (action-5)*3)];
hfo.act(a,unum);
} else {
hfo.act(a);
}
std::string s = std::to_string(action);
for (int state_vec_fc=0; state_vec_fc < state_vec.size(); state_vec_fc++) {
s+=std::to_string(state_vec[state_vec_fc]) + ",";
}
s+="UNUM" +std::to_string(unum) +"\n";;
status = hfo.step();
// std::cout <<s;
}
//std :: cout <<":::::::::::::" << num_steps_per_epi<< " "<<step << " "<<"\n";
// End of episode
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
sa->endEpisode();
}
}
delete sa;
delete fa;
}
int main(int argc, char **argv) {
int numAgents = 0;
int numEpisodes = 10;
int basePort = 6000;
double learnR = 0.1;
int suffix = 0;
bool opponentPresent = true;
int numOpponents = 0;
double eps = 0.01;
int step = 10;
std::string weightid;
for (int i = 0; i<argc; i++) {
std::string param = std::string(argv[i]);
std::cout<<param<<"\n";
}
for(int i = 1; i < argc; i++) {
std::string param = std::string(argv[i]);
if(param == "--numAgents") {
numAgents = atoi(argv[++i]);
}else if(param == "--numEpisodes") {
numEpisodes = atoi(argv[++i]);
}else if(param == "--basePort") {
basePort = atoi(argv[++i]);
}else if(param == "--learnRate") {
learnR = atof(argv[++i]);
if(learnR < 0 || learnR > 1) {
printUsage();
return 0;
}
}else if(param == "--suffix") {
suffix = atoi(argv[++i]);
}else if(param == "--noOpponent") {
opponentPresent = false;
}else if(param=="--eps"){
eps=atoi(argv[++i]);
}else if(param=="--numOpponents"){
numOpponents=atoi(argv[++i]);
}else if(param=="--step"){
step=atoi(argv[++i]);
}else if(param=="--weightId"){
weightid=std::string(argv[++i]);
}else {
printUsage();
return 0;
}
}
int numTeammates = numAgents; //using goalie npc
std::thread agentThreads[numAgents];
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent] = std::thread(offenseAgent, basePort + agent,
numTeammates, numOpponents, numEpisodes, learnR,
suffix, opponentPresent, eps, step, weightid);
usleep(500000L);
}
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent].join();
}
return 0;
}
#include "SarsaAgent.h" #include "SarsaAgent.h"
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){ //add lambda as parameter to sarsaagent
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
this->lambda = lambda;
episodeNumber = 0; episodeNumber = 0;
lastAction = -1; lastAction = -1;
//have memory for lambda
} }
void SarsaAgent::update(double state[], int action, double reward, double discountFactor){ void SarsaAgent::update(double state[], int action, double reward, double discountFactor){
...@@ -34,7 +35,7 @@ void SarsaAgent::update(double state[], int action, double reward, double discou ...@@ -34,7 +35,7 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume gamma, lambda are 0. //Assume gamma, lambda are 0.
FA->decayTraces(0); FA->decayTraces(discountFactor*lambda);//replace 0 with gamma*lambda
for(int i = 0; i < getNumFeatures(); i++){ for(int i = 0; i < getNumFeatures(); i++){
lastState[i] = state[i]; lastState[i] = state[i];
...@@ -59,8 +60,8 @@ void SarsaAgent::endEpisode(){ ...@@ -59,8 +60,8 @@ void SarsaAgent::endEpisode(){
double delta = lastReward - oldQ; double delta = lastReward - oldQ;
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume lambda is 0. //Assume lambda is 0. this comment looks wrong.
FA->decayTraces(0); FA->decayTraces(0);//remains 0
} }
if(toSaveWeights && (episodeNumber + 1) % 5 == 0){ if(toSaveWeights && (episodeNumber + 1) % 5 == 0){
......
...@@ -12,10 +12,11 @@ class SarsaAgent:public PolicyAgent{ ...@@ -12,10 +12,11 @@ class SarsaAgent:public PolicyAgent{
double lastState[MAX_STATE_VARS]; double lastState[MAX_STATE_VARS];
int lastAction; int lastAction;
double lastReward; double lastReward;
double lambda;
public: public:
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile); SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
int argmaxQ(double state[]); int argmaxQ(double state[]);
double computeQ(double state[], int action); double computeQ(double state[], int action);
......
...@@ -56,7 +56,7 @@ void purgeFeatures(double *state, const std::vector<float>& state_vec, ...@@ -56,7 +56,7 @@ void purgeFeatures(double *state, const std::vector<float>& state_vec,
// Ignore Angle and Uniform Number of Teammates // Ignore Angle and Uniform Number of Teammates
int temp = i-tmpIndex; int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue; if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue;
if (i > 9+6*numTMates) continue;
state[stateIndex] = state_vec[i]; state[stateIndex] = state_vec[i];
stateIndex++; stateIndex++;
} }
...@@ -107,9 +107,9 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -107,9 +107,9 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
"_" + std::to_string(numTMates + 1) + "_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix); "_" + std::to_string(suffix);
wtFile = &s[0u]; wtFile = &s[0u];
double lambda = 0;
CMAC *fa = new CMAC(numF, numA, range, min, res); CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, fa, wtFile, wtFile); SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo; hfo::HFOEnvironment hfo;
hfo::status_t status; hfo::status_t status;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment