Use shortest-path to calculate hypothetical adv.
This is much more efficient than taking an average, at the cost of being slightly less optimal.
This commit is contained in:
parent
0771e5c370
commit
1013a1ee17
|
@ -349,6 +349,7 @@ pub fn hypothetical_advantage(
|
||||||
player1: PlayerId,
|
player1: PlayerId,
|
||||||
player2: PlayerId,
|
player2: PlayerId,
|
||||||
) -> sqlite::Result<f64> {
|
) -> sqlite::Result<f64> {
|
||||||
|
use std::collections::HashSet;
|
||||||
// Check trivial cases
|
// Check trivial cases
|
||||||
if player1 == player2
|
if player1 == player2
|
||||||
|| is_isolated(connection, dataset, player1)?
|
|| is_isolated(connection, dataset, player1)?
|
||||||
|
@ -357,58 +358,43 @@ pub fn hypothetical_advantage(
|
||||||
return Ok(0.0);
|
return Ok(0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut paths: Vec<Vec<(Vec<PlayerId>, f64)>> = vec![vec![(vec![player1], 0.0)]];
|
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
|
||||||
|
let mut visited: HashSet<PlayerId> = HashSet::new();
|
||||||
|
|
||||||
for _ in 2..=7 {
|
let mut final_adv: Option<f64> = None;
|
||||||
let new_paths = paths.last().unwrap().iter().cloned().try_fold(
|
|
||||||
Vec::new(),
|
while final_adv.is_none() {
|
||||||
|mut acc, (path, adv)| {
|
if paths.is_empty() {
|
||||||
acc.extend(
|
return Ok(0.0);
|
||||||
get_edges(connection, dataset, *path.last().unwrap())?
|
}
|
||||||
|
paths = paths
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|(x, next_adv)| {
|
.map(|(path, old_adv)| {
|
||||||
if path.contains(&x) {
|
let last = *path.last().unwrap();
|
||||||
|
Ok(get_edges(connection, dataset, last)?
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(id, adv)| {
|
||||||
|
if visited.contains(&id) {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let mut path = path.clone();
|
visited.insert(id);
|
||||||
path.extend_one(x);
|
let mut path_ = path.clone();
|
||||||
Some((path, adv + next_adv))
|
path_.extend_one(id);
|
||||||
|
if id == player2 {
|
||||||
|
final_adv = Some(old_adv + adv);
|
||||||
}
|
}
|
||||||
}),
|
Some((path_, old_adv + adv))
|
||||||
);
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
})
|
||||||
|
.try_fold(Vec::new(), |mut acc, paths| {
|
||||||
|
acc.extend(paths?);
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
},
|
})?;
|
||||||
)?;
|
|
||||||
paths.extend_one(new_paths);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(paths
|
Ok(final_adv.unwrap())
|
||||||
.into_iter()
|
|
||||||
.skip(1)
|
|
||||||
.map(|ps| {
|
|
||||||
let ps_correct = ps
|
|
||||||
.iter()
|
|
||||||
.filter_map(|(path, adv)| {
|
|
||||||
if *path.last().unwrap() == player2 {
|
|
||||||
Some(adv)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let num_ps = ps_correct.len();
|
|
||||||
if num_ps == 0 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
Some(ps_correct.into_iter().sum::<f64>() / num_ps as f64)
|
|
||||||
})
|
|
||||||
.skip_while(|x| x.is_none())
|
|
||||||
.enumerate()
|
|
||||||
.fold((0.0, 0.0), |(total, last), (i, adv)| {
|
|
||||||
let adv_ = adv.unwrap_or(last);
|
|
||||||
(total + (0.5_f64.powi((i + 1) as i32) * adv_), adv_)
|
|
||||||
})
|
|
||||||
.0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn initialize_edge(
|
pub fn initialize_edge(
|
||||||
|
@ -574,8 +560,9 @@ mod tests {
|
||||||
|
|
||||||
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
|
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
|
||||||
|
|
||||||
assert!(
|
assert_eq!(
|
||||||
(hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))? - 1.0) < 0.1
|
hypothetical_advantage(&connection, "test", PlayerId(1), PlayerId(2))?,
|
||||||
|
1.0
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -48,23 +48,23 @@ pub fn get_auth_token(config_dir: &Path) -> Option<String> {
|
||||||
// cynic always assumes that IDs are strings. To get around that, we define new
|
// cynic always assumes that IDs are strings. To get around that, we define new
|
||||||
// scalar types that deserialize to u64.
|
// scalar types that deserialize to u64.
|
||||||
|
|
||||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
#[cynic(graphql_type = "ID")]
|
#[cynic(graphql_type = "ID")]
|
||||||
pub struct VideogameId(pub u64);
|
pub struct VideogameId(pub u64);
|
||||||
|
|
||||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
#[cynic(graphql_type = "ID")]
|
#[cynic(graphql_type = "ID")]
|
||||||
pub struct EventId(pub u64);
|
pub struct EventId(pub u64);
|
||||||
|
|
||||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
#[cynic(graphql_type = "ID")]
|
#[cynic(graphql_type = "ID")]
|
||||||
pub struct EntrantId(pub u64);
|
pub struct EntrantId(pub u64);
|
||||||
|
|
||||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
#[cynic(graphql_type = "ID")]
|
#[cynic(graphql_type = "ID")]
|
||||||
pub struct PlayerId(pub u64);
|
pub struct PlayerId(pub u64);
|
||||||
|
|
||||||
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(cynic::Scalar, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
pub struct Timestamp(pub u64);
|
pub struct Timestamp(pub u64);
|
||||||
|
|
||||||
// Query machinery
|
// Query machinery
|
||||||
|
|
Loading…
Reference in a new issue