1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-03-31 22:05:10 +02:00

[programming challenge, SSN] enemy's army is no longer part of ANN's input

This commit is contained in:
mateuszb 2012-05-28 20:31:32 +00:00
parent 3d16f0a081
commit 4a7891139e
2 changed files with 32 additions and 32 deletions

View File

@ -91,6 +91,7 @@
<Optimization>Disabled</Optimization> <Optimization>Disabled</Optimization>
<PrecompiledHeader>Use</PrecompiledHeader> <PrecompiledHeader>Use</PrecompiledHeader>
<PrecompiledHeaderFile>StdInc.h</PrecompiledHeaderFile> <PrecompiledHeaderFile>StdInc.h</PrecompiledHeaderFile>
<AdditionalOptions>/Zm150 %(AdditionalOptions)</AdditionalOptions>
</ClCompile> </ClCompile>
<Link> <Link>
<GenerateDebugInformation>true</GenerateDebugInformation> <GenerateDebugInformation>true</GenerateDebugInformation>

View File

@ -364,49 +364,47 @@ double rateArt(const DuelParameters dp, CArtifactInstance * inst)
} }
const unsigned int num_input = 27; const unsigned int num_input = 18;
double * genSSNinput(const DuelParameters & dp, CArtifactInstance * art) double * genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
{ {
double * ret = new double[num_input]; double * ret = new double[num_input];
double * cur = ret; double * cur = ret;
//general description //general description
*(cur++) = dp.bfieldType/30.0; *(cur++) = bfieldType/30.0;
*(cur++) = dp.terType/12.0; *(cur++) = terType/12.0;
//creature & hero description //creature & hero description
for(int i=0; i<2; ++i)
*(cur++) = dp.heroId/200.0;
for(int k=0; k<4; ++k)
*(cur++) = dp.heroPrimSkills[k]/20.0;
//weighted average of statistics
auto avg = [&](std::function<int(CCreature *)> getter) -> double
{ {
auto & side = dp.sides[0]; double ret = 0.0;
*(cur++) = side.heroId/200.0; int div = 0;
for(int k=0; k<4; ++k) for(int i=0; i<7; ++i)
*(cur++) = side.heroPrimSkills[k]/20.0;
//weighted average of statistics
auto avg = [&](std::function<int(CCreature *)> getter) -> double
{ {
double ret = 0.0; auto & cstack = dp.stacks[i];
int div = 0; if(cstack.count > 0)
for(int i=0; i<7; ++i)
{ {
auto & cstack = side.stacks[i]; ret += getter(VLC->creh->creatures[cstack.type]) * cstack.count;
if(cstack.count > 0) div+=cstack.count;
{
ret += getter(VLC->creh->creatures[cstack.type]) * cstack.count;
div+=cstack.count;
}
} }
return ret/div; }
}; return ret/div;
};
*(cur++) = avg([](CCreature * c){return c->attack;})/50.0;
*(cur++) = avg([](CCreature * c){return c->defence;})/50.0; *(cur++) = avg([](CCreature * c){return c->attack;})/50.0;
*(cur++) = avg([](CCreature * c){return c->speed;})/15.0; *(cur++) = avg([](CCreature * c){return c->defence;})/50.0;
*(cur++) = avg([](CCreature * c){return c->hitPoints;})/1000.0; *(cur++) = avg([](CCreature * c){return c->speed;})/15.0;
} *(cur++) = avg([](CCreature * c){return c->hitPoints;})/1000.0;
//bonus description //bonus description
auto & blist = art->getBonusList(); auto & blist = art->getBonusList();
@ -425,7 +423,7 @@ double * genSSNinput(const DuelParameters & dp, CArtifactInstance * art)
//returns how good the artifact is for the neural network //returns how good the artifact is for the neural network
double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance * inst) double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance * inst)
{ {
double * input = genSSNinput(dp, inst); double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
double * out = net.run(input); double * out = net.run(input);
double ret = *out; double ret = *out;
free(out); free(out);
@ -450,9 +448,10 @@ void learnSSN(FANN::neural_net & net, const std::vector<Example> & input)
double ** outputs = new double *[input.size()]; double ** outputs = new double *[input.size()];
for(int i=0; i<input.size(); ++i) for(int i=0; i<input.size(); ++i)
{ {
inputs[i] = genSSNinput(input[i].dp, input[i].art); const auto & ci = input[i];
inputs[i] = genSSNinput(ci.dp.sides[0], ci.art, ci.dp.bfieldType, ci.dp.terType);
outputs[i] = new double; outputs[i] = new double;
*(outputs[i]) = input[i].value; *(outputs[i]) = ci.value;
} }
td.set_train_data(input.size(), num_input, inputs, 1, outputs); td.set_train_data(input.size(), num_input, inputs, 1, outputs);
net.set_callback(ANNCallback, NULL); net.set_callback(ANNCallback, NULL);
@ -738,7 +737,7 @@ int main(int argc, char **argv)
VLC->init(); VLC->init();
buildLearningSet(); //buildLearningSet();
SSNRun(); SSNRun();