diff --git a/AI/BattleAI/BattleExchangeVariant.cpp b/AI/BattleAI/BattleExchangeVariant.cpp index cdb65a697..5e71bc4a9 100644 --- a/AI/BattleAI/BattleExchangeVariant.cpp +++ b/AI/BattleAI/BattleExchangeVariant.cpp @@ -519,14 +519,14 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits( for(auto hex : hexes) { - vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex)); + vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex)); } if(!ap.attack.attacker->isTurret()) { for(auto hex : ap.attack.attacker->getHexes()) { - auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex); + auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex); for(auto unit : unitsReachingAttacker) { if(unit->unitSide() != ap.attack.attacker->unitSide()) @@ -800,7 +800,9 @@ BattleScore BattleExchangeEvaluator::calculateExchange( if(!u->getPosition().isValid()) return false; // e.g. tower shooters - return vstd::contains_if(reachabilityMap.at(u->getPosition().toInt()), [&attacker](const battle::Unit * other) -> bool + const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition()); + + return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool { return attacker->unitId() == other->unitId(); }); @@ -887,7 +889,7 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap) { for(auto pos : ap.attack.attacker->getSurroundingHexes()) { - for(auto u : reachabilityMap.at(pos.toInt())) + for(auto u : getOneTurnReachableUnits(0, pos)) { if(u->unitSide() != ap.attack.attacker->unitSide()) { @@ -899,6 +901,22 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap) return false; } +void ReachabilityMapCache::update(const std::vector & turnOrder, std::shared_ptr hb) +{ + for(auto turn : turnOrder) + { + for(auto u : turn) + { + if(!vstd::contains(unitReachabilityMap, u->unitId())) + { + unitReachabilityMap[u->unitId()] = hb->getReachability(u); + } + } + } + + hexReachabilityPerTurn.clear(); +} + void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr hb) { const int TURN_DEPTH = 2; @@ -906,28 +924,25 @@ void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptrbattleGetTurnOrder(turnOrder, std::numeric_limits::max(), TURN_DEPTH); - - for(auto turn : turnOrder) - { - for(auto u : turn) - { - if(!vstd::contains(reachabilityCache, u->unitId())) - { - reachabilityCache[u->unitId()] = hb->getReachability(u); - } - } - } - - tbb::parallel_for(tbb::blocked_range(0, reachabilityMap.size()), [&](const tbb::blocked_range & r) - { - for(auto i = r.begin(); i != r.end(); i++) - reachabilityMap[i] = getOneTurnReachableUnits(0, BattleHex(i)); - }); + reachabilityMap.update(turnOrder, hb); } -std::vector BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const +const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr cb, std::shared_ptr env, const std::vector & turnOrder, uint8_t turn, BattleHex hex) { - std::vector result; + auto & turnData = hexReachabilityPerTurn[turn]; + + if (!turnData.isValid[hex.toInt()]) + { + turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex); + turnData.isValid.set(hex.toInt()); + } + + return turnData.hexes[hex.toInt()]; +} + +battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr cb, std::shared_ptr env, const std::vector & turnOrder, uint8_t turn, BattleHex hex) +{ + battle::Units result; for(int i = 0; i < turnOrder.size(); i++, turn++) { @@ -949,10 +964,10 @@ std::vector BattleExchangeEvaluator::getOneTurnReachableUn auto unitSpeed = unit->getMovementRange(turn); auto radius = unitSpeed * (turn + 1); - auto reachabilityIter = reachabilityCache.find(unit->unitId()); - assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call? + auto reachabilityIter = unitReachabilityMap.find(unit->unitId()); + assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call? - ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit); + ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit); bool reachable = unitReachability.distances.at(hex.toInt()) <= radius; @@ -981,6 +996,11 @@ std::vector BattleExchangeEvaluator::getOneTurnReachableUn return result; } +const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const +{ + return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex); +} + // avoid blocking path for stronger stack by weaker stack bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position) { @@ -1032,9 +1052,11 @@ bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb } } - if(!reachable && std::count(reachabilityMap[hex.toInt()].begin(), reachabilityMap[hex.toInt()].end(), unit) > 1) + if(!reachable) { - blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY); + auto reachableUnits = getOneTurnReachableUnits(0, hex); + if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1) + blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY); } } } diff --git a/AI/BattleAI/BattleExchangeVariant.h b/AI/BattleAI/BattleExchangeVariant.h index 8518a51df..de8577d5a 100644 --- a/AI/BattleAI/BattleExchangeVariant.h +++ b/AI/BattleAI/BattleExchangeVariant.h @@ -123,13 +123,29 @@ struct ReachabilityData std::set enemyUnitsReachingAttacker; }; +class ReachabilityMapCache +{ + struct PerTurnData{ + std::bitset isValid; + std::array hexes; + }; + + std::map unitReachabilityMap; // unit ID -> reachability + std::map hexReachabilityPerTurn; + + //const ReachabilityInfo & update(); + battle::Units computeOneTurnReachableUnits(std::shared_ptr cb, std::shared_ptr env, const std::vector & turnOrder, uint8_t turn, BattleHex hex); +public: + const battle::Units & getOneTurnReachableUnits(std::shared_ptr cb, std::shared_ptr env, const std::vector & turnOrder, uint8_t turn, BattleHex hex); + void update(const std::vector & turnOrder, std::shared_ptr hb); +}; + class BattleExchangeEvaluator { private: std::shared_ptr cb; std::shared_ptr env; - std::map reachabilityCache; - std::array, GameConstants::BFIELD_SIZE> reachabilityMap; + mutable ReachabilityMapCache reachabilityMap; std::vector turnOrder; float negativeEffectMultiplier; int simulationTurnsCount; @@ -169,7 +185,7 @@ public: DamageCache & damageCache, std::shared_ptr hb) const; - std::vector getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const; + const battle::Units & getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const; void updateReachabilityMap(std::shared_ptr hb); ReachabilityData getExchangeUnits(