Use shortest-path to calculate hypothetical adv.

This is much more efficient than taking an average, at the cost of being
slightly less optimal.
This commit is contained in:
Kiana Sheibani 2023-10-05 22:10:24 -04:00
parent 0771e5c370
commit 1013a1ee17
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 41 additions and 54 deletions

View file

@ -349,6 +349,7 @@ pub fn hypothetical_advantage(
player1: PlayerId,
player2: PlayerId,
) -> sqlite::Result<f64> {
use std::collections::HashSet;
// Check trivial cases
if player1 == player2
|| is_isolated(connection, dataset, player1)?
@ -357,58 +358,43 @@ pub fn hypothetical_advantage(
return Ok(0.0);
}
let mut paths: Vec<Vec<(Vec<PlayerId>, f64)>> = vec![vec![(vec![player1], 0.0)]];
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
let mut visited: HashSet<PlayerId> = HashSet::new();
for _ in 2..=7 {
let new_paths = paths.last().unwrap().iter().cloned().try_fold(
Vec::new(),
|mut acc, (path, adv)| {
acc.extend(
get_edges(connection, dataset, *path.last().unwrap())?
.into_iter()
.filter_map(|(x, next_adv)| {
if path.contains(&x) {
None
} else {
let mut path = path.clone();
path.extend_one(x);
Some((path, adv + next_adv))
let mut final_adv: Option<f64> = None;
while final_adv.is_none() {
if paths.is_empty() {
return Ok(0.0);
}
paths = paths
.into_iter()
.map(|(path, old_adv)| {
let last = *path.last().unwrap();
Ok(get_edges(connection, dataset, last)?
.into_iter()
.filter_map(|(id, adv)| {
if visited.contains(&id) {
None
} else {
visited.insert(id);
let mut path_ = path.clone();
path_.extend_one(id);
if id == player2 {
final_adv = Some(old_adv + adv);
}
}),
);
Some((path_, old_adv + adv))
}
})
.collect::<Vec<_>>())
})
.try_fold(Vec::new(), |mut acc, paths| {
acc.extend(paths?);
Ok(acc)
},
)?;
paths.extend_one(new_paths);
})?;
}
Ok(paths
.into_iter()
.skip(1)
.map(|ps| {
let ps_correct = ps
.iter()
.filter_map(|(path, adv)| {
if *path.last().unwrap() == player2 {
Some(adv)
} else {
None
}
})
.collect::<Vec<_>>();
let num_ps = ps_correct.len();
if num_ps == 0 {
return None;
}
Some(ps_correct.into_iter().sum::<f64>() / num_ps as f64)
})
.skip_while(|x| x.is_none())
.enumerate()
.fold((0.0, 0.0), |(total, last), (i, adv)| {
let adv_ = adv.unwrap_or(last);
(total + (0.5_f64.powi((i + 1) as i32) * adv_), adv_)
})
.0)
Ok(final_adv.unwrap())
}
pub fn initialize_edge(
@ -574,8 +560,9 @@ mod tests {
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
assert!(
(hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))? - 1.0) < 0.1
assert_eq!(
hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))?,
1.0
);
Ok(())

View file

@ -48,23 +48,23 @@ pub fn get_auth_token(config_dir: &Path) -> Option<String> {
// cynic always assumes that IDs are strings. To get around that, we define new
// scalar types that deserialize to u64.
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")]
pub struct VideogameId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")]
pub struct EventId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")]
pub struct EntrantId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cynic(graphql_type = "ID")]
pub struct PlayerId(pub u64);
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Timestamp(pub u64);
// Query machinery