Commit b04603ca authored by Shashank Suhas's avatar Shashank Suhas

Working mod branch

parent 14de35c6
#include "PolicyAgent.h"
PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile){
// PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile){
PolicyAgent::PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA1, FunctionApproximator *FA2, char *loadWeightsFile, char *saveWeightsFile){
this->numFeatures = numFeatures;
this->numActions = numActions;
this->learningRate = learningRate;
this->epsilon = epsilon;
this->FA = FA;
this->FA1 = FA1;
this->FA2 = FA2;
this->FA = FA1;
toLoadWeights = strlen(loadWeightsFile) > 0;
if(toLoadWeights){
strcpy(this->loadWeightsFile, loadWeightsFile);
loadWeights(loadWeightsFile);
// loadWeights(loadWeightsFile);
}
toSaveWeights = strlen(saveWeightsFile) > 0;
if(toSaveWeights){
strcpy(this->saveWeightsFile, saveWeightsFile);
}
}
PolicyAgent::~PolicyAgent(){
}
void PolicyAgent::switchToSecondFA(){
this->FA = this->FA2;
this->idCurrFA = 1;
}
void PolicyAgent::switchToFirstFA(){
this->FA = this->FA1;
this->idCurrFA = 0;
}
int PolicyAgent::getNumFeatures(){
return numFeatures;
}
......@@ -34,12 +46,14 @@ int PolicyAgent::getNumActions(){
}
void PolicyAgent::loadWeights(char *fileName){
std::cout << "Loading Weights from " << fileName << std::endl;
FA->read(fileName);
// std::cout << "Loading Weights from " << fileName << std::endl;
// std::cout << "Doing nothing. Check PolicyAgent.cpp." << std::endl;
// FA->read(fileName);
}
void PolicyAgent::saveWeights(char *fileName){
FA->write(fileName);
// std::cout<< "Doing nothing. Check PolicyAgent.cpp." << std::endl;
// FA->write(fileName);
}
int PolicyAgent::argmaxQ(double state[]){
......
......@@ -29,14 +29,24 @@ class PolicyAgent{
FunctionApproximator *FA;
FunctionApproximator *FA1;
FunctionApproximator *FA2;
int getNumFeatures();
int getNumActions();
int idCurrFA = 0;
public:
PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
// PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
PolicyAgent(int numFeatures, int numActions, double learningRate, double epsilon, FunctionApproximator *FA1, FunctionApproximator *F2, char *loadWeightsFile, char *saveWeightsFile);
~PolicyAgent();
void switchToFirstFA();
void switchToSecondFA();
virtual int argmaxQ(double state[]);
virtual double computeQ(double state[], int action);
......
#include "SarsaAgent.h"
//add lambda as parameter to sarsaagent
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
// SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA, loadWeightsFile, saveWeightsFile){
SarsaAgent::SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA1, FunctionApproximator *FA2, char *loadWeightsFile, char *saveWeightsFile):PolicyAgent(numFeatures, numActions, learningRate, epsilon, FA1, FA2, loadWeightsFile, saveWeightsFile){
std::cout<<"Num actions: \t"<<numActions<<std::endl;
this->lambda = lambda;
episodeNumber = 0;
......@@ -18,19 +19,22 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
}
lastAction = action;
lastReward = reward;
FA->setState(lastState);
lastQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
}
else{
FA->setState(lastState);
double oldQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
// FA->setState(lastState);
double delta = lastReward - oldQ;
// double oldQ = FA->computeQ(lastAction);
// FA->updateTraces(lastAction);
FA->setState(state);
// double delta = lastReward - oldQ;
double delta = lastReward - lastQ;
//Sarsa update
FA->setState(state);
double newQ = FA->computeQ(action);
// std::cout<<"newQ \t"<<newQ<<std::endl;
delta += discountFactor * newQ;
......@@ -46,13 +50,19 @@ void SarsaAgent::update(double state[], int action, double reward, double discou
}
lastAction = action;
lastReward = reward;
lastQ = newQ;
FA->updateTraces(action);
}
}
void SarsaAgent::copyWeights(SarsaAgent *agent){
dynamic_cast<CMAC*>(FA)->copyWeights(dynamic_cast<CMAC*>(agent->FA));
void SarsaAgent::copyWeights(){
dynamic_cast<CMAC*>(this->FA1)->copyWeights(dynamic_cast<CMAC*>(this->FA2));
}
// void SarsaAgent::copyWeights(SarsaAgent *agent){
// dynamic_cast<CMAC*>(FA)->copyWeights(dynamic_cast<CMAC*>(agent->FA));
// }
void SarsaAgent::endEpisode(){
episodeNumber++;
......@@ -62,10 +72,11 @@ void SarsaAgent::endEpisode(){
}
else{
FA->setState(lastState);
double oldQ = FA->computeQ(lastAction);
FA->updateTraces(lastAction);
double delta = lastReward - oldQ;
// FA->setState(lastState);
// double oldQ = FA->computeQ(lastAction);
// FA->updateTraces(lastAction);
// double delta = lastReward - oldQ;
double delta = lastReward - lastQ;
FA->updateWeights(delta, learningRate);
//Assume lambda is 0. this comment looks wrong.
......@@ -78,7 +89,6 @@ void SarsaAgent::endEpisode(){
}
lastAction = -1;
}
void SarsaAgent::reset(){
......
......@@ -15,12 +15,15 @@ class SarsaAgent:public PolicyAgent{
int lastAction;
double lastReward;
double lambda;
double lastQ;
public:
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
// SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA, char *loadWeightsFile, char *saveWeightsFile);
SarsaAgent(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, FunctionApproximator *FA1, FunctionApproximator *FA2, char *loadWeightsFile, char *saveWeightsFile);
void copyWeights(SarsaAgent*);
// void copyWeights(SarsaAgent*);
void copyWeights();
int argmaxQ(double state[]);
double computeQ(double state[], int action);
......
......@@ -122,12 +122,14 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// CMAC fa2(numF, numA, range, min, res);
CMAC *fa1 = new CMAC(numF, numA, range, min, res);
CMAC *fa2 = new CMAC(numF, numA, range, min, res);
CMAC *fa = fa1;
// CMAC *fa = fa1;
// SarsaAgent sa1(numF, numA, learnR, eps, lambda, &fa1, filename1.c_str(), filename1.c_str());
// SarsaAgent sa1(numF, numA, learnR, eps, lambda, &fa1, filename2.c_str(), filename2.c_str());
SarsaAgent *sa1 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, str1, str1);
SarsaAgent *sa3 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa2, str2, str2);
SarsaAgent *sa = sa1, *sa2 = sa1;
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, fa2, str1, str1);
// SarsaAgent *sa1 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, str1, str1);
// SarsaAgent *sa3 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa2, str2, str2);
// SarsaAgent *sa = sa1, *sa2 = sa1;
hfo::HFOEnvironment hfo;
hfo::status_t status;
......@@ -138,19 +140,20 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// bool second_model_active = true;
hfo.connectToServer(hfo::HIGH_LEVEL_FEATURE_SET,"../../bin/teams/base/config/formations-dt",6000,"localhost","base_left",false,"");
for (int episode=0; episode < numEpi; episode++) {
if(episode==6000)
{
// for(int i=0; i<RL_MEMORY_SIZE; i++)
// sa3->FA->weights[i] = sa1->FA->weights[i];
std::cout<<"Copying weights"<<std::endl;
sa3->copyWeights(sa1);
sa2 = sa3;
sa->copyWeights();
}
int count = 0;
status = hfo::IN_GAME;
action = -1;
bool model_changed_flag = false;
int iter_count = -1;
sa->switchToFirstFA();
while (status == hfo::IN_GAME) {
iter_count++;
const std::vector<float>& state_vec = hfo.getState();
......@@ -167,7 +170,7 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// if(state_vec[numTMates] >= 0.2 && model_changed_flag == false)
if(iter_count > 100 && model_changed_flag == false)
{
sa = sa2;
sa->switchToSecondFA();
model_changed_flag = true;
}
......@@ -195,15 +198,12 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
if(action != -1) {
reward = getReward(status);
sa->update(state, action, reward, discFac);
sa1->endEpisode();
sa2->endEpisode();
sa = sa1;
sa->endEpisode();
model_changed_flag = false;
}
}
delete sa1, sa2;
delete fa1, fa2;
delete sa, fa1, fa2;
}
int main(int argc, char **argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment