1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-03-21 21:17:49 +02:00

[programming challenge, SSN] neural network's code seems to be more or less complete; not tested though

This commit is contained in:
mateuszb 2012-05-21 20:11:02 +00:00
parent 2f7ba07050
commit cc9823bd73

View File

@ -91,7 +91,7 @@ double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
return battleOutcome; return battleOutcome;
} }
std::vector<CArtifactInstance*> genArts() std::vector<CArtifactInstance*> genArts(const std::vector<Bonus> & bonusesToGive)
{ {
std::vector<CArtifactInstance*> ret; std::vector<CArtifactInstance*> ret;
@ -101,31 +101,41 @@ std::vector<CArtifactInstance*> genArts()
nowy->constituentOf = nowy->constituents = NULL; nowy->constituentOf = nowy->constituents = NULL;
nowy->possibleSlots.push_back(Arts::LEFT_HAND); nowy->possibleSlots.push_back(Arts::LEFT_HAND);
CArtifactInstance *artinst = new CArtifactInstance(nowy);
auto &arts = VLC->arth->artifacts;
CArtifactInstance *inny = new CArtifactInstance(VLC->arth->artifacts[15]);
artinst->addNewBonus(new Bonus(Bonus::PERMANENT, Bonus::PRIMARY_SKILL, Bonus::ARTIFACT_INSTANCE, +25, nowy->id, PrimarySkill::ATTACK)); BOOST_FOREACH(auto b, bonusesToGive)
artinst->addNewBonus(new Bonus(Bonus::PERMANENT, Bonus::PRIMARY_SKILL, Bonus::ARTIFACT_INSTANCE, +25, nowy->id, PrimarySkill::DEFENSE));
auto bonuses = artinst->getBonuses([](const Bonus *){ return true; });
BOOST_FOREACH(Bonus *b, *bonuses)
{ {
std::cout << format("%s (%d) value:%d, description: %s\n") % bonusTypeToString(b->type) % b->subtype % b->val % b->Description(); CArtifactInstance *artinst = new CArtifactInstance(nowy);
auto &arts = VLC->arth->artifacts;
artinst->addNewBonus(new Bonus(b));
ret.push_back(artinst);
} }
// auto bonuses = artinst->getBonuses([](const Bonus *){ return true; });
// BOOST_FOREACH(Bonus *b, *bonuses)
// {
// std::cout << format("%s (%d) value:%d, description: %s\n") % bonusTypeToString(b->type) % b->subtype % b->val % b->Description();
// }
return ret; return ret;
} }
//returns how good the artifact is for the neural network //returns how good the artifact is for the neural network
double runSSN(FANN::neural_net & net, CArtifactInstance * inst) double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance * inst)
{ {
TArtSet setL, setR;
setL[inst->artType->possibleSlots[0]] = inst;
return 0.0; double resultLR = cmpArtSets(dp, setL, setR),
resultRL = cmpArtSets(dp, setR, setL),
resultsBase = cmpArtSets(dp, TArtSet(), TArtSet());
double LRgain = resultLR - resultsBase,
RLgain = resultRL - resultsBase;
return LRgain+RLgain;
} }
const unsigned int num_input = 2; const unsigned int num_input = 16;
double * genSSNinput(const DuelParameters & dp, CArtifactInstance * art) double * genSSNinput(const DuelParameters & dp, CArtifactInstance * art)
{ {
@ -170,15 +180,31 @@ double * genSSNinput(const DuelParameters & dp, CArtifactInstance * art)
} }
//bonus description //bonus description
auto & blist = art->getBonusList();
*(cur++) = art->Attack();
*(cur++) = art->Defense();
*(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACKS_SPEED));
*(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACK_HEALTH));
return ret; return ret;
} }
void learnSSN(FANN::neural_net & net, const DuelParameters & dp, CArtifactInstance * art, double desiredVal) void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters, CArtifactInstance *> > & input)
{ {
double * input = genSSNinput(dp, art); FANN::training_data td;
net.train(input, &desiredVal);
delete input; double ** inputs = new double *[input.size()];
double ** outputs = new double *[input.size()];
for(int i=0; i<input.size(); ++i)
{
inputs[i] = genSSNinput(input[i].first, input[i].second);
outputs[i] = new double;
*(outputs[i]) = runSSN(net, input[i].first, input[i].second);
}
td.set_train_data(input.size(), num_input, inputs, 1, outputs);
net.train_epoch(td);
} }
void initNet(FANN::neural_net & ret) void initNet(FANN::neural_net & ret)
@ -206,7 +232,6 @@ void initNet(FANN::neural_net & ret)
void SSNRun() void SSNRun()
{ {
auto availableArts = genArts();
std::vector<std::pair<CArtifactInstance *, double> > artNotes; std::vector<std::pair<CArtifactInstance *, double> > artNotes;
TArtSet setL, setR; TArtSet setL, setR;
@ -214,27 +239,27 @@ void SSNRun()
FANN::neural_net network; FANN::neural_net network;
initNet(network); initNet(network);
for(int i=0; i<availableArts.size(); ++i) // for(int i=0; i<availableArts.size(); ++i)
{ // {
artNotes.push_back(std::make_pair(availableArts[i], runSSN(network, availableArts[i]))); // artNotes.push_back(std::make_pair(availableArts[i], runSSN(network, availableArts[i])));
} // }
boost::range::sort(artNotes, // boost::range::sort(artNotes,
[](const std::pair<CArtifactInstance *, double> & a1, const std::pair<CArtifactInstance *, double> & a2) // [](const std::pair<CArtifactInstance *, double> & a1, const std::pair<CArtifactInstance *, double> & a2)
{return a1.second > a2.second;}); // {return a1.second > a2.second;});
//
//pick best arts into setL // //pick best arts into setL
BOOST_FOREACH(auto & ap, artNotes) // BOOST_FOREACH(auto & ap, artNotes)
{ // {
auto art = ap.first; // auto art = ap.first;
BOOST_FOREACH(auto slot, art->artType->possibleSlots) // BOOST_FOREACH(auto slot, art->artType->possibleSlots)
{ // {
if(setL.find(slot) != setL.end()) // if(setL.find(slot) != setL.end())
{ // {
setL[slot] = art; // setL[slot] = art;
break; // break;
} // }
} // }
} // }
//duels to test on //duels to test on
@ -274,17 +299,18 @@ void SSNRun()
} }
auto arts = genArts(btt);
//evaluate //evaluate
std::vector<std::pair<DuelParameters, CArtifactInstance *> > setups;
for(int i=0; i<dps.size(); ++i) for(int i=0; i<dps.size(); ++i)
{ {
auto & dp = dps[i]; for(int j=0; j<arts.size(); ++j)
double resultLR = cmpArtSets(dp, setL, setR), {
resultRL = cmpArtSets(dp, setR, setL), setups.push_back(std::make_pair(dps[i], arts[j]));
resultsBase = cmpArtSets(dp, TArtSet(), TArtSet()); }
double LRgain = resultLR - resultsBase,
RLgain = resultRL - resultsBase;
} }
learnSSN(network, setups);
} }
int main(int argc, char **argv) int main(int argc, char **argv)