1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-03-19 21:10:12 +02:00

[programming challenge, SSN] Insignificant reorganization.

This commit is contained in:
Michał W. Urbańczyk 2012-05-30 17:32:05 +00:00
parent 4a7891139e
commit 0e19dd1a79

View File

@ -16,6 +16,36 @@ std::string servername;
std::string runnername;
extern DLL_EXPORT LibClasses * VLC;
typedef std::map<int, CArtifactInstance*> TArtSet;
namespace Utilities
{
std::string addQuotesIfNeeded(const std::string &s)
{
if(s.find_first_of(' ') != std::string::npos)
return "\"" + s + "\"";
return s;
}
void prog_help()
{
std::cout << "If run without args, then StupidAI will be run on b1.json.\n";
}
string toString(int i)
{
return boost::lexical_cast<string>(i);
}
string describeBonus(const Bonus &b)
{
return "+" + toString(b.val) + "_to_" + bonusTypeToString(b.type)+"_sub"+toString(b.subtype);
}
}
using namespace Utilities;
struct Example
{
//ANN input
@ -35,7 +65,6 @@ struct Example
{}
inline bool operator<(const Example & rhs) const
{
if (k<rhs.k)
@ -64,7 +93,28 @@ struct Example
}
};
vector<string> getFileNames(const string &dirname = "./examples/", const std::string &ext = "example")
class Framework
{
static CArtifactInstance *generateArtWithBonus(const Bonus &b);
static DuelParameters generateDuel(const ArmyDescriptor &ad); //generates simple duel where both sides have given army
static void runCommand(const std::string &command, const std::string &name, const std::string &logsDir = "");
static double playBattle(const DuelParameters &dp);
static double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR);
static double rateArt(const DuelParameters dp, CArtifactInstance * inst); //rates given artifact
static int theLastN();
static vector<string> getFileNames(const string &dirname = "./examples/", const std::string &ext = "example");
static vector<ArmyDescriptor> learningArmies();
static vector<Bonus> learningBonuses();
public:
Framework();
~Framework();
static void buildLearningSet();
static vector<Example> loadExamples(bool printInfo = true);
};
vector<string> Framework::getFileNames(const string &dirname, const std::string &ext)
{
vector<string> ret;
if(!fs::exists(dirname))
@ -87,7 +137,7 @@ vector<string> getFileNames(const string &dirname = "./examples/", const std::st
return ret;
}
vector<Example> loadExamples(bool printInfo = true)
vector<Example> Framework::loadExamples(bool printInfo)
{
std::vector<Example> examples;
BOOST_FOREACH(auto fname, getFileNames("./examples/", "example"))
@ -110,481 +160,7 @@ vector<Example> loadExamples(bool printInfo = true)
return examples;
}
bool matchExample(const Example &ex, int i, int j, int k)
{
return ex.i == i && ex.j == j && ex.k == k;
}
//generates simple duel where both sides have given army
DuelParameters generateDuel(const ArmyDescriptor &ad)
{
DuelParameters dp;
dp.bfieldType = 1;
dp.terType = 1;
auto &side = dp.sides[0];
side.heroId = 0;
side.heroPrimSkills.resize(4,0);
BOOST_FOREACH(auto &stack, ad)
{
side.stacks[stack.first] = DuelParameters::SideSettings::StackSettings(stack.second.type->idNumber, stack.second.count);
}
dp.sides[1] = side;
dp.sides[1].heroId = 1;
return dp;
}
std::vector<ArmyDescriptor> learningArmies()
{
std::vector<ArmyDescriptor> ret;
//armia zlozona ze stworow z malymi HP-kami
ArmyDescriptor lowHP;
lowHP[0] = CStackBasicDescriptor(1, 9); //halabardier
lowHP[1] = CStackBasicDescriptor(14, 20); //centaur
lowHP[2] = CStackBasicDescriptor(139, 123); //chlop
lowHP[3] = CStackBasicDescriptor(70, 30); //troglodyta
lowHP[4] = CStackBasicDescriptor(42, 50); //imp
//armia zlozona z poteznaych stworow
ArmyDescriptor highHP;
highHP[0] = CStackBasicDescriptor(13, 17); //archaniol
highHP[1] = CStackBasicDescriptor(132, 8); //azure dragon
highHP[2] = CStackBasicDescriptor(133, 10); //crystal dragon
highHP[3] = CStackBasicDescriptor(83, 22); //black dragon
//armia zlozona z tygodniowego przyrostu w zamku
auto &castleTown = VLC->townh->towns[0];
ArmyDescriptor castleNormal;
for(int i = 0; i < 7; i++)
{
auto &cre = VLC->creh->creatures[castleTown.basicCreatures[i]];
castleNormal[i] = CStackBasicDescriptor(cre.get(), cre->growth);
}
castleNormal[5].type = VLC->creh->creatures[52]; //replace cavaliers with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z tygodniowego ulepszonego przyrostu w ramparcie
auto &rampartTown = VLC->townh->towns[1];
ArmyDescriptor rampartUpgraded;
for(int i = 0; i < 7; i++)
{
auto &cre = VLC->creh->creatures[rampartTown.upgradedCreatures[i]];
rampartUpgraded[i] = CStackBasicDescriptor(cre.get(), cre->growth);
}
rampartUpgraded[5].type = VLC->creh->creatures[52]; //replace unicorn with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z samych strzelcow
ArmyDescriptor shooters;
shooters[0] = CStackBasicDescriptor(35, 17); //arcymag
shooters[1] = CStackBasicDescriptor(41, 1); //titan
shooters[2] = CStackBasicDescriptor(3, 70); //kusznik
shooters[3] = CStackBasicDescriptor(89, 50); //ulepszony ork
ret.push_back(lowHP);
ret.push_back(highHP);
ret.push_back(castleNormal);
ret.push_back(rampartUpgraded);
ret.push_back(shooters);
return ret;
}
std::vector<Bonus> learningBonuses()
{
std::vector<Bonus> ret;
Bonus b;
b.type = Bonus::PRIMARY_SKILL;
b.subtype = PrimarySkill::ATTACK;
ret.push_back(b);
b.subtype = PrimarySkill::DEFENSE;
ret.push_back(b);
b.type = Bonus::STACK_HEALTH;
b.subtype = 0;
ret.push_back(b);
b.type = Bonus::STACKS_SPEED;
ret.push_back(b);
b.type = Bonus::BLOCKS_RETALIATION;
ret.push_back(b);
b.type = Bonus::ADDITIONAL_RETALIATION;
ret.push_back(b);
b.type = Bonus::ADDITIONAL_ATTACK;
ret.push_back(b);
b.type = Bonus::CREATURE_DAMAGE;
ret.push_back(b);
b.type = Bonus::ALWAYS_MAXIMUM_DAMAGE;
ret.push_back(b);
b.type = Bonus::NO_DISTANCE_PENALTY;
ret.push_back(b);
return ret;
}
std::string addQuotesIfNeeded(const std::string &s)
{
if(s.find_first_of(' ') != std::string::npos)
return "\"" + s + "\"";
return s;
}
void prog_help()
{
std::cout << "If run without args, then StupidAI will be run on b1.json.\n";
}
void runCommand(const std::string &command, const std::string &name, const std::string &logsDir = "")
{
static std::string commands[100000];
static int i = 0;
std::string &cmd = commands[i++];
if(logsDir.size() && name.size())
{
std::string directionLogs = logsDir + "/" + name + ".txt";
cmd = command + " > " + addQuotesIfNeeded(directionLogs);
}
else
cmd = command;
boost::thread tt(boost::bind(std::system, cmd.c_str()));
}
double playBattle(const DuelParameters &dp)
{
string battleFileName = "pliczek.ssnb";
{
CSaveFile out(battleFileName);
out << dp;
}
std::string serverCommand = servername + " " + addQuotesIfNeeded(battleFileName) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : "");
std::string runnerCommand = runnername + " " + addQuotesIfNeeded(logsDir);
std::cout <<"Server command: " << serverCommand << std::endl << "Runner command: " << runnerCommand << std::endl;
int code = 0;
boost::thread t([&]
{
code = std::system(serverCommand.c_str());
});
runCommand(runnerCommand, "first_runner", logsDir);
runCommand(runnerCommand, "second_runner", logsDir);
runCommand(runnerCommand, "third_runner", logsDir);
if(withVisualization)
{
//boost::this_thread::sleep(boost::posix_time::millisec(500)); //FIXME
boost::thread tttt(boost::bind(std::system, "VCMI_Client.exe -battle"));
}
//boost::this_thread::sleep(boost::posix_time::seconds(5));
t.join();
return code / 1000000.0;
}
typedef std::map<int, CArtifactInstance*> TArtSet;
double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
{
dp.sides[0].artifacts = setL;
dp.sides[1].artifacts = setR;
auto battleOutcome = playBattle(dp);
return battleOutcome;
}
CArtifactInstance *generateArtWithBonus(const Bonus &b)
{
std::vector<CArtifactInstance*> ret;
static CArtifact *nowy = NULL;
if(!nowy)
{
nowy = new CArtifact();
nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo";
nowy->name = "Cudowny miecz";
nowy->constituentOf = nowy->constituents = NULL;
nowy->possibleSlots.push_back(Arts::LEFT_HAND);
}
CArtifactInstance *artinst = new CArtifactInstance(nowy);
artinst->addNewBonus(new Bonus(b));
return artinst;
}
std::vector<CArtifactInstance*> genArts(const std::vector<Bonus> & bonusesToGive)
{
std::vector<CArtifactInstance*> ret;
BOOST_FOREACH(auto b, bonusesToGive)
{
ret.push_back(generateArtWithBonus(b));
}
// auto bonuses = artinst->getBonuses([](const Bonus *){ return true; });
// BOOST_FOREACH(Bonus *b, *bonuses)
// {
// std::cout << format("%s (%d) value:%d, description: %s\n") % bonusTypeToString(b->type) % b->subtype % b->val % b->Description();
// }
return ret;
}
//rates given artifact
double rateArt(const DuelParameters dp, CArtifactInstance * inst)
{
TArtSet setL, setR;
setL[inst->artType->possibleSlots[0]] = inst;
double resultLR = cmpArtSets(dp, setL, setR),
resultRL = cmpArtSets(dp, setR, setL),
resultsBase = cmpArtSets(dp, TArtSet(), TArtSet());
//lewa strona z art 0.9
//bez artefaktow -0.41
//prawa strona z art. -0.926
double LRgain = resultLR - resultsBase,
RLgain = resultsBase - resultRL;
return LRgain+RLgain;
}
const unsigned int num_input = 18;
double * genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
{
double * ret = new double[num_input];
double * cur = ret;
//general description
*(cur++) = bfieldType/30.0;
*(cur++) = terType/12.0;
//creature & hero description
*(cur++) = dp.heroId/200.0;
for(int k=0; k<4; ++k)
*(cur++) = dp.heroPrimSkills[k]/20.0;
//weighted average of statistics
auto avg = [&](std::function<int(CCreature *)> getter) -> double
{
double ret = 0.0;
int div = 0;
for(int i=0; i<7; ++i)
{
auto & cstack = dp.stacks[i];
if(cstack.count > 0)
{
ret += getter(VLC->creh->creatures[cstack.type]) * cstack.count;
div+=cstack.count;
}
}
return ret/div;
};
*(cur++) = avg([](CCreature * c){return c->attack;})/50.0;
*(cur++) = avg([](CCreature * c){return c->defence;})/50.0;
*(cur++) = avg([](CCreature * c){return c->speed;})/15.0;
*(cur++) = avg([](CCreature * c){return c->hitPoints;})/1000.0;
//bonus description
auto & blist = art->getBonusList();
*(cur++) = blist[0]->type/100.0;
*(cur++) = blist[0]->subtype/10.0;
*(cur++) = blist[0]->val/100.0;;
*(cur++) = art->Attack()/10.0;
*(cur++) = art->Defense()/10.0;
*(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACKS_SPEED))/5.0;
*(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACK_HEALTH))/10.0;
return ret;
}
//returns how good the artifact is for the neural network
double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance * inst)
{
double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
double * out = net.run(input);
double ret = *out;
free(out);
return ret;
}
int ANNCallback(FANN::neural_net &net, FANN::training_data &train,
unsigned int max_epochs, unsigned int epochs_between_reports,
float desired_error, unsigned int epochs, void *user_data)
{
//cout << "Epochs " << setw(8) << epochs << ". "
// << "Current Error: " << left << net.get_MSE() << right << endl;
return 0;
}
void learnSSN(FANN::neural_net & net, const std::vector<Example> & input)
{
FANN::training_data td;
double ** inputs = new double *[input.size()];
double ** outputs = new double *[input.size()];
for(int i=0; i<input.size(); ++i)
{
const auto & ci = input[i];
inputs[i] = genSSNinput(ci.dp.sides[0], ci.art, ci.dp.bfieldType, ci.dp.terType);
outputs[i] = new double;
*(outputs[i]) = ci.value;
}
td.set_train_data(input.size(), num_input, inputs, 1, outputs);
net.set_callback(ANNCallback, NULL);
net.train_on_data(td, 1000, 1000, 0.01);
}
void initNet(FANN::neural_net & ret)
{
const float learning_rate = 0.7f;
const unsigned int num_layers = 3;
const unsigned int num_hidden = 30;
const unsigned int num_output = 1;
const float desired_error = 0.001f;
const unsigned int max_iterations = 300000;
const unsigned int iterations_between_reports = 1000;
ret.create_standard(num_layers, num_input, num_hidden, num_output);
ret.set_learning_rate(learning_rate);
ret.set_activation_steepness_hidden(0.9);
ret.set_activation_steepness_output(1.0);
ret.set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE);
ret.set_activation_function_output(FANN::SIGMOID_SYMMETRIC_STEPWISE);
ret.randomize_weights(0.0, 1.0);
}
void SSNRun()
{
std::vector<std::pair<CArtifactInstance *, double> > artNotes;
TArtSet setL, setR;
FANN::neural_net network;
initNet(network);
// for(int i=0; i<availableArts.size(); ++i)
// {
// artNotes.push_back(std::make_pair(availableArts[i], runSSN(network, availableArts[i])));
// }
// boost::range::sort(artNotes,
// [](const std::pair<CArtifactInstance *, double> & a1, const std::pair<CArtifactInstance *, double> & a2)
// {return a1.second > a2.second;});
//
// //pick best arts into setL
// BOOST_FOREACH(auto & ap, artNotes)
// {
// auto art = ap.first;
// BOOST_FOREACH(auto slot, art->artType->possibleSlots)
// {
// if(setL.find(slot) != setL.end())
// {
// setL[slot] = art;
// break;
// }
// }
// }
//duels to test on
std::vector<DuelParameters> dps;
for(int k = 0; k<10; ++k)
{
DuelParameters dp;
dps.push_back(dp);
}
std::vector<Bonus> btt; //bonuses to test on
for(int i=0; i<5; ++i)
{
Bonus b;
b.additionalInfo = -1;
b.duration = Bonus::PERMANENT;
b.source = Bonus::ARTIFACT;
b.sid = 0;
b.turnsRemain = 0xda;
b.valType = Bonus::ADDITIVE_VALUE;
b.effectRange = Bonus::NO_LIMIT;
b.type = Bonus::PRIMARY_SKILL;
b.subtype = PrimarySkill::ATTACK;
b.val = 5 * i + 1;
btt.push_back(b);
b.subtype = PrimarySkill::DEFENSE;
btt.push_back(b);
b.type = Bonus::STACKS_SPEED;
b.subtype = 0;
btt.push_back(b);
b.type = Bonus::STACK_HEALTH;
btt.push_back(b);
}
auto arts = genArts(btt);
//evaluate
std::vector<Example> setups;
std::ofstream desOuts("desiredOuts.dat");
for(int i=0; i<dps.size(); ++i)
{
for(int j=0; j<arts.size(); ++j)
{
setups.push_back(Example(dps[i], arts[j], rateArt(dps[i], arts[i])));
desOuts << (*setups.rbegin()).value << " ";
}
desOuts << std::endl;
}
learnSSN(network, setups);
network.save("network_config_file.net");
}
string toString(int i)
{
return boost::lexical_cast<string>(i);
}
string describeBonus(const Bonus &b)
{
return "+" + toString(b.val) + "_to_" + bonusTypeToString(b.type)+"_sub"+toString(b.subtype);
}
int theLastN()
int Framework::theLastN()
{
auto fnames = getFileNames();
if(!fnames.size())
@ -598,7 +174,7 @@ int theLastN()
return boost::lexical_cast<int>(fs::basename(fnames.back()));
}
void buildLearningSet()
void Framework::buildLearningSet()
{
vector<Example> examples = loadExamples();
range::sort(examples);
@ -664,7 +240,373 @@ void buildLearningSet()
tlog0 << "Set of learning/testing examples is complete and ready!\n";
}
vector<ArmyDescriptor> Framework::learningArmies()
{
vector<ArmyDescriptor> ret;
//armia zlozona ze stworow z malymi HP-kami
ArmyDescriptor lowHP;
lowHP[0] = CStackBasicDescriptor(1, 9); //halabardier
lowHP[1] = CStackBasicDescriptor(14, 20); //centaur
lowHP[2] = CStackBasicDescriptor(139, 123); //chlop
lowHP[3] = CStackBasicDescriptor(70, 30); //troglodyta
lowHP[4] = CStackBasicDescriptor(42, 50); //imp
//armia zlozona z poteznaych stworow
ArmyDescriptor highHP;
highHP[0] = CStackBasicDescriptor(13, 17); //archaniol
highHP[1] = CStackBasicDescriptor(132, 8); //azure dragon
highHP[2] = CStackBasicDescriptor(133, 10); //crystal dragon
highHP[3] = CStackBasicDescriptor(83, 22); //black dragon
//armia zlozona z tygodniowego przyrostu w zamku
auto &castleTown = VLC->townh->towns[0];
ArmyDescriptor castleNormal;
for(int i = 0; i < 7; i++)
{
auto &cre = VLC->creh->creatures[castleTown.basicCreatures[i]];
castleNormal[i] = CStackBasicDescriptor(cre.get(), cre->growth);
}
castleNormal[5].type = VLC->creh->creatures[52]; //replace cavaliers with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z tygodniowego ulepszonego przyrostu w ramparcie
auto &rampartTown = VLC->townh->towns[1];
ArmyDescriptor rampartUpgraded;
for(int i = 0; i < 7; i++)
{
auto &cre = VLC->creh->creatures[rampartTown.upgradedCreatures[i]];
rampartUpgraded[i] = CStackBasicDescriptor(cre.get(), cre->growth);
}
rampartUpgraded[5].type = VLC->creh->creatures[52]; //replace unicorn with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z samych strzelcow
ArmyDescriptor shooters;
shooters[0] = CStackBasicDescriptor(35, 17); //arcymag
shooters[1] = CStackBasicDescriptor(41, 1); //titan
shooters[2] = CStackBasicDescriptor(3, 70); //kusznik
shooters[3] = CStackBasicDescriptor(89, 50); //ulepszony ork
ret.push_back(lowHP);
ret.push_back(highHP);
ret.push_back(castleNormal);
ret.push_back(rampartUpgraded);
ret.push_back(shooters);
return ret;
}
vector<Bonus> Framework::learningBonuses()
{
vector<Bonus> ret;
Bonus b;
b.type = Bonus::PRIMARY_SKILL;
b.subtype = PrimarySkill::ATTACK;
ret.push_back(b);
b.subtype = PrimarySkill::DEFENSE;
ret.push_back(b);
b.type = Bonus::STACK_HEALTH;
b.subtype = 0;
ret.push_back(b);
b.type = Bonus::STACKS_SPEED;
ret.push_back(b);
b.type = Bonus::BLOCKS_RETALIATION;
ret.push_back(b);
b.type = Bonus::ADDITIONAL_RETALIATION;
ret.push_back(b);
b.type = Bonus::ADDITIONAL_ATTACK;
ret.push_back(b);
b.type = Bonus::CREATURE_DAMAGE;
ret.push_back(b);
b.type = Bonus::ALWAYS_MAXIMUM_DAMAGE;
ret.push_back(b);
b.type = Bonus::NO_DISTANCE_PENALTY;
ret.push_back(b);
return ret;
}
double Framework::rateArt(const DuelParameters dp, CArtifactInstance * inst)
{
TArtSet setL, setR;
setL[inst->artType->possibleSlots[0]] = inst;
double resultLR = cmpArtSets(dp, setL, setR),
resultRL = cmpArtSets(dp, setR, setL),
resultsBase = cmpArtSets(dp, TArtSet(), TArtSet());
//lewa strona z art 0.9
//bez artefaktow -0.41
//prawa strona z art. -0.926
double LRgain = resultLR - resultsBase,
RLgain = resultsBase - resultRL;
return LRgain+RLgain;
}
double Framework::cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
{
dp.sides[0].artifacts = setL;
dp.sides[1].artifacts = setR;
auto battleOutcome = playBattle(dp);
return battleOutcome;
}
double Framework::playBattle(const DuelParameters &dp)
{
string battleFileName = "pliczek.ssnb";
{
CSaveFile out(battleFileName);
out << dp;
}
std::string serverCommand = servername + " " + addQuotesIfNeeded(battleFileName) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : "");
std::string runnerCommand = runnername + " " + addQuotesIfNeeded(logsDir);
std::cout <<"Server command: " << serverCommand << std::endl << "Runner command: " << runnerCommand << std::endl;
int code = 0;
boost::thread t([&]
{
code = std::system(serverCommand.c_str());
});
runCommand(runnerCommand, "first_runner", logsDir);
runCommand(runnerCommand, "second_runner", logsDir);
runCommand(runnerCommand, "third_runner", logsDir);
if(withVisualization)
{
//boost::this_thread::sleep(boost::posix_time::millisec(500)); //FIXME
boost::thread tttt(boost::bind(std::system, "VCMI_Client.exe -battle"));
}
//boost::this_thread::sleep(boost::posix_time::seconds(5));
t.join();
return code / 1000000.0;
}
void Framework::runCommand(const std::string &command, const std::string &name, const std::string &logsDir /*= ""*/)
{
static std::string commands[100000];
static int i = 0;
std::string &cmd = commands[i++];
if(logsDir.size() && name.size())
{
std::string directionLogs = logsDir + "/" + name + ".txt";
cmd = command + " > " + addQuotesIfNeeded(directionLogs);
}
else
cmd = command;
boost::thread tt(boost::bind(std::system, cmd.c_str()));
}
DuelParameters Framework::generateDuel(const ArmyDescriptor &ad)
{
DuelParameters dp;
dp.bfieldType = 1;
dp.terType = 1;
auto &side = dp.sides[0];
side.heroId = 0;
side.heroPrimSkills.resize(4,0);
BOOST_FOREACH(auto &stack, ad)
{
side.stacks[stack.first] = DuelParameters::SideSettings::StackSettings(stack.second.type->idNumber, stack.second.count);
}
dp.sides[1] = side;
dp.sides[1].heroId = 1;
return dp;
}
CArtifactInstance * Framework::generateArtWithBonus(const Bonus &b)
{
std::vector<CArtifactInstance*> ret;
static CArtifact *nowy = NULL;
if(!nowy)
{
nowy = new CArtifact();
nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo";
nowy->name = "Cudowny miecz";
nowy->constituentOf = nowy->constituents = NULL;
nowy->possibleSlots.push_back(Arts::LEFT_HAND);
}
CArtifactInstance *artinst = new CArtifactInstance(nowy);
artinst->addNewBonus(new Bonus(b));
return artinst;
}
class SSN
{
FANN::neural_net net;
void init();
static int ANNCallback(FANN::neural_net &net, FANN::training_data &train, unsigned int max_epochs, unsigned int epochs_between_reports, float desired_error, unsigned int epochs, void *user_data);
static double * genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType);
static const unsigned int num_input = 18;
public:
SSN();
~SSN();
void learn(const std::vector<Example> & input);
double run(const DuelParameters &dp, CArtifactInstance * inst);
void save(const std::string &filename);
};
SSN::SSN()
{
init();
}
void SSN::init()
{
const float learning_rate = 0.7f;
const unsigned int num_layers = 3;
const unsigned int num_hidden = 30;
const unsigned int num_output = 1;
const float desired_error = 0.001f;
const unsigned int max_iterations = 300000;
const unsigned int iterations_between_reports = 1000;
net.create_standard(num_layers, num_input, num_hidden, num_output);
net.set_learning_rate(learning_rate);
net.set_activation_steepness_hidden(0.9);
net.set_activation_steepness_output(1.0);
net.set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE);
net.set_activation_function_output(FANN::SIGMOID_SYMMETRIC_STEPWISE);
net.randomize_weights(0.0, 1.0);
}
double SSN::run(const DuelParameters &dp, CArtifactInstance * inst)
{
double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
double * out = net.run(input);
double ret = *out;
free(out);
return ret;
}
int SSN::ANNCallback(FANN::neural_net &net, FANN::training_data &train, unsigned int max_epochs, unsigned int epochs_between_reports, float desired_error, unsigned int epochs, void *user_data)
{
//cout << "Epochs " << setw(8) << epochs << ". "
// << "Current Error: " << left << net.get_MSE() << right << endl;
return 0;
}
void SSN::learn(const std::vector<Example> & input)
{
//FIXME - sypie przy destrukcji
//FANN::training_data td;
FANN::training_data &td = *new FANN::training_data;
double ** inputs = new double *[input.size()];
double ** outputs = new double *[input.size()];
for(int i=0; i<input.size(); ++i)
{
const auto & ci = input[i];
inputs[i] = genSSNinput(ci.dp.sides[0], ci.art, ci.dp.bfieldType, ci.dp.terType);
outputs[i] = new double;
*(outputs[i]) = ci.value;
}
td.set_train_data(input.size(), num_input, inputs, 1, outputs);
net.set_callback(ANNCallback, NULL);
net.train_on_data(td, 1000, 1000, 0.01);
}
double * SSN::genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
{
double * ret = new double[num_input];
double * cur = ret;
//general description
*(cur++) = bfieldType/30.0;
*(cur++) = terType/12.0;
//creature & hero description
*(cur++) = dp.heroId/200.0;
for(int k=0; k<4; ++k)
*(cur++) = dp.heroPrimSkills[k]/20.0;
//weighted average of statistics
auto avg = [&](std::function<int(CCreature *)> getter) -> double
{
double ret = 0.0;
int div = 0;
for(int i=0; i<7; ++i)
{
auto & cstack = dp.stacks[i];
if(cstack.count > 0)
{
ret += getter(VLC->creh->creatures[cstack.type]) * cstack.count;
div+=cstack.count;
}
}
return ret/div;
};
*(cur++) = avg([](CCreature * c){return c->attack;})/50.0;
*(cur++) = avg([](CCreature * c){return c->defence;})/50.0;
*(cur++) = avg([](CCreature * c){return c->speed;})/15.0;
*(cur++) = avg([](CCreature * c){return c->hitPoints;})/1000.0;
//bonus description
auto & blist = art->getBonusList();
*(cur++) = blist[0]->type/100.0;
*(cur++) = blist[0]->subtype/10.0;
*(cur++) = blist[0]->val/100.0;;
*(cur++) = art->Attack()/10.0;
*(cur++) = art->Defense()/10.0;
*(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACKS_SPEED))/5.0;
*(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACK_HEALTH))/10.0;
return ret;
}
void SSN::save(const std::string &filename)
{
net.save(filename);
}
SSN::~SSN()
{
}
void SSNRun()
{
//buildLearningSet();
auto examples = Framework::loadExamples(false);
SSN network;
network.learn(examples);
network.save("network_config_file.net");
}
int main(int argc, char **argv)
{
@ -736,9 +678,6 @@ int main(int argc, char **argv)
VLC = new LibClasses();
VLC->init();
//buildLearningSet();
SSNRun();
return EXIT_SUCCESS;