Use shortest-path to calculate hypothetical adv.
This is much more efficient than taking an average, at the cost of being slightly less optimal.
This commit is contained in:
parent
0771e5c370
commit
1013a1ee17
|
@ -349,6 +349,7 @@ pub fn hypothetical_advantage(
|
|||
player1: PlayerId,
|
||||
player2: PlayerId,
|
||||
) -> sqlite::Result<f64> {
|
||||
use std::collections::HashSet;
|
||||
// Check trivial cases
|
||||
if player1 == player2
|
||||
|| is_isolated(connection, dataset, player1)?
|
||||
|
@ -357,58 +358,43 @@ pub fn hypothetical_advantage(
|
|||
return Ok(0.0);
|
||||
}
|
||||
|
||||
let mut paths: Vec<Vec<(Vec<PlayerId>, f64)>> = vec![vec![(vec![player1], 0.0)]];
|
||||
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
|
||||
let mut visited: HashSet<PlayerId> = HashSet::new();
|
||||
|
||||
for _ in 2..=7 {
|
||||
let new_paths = paths.last().unwrap().iter().cloned().try_fold(
|
||||
Vec::new(),
|
||||
|mut acc, (path, adv)| {
|
||||
acc.extend(
|
||||
get_edges(connection, dataset, *path.last().unwrap())?
|
||||
let mut final_adv: Option<f64> = None;
|
||||
|
||||
while final_adv.is_none() {
|
||||
if paths.is_empty() {
|
||||
return Ok(0.0);
|
||||
}
|
||||
paths = paths
|
||||
.into_iter()
|
||||
.filter_map(|(x, next_adv)| {
|
||||
if path.contains(&x) {
|
||||
.map(|(path, old_adv)| {
|
||||
let last = *path.last().unwrap();
|
||||
Ok(get_edges(connection, dataset, last)?
|
||||
.into_iter()
|
||||
.filter_map(|(id, adv)| {
|
||||
if visited.contains(&id) {
|
||||
None
|
||||
} else {
|
||||
let mut path = path.clone();
|
||||
path.extend_one(x);
|
||||
Some((path, adv + next_adv))
|
||||
visited.insert(id);
|
||||
let mut path_ = path.clone();
|
||||
path_.extend_one(id);
|
||||
if id == player2 {
|
||||
final_adv = Some(old_adv + adv);
|
||||
}
|
||||
}),
|
||||
);
|
||||
Some((path_, old_adv + adv))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
.try_fold(Vec::new(), |mut acc, paths| {
|
||||
acc.extend(paths?);
|
||||
Ok(acc)
|
||||
},
|
||||
)?;
|
||||
paths.extend_one(new_paths);
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(paths
|
||||
.into_iter()
|
||||
.skip(1)
|
||||
.map(|ps| {
|
||||
let ps_correct = ps
|
||||
.iter()
|
||||
.filter_map(|(path, adv)| {
|
||||
if *path.last().unwrap() == player2 {
|
||||
Some(adv)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let num_ps = ps_correct.len();
|
||||
if num_ps == 0 {
|
||||
return None;
|
||||
}
|
||||
Some(ps_correct.into_iter().sum::<f64>() / num_ps as f64)
|
||||
})
|
||||
.skip_while(|x| x.is_none())
|
||||
.enumerate()
|
||||
.fold((0.0, 0.0), |(total, last), (i, adv)| {
|
||||
let adv_ = adv.unwrap_or(last);
|
||||
(total + (0.5_f64.powi((i + 1) as i32) * adv_), adv_)
|
||||
})
|
||||
.0)
|
||||
Ok(final_adv.unwrap())
|
||||
}
|
||||
|
||||
pub fn initialize_edge(
|
||||
|
@ -574,8 +560,9 @@ mod tests {
|
|||
|
||||
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
|
||||
|
||||
assert!(
|
||||
(hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))? - 1.0) < 0.1
|
||||
assert_eq!(
|
||||
hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))?,
|
||||
1.0
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -48,23 +48,23 @@ pub fn get_auth_token(config_dir: &Path) -> Option<String> {
|
|||
// cynic always assumes that IDs are strings. To get around that, we define new
|
||||
// scalar types that deserialize to u64.
|
||||
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[cynic(graphql_type = "ID")]
|
||||
pub struct VideogameId(pub u64);
|
||||
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[cynic(graphql_type = "ID")]
|
||||
pub struct EventId(pub u64);
|
||||
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[cynic(graphql_type = "ID")]
|
||||
pub struct EntrantId(pub u64);
|
||||
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[cynic(graphql_type = "ID")]
|
||||
pub struct PlayerId(pub u64);
|
||||
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct Timestamp(pub u64);
|
||||
|
||||
// Query machinery
|
||||
|
|
Loading…
Reference in a new issue