Use Dijkstra's algorithm to construct hyp. adv.

This commit is contained in:
Kiana Sheibani 2023-11-18 20:43:26 -05:00
parent 847d879287
commit 092d75dbbb
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -413,11 +413,13 @@ pub fn get_edges(
connection: &Connection, connection: &Connection,
dataset: &str, dataset: &str,
player: PlayerId, player: PlayerId,
) -> sqlite::Result<Vec<(PlayerId, f64)>> { ) -> sqlite::Result<Vec<(PlayerId, f64, u64)>> {
let query = format!( let query = format!(
r#"SELECT iif(:pl = player_B, player_A, player_B) AS id, iif(:pl = player_B, -advantage, advantage) AS advantage r#"SELECT
FROM "{}_network" iif(:pl = player_B, player_A, player_B) AS id,
WHERE player_A = :pl OR player_B = :pl"#, iif(:pl = player_B, -advantage, advantage) AS advantage, sets_count
FROM "{}_network"
WHERE player_A = :pl OR player_B = :pl"#,
dataset dataset
); );
@ -430,6 +432,7 @@ pub fn get_edges(
Ok(( Ok((
PlayerId(r_.read::<i64, _>("id") as u64), PlayerId(r_.read::<i64, _>("id") as u64),
r_.read::<f64, _>("advantage"), r_.read::<f64, _>("advantage"),
r_.read::<i64, _>("sets_count") as u64,
)) ))
}) })
.try_collect() .try_collect()
@ -462,64 +465,82 @@ pub fn either_isolated(
pub fn hypothetical_advantage( pub fn hypothetical_advantage(
connection: &Connection, connection: &Connection,
dataset: &str, dataset: &str,
decay_rate: f64,
player1: PlayerId, player1: PlayerId,
player2: PlayerId, player2: PlayerId,
set_limit: u64,
decay_rate: f64,
adj_decay_rate: f64,
) -> sqlite::Result<f64> { ) -> sqlite::Result<f64> {
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
// Check trivial cases // Check trivial cases
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? { if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
return Ok(0.0); return Ok(0.0);
} }
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
let mut visited: HashSet<PlayerId> = HashSet::new(); let mut visited: HashSet<PlayerId> = HashSet::new();
let mut queue: HashMap<PlayerId, (f64, f64)> = HashMap::from([(player1, (0.0, 1.0))]);
let mut final_path: Option<(f64, usize)> = None; while !queue.is_empty() {
let visiting = *queue
.iter()
.max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap())
.unwrap()
.0;
let (adv_v, decay_v) = queue.remove(&visiting).unwrap();
let connections = get_edges(connection, dataset, visiting)?;
while final_path.is_none() { for (id, adv, sets) in connections
if paths.is_empty() {
return Ok(0.0);
}
paths = paths
.into_iter() .into_iter()
.map(|(path, old_adv)| { .filter(|(id, _, _)| !visited.contains(id))
let last = *path.last().unwrap(); {
Ok(get_edges(connection, dataset, last)? let advantage = adv_v + adv;
.into_iter()
.filter_map(|(id, adv)| { if id == player2 {
if visited.contains(&id) { return Ok(advantage * decay_v);
None }
} else {
if id == player2 { let decay = decay_v
final_path = Some((old_adv + adv, path.len() - 1)); * if sets >= set_limit {
} decay_rate
visited.insert(id); } else {
let mut path_ = path.clone(); adj_decay_rate
path_.extend_one(id); };
Some((path_, old_adv + adv))
} if queue
}) .get(&id)
.collect::<Vec<_>>()) .map(|(_, decay_old)| *decay_old < decay)
}) .unwrap_or(true)
.try_fold(Vec::new(), |mut acc, paths| { {
acc.extend(paths?); queue.insert(id, (advantage, decay));
Ok(acc) }
})?; }
visited.insert(visiting);
} }
let (final_adv, len) = final_path.unwrap(); // No path found
Ok(final_adv * decay_rate.powi(len as i32)) Ok(0.0)
} }
pub fn initialize_edge( pub fn initialize_edge(
connection: &Connection, connection: &Connection,
dataset: &str, dataset: &str,
decay_rate: f64,
player1: PlayerId, player1: PlayerId,
player2: PlayerId, player2: PlayerId,
set_limit: u64,
decay_rate: f64,
adj_decay_rate: f64,
) -> sqlite::Result<f64> { ) -> sqlite::Result<f64> {
let adv = hypothetical_advantage(connection, dataset, decay_rate, player1, player2)?; let adv = hypothetical_advantage(
connection,
dataset,
player1,
player2,
set_limit,
decay_rate,
adj_decay_rate,
)?;
insert_advantage(connection, dataset, player1, player2, adv)?; insert_advantage(connection, dataset, player1, player2, adv)?;
Ok(adv) Ok(adv)
} }
@ -644,15 +665,15 @@ CREATE TABLE IF NOT EXISTS datasets (
assert_eq!( assert_eq!(
get_edges(&connection, "test", PlayerId(1))?, get_edges(&connection, "test", PlayerId(1))?,
[(PlayerId(2), -1.0), (PlayerId(3), 5.0)] [(PlayerId(2), -1.0, 0), (PlayerId(3), 5.0, 0)]
); );
assert_eq!( assert_eq!(
get_edges(&connection, "test", PlayerId(2))?, get_edges(&connection, "test", PlayerId(2))?,
[(PlayerId(1), 1.0)] [(PlayerId(1), 1.0, 0)]
); );
assert_eq!( assert_eq!(
get_edges(&connection, "test", PlayerId(3))?, get_edges(&connection, "test", PlayerId(3))?,
[(PlayerId(1), -5.0)] [(PlayerId(1), -5.0, 0)]
); );
Ok(()) Ok(())
} }
@ -670,10 +691,19 @@ CREATE TABLE IF NOT EXISTS datasets (
Timestamp(0), Timestamp(0),
)?; )?;
let metadata = metadata();
for i in 1..=num_players { for i in 1..=num_players {
for j in 1..=num_players { for j in 1..=num_players {
assert_eq!( assert_eq!(
hypothetical_advantage(&connection, "test", 0.5, PlayerId(i), PlayerId(j))?, hypothetical_advantage(
&connection,
"test",
PlayerId(i),
PlayerId(j),
metadata.set_limit,
metadata.decay_rate,
metadata.adj_decay_rate
)?,
0.0 0.0
); );
} }
@ -690,8 +720,17 @@ CREATE TABLE IF NOT EXISTS datasets (
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?; insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
let metadata = metadata();
assert_eq!( assert_eq!(
hypothetical_advantage(&connection, "test", 0.5, PlayerId(1), PlayerId(2))?, hypothetical_advantage(
&connection,
"test",
PlayerId(1),
PlayerId(2),
metadata.set_limit,
metadata.decay_rate,
metadata.adj_decay_rate
)?,
1.0 1.0
); );