1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-03-17 20:58:07 +02:00

[programming challenge, SSN] Framework for storing and generating learning examples.

This commit is contained in:
Michał W. Urbańczyk 2012-05-28 11:03:20 +00:00
parent 867d01dc34
commit 3d16f0a081
2 changed files with 362 additions and 34 deletions

View File

@ -32,6 +32,9 @@
#include <boost/thread.hpp> #include <boost/thread.hpp>
#include <boost/bind.hpp> #include <boost/bind.hpp>
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include <boost/filesystem.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <string>
using boost::format; using boost::format;
using boost::str; using boost::str;

View File

@ -1,8 +1,9 @@
//#include "../global.h" //#include "../global.h"
#include "StdInc.h" #include "StdInc.h"
#include "../lib/VCMI_Lib.h" #include "../lib/VCMI_Lib.h"
#include "boost/tuple/tuple.hpp"
namespace po = boost::program_options; namespace po = boost::program_options;
namespace fs = boost::filesystem;
using namespace std;
//FANN //FANN
@ -15,6 +16,221 @@ std::string servername;
std::string runnername; std::string runnername;
extern DLL_EXPORT LibClasses * VLC; extern DLL_EXPORT LibClasses * VLC;
struct Example
{
//ANN input
DuelParameters dp;
CArtifactInstance *art;
//ANN expected output
double value;
//other
std::string description;
int i, j, k; //helper values for identification
Example(){}
Example(const DuelParameters &Dp, CArtifactInstance *Art, double Value)
: dp(Dp), art(Art), value(Value)
{}
inline bool operator<(const Example & rhs) const
{
if (k<rhs.k)
return true;
if (k>rhs.k)
return false;
if (j<rhs.j)
return true;
if (j>rhs.j)
return false;
if (i<rhs.i)
return true;
if (i>rhs.i)
return false;
return false;
}
bool operator==(const Example &rhs) const
{
return rhs.i == i && rhs.j == j && rhs.k == k;
}
template <typename Handler> void serialize(Handler &h, const int version)
{
h & dp & art & value & description & i & j & k;
}
};
vector<string> getFileNames(const string &dirname = "./examples/", const std::string &ext = "example")
{
vector<string> ret;
if(!fs::exists(dirname))
{
tlog1 << "Cannot find " << dirname << " directory! Will attempt creating it.\n";
fs::create_directory(dirname);
}
fs::path tie(dirname);
fs::directory_iterator end_iter;
for ( fs::directory_iterator file (tie); file!=end_iter; ++file )
{
if(fs::is_regular_file(file->status())
&& boost::ends_with(file->path().filename(), ext))
{
ret.push_back(file->path().string());
}
}
return ret;
}
vector<Example> loadExamples(bool printInfo = true)
{
std::vector<Example> examples;
BOOST_FOREACH(auto fname, getFileNames("./examples/", "example"))
{
CLoadFile loadf(fname);
Example ex;
loadf >> ex;
examples.push_back(ex);
}
if(printInfo)
{
tlog0 << "Found " << examples.size() << " examples.\n";
BOOST_FOREACH(auto &ex, examples)
{
tlog0 << format("Battle on army %d for bonus %d of value %d has resultdiff %lf\n") % ex.i % ex.j % ex.k % ex.value;
}
}
return examples;
}
bool matchExample(const Example &ex, int i, int j, int k)
{
return ex.i == i && ex.j == j && ex.k == k;
}
//generates simple duel where both sides have given army
DuelParameters generateDuel(const ArmyDescriptor &ad)
{
DuelParameters dp;
dp.bfieldType = 1;
dp.terType = 1;
auto &side = dp.sides[0];
side.heroId = 0;
side.heroPrimSkills.resize(4,0);
BOOST_FOREACH(auto &stack, ad)
{
side.stacks[stack.first] = DuelParameters::SideSettings::StackSettings(stack.second.type->idNumber, stack.second.count);
}
dp.sides[1] = side;
dp.sides[1].heroId = 1;
return dp;
}
std::vector<ArmyDescriptor> learningArmies()
{
std::vector<ArmyDescriptor> ret;
//armia zlozona ze stworow z malymi HP-kami
ArmyDescriptor lowHP;
lowHP[0] = CStackBasicDescriptor(1, 9); //halabardier
lowHP[1] = CStackBasicDescriptor(14, 20); //centaur
lowHP[2] = CStackBasicDescriptor(139, 123); //chlop
lowHP[3] = CStackBasicDescriptor(70, 30); //troglodyta
lowHP[4] = CStackBasicDescriptor(42, 50); //imp
//armia zlozona z poteznaych stworow
ArmyDescriptor highHP;
highHP[0] = CStackBasicDescriptor(13, 17); //archaniol
highHP[1] = CStackBasicDescriptor(132, 8); //azure dragon
highHP[2] = CStackBasicDescriptor(133, 10); //crystal dragon
highHP[3] = CStackBasicDescriptor(83, 22); //black dragon
//armia zlozona z tygodniowego przyrostu w zamku
auto &castleTown = VLC->townh->towns[0];
ArmyDescriptor castleNormal;
for(int i = 0; i < 7; i++)
{
auto &cre = VLC->creh->creatures[castleTown.basicCreatures[i]];
castleNormal[i] = CStackBasicDescriptor(cre.get(), cre->growth);
}
castleNormal[5].type = VLC->creh->creatures[52]; //replace cavaliers with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z tygodniowego ulepszonego przyrostu w ramparcie
auto &rampartTown = VLC->townh->towns[1];
ArmyDescriptor rampartUpgraded;
for(int i = 0; i < 7; i++)
{
auto &cre = VLC->creh->creatures[rampartTown.upgradedCreatures[i]];
rampartUpgraded[i] = CStackBasicDescriptor(cre.get(), cre->growth);
}
rampartUpgraded[5].type = VLC->creh->creatures[52]; //replace unicorn with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z samych strzelcow
ArmyDescriptor shooters;
shooters[0] = CStackBasicDescriptor(35, 17); //arcymag
shooters[1] = CStackBasicDescriptor(41, 1); //titan
shooters[2] = CStackBasicDescriptor(3, 70); //kusznik
shooters[3] = CStackBasicDescriptor(89, 50); //ulepszony ork
ret.push_back(lowHP);
ret.push_back(highHP);
ret.push_back(castleNormal);
ret.push_back(rampartUpgraded);
ret.push_back(shooters);
return ret;
}
std::vector<Bonus> learningBonuses()
{
std::vector<Bonus> ret;
Bonus b;
b.type = Bonus::PRIMARY_SKILL;
b.subtype = PrimarySkill::ATTACK;
ret.push_back(b);
b.subtype = PrimarySkill::DEFENSE;
ret.push_back(b);
b.type = Bonus::STACK_HEALTH;
b.subtype = 0;
ret.push_back(b);
b.type = Bonus::STACKS_SPEED;
ret.push_back(b);
b.type = Bonus::BLOCKS_RETALIATION;
ret.push_back(b);
b.type = Bonus::ADDITIONAL_RETALIATION;
ret.push_back(b);
b.type = Bonus::ADDITIONAL_ATTACK;
ret.push_back(b);
b.type = Bonus::CREATURE_DAMAGE;
ret.push_back(b);
b.type = Bonus::ALWAYS_MAXIMUM_DAMAGE;
ret.push_back(b);
b.type = Bonus::NO_DISTANCE_PENALTY;
ret.push_back(b);
return ret;
}
std::string addQuotesIfNeeded(const std::string &s) std::string addQuotesIfNeeded(const std::string &s)
{ {
if(s.find_first_of(' ') != std::string::npos) if(s.find_first_of(' ') != std::string::npos)
@ -30,7 +246,7 @@ void prog_help()
void runCommand(const std::string &command, const std::string &name, const std::string &logsDir = "") void runCommand(const std::string &command, const std::string &name, const std::string &logsDir = "")
{ {
static std::string commands[100]; static std::string commands[100000];
static int i = 0; static int i = 0;
std::string &cmd = commands[i++]; std::string &cmd = commands[i++];
if(logsDir.size() && name.size()) if(logsDir.size() && name.size())
@ -46,13 +262,14 @@ void runCommand(const std::string &command, const std::string &name, const std::
double playBattle(const DuelParameters &dp) double playBattle(const DuelParameters &dp)
{ {
string battleFileName = "pliczek.ssnb";
{ {
CSaveFile out("pliczek.ssnb"); CSaveFile out(battleFileName);
out << dp; out << dp;
} }
std::string serverCommand = servername + " " + addQuotesIfNeeded(battle) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : ""); std::string serverCommand = servername + " " + addQuotesIfNeeded(battleFileName) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : "");
std::string runnerCommand = runnername + " " + addQuotesIfNeeded(logsDir); std::string runnerCommand = runnername + " " + addQuotesIfNeeded(logsDir);
std::cout <<"Server command: " << serverCommand << std::endl << "Runner command: " << runnerCommand << std::endl; std::cout <<"Server command: " << serverCommand << std::endl << "Runner command: " << runnerCommand << std::endl;
@ -81,10 +298,6 @@ typedef std::map<int, CArtifactInstance*> TArtSet;
double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR) double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
{ {
//lewa strona z art 0.9
//bez artefaktow -0.41
//prawa strona z art. -0.926
dp.sides[0].artifacts = setL; dp.sides[0].artifacts = setL;
dp.sides[1].artifacts = setR; dp.sides[1].artifacts = setR;
@ -92,23 +305,32 @@ double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
return battleOutcome; return battleOutcome;
} }
std::vector<CArtifactInstance*> genArts(const std::vector<Bonus> & bonusesToGive) CArtifactInstance *generateArtWithBonus(const Bonus &b)
{ {
std::vector<CArtifactInstance*> ret; std::vector<CArtifactInstance*> ret;
CArtifact *nowy = new CArtifact(); static CArtifact *nowy = NULL;
if(!nowy)
{
nowy = new CArtifact();
nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo"; nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo";
nowy->name = "Cudowny miecz"; nowy->name = "Cudowny miecz";
nowy->constituentOf = nowy->constituents = NULL; nowy->constituentOf = nowy->constituents = NULL;
nowy->possibleSlots.push_back(Arts::LEFT_HAND); nowy->possibleSlots.push_back(Arts::LEFT_HAND);
}
CArtifactInstance *artinst = new CArtifactInstance(nowy);
artinst->addNewBonus(new Bonus(b));
return artinst;
}
std::vector<CArtifactInstance*> genArts(const std::vector<Bonus> & bonusesToGive)
{
std::vector<CArtifactInstance*> ret;
BOOST_FOREACH(auto b, bonusesToGive) BOOST_FOREACH(auto b, bonusesToGive)
{ {
CArtifactInstance *artinst = new CArtifactInstance(nowy); ret.push_back(generateArtWithBonus(b));
auto &arts = VLC->arth->artifacts;
artinst->addNewBonus(new Bonus(b));
ret.push_back(artinst);
} }
// auto bonuses = artinst->getBonuses([](const Bonus *){ return true; }); // auto bonuses = artinst->getBonuses([](const Bonus *){ return true; });
@ -130,8 +352,14 @@ double rateArt(const DuelParameters dp, CArtifactInstance * inst)
resultRL = cmpArtSets(dp, setR, setL), resultRL = cmpArtSets(dp, setR, setL),
resultsBase = cmpArtSets(dp, TArtSet(), TArtSet()); resultsBase = cmpArtSets(dp, TArtSet(), TArtSet());
//lewa strona z art 0.9
//bez artefaktow -0.41
//prawa strona z art. -0.926
double LRgain = resultLR - resultsBase, double LRgain = resultLR - resultsBase,
RLgain = resultRL - resultsBase; RLgain = resultsBase - resultRL;
return LRgain+RLgain; return LRgain+RLgain;
} }
@ -214,7 +442,7 @@ int ANNCallback(FANN::neural_net &net, FANN::training_data &train,
return 0; return 0;
} }
void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > & input) void learnSSN(FANN::neural_net & net, const std::vector<Example> & input)
{ {
FANN::training_data td; FANN::training_data td;
@ -222,9 +450,9 @@ void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParamet
double ** outputs = new double *[input.size()]; double ** outputs = new double *[input.size()];
for(int i=0; i<input.size(); ++i) for(int i=0; i<input.size(); ++i)
{ {
inputs[i] = genSSNinput(input[i].get<0>(), input[i].get<1>()); inputs[i] = genSSNinput(input[i].dp, input[i].art);
outputs[i] = new double; outputs[i] = new double;
*(outputs[i]) = input[i].get<2>(); *(outputs[i]) = input[i].value;
} }
td.set_train_data(input.size(), num_input, inputs, 1, outputs); td.set_train_data(input.size(), num_input, inputs, 1, outputs);
net.set_callback(ANNCallback, NULL); net.set_callback(ANNCallback, NULL);
@ -245,7 +473,7 @@ void initNet(FANN::neural_net & ret)
ret.set_learning_rate(learning_rate); ret.set_learning_rate(learning_rate);
ret.set_activation_steepness_hidden(1.0); ret.set_activation_steepness_hidden(0.9);
ret.set_activation_steepness_output(1.0); ret.set_activation_steepness_output(1.0);
ret.set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE); ret.set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE);
@ -286,20 +514,14 @@ void SSNRun()
// } // }
//duels to test on //duels to test on
std::vector<DuelParameters> dps; std::vector<DuelParameters> dps;
for(int k = 0; k<10; ++k) for(int k = 0; k<10; ++k)
{ {
DuelParameters dp; DuelParameters dp;
dp.bfieldType = 1;
dp.terType = 1;
auto &side = dp.sides[0];
side.heroId = 0;
side.heroPrimSkills.resize(4,0);
side.stacks[0] = DuelParameters::SideSettings::StackSettings(10+k*3, rand()%30);
dp.sides[1] = side;
dp.sides[1].heroId = 1;
dps.push_back(dp); dps.push_back(dp);
} }
@ -307,6 +529,14 @@ void SSNRun()
for(int i=0; i<5; ++i) for(int i=0; i<5; ++i)
{ {
Bonus b; Bonus b;
b.additionalInfo = -1;
b.duration = Bonus::PERMANENT;
b.source = Bonus::ARTIFACT;
b.sid = 0;
b.turnsRemain = 0xda;
b.valType = Bonus::ADDITIVE_VALUE;
b.effectRange = Bonus::NO_LIMIT;
b.type = Bonus::PRIMARY_SKILL; b.type = Bonus::PRIMARY_SKILL;
b.subtype = PrimarySkill::ATTACK; b.subtype = PrimarySkill::ATTACK;
b.val = 5 * i + 1; b.val = 5 * i + 1;
@ -327,7 +557,7 @@ void SSNRun()
auto arts = genArts(btt); auto arts = genArts(btt);
//evaluate //evaluate
std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > setups; std::vector<Example> setups;
std::ofstream desOuts("desiredOuts.dat"); std::ofstream desOuts("desiredOuts.dat");
@ -335,8 +565,8 @@ void SSNRun()
{ {
for(int j=0; j<arts.size(); ++j) for(int j=0; j<arts.size(); ++j)
{ {
setups.push_back(boost::make_tuple(dps[i], arts[j], rateArt(dps[i], arts[i]))); setups.push_back(Example(dps[i], arts[j], rateArt(dps[i], arts[i])));
desOuts << (*setups.rbegin()).get<2>() << " "; desOuts << (*setups.rbegin()).value << " ";
} }
desOuts << std::endl; desOuts << std::endl;
} }
@ -345,6 +575,98 @@ void SSNRun()
network.save("network_config_file.net"); network.save("network_config_file.net");
} }
string toString(int i)
{
return boost::lexical_cast<string>(i);
}
string describeBonus(const Bonus &b)
{
return "+" + toString(b.val) + "_to_" + bonusTypeToString(b.type)+"_sub"+toString(b.subtype);
}
int theLastN()
{
auto fnames = getFileNames();
if(!fnames.size())
return -1;
range::sort(fnames, [](const std::string &a, const std::string &b)
{
return boost::lexical_cast<int>(fs::basename(a)) < boost::lexical_cast<int>(fs::basename(b));
});
return boost::lexical_cast<int>(fs::basename(fnames.back()));
}
void buildLearningSet()
{
vector<Example> examples = loadExamples();
range::sort(examples);
int startExamplesFrom = 0;
ofstream learningLog("log.txt", std::ios::app);
int n = theLastN()+1;
auto armies = learningArmies();
auto bonuese = learningBonuses();
for(int i = 0; i < armies.size(); i++)
{
string army = "army" + toString(i);
for(int j = 0; j < bonuese.size(); j++)
{
Bonus b = bonuese[j];
string bonusStr = "bonus" + toString(j) + describeBonus(b);
for(int k = 0; k < 10; k++)
{
int nHere = n++;
// if(nHere < startExamplesFrom)
// continue;
//
tlog2 << "n="<<nHere<<std::endl;
b.val = k;
Example ex;
ex.i = i;
ex.j = j;
ex.k = k;
ex.art = generateArtWithBonus(b);
ex.dp = generateDuel(armies[i]);
ex.description = army + "\t" + describeBonus(b) + "\t";
if(vstd::contains(examples, ex))
{
string msg = str(format("n=%d \tarmy %d \tbonus %d \tresult %lf \t Bonus#%s#") % nHere % i %j % ex.value % describeBonus(b));
tlog0 << "Already present example, skipping " << msg;
continue;
}
ex.value = rateArt(ex.dp, ex.art);
CSaveFile output("./examples/" + toString(nHere) + ".example");
output << ex;
time_t rawtime;
struct tm * timeinfo;
time ( &rawtime );
timeinfo = localtime ( &rawtime );
string msg = str(format("n=%d \tarmy %d \tbonus %d \tresult %lf \t Bonus#%s# \tdate: %s") % nHere % i %j % ex.value % describeBonus(b) % asctime(timeinfo));
learningLog << msg << flush;
tlog0 << msg;
}
}
}
tlog0 << "Set of learning/testing examples is complete and ready!\n";
}
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl; std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl;
@ -415,6 +737,9 @@ int main(int argc, char **argv)
VLC = new LibClasses(); VLC = new LibClasses();
VLC->init(); VLC->init();
buildLearningSet();
SSNRun(); SSNRun();
return EXIT_SUCCESS; return EXIT_SUCCESS;