diff --git a/src/database.rs b/src/database.rs index 5e5533b..d5f3a46 100644 --- a/src/database.rs +++ b/src/database.rs @@ -402,6 +402,32 @@ pub fn get_player_set_counts( )) } +pub fn get_matchup_set_counts( + connection: &Connection, + dataset: &str, + player1: PlayerId, + player2: PlayerId, +) -> sqlite::Result<(u64, u64)> { + if player1 == player2 { + return Ok((0, 0)); + } + + let query = format!( + r#"SELECT iif(:a > :b, sets_count_B, sets_count_A) sets_count_A, iif(:a > :b, sets_count_A, sets_count_B) sets_count_B + FROM "{}_network" WHERE player_A = min(:a, :b) AND player_B = max(:a, :b)"#, + dataset + ); + + let mut statement = connection.prepare(&query)?; + statement.bind((":a", player1.0 as i64))?; + statement.bind((":b", player2.0 as i64))?; + statement.next()?; + Ok(( + statement.read::("sets_count_A")? as u64, + statement.read::("sets_count_B")? as u64, + )) +} + pub fn set_player_data( connection: &Connection, dataset: &str, @@ -484,13 +510,10 @@ pub fn adjust_advantages( adjust1: f64, adjust2: f64, decay_rate: f64, - adj_decay_rate: f64, - set_limit: u64, ) -> sqlite::Result<()> { let query1 = format!( r#"UPDATE "{}_network" -SET advantage = advantage + iif(:pl = player_A, -:v, :v) - * iif(sets_count >= :sl, :d, :da) +SET advantage = advantage + iif(:pl = player_A, -:v, :v) * :d WHERE (player_A = :pl AND player_B != :plo) OR (player_B = :pl AND player_A != :plo)"#, dataset @@ -508,18 +531,14 @@ WHERE player_A = min(:a, :b) AND player_B = max(:a, :b)"#, statement.bind((":pl", player1.0 as i64))?; statement.bind((":plo", player2.0 as i64))?; statement.bind((":v", adjust1))?; - statement.bind((":sl", set_limit as i64))?; statement.bind((":d", decay_rate))?; - statement.bind((":da", adj_decay_rate))?; statement.into_iter().try_for_each(|x| x.map(|_| ()))?; statement = connection.prepare(&query1)?; statement.bind((":pl", player2.0 as i64))?; statement.bind((":plo", player1.0 as i64))?; statement.bind((":v", adjust2))?; - statement.bind((":sl", set_limit as i64))?; statement.bind((":d", decay_rate))?; - statement.bind((":da", adj_decay_rate))?; statement.into_iter().try_for_each(|x| x.map(|_| ()))?; statement = connection.prepare(&query2)?; diff --git a/src/sync.rs b/src/sync.rs index 7639abc..e8a9398 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -265,6 +265,13 @@ fn update_from_set( &results.id, )?; + let (sets1, sets2) = get_matchup_set_counts(connection, dataset, player1, player2)?; + let decay_rate = if sets1 + sets2 >= metadata.set_limit { + metadata.decay_rate + } else { + metadata.adj_decay_rate + }; + adjust_advantages( connection, dataset, @@ -274,9 +281,7 @@ fn update_from_set( results.winner, adjust1, adjust2, - metadata.decay_rate, - metadata.adj_decay_rate, - metadata.set_limit, + decay_rate, ) } @@ -323,7 +328,7 @@ pub fn sync_dataset( #[cfg(test)] mod tests { use super::*; - use crate::datasets::tests::*; + use crate::database::tests::*; #[test] fn glicko_single() -> sqlite::Result<()> {