diff --git a/src/datasets.rs b/src/datasets.rs index d726159..06f19b7 100644 --- a/src/datasets.rs +++ b/src/datasets.rs @@ -413,11 +413,13 @@ pub fn get_edges( connection: &Connection, dataset: &str, player: PlayerId, -) -> sqlite::Result> { +) -> sqlite::Result> { let query = format!( - r#"SELECT iif(:pl = player_B, player_A, player_B) AS id, iif(:pl = player_B, -advantage, advantage) AS advantage - FROM "{}_network" - WHERE player_A = :pl OR player_B = :pl"#, + r#"SELECT + iif(:pl = player_B, player_A, player_B) AS id, + iif(:pl = player_B, -advantage, advantage) AS advantage, sets_count + FROM "{}_network" + WHERE player_A = :pl OR player_B = :pl"#, dataset ); @@ -430,6 +432,7 @@ pub fn get_edges( Ok(( PlayerId(r_.read::("id") as u64), r_.read::("advantage"), + r_.read::("sets_count") as u64, )) }) .try_collect() @@ -462,64 +465,82 @@ pub fn either_isolated( pub fn hypothetical_advantage( connection: &Connection, dataset: &str, - decay_rate: f64, player1: PlayerId, player2: PlayerId, + set_limit: u64, + decay_rate: f64, + adj_decay_rate: f64, ) -> sqlite::Result { - use std::collections::HashSet; + use std::collections::{HashMap, HashSet}; + // Check trivial cases if player1 == player2 || either_isolated(connection, dataset, player1, player2)? { return Ok(0.0); } - let mut paths: Vec<(Vec, f64)> = vec![(vec![player1], 0.0)]; let mut visited: HashSet = HashSet::new(); + let mut queue: HashMap = HashMap::from([(player1, (0.0, 1.0))]); - let mut final_path: Option<(f64, usize)> = None; + while !queue.is_empty() { + let visiting = *queue + .iter() + .max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap()) + .unwrap() + .0; + let (adv_v, decay_v) = queue.remove(&visiting).unwrap(); + let connections = get_edges(connection, dataset, visiting)?; - while final_path.is_none() { - if paths.is_empty() { - return Ok(0.0); - } - paths = paths + for (id, adv, sets) in connections .into_iter() - .map(|(path, old_adv)| { - let last = *path.last().unwrap(); - Ok(get_edges(connection, dataset, last)? - .into_iter() - .filter_map(|(id, adv)| { - if visited.contains(&id) { - None - } else { - if id == player2 { - final_path = Some((old_adv + adv, path.len() - 1)); - } - visited.insert(id); - let mut path_ = path.clone(); - path_.extend_one(id); - Some((path_, old_adv + adv)) - } - }) - .collect::>()) - }) - .try_fold(Vec::new(), |mut acc, paths| { - acc.extend(paths?); - Ok(acc) - })?; + .filter(|(id, _, _)| !visited.contains(id)) + { + let advantage = adv_v + adv; + + if id == player2 { + return Ok(advantage * decay_v); + } + + let decay = decay_v + * if sets >= set_limit { + decay_rate + } else { + adj_decay_rate + }; + + if queue + .get(&id) + .map(|(_, decay_old)| *decay_old < decay) + .unwrap_or(true) + { + queue.insert(id, (advantage, decay)); + } + } + + visited.insert(visiting); } - let (final_adv, len) = final_path.unwrap(); - Ok(final_adv * decay_rate.powi(len as i32)) + // No path found + Ok(0.0) } pub fn initialize_edge( connection: &Connection, dataset: &str, - decay_rate: f64, player1: PlayerId, player2: PlayerId, + set_limit: u64, + decay_rate: f64, + adj_decay_rate: f64, ) -> sqlite::Result { - let adv = hypothetical_advantage(connection, dataset, decay_rate, player1, player2)?; + let adv = hypothetical_advantage( + connection, + dataset, + player1, + player2, + set_limit, + decay_rate, + adj_decay_rate, + )?; insert_advantage(connection, dataset, player1, player2, adv)?; Ok(adv) } @@ -644,15 +665,15 @@ CREATE TABLE IF NOT EXISTS datasets ( assert_eq!( get_edges(&connection, "test", PlayerId(1))?, - [(PlayerId(2), -1.0), (PlayerId(3), 5.0)] + [(PlayerId(2), -1.0, 0), (PlayerId(3), 5.0, 0)] ); assert_eq!( get_edges(&connection, "test", PlayerId(2))?, - [(PlayerId(1), 1.0)] + [(PlayerId(1), 1.0, 0)] ); assert_eq!( get_edges(&connection, "test", PlayerId(3))?, - [(PlayerId(1), -5.0)] + [(PlayerId(1), -5.0, 0)] ); Ok(()) } @@ -670,10 +691,19 @@ CREATE TABLE IF NOT EXISTS datasets ( Timestamp(0), )?; + let metadata = metadata(); for i in 1..=num_players { for j in 1..=num_players { assert_eq!( - hypothetical_advantage(&connection, "test", 0.5, PlayerId(i), PlayerId(j))?, + hypothetical_advantage( + &connection, + "test", + PlayerId(i), + PlayerId(j), + metadata.set_limit, + metadata.decay_rate, + metadata.adj_decay_rate + )?, 0.0 ); } @@ -690,8 +720,17 @@ CREATE TABLE IF NOT EXISTS datasets ( insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?; + let metadata = metadata(); assert_eq!( - hypothetical_advantage(&connection, "test", 0.5, PlayerId(1), PlayerId(2))?, + hypothetical_advantage( + &connection, + "test", + PlayerId(1), + PlayerId(2), + metadata.set_limit, + metadata.decay_rate, + metadata.adj_decay_rate + )?, 1.0 );