From 1013a1ee17bbe90b444cf7c5a729f15a01364459 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Thu, 5 Oct 2023 22:10:24 -0400 Subject: [PATCH] Use shortest-path to calculate hypothetical adv. This is much more efficient than taking an average, at the cost of being slightly less optimal. --- src/datasets.rs | 85 +++++++++++++++++++++---------------------------- src/queries.rs | 10 +++--- 2 files changed, 41 insertions(+), 54 deletions(-) diff --git a/src/datasets.rs b/src/datasets.rs index c24ab8c..baeb8ad 100644 --- a/src/datasets.rs +++ b/src/datasets.rs @@ -349,6 +349,7 @@ pub fn hypothetical_advantage( player1: PlayerId, player2: PlayerId, ) -> sqlite::Result { + use std::collections::HashSet; // Check trivial cases if player1 == player2 || is_isolated(connection, dataset, player1)? @@ -357,58 +358,43 @@ pub fn hypothetical_advantage( return Ok(0.0); } - let mut paths: Vec, f64)>> = vec![vec![(vec![player1], 0.0)]]; + let mut paths: Vec<(Vec, f64)> = vec![(vec![player1], 0.0)]; + let mut visited: HashSet = HashSet::new(); - for _ in 2..=7 { - let new_paths = paths.last().unwrap().iter().cloned().try_fold( - Vec::new(), - |mut acc, (path, adv)| { - acc.extend( - get_edges(connection, dataset, *path.last().unwrap())? - .into_iter() - .filter_map(|(x, next_adv)| { - if path.contains(&x) { - None - } else { - let mut path = path.clone(); - path.extend_one(x); - Some((path, adv + next_adv)) + let mut final_adv: Option = None; + + while final_adv.is_none() { + if paths.is_empty() { + return Ok(0.0); + } + paths = paths + .into_iter() + .map(|(path, old_adv)| { + let last = *path.last().unwrap(); + Ok(get_edges(connection, dataset, last)? + .into_iter() + .filter_map(|(id, adv)| { + if visited.contains(&id) { + None + } else { + visited.insert(id); + let mut path_ = path.clone(); + path_.extend_one(id); + if id == player2 { + final_adv = Some(old_adv + adv); } - }), - ); + Some((path_, old_adv + adv)) + } + }) + .collect::>()) + }) + .try_fold(Vec::new(), |mut acc, paths| { + acc.extend(paths?); Ok(acc) - }, - )?; - paths.extend_one(new_paths); + })?; } - Ok(paths - .into_iter() - .skip(1) - .map(|ps| { - let ps_correct = ps - .iter() - .filter_map(|(path, adv)| { - if *path.last().unwrap() == player2 { - Some(adv) - } else { - None - } - }) - .collect::>(); - let num_ps = ps_correct.len(); - if num_ps == 0 { - return None; - } - Some(ps_correct.into_iter().sum::() / num_ps as f64) - }) - .skip_while(|x| x.is_none()) - .enumerate() - .fold((0.0, 0.0), |(total, last), (i, adv)| { - let adv_ = adv.unwrap_or(last); - (total + (0.5_f64.powi((i + 1) as i32) * adv_), adv_) - }) - .0) + Ok(final_adv.unwrap()) } pub fn initialize_edge( @@ -574,8 +560,9 @@ mod tests { insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?; - assert!( - (hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))? - 1.0) < 0.1 + assert_eq!( + hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))?, + 1.0 ); Ok(()) diff --git a/src/queries.rs b/src/queries.rs index bbc56ce..cc74152 100644 --- a/src/queries.rs +++ b/src/queries.rs @@ -48,23 +48,23 @@ pub fn get_auth_token(config_dir: &Path) -> Option { // cynic always assumes that IDs are strings. To get around that, we define new // scalar types that deserialize to u64. -#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cynic(graphql_type = "ID")] pub struct VideogameId(pub u64); -#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cynic(graphql_type = "ID")] pub struct EventId(pub u64); -#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cynic(graphql_type = "ID")] pub struct EntrantId(pub u64); -#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cynic(graphql_type = "ID")] pub struct PlayerId(pub u64); -#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Timestamp(pub u64); // Query machinery