Use Dijkstra's algorithm to construct hyp. adv.

This commit is contained in:
Kiana Sheibani 2023-11-18 20:43:26 -05:00
parent 847d879287
commit 092d75dbbb
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -413,9 +413,11 @@ pub fn get_edges(
connection: &Connection,
dataset: &str,
player: PlayerId,
) -> sqlite::Result<Vec<(PlayerId, f64)>> {
) -> sqlite::Result<Vec<(PlayerId, f64, u64)>> {
let query = format!(
r#"SELECT iif(:pl = player_B, player_A, player_B) AS id, iif(:pl = player_B, -advantage, advantage) AS advantage
r#"SELECT
iif(:pl = player_B, player_A, player_B) AS id,
iif(:pl = player_B, -advantage, advantage) AS advantage, sets_count
FROM "{}_network"
WHERE player_A = :pl OR player_B = :pl"#,
dataset
@ -430,6 +432,7 @@ pub fn get_edges(
Ok((
PlayerId(r_.read::<i64, _>("id") as u64),
r_.read::<f64, _>("advantage"),
r_.read::<i64, _>("sets_count") as u64,
))
})
.try_collect()
@ -462,64 +465,82 @@ pub fn either_isolated(
pub fn hypothetical_advantage(
connection: &Connection,
dataset: &str,
decay_rate: f64,
player1: PlayerId,
player2: PlayerId,
set_limit: u64,
decay_rate: f64,
adj_decay_rate: f64,
) -> sqlite::Result<f64> {
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
// Check trivial cases
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
return Ok(0.0);
}
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
let mut visited: HashSet<PlayerId> = HashSet::new();
let mut queue: HashMap<PlayerId, (f64, f64)> = HashMap::from([(player1, (0.0, 1.0))]);
let mut final_path: Option<(f64, usize)> = None;
while !queue.is_empty() {
let visiting = *queue
.iter()
.max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap())
.unwrap()
.0;
let (adv_v, decay_v) = queue.remove(&visiting).unwrap();
let connections = get_edges(connection, dataset, visiting)?;
while final_path.is_none() {
if paths.is_empty() {
return Ok(0.0);
}
paths = paths
for (id, adv, sets) in connections
.into_iter()
.map(|(path, old_adv)| {
let last = *path.last().unwrap();
Ok(get_edges(connection, dataset, last)?
.into_iter()
.filter_map(|(id, adv)| {
if visited.contains(&id) {
None
} else {
.filter(|(id, _, _)| !visited.contains(id))
{
let advantage = adv_v + adv;
if id == player2 {
final_path = Some((old_adv + adv, path.len() - 1));
}
visited.insert(id);
let mut path_ = path.clone();
path_.extend_one(id);
Some((path_, old_adv + adv))
}
})
.collect::<Vec<_>>())
})
.try_fold(Vec::new(), |mut acc, paths| {
acc.extend(paths?);
Ok(acc)
})?;
return Ok(advantage * decay_v);
}
let (final_adv, len) = final_path.unwrap();
Ok(final_adv * decay_rate.powi(len as i32))
let decay = decay_v
* if sets >= set_limit {
decay_rate
} else {
adj_decay_rate
};
if queue
.get(&id)
.map(|(_, decay_old)| *decay_old < decay)
.unwrap_or(true)
{
queue.insert(id, (advantage, decay));
}
}
visited.insert(visiting);
}
// No path found
Ok(0.0)
}
pub fn initialize_edge(
connection: &Connection,
dataset: &str,
decay_rate: f64,
player1: PlayerId,
player2: PlayerId,
set_limit: u64,
decay_rate: f64,
adj_decay_rate: f64,
) -> sqlite::Result<f64> {
let adv = hypothetical_advantage(connection, dataset, decay_rate, player1, player2)?;
let adv = hypothetical_advantage(
connection,
dataset,
player1,
player2,
set_limit,
decay_rate,
adj_decay_rate,
)?;
insert_advantage(connection, dataset, player1, player2, adv)?;
Ok(adv)
}
@ -644,15 +665,15 @@ CREATE TABLE IF NOT EXISTS datasets (
assert_eq!(
get_edges(&connection, "test", PlayerId(1))?,
[(PlayerId(2), -1.0), (PlayerId(3), 5.0)]
[(PlayerId(2), -1.0, 0), (PlayerId(3), 5.0, 0)]
);
assert_eq!(
get_edges(&connection, "test", PlayerId(2))?,
[(PlayerId(1), 1.0)]
[(PlayerId(1), 1.0, 0)]
);
assert_eq!(
get_edges(&connection, "test", PlayerId(3))?,
[(PlayerId(1), -5.0)]
[(PlayerId(1), -5.0, 0)]
);
Ok(())
}
@ -670,10 +691,19 @@ CREATE TABLE IF NOT EXISTS datasets (
Timestamp(0),
)?;
let metadata = metadata();
for i in 1..=num_players {
for j in 1..=num_players {
assert_eq!(
hypothetical_advantage(&connection, "test", 0.5, PlayerId(i), PlayerId(j))?,
hypothetical_advantage(
&connection,
"test",
PlayerId(i),
PlayerId(j),
metadata.set_limit,
metadata.decay_rate,
metadata.adj_decay_rate
)?,
0.0
);
}
@ -690,8 +720,17 @@ CREATE TABLE IF NOT EXISTS datasets (
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
let metadata = metadata();
assert_eq!(
hypothetical_advantage(&connection, "test", 0.5, PlayerId(1), PlayerId(2))?,
hypothetical_advantage(
&connection,
"test",
PlayerId(1),
PlayerId(2),
metadata.set_limit,
metadata.decay_rate,
metadata.adj_decay_rate
)?,
1.0
);