diff --git a/src/database.rs b/src/database.rs index e0d48b5..30c199d 100644 --- a/src/database.rs +++ b/src/database.rs @@ -607,7 +607,7 @@ pub fn hypothetical_advantage( decay_rate: f64, adj_decay_rate: f64, ) -> sqlite::Result { - use std::collections::{HashMap, HashSet}; + use std::collections::{HashSet, VecDeque}; // Check trivial cases if player1 == player2 || either_isolated(connection, dataset, player1, player2)? { @@ -615,48 +615,62 @@ pub fn hypothetical_advantage( } let mut visited: HashSet = HashSet::new(); - let mut queue: HashMap = HashMap::from([(player1, (0.0, 1.0))]); + let mut queue: VecDeque<(PlayerId, Vec<(f64, f64)>)> = + VecDeque::from([(player1, Vec::from([(0.0, 1.0)]))]); + + let mut final_paths = Vec::new(); + + while !queue.is_empty() && final_paths.len() < 100 { + let (visiting, paths) = queue.pop_front().unwrap(); - while !queue.is_empty() { - let visiting = *queue - .iter() - .max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap()) - .unwrap() - .0; - let (adv_v, decay_v) = queue.remove(&visiting).unwrap(); let connections = get_edges(connection, dataset, visiting)?; for (id, adv, sets) in connections .into_iter() .filter(|(id, _, _)| !visited.contains(id)) { - let advantage = adv_v + adv; + let rf = if id == player2 { + &mut final_paths + } else if let Some(r) = queue.iter_mut().find(|(id_, _)| id == *id_) { + &mut r.1 + } else { + queue.push_back((id, Vec::new())); + &mut queue.back_mut().unwrap().1 + }; - if id == player2 { - return Ok(advantage * decay_v); - } - - let decay = decay_v - * if sets >= set_limit { + if rf.len() < 100 { + let decay = if sets >= set_limit { decay_rate } else { adj_decay_rate }; + let iter = paths.iter().map(|(a, d)| (a + adv, d * decay)); - if queue - .get(&id) - .map(|(_, decay_old)| *decay_old < decay) - .unwrap_or(true) - { - queue.insert(id, (advantage, decay)); + rf.extend(iter); + rf.truncate(100); } } visited.insert(visiting); } - // No path found - Ok(0.0) + let max_decay = final_paths + .iter() + .map(|x| x.1) + .max_by(|d1, d2| d1.partial_cmp(d2).unwrap()); + + if let Some(mdec) = max_decay { + let sum_decay = final_paths.iter().map(|x| x.1).sum::(); + Ok(final_paths + .into_iter() + .map(|(adv, dec)| adv * dec) + .sum::() + / sum_decay + * mdec) + } else { + // No paths found + Ok(0.0) + } } pub fn initialize_edge(