Use Dijkstra's algorithm to construct hyp. adv.
This commit is contained in:
parent
847d879287
commit
092d75dbbb
123
src/datasets.rs
123
src/datasets.rs
|
@ -413,9 +413,11 @@ pub fn get_edges(
|
||||||
connection: &Connection,
|
connection: &Connection,
|
||||||
dataset: &str,
|
dataset: &str,
|
||||||
player: PlayerId,
|
player: PlayerId,
|
||||||
) -> sqlite::Result<Vec<(PlayerId, f64)>> {
|
) -> sqlite::Result<Vec<(PlayerId, f64, u64)>> {
|
||||||
let query = format!(
|
let query = format!(
|
||||||
r#"SELECT iif(:pl = player_B, player_A, player_B) AS id, iif(:pl = player_B, -advantage, advantage) AS advantage
|
r#"SELECT
|
||||||
|
iif(:pl = player_B, player_A, player_B) AS id,
|
||||||
|
iif(:pl = player_B, -advantage, advantage) AS advantage, sets_count
|
||||||
FROM "{}_network"
|
FROM "{}_network"
|
||||||
WHERE player_A = :pl OR player_B = :pl"#,
|
WHERE player_A = :pl OR player_B = :pl"#,
|
||||||
dataset
|
dataset
|
||||||
|
@ -430,6 +432,7 @@ pub fn get_edges(
|
||||||
Ok((
|
Ok((
|
||||||
PlayerId(r_.read::<i64, _>("id") as u64),
|
PlayerId(r_.read::<i64, _>("id") as u64),
|
||||||
r_.read::<f64, _>("advantage"),
|
r_.read::<f64, _>("advantage"),
|
||||||
|
r_.read::<i64, _>("sets_count") as u64,
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.try_collect()
|
.try_collect()
|
||||||
|
@ -462,64 +465,82 @@ pub fn either_isolated(
|
||||||
pub fn hypothetical_advantage(
|
pub fn hypothetical_advantage(
|
||||||
connection: &Connection,
|
connection: &Connection,
|
||||||
dataset: &str,
|
dataset: &str,
|
||||||
decay_rate: f64,
|
|
||||||
player1: PlayerId,
|
player1: PlayerId,
|
||||||
player2: PlayerId,
|
player2: PlayerId,
|
||||||
|
set_limit: u64,
|
||||||
|
decay_rate: f64,
|
||||||
|
adj_decay_rate: f64,
|
||||||
) -> sqlite::Result<f64> {
|
) -> sqlite::Result<f64> {
|
||||||
use std::collections::HashSet;
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
// Check trivial cases
|
// Check trivial cases
|
||||||
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
|
if player1 == player2 || either_isolated(connection, dataset, player1, player2)? {
|
||||||
return Ok(0.0);
|
return Ok(0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut paths: Vec<(Vec<PlayerId>, f64)> = vec![(vec![player1], 0.0)];
|
|
||||||
let mut visited: HashSet<PlayerId> = HashSet::new();
|
let mut visited: HashSet<PlayerId> = HashSet::new();
|
||||||
|
let mut queue: HashMap<PlayerId, (f64, f64)> = HashMap::from([(player1, (0.0, 1.0))]);
|
||||||
|
|
||||||
let mut final_path: Option<(f64, usize)> = None;
|
while !queue.is_empty() {
|
||||||
|
let visiting = *queue
|
||||||
|
.iter()
|
||||||
|
.max_by(|a, b| a.1 .1.partial_cmp(&b.1 .1).unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.0;
|
||||||
|
let (adv_v, decay_v) = queue.remove(&visiting).unwrap();
|
||||||
|
let connections = get_edges(connection, dataset, visiting)?;
|
||||||
|
|
||||||
while final_path.is_none() {
|
for (id, adv, sets) in connections
|
||||||
if paths.is_empty() {
|
|
||||||
return Ok(0.0);
|
|
||||||
}
|
|
||||||
paths = paths
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(path, old_adv)| {
|
.filter(|(id, _, _)| !visited.contains(id))
|
||||||
let last = *path.last().unwrap();
|
{
|
||||||
Ok(get_edges(connection, dataset, last)?
|
let advantage = adv_v + adv;
|
||||||
.into_iter()
|
|
||||||
.filter_map(|(id, adv)| {
|
|
||||||
if visited.contains(&id) {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
if id == player2 {
|
if id == player2 {
|
||||||
final_path = Some((old_adv + adv, path.len() - 1));
|
return Ok(advantage * decay_v);
|
||||||
}
|
|
||||||
visited.insert(id);
|
|
||||||
let mut path_ = path.clone();
|
|
||||||
path_.extend_one(id);
|
|
||||||
Some((path_, old_adv + adv))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>())
|
|
||||||
})
|
|
||||||
.try_fold(Vec::new(), |mut acc, paths| {
|
|
||||||
acc.extend(paths?);
|
|
||||||
Ok(acc)
|
|
||||||
})?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let (final_adv, len) = final_path.unwrap();
|
let decay = decay_v
|
||||||
Ok(final_adv * decay_rate.powi(len as i32))
|
* if sets >= set_limit {
|
||||||
|
decay_rate
|
||||||
|
} else {
|
||||||
|
adj_decay_rate
|
||||||
|
};
|
||||||
|
|
||||||
|
if queue
|
||||||
|
.get(&id)
|
||||||
|
.map(|(_, decay_old)| *decay_old < decay)
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
queue.insert(id, (advantage, decay));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
visited.insert(visiting);
|
||||||
|
}
|
||||||
|
|
||||||
|
// No path found
|
||||||
|
Ok(0.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn initialize_edge(
|
pub fn initialize_edge(
|
||||||
connection: &Connection,
|
connection: &Connection,
|
||||||
dataset: &str,
|
dataset: &str,
|
||||||
decay_rate: f64,
|
|
||||||
player1: PlayerId,
|
player1: PlayerId,
|
||||||
player2: PlayerId,
|
player2: PlayerId,
|
||||||
|
set_limit: u64,
|
||||||
|
decay_rate: f64,
|
||||||
|
adj_decay_rate: f64,
|
||||||
) -> sqlite::Result<f64> {
|
) -> sqlite::Result<f64> {
|
||||||
let adv = hypothetical_advantage(connection, dataset, decay_rate, player1, player2)?;
|
let adv = hypothetical_advantage(
|
||||||
|
connection,
|
||||||
|
dataset,
|
||||||
|
player1,
|
||||||
|
player2,
|
||||||
|
set_limit,
|
||||||
|
decay_rate,
|
||||||
|
adj_decay_rate,
|
||||||
|
)?;
|
||||||
insert_advantage(connection, dataset, player1, player2, adv)?;
|
insert_advantage(connection, dataset, player1, player2, adv)?;
|
||||||
Ok(adv)
|
Ok(adv)
|
||||||
}
|
}
|
||||||
|
@ -644,15 +665,15 @@ CREATE TABLE IF NOT EXISTS datasets (
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
get_edges(&connection, "test", PlayerId(1))?,
|
get_edges(&connection, "test", PlayerId(1))?,
|
||||||
[(PlayerId(2), -1.0), (PlayerId(3), 5.0)]
|
[(PlayerId(2), -1.0, 0), (PlayerId(3), 5.0, 0)]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
get_edges(&connection, "test", PlayerId(2))?,
|
get_edges(&connection, "test", PlayerId(2))?,
|
||||||
[(PlayerId(1), 1.0)]
|
[(PlayerId(1), 1.0, 0)]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
get_edges(&connection, "test", PlayerId(3))?,
|
get_edges(&connection, "test", PlayerId(3))?,
|
||||||
[(PlayerId(1), -5.0)]
|
[(PlayerId(1), -5.0, 0)]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -670,10 +691,19 @@ CREATE TABLE IF NOT EXISTS datasets (
|
||||||
Timestamp(0),
|
Timestamp(0),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
let metadata = metadata();
|
||||||
for i in 1..=num_players {
|
for i in 1..=num_players {
|
||||||
for j in 1..=num_players {
|
for j in 1..=num_players {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
hypothetical_advantage(&connection, "test", 0.5, PlayerId(i), PlayerId(j))?,
|
hypothetical_advantage(
|
||||||
|
&connection,
|
||||||
|
"test",
|
||||||
|
PlayerId(i),
|
||||||
|
PlayerId(j),
|
||||||
|
metadata.set_limit,
|
||||||
|
metadata.decay_rate,
|
||||||
|
metadata.adj_decay_rate
|
||||||
|
)?,
|
||||||
0.0
|
0.0
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -690,8 +720,17 @@ CREATE TABLE IF NOT EXISTS datasets (
|
||||||
|
|
||||||
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
|
insert_advantage(&connection, "test", PlayerId(1), PlayerId(2), 1.0)?;
|
||||||
|
|
||||||
|
let metadata = metadata();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
hypothetical_advantage(&connection, "test", 0.5, PlayerId(1), PlayerId(2))?,
|
hypothetical_advantage(
|
||||||
|
&connection,
|
||||||
|
"test",
|
||||||
|
PlayerId(1),
|
||||||
|
PlayerId(2),
|
||||||
|
metadata.set_limit,
|
||||||
|
metadata.decay_rate,
|
||||||
|
metadata.adj_decay_rate
|
||||||
|
)?,
|
||||||
1.0
|
1.0
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue