1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-01-12 02:28:11 +02:00

BattleAI: spellcast fixes and floating point score

This commit is contained in:
Andrii Danylchenko 2023-08-26 13:06:41 +03:00
parent dc88f14e0b
commit 5f13a0bbda
10 changed files with 241 additions and 119 deletions

View File

@ -22,7 +22,7 @@ void DamageCache::cacheDamage(const battle::Unit * attacker, const battle::Unit
{ {
auto damage = averageDmg(hb->battleEstimateDamage(attacker, defender, 0).damage); auto damage = averageDmg(hb->battleEstimateDamage(attacker, defender, 0).damage);
damageCache[attacker->unitId()][defender->unitId()] = damage / attacker->getCount(); damageCache[attacker->unitId()][defender->unitId()] = static_cast<float>(damage) / attacker->getCount();
} }
@ -98,18 +98,18 @@ AttackPossibility::AttackPossibility(BattleHex from, BattleHex dest, const Battl
{ {
} }
int64_t AttackPossibility::damageDiff() const float AttackPossibility::damageDiff() const
{ {
return defenderDamageReduce - attackerDamageReduce - collateralDamageReduce + shootersBlockedDmg; return defenderDamageReduce - attackerDamageReduce - collateralDamageReduce + shootersBlockedDmg;
} }
int64_t AttackPossibility::damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const float AttackPossibility::damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const
{ {
return positiveEffectMultiplier * (defenderDamageReduce + shootersBlockedDmg) return positiveEffectMultiplier * (defenderDamageReduce + shootersBlockedDmg)
- negativeEffectMultiplier * (attackerDamageReduce + collateralDamageReduce); - negativeEffectMultiplier * (attackerDamageReduce + collateralDamageReduce);
} }
int64_t AttackPossibility::attackValue() const float AttackPossibility::attackValue() const
{ {
return damageDiff(); return damageDiff();
} }
@ -119,7 +119,7 @@ int64_t AttackPossibility::attackValue() const
/// Half bounty for kill, half for making damage equal to enemy health /// Half bounty for kill, half for making damage equal to enemy health
/// Bounty - the killed creature average damage calculated against attacker /// Bounty - the killed creature average damage calculated against attacker
/// </summary> /// </summary>
int64_t AttackPossibility::calculateDamageReduce( float AttackPossibility::calculateDamageReduce(
const battle::Unit * attacker, const battle::Unit * attacker,
const battle::Unit * defender, const battle::Unit * defender,
uint64_t damageDealt, uint64_t damageDealt,
@ -163,7 +163,7 @@ int64_t AttackPossibility::calculateDamageReduce(
auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth; auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth;
auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio); auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio);
return (int64_t)(damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY)); return damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY);
} }
int64_t AttackPossibility::evaluateBlockedShootersDmg( int64_t AttackPossibility::evaluateBlockedShootersDmg(
@ -270,7 +270,8 @@ AttackPossibility AttackPossibility::evaluate(
for(int i = 0; i < totalAttacks; i++) for(int i = 0; i < totalAttacks; i++)
{ {
int64_t damageDealt, damageReceived, defenderDamageReduce, attackerDamageReduce; int64_t damageDealt, damageReceived;
float defenderDamageReduce, attackerDamageReduce;
DamageEstimation retaliation; DamageEstimation retaliation;
auto attackDmg = state->battleEstimateDamage(ap.attack, &retaliation); auto attackDmg = state->battleEstimateDamage(ap.attack, &retaliation);

View File

@ -46,16 +46,16 @@ public:
std::vector<std::shared_ptr<battle::CUnitState>> affectedUnits; std::vector<std::shared_ptr<battle::CUnitState>> affectedUnits;
int64_t defenderDamageReduce = 0; float defenderDamageReduce = 0;
int64_t attackerDamageReduce = 0; //usually by counter-attack float attackerDamageReduce = 0; //usually by counter-attack
int64_t collateralDamageReduce = 0; // friendly fire (usually by two-hex attacks) float collateralDamageReduce = 0; // friendly fire (usually by two-hex attacks)
int64_t shootersBlockedDmg = 0; int64_t shootersBlockedDmg = 0;
AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack_); AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack_);
int64_t damageDiff() const; float damageDiff() const;
int64_t attackValue() const; float attackValue() const;
int64_t damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const; float damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const;
static AttackPossibility evaluate( static AttackPossibility evaluate(
const BattleAttackInfo & attackInfo, const BattleAttackInfo & attackInfo,
@ -63,7 +63,7 @@ public:
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<CBattleInfoCallback> state); std::shared_ptr<CBattleInfoCallback> state);
static int64_t calculateDamageReduce( static float calculateDamageReduce(
const battle::Unit * attacker, const battle::Unit * attacker,
const battle::Unit * defender, const battle::Unit * defender,
uint64_t damageDealt, uint64_t damageDealt,

View File

@ -81,7 +81,28 @@ void CBattleAI::yourTacticPhase(int distance)
cb->battleMakeTacticAction(BattleAction::makeEndOFTacticPhase(cb->battleGetTacticsSide())); cb->battleMakeTacticAction(BattleAction::makeEndOFTacticPhase(cb->battleGetTacticsSide()));
} }
void CBattleAI::activeStack( const CStack * stack ) float getStrengthRatio(std::shared_ptr<CBattleCallback> cb, int side)
{
auto stacks = cb->battleGetAllStacks();
auto our = 0, enemy = 0;
for(auto stack : stacks)
{
auto creature = stack->creatureId().toCreature();
if(!creature)
continue;
if(stack->unitSide() == side)
our += stack->getCount() * creature->getAIValue();
else
enemy += stack->getCount() * creature->getAIValue();
}
return enemy == 0 ? 1.0f : static_cast<float>(our) / enemy;
}
void CBattleAI::activeStack(const CStack * stack )
{ {
LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName()); LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName());
@ -110,7 +131,11 @@ void CBattleAI::activeStack( const CStack * stack )
return; return;
} }
BattleEvaluator evaluator(env, cb, stack, playerID, side, strengthRatio); #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Build evaluator and targets");
#endif
BattleEvaluator evaluator(env, cb, stack, playerID, side, getStrengthRatio(cb, side));
result = evaluator.selectStackAction(stack); result = evaluator.selectStackAction(stack);
@ -207,10 +232,6 @@ void CBattleAI::battleStart(const CCreatureSet *army1, const CCreatureSet *army2
{ {
LOG_TRACE(logAi); LOG_TRACE(logAi);
side = Side; side = Side;
strengthRatio = static_cast<float>(army1->getArmyStrength()) / static_cast<float>(army2->getArmyStrength());
if(side == 1)
strengthRatio = 1 / strengthRatio;
skipCastUntilNextBattle = false; skipCastUntilNextBattle = false;
} }

View File

@ -62,7 +62,6 @@ class CBattleAI : public CBattleGameInterface
bool wasWaitingForRealize; bool wasWaitingForRealize;
bool wasUnlockingGs; bool wasUnlockingGs;
int movesSkippedByDefense; int movesSkippedByDefense;
float strengthRatio;
bool skipCastUntilNextBattle; bool skipCastUntilNextBattle;
public: public:

View File

@ -100,11 +100,14 @@ std::optional<PossibleSpellcast> BattleEvaluator::findBestCreatureSpell(const CS
BattleAction BattleEvaluator::selectStackAction(const CStack * stack) BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
{ {
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("Select stack action");
#endif
//evaluate casting spell for spellcasting stack //evaluate casting spell for spellcasting stack
std::optional<PossibleSpellcast> bestSpellcast = findBestCreatureSpell(stack); std::optional<PossibleSpellcast> bestSpellcast = findBestCreatureSpell(stack);
auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, *targets, damageCache, hb); auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, *targets, damageCache, hb);
auto score = EvaluationResult::INEFFECTIVE_SCORE; float score = EvaluationResult::INEFFECTIVE_SCORE;
if(targets->possibleAttacks.empty() && bestSpellcast.has_value()) if(targets->possibleAttacks.empty() && bestSpellcast.has_value())
{ {
@ -136,7 +139,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
{ {
score = evaluationResult.score; score = evaluationResult.score;
logAi->debug("BattleAI: %s -> %s x %d, from %d curpos %d dist %d speed %d: +%lld -%lld = %lld", logAi->debug("BattleAI: %s -> %s x %d, from %d curpos %d dist %d speed %d: +%2f -%2f = %2f",
bestAttack.attackerState->unitType()->getJsonKey(), bestAttack.attackerState->unitType()->getJsonKey(),
bestAttack.affectedUnits[0]->unitType()->getJsonKey(), bestAttack.affectedUnits[0]->unitType()->getJsonKey(),
(int)bestAttack.affectedUnits[0]->getCount(), (int)bestAttack.affectedUnits[0]->getCount(),
@ -145,7 +148,8 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
bestAttack.attack.chargeDistance, bestAttack.attack.chargeDistance,
bestAttack.attack.attacker->speed(0, true), bestAttack.attack.attacker->speed(0, true),
bestAttack.defenderDamageReduce, bestAttack.defenderDamageReduce,
bestAttack.attackerDamageReduce, bestAttack.attackValue() bestAttack.attackerDamageReduce,
bestAttack.attackValue()
); );
if (moveTarget.scorePerTurn <= score) if (moveTarget.scorePerTurn <= score)
@ -513,11 +517,20 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
CStopWatch timer; CStopWatch timer;
#if BATTLE_TRACE_LEVEL >= 1
tbb::blocked_range<size_t> r(0, possibleCasts.size());
#else
tbb::parallel_for(tbb::blocked_range<size_t>(0, possibleCasts.size()), [&](const tbb::blocked_range<size_t> & r) tbb::parallel_for(tbb::blocked_range<size_t>(0, possibleCasts.size()), [&](const tbb::blocked_range<size_t> & r)
{ {
#endif
for(auto i = r.begin(); i != r.end(); i++) for(auto i = r.begin(); i != r.end(); i++)
{ {
auto & ps = possibleCasts[i]; auto & ps = possibleCasts[i];
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("Evaluating %s", ps.spell->getNameTranslated());
#endif
auto state = std::make_shared<HypotheticBattle>(env.get(), cb); auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell); spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell);
@ -531,12 +544,17 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
return !original || u->speed() != original->speed(); return !original || u->speed() != original->speed();
}); });
DamageCache innerCache(&damageCache); DamageCache safeCopy = damageCache;
DamageCache innerCache(&safeCopy);
innerCache.buildDamageCache(state, side); innerCache.buildDamageCache(state, side);
if(needFullEval || !cachedAttack) if(needFullEval || !cachedAttack)
{ {
PotentialTargets innerTargets(activeStack, damageCache, state); #if BATTLE_TRACE_LEVEL >= 1
logAi->trace("Full evaluation is started due to stack speed affected.");
#endif
PotentialTargets innerTargets(activeStack, innerCache, state);
BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio); BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio);
if(!innerTargets.possibleAttacks.empty()) if(!innerTargets.possibleAttacks.empty())
@ -586,14 +604,27 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
} }
else else
ps.value -= dpsReduce * scoreEvaluator.getNegativeEffectMultiplier(); ps.value -= dpsReduce * scoreEvaluator.getNegativeEffectMultiplier();
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace(
"Spell affects %s (%d), dps: %2f",
unit->creatureId().toCreature()->getNameSingularTranslated(),
unit->getCount(),
dpsReduce);
#endif
} }
} }
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("Total score: %2f", ps.value);
#endif
} }
#if BATTLE_TRACE_LEVEL == 0
}); });
#endif
LOGFL("Evaluation took %d ms", timer.getDiff()); LOGFL("Evaluation took %d ms", timer.getDiff());
auto pscValue = [](const PossibleSpellcast &ps) -> int64_t auto pscValue = [](const PossibleSpellcast &ps) -> float
{ {
return ps.value; return ps.value;
}; };

View File

@ -33,7 +33,7 @@ class BattleEvaluator
std::optional<AttackPossibility> cachedAttack; std::optional<AttackPossibility> cachedAttack;
PlayerColor playerID; PlayerColor playerID;
int side; int side;
int64_t cachedScore; float cachedScore;
DamageCache damageCache; DamageCache damageCache;
float strengthRatio; float strengthRatio;

View File

@ -25,40 +25,107 @@ MoveTarget::MoveTarget()
turnsToRich = 1; turnsToRich = 1;
} }
int64_t BattleExchangeVariant::trackAttack(const AttackPossibility & ap, HypotheticBattle & state) float BattleExchangeVariant::trackAttack(
const AttackPossibility & ap,
std::shared_ptr<HypotheticBattle> hb,
DamageCache & damageCache)
{ {
auto attacker = hb->getForUpdate(ap.attack.attacker->unitId());
const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
float attackValue = 0;
auto affectedUnits = ap.affectedUnits; auto affectedUnits = ap.affectedUnits;
affectedUnits.push_back(ap.attackerState); affectedUnits.push_back(ap.attackerState);
for(auto affectedUnit : affectedUnits) for(auto affectedUnit : affectedUnits)
{ {
auto unitToUpdate = state.getForUpdate(affectedUnit->unitId()); auto unitToUpdate = hb->getForUpdate(affectedUnit->unitId());
unitToUpdate->health = affectedUnit->health; if(unitToUpdate->unitSide() == attacker->unitSide())
unitToUpdate->shots = affectedUnit->shots; {
unitToUpdate->counterAttacks = affectedUnit->counterAttacks; if(unitToUpdate->unitId() == attacker->unitId())
unitToUpdate->movedThisRound = affectedUnit->movedThisRound; {
} auto defender = hb->getForUpdate(ap.attack.defender->unitId());
auto attackValue = ap.damageDiff(positiveEffectMultiplier, negativeEffectMultiplier); if(!defender->alive() || counterAttacksBlocked || ap.attack.shooting || !defender->ableToRetaliate())
continue;
dpsScore += attackValue; auto retaliationDamage = damageCache.getDamage(defender.get(), unitToUpdate.get(), hb);
auto attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), unitToUpdate.get(), retaliationDamage, damageCache, hb);
attackValue -= attackerDamageReduce;
dpsScore -= attackerDamageReduce * negativeEffectMultiplier;
attackerValue[unitToUpdate->unitId()].isRetalitated = true;
unitToUpdate->damage(retaliationDamage);
defender->afterAttack(false, true);
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace( logAi->trace(
"%s -> %s, ap attack, %s, dps: %lld, score: %lld", "%s -> %s, ap retalitation, %s, dps: %2f, score: %2f",
ap.attack.attacker->getDescription(), defender->getDescription(),
ap.attack.defender->getDescription(), unitToUpdate->getDescription(),
ap.attack.shooting ? "shot" : "mellee", ap.attack.shooting ? "shot" : "mellee",
ap.damageDealt, retaliationDamage,
attackValue); attackerDamageReduce);
#endif #endif
}
else
{
auto collateralDamage = damageCache.getDamage(attacker.get(), unitToUpdate.get(), hb);
auto collateralDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), unitToUpdate.get(), collateralDamage, damageCache, hb);
attackValue -= collateralDamageReduce;
dpsScore -= collateralDamageReduce * negativeEffectMultiplier;
unitToUpdate->damage(collateralDamage);
#if BATTLE_TRACE_LEVEL>=1
logAi->trace(
"%s -> %s, ap collateral, %s, dps: %2f, score: %2f",
attacker->getDescription(),
unitToUpdate->getDescription(),
ap.attack.shooting ? "shot" : "mellee",
collateralDamage,
collateralDamageReduce);
#endif
}
}
else
{
int64_t attackDamage = damageCache.getDamage(attacker.get(), unitToUpdate.get(), hb);
float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), unitToUpdate.get(), attackDamage, damageCache, hb);
attackValue += defenderDamageReduce;
dpsScore += defenderDamageReduce * positiveEffectMultiplier;
attackerValue[attacker->unitId()].value += defenderDamageReduce;
unitToUpdate->damage(attackDamage);
#if BATTLE_TRACE_LEVEL>=1
logAi->trace(
"%s -> %s, ap attack, %s, dps: %2f, score: %2f",
attacker->getDescription(),
unitToUpdate->getDescription(),
ap.attack.shooting ? "shot" : "mellee",
attackDamage,
defenderDamageReduce);
#endif
}
}
attackValue += ap.shootersBlockedDmg;
dpsScore += ap.shootersBlockedDmg * positiveEffectMultiplier;
attacker->afterAttack(ap.attack.shooting, false);
return attackValue; return attackValue;
} }
int64_t BattleExchangeVariant::trackAttack( float BattleExchangeVariant::trackAttack(
std::shared_ptr<StackWithBonuses> attacker, std::shared_ptr<StackWithBonuses> attacker,
std::shared_ptr<StackWithBonuses> defender, std::shared_ptr<StackWithBonuses> defender,
bool shooting, bool shooting,
@ -71,23 +138,15 @@ int64_t BattleExchangeVariant::trackAttack(
static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION); static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation); const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
// FIXME: provide distance info for Jousting bonus
BattleAttackInfo bai(attacker.get(), defender.get(), 0, shooting);
if(shooting)
{
bai.attackerPos.setXY(8, 5);
}
int64_t attackDamage = damageCache.getDamage(attacker.get(), defender.get(), hb); int64_t attackDamage = damageCache.getDamage(attacker.get(), defender.get(), hb);
int64_t defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, damageCache, hb); float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, damageCache, hb);
int64_t attackerDamageReduce = 0; float attackerDamageReduce = 0;
if(!evaluateOnly) if(!evaluateOnly)
{ {
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace( logAi->trace(
"%s -> %s, normal attack, %s, dps: %lld, %lld", "%s -> %s, normal attack, %s, dps: %lld, %2f",
attacker->getDescription(), attacker->getDescription(),
defender->getDescription(), defender->getDescription(),
shooting ? "shot" : "mellee", shooting ? "shot" : "mellee",
@ -107,36 +166,33 @@ int64_t BattleExchangeVariant::trackAttack(
attacker->afterAttack(shooting, false); attacker->afterAttack(shooting, false);
} }
if(defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting) if(!evaluateOnly && defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
{ {
auto retaliationDamage = damageCache.getDamage(defender.get(), attacker.get(), hb); auto retaliationDamage = damageCache.getDamage(defender.get(), attacker.get(), hb);
attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, damageCache, hb); attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, damageCache, hb);
if(!evaluateOnly)
{
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace( logAi->trace(
"%s -> %s, retaliation, dps: %lld, %lld", "%s -> %s, retaliation, dps: %lld, %2f",
defender->getDescription(), defender->getDescription(),
attacker->getDescription(), attacker->getDescription(),
retaliationDamage, retaliationDamage,
attackerDamageReduce); attackerDamageReduce);
#endif #endif
if(isOurAttack) if(isOurAttack)
{ {
dpsScore -= attackerDamageReduce * negativeEffectMultiplier; dpsScore -= attackerDamageReduce * negativeEffectMultiplier;
attackerValue[attacker->unitId()].isRetalitated = true; attackerValue[attacker->unitId()].isRetalitated = true;
}
else
{
dpsScore += attackerDamageReduce * positiveEffectMultiplier;
attackerValue[defender->unitId()].value += attackerDamageReduce;
}
attacker->damage(retaliationDamage);
defender->afterAttack(false, true);
} }
else
{
dpsScore += attackerDamageReduce * positiveEffectMultiplier;
attackerValue[defender->unitId()].value += attackerDamageReduce;
}
attacker->damage(retaliationDamage);
defender->afterAttack(false, true);
} }
auto score = defenderDamageReduce - attackerDamageReduce; auto score = defenderDamageReduce - attackerDamageReduce;
@ -144,7 +200,7 @@ int64_t BattleExchangeVariant::trackAttack(
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
if(!score) if(!score)
{ {
logAi->trace("Attack has zero score d:%lld a:%lld", defenderDamageReduce, attackerDamageReduce); logAi->trace("Attack has zero score d:%2f a:%2f", defenderDamageReduce, attackerDamageReduce);
} }
#endif #endif
@ -159,33 +215,22 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
{ {
EvaluationResult result(targets.bestAction()); EvaluationResult result(targets.bestAction());
updateReachabilityMap(hb);
for(auto & ap : targets.possibleAttacks)
{
int64_t score = calculateExchange(ap, targets, damageCache, hb);
if(score > result.score)
{
result.score = score;
result.bestAttack = ap;
}
}
if(!activeStack->waited()) if(!activeStack->waited())
{ {
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Evaluating waited attack for %s", activeStack->getDescription()); logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
#endif #endif
hb->getForUpdate(activeStack->unitId())->waiting = true; auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
hb->getForUpdate(activeStack->unitId())->waitedThisTurn = true;
updateReachabilityMap(hb); hbWaited->getForUpdate(activeStack->unitId())->waiting = true;
hbWaited->getForUpdate(activeStack->unitId())->waitedThisTurn = true;
updateReachabilityMap(hbWaited);
for(auto & ap : targets.possibleAttacks) for(auto & ap : targets.possibleAttacks)
{ {
int64_t score = calculateExchange(ap, targets, damageCache, hb); float score = calculateExchange(ap, targets, damageCache, hbWaited);
if(score > result.score) if(score > result.score)
{ {
@ -196,6 +241,24 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
} }
} }
#if BATTLE_TRACE_LEVEL>=1
logAi->trace("Evaluating normal attack for %s", activeStack->getDescription());
#endif
updateReachabilityMap(hb);
for(auto & ap : targets.possibleAttacks)
{
float score = calculateExchange(ap, targets, damageCache, hb);
if(score >= result.score)
{
result.score = score;
result.bestAttack = ap;
result.wait = false;
}
}
return result; return result;
} }
@ -361,14 +424,14 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
return exchangeUnits; return exchangeUnits;
} }
int64_t BattleExchangeEvaluator::calculateExchange( float BattleExchangeEvaluator::calculateExchange(
const AttackPossibility & ap, const AttackPossibility & ap,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb) std::shared_ptr<HypotheticBattle> hb)
{ {
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Battle exchange at %lld", ap.attack.shooting ? ap.dest : ap.from); logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
#endif #endif
if(cb->battleGetMySide() == BattlePerspective::LEFT_SIDE if(cb->battleGetMySide() == BattlePerspective::LEFT_SIDE
@ -439,7 +502,7 @@ int64_t BattleExchangeEvaluator::calculateExchange(
if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive()) if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive())
{ {
auto estimateAttack = [&](const battle::Unit * u) -> int64_t auto estimateAttack = [&](const battle::Unit * u) -> float
{ {
auto stackWithBonuses = exchangeBattle->getForUpdate(u->unitId()); auto stackWithBonuses = exchangeBattle->getForUpdate(u->unitId());
auto score = v.trackAttack( auto score = v.trackAttack(
@ -452,7 +515,7 @@ int64_t BattleExchangeEvaluator::calculateExchange(
true); true);
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Best target selector %s->%s score = %lld", attacker->getDescription(), u->getDescription(), score); logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), u->getDescription(), score);
#endif #endif
return score; return score;
@ -497,9 +560,10 @@ int64_t BattleExchangeEvaluator::calculateExchange(
auto shooting = exchangeBattle->battleCanShoot(attacker.get()); auto shooting = exchangeBattle->battleCanShoot(attacker.get());
const int totalAttacks = attacker->getTotalAttacks(shooting); const int totalAttacks = attacker->getTotalAttacks(shooting);
if(canUseAp && activeUnit->unitId() == ap.attack.attacker->unitId() && targetUnit->unitId() == ap.attack.defender->unitId()) if(canUseAp && activeUnit->unitId() == ap.attack.attacker->unitId()
&& targetUnit->unitId() == ap.attack.defender->unitId())
{ {
v.trackAttack(ap, *exchangeBattle); v.trackAttack(ap, exchangeBattle, damageCache);
} }
else else
{ {
@ -530,7 +594,7 @@ int64_t BattleExchangeEvaluator::calculateExchange(
v.adjustPositions(melleeAttackers, ap, reachabilityMap); v.adjustPositions(melleeAttackers, ap, reachabilityMap);
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Exchange score: %lld", v.getScore()); logAi->trace("Exchange score: %2f", v.getScore());
#endif #endif
return v.getScore(); return v.getScore();
@ -560,7 +624,7 @@ void BattleExchangeVariant::adjustPositions(
vstd::erase_if_present(hexes, ap.attack.attacker->occupiedHex(ap.attack.attackerPos)); vstd::erase_if_present(hexes, ap.attack.attacker->occupiedHex(ap.attack.attackerPos));
} }
int64_t notRealizedDamage = 0; float notRealizedDamage = 0;
for(auto unit : attackers) for(auto unit : attackers)
{ {
@ -576,7 +640,7 @@ void BattleExchangeVariant::adjustPositions(
continue; continue;
} }
auto desiredPosition = vstd::minElementByFun(hexes, [&](BattleHex h) -> int64_t auto desiredPosition = vstd::minElementByFun(hexes, [&](BattleHex h) -> float
{ {
auto score = vstd::contains(reachabilityMap[h], unit) auto score = vstd::contains(reachabilityMap[h], unit)
? reachabilityMap[h].size() ? reachabilityMap[h].size()

View File

@ -16,7 +16,7 @@
struct AttackerValue struct AttackerValue
{ {
int64_t value; float value;
bool isRetalitated; bool isRetalitated;
BattleHex position; BattleHex position;
@ -25,8 +25,8 @@ struct AttackerValue
struct MoveTarget struct MoveTarget
{ {
int64_t score; float score;
int64_t scorePerTurn; float scorePerTurn;
std::vector<BattleHex> positions; std::vector<BattleHex> positions;
std::optional<AttackPossibility> cachedAttack; std::optional<AttackPossibility> cachedAttack;
uint8_t turnsToRich; uint8_t turnsToRich;
@ -36,12 +36,12 @@ struct MoveTarget
struct EvaluationResult struct EvaluationResult
{ {
static const int64_t INEFFECTIVE_SCORE = -1000000; static const int64_t INEFFECTIVE_SCORE = -10000;
AttackPossibility bestAttack; AttackPossibility bestAttack;
MoveTarget bestMove; MoveTarget bestMove;
bool wait; bool wait;
int64_t score; float score;
bool defend; bool defend;
EvaluationResult(const AttackPossibility & ap) EvaluationResult(const AttackPossibility & ap)
@ -62,9 +62,12 @@ public:
BattleExchangeVariant(float positiveEffectMultiplier, float negativeEffectMultiplier) BattleExchangeVariant(float positiveEffectMultiplier, float negativeEffectMultiplier)
: dpsScore(0), positiveEffectMultiplier(positiveEffectMultiplier), negativeEffectMultiplier(negativeEffectMultiplier) {} : dpsScore(0), positiveEffectMultiplier(positiveEffectMultiplier), negativeEffectMultiplier(negativeEffectMultiplier) {}
int64_t trackAttack(const AttackPossibility & ap, HypotheticBattle & state); float trackAttack(
const AttackPossibility & ap,
std::shared_ptr<HypotheticBattle> hb,
DamageCache & damageCache);
int64_t trackAttack( float trackAttack(
std::shared_ptr<StackWithBonuses> attacker, std::shared_ptr<StackWithBonuses> attacker,
std::shared_ptr<StackWithBonuses> defender, std::shared_ptr<StackWithBonuses> defender,
bool shooting, bool shooting,
@ -73,7 +76,7 @@ public:
std::shared_ptr<HypotheticBattle> hb, std::shared_ptr<HypotheticBattle> hb,
bool evaluateOnly = false); bool evaluateOnly = false);
int64_t getScore() const { return dpsScore; } float getScore() const { return dpsScore; }
void adjustPositions( void adjustPositions(
std::vector<const battle::Unit *> attackers, std::vector<const battle::Unit *> attackers,
@ -83,7 +86,7 @@ public:
private: private:
float positiveEffectMultiplier; float positiveEffectMultiplier;
float negativeEffectMultiplier; float negativeEffectMultiplier;
int64_t dpsScore; float dpsScore;
std::map<uint32_t, AttackerValue> attackerValue; std::map<uint32_t, AttackerValue> attackerValue;
}; };
@ -110,7 +113,7 @@ public:
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb); std::shared_ptr<HypotheticBattle> hb);
int64_t calculateExchange( float calculateExchange(
const AttackPossibility & ap, const AttackPossibility & ap,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,

View File

@ -27,7 +27,7 @@ public:
const CSpell * spell; const CSpell * spell;
spells::Target dest; spells::Target dest;
int64_t value; float value;
PossibleSpellcast(); PossibleSpellcast();
virtual ~PossibleSpellcast(); virtual ~PossibleSpellcast();

View File

@ -45,7 +45,8 @@ StackWithBonuses::StackWithBonuses(const HypotheticBattle * Owner, const battle:
id(Stack->unitId()), id(Stack->unitId()),
side(Stack->unitSide()), side(Stack->unitSide()),
player(Stack->unitOwner()), player(Stack->unitOwner()),
slot(Stack->unitSlot()) slot(Stack->unitSlot()),
treeVersionLocal(0)
{ {
localInit(Owner); localInit(Owner);
@ -61,7 +62,8 @@ StackWithBonuses::StackWithBonuses(const HypotheticBattle * Owner, const battle:
id(Stack->unitId()), id(Stack->unitId()),
side(Stack->unitSide()), side(Stack->unitSide()),
player(Stack->unitOwner()), player(Stack->unitOwner()),
slot(Stack->unitSlot()) slot(Stack->unitSlot()),
treeVersionLocal(0)
{ {
localInit(Owner); localInit(Owner);
@ -76,7 +78,8 @@ StackWithBonuses::StackWithBonuses(const HypotheticBattle * Owner, const battle:
baseAmount(info.count), baseAmount(info.count),
id(info.id), id(info.id),
side(info.side), side(info.side),
slot(SlotID::SUMMONED_SLOT_PLACEHOLDER) slot(SlotID::SUMMONED_SLOT_PLACEHOLDER),
treeVersionLocal(0)
{ {
type = info.type.toCreature(); type = info.type.toCreature();
origBearer = type; origBearer = type;