Use shortest-path to calculate hypothetical adv.

This is much more efficient than taking an average, at the cost of being
slightly less optimal.
This commit is contained in:
Kiana Sheibani 2023-10-05 22:10:24 -04:00
parent 0771e5c370
commit 1013a1ee17
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 41 additions and 54 deletions

View file

@ -349,6 +349,7 @@ pub fn hypothetical_advantage(
player1: PlayerId, player1: PlayerId,
player2: PlayerId, player2: PlayerId,
) -> sqlite::Result<f64> { ) -> sqlite::Result<f64> {
use std::collections::HashSet;
// Check trivial cases // Check trivial cases
if player1 == player2 if player1 == player2
|| is_isolated(connection, dataset, player1)? || is_isolated(connection, dataset, player1)?
@ -357,58 +358,43 @@ pub fn hypothetical_advantage(
return Ok(0.0); return Ok(0.0);
} }
let mut paths: Vec<Vec<(Vec<PlayerId>, f64)>> = vec![vec![(vec![player1], 0.0)]]; let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
let mut visited: HashSet<PlayerId> = HashSet::new();
for _ in 2..=7 { let mut final_adv: Option<f64> = None;
let new_paths = paths.last().unwrap().iter().cloned().try_fold(
Vec::new(), while final_adv.is_none() {
|mut acc, (path, adv)| { if paths.is_empty() {
acc.extend( return Ok(0.0);
get_edges(connection, dataset, *path.last().unwrap())? }
.into_iter() paths = paths
.filter_map(|(x, next_adv)| { .into_iter()
if path.contains(&x) { .map(|(path, old_adv)| {
None let last = *path.last().unwrap();
} else { Ok(get_edges(connection, dataset, last)?
let mut path = path.clone(); .into_iter()
path.extend_one(x); .filter_map(|(id, adv)| {
Some((path, adv + next_adv)) if visited.contains(&id) {
None
} else {
visited.insert(id);
let mut path_ = path.clone();
path_.extend_one(id);
if id == player2 {
final_adv = Some(old_adv + adv);
} }
}), Some((path_, old_adv + adv))
); }
})
.collect::<Vec<_>>())
})
.try_fold(Vec::new(), |mut acc, paths| {
acc.extend(paths?);
Ok(acc) Ok(acc)
}, })?;
)?;
paths.extend_one(new_paths);
} }
Ok(paths Ok(final_adv.unwrap())
.into_iter()
.skip(1)
.map(|ps| {
let ps_correct = ps
.iter()
.filter_map(|(path, adv)| {
if *path.last().unwrap() == player2 {
Some(adv)
} else {
None
}
})
.collect::<Vec<_>>();
let num_ps = ps_correct.len();
if num_ps == 0 {
return None;
}
Some(ps_correct.into_iter().sum::<f64>() / num_ps as f64)
})
.skip_while(|x| x.is_none())
.enumerate()
.fold((0.0, 0.0), |(total, last), (i, adv)| {
let adv_ = adv.unwrap_or(last);
(total + (0.5_f64.powi((i + 1) as i32) * adv_), adv_)
})
.0)
} }
pub fn initialize_edge( pub fn initialize_edge(
@ -574,8 +560,9 @@ mod tests {
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?; insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
assert!( assert_eq!(
(hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))? - 1.0) < 0.1 hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))?,
1.0
); );
Ok(()) Ok(())

View file

@ -48,23 +48,23 @@ pub fn get_auth_token(config_dir: &Path) -> Option<String> {
// cynic always assumes that IDs are strings. To get around that, we define new // cynic always assumes that IDs are strings. To get around that, we define new
// scalar types that deserialize to u64. // scalar types that deserialize to u64.
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")] #[cynic(graphql_type = "ID")]
pub struct VideogameId(pub u64); pub struct VideogameId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")] #[cynic(graphql_type = "ID")]
pub struct EventId(pub u64); pub struct EventId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")] #[cynic(graphql_type = "ID")]
pub struct EntrantId(pub u64); pub struct EntrantId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")] #[cynic(graphql_type = "ID")]
pub struct PlayerId(pub u64); pub struct PlayerId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Timestamp(pub u64); pub struct Timestamp(pub u64);
// Query machinery // Query machinery