Commit b04603ca authored by Shashank Suhas's avatar Shashank Suhas

Working mod branch

parent 14de35c6
#include "PolicyAgent.h" #include "PolicyAgent.h"
PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile){ // PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile){
PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA1, FunctionApproximator *FA2, char *loadWeightsFile, char *saveWeightsFile){
this->numFeatures = numFeatures; this->numFeatures = numFeatures;
this->numActions = numActions; this->numActions = numActions;
this->learningRate = learningRate; this->learningRate = learningRate;
this->epsilon = epsilon; this->epsilon = epsilon;
this->FA = FA; this->FA1 = FA1;
this->FA2 = FA2;
this->FA = FA1;
toLoadWeights = strlen(loadWeightsFile) > 0; toLoadWeights = strlen(loadWeightsFile) > 0;
if(toLoadWeights){ if(toLoadWeights){
strcpy(this->loadWeightsFile, loadWeightsFile); strcpy(this->loadWeightsFile, loadWeightsFile);
loadWeights(loadWeightsFile); // loadWeights(loadWeightsFile);
} }
toSaveWeights = strlen(saveWeightsFile) > 0; toSaveWeights = strlen(saveWeightsFile) > 0;
if(toSaveWeights){ if(toSaveWeights){
strcpy(this->saveWeightsFile, saveWeightsFile); strcpy(this->saveWeightsFile, saveWeightsFile);
} }
} }
PolicyAgent::~PolicyAgent(){ PolicyAgent::~PolicyAgent(){
} }
void PolicyAgent::switchToSecondFA(){
this->FA = this->FA2;
this->idCurrFA = 1;
}
void PolicyAgent::switchToFirstFA(){
this->FA = this->FA1;
this->idCurrFA = 0;
}
int PolicyAgent::getNumFeatures(){ int PolicyAgent::getNumFeatures(){
return numFeatures; return numFeatures;
} }
...@@ -34,12 +46,14 @@ int PolicyAgent::getNumActions(){ ...@@ -34,12 +46,14 @@ int PolicyAgent::getNumActions(){
} }
void PolicyAgent::loadWeights(char *fileName){ void PolicyAgent::loadWeights(char *fileName){
std::cout << "Loading Weights from " << fileName << std::endl; // std::cout << "Loading Weights from " << fileName << std::endl;
FA->read(fileName); // std::cout << "Doing nothing. Check PolicyAgent.cpp." << std::endl;
// FA->read(fileName);
} }
void PolicyAgent::saveWeights(char *fileName){ void PolicyAgent::saveWeights(char *fileName){
FA->write(fileName); // std::cout<< "Doing nothing. Check PolicyAgent.cpp." << std::endl;
// FA->write(fileName);
} }
int PolicyAgent::argmaxQ(double state[]){ int PolicyAgent::argmaxQ(double state[]){
......
...@@ -29,14 +29,24 @@ class PolicyAgent{ ...@@ -29,14 +29,24 @@ class PolicyAgent{
FunctionApproximator *FA; FunctionApproximator *FA;
FunctionApproximator *FA1;
FunctionApproximator *FA2;
int getNumFeatures(); int getNumFeatures();
int getNumActions(); int getNumActions();
int idCurrFA = 0;
public: public:
PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile); // PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA1, FunctionApproximator *F2, char *loadWeightsFile, char *saveWeightsFile);
~PolicyAgent(); ~PolicyAgent();
void switchToFirstFA();
void switchToSecondFA();
virtual int argmaxQ(double state[]); virtual int argmaxQ(double state[]);
virtual double computeQ(double state[], int action); virtual double computeQ(double state[], int action);
......
#include "SarsaAgent.h" #include "SarsaAgent.h"
//add lambda as parameter to sarsaagent //add lambda as parameter to sarsaagent
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){ // SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA1, FunctionApproximator *FA2, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA1, FA2, loadWeightsFile, saveWeightsFile){
std::cout<<"Num actions: \t"<<numActions<<std::endl; std::cout<<"Num actions: \t"<<numActions<<std::endl;
this->lambda = lambda; this->lambda = lambda;
episodeNumber = 0; episodeNumber = 0;
...@@ -18,19 +19,22 @@ void SarsaAgent::update(double state[], int action, double reward, double discou ...@@ -18,19 +19,22 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
} }
lastAction = action; lastAction = action;
lastReward = reward; lastReward = reward;
FA->setState(lastState);
lastQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
} }
else{ else{
FA->setState(lastState); // FA->setState(lastState);
double oldQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
double delta = lastReward - oldQ; // double oldQ = FA->computeQ(lastAction);
// FA->updateTraces(lastAction);
FA->setState(state); // double delta = lastReward - oldQ;
double delta = lastReward - lastQ;
//Sarsa update //Sarsa update
FA->setState(state);
double newQ = FA->computeQ(action); double newQ = FA->computeQ(action);
// std::cout<<"newQ \t"<<newQ<<std::endl; // std::cout<<"newQ \t"<<newQ<<std::endl;
delta += discountFactor * newQ; delta += discountFactor * newQ;
...@@ -46,13 +50,19 @@ void SarsaAgent::update(double state[], int action, double reward, double discou ...@@ -46,13 +50,19 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
} }
lastAction = action; lastAction = action;
lastReward = reward; lastReward = reward;
lastQ = newQ;
FA->updateTraces(action);
} }
} }
void SarsaAgent::copyWeights(SarsaAgent *agent){ void SarsaAgent::copyWeights(){
dynamic_cast<CMAC*>(FA)->copyWeights(dynamic_cast<CMAC*>(agent->FA)); dynamic_cast<CMAC*>(this->FA1)->copyWeights(dynamic_cast<CMAC*>(this->FA2));
} }
// void SarsaAgent::copyWeights(SarsaAgent *agent){
// dynamic_cast<CMAC*>(FA)->copyWeights(dynamic_cast<CMAC*>(agent->FA));
// }
void SarsaAgent::endEpisode(){ void SarsaAgent::endEpisode(){
episodeNumber++; episodeNumber++;
...@@ -62,10 +72,11 @@ void SarsaAgent::endEpisode(){ ...@@ -62,10 +72,11 @@ void SarsaAgent::endEpisode(){
} }
else{ else{
FA->setState(lastState); // FA->setState(lastState);
double oldQ = FA->computeQ(lastAction); // double oldQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction); // FA->updateTraces(lastAction);
double delta = lastReward - oldQ; // double delta = lastReward - oldQ;
double delta = lastReward - lastQ;
FA->updateWeights(delta, learningRate); FA->updateWeights(delta, learningRate);
//Assume lambda is 0. this comment looks wrong. //Assume lambda is 0. this comment looks wrong.
...@@ -78,7 +89,6 @@ void SarsaAgent::endEpisode(){ ...@@ -78,7 +89,6 @@ void SarsaAgent::endEpisode(){
} }
lastAction = -1; lastAction = -1;
} }
void SarsaAgent::reset(){ void SarsaAgent::reset(){
......
...@@ -15,12 +15,15 @@ class SarsaAgent:public PolicyAgent{ ...@@ -15,12 +15,15 @@ class SarsaAgent:public PolicyAgent{
int lastAction; int lastAction;
double lastReward; double lastReward;
double lambda; double lambda;
double lastQ;
public: public:
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile); // SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA1, FunctionApproximator *FA2, char *loadWeightsFile, char *saveWeightsFile);
void copyWeights(SarsaAgent*); // void copyWeights(SarsaAgent*);
void copyWeights();
int argmaxQ(double state[]); int argmaxQ(double state[]);
double computeQ(double state[], int action); double computeQ(double state[], int action);
......
...@@ -122,12 +122,14 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -122,12 +122,14 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// CMAC fa2(numF, numA, range, min, res); // CMAC fa2(numF, numA, range, min, res);
CMAC *fa1 = new CMAC(numF, numA, range, min, res); CMAC *fa1 = new CMAC(numF, numA, range, min, res);
CMAC *fa2 = new CMAC(numF, numA, range, min, res); CMAC *fa2 = new CMAC(numF, numA, range, min, res);
CMAC *fa = fa1; // CMAC *fa = fa1;
// SarsaAgent sa1(numF, numA, learnR, eps, lambda, &fa1, filename1.c_str(), filename1.c_str()); // SarsaAgent sa1(numF, numA, learnR, eps, lambda, &fa1, filename1.c_str(), filename1.c_str());
// SarsaAgent sa1(numF, numA, learnR, eps, lambda, &fa1, filename2.c_str(), filename2.c_str()); // SarsaAgent sa1(numF, numA, learnR, eps, lambda, &fa1, filename2.c_str(), filename2.c_str());
SarsaAgent *sa1 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, str1, str1); SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, fa2, str1, str1);
SarsaAgent *sa3 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa2, str2, str2);
SarsaAgent *sa = sa1, *sa2 = sa1; // SarsaAgent *sa1 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, str1, str1);
// SarsaAgent *sa3 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa2, str2, str2);
// SarsaAgent *sa = sa1, *sa2 = sa1;
hfo::HFOEnvironment hfo; hfo::HFOEnvironment hfo;
hfo::status_t status; hfo::status_t status;
...@@ -138,19 +140,20 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -138,19 +140,20 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// bool second_model_active = true; // bool second_model_active = true;
hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",6000,"localhost","base_left",false,""); hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",6000,"localhost","base_left",false,"");
for (int episode=0; episode < numEpi; episode++) { for (int episode=0; episode < numEpi; episode++) {
if(episode==6000) if(episode==6000)
{ {
// for(int i=0; i<RL_MEMORY_SIZE; i++) // for(int i=0; i<RL_MEMORY_SIZE; i++)
// sa3->FA->weights[i] = sa1->FA->weights[i]; // sa3->FA->weights[i] = sa1->FA->weights[i];
std::cout<<"Copying weights"<<std::endl; std::cout<<"Copying weights"<<std::endl;
sa3->copyWeights(sa1); sa->copyWeights();
sa2 = sa3;
} }
int count = 0; int count = 0;
status = hfo::IN_GAME; status = hfo::IN_GAME;
action = -1; action = -1;
bool model_changed_flag = false; bool model_changed_flag = false;
int iter_count = -1; int iter_count = -1;
sa->switchToFirstFA();
while (status == hfo::IN_GAME) { while (status == hfo::IN_GAME) {
iter_count++; iter_count++;
const std::vector<float>& state_vec = hfo.getState(); const std::vector<float>& state_vec = hfo.getState();
...@@ -167,7 +170,7 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -167,7 +170,7 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// if(state_vec[numTMates] >= 0.2 && model_changed_flag == false) // if(state_vec[numTMates] >= 0.2 && model_changed_flag == false)
if(iter_count > 100 && model_changed_flag == false) if(iter_count > 100 && model_changed_flag == false)
{ {
sa = sa2; sa->switchToSecondFA();
model_changed_flag = true; model_changed_flag = true;
} }
...@@ -195,15 +198,12 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -195,15 +198,12 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
if(action != -1) { if(action != -1) {
reward = getReward(status); reward = getReward(status);
sa->update(state, action, reward, discFac); sa->update(state, action, reward, discFac);
sa1->endEpisode(); sa->endEpisode();
sa2->endEpisode();
sa = sa1;
model_changed_flag = false; model_changed_flag = false;
} }
} }
delete sa1, sa2; delete sa, fa1, fa2;
delete fa1, fa2;
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment