Modify hypothetical advantage calc. for stability

This commit is contained in:
Kiana Sheibani 2024-01-01 22:20:55 -05:00
parent 26c2813b09
commit d39c280310
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -607,7 +607,7 @@ pub fn hypothetical_advantage(
decay_rate: f64, decay_rate: f64,
adj_decay_rate: f64, adj_decay_rate: f64,
) -> sqlite::Result<f64> { ) -> sqlite::Result<f64> {
use std::collections::{HashMap, HashSet}; use std::collections::{HashSet, VecDeque};
// Check trivial cases // Check trivial cases
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? { if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
@ -615,49 +615,63 @@ pub fn hypothetical_advantage(
} }
let mut visited: HashSet<PlayerId> = HashSet::new(); let mut visited: HashSet<PlayerId> = HashSet::new();
let mut queue: HashMap<PlayerId, (f64, f64)> = HashMap::from([(player1, (0.0, 1.0))]); let mut queue: VecDeque<(PlayerId, Vec<(f64, f64)>)> =
VecDeque::from([(player1, Vec::from([(0.0, 1.0)]))]);
let mut final_paths = Vec::new();
while !queue.is_empty() && final_paths.len() < 100 {
let (visiting, paths) = queue.pop_front().unwrap();
while !queue.is_empty() {
let visiting = *queue
.iter()
.max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap())
.unwrap()
.0;
let (adv_v, decay_v) = queue.remove(&visiting).unwrap();
let connections = get_edges(connection, dataset, visiting)?; let connections = get_edges(connection, dataset, visiting)?;
for (id, adv, sets) in connections for (id, adv, sets) in connections
.into_iter() .into_iter()
.filter(|(id, _, _)| !visited.contains(id)) .filter(|(id, _, _)| !visited.contains(id))
{ {
let advantage = adv_v + adv; let rf = if id == player2 {
&mut final_paths
} else if let Some(r) = queue.iter_mut().find(|(id_, _)| id == *id_) {
&mut r.1
} else {
queue.push_back((id, Vec::new()));
&mut queue.back_mut().unwrap().1
};
if id == player2 { if rf.len() < 100 {
return Ok(advantage * decay_v); let decay = if sets >= set_limit {
}
let decay = decay_v
* if sets >= set_limit {
decay_rate decay_rate
} else { } else {
adj_decay_rate adj_decay_rate
}; };
let iter = paths.iter().map(|(a, d)| (a + adv, d * decay));
if queue rf.extend(iter);
.get(&id) rf.truncate(100);
.map(|(_, decay_old)| *decay_old < decay)
.unwrap_or(true)
{
queue.insert(id, (advantage, decay));
} }
} }
visited.insert(visiting); visited.insert(visiting);
} }
// No path found let max_decay = final_paths
.iter()
.map(|x| x.1)
.max_by(|d1, d2| d1.partial_cmp(d2).unwrap());
if let Some(mdec) = max_decay {
let sum_decay = final_paths.iter().map(|x| x.1).sum::<f64>();
Ok(final_paths
.into_iter()
.map(|(adv, dec)| adv * dec)
.sum::<f64>()
/ sum_decay
* mdec)
} else {
// No paths found
Ok(0.0) Ok(0.0)
} }
}
pub fn initialize_edge( pub fn initialize_edge(
connection: &Connection, connection: &Connection,