Commit 08411cb1 authored by Siddharth Aravindan's avatar Siddharth Aravindan Committed by asiddharth

Refactored SARSA libraries, added hand coded agent

parent 384fb648
#include <vector>
#include <HFO.hpp>
#include <cstdlib>
#include <math.h>
#include <fstream>
using namespace std;
using namespace hfo;
/* Before running this program, first Start HFO server:
../bin/HFO --offense-npcs 2 --defense-agents 1 --defense-npcs 1
This is a hand coded defense agent, which can play a 2v2 game againt 2 offense npcs when paired up with a goal keeper
Server Connection Options. See printouts from bin/HFO.*/
feature_set_t features = HIGH_LEVEL_FEATURE_SET;
string config_dir = "../bin/teams/base/config/formations-dt";
int port = 7000;
string server_addr = "localhost";
string team_name = "base_right";
bool goalie = false;
double kickable_dist = 1.504052352;
double open_area_up_limit_x = 0.747311440447;
double open_area_up_limit_y = 0.229619544504;
double open_area_low_limit_x = -0.352161264597;
double open_area_low_limit_y = 0.140736680776;
double tackle_limit = 1.613456553;
double HALF_FIELD_WIDTH = 68 ; // y coordinate -34 to 34 (-34 = bottom 34 = top)
double HALF_FIELD_LENGTH = 52.5; // x coordinate 0 to 52.5 (0 = goalline 52.5 = center)
struct action_with_params {
action_t action;
double param;
};
// Returns a random high-level action
action_t get_random_high_lv_action() {
action_t action_indx = (action_t) ((rand() % 4) + REDUCE_ANGLE_TO_GOAL);
return action_indx;
}
double get_actual_angle( double normalized_angle) {
return normalized_angle * M_PI;
}
double get_dist_normalized (double ref_x, double ref_y, double src_x, double src_y) {
return sqrt(pow(ref_x - src_x,2) + pow((HALF_FIELD_WIDTH/HALF_FIELD_LENGTH)*(ref_y - src_y),2));
}
bool is_kickable(double ball_pos_x, double ball_pos_y, double opp_pos_x, double opp_pos_y) {
return get_dist_normalized(ball_pos_x, ball_pos_y, opp_pos_x, opp_pos_y) < kickable_dist; //#param
}
bool is_in_open_area(double pos_x, double pos_y) {
if (pos_x < open_area_up_limit_x ) {//&& pos_x > open_area_low_limit_x && pos_y < open_area_up_limit_y && pos_y > open_area_low_limit_y ) { //#param
return false;
} else {
return true;
}
}
action_with_params get_defense_action(const std::vector<float>& state_vec, double no_of_opponents, double numTMates) {
int size_of_vec = 10 + 6*numTMates + 3*no_of_opponents;
if (size_of_vec != state_vec.size()) {
std :: cout <<"Invalid Feature Vector / Check the number of teammates/opponents provided";
return {NOOP,0};
}
double agent_posx = state_vec[0];
double agent_posy = state_vec[1];
double agent_orientation = get_actual_angle(state_vec[2]);
double opp1_unum = state_vec[9+6*numTMates+3];
double opp2_unum = state_vec[9+(1*3)+6*numTMates+3];
double ball_pos_x = state_vec[3];
double ball_pos_y = state_vec[4];
double opp1_pos_x = state_vec[9+6*numTMates+1];
double opp1_pos_y = state_vec[9+6*numTMates+2];
double opp2_pos_x = state_vec[9+(1*3)+6*numTMates+3];
double opp2_pos_y = state_vec[9+(1*3)+6*numTMates+3];
double opp1_dist_to_ball = get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y);
double opp1_dist_to_agent = get_dist_normalized(agent_posx, agent_posy, opp1_pos_x, opp1_pos_y);
bool is_kickable_opp1 = is_kickable(ball_pos_x, ball_pos_y,opp1_pos_x, opp1_pos_y);
bool is_in_open_area_opp1 = is_in_open_area(opp1_pos_x, opp1_pos_y);
double opp2_dist_to_ball = get_dist_normalized(ball_pos_x, ball_pos_y, opp2_pos_x, opp2_pos_y);
double opp2_dist_to_agent = get_dist_normalized(agent_posx, agent_posy, opp2_pos_x, opp2_pos_y);
bool is_kickable_opp2 = is_kickable(ball_pos_x, ball_pos_y,opp2_pos_x, opp2_pos_y);
bool is_in_open_area_opp2 = is_in_open_area(opp2_pos_x, opp2_pos_y);
double tackle_limit_nn = tackle_limit;
if (is_in_open_area(opp1_pos_x, opp1_pos_y) && is_in_open_area(opp2_pos_x, opp2_pos_y)) {
//std:: cout << "In open Area" << "\n";
if (is_kickable(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) &&
get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) <
get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y)) {
return {MARK_PLAYER, opp2_unum};
// return {REDUCE_ANGLE_TO_GOAL, 1};
} else if (is_kickable(ball_pos_x, ball_pos_y, opp2_pos_x, opp2_pos_y)) {
return {MARK_PLAYER, opp1_unum};
// return {REDUCE_ANGLE_TO_GOAL, 1};
} else if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) >
get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy) &&
get_dist_normalized(ball_pos_x, ball_pos_y, opp2_pos_x, opp2_pos_y) >
get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy)) {
return {GO_TO_BALL,1};
} else {
return {REDUCE_ANGLE_TO_GOAL, 1};
}
} else {
//std:: cout << "In Penalty Area" << "\n";
if (! is_kickable(ball_pos_x,ball_pos_y,opp1_pos_x, opp1_pos_y) && ! is_kickable(ball_pos_x,ball_pos_y,opp2_pos_x,opp2_pos_y))
{
//std :: cout <<"IN AREA BUT GOTO\n";
if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) > get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy) &&
get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) > get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy)) {
return {GO_TO_BALL,2};
} else {
return {REDUCE_ANGLE_TO_GOAL, 0};
}
} else if ( get_dist_normalized (agent_posx, agent_posy, opp1_pos_x, opp1_pos_y) < tackle_limit
|| get_dist_normalized (agent_posx, agent_posy, opp2_pos_x, opp2_pos_y) < tackle_limit ) { //#param
/*double turn_angle;
//REVISIT TURN ANGLE.. CONSIDER ACTUAL DIRECTION ALSO..
if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) < get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y)) {
turn_angle = atan2((HALF_FIELD_WIDTH/HALF_FIELD_LENGTH)*(opp1_pos_y), opp1_pos_x) - (agent_orientation*M_PI);
} else {
turn_angle = atan2((HALF_FIELD_WIDTH/HALF_FIELD_LENGTH)*(opp2_pos_y), opp2_pos_x) - (agent_orientation*M_PI);
}
turn_angle = atan2(tan(turn_angle),1);
if (turn_angle > M_PI ) {
turn_angle -= 2*M_PI;
} else if (turn_angle < -M_PI) {
turn_angle += 2*M_PI;
}*/
/*string s = "TACKLE " + std ::to_string (turn_angle) + "\n";
cout << s; */
//return {DASH, turn_angle*180/M_PI};
//return {TACKLE, turn_angle*180/M_PI}; //TACKLE needs power od dir
return {MOVE,0};
} else if ((!is_in_open_area(opp1_pos_x, opp1_pos_y) && is_in_open_area(opp2_pos_x, opp2_pos_y)) ||
(!is_in_open_area(opp2_pos_x, opp2_pos_y) && is_in_open_area(opp1_pos_x, opp1_pos_y))) {
return {REDUCE_ANGLE_TO_GOAL,0};
} else if (!is_in_open_area(opp1_pos_x, opp1_pos_y) && !is_in_open_area(opp2_pos_x, opp2_pos_y)) {
if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) < get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y)) {
return {MARK_PLAYER,opp2_unum};
} else {
return {MARK_PLAYER,opp1_unum};
}
} else {
std :: cout <<"Unknown Condition";
return {NOOP,0};
}
}
}
void read_params() {
std::ifstream fin("params.txt");
double d[6];
int i = 0;
while (true) {
fin >> d[i];
if( fin.eof() ) break;
std::cout << d[i] << std::endl;
i++;
//if (i >=6 ) {
// std::cout << "invalid params" << d[5];
// exit(0);
//}
}
fin.close();
kickable_dist = (d[0]+1)* 0.818175061;
open_area_up_limit_x = d[1];
open_area_up_limit_y = d[2];
open_area_low_limit_x = d[3];
open_area_low_limit_y = d[4];
tackle_limit = (d[5]+1)* 0.818175061;
std :: cout << "kickable dist " << kickable_dist << "tackle_limit" <<tackle_limit;
return;
}
void write_cost(double cost) {
std::ofstream myfile ("cost.txt");
if (myfile.is_open()) {
myfile << cost;
myfile.close();
}
return;
}
int main(int argc, char** argv) {
// Create the HFO environment
//read_params();
HFOEnvironment hfo;
int random = 0;
double numGoals = 0;
double numEpisodes = 5000;
// Connect to the server and request high-level feature set. See
// manual for more information on feature sets.
hfo.connectToServer(features, config_dir, port, server_addr,
team_name, goalie);
for (int episode=0; episode<numEpisodes; episode++) {
status_t status = IN_GAME;
while (status == IN_GAME) {
// Get the vector of state features for the current state
const vector<float>& feature_vec = hfo.getState();
if (random == 0) {
action_with_params a = get_defense_action(feature_vec, 2,1);
// std::cout << a.action << a.param;
if (a.action == hfo :: MARK_PLAYER || a.action == hfo::TACKLE) {
hfo.act(a.action, a.param);
} else if (a.action == hfo :: DASH) {
double power = 100;
hfo.act(a.action, power, a.param);
} else {
hfo.act(a.action);
}
string s = hfo::ActionToString(a.action) + " " +to_string(a.param) + "\n";
// std::cout << s;
} else {
std::cout <<"Randm";
action_t a = get_random_high_lv_action();
if (a == hfo :: MARK_PLAYER) {
hfo.act(NOOP);
} else {
hfo.act(a);
}
}
//hfo.act(hfo::INTERCEPT);
status = hfo.step();
}
if (status==GOAL)
numGoals = numGoals+1;
// Check what the outcome of the episode was
cout << "Episode " << episode << " ended with status: "
<< StatusToString(status) << std::endl;
}
double cost = numGoals/numEpisodes;
hfo.act(QUIT);
//write_cost(cost);
};
# Run this file to create an executable of hand_coded_defense_agent.cpp
g++ -c hand_coded_defense_agent.cpp -I ../src/ -std=c++0x -pthread
g++ -L ../lib/ hand_coded_defense_agent.o -lhfo -pthread -o hand_coded_defense_agent -Wl,-rpath,../lib
#Directories
FA_DIR = ./funcapprox
POLICY_DIR = ./policy
HFO_SRC_DIR = ../../src
HFO_LIB_DIR = ../../lib
#Includes
INCLUDES = -I$(FA_DIR) -I$(POLICY_DIR) -I$(HFO_SRC_DIR)
#Libs
FA_LIB = funcapprox
POLICY_LIB = policyagent
#Flags
CXXFLAGS = -g -Wall -std=c++11 -pthread
LDFLAGS = -l$(FA_LIB) -l$(POLICY_LIB) -lhfo -pthread
LDLIBS = -L$(FA_DIR) -L$(POLICY_DIR) -L$(HFO_LIB_DIR)
LINKEROPTIONS = -Wl,-rpath,$(HFO_LIB_DIR)
#Compiler
CXX = g++
#Sources
SRC = high_level_sarsa_agent.cpp
#Objects
OBJ = $(SRC:.cpp=.o)
#Target
TARGET = high_level_sarsa_agent
#Rules
.PHONY: $(FA_LIB)
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(FA_LIB):
$(MAKE) -C $(FA_DIR)
$(POLICY_LIB):
$(MAKE) -C $(POLICY_DIR)
$(TARGET): $(FA_LIB) $(POLICY_LIB) $(OBJ)
$(CXX) $(OBJ) $(CXXFLAGS) $(LDLIBS) $(LDFLAGS) -o $(TARGET) $(LINKEROPTIONS)
cleanfa:
$(MAKE) clean -C $(FA_DIR)
cleanpolicy:
$(MAKE) clean -C $(POLICY_DIR)
clean: cleanfa cleanpolicy
rm -f $(TARGET) $(OBJ) *~
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = FuncApprox.cpp tiles2.cpp CMAC.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libfuncapprox.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#Directories
FA_DIR = ../funcapprox
#Includes
INCLUDES = -I$(FA_DIR)
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = PolicyAgent.cpp SarsaAgent.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libpolicyagent.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#include <iostream>
#include <vector>
#include <HFO.hpp>
#include <cstdlib>
#include <thread>
#include "SarsaAgent.h"
#include "CMAC.h"
#include <unistd.h>
// Before running this program, first Start HFO server:
// $./bin/HFO --offense-agents numAgents
void printUsage() {
std::cout<<"Usage:123 ./high_level_sarsa_agent [Options]"<<std::endl;
std::cout<<"Options:"<<std::endl;
std::cout<<" --numAgents <int> Number of SARSA agents"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --numEpisodes <int> Number of episodes to run"<<std::endl;
std::cout<<" Default: 10"<<std::endl;
std::cout<<" --basePort <int> SARSA agent base port"<<std::endl;
std::cout<<" Default: 6000"<<std::endl;
std::cout<<" --learnRate <float> Learning rate of SARSA agents"<<std::endl;
std::cout<<" Range: [0.0, 1.0]"<<std::endl;
std::cout<<" Default: 0.1"<<std::endl;
std::cout<<" --suffix <int> Suffix for weights files"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --noOpponent Sets opponent present flag to false"<<std::endl;
std::cout<<" --step Sets the persistent step size"<<std::endl;
std::cout<<" --eps Sets the exploration rate"<<std::endl;
std::cout<<" --numOpponents Sets the number of opponents"<<std::endl;
std::cout<<" --weightId Sets the given Id for weight File"<<std::endl;
std::cout<<" --help Displays this help and exit"<<std::endl;
}
// Returns the reward for SARSA based on current state
double getReward(hfo::status_t status) {
double reward;
if (status==hfo::GOAL) reward = -1;
else if (status==hfo::CAPTURED_BY_DEFENSE) reward = 1;
else if (status==hfo::OUT_OF_BOUNDS) reward = 1;
else reward = 0;
return reward;
}
// Fill state with only the required features from state_vec
void purgeFeatures(double *state, const std::vector<float>& state_vec,
int numTMates, int numOpponents, bool oppPres) {
int stateIndex = 0;
// If no opponents ignore features Distance to Opponent
// and Distance from Teammate i to Opponent are absent
int tmpIndex = oppPres ? (9 + 3 * numTMates) : (9 + 2 * numTMates);
for(int i = 0; i < state_vec.size(); i++) {
// Ignore first six featues
if(i == 5 || i==8) continue;
if(i>9 && i<= 9+numTMates) continue; // Ignore Goal Opening angles, as invalid
if(i<= 9+3*numTMates && i > 9+2*numTMates) continue; // Ignore Pass Opening angles, as invalid
// Ignore Uniform Number of Teammates and opponents
int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 0) )continue;
//if (i > 9+6*numTMates) continue;
state[stateIndex] = state_vec[i];
stateIndex++;
}
}
// Convert int to hfo::Action
hfo::action_t toAction(int action, const std::vector<float>& state_vec) {
hfo::action_t a;
switch (action) {
case 0: a = hfo::MOVE; break;
case 1: a = hfo::REDUCE_ANGLE_TO_GOAL; break;
case 2: a = hfo::GO_TO_BALL; break;
case 3: a = hfo::NOOP; break;
case 4: a = hfo::DEFEND_GOAL; break;
default : a = hfo::MARK_PLAYER; break;
}
return a;
}
void offenseAgent(int port, int numTMates, int numOpponents, int numEpi, double learnR,
int suffix, bool oppPres, double eps, int step, std::string weightid) {
// Number of features
int numF = oppPres ? (8 + 3 * numTMates + 2*numOpponents) : (3 + 3 * numTMates);
// Number of actions
int numA = 5 + numOpponents; //DEF_GOAL+MOVE+GTB+NOOP+RATG+MP(unum)
// Other SARSA parameters
eps = 0.01;
double discFac = 1;
double lambda=0.9375;
// Tile coding parameter
double resolution = 0.1;
double range[numF];
double min[numF];
double res[numF];
for(int i = 0; i < numF; i++) {
min[i] = -1;
range[i] = 2;
res[i] = resolution;
}
// Weights file
char *wtFile;
std::string s = "weights_" + std::to_string(port) +
"_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix) +
"_" + std::to_string(step) +
"_" + weightid;
wtFile = &s[0u];
CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo;
hfo::status_t status;
hfo::action_t a;
double state[numF];
int action = -1;
double reward;
int no_of_offense = numTMates + 1;
hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",port,"localhost","base_right",false,"");
for (int episode=0; episode < numEpi; episode++) {
int count = 0;
status = hfo::IN_GAME;
action = -1;
int count_steps = 0;
double unum = -1;
const std::vector<float>& state_vec = hfo.getState();
int num_steps_per_epi = 0;
while (status == hfo::IN_GAME) {
num_steps_per_epi++;
//std::cout << "::::::"<< hfo::ActionToString(a) <<" "<<count_steps <<std::endl;
if (count_steps != step && action >=0 && (a != hfo :: MARK_PLAYER || unum>0)) {
count_steps ++;
if (a == hfo::MARK_PLAYER) {
hfo.act(a,unum);
//std::cout << "MARKING" << unum <<"\n";
} else {
hfo.act(a);
}
status = hfo.step();
continue;
} else {
count_steps = 0;
}
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
}
// Fill up state array
purgeFeatures(state, state_vec, numTMates, numOpponents, oppPres);
// Get raw action
action = sa->selectAction(state);
// Get hfo::Action
a = toAction(action, state_vec);
if (a== hfo::MARK_PLAYER) {
unum = state_vec[(state_vec.size()-1 - (action-5)*3)];
hfo.act(a,unum);
} else {
hfo.act(a);
}
std::string s = std::to_string(action);
for (int state_vec_fc=0; state_vec_fc < state_vec.size(); state_vec_fc++) {
s+=std::to_string(state_vec[state_vec_fc]) + ",";
}
s+="UNUM" +std::to_string(unum) +"\n";;
status = hfo.step();
// std::cout <<s;
}
//std :: cout <<":::::::::::::" << num_steps_per_epi<< " "<<step << " "<<"\n";
// End of episode
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
sa->endEpisode();
}
}
delete sa;
delete fa;
}
int main(int argc, char **argv) {
int numAgents = 0;
int numEpisodes = 10;
int basePort = 6000;
double learnR = 0.1;
int suffix = 0;
bool opponentPresent = true;
int numOpponents = 0;
double eps = 0.01;
int step = 10;
std::string weightid;
for (int i = 0; i<argc; i++) {
std::string param = std::string(argv[i]);
std::cout<<param<<"\n";
}
for(int i = 1; i < argc; i++) {
std::string param = std::string(argv[i]);
if(param == "--numAgents") {
numAgents = atoi(argv[++i]);
}else if(param == "--numEpisodes") {
numEpisodes = atoi(argv[++i]);
}else if(param == "--basePort") {
basePort = atoi(argv[++i]);
}else if(param == "--learnRate") {
learnR = atof(argv[++i]);
if(learnR < 0 || learnR > 1) {
printUsage();
return 0;
}
}else if(param == "--suffix") {
suffix = atoi(argv[++i]);
}else if(param == "--noOpponent") {
opponentPresent = false;
}else if(param=="--eps"){
eps=atoi(argv[++i]);
}else if(param=="--numOpponents"){
numOpponents=atoi(argv[++i]);
}else if(param=="--step"){
step=atoi(argv[++i]);
}else if(param=="--weightId"){
weightid=std::string(argv[++i]);
}else {
printUsage();
return 0;
}
}
int numTeammates = numAgents; //using goalie npc
std::thread agentThreads[numAgents];
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent] = std::thread(offenseAgent, basePort + agent,
numTeammates, numOpponents, numEpisodes, learnR,
suffix, opponentPresent, eps, step, weightid);
usleep(500000L);
}
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent].join();
}
return 0;
}
#include "SarsaAgent.h" #include "SarsaAgent.h"
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){ //add lambda as parameter to sarsaagent
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
this->lambda = lambda;
episodeNumber = 0; episodeNumber = 0;
lastAction = -1; lastAction = -1;
//have memory for lambda
} }
void SarsaAgent::update(double state[], int action, double reward, double discountFactor){ void SarsaAgent::update(double state[], int action, double reward, double discountFactor){
...@@ -34,7 +35,7 @@ void SarsaAgent::update(double state[], int action, double reward, double discou ...@@ -34,7 +35,7 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume gamma, lambda are 0. //Assume gamma, lambda are 0.
FA->decayTraces(0); FA->decayTraces(discountFactor*lambda);//replace 0 with gamma*lambda
for(int i = 0; i < getNumFeatures(); i++){ for(int i = 0; i < getNumFeatures(); i++){
lastState[i] = state[i]; lastState[i] = state[i];
...@@ -59,8 +60,8 @@ void SarsaAgent::endEpisode(){ ...@@ -59,8 +60,8 @@ void SarsaAgent::endEpisode(){
double delta = lastReward - oldQ; double delta = lastReward - oldQ;
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume lambda is 0. //Assume lambda is 0. this comment looks wrong.
FA->decayTraces(0); FA->decayTraces(0);//remains 0
} }
if(toSaveWeights && (episodeNumber + 1) % 5 == 0){ if(toSaveWeights && (episodeNumber + 1) % 5 == 0){
......
...@@ -12,10 +12,11 @@ class SarsaAgent:public PolicyAgent{ ...@@ -12,10 +12,11 @@ class SarsaAgent:public PolicyAgent{
double lastState[MAX_STATE_VARS]; double lastState[MAX_STATE_VARS];
int lastAction; int lastAction;
double lastReward; double lastReward;
double lambda;
public: public:
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile); SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
int argmaxQ(double state[]); int argmaxQ(double state[]);
double computeQ(double state[], int action); double computeQ(double state[], int action);
......
...@@ -56,7 +56,7 @@ void purgeFeatures(double *state, const std::vector<float>& state_vec, ...@@ -56,7 +56,7 @@ void purgeFeatures(double *state, const std::vector<float>& state_vec,
// Ignore Angle and Uniform Number of Teammates // Ignore Angle and Uniform Number of Teammates
int temp = i-tmpIndex; int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue; if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue;
if (i > 9+6*numTMates) continue;
state[stateIndex] = state_vec[i]; state[stateIndex] = state_vec[i];
stateIndex++; stateIndex++;
} }
...@@ -107,9 +107,9 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -107,9 +107,9 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
"_" + std::to_string(numTMates + 1) + "_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix); "_" + std::to_string(suffix);
wtFile = &s[0u]; wtFile = &s[0u];
double lambda = 0;
CMAC *fa = new CMAC(numF, numA, range, min, res); CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, fa, wtFile, wtFile); SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo; hfo::HFOEnvironment hfo;
hfo::status_t status; hfo::status_t status;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment