Modify hypothetical advantage calc. for stability
This commit is contained in:
parent
26c2813b09
commit
d39c280310
|
@ -607,7 +607,7 @@ pub fn hypothetical_advantage(
|
|||
decay_rate: f64,
|
||||
adj_decay_rate: f64,
|
||||
) -> sqlite::Result<f64> {
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
|
||||
// Check trivial cases
|
||||
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
|
||||
|
@ -615,48 +615,62 @@ pub fn hypothetical_advantage(
|
|||
}
|
||||
|
||||
let mut visited: HashSet<PlayerId> = HashSet::new();
|
||||
let mut queue: HashMap<PlayerId, (f64, f64)> = HashMap::from([(player1, (0.0, 1.0))]);
|
||||
let mut queue: VecDeque<(PlayerId, Vec<(f64, f64)>)> =
|
||||
VecDeque::from([(player1, Vec::from([(0.0, 1.0)]))]);
|
||||
|
||||
let mut final_paths = Vec::new();
|
||||
|
||||
while !queue.is_empty() && final_paths.len() < 100 {
|
||||
let (visiting, paths) = queue.pop_front().unwrap();
|
||||
|
||||
while !queue.is_empty() {
|
||||
let visiting = *queue
|
||||
.iter()
|
||||
.max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap())
|
||||
.unwrap()
|
||||
.0;
|
||||
let (adv_v, decay_v) = queue.remove(&visiting).unwrap();
|
||||
let connections = get_edges(connection, dataset, visiting)?;
|
||||
|
||||
for (id, adv, sets) in connections
|
||||
.into_iter()
|
||||
.filter(|(id, _, _)| !visited.contains(id))
|
||||
{
|
||||
let advantage = adv_v + adv;
|
||||
let rf = if id == player2 {
|
||||
&mut final_paths
|
||||
} else if let Some(r) = queue.iter_mut().find(|(id_, _)| id == *id_) {
|
||||
&mut r.1
|
||||
} else {
|
||||
queue.push_back((id, Vec::new()));
|
||||
&mut queue.back_mut().unwrap().1
|
||||
};
|
||||
|
||||
if id == player2 {
|
||||
return Ok(advantage * decay_v);
|
||||
}
|
||||
|
||||
let decay = decay_v
|
||||
* if sets >= set_limit {
|
||||
if rf.len() < 100 {
|
||||
let decay = if sets >= set_limit {
|
||||
decay_rate
|
||||
} else {
|
||||
adj_decay_rate
|
||||
};
|
||||
let iter = paths.iter().map(|(a, d)| (a + adv, d * decay));
|
||||
|
||||
if queue
|
||||
.get(&id)
|
||||
.map(|(_, decay_old)| *decay_old < decay)
|
||||
.unwrap_or(true)
|
||||
{
|
||||
queue.insert(id, (advantage, decay));
|
||||
rf.extend(iter);
|
||||
rf.truncate(100);
|
||||
}
|
||||
}
|
||||
|
||||
visited.insert(visiting);
|
||||
}
|
||||
|
||||
// No path found
|
||||
Ok(0.0)
|
||||
let max_decay = final_paths
|
||||
.iter()
|
||||
.map(|x| x.1)
|
||||
.max_by(|d1, d2| d1.partial_cmp(d2).unwrap());
|
||||
|
||||
if let Some(mdec) = max_decay {
|
||||
let sum_decay = final_paths.iter().map(|x| x.1).sum::<f64>();
|
||||
Ok(final_paths
|
||||
.into_iter()
|
||||
.map(|(adv, dec)| adv * dec)
|
||||
.sum::<f64>()
|
||||
/ sum_decay
|
||||
* mdec)
|
||||
} else {
|
||||
// No paths found
|
||||
Ok(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn initialize_edge(
|
||||
|
|
Loading…
Reference in a new issue