Use Dijkstra's algorithm to construct hyp. adv.
This commit is contained in:
parent
847d879287
commit
092d75dbbb
123
src/datasets.rs
123
src/datasets.rs
|
@ -413,9 +413,11 @@ pub fn get_edges(
|
|||
connection: &Connection,
|
||||
dataset: &str,
|
||||
player: PlayerId,
|
||||
) -> sqlite::Result<Vec<(PlayerId, f64)>> {
|
||||
) -> sqlite::Result<Vec<(PlayerId, f64, u64)>> {
|
||||
let query = format!(
|
||||
r#"SELECT iif(:pl = player_B, player_A, player_B) AS id, iif(:pl = player_B, -advantage, advantage) AS advantage
|
||||
r#"SELECT
|
||||
iif(:pl = player_B, player_A, player_B) AS id,
|
||||
iif(:pl = player_B, -advantage, advantage) AS advantage, sets_count
|
||||
FROM "{}_network"
|
||||
WHERE player_A = :pl OR player_B = :pl"#,
|
||||
dataset
|
||||
|
@ -430,6 +432,7 @@ pub fn get_edges(
|
|||
Ok((
|
||||
PlayerId(r_.read::<i64, _>("id") as u64),
|
||||
r_.read::<f64, _>("advantage"),
|
||||
r_.read::<i64, _>("sets_count") as u64,
|
||||
))
|
||||
})
|
||||
.try_collect()
|
||||
|
@ -462,64 +465,82 @@ pub fn either_isolated(
|
|||
pub fn hypothetical_advantage(
|
||||
connection: &Connection,
|
||||
dataset: &str,
|
||||
decay_rate: f64,
|
||||
player1: PlayerId,
|
||||
player2: PlayerId,
|
||||
set_limit: u64,
|
||||
decay_rate: f64,
|
||||
adj_decay_rate: f64,
|
||||
) -> sqlite::Result<f64> {
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
// Check trivial cases
|
||||
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
|
||||
let mut visited: HashSet<PlayerId> = HashSet::new();
|
||||
let mut queue: HashMap<PlayerId, (f64, f64)> = HashMap::from([(player1, (0.0, 1.0))]);
|
||||
|
||||
let mut final_path: Option<(f64, usize)> = None;
|
||||
while !queue.is_empty() {
|
||||
let visiting = *queue
|
||||
.iter()
|
||||
.max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap())
|
||||
.unwrap()
|
||||
.0;
|
||||
let (adv_v, decay_v) = queue.remove(&visiting).unwrap();
|
||||
let connections = get_edges(connection, dataset, visiting)?;
|
||||
|
||||
while final_path.is_none() {
|
||||
if paths.is_empty() {
|
||||
return Ok(0.0);
|
||||
}
|
||||
paths = paths
|
||||
for (id, adv, sets) in connections
|
||||
.into_iter()
|
||||
.map(|(path, old_adv)| {
|
||||
let last = *path.last().unwrap();
|
||||
Ok(get_edges(connection, dataset, last)?
|
||||
.into_iter()
|
||||
.filter_map(|(id, adv)| {
|
||||
if visited.contains(&id) {
|
||||
None
|
||||
} else {
|
||||
.filter(|(id, _, _)| !visited.contains(id))
|
||||
{
|
||||
let advantage = adv_v + adv;
|
||||
|
||||
if id == player2 {
|
||||
final_path = Some((old_adv + adv, path.len() - 1));
|
||||
}
|
||||
visited.insert(id);
|
||||
let mut path_ = path.clone();
|
||||
path_.extend_one(id);
|
||||
Some((path_, old_adv + adv))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
.try_fold(Vec::new(), |mut acc, paths| {
|
||||
acc.extend(paths?);
|
||||
Ok(acc)
|
||||
})?;
|
||||
return Ok(advantage * decay_v);
|
||||
}
|
||||
|
||||
let (final_adv, len) = final_path.unwrap();
|
||||
Ok(final_adv * decay_rate.powi(len as i32))
|
||||
let decay = decay_v
|
||||
* if sets >= set_limit {
|
||||
decay_rate
|
||||
} else {
|
||||
adj_decay_rate
|
||||
};
|
||||
|
||||
if queue
|
||||
.get(&id)
|
||||
.map(|(_, decay_old)| *decay_old < decay)
|
||||
.unwrap_or(true)
|
||||
{
|
||||
queue.insert(id, (advantage, decay));
|
||||
}
|
||||
}
|
||||
|
||||
visited.insert(visiting);
|
||||
}
|
||||
|
||||
// No path found
|
||||
Ok(0.0)
|
||||
}
|
||||
|
||||
pub fn initialize_edge(
|
||||
connection: &Connection,
|
||||
dataset: &str,
|
||||
decay_rate: f64,
|
||||
player1: PlayerId,
|
||||
player2: PlayerId,
|
||||
set_limit: u64,
|
||||
decay_rate: f64,
|
||||
adj_decay_rate: f64,
|
||||
) -> sqlite::Result<f64> {
|
||||
let adv = hypothetical_advantage(connection, dataset, decay_rate, player1, player2)?;
|
||||
let adv = hypothetical_advantage(
|
||||
connection,
|
||||
dataset,
|
||||
player1,
|
||||
player2,
|
||||
set_limit,
|
||||
decay_rate,
|
||||
adj_decay_rate,
|
||||
)?;
|
||||
insert_advantage(connection, dataset, player1, player2, adv)?;
|
||||
Ok(adv)
|
||||
}
|
||||
|
@ -644,15 +665,15 @@ CREATE TABLE IF NOT EXISTS datasets (
|
|||
|
||||
assert_eq!(
|
||||
get_edges(&connection, "test", PlayerId(1))?,
|
||||
[(PlayerId(2), -1.0), (PlayerId(3), 5.0)]
|
||||
[(PlayerId(2), -1.0, 0), (PlayerId(3), 5.0, 0)]
|
||||
);
|
||||
assert_eq!(
|
||||
get_edges(&connection, "test", PlayerId(2))?,
|
||||
[(PlayerId(1), 1.0)]
|
||||
[(PlayerId(1), 1.0, 0)]
|
||||
);
|
||||
assert_eq!(
|
||||
get_edges(&connection, "test", PlayerId(3))?,
|
||||
[(PlayerId(1), -5.0)]
|
||||
[(PlayerId(1), -5.0, 0)]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -670,10 +691,19 @@ CREATE TABLE IF NOT EXISTS datasets (
|
|||
Timestamp(0),
|
||||
)?;
|
||||
|
||||
let metadata = metadata();
|
||||
for i in 1..=num_players {
|
||||
for j in 1..=num_players {
|
||||
assert_eq!(
|
||||
hypothetical_advantage(&connection, "test", 0.5, PlayerId(i), PlayerId(j))?,
|
||||
hypothetical_advantage(
|
||||
&connection,
|
||||
"test",
|
||||
PlayerId(i),
|
||||
PlayerId(j),
|
||||
metadata.set_limit,
|
||||
metadata.decay_rate,
|
||||
metadata.adj_decay_rate
|
||||
)?,
|
||||
0.0
|
||||
);
|
||||
}
|
||||
|
@ -690,8 +720,17 @@ CREATE TABLE IF NOT EXISTS datasets (
|
|||
|
||||
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
|
||||
|
||||
let metadata = metadata();
|
||||
assert_eq!(
|
||||
hypothetical_advantage(&connection, "test", 0.5, PlayerId(1), PlayerId(2))?,
|
||||
hypothetical_advantage(
|
||||
&connection,
|
||||
"test",
|
||||
PlayerId(1),
|
||||
PlayerId(2),
|
||||
metadata.set_limit,
|
||||
metadata.decay_rate,
|
||||
metadata.adj_decay_rate
|
||||
)?,
|
||||
1.0
|
||||
);
|
||||
|
||||
|
|
Loading…
Reference in a new issue