You need to sign in or sign up before continuing.
Commit 08411cb1 authored by Siddharth Aravindan's avatar Siddharth Aravindan Committed by asiddharth

Refactored SARSA libraries, added hand coded agent

parent 384fb648
#include <vector>
#include <HFO.hpp>
#include <cstdlib>
#include <math.h>
#include <fstream>
using namespace std;
using namespace hfo;
/* Before running this program, first Start HFO server:
../bin/HFO --offense-npcs 2 --defense-agents 1 --defense-npcs 1
This is a hand coded defense agent, which can play a 2v2 game againt 2 offense npcs when paired up with a goal keeper
Server Connection Options. See printouts from bin/HFO.*/
feature_set_t features = HIGH_LEVEL_FEATURE_SET;
string config_dir = "../bin/teams/base/config/formations-dt";
int port = 7000;
string server_addr = "localhost";
string team_name = "base_right";
bool goalie = false;
double kickable_dist = 1.504052352;
double open_area_up_limit_x = 0.747311440447;
double open_area_up_limit_y = 0.229619544504;
double open_area_low_limit_x = -0.352161264597;
double open_area_low_limit_y = 0.140736680776;
double tackle_limit = 1.613456553;
double HALF_FIELD_WIDTH = 68 ; // y coordinate -34 to 34 (-34 = bottom 34 = top)
double HALF_FIELD_LENGTH = 52.5; // x coordinate 0 to 52.5 (0 = goalline 52.5 = center)
struct action_with_params {
action_t action;
double param;
};
// Returns a random high-level action
action_t get_random_high_lv_action() {
action_t action_indx = (action_t) ((rand() % 4) + REDUCE_ANGLE_TO_GOAL);
return action_indx;
}
double get_actual_angle( double normalized_angle) {
return normalized_angle * M_PI;
}
double get_dist_normalized (double ref_x, double ref_y, double src_x, double src_y) {
return sqrt(pow(ref_x - src_x,2) + pow((HALF_FIELD_WIDTH/HALF_FIELD_LENGTH)*(ref_y - src_y),2));
}
bool is_kickable(double ball_pos_x, double ball_pos_y, double opp_pos_x, double opp_pos_y) {
return get_dist_normalized(ball_pos_x, ball_pos_y, opp_pos_x, opp_pos_y) < kickable_dist; //#param
}
bool is_in_open_area(double pos_x, double pos_y) {
if (pos_x < open_area_up_limit_x ) {//&& pos_x > open_area_low_limit_x && pos_y < open_area_up_limit_y && pos_y > open_area_low_limit_y ) { //#param
return false;
} else {
return true;
}
}
action_with_params get_defense_action(const std::vector<float>& state_vec, double no_of_opponents, double numTMates) {
int size_of_vec = 10 + 6*numTMates + 3*no_of_opponents;
if (size_of_vec != state_vec.size()) {
std :: cout <<"Invalid Feature Vector / Check the number of teammates/opponents provided";
return {NOOP,0};
}
double agent_posx = state_vec[0];
double agent_posy = state_vec[1];
double agent_orientation = get_actual_angle(state_vec[2]);
double opp1_unum = state_vec[9+6*numTMates+3];
double opp2_unum = state_vec[9+(1*3)+6*numTMates+3];
double ball_pos_x = state_vec[3];
double ball_pos_y = state_vec[4];
double opp1_pos_x = state_vec[9+6*numTMates+1];
double opp1_pos_y = state_vec[9+6*numTMates+2];
double opp2_pos_x = state_vec[9+(1*3)+6*numTMates+3];
double opp2_pos_y = state_vec[9+(1*3)+6*numTMates+3];
double opp1_dist_to_ball = get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y);
double opp1_dist_to_agent = get_dist_normalized(agent_posx, agent_posy, opp1_pos_x, opp1_pos_y);
bool is_kickable_opp1 = is_kickable(ball_pos_x, ball_pos_y,opp1_pos_x, opp1_pos_y);
bool is_in_open_area_opp1 = is_in_open_area(opp1_pos_x, opp1_pos_y);
double opp2_dist_to_ball = get_dist_normalized(ball_pos_x, ball_pos_y, opp2_pos_x, opp2_pos_y);
double opp2_dist_to_agent = get_dist_normalized(agent_posx, agent_posy, opp2_pos_x, opp2_pos_y);
bool is_kickable_opp2 = is_kickable(ball_pos_x, ball_pos_y,opp2_pos_x, opp2_pos_y);
bool is_in_open_area_opp2 = is_in_open_area(opp2_pos_x, opp2_pos_y);
double tackle_limit_nn = tackle_limit;
if (is_in_open_area(opp1_pos_x, opp1_pos_y) && is_in_open_area(opp2_pos_x, opp2_pos_y)) {
//std:: cout << "In open Area" << "\n";
if (is_kickable(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) &&
get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) <
get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y)) {
return {MARK_PLAYER, opp2_unum};
// return {REDUCE_ANGLE_TO_GOAL, 1};
} else if (is_kickable(ball_pos_x, ball_pos_y, opp2_pos_x, opp2_pos_y)) {
return {MARK_PLAYER, opp1_unum};
// return {REDUCE_ANGLE_TO_GOAL, 1};
} else if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) >
get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy) &&
get_dist_normalized(ball_pos_x, ball_pos_y, opp2_pos_x, opp2_pos_y) >
get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy)) {
return {GO_TO_BALL,1};
} else {
return {REDUCE_ANGLE_TO_GOAL, 1};
}
} else {
//std:: cout << "In Penalty Area" << "\n";
if (! is_kickable(ball_pos_x,ball_pos_y,opp1_pos_x, opp1_pos_y) && ! is_kickable(ball_pos_x,ball_pos_y,opp2_pos_x,opp2_pos_y))
{
//std :: cout <<"IN AREA BUT GOTO\n";
if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) > get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy) &&
get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) > get_dist_normalized(ball_pos_x, ball_pos_y, agent_posx, agent_posy)) {
return {GO_TO_BALL,2};
} else {
return {REDUCE_ANGLE_TO_GOAL, 0};
}
} else if ( get_dist_normalized (agent_posx, agent_posy, opp1_pos_x, opp1_pos_y) < tackle_limit
|| get_dist_normalized (agent_posx, agent_posy, opp2_pos_x, opp2_pos_y) < tackle_limit ) { //#param
/*double turn_angle;
//REVISIT TURN ANGLE.. CONSIDER ACTUAL DIRECTION ALSO..
if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) < get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y)) {
turn_angle = atan2((HALF_FIELD_WIDTH/HALF_FIELD_LENGTH)*(opp1_pos_y), opp1_pos_x) - (agent_orientation*M_PI);
} else {
turn_angle = atan2((HALF_FIELD_WIDTH/HALF_FIELD_LENGTH)*(opp2_pos_y), opp2_pos_x) - (agent_orientation*M_PI);
}
turn_angle = atan2(tan(turn_angle),1);
if (turn_angle > M_PI ) {
turn_angle -= 2*M_PI;
} else if (turn_angle < -M_PI) {
turn_angle += 2*M_PI;
}*/
/*string s = "TACKLE " + std ::to_string (turn_angle) + "\n";
cout << s; */
//return {DASH, turn_angle*180/M_PI};
//return {TACKLE, turn_angle*180/M_PI}; //TACKLE needs power od dir
return {MOVE,0};
} else if ((!is_in_open_area(opp1_pos_x, opp1_pos_y) && is_in_open_area(opp2_pos_x, opp2_pos_y)) ||
(!is_in_open_area(opp2_pos_x, opp2_pos_y) && is_in_open_area(opp1_pos_x, opp1_pos_y))) {
return {REDUCE_ANGLE_TO_GOAL,0};
} else if (!is_in_open_area(opp1_pos_x, opp1_pos_y) && !is_in_open_area(opp2_pos_x, opp2_pos_y)) {
if (get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y) < get_dist_normalized(ball_pos_x, ball_pos_y, opp1_pos_x, opp1_pos_y)) {
return {MARK_PLAYER,opp2_unum};
} else {
return {MARK_PLAYER,opp1_unum};
}
} else {
std :: cout <<"Unknown Condition";
return {NOOP,0};
}
}
}
void read_params() {
std::ifstream fin("params.txt");
double d[6];
int i = 0;
while (true) {
fin >> d[i];
if( fin.eof() ) break;
std::cout << d[i] << std::endl;
i++;
//if (i >=6 ) {
// std::cout << "invalid params" << d[5];
// exit(0);
//}
}
fin.close();
kickable_dist = (d[0]+1)* 0.818175061;
open_area_up_limit_x = d[1];
open_area_up_limit_y = d[2];
open_area_low_limit_x = d[3];
open_area_low_limit_y = d[4];
tackle_limit = (d[5]+1)* 0.818175061;
std :: cout << "kickable dist " << kickable_dist << "tackle_limit" <<tackle_limit;
return;
}
void write_cost(double cost) {
std::ofstream myfile ("cost.txt");
if (myfile.is_open()) {
myfile << cost;
myfile.close();
}
return;
}
int main(int argc, char** argv) {
// Create the HFO environment
//read_params();
HFOEnvironment hfo;
int random = 0;
double numGoals = 0;
double numEpisodes = 5000;
// Connect to the server and request high-level feature set. See
// manual for more information on feature sets.
hfo.connectToServer(features, config_dir, port, server_addr,
team_name, goalie);
for (int episode=0; episode<numEpisodes; episode++) {
status_t status = IN_GAME;
while (status == IN_GAME) {
// Get the vector of state features for the current state
const vector<float>& feature_vec = hfo.getState();
if (random == 0) {
action_with_params a = get_defense_action(feature_vec, 2,1);
// std::cout << a.action << a.param;
if (a.action == hfo :: MARK_PLAYER || a.action == hfo::TACKLE) {
hfo.act(a.action, a.param);
} else if (a.action == hfo :: DASH) {
double power = 100;
hfo.act(a.action, power, a.param);
} else {
hfo.act(a.action);
}
string s = hfo::ActionToString(a.action) + " " +to_string(a.param) + "\n";
// std::cout << s;
} else {
std::cout <<"Randm";
action_t a = get_random_high_lv_action();
if (a == hfo :: MARK_PLAYER) {
hfo.act(NOOP);
} else {
hfo.act(a);
}
}
//hfo.act(hfo::INTERCEPT);
status = hfo.step();
}
if (status==GOAL)
numGoals = numGoals+1;
// Check what the outcome of the episode was
cout << "Episode " << episode << " ended with status: "
<< StatusToString(status) << std::endl;
}
double cost = numGoals/numEpisodes;
hfo.act(QUIT);
//write_cost(cost);
};
# Run this file to create an executable of hand_coded_defense_agent.cpp
g++ -c hand_coded_defense_agent.cpp -I ../src/ -std=c++0x -pthread
g++ -L ../lib/ hand_coded_defense_agent.o -lhfo -pthread -o hand_coded_defense_agent -Wl,-rpath,../lib
#Directories
FA_DIR = ./funcapprox
POLICY_DIR = ./policy
HFO_SRC_DIR = ../../src
HFO_LIB_DIR = ../../lib
#Includes
INCLUDES = -I$(FA_DIR) -I$(POLICY_DIR) -I$(HFO_SRC_DIR)
#Libs
FA_LIB = funcapprox
POLICY_LIB = policyagent
#Flags
CXXFLAGS = -g -Wall -std=c++11 -pthread
LDFLAGS = -l$(FA_LIB) -l$(POLICY_LIB) -lhfo -pthread
LDLIBS = -L$(FA_DIR) -L$(POLICY_DIR) -L$(HFO_LIB_DIR)
LINKEROPTIONS = -Wl,-rpath,$(HFO_LIB_DIR)
#Compiler
CXX = g++
#Sources
SRC = high_level_sarsa_agent.cpp
#Objects
OBJ = $(SRC:.cpp=.o)
#Target
TARGET = high_level_sarsa_agent
#Rules
.PHONY: $(FA_LIB)
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(FA_LIB):
$(MAKE) -C $(FA_DIR)
$(POLICY_LIB):
$(MAKE) -C $(POLICY_DIR)
$(TARGET): $(FA_LIB) $(POLICY_LIB) $(OBJ)
$(CXX) $(OBJ) $(CXXFLAGS) $(LDLIBS) $(LDFLAGS) -o $(TARGET) $(LINKEROPTIONS)
cleanfa:
$(MAKE) clean -C $(FA_DIR)
cleanpolicy:
$(MAKE) clean -C $(POLICY_DIR)
clean: cleanfa cleanpolicy
rm -f $(TARGET) $(OBJ) *~
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = FuncApprox.cpp tiles2.cpp CMAC.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libfuncapprox.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#Directories
FA_DIR = ../funcapprox
#Includes
INCLUDES = -I$(FA_DIR)
#Flags
CXXFLAGS = -g -O3 -Wall
#Compiler
CXX = g++
#Sources
SRCS = PolicyAgent.cpp SarsaAgent.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
#Target
TARGET = libpolicyagent.a
#Rules
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
clean:
rm -f $(TARGET) $(OBJS) *~
#include <iostream>
#include <vector>
#include <HFO.hpp>
#include <cstdlib>
#include <thread>
#include "SarsaAgent.h"
#include "CMAC.h"
#include <unistd.h>
// Before running this program, first Start HFO server:
// $./bin/HFO --offense-agents numAgents
void printUsage() {
std::cout<<"Usage:123 ./high_level_sarsa_agent [Options]"<<std::endl;
std::cout<<"Options:"<<std::endl;
std::cout<<" --numAgents <int> Number of SARSA agents"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --numEpisodes <int> Number of episodes to run"<<std::endl;
std::cout<<" Default: 10"<<std::endl;
std::cout<<" --basePort <int> SARSA agent base port"<<std::endl;
std::cout<<" Default: 6000"<<std::endl;
std::cout<<" --learnRate <float> Learning rate of SARSA agents"<<std::endl;
std::cout<<" Range: [0.0, 1.0]"<<std::endl;
std::cout<<" Default: 0.1"<<std::endl;
std::cout<<" --suffix <int> Suffix for weights files"<<std::endl;
std::cout<<" Default: 0"<<std::endl;
std::cout<<" --noOpponent Sets opponent present flag to false"<<std::endl;
std::cout<<" --step Sets the persistent step size"<<std::endl;
std::cout<<" --eps Sets the exploration rate"<<std::endl;
std::cout<<" --numOpponents Sets the number of opponents"<<std::endl;
std::cout<<" --weightId Sets the given Id for weight File"<<std::endl;
std::cout<<" --help Displays this help and exit"<<std::endl;
}
// Returns the reward for SARSA based on current state
double getReward(hfo::status_t status) {
double reward;
if (status==hfo::GOAL) reward = -1;
else if (status==hfo::CAPTURED_BY_DEFENSE) reward = 1;
else if (status==hfo::OUT_OF_BOUNDS) reward = 1;
else reward = 0;
return reward;
}
// Fill state with only the required features from state_vec
void purgeFeatures(double *state, const std::vector<float>& state_vec,
int numTMates, int numOpponents, bool oppPres) {
int stateIndex = 0;
// If no opponents ignore features Distance to Opponent
// and Distance from Teammate i to Opponent are absent
int tmpIndex = oppPres ? (9 + 3 * numTMates) : (9 + 2 * numTMates);
for(int i = 0; i < state_vec.size(); i++) {
// Ignore first six featues
if(i == 5 || i==8) continue;
if(i>9 && i<= 9+numTMates) continue; // Ignore Goal Opening angles, as invalid
if(i<= 9+3*numTMates && i > 9+2*numTMates) continue; // Ignore Pass Opening angles, as invalid
// Ignore Uniform Number of Teammates and opponents
int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 0) )continue;
//if (i > 9+6*numTMates) continue;
state[stateIndex] = state_vec[i];
stateIndex++;
}
}
// Convert int to hfo::Action
hfo::action_t toAction(int action, const std::vector<float>& state_vec) {
hfo::action_t a;
switch (action) {
case 0: a = hfo::MOVE; break;
case 1: a = hfo::REDUCE_ANGLE_TO_GOAL; break;
case 2: a = hfo::GO_TO_BALL; break;
case 3: a = hfo::NOOP; break;
case 4: a = hfo::DEFEND_GOAL; break;
default : a = hfo::MARK_PLAYER; break;
}
return a;
}
void offenseAgent(int port, int numTMates, int numOpponents, int numEpi, double learnR,
int suffix, bool oppPres, double eps, int step, std::string weightid) {
// Number of features
int numF = oppPres ? (8 + 3 * numTMates + 2*numOpponents) : (3 + 3 * numTMates);
// Number of actions
int numA = 5 + numOpponents; //DEF_GOAL+MOVE+GTB+NOOP+RATG+MP(unum)
// Other SARSA parameters
eps = 0.01;
double discFac = 1;
double lambda=0.9375;
// Tile coding parameter
double resolution = 0.1;
double range[numF];
double min[numF];
double res[numF];
for(int i = 0; i < numF; i++) {
min[i] = -1;
range[i] = 2;
res[i] = resolution;
}
// Weights file
char *wtFile;
std::string s = "weights_" + std::to_string(port) +
"_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix) +
"_" + std::to_string(step) +
"_" + weightid;
wtFile = &s[0u];
CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo;
hfo::status_t status;
hfo::action_t a;
double state[numF];
int action = -1;
double reward;
int no_of_offense = numTMates + 1;
hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",port,"localhost","base_right",false,"");
for (int episode=0; episode < numEpi; episode++) {
int count = 0;
status = hfo::IN_GAME;
action = -1;
int count_steps = 0;
double unum = -1;
const std::vector<float>& state_vec = hfo.getState();
int num_steps_per_epi = 0;
while (status == hfo::IN_GAME) {
num_steps_per_epi++;
//std::cout << "::::::"<< hfo::ActionToString(a) <<" "<<count_steps <<std::endl;
if (count_steps != step && action >=0 && (a != hfo :: MARK_PLAYER || unum>0)) {
count_steps ++;
if (a == hfo::MARK_PLAYER) {
hfo.act(a,unum);
//std::cout << "MARKING" << unum <<"\n";
} else {
hfo.act(a);
}
status = hfo.step();
continue;
} else {
count_steps = 0;
}
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
}
// Fill up state array
purgeFeatures(state, state_vec, numTMates, numOpponents, oppPres);
// Get raw action
action = sa->selectAction(state);
// Get hfo::Action
a = toAction(action, state_vec);
if (a== hfo::MARK_PLAYER) {
unum = state_vec[(state_vec.size()-1 - (action-5)*3)];
hfo.act(a,unum);
} else {
hfo.act(a);
}
std::string s = std::to_string(action);
for (int state_vec_fc=0; state_vec_fc < state_vec.size(); state_vec_fc++) {
s+=std::to_string(state_vec[state_vec_fc]) + ",";
}
s+="UNUM" +std::to_string(unum) +"\n";;
status = hfo.step();
// std::cout <<s;
}
//std :: cout <<":::::::::::::" << num_steps_per_epi<< " "<<step << " "<<"\n";
// End of episode
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
sa->endEpisode();
}
}
delete sa;
delete fa;
}
int main(int argc, char **argv) {
int numAgents = 0;
int numEpisodes = 10;
int basePort = 6000;
double learnR = 0.1;
int suffix = 0;
bool opponentPresent = true;
int numOpponents = 0;
double eps = 0.01;
int step = 10;
std::string weightid;
for (int i = 0; i<argc; i++) {
std::string param = std::string(argv[i]);
std::cout<<param<<"\n";
}
for(int i = 1; i < argc; i++) {
std::string param = std::string(argv[i]);
if(param == "--numAgents") {
numAgents = atoi(argv[++i]);
}else if(param == "--numEpisodes") {
numEpisodes = atoi(argv[++i]);
}else if(param == "--basePort") {
basePort = atoi(argv[++i]);
}else if(param == "--learnRate") {
learnR = atof(argv[++i]);
if(learnR < 0 || learnR > 1) {
printUsage();
return 0;
}
}else if(param == "--suffix") {
suffix = atoi(argv[++i]);
}else if(param == "--noOpponent") {
opponentPresent = false;
}else if(param=="--eps"){
eps=atoi(argv[++i]);
}else if(param=="--numOpponents"){
numOpponents=atoi(argv[++i]);
}else if(param=="--step"){
step=atoi(argv[++i]);
}else if(param=="--weightId"){
weightid=std::string(argv[++i]);
}else {
printUsage();
return 0;
}
}
int numTeammates = numAgents; //using goalie npc
std::thread agentThreads[numAgents];
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent] = std::thread(offenseAgent, basePort + agent,
numTeammates, numOpponents, numEpisodes, learnR,
suffix, opponentPresent, eps, step, weightid);
usleep(500000L);
}
for (int agent = 0; agent < numAgents; agent++) {
agentThreads[agent].join();
}
return 0;
}
#include "SarsaAgent.h" #include "SarsaAgent.h"
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){ //add lambda as parameter to sarsaagent
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
this->lambda = lambda;
episodeNumber = 0; episodeNumber = 0;
lastAction = -1; lastAction = -1;
//have memory for lambda
} }
void SarsaAgent::update(double state[], int action, double reward, double discountFactor){ void SarsaAgent::update(double state[], int action, double reward, double discountFactor){
...@@ -34,7 +35,7 @@ void SarsaAgent::update(double state[], int action, double reward, double discou ...@@ -34,7 +35,7 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume gamma, lambda are 0. //Assume gamma, lambda are 0.
FA->decayTraces(0); FA->decayTraces(discountFactor*lambda);//replace 0 with gamma*lambda
for(int i = 0; i < getNumFeatures(); i++){ for(int i = 0; i < getNumFeatures(); i++){
lastState[i] = state[i]; lastState[i] = state[i];
...@@ -59,8 +60,8 @@ void SarsaAgent::endEpisode(){ ...@@ -59,8 +60,8 @@ void SarsaAgent::endEpisode(){
double delta = lastReward - oldQ; double delta = lastReward - oldQ;
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume lambda is 0. //Assume lambda is 0. this comment looks wrong.
FA->decayTraces(0); FA->decayTraces(0);//remains 0
} }
if(toSaveWeights && (episodeNumber + 1) % 5 == 0){ if(toSaveWeights && (episodeNumber + 1) % 5 == 0){
......
...@@ -12,10 +12,11 @@ class SarsaAgent:public PolicyAgent{ ...@@ -12,10 +12,11 @@ class SarsaAgent:public PolicyAgent{
double lastState[MAX_STATE_VARS]; double lastState[MAX_STATE_VARS];
int lastAction; int lastAction;
double lastReward; double lastReward;
double lambda;
public: public:
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile); SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
int argmaxQ(double state[]); int argmaxQ(double state[]);
double computeQ(double state[], int action); double computeQ(double state[], int action);
......
...@@ -56,7 +56,7 @@ void purgeFeatures(double *state, const std::vector<float>& state_vec, ...@@ -56,7 +56,7 @@ void purgeFeatures(double *state, const std::vector<float>& state_vec,
// Ignore Angle and Uniform Number of Teammates // Ignore Angle and Uniform Number of Teammates
int temp = i-tmpIndex; int temp = i-tmpIndex;
if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue; if(temp > 0 && (temp % 3 == 2 || temp % 3 == 0)) continue;
if (i > 9+6*numTMates) continue;
state[stateIndex] = state_vec[i]; state[stateIndex] = state_vec[i];
stateIndex++; stateIndex++;
} }
...@@ -107,9 +107,9 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -107,9 +107,9 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
"_" + std::to_string(numTMates + 1) + "_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix); "_" + std::to_string(suffix);
wtFile = &s[0u]; wtFile = &s[0u];
double lambda = 0;
CMAC *fa = new CMAC(numF, numA, range, min, res); CMAC *fa = new CMAC(numF, numA, range, min, res);
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, fa, wtFile, wtFile); SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
hfo::HFOEnvironment hfo; hfo::HFOEnvironment hfo;
hfo::status_t status; hfo::status_t status;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment