diff --git a/Cargo.toml b/Cargo.toml index f408efb..1a11170 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ crc64 = "2.0.0" anyhow = "1.0" thiserror = "1.0" arrayvec = "0.7.4" +num_cpus = "1.16.0" diff --git a/src/engine/midgame.rs b/src/engine/midgame.rs index 4f74a0f..ea2f42f 100644 --- a/src/engine/midgame.rs +++ b/src/engine/midgame.rs @@ -8,7 +8,12 @@ use futures::channel::mpsc; use futures::channel::mpsc::UnboundedSender; use futures::future::{BoxFuture, FutureExt}; use futures::StreamExt; +use num_cpus; use std::cmp::max; +use std::collections::HashSet; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; +use std::thread; struct YBWCContext { tx: UnboundedSender<((i8, Option), SolveStat)>, @@ -225,3 +230,201 @@ where } .boxed() } + +fn simplified_abdada_body<'a>( + solve_obj: &mut SolveObj, + sub_solver: &SubSolver, + board: Board, + (mut alpha, beta): (i8, i8), + passed: bool, + depth: i8, + cs_hash: &'a Arc>>, +) -> (i8, Option, SolveStat) { + let v = move_ordering_impl(solve_obj, board, None); + let mut stat = SolveStat::one(); + if v.is_empty() { + if passed { + return (board.score(), Some(Hand::Pass), stat); + } else { + let (child_res, _child_best, child_stat) = simplified_abdada_intro( + solve_obj, + sub_solver, + board.pass_unchecked(), + (-beta, -alpha), + true, + depth, + cs_hash, + ); + stat.merge(child_stat); + return (-child_res, Some(Hand::Pass), stat); + } + } + let mut q = VecDeque::with_capacity(v.len()); + for (pos, next) in v { + q.push_back((pos, next, false)); + } + let mut res = -(BOARD_SIZE as i8); + let mut best = None; + let mut is_first = true; + let mut stat = SolveStat::one(); + while let Some((pos, next, deffered)) = q.pop_front() { + if is_first { + start_search(next, &cs_hash); + let (cres, _chand, cstat) = simplified_abdada_intro( + solve_obj, + sub_solver, + next, + (-beta, -alpha), + false, + depth + 1, + cs_hash, + ); + finish_search(next, &cs_hash); + stat.merge(cstat); + res = -cres; + best = Some(pos); + alpha = max(alpha, res); + if alpha >= beta { + return (res, best, stat); + } + is_first = false; + continue; + } + if !deffered && defer_search(next, &cs_hash) { + q.push_back((pos, next, true)); + continue; + } + // NWS + start_search(next, &cs_hash); + let (cres, _chand, cstat) = simplified_abdada_intro( + solve_obj, + sub_solver, + next, + (-alpha - 1, -alpha), + false, + depth + 1, + cs_hash, + ); + finish_search(next, &cs_hash); + stat.merge(cstat); + let mut tmp = -cres; + if alpha < tmp && tmp < beta { + let (cres, _chand, cstat) = simplified_abdada_intro( + solve_obj, + sub_solver, + next, + (-beta, -tmp), + false, + depth + 1, + cs_hash, + ); + stat.merge(cstat); + tmp = -cres; + } + if tmp >= beta { + return (tmp, Some(pos), stat); + } + if tmp > res { + best = Some(pos); + res = tmp; + } + } + (res, best, stat) +} + +fn simplified_abdada_intro<'a>( + solve_obj: &mut SolveObj, + sub_solver: &SubSolver, + board: Board, + (mut alpha, mut beta): (i8, i8), + passed: bool, + depth: i8, + cs_hash: &'a Arc>>, +) -> (i8, Option, SolveStat) { + let rem = popcnt(board.empty()); + if depth >= solve_obj.params.ybwc_depth_limit || rem < solve_obj.params.ybwc_empties_limit { + let (res, stat) = solve_inner(solve_obj, board, (alpha, beta), passed); + return (res, None, stat); + } + match stability_cut(board, (alpha, beta)) { + CutType::NoCut => (), + CutType::MoreThanBeta(v) => return (v, None, SolveStat::one_stcut()), + CutType::LessThanAlpha(v) => return (v, None, SolveStat::one_stcut()), + } + let (lower, upper, _old_best) = match lookup_table(solve_obj, board, (&mut alpha, &mut beta)) { + CacheLookupResult::Cut(v) => return (v, None, SolveStat::zero()), + CacheLookupResult::NoCut(l, u, b) => (l, u, b), + }; + let (res, best, stat) = simplified_abdada_body( + solve_obj, + sub_solver, + board, + (alpha, beta), + passed, + depth, + cs_hash, + ); + if rem >= solve_obj.params.res_cache_limit { + update_table( + solve_obj.res_cache.clone(), + solve_obj.cache_gen, + board, + res, + best, + (alpha, beta), + (lower, upper), + ); + } + (res, best, stat) +} + +fn start_search(board: Board, cs_hash: &Arc>>) { + cs_hash.lock().unwrap().insert(board); +} + +fn finish_search(board: Board, cs_hash: &Arc>>) { + cs_hash.lock().unwrap().remove(&board); +} + +fn defer_search(board: Board, cs_hash: &Arc>>) -> bool { + cs_hash.lock().unwrap().contains(&board) +} + +pub fn simplified_abdada( + solve_obj: &mut SolveObj, + sub_solver: &SubSolver, + board: Board, + (alpha, beta): (i8, i8), + passed: bool, + depth: i8, +) -> (i8, Option, SolveStat) { + thread::scope(|s| { + let mut handles = Vec::new(); + let cs_hash = Arc::new(Mutex::new(HashSet::new())); + for _ in 0..num_cpus::get() { + let mut solve_obj = solve_obj.clone(); + let cs_hash = cs_hash.clone(); + handles.push(s.spawn(move || { + simplified_abdada_intro( + &mut solve_obj, + sub_solver, + board, + (alpha, beta), + passed, + depth, + &cs_hash, + ) + })); + } + let mut stat = SolveStat::zero(); + let mut res = -(BOARD_SIZE as i8); + let mut best = None; + for h in handles { + let (tres, tbest, tstat) = h.join().unwrap(); + stat.merge(tstat); + res = tres; + best = tbest; + } + (res, best, stat) + }) +} diff --git a/src/engine/search.rs b/src/engine/search.rs index 1265280..4155f3a 100644 --- a/src/engine/search.rs +++ b/src/engine/search.rs @@ -17,7 +17,6 @@ use std::cmp::{max, min}; use std::io::Write; use std::mem::swap; use std::sync::Arc; -use tokio::runtime::Runtime; use tokio::sync::Semaphore; #[derive(Debug, Serialize, Deserialize)] @@ -277,11 +276,8 @@ pub fn solve( passed: bool, depth: i8, ) -> (i8, Option, SolveStat) { - let rt = Runtime::new().unwrap(); - rt.block_on(async move { - let sub_solver = SubSolver::new(worker_urls); - solve_outer(solve_obj, &sub_solver, board, (alpha, beta), passed, depth).await - }) + let sub_solver = SubSolver::new(worker_urls); + simplified_abdada(solve_obj, &sub_solver, board, (alpha, beta), passed, depth) } pub async fn solve_with_move(board: Board, solve_obj: &mut SolveObj, sub_solver: &Arc) -> Hand {