mirror of
https://github.com/vcmi/vcmi.git
synced 2026-05-22 09:55:17 +02:00
TBB for battle AI spellcast an fixes
This commit is contained in:
+61
-86
@@ -13,6 +13,7 @@
|
||||
|
||||
#include "StackWithBonuses.h"
|
||||
#include "EnemyInfo.h"
|
||||
#include "tbb/parallel_for.h"
|
||||
#include "../../lib/CStopWatch.h"
|
||||
#include "../../lib/CThreadHelper.h"
|
||||
#include "../../lib/mapObjects/CGTownInstance.h"
|
||||
@@ -704,96 +705,70 @@ void BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
|
||||
}
|
||||
}
|
||||
|
||||
struct ScriptsCache
|
||||
{
|
||||
//todo: re-implement scripts context cache
|
||||
};
|
||||
|
||||
auto evaluateSpellcast = [&] (PossibleSpellcast * ps, std::shared_ptr<ScriptsCache>)
|
||||
{
|
||||
auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
|
||||
|
||||
spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps->spell);
|
||||
cast.castEval(state->getServerCallback(), ps->dest);
|
||||
|
||||
auto allUnits = state->battleGetUnitsIf([](const battle::Unit * u) -> bool{ return true; });
|
||||
|
||||
auto needFullEval = vstd::contains_if(allUnits, [&](const battle::Unit * u) -> bool
|
||||
{
|
||||
auto original = cb->battleGetUnitByID(u->unitId());
|
||||
return !original || u->speed() != original->speed();
|
||||
});
|
||||
|
||||
DamageCache innerCache(&damageCache);
|
||||
innerCache.buildDamageCache(state, side);
|
||||
|
||||
if(needFullEval || !cachedAttack)
|
||||
{
|
||||
PotentialTargets innerTargets(activeStack, damageCache, state);
|
||||
BattleExchangeEvaluator innerEvaluator(state, env);
|
||||
|
||||
if(!innerTargets.possibleAttacks.empty())
|
||||
{
|
||||
innerEvaluator.updateReachabilityMap(state);
|
||||
|
||||
auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state);
|
||||
|
||||
ps->value = newStackAction.score;
|
||||
}
|
||||
else
|
||||
{
|
||||
ps->value = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ps->value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state);
|
||||
}
|
||||
|
||||
for(auto unit : allUnits)
|
||||
{
|
||||
auto newHealth = unit->getAvailableHealth();
|
||||
auto oldHealth = healthOfStack[unit->unitId()];
|
||||
|
||||
if(oldHealth != newHealth)
|
||||
{
|
||||
auto damage = std::abs(oldHealth - newHealth);
|
||||
auto originalDefender = cb->battleGetUnitByID(unit->unitId());
|
||||
auto dpsReduce = AttackPossibility::calculateDamageReduce(nullptr, originalDefender ? originalDefender : unit, damage, innerCache, state);
|
||||
auto ourUnit = unit->unitSide() == side ? 1 : -1;
|
||||
auto goodEffect = newHealth > oldHealth ? 1 : -1;
|
||||
|
||||
ps->value += ourUnit * goodEffect * dpsReduce;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using EvalRunner = ThreadPool<ScriptsCache>;
|
||||
|
||||
EvalRunner::Tasks tasks;
|
||||
|
||||
for(PossibleSpellcast & psc : possibleCasts)
|
||||
tasks.push_back(std::bind(evaluateSpellcast, &psc, _1));
|
||||
|
||||
uint32_t threadCount = boost::thread::hardware_concurrency();
|
||||
|
||||
if(threadCount == 0)
|
||||
{
|
||||
logGlobal->warn("No information of CPU cores available");
|
||||
threadCount = 1;
|
||||
}
|
||||
|
||||
CStopWatch timer;
|
||||
|
||||
std::vector<std::shared_ptr<ScriptsCache>> scriptsPool;
|
||||
tbb::parallel_for(tbb::blocked_range<size_t>(0, possibleCasts.size()), [&](const tbb::blocked_range<size_t> & r)
|
||||
{
|
||||
for(auto i = r.begin(); i != r.end(); i++)
|
||||
{
|
||||
auto & ps = possibleCasts[i];
|
||||
auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
|
||||
|
||||
for(uint32_t idx = 0; idx < threadCount; idx++)
|
||||
{
|
||||
scriptsPool.emplace_back();
|
||||
}
|
||||
spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell);
|
||||
cast.castEval(state->getServerCallback(), ps.dest);
|
||||
|
||||
EvalRunner runner(&tasks, scriptsPool);
|
||||
runner.run();
|
||||
auto allUnits = state->battleGetUnitsIf([](const battle::Unit * u) -> bool { return true; });
|
||||
|
||||
auto needFullEval = vstd::contains_if(allUnits, [&](const battle::Unit * u) -> bool
|
||||
{
|
||||
auto original = cb->battleGetUnitByID(u->unitId());
|
||||
return !original || u->speed() != original->speed();
|
||||
});
|
||||
|
||||
DamageCache innerCache(&damageCache);
|
||||
innerCache.buildDamageCache(state, side);
|
||||
|
||||
if(needFullEval || !cachedAttack)
|
||||
{
|
||||
PotentialTargets innerTargets(activeStack, damageCache, state);
|
||||
BattleExchangeEvaluator innerEvaluator(state, env);
|
||||
|
||||
if(!innerTargets.possibleAttacks.empty())
|
||||
{
|
||||
innerEvaluator.updateReachabilityMap(state);
|
||||
|
||||
auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state);
|
||||
|
||||
ps.value = newStackAction.score;
|
||||
}
|
||||
else
|
||||
{
|
||||
ps.value = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ps.value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state);
|
||||
}
|
||||
|
||||
for(auto unit : allUnits)
|
||||
{
|
||||
auto newHealth = unit->getAvailableHealth();
|
||||
auto oldHealth = healthOfStack[unit->unitId()];
|
||||
|
||||
if(oldHealth != newHealth)
|
||||
{
|
||||
auto damage = std::abs(oldHealth - newHealth);
|
||||
auto originalDefender = cb->battleGetUnitByID(unit->unitId());
|
||||
auto dpsReduce = AttackPossibility::calculateDamageReduce(nullptr, originalDefender ? originalDefender : unit, damage, innerCache, state);
|
||||
auto ourUnit = unit->unitSide() == side ? 1 : -1;
|
||||
auto goodEffect = newHealth > oldHealth ? 1 : -1;
|
||||
|
||||
ps.value += ourUnit * goodEffect * dpsReduce;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
LOGFL("Evaluation took %d ms", timer.getDiff());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user