1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-03-19 21:10:12 +02:00

Merge pull request #5250 from IvanSavenko/optimize_ai

BattleAI optimizations
This commit is contained in:
Ivan Savenko 2025-01-12 14:20:30 +02:00 committed by GitHub
commit d935a19504
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 277 additions and 158 deletions

View File

@ -92,8 +92,8 @@ void DamageCache::buildDamageCache(std::shared_ptr<HypotheticBattle> hb, BattleS
return u->isValidTarget();
});
std::vector<const battle::Unit *> ourUnits;
std::vector<const battle::Unit *> enemyUnits;
battle::Units ourUnits;
battle::Units enemyUnits;
for(auto stack : stacks)
{
@ -346,9 +346,9 @@ AttackPossibility AttackPossibility::evaluate(
if (!attackInfo.shooting)
ap.attackerState->setPosition(hex);
std::vector<const battle::Unit *> defenderUnits;
std::vector<const battle::Unit *> retaliatedUnits = {attacker};
std::vector<const battle::Unit *> affectedUnits;
battle::Units defenderUnits;
battle::Units retaliatedUnits = {attacker};
battle::Units affectedUnits;
if (attackInfo.shooting)
defenderUnits = state->getAttackedBattleUnits(attacker, defender, defHex, true, hex, defender->getPosition());
@ -384,7 +384,9 @@ AttackPossibility AttackPossibility::evaluate(
affectedUnits = defenderUnits;
vstd::concatenate(affectedUnits, retaliatedUnits);
#if BATTLE_TRACE_LEVEL>=1
logAi->trace("Attacked battle units count %d, %d->%d", affectedUnits.size(), hex, defHex);
#endif
std::map<uint32_t, std::shared_ptr<battle::CUnitState>> defenderStates;

View File

@ -756,7 +756,9 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
auto updatedAttack = AttackPossibility::evaluate(updatedBai, cachedAttack.ap->from, innerCache, state);
stackActionScore = scoreEvaluator.evaluateExchange(updatedAttack, cachedAttack.turn, *targets, innerCache, state);
BattleExchangeEvaluator innerEvaluator(scoreEvaluator);
stackActionScore = innerEvaluator.evaluateExchange(updatedAttack, cachedAttack.turn, *targets, innerCache, state);
}
for(const auto & unit : allUnits)
{

View File

@ -11,6 +11,7 @@
#include "BattleExchangeVariant.h"
#include "BattleEvaluator.h"
#include "../../lib/CStack.h"
#include "tbb/parallel_for.h"
AttackerValue::AttackerValue()
: value(0),
@ -470,10 +471,10 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
return result;
}
std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
battle::Units BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
{
std::queue<const battle::Unit *> queue;
std::vector<const battle::Unit *> checkedStacks;
battle::Units checkedStacks;
queue.push(blockerUnit);
@ -505,7 +506,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
uint8_t turn,
PotentialTargets & targets,
std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits) const
battle::Units additionalUnits) const
{
ReachabilityData result;
@ -514,18 +515,18 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
if(!ap.attack.shooting)
hexes.insert(ap.from);
std::vector<const battle::Unit *> allReachableUnits = additionalUnits;
battle::Units allReachableUnits = additionalUnits;
for(auto hex : hexes)
{
vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex));
vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex));
}
if(!ap.attack.attacker->isTurret())
{
for(auto hex : ap.attack.attacker->getHexes())
{
auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex);
auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex);
for(auto unit : unitsReachingAttacker)
{
if(unit->unitSide() != ap.attack.attacker->unitSide())
@ -635,7 +636,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
PotentialTargets & targets,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits) const
battle::Units additionalUnits) const
{
#if BATTLE_TRACE_LEVEL>=1
logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
@ -648,8 +649,8 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
}
std::vector<const battle::Unit *> ourStacks;
std::vector<const battle::Unit *> enemyStacks;
battle::Units ourStacks;
battle::Units enemyStacks;
if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
enemyStacks.push_back(ap.attack.defender);
@ -799,7 +800,9 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
if(!u->getPosition().isValid())
return false; // e.g. tower shooters
return vstd::contains_if(reachabilityMap.at(u->getPosition()), [&attacker](const battle::Unit * other) -> bool
const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition());
return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool
{
return attacker->unitId() == other->unitId();
});
@ -886,7 +889,7 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
{
for(auto pos : ap.attack.attacker->getSurroundingHexes())
{
for(auto u : reachabilityMap[pos])
for(auto u : getOneTurnReachableUnits(0, pos))
{
if(u->unitSide() != ap.attack.attacker->unitSide())
{
@ -898,6 +901,22 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
return false;
}
void ReachabilityMapCache::update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb)
{
for(auto turn : turnOrder)
{
for(auto u : turn)
{
if(!vstd::contains(unitReachabilityMap, u->unitId()))
{
unitReachabilityMap[u->unitId()] = hb->getReachability(u);
}
}
}
hexReachabilityPerTurn.clear();
}
void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
{
const int TURN_DEPTH = 2;
@ -905,26 +924,25 @@ void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBa
turnOrder.clear();
hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
for(auto turn : turnOrder)
{
for(auto u : turn)
{
if(!vstd::contains(reachabilityCache, u->unitId()))
{
reachabilityCache[u->unitId()] = hb->getReachability(u);
}
}
}
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); ++hex)
{
reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
}
reachabilityMap.update(turnOrder, hb);
}
std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex)
{
std::vector<const battle::Unit *> result;
auto & turnData = hexReachabilityPerTurn[turn];
if (!turnData.isValid[hex.toInt()])
{
turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
turnData.isValid.set(hex.toInt());
}
return turnData.hexes[hex.toInt()];
}
battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex)
{
battle::Units result;
for(int i = 0; i < turnOrder.size(); i++, turn++)
{
@ -946,10 +964,10 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
auto unitSpeed = unit->getMovementRange(turn);
auto radius = unitSpeed * (turn + 1);
auto reachabilityIter = reachabilityCache.find(unit->unitId());
assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
auto reachabilityIter = unitReachabilityMap.find(unit->unitId());
assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call?
ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
bool reachable = unitReachability.distances.at(hex.toInt()) <= radius;
@ -978,6 +996,11 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
return result;
}
const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
{
return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
}
// avoid blocking path for stronger stack by weaker stack
bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
{
@ -1029,9 +1052,11 @@ bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb
}
}
if(!reachable && std::count(reachabilityMap[hex].begin(), reachabilityMap[hex].end(), unit) > 1)
if(!reachable)
{
blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
auto reachableUnits = getOneTurnReachableUnits(0, hex);
if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1)
blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
}
}
}

View File

@ -112,24 +112,40 @@ private:
struct ReachabilityData
{
std::map<int, std::vector<const battle::Unit *>> units;
std::map<int, battle::Units> units;
// shooters which are within mellee attack and mellee units
std::vector<const battle::Unit *> melleeAccessible;
battle::Units melleeAccessible;
// far shooters
std::vector<const battle::Unit *> shooters;
battle::Units shooters;
std::set<uint32_t> enemyUnitsReachingAttacker;
};
class ReachabilityMapCache
{
struct PerTurnData{
std::bitset<GameConstants::BFIELD_SIZE> isValid;
std::array<battle::Units, GameConstants::BFIELD_SIZE> hexes;
};
std::map<uint32_t, ReachabilityInfo> unitReachabilityMap; // unit ID -> reachability
std::map<uint32_t, PerTurnData> hexReachabilityPerTurn;
//const ReachabilityInfo & update();
battle::Units computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex);
public:
const battle::Units & getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex);
void update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb);
};
class BattleExchangeEvaluator
{
private:
std::shared_ptr<CBattleInfoCallback> cb;
std::shared_ptr<Environment> env;
std::map<uint32_t, ReachabilityInfo> reachabilityCache;
std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
mutable ReachabilityMapCache reachabilityMap;
std::vector<battle::Units> turnOrder;
float negativeEffectMultiplier;
int simulationTurnsCount;
@ -142,7 +158,7 @@ private:
PotentialTargets & targets,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits = {}) const;
battle::Units additionalUnits = {}) const;
bool canBeHitThisTurn(const AttackPossibility & ap);
@ -169,7 +185,7 @@ public:
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb) const;
std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const;
const battle::Units & getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const;
void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
ReachabilityData getExchangeUnits(
@ -177,7 +193,7 @@ public:
uint8_t turn,
PotentialTargets & targets,
std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits = {}) const;
battle::Units additionalUnits = {}) const;
bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
@ -187,7 +203,7 @@ public:
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb);
std::vector<const battle::Unit *> getAdjacentUnits(const battle::Unit * unit) const;
battle::Units getAdjacentUnits(const battle::Unit * unit) const;
float getPositiveEffectMultiplier() const { return 1; }
float getNegativeEffectMultiplier() const { return negativeEffectMultiplier; }

View File

@ -14,7 +14,7 @@ class PotentialTargets
{
public:
std::vector<AttackPossibility> possibleAttacks;
std::vector<const battle::Unit *> unreachableEnemies;
battle::Units unreachableEnemies;
PotentialTargets(){};
PotentialTargets(

View File

@ -670,15 +670,15 @@ namespace vstd
return false;
}
template<typename T>
void removeDuplicates(std::vector<T> &vec)
template <typename Container>
void removeDuplicates(Container &vec)
{
std::sort(vec.begin(), vec.end());
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}
template <typename T>
void concatenate(std::vector<T> &dest, const std::vector<T> &src)
template <typename Container>
void concatenate(Container &dest, const Container &src)
{
dest.reserve(dest.size() + src.size());
dest.insert(dest.end(), src.begin(), src.end());

View File

@ -22,6 +22,7 @@ class SpellSchool;
namespace battle
{
class Unit;
using Units = boost::container::small_vector<const Unit *, 4>;
}
namespace spells
@ -65,7 +66,7 @@ public:
virtual void getCasterName(MetaString & text) const = 0;
///full default text
virtual void getCastDescription(const Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const = 0;
virtual void getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const = 0;
virtual void spendMana(ServerCallback * server, const int32_t spellCost) const = 0;

View File

@ -93,7 +93,9 @@ public:
h & static_cast<CArtifactSet&>(*this);
h & _armyObj;
h & experience;
BONUS_TREE_DESERIALIZATION_FIX
if(!h.saving)
deserializationFix();
}
void serializeJson(JsonSerializeFormat & handler);

View File

@ -35,6 +35,7 @@ CStack::CStack(const CStackInstance * Base, const PlayerColor & O, int I, Battle
side(Side)
{
health.init(); //???
doubleWideCached = battle::CUnitState::doubleWide();
}
CStack::CStack():
@ -55,6 +56,7 @@ CStack::CStack(const CStackBasicDescriptor * stack, const PlayerColor & O, int I
side(Side)
{
health.init(); //???
doubleWideCached = battle::CUnitState::doubleWide();
}
void CStack::localInit(BattleInfo * battleInfo)
@ -296,7 +298,7 @@ BattleHexArray CStack::meleeAttackHexes(const battle::Unit * attacker, const bat
bool CStack::isMeleeAttackPossible(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerPos, BattleHex defenderPos)
{
if(defender->hasBonusOfType(BonusType::INVINCIBLE))
if(defender->isInvincible())
return false;
return !meleeAttackHexes(attacker, defender, attackerPos, defenderPos).empty();
@ -404,4 +406,30 @@ void CStack::spendMana(ServerCallback * server, const int spellCost) const
server->apply(ssp);
}
void CStack::postDeserialize(const CArmedInstance * army, const SlotID & extSlot)
{
if(extSlot == SlotID::COMMANDER_SLOT_PLACEHOLDER)
{
const auto * hero = dynamic_cast<const CGHeroInstance *>(army);
assert(hero);
base = hero->commander;
}
else if(slot == SlotID::SUMMONED_SLOT_PLACEHOLDER || slot == SlotID::ARROW_TOWERS_SLOT || slot == SlotID::WAR_MACHINES_SLOT)
{
//no external slot possible, so no base stack
base = nullptr;
}
else if(!army || extSlot == SlotID() || !army->hasStackAtSlot(extSlot))
{
base = nullptr;
logGlobal->warn("%s doesn't have a base stack!", typeID.toEntity(VLC)->getNameSingularTranslated());
}
else
{
base = &army->getStack(extSlot);
}
doubleWideCached = battle::CUnitState::doubleWide();
}
VCMI_LIB_NAMESPACE_END

View File

@ -23,7 +23,7 @@ struct BattleStackAttacked;
class BattleInfo;
//Represents STACK_BATTLE nodes
class DLL_LINKAGE CStack : public CBonusSystemNode, public battle::CUnitState, public battle::IUnitEnvironment
class DLL_LINKAGE CStack final : public CBonusSystemNode, public battle::CUnitState, public battle::IUnitEnvironment
{
private:
ui32 ID = -1; //unique ID of stack
@ -36,6 +36,9 @@ private:
SlotID slot; //slot - position in garrison (may be 255 for neutrals/called creatures)
bool doubleWideCached = false;
void postDeserialize(const CArmedInstance * army, const SlotID & extSlot);
public:
const CStackInstance * base = nullptr; //garrison slot from which stack originates (nullptr for war machines, summoned cres, etc)
@ -77,6 +80,7 @@ public:
BattleSide unitSide() const override;
PlayerColor unitOwner() const override;
SlotID unitSlot() const override;
bool doubleWide() const override { return doubleWideCached;};
std::string getDescription() const override;
@ -119,26 +123,7 @@ public:
h & army;
h & extSlot;
if(extSlot == SlotID::COMMANDER_SLOT_PLACEHOLDER)
{
const auto * hero = dynamic_cast<const CGHeroInstance *>(army);
assert(hero);
base = hero->commander;
}
else if(slot == SlotID::SUMMONED_SLOT_PLACEHOLDER || slot == SlotID::ARROW_TOWERS_SLOT || slot == SlotID::WAR_MACHINES_SLOT)
{
//no external slot possible, so no base stack
base = nullptr;
}
else if(!army || extSlot == SlotID() || !army->hasStackAtSlot(extSlot))
{
base = nullptr;
logGlobal->warn("%s doesn't have a base stack!", typeID.toEntity(VLC)->getNameSingularTranslated());
}
else
{
base = &army->getStack(extSlot);
}
postDeserialize(army, extSlot);
}
}
@ -146,4 +131,4 @@ private:
const BattleInfo * battle; //do not serialize
};
VCMI_LIB_NAMESPACE_END
VCMI_LIB_NAMESPACE_END

View File

@ -113,11 +113,11 @@ public:
}
void clear() noexcept;
inline void erase(size_type index) noexcept
inline void erase(BattleHex target) noexcept
{
assert(index < totalSize);
internalStorage[index] = BattleHex::INVALID;
presenceFlags[index] = 0;
assert(contains(target));
vstd::erase(internalStorage, target);
presenceFlags[target.toInt()] = 0;
}
void erase(iterator first, iterator last) noexcept;
inline void pop_back() noexcept
@ -160,17 +160,23 @@ public:
/// get (precomputed) all possible surrounding tiles
static const BattleHexArray & getAllNeighbouringTiles(BattleHex hex) noexcept
{
assert(hex.isValid());
static const BattleHexArray invalid;
return allNeighbouringTiles[hex.toInt()];
if (hex.isValid())
return allNeighbouringTiles[hex.toInt()];
else
return invalid;
}
/// get (precomputed) only valid and available surrounding tiles
static const BattleHexArray & getNeighbouringTiles(BattleHex hex) noexcept
{
assert(hex.isValid());
static const BattleHexArray invalid;
return neighbouringTiles[hex.toInt()];
if (hex.isValid())
return neighbouringTiles[hex.toInt()];
else
return invalid;
}
/// get (precomputed) only valid and available surrounding tiles for double wide creatures

View File

@ -27,7 +27,7 @@ BattleStateInfoForRetreat::BattleStateInfoForRetreat():
{
}
uint64_t getFightingStrength(const std::vector<const battle::Unit *> & stacks, const CGHeroInstance * hero = nullptr)
uint64_t getFightingStrength(const battle::Units & stacks, const CGHeroInstance * hero = nullptr)
{
uint64_t result = 0;

View File

@ -16,6 +16,7 @@ VCMI_LIB_NAMESPACE_BEGIN
namespace battle
{
class Unit;
using Units = boost::container::small_vector<const Unit *, 4>;
}
class CGHeroInstance;
@ -27,8 +28,8 @@ public:
bool canSurrender;
bool isLastTurnBeforeDie;
BattleSide ourSide;
std::vector<const battle::Unit *> ourStacks;
std::vector<const battle::Unit *> enemyStacks;
battle::Units ourStacks;
battle::Units enemyStacks;
const CGHeroInstance * ourHero;
const CGHeroInstance * enemyHero;
int turnsSkippedByDefense;

View File

@ -383,11 +383,9 @@ battle::Units CBattleInfoCallback::battleAliveUnits(BattleSide side) const
using namespace battle;
//T is battle::Unit descendant
template <typename T>
const T * takeOneUnit(std::vector<const T*> & allUnits, const int turn, BattleSide & sideThatLastMoved, int phase)
static const battle::Unit * takeOneUnit(battle::Units & allUnits, const int turn, BattleSide & sideThatLastMoved, int phase)
{
const T * returnedUnit = nullptr;
const battle::Unit * returnedUnit = nullptr;
size_t currentUnitIndex = 0;
for(size_t i = 0; i < allUnits.size(); i++)
@ -677,7 +675,7 @@ bool CBattleInfoCallback::battleCanAttack(const battle::Unit * stack, const batt
if (!stack || !target)
return false;
if(target->hasBonusOfType(BonusType::INVINCIBLE))
if(target->isInvincible())
return false;
if(!battleMatchOwner(stack, target))
@ -746,7 +744,7 @@ bool CBattleInfoCallback::battleCanShoot(const battle::Unit * attacker, BattleHe
if(!defender)
return false;
if(defender->hasBonusOfType(BonusType::INVINCIBLE))
if(defender->isInvincible())
return false;
}
@ -812,7 +810,7 @@ DamageEstimation CBattleInfoCallback::battleEstimateDamage(const BattleAttackInf
if (!bai.defender->ableToRetaliate())
return ret;
if (bai.attacker->hasBonusOfType(BonusType::BLOCKS_RETALIATION) || bai.attacker->hasBonusOfType(BonusType::INVINCIBLE))
if (bai.attacker->hasBonusOfType(BonusType::BLOCKS_RETALIATION) || bai.attacker->isInvincible())
return ret;
//TODO: rewrite using boost::numeric::interval
@ -1168,7 +1166,7 @@ std::pair<const battle::Unit *, BattleHex> CBattleInfoCallback::getNearestStack(
std::vector<DistStack> stackPairs;
std::vector<const battle::Unit *> possible = battleGetUnitsIf([=](const battle::Unit * unit)
battle::Units possible = battleGetUnitsIf([=](const battle::Unit * unit)
{
return unit->isValidTarget(false) && unit != closest;
});
@ -1355,14 +1353,9 @@ AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes(
if(attacker->hasBonusOfType(BonusType::WIDE_BREATH))
{
BattleHexArray hexes = destinationTile.getNeighbouringTiles();
for(int i = 0; i < hexes.size(); i++)
{
if(hexes.at(i) == attackOriginHex)
{
hexes.erase(i);
i = 0;
}
}
if (hexes.contains(attackOriginHex))
hexes.erase(attackOriginHex);
for(BattleHex tile : hexes)
{
//friendly stacks can also be damaged by Dragon Breath
@ -1436,7 +1429,7 @@ AttackableTiles CBattleInfoCallback::getPotentiallyShootableHexes(const battle::
return at;
}
std::vector<const battle::Unit*> CBattleInfoCallback::getAttackedBattleUnits(
battle::Units CBattleInfoCallback::getAttackedBattleUnits(
const battle::Unit * attacker,
const battle::Unit * defender,
BattleHex destinationTile,
@ -1444,7 +1437,7 @@ std::vector<const battle::Unit*> CBattleInfoCallback::getAttackedBattleUnits(
BattleHex attackerPos,
BattleHex defenderPos) const
{
std::vector<const battle::Unit*> units;
battle::Units units;
RETURN_IF_NOT_BATTLE(units);
if(attackerPos == BattleHex::INVALID)
@ -1716,18 +1709,22 @@ bool CBattleInfoCallback::battleIsUnitBlocked(const battle::Unit * unit) const
return false;
}
std::set<const battle::Unit *> CBattleInfoCallback::battleAdjacentUnits(const battle::Unit * unit) const
battle::Units CBattleInfoCallback::battleAdjacentUnits(const battle::Unit * unit) const
{
std::set<const battle::Unit *> ret;
RETURN_IF_NOT_BATTLE(ret);
RETURN_IF_NOT_BATTLE({});
for(auto hex : unit->getSurroundingHexes())
const auto & hexes = unit->getSurroundingHexes();
const auto & units = battleGetUnitsIf([=](const battle::Unit * unit)
{
if(const auto * neighbour = battleGetUnitByPos(hex, true))
ret.insert(neighbour);
}
const auto & unitHexes = unit->getHexes();
for (const auto & hex : unitHexes)
if (hexes.contains(hex))
return true;
return false;
});
return ret;
return units;
}
SpellID CBattleInfoCallback::getRandomBeneficialSpell(vstd::RNG & rand, const battle::Unit * caster, const battle::Unit * subject) const

View File

@ -95,7 +95,7 @@ public:
bool battleCanShoot(const battle::Unit * attacker, BattleHex dest) const; //determines if stack with given ID shoot at the selected destination
bool battleCanShoot(const battle::Unit * attacker) const; //determines if stack with given ID shoot in principle
bool battleIsUnitBlocked(const battle::Unit * unit) const; //returns true if there is neighboring enemy stack
std::set<const battle::Unit *> battleAdjacentUnits(const battle::Unit * unit) const;
battle::Units battleAdjacentUnits(const battle::Unit * unit) const;
DamageEstimation calculateDmgRange(const BattleAttackInfo & info) const;
@ -147,7 +147,7 @@ public:
AttackableTiles getPotentiallyShootableHexes(const battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const;
std::vector<const battle::Unit *> getAttackedBattleUnits(
battle::Units getAttackedBattleUnits(
const battle::Unit* attacker,
const battle::Unit * defender,
BattleHex destinationTile,
@ -173,4 +173,4 @@ protected:
BattleHexArray getStoppers(BattleSide whichSidePerspective) const; //get hexes with stopping obstacles (quicksands)
};
VCMI_LIB_NAMESPACE_END
VCMI_LIB_NAMESPACE_END

View File

@ -464,7 +464,7 @@ void CUnitState::getCasterName(MetaString & text) const
addNameReplacement(text, true);
}
void CUnitState::getCastDescription(const spells::Spell * spell, const std::vector<const Unit *> & attacked, MetaString & text) const
void CUnitState::getCastDescription(const spells::Spell * spell, const battle::Units & attacked, MetaString & text) const
{
text.appendLocalString(EMetaText::GENERAL_TXT, 565);//The %s casts %s
//todo: use text 566 for single creature
@ -700,6 +700,11 @@ bool CUnitState::isHypnotized() const
return bonusCache.getBonusValue(UnitBonusValuesProxy::HYPNOTIZED);
}
bool CUnitState::isInvincible() const
{
return bonusCache.getBonusValue(UnitBonusValuesProxy::INVINCIBLE);
}
int CUnitState::getTotalAttacks(bool ranged) const
{
return 1 + (ranged ?

View File

@ -183,7 +183,7 @@ public:
PlayerColor getCasterOwner() const override;
const CGHeroInstance * getHeroCaster() const override;
void getCasterName(MetaString & text) const override;
void getCastDescription(const spells::Spell * spell, const std::vector<const Unit *> & attacked, MetaString & text) const override;
void getCastDescription(const spells::Spell * spell, const battle::Units & attacked, MetaString & text) const override;
int32_t manaLimit() const override;
bool ableToRetaliate() const override;
@ -193,6 +193,7 @@ public:
bool isValidTarget(bool allowDead = false) const override;
bool isHypnotized() const override;
bool isInvincible() const override;
bool isClone() const override;
bool hasClone() const override;
@ -269,7 +270,7 @@ private:
void reset();
};
class DLL_LINKAGE CUnitStateDetached : public CUnitState
class DLL_LINKAGE CUnitStateDetached final : public CUnitState
{
public:
explicit CUnitStateDetached(const IUnitInfo * unit_, const IBonusBearer * bonus_);

View File

@ -27,7 +27,7 @@ namespace battle
{
class IUnitInfo;
class Unit;
using Units = std::vector<const Unit *>;
using Units = boost::container::small_vector<const Unit *, 4>;
using UnitFilter = std::function<bool(const Unit *)>;
}

View File

@ -107,25 +107,51 @@ const BattleHexArray & Unit::getHexes(BattleHex assumedPos) const
return getHexes(assumedPos, doubleWide(), unitSide());
}
BattleHexArray::ArrayOfBattleHexArrays Unit::precomputeUnitHexes(BattleSide side, bool twoHex)
{
BattleHexArray::ArrayOfBattleHexArrays result;
for (BattleHex assumedPos = 0; assumedPos < GameConstants::BFIELD_SIZE; ++assumedPos)
{
BattleHexArray hexes;
hexes.insert(assumedPos);
if(twoHex)
hexes.insert(occupiedHex(assumedPos, twoHex, side));
result[assumedPos.toInt()] = std::move(hexes);
}
return result;
}
const BattleHexArray & Unit::getHexes(BattleHex assumedPos, bool twoHex, BattleSide side)
{
static BattleHexArray::ArrayOfBattleHexArrays precomputed[4];
int index = side == BattleSide::ATTACKER ? 0 : 2;
static const std::array<BattleHexArray::ArrayOfBattleHexArrays, 4> precomputed = {
precomputeUnitHexes(BattleSide::ATTACKER, false),
precomputeUnitHexes(BattleSide::ATTACKER, true),
precomputeUnitHexes(BattleSide::DEFENDER, false),
precomputeUnitHexes(BattleSide::DEFENDER, true),
};
if(!precomputed[index + twoHex][assumedPos.toInt()].empty())
static const std::array<BattleHexArray, 5> invalidHexes = {
BattleHexArray({BattleHex( 0)}),
BattleHexArray({BattleHex(-1)}),
BattleHexArray({BattleHex(-2)}),
BattleHexArray({BattleHex(-3)}),
BattleHexArray({BattleHex(-4)})
};
if (assumedPos.isValid())
{
int index = side == BattleSide::ATTACKER ? 0 : 2;
return precomputed[index + twoHex][assumedPos.toInt()];
// first run, compute
BattleHexArray hexes;
hexes.insert(assumedPos);
if(twoHex)
hexes.insert(occupiedHex(assumedPos, twoHex, side));
precomputed[index + twoHex][assumedPos.toInt()] = std::move(hexes);
return precomputed[index + twoHex][assumedPos.toInt()];
}
else
{
// Towers and such
return invalidHexes.at(-assumedPos.toInt());
}
}
BattleHex Unit::occupiedHex() const

View File

@ -64,6 +64,8 @@ class CUnitState;
class DLL_LINKAGE Unit : public IUnitInfo, public spells::Caster, public virtual IBonusBearer, public ACreature
{
static BattleHexArray::ArrayOfBattleHexArrays precomputeUnitHexes(BattleSide side, bool twoHex);
public:
virtual ~Unit();
@ -85,6 +87,7 @@ public:
virtual bool isValidTarget(bool allowDead = false) const = 0; //non-turret non-ghost stacks (can be attacked or be object of magic effect)
virtual bool isHypnotized() const = 0;
virtual bool isInvincible() const = 0;
virtual bool isClone() const = 0;
virtual bool hasClone() const = 0;

View File

@ -203,6 +203,7 @@ const UnitBonusValuesProxy::SelectorsArray * UnitBonusValuesProxy::generateSelec
Selector::type()(BonusType::FORGETFULL),//FORGETFULL,
Selector::type()(BonusType::FREE_SHOOTING).Or(Selector::type()(BonusType::SIEGE_WEAPON)),//HAS_FREE_SHOOTING,
Selector::type()(BonusType::STACK_HEALTH),//STACK_HEALTH,
Selector::type()(BonusType::INVINCIBLE),//INVINCIBLE,
Selector::type()(BonusType::NONE).And(Selector::source(BonusSource::SPELL_EFFECT, BonusSourceID(SpellID(SpellID::CLONE))))
};

View File

@ -116,6 +116,7 @@ public:
FORGETFULL,
HAS_FREE_SHOOTING,
STACK_HEALTH,
INVINCIBLE,
CLONE_MARKER,

View File

@ -856,7 +856,7 @@ void CGHeroInstance::getCasterName(MetaString & text) const
text.replaceRawString(getNameTranslated());
}
void CGHeroInstance::getCastDescription(const spells::Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const
void CGHeroInstance::getCastDescription(const spells::Spell * spell, const battle::Units & attacked, MetaString & text) const
{
const bool singleTarget = attacked.size() == 1;
const int textIndex = singleTarget ? 195 : 196;

View File

@ -309,7 +309,7 @@ public:
const CGHeroInstance * getHeroCaster() const override;
void getCasterName(MetaString & text) const override;
void getCastDescription(const spells::Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const override;
void getCastDescription(const spells::Spell * spell, const battle::Units & attacked, MetaString & text) const override;
void spendMana(ServerCallback * server, const int spellCost) const override;
void attachToBoat(CGBoat* newBoat);

View File

@ -1503,6 +1503,11 @@ void NewObject::applyGs(CGameState *gs)
gs->map->addBlockVisTiles(newObject);
gs->map->calculateGuardingGreaturePositions();
// attach newly spawned wandering monster to global bonus system node
auto newArmy = dynamic_cast<CArmedInstance*>(newObject);
if (newArmy)
newArmy->whatShouldBeAttached().attachTo(newArmy->whereShouldBeAttached(gs));
logGlobal->debug("Added object id=%d; name=%s", newObject->id, newObject->getObjectName());
}

View File

@ -49,7 +49,7 @@ int32_t AbilityCaster::getEffectLevel(const Spell * spell) const
return getSpellSchoolLevel(spell);
}
void AbilityCaster::getCastDescription(const Spell * spell, const std::vector<const battle::Unit*> & attacked, MetaString & text) const
void AbilityCaster::getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const
{
//do nothing
}

View File

@ -25,7 +25,7 @@ public:
int32_t getSpellSchoolLevel(const Spell * spell, SpellSchool * outSelectedSchool = nullptr) const override;
int32_t getEffectLevel(const Spell * spell) const override;
void getCastDescription(const Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const override;
void getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const override;
void spendMana(ServerCallback * server, const int32_t spellCost) const override;
private:

View File

@ -231,7 +231,7 @@ bool BattleSpellMechanics::canBeCastAt(const Target & target, Problem & problem)
if(mainTarget && mainTarget == caster)
return false; // can't cast on self
if(mainTarget && mainTarget->hasBonusOfType(BonusType::INVINCIBLE) && !getSpell()->getPositiveness())
if(mainTarget && mainTarget->isInvincible() && !getSpell()->getPositiveness())
return false;
}
else if(getSpell()->canCastOnlyOnSelf())
@ -259,7 +259,7 @@ std::vector<const CStack *> BattleSpellMechanics::getAffectedStacks(const Target
for(const Destination & dest : all)
{
if(dest.unitValue && !dest.unitValue->hasBonusOfType(BonusType::INVINCIBLE))
if(dest.unitValue && !dest.unitValue->isInvincible())
{
//FIXME: remove and return battle::Unit
stacks.insert(battle()->battleGetStackByID(dest.unitValue->unitId(), false));
@ -473,7 +473,7 @@ std::set<const battle::Unit *> BattleSpellMechanics::collectTargets() const
return result;
}
void BattleSpellMechanics::doRemoveEffects(ServerCallback * server, const std::vector<const battle::Unit *> & targets, const CSelector & selector)
void BattleSpellMechanics::doRemoveEffects(ServerCallback * server, const battle::Units & targets, const CSelector & selector)
{
SetStackEffect sse;
sse.battleID = battle()->getBattle()->getBattleID();

View File

@ -18,6 +18,11 @@ VCMI_LIB_NAMESPACE_BEGIN
struct BattleSpellCast;
namespace battle
{
using Units = boost::container::small_vector<const Unit *, 4>;
}
namespace spells
{
@ -66,14 +71,14 @@ private:
std::shared_ptr<effects::Effects> effects;
std::shared_ptr<IReceptiveCheck> targetCondition;
std::vector<const battle::Unit *> affectedUnits;
battle::Units affectedUnits;
effects::Effects::EffectsToApply effectsToApply;
void beforeCast(BattleSpellCast & sc, vstd::RNG & rng, const Target & target);
std::set<const battle::Unit *> collectTargets() const;
void doRemoveEffects(ServerCallback * server, const std::vector<const battle::Unit *> & targets, const CSelector & selector);
void doRemoveEffects(ServerCallback * server, const battle::Units & targets, const CSelector & selector);
BattleHexArray spellRangeInHexes(BattleHex centralHex) const;

View File

@ -57,7 +57,7 @@ void BonusCaster::getCasterName(MetaString & text) const
}
}
void BonusCaster::getCastDescription(const Spell * spell, const std::vector<const battle::Unit*> & attacked, MetaString & text) const
void BonusCaster::getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const
{
const bool singleTarget = attacked.size() == 1;
const int textIndex = singleTarget ? 195 : 196;

View File

@ -26,7 +26,7 @@ public:
virtual ~BonusCaster();
void getCasterName(MetaString & text) const override;
void getCastDescription(const Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const override;
void getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const override;
void spendMana(ServerCallback * server, const int spellCost) const override;
private:

View File

@ -434,7 +434,7 @@ int64_t CSpell::adjustRawDamage(const spells::Caster * caster, const battle::Uni
}
//invincible
if(bearer->hasBonusOfType(BonusType::INVINCIBLE))
if(affectedCreature->isInvincible())
ret = 0;
}
ret = caster->getSpellBonus(this, ret, affectedCreature);

View File

@ -75,7 +75,7 @@ void SilentCaster::getCasterName(MetaString & text) const
logGlobal->debug("Unexpected call to SilentCaster::getCasterName");
}
void SilentCaster::getCastDescription(const Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const
void SilentCaster::getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const
{
//do nothing
}

View File

@ -25,7 +25,7 @@ public:
SilentCaster(PlayerColor owner_, const Caster * caster);
void getCasterName(MetaString & text) const override;
void getCastDescription(const Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const override;
void getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const override;
void spendMana(ServerCallback * server, const int spellCost) const override;
PlayerColor getCasterOwner() const override;
int32_t manaLimit() const override;

View File

@ -106,7 +106,7 @@ void ProxyCaster::getCasterName(MetaString & text) const
actualCaster->getCasterName(text);
}
void ProxyCaster::getCastDescription(const Spell * spell, const std::vector<const battle::Unit*> & attacked, MetaString & text) const
void ProxyCaster::getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const
{
if(actualCaster)
actualCaster->getCastDescription(spell, attacked, text);

View File

@ -33,7 +33,7 @@ public:
int64_t getEffectValue(const Spell * spell) const override;
PlayerColor getCasterOwner() const override;
void getCasterName(MetaString & text) const override;
void getCastDescription(const Spell * spell, const std::vector<const battle::Unit *> & attacked, MetaString & text) const override;
void getCastDescription(const Spell * spell, const battle::Units & attacked, MetaString & text) const override;
void spendMana(ServerCallback * server, const int32_t spellCost) const override;
const CGHeroInstance * getHeroCaster() const override;
int32_t manaLimit() const override;

View File

@ -223,7 +223,8 @@ EffectTarget UnitEffect::transformTargetByChain(const Mechanics * m, const Targe
effectTarget.emplace_back();
for(auto hex : battle::Unit::getHexes(unit->getPosition(), unit->doubleWide(), unit->unitSide()))
possibleHexes.erase(hex.toInt());
if (possibleHexes.contains(hex))
possibleHexes.erase(hex);
if(possibleHexes.empty())
break;
@ -278,4 +279,4 @@ void UnitEffect::serializeJsonEffect(JsonSerializeFormat & handler)
}
}
VCMI_LIB_NAMESPACE_END
VCMI_LIB_NAMESPACE_END

View File

@ -276,7 +276,7 @@ bool BattleActionProcessor::doAttackAction(const CBattleInfoCallback & battle, c
for (int i = 0; i < totalAttacks; ++i)
{
//first strike
if(i == 0 && firstStrike && retaliation && !stack->hasBonusOfType(BonusType::BLOCKS_RETALIATION) && !stack->hasBonusOfType(BonusType::INVINCIBLE))
if(i == 0 && firstStrike && retaliation && !stack->hasBonusOfType(BonusType::BLOCKS_RETALIATION) && !stack->isInvincible())
{
makeAttack(battle, destinationStack, stack, 0, stack->getPosition(), true, false, true);
}
@ -303,7 +303,7 @@ bool BattleActionProcessor::doAttackAction(const CBattleInfoCallback & battle, c
//we check retaliation twice, so if it unblocked during attack it will work only on next attack
if(stack->alive()
&& !stack->hasBonusOfType(BonusType::BLOCKS_RETALIATION)
&& !stack->hasBonusOfType(BonusType::INVINCIBLE)
&& !stack->isInvincible()
&& (i == 0 && !firstStrike)
&& retaliation && destinationStack->ableToRetaliate())
{

View File

@ -76,6 +76,11 @@ public:
return hasBonusOfType(BonusType::HYPNOTIZED);
}
bool isInvincible() const override
{
return hasBonusOfType(BonusType::INVINCIBLE);
}
void redirectBonusesToFake()
{
ON_CALL(*this, getAllBonuses(_, _, _)).WillByDefault(Invoke(&bonusFake, &BonusBearerMock::getAllBonuses));

View File

@ -28,7 +28,7 @@ public:
MOCK_CONST_METHOD1(getEffectValue, int64_t(const spells::Spell *));
MOCK_CONST_METHOD0(getCasterOwner, PlayerColor());
MOCK_CONST_METHOD1(getCasterName, void(MetaString &));
MOCK_CONST_METHOD3(getCastDescription, void(const spells::Spell *, const std::vector<const battle::Unit *> &, MetaString &));
MOCK_CONST_METHOD3(getCastDescription, void(const spells::Spell *, const battle::Units &, MetaString &));
MOCK_CONST_METHOD2(spendMana, void(ServerCallback *, const int32_t));
MOCK_CONST_METHOD0(manaLimit, int32_t());
MOCK_CONST_METHOD0(getHeroCaster, CGHeroInstance*());
@ -58,6 +58,7 @@ public:
MOCK_CONST_METHOD1(isValidTarget, bool(bool));
MOCK_CONST_METHOD0(isHypnotized, bool());
MOCK_CONST_METHOD0(isInvincible, bool());
MOCK_CONST_METHOD0(isClone, bool());
MOCK_CONST_METHOD0(hasClone, bool());
MOCK_CONST_METHOD0(canCast, bool());