From 5c4ce61d57c564e9107258818ebec9fc3febe59f Mon Sep 17 00:00:00 2001 From: Ivan Savenko Date: Fri, 23 May 2025 14:00:47 +0300 Subject: [PATCH] Fix handling of double-wide creatures by BattleAI One of recently added sanity checks was causing crashes during BattleAI decision-making. Actual reason turned out to be due to invalid requests generated by BattleAI when attempting to attack enemy unit from behind with double- wide unit. This change should make BattleAI correctly estimate such attacks --- AI/BattleAI/BattleExchangeVariant.cpp | 20 +++--- lib/battle/CBattleInfoCallback.cpp | 2 +- lib/battle/Unit.cpp | 32 +++++---- lib/battle/Unit.h | 2 + test/battle/CBattleInfoCallbackTest.cpp | 92 ++++++++++++++++++++++++- 5 files changed, 122 insertions(+), 26 deletions(-) diff --git a/AI/BattleAI/BattleExchangeVariant.cpp b/AI/BattleAI/BattleExchangeVariant.cpp index bfd5416fb..c7693ab44 100644 --- a/AI/BattleAI/BattleExchangeVariant.cpp +++ b/AI/BattleAI/BattleExchangeVariant.cpp @@ -385,7 +385,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable( } auto turnsToReach = (distance - 1) / speed + 1; - const BattleHexArray & hexes = enemy->getSurroundingHexes(); + const BattleHexArray & hexes = enemy->getAttackableHexes(activeStack); auto enemySpeed = enemy->getMovementRange(); auto speedRatio = speed / static_cast(enemySpeed); auto multiplier = (speedRatio > 1 ? 1 : speedRatio) * penaltyMultiplier; @@ -416,8 +416,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable( logAi->trace("New high score"); #endif - for(BattleHex enemyHex : enemy->getAttackableHexes(activeStack)) - { + BattleHex enemyHex = hex; while(!flying && dists.distances[enemyHex.toInt()] > speed && dists.predecessors.at(enemyHex.toInt()).isValid()) { enemyHex = dists.predecessors.at(enemyHex.toInt()); @@ -425,14 +424,17 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable( if(dists.accessibility[enemyHex.toInt()] == EAccessibility::ALIVE_STACK) { auto defenderToBypass = hb->battleGetUnitByPos(enemyHex); - - if(defenderToBypass) + assert(defenderToBypass != nullptr); + auto attackHex = dists.predecessors[enemyHex.toInt()]; + + if(defenderToBypass && + defenderToBypass != enemy && + vstd::contains(defenderToBypass->getAttackableHexes(activeStack), attackHex)) { #if BATTLE_TRACE_LEVEL >= 1 logAi->trace("Found target to bypass at %d", enemyHex.toInt()); #endif - - auto attackHex = dists.predecessors[enemyHex.toInt()]; + auto baiBypass = BattleAttackInfo(activeStack, defenderToBypass, 0, cb->battleCanShoot(activeStack)); auto attackBypass = AttackPossibility::evaluate(baiBypass, attackHex, damageCache, hb); @@ -461,9 +463,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable( } } - result.positions.insert(enemyHex); - } - + result.positions.insert(enemyHex); result.cachedAttack = attack; result.turnsToReach = turnsToReach; } diff --git a/lib/battle/CBattleInfoCallback.cpp b/lib/battle/CBattleInfoCallback.cpp index 58669a37a..559f3fb5c 100644 --- a/lib/battle/CBattleInfoCallback.cpp +++ b/lib/battle/CBattleInfoCallback.cpp @@ -1338,7 +1338,7 @@ AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes( attackOriginHex = attacker->occupiedHex(attackOriginHex); if (!vstd::contains(defender->getSurroundingHexes(defenderPos), attackOriginHex)) - throw std::runtime_error("!!!"); + throw std::runtime_error("Atempt to attack from invalid position!"); auto attackDirection = BattleHex::mutualPosition(attackOriginHex, defenderPos); diff --git a/lib/battle/Unit.cpp b/lib/battle/Unit.cpp index 7af51830f..6661397a2 100644 --- a/lib/battle/Unit.cpp +++ b/lib/battle/Unit.cpp @@ -68,22 +68,28 @@ const BattleHexArray & Unit::getSurroundingHexes(const BattleHex & position, boo BattleHexArray Unit::getAttackableHexes(const Unit * attacker) const { - const BattleHexArray & defenderHexes = getHexes(); - - BattleHexArray targetableHexes; - - for(const auto & defenderHex : defenderHexes) + if (!attacker->doubleWide()) { - auto hexes = battle::Unit::getHexes(defenderHex); - - if(hexes.size() == 2 && BattleHex::getDistance(hexes.front(), hexes.back()) != 1) - hexes.pop_back(); - - for(const auto & hex : hexes) - targetableHexes.insert(hex.getNeighbouringTiles()); + return getSurroundingHexes(); } + else + { + BattleHexArray result; - return targetableHexes; + for (const auto & attackOrigin : getSurroundingHexes()) + { + if (!coversPos(attacker->occupiedHex(attackOrigin)) && attackOrigin.isAvailable()) + result.insert(attackOrigin); + + bool isAttacker = attacker->unitSide() == BattleSide::ATTACKER; + BattleHex::EDir headDirection = isAttacker ? BattleHex::RIGHT : BattleHex::LEFT; + BattleHex headHex = attackOrigin.cloneInDirection(headDirection); + + if (!coversPos(headHex) && headHex.isAvailable()) + result.insert(headHex); + } + return result; + } } bool Unit::coversPos(const BattleHex & pos) const diff --git a/lib/battle/Unit.h b/lib/battle/Unit.h index dae622397..37affcc0f 100644 --- a/lib/battle/Unit.h +++ b/lib/battle/Unit.h @@ -132,6 +132,8 @@ public: virtual std::string getDescription() const; const BattleHexArray & getSurroundingHexes(const BattleHex & assumedPosition = BattleHex::INVALID) const; // get six or 8 surrounding hexes depending on creature size + + /// Returns list of hexes from which attacker can attack this unit BattleHexArray getAttackableHexes(const Unit * attacker) const; static const BattleHexArray & getSurroundingHexes(const BattleHex & position, bool twoHex, BattleSide side); diff --git a/test/battle/CBattleInfoCallbackTest.cpp b/test/battle/CBattleInfoCallbackTest.cpp index 263d382c5..89096931c 100644 --- a/test/battle/CBattleInfoCallbackTest.cpp +++ b/test/battle/CBattleInfoCallbackTest.cpp @@ -262,6 +262,96 @@ public: } }; +///// getAttackableHexes tests + +TEST_F(AttackableHexesTest, getAttackableHexes_SingleWideAttacker_SingleWideDefender) +{ + UnitFake & attacker = addRegularMelee(60, BattleSide::ATTACKER); + UnitFake & defender = addRegularMelee(90, BattleSide::DEFENDER); + + static const BattleHexArray expectedDef = + { + 72, + 73, + 89, + 91, + 106, + 107 + }; + + auto attackable = defender.getAttackableHexes(&attacker); + attackable.sort([](const auto & l, const auto & r) { return l < r; }); + EXPECT_EQ(expectedDef, attackable); +} + +TEST_F(AttackableHexesTest, getAttackableHexes_SingleWideAttacker_DoubleWideDefender) +{ + UnitFake & attacker = addRegularMelee(60, BattleSide::ATTACKER); + UnitFake & defender = addDragon(90, BattleSide::DEFENDER); + + static const BattleHexArray expectedDef = + { + 72, + 73, + 74, + 89, + 92, + 106, + 107, + 108 + }; + + auto attackable = defender.getAttackableHexes(&attacker); + attackable.sort([](const auto & l, const auto & r) { return l < r; }); + EXPECT_EQ(expectedDef, attackable); +} + +TEST_F(AttackableHexesTest, getAttackableHexes_DoubleWideAttacker_SingleWideDefender) +{ + UnitFake & attacker = addDragon(60, BattleSide::ATTACKER); + UnitFake & defender = addRegularMelee(90, BattleSide::DEFENDER); + + static const BattleHexArray expectedDef = + { + 72, + 73, + 74, + 89, + 92, + 106, + 107, + 108 + }; + + auto attackable = defender.getAttackableHexes(&attacker); + attackable.sort([](const auto & l, const auto & r) { return l < r; }); + EXPECT_EQ(expectedDef, attackable); +} + +TEST_F(AttackableHexesTest, getAttackableHexes_DoubleWideAttacker_DoubleWideDefender) +{ + UnitFake & attacker = addDragon(60, BattleSide::ATTACKER); + UnitFake & defender = addDragon(90, BattleSide::DEFENDER); + + static const BattleHexArray expectedDef = + { + 72, + 73, + 74, + 75, + 89, + 93, + 106, + 107, + 108, + 109 + }; + + auto attackable = defender.getAttackableHexes(&attacker); + attackable.sort([](const auto & l, const auto & r) { return l < r; }); + EXPECT_EQ(expectedDef, attackable); +} + //// CERBERI 3-HEADED ATTACKS TEST_F(AttackableHexesTest, CerberiAttackerRight) @@ -276,7 +366,6 @@ TEST_F(AttackableHexesTest, CerberiAttackerRight) auto attacked = getAttackedUnits(attacker, defender, defender.getPosition()); - EXPECT_TRUE(vstd::contains(attacked, &defender)); EXPECT_TRUE(vstd::contains(attacked, &right)); EXPECT_TRUE(vstd::contains(attacked, &left)); } @@ -356,7 +445,6 @@ TEST_F(AttackableHexesTest, DragonRightRegular_RightHorithontalBreath) auto attacked = getAttackedUnits(attacker, defender, defender.getPosition()); - EXPECT_TRUE(vstd::contains(attacked, &defender)); EXPECT_TRUE(vstd::contains(attacked, &next)); }