mirror of
https://github.com/vcmi/vcmi.git
synced 2024-12-22 22:13:35 +02:00
Fixed potential thread races in Battle AI
This commit is contained in:
parent
f2d870e651
commit
b7efa6c8cc
@ -602,10 +602,10 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
|
|||||||
ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
|
ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(auto unit : allUnits)
|
for(const auto & unit : allUnits)
|
||||||
{
|
{
|
||||||
auto newHealth = unit->getAvailableHealth();
|
auto newHealth = unit->getAvailableHealth();
|
||||||
auto oldHealth = healthOfStack[unit->unitId()];
|
auto oldHealth = vstd::find_or(healthOfStack, unit->unitId(), 0); // old health value may not exist for newly summoned units
|
||||||
|
|
||||||
if(oldHealth != newHealth)
|
if(oldHealth != newHealth)
|
||||||
{
|
{
|
||||||
@ -732,6 +732,3 @@ void BattleEvaluator::print(const std::string & text) const
|
|||||||
{
|
{
|
||||||
logAi->trace("%s Battle AI[%p]: %s", playerID.toString(), this, text);
|
logAi->trace("%s Battle AI[%p]: %s", playerID.toString(), this, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -390,7 +390,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
|
|||||||
const AttackPossibility & ap,
|
const AttackPossibility & ap,
|
||||||
uint8_t turn,
|
uint8_t turn,
|
||||||
PotentialTargets & targets,
|
PotentialTargets & targets,
|
||||||
std::shared_ptr<HypotheticBattle> hb)
|
std::shared_ptr<HypotheticBattle> hb) const
|
||||||
{
|
{
|
||||||
ReachabilityData result;
|
ReachabilityData result;
|
||||||
|
|
||||||
@ -402,7 +402,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
|
|||||||
|
|
||||||
for(auto hex : hexes)
|
for(auto hex : hexes)
|
||||||
{
|
{
|
||||||
vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap[hex] : getOneTurnReachableUnits(turn, hex));
|
vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex));
|
||||||
}
|
}
|
||||||
|
|
||||||
vstd::removeDuplicates(allReachableUnits);
|
vstd::removeDuplicates(allReachableUnits);
|
||||||
@ -481,7 +481,7 @@ float BattleExchangeEvaluator::evaluateExchange(
|
|||||||
uint8_t turn,
|
uint8_t turn,
|
||||||
PotentialTargets & targets,
|
PotentialTargets & targets,
|
||||||
DamageCache & damageCache,
|
DamageCache & damageCache,
|
||||||
std::shared_ptr<HypotheticBattle> hb)
|
std::shared_ptr<HypotheticBattle> hb) const
|
||||||
{
|
{
|
||||||
BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
|
BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
|
||||||
|
|
||||||
@ -502,7 +502,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
|
|||||||
uint8_t turn,
|
uint8_t turn,
|
||||||
PotentialTargets & targets,
|
PotentialTargets & targets,
|
||||||
DamageCache & damageCache,
|
DamageCache & damageCache,
|
||||||
std::shared_ptr<HypotheticBattle> hb)
|
std::shared_ptr<HypotheticBattle> hb) const
|
||||||
{
|
{
|
||||||
#if BATTLE_TRACE_LEVEL>=1
|
#if BATTLE_TRACE_LEVEL>=1
|
||||||
logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
|
logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
|
||||||
@ -613,7 +613,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
auto reachable = exchangeBattle->battleGetUnitsIf([&](const battle::Unit * u) -> bool
|
auto reachable = exchangeBattle->battleGetUnitsIf([this, &exchangeBattle, &attacker](const battle::Unit * u) -> bool
|
||||||
{
|
{
|
||||||
if(u->unitSide() == attacker->unitSide())
|
if(u->unitSide() == attacker->unitSide())
|
||||||
return false;
|
return false;
|
||||||
@ -621,7 +621,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
|
|||||||
if(!exchangeBattle->getForUpdate(u->unitId())->alive())
|
if(!exchangeBattle->getForUpdate(u->unitId())->alive())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return vstd::contains_if(reachabilityMap[u->getPosition()], [&](const battle::Unit * other) -> bool
|
return vstd::contains_if(reachabilityMap.at(u->getPosition()), [&attacker](const battle::Unit * other) -> bool
|
||||||
{
|
{
|
||||||
return attacker->unitId() == other->unitId();
|
return attacker->unitId() == other->unitId();
|
||||||
});
|
});
|
||||||
@ -732,7 +732,7 @@ void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex)
|
std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
|
||||||
{
|
{
|
||||||
std::vector<const battle::Unit *> result;
|
std::vector<const battle::Unit *> result;
|
||||||
|
|
||||||
@ -756,13 +756,10 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
|
|||||||
auto unitSpeed = unit->getMovementRange(turn);
|
auto unitSpeed = unit->getMovementRange(turn);
|
||||||
auto radius = unitSpeed * (turn + 1);
|
auto radius = unitSpeed * (turn + 1);
|
||||||
|
|
||||||
ReachabilityInfo unitReachability = vstd::getOrCompute(
|
auto reachabilityIter = reachabilityCache.find(unit->unitId());
|
||||||
reachabilityCache,
|
assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
|
||||||
unit->unitId(),
|
|
||||||
[&](ReachabilityInfo & data)
|
ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
|
||||||
{
|
|
||||||
data = turnBattle.getReachability(unit);
|
|
||||||
});
|
|
||||||
|
|
||||||
bool reachable = unitReachability.distances[hex] <= radius;
|
bool reachable = unitReachability.distances[hex] <= radius;
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ private:
|
|||||||
uint8_t turn,
|
uint8_t turn,
|
||||||
PotentialTargets & targets,
|
PotentialTargets & targets,
|
||||||
DamageCache & damageCache,
|
DamageCache & damageCache,
|
||||||
std::shared_ptr<HypotheticBattle> hb);
|
std::shared_ptr<HypotheticBattle> hb) const;
|
||||||
|
|
||||||
bool canBeHitThisTurn(const AttackPossibility & ap);
|
bool canBeHitThisTurn(const AttackPossibility & ap);
|
||||||
|
|
||||||
@ -162,16 +162,16 @@ public:
|
|||||||
uint8_t turn,
|
uint8_t turn,
|
||||||
PotentialTargets & targets,
|
PotentialTargets & targets,
|
||||||
DamageCache & damageCache,
|
DamageCache & damageCache,
|
||||||
std::shared_ptr<HypotheticBattle> hb);
|
std::shared_ptr<HypotheticBattle> hb) const;
|
||||||
|
|
||||||
std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex);
|
std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const;
|
||||||
void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
|
void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
|
||||||
|
|
||||||
ReachabilityData getExchangeUnits(
|
ReachabilityData getExchangeUnits(
|
||||||
const AttackPossibility & ap,
|
const AttackPossibility & ap,
|
||||||
uint8_t turn,
|
uint8_t turn,
|
||||||
PotentialTargets & targets,
|
PotentialTargets & targets,
|
||||||
std::shared_ptr<HypotheticBattle> hb);
|
std::shared_ptr<HypotheticBattle> hb) const;
|
||||||
|
|
||||||
bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
|
bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
|
||||||
|
|
||||||
|
23
Global.h
23
Global.h
@ -348,6 +348,15 @@ namespace vstd
|
|||||||
return std::find(c.begin(),c.end(),i);
|
return std::find(c.begin(),c.end(),i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns existing value from map, or default value if key does not exists
|
||||||
|
template <typename Map>
|
||||||
|
const typename Map::mapped_type & find_or(const Map& m, const typename Map::key_type& key, const typename Map::mapped_type& defaultValue) {
|
||||||
|
auto it = m.find(key);
|
||||||
|
if (it == m.end())
|
||||||
|
return defaultValue;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
//returns first key that maps to given value if present, returns success via found if provided
|
//returns first key that maps to given value if present, returns success via found if provided
|
||||||
template <typename Key, typename T>
|
template <typename Key, typename T>
|
||||||
Key findKey(const std::map<Key, T> & map, const T & value, bool * found = nullptr)
|
Key findKey(const std::map<Key, T> & map, const T & value, bool * found = nullptr)
|
||||||
@ -684,20 +693,6 @@ namespace vstd
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class M, class Key, class F>
|
|
||||||
typename M::mapped_type & getOrCompute(M & m, const Key & k, F f)
|
|
||||||
{
|
|
||||||
typedef typename M::mapped_type V;
|
|
||||||
|
|
||||||
std::pair<typename M::iterator, bool> r = m.insert(typename M::value_type(k, V()));
|
|
||||||
V & v = r.first->second;
|
|
||||||
|
|
||||||
if(r.second)
|
|
||||||
f(v);
|
|
||||||
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
//c++20 feature
|
//c++20 feature
|
||||||
template<typename Arithmetic, typename Floating>
|
template<typename Arithmetic, typename Floating>
|
||||||
Arithmetic lerp(const Arithmetic & a, const Arithmetic & b, const Floating & f)
|
Arithmetic lerp(const Arithmetic & a, const Arithmetic & b, const Floating & f)
|
||||||
|
Loading…
Reference in New Issue
Block a user