diff --git a/Cargo.toml b/Cargo.toml index f808187..f408efb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ reqwest = { version = "0.11.12", features = ["json"], default-features = false } crc64 = "2.0.0" anyhow = "1.0" thiserror = "1.0" +arrayvec = "0.7.4" diff --git a/src/engine.rs b/src/engine.rs index 357bda0..9c9e931 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -3,6 +3,7 @@ pub mod board; pub mod endgame; pub mod eval; pub mod hand; +pub mod last_cache; pub mod midgame; pub mod search; pub mod table; diff --git a/src/engine/board.rs b/src/engine/board.rs index 790976e..7776d09 100644 --- a/src/engine/board.rs +++ b/src/engine/board.rs @@ -140,24 +140,24 @@ impl Board { self.flip(pos) != 0 } - pub fn play(&self, pos: usize) -> Result { + pub fn play(&self, pos: usize) -> Option { if pos >= BOARD_SIZE { - return Err(UnmovableError {}.into()); + return None; } if ((self.player >> pos) & 1) != 0 || ((self.opponent >> pos) & 1) != 0 { - return Err(UnmovableError {}.into()); + return None; } let flip_bits = self.flip(pos); if flip_bits == 0 { - return Err(UnmovableError {}.into()); + return None; } - Ok(Board { + Some(Board { player: self.opponent ^ flip_bits, opponent: (self.player ^ flip_bits) | (1u64 << pos), }) } - pub fn play_hand(&self, hand: Hand) -> Result { + pub fn play_hand(&self, hand: Hand) -> Option { match hand { Hand::Play(pos) => self.play(pos), Hand::Pass => self.pass(), @@ -171,14 +171,14 @@ impl Board { } } - pub fn pass(&self) -> Result { + pub fn pass(&self) -> Option { if self.mobility_bits() == 0 { - Ok(Board { + Some(Board { player: self.opponent, opponent: self.player, }) } else { - Err(UnmovableError {}.into()) + None } } @@ -446,9 +446,9 @@ impl BoardWithColor { } } - pub fn play(&self, pos: usize) -> Result { + pub fn play(&self, pos: usize) -> Option { let board = self.board.play(pos)?; - Ok(BoardWithColor { + Some(BoardWithColor { board, is_black: !self.is_black, }) @@ -461,8 +461,8 @@ impl BoardWithColor { } } - pub fn pass(&self) -> Result { - Ok(BoardWithColor { + pub fn pass(&self) -> Option { + Some(BoardWithColor { board: self.board.pass()?, is_black: !self.is_black, }) @@ -558,7 +558,7 @@ impl Iterator for PlayIterator { while self.remain != 0 { let pos = self.remain.trailing_zeros() as usize; self.remain &= self.remain - 1; - if let Ok(next) = self.board.play(pos) { + if let Some(next) = self.board.play(pos) { return Some((next, Hand::Play(pos))); } } diff --git a/src/engine/endgame.rs b/src/engine/endgame.rs index a1287f4..4b11dff 100644 --- a/src/engine/endgame.rs +++ b/src/engine/endgame.rs @@ -5,26 +5,19 @@ use crate::engine::board::*; use crate::engine::hand::*; use crate::engine::search::*; use crate::engine::table::*; +use arrayvec::ArrayVec; use bitintr::Tzcnt; use std::cmp::max; -use std::mem::MaybeUninit; -fn near_leaf(board: Board) -> (i8, SolveStat) { - let bit = board.empty(); - let pos = bit.tzcnt() as usize; - match board.play(pos) { - Ok(next) => (-next.score(), SolveStat::one()), - Err(_) => ( - match board.pass_unchecked().play(pos) { - Ok(next) => next.score(), - Err(_) => board.score(), - }, - SolveStat { - node_count: 2, - st_cut_count: 0, - }, - ), - } +fn near_leaf(solve_obj: &mut SolveObj, board: Board) -> (i8, SolveStat) { + let (score, node_count) = solve_obj.last_cache.solve_last(board); + ( + score, + SolveStat { + node_count, + st_cut_count: 0, + }, + ) } fn naive(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8, i8), passed: bool) -> (i8, SolveStat) { @@ -68,7 +61,7 @@ fn static_order(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8, while remain != 0 { let pos = remain.tzcnt() as usize; remain = remain & (remain - 1); - if let Ok(next) = board.play(pos) { + if let Some(next) = board.play(pos) { pass = false; let (child_res, child_stat) = solve_inner(solve_obj, next, (-beta, -alpha), false); res = max(res, -child_res); @@ -112,19 +105,14 @@ fn negascout_impl(solve_obj: &mut SolveObj, next: Board, (alpha, beta): (i8, i8) fn fastest_first(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8, i8), passed: bool) -> (i8, SolveStat) { const MAX_FFS_NEXT: usize = 20; - let nexts = MaybeUninit::<[(i8, Board); MAX_FFS_NEXT]>::uninit(); - let mut nexts = unsafe { nexts.assume_init() }; - let mut count = 0; + let mut nexts = ArrayVec::<_, MAX_FFS_NEXT>::new(); for (next, _pos) in board.next_iter() { - nexts[count] = (weighted_mobility(&next), next); - count += 1; + nexts.push((weighted_mobility(&next), next)); } - assert!(count <= MAX_FFS_NEXT); - - nexts[0..count].sort_by(|a, b| a.0.cmp(&b.0)); + nexts.sort_by(|a, b| a.0.cmp(&b.0)); let mut res = -(BOARD_SIZE as i8); let mut stat = SolveStat::one(); - for (i, &(_, next)) in nexts[0..count].iter().enumerate() { + for (i, &(_, next)) in nexts.iter().enumerate() { let (child_res, child_stat) = negascout_impl(solve_obj, next, (alpha, beta), i == 0); res = max(res, -child_res); stat.merge(child_stat); @@ -133,7 +121,7 @@ fn fastest_first(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8, return (res, stat); } } - if count == 0 { + if nexts.is_empty() { if passed { return (board.score(), stat); } else { @@ -190,7 +178,7 @@ pub fn solve_inner( if rem == 0 { (board.score(), SolveStat::zero()) } else if rem == 1 { - near_leaf(board) + near_leaf(solve_obj, board) } else if rem < solve_obj.params.static_ordering_limit { naive(solve_obj, board, (alpha, beta), passed) } else if rem < solve_obj.params.ffs_ordering_limit { diff --git a/src/engine/last_cache.rs b/src/engine/last_cache.rs new file mode 100644 index 0000000..8598a14 --- /dev/null +++ b/src/engine/last_cache.rs @@ -0,0 +1,113 @@ +#[cfg(test)] +mod test; +use crate::engine::bits::*; +use crate::engine::board::*; +use bitintr::Tzcnt; + +pub struct LastCache { + table: [(i8, i8); 4096], + masks: [(u64, u64, u64); BOARD_SIZE], + indices: [(u8, u8); BOARD_SIZE], +} + +impl LastCache { + fn pre_compute(bits: u64, pos: usize) -> (i8, i8) { + let opp_mask = 0xff ^ (1 << pos); + let opp = !bits & opp_mask; + let board_first = Board { + player: bits, + opponent: opp, + }; + let board_second = Board { + player: opp, + opponent: bits, + }; + ( + popcnt(board_first.flip_unchecked(pos)), + popcnt(board_second.flip_unchecked(pos)), + ) + } + + pub fn new() -> LastCache { + let mut table = [(0, 0); 4096]; + for bits in 0..256 { + for pos in 0..8 { + let idx = bits as usize * 8 + pos; + table[idx] = Self::pre_compute(bits, pos); + } + } + let mut masks = [(0, 0, 0); BOARD_SIZE]; + let mut indices = [(0, 0); BOARD_SIZE]; + for pos in 0..BOARD_SIZE { + let row = pos / 8; + let col = pos % 8; + let col_mask = 0x0101010101010101 << col; + let diag1_mask = if row > col { + 0x8040201008040201 << ((row - col) * 8) + } else { + 0x8040201008040201 >> ((col - row) * 8) + }; + let diag1_idx = if row > col { col } else { row }; + let diag2_mask = if row + col > 7 { + 0x0102040810204080 << ((row + col - 7) * 8) + } else { + 0x0102040810204080 >> ((7 - row - col) * 8) + }; + let diag2_idx = if row + col > 7 { 7 - col } else { row }; + masks[pos] = (col_mask, diag1_mask, diag2_mask); + indices[pos] = (diag1_idx as u8, diag2_idx as u8); + } + LastCache { + table, + masks, + indices, + } + } + + unsafe fn solve_last_impl(&self, board: Board) -> (i8, usize) { + let pos = board.empty().tzcnt() as usize; + let row = pos >> 3; + let col = pos & 0b111; + let row_bits = (board.player >> (row * 8)) & 0xff; + let &(col_mask, diag1_mask, diag2_mask) = self.masks.get_unchecked(pos); + let &(diag1_idx, diag2_idx) = self.indices.get_unchecked(pos); + let &row_score = self.table.get_unchecked((row_bits as usize) * 8 + col); + let col_bits = pext(board.player, col_mask); + let &col_score = self.table.get_unchecked((col_bits as usize) * 8 + row); + let diag1_bits = pext(board.player, diag1_mask); + let &diag1_score = self + .table + .get_unchecked((diag1_bits as usize) * 8 + diag1_idx as usize); + let diag2_bits = pext(board.player, diag2_mask); + let &diag2_score = self + .table + .get_unchecked((diag2_bits as usize) * 8 + diag2_idx as usize); + let pcnt = popcnt(board.player); + let ocnt = 63 - pcnt; + let diff_first = row_score.0 + col_score.0 + diag1_score.0 + diag2_score.0; + if diff_first > 0 { + (pcnt - ocnt + 2 * diff_first + 1, 1) + } else { + let diag1_bits_second = pext(board.opponent, diag1_mask); + let &diag1_score_second = self + .table + .get_unchecked((diag1_bits_second as usize) * 8 + diag1_idx as usize); + let diag2_bits_second = pext(board.opponent, diag2_mask); + let &diag2_score_second = self + .table + .get_unchecked((diag2_bits_second as usize) * 8 + diag2_idx as usize); + let diff_second = row_score.1 + col_score.1 + diag1_score_second.0 + diag2_score_second.0; + if diff_second > 0 { + (pcnt - ocnt - 2 * diff_second - 1, 2) + } else if pcnt > ocnt { + (64 - 2 * ocnt, 0) + } else { + (2 * pcnt - 64, 0) + } + } + } + + pub fn solve_last(&self, board: Board) -> (i8, usize) { + unsafe { self.solve_last_impl(board) } + } +} diff --git a/src/engine/last_cache/test.rs b/src/engine/last_cache/test.rs new file mode 100644 index 0000000..083bc51 --- /dev/null +++ b/src/engine/last_cache/test.rs @@ -0,0 +1,32 @@ +use super::*; + +use rand::{Rng, SeedableRng}; + +fn solve_last_naive(board: Board) -> (i8, usize) { + let pos = board.empty().tzcnt() as usize; + match board.play(pos) { + Some(next) => (-next.score(), 1), + None => match board.pass_unchecked().play(pos) { + Some(next) => (next.score(), 2), + None => (board.score(), 0), + }, + } +} + +#[test] +fn test_last_cache() { + // gen data + let mut rng = rand_xoshiro::Xoshiro256StarStar::seed_from_u64(0xDEADBEAF); + const LENGTH: usize = 256; + let last_cache = LastCache::new(); + // last_cache + for _ in 0..LENGTH { + let bit = rng.gen::(); + let pos = rng.gen_range(0..BOARD_SIZE); + let pos_mask = !(1 << pos); + let player = bit & pos_mask; + let opponent = !bit & pos_mask; + let board = Board { player, opponent }; + assert_eq!(last_cache.solve_last(board), solve_last_naive(board)); + } +} diff --git a/src/engine/search.rs b/src/engine/search.rs index d56ef1a..a25fcfe 100644 --- a/src/engine/search.rs +++ b/src/engine/search.rs @@ -4,10 +4,12 @@ use crate::engine::bits::*; use crate::engine::board::*; use crate::engine::eval::*; use crate::engine::hand::*; +use crate::engine::last_cache::*; use crate::engine::midgame::*; use crate::engine::table::*; use crate::engine::think::*; use anyhow::Result; +use arrayvec::ArrayVec; use crc64::Crc64; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -51,6 +53,7 @@ pub struct SolveObj { pub res_cache: Arc, pub eval_cache: Arc, pub evaluator: Arc, + pub last_cache: Arc, pub params: SearchParams, } @@ -65,6 +68,7 @@ impl SolveObj { res_cache, eval_cache, evaluator, + last_cache: Arc::new(LastCache::new()), params, } } @@ -207,7 +211,8 @@ fn calc_max_depth(rem: i8) -> i8 { } pub fn move_ordering_impl(solve_obj: &mut SolveObj, board: Board, _old_best: Option) -> Vec<(Hand, Board)> { - let mut nexts = Vec::with_capacity(32); + const MAX_NEXT_COUNT: usize = 32; + let mut nexts = ArrayVec::<_, MAX_NEXT_COUNT>::new(); for (next, pos) in board.next_iter() { nexts.push((0, pos, next)); } @@ -215,7 +220,7 @@ pub fn move_ordering_impl(solve_obj: &mut SolveObj, board: Board, _old_best: Opt let rem = popcnt(board.empty()); let max_depth = calc_max_depth(rem); let min_depth = (max_depth - 3).max(0); - let mut tmp = Vec::with_capacity(32); + let mut tmp = ArrayVec::<_, MAX_NEXT_COUNT>::new(); for think_depth in min_depth..=max_depth { tmp.clear(); for &(_score, pos, next) in nexts.iter() { diff --git a/src/main.rs b/src/main.rs index d9b1c1f..bd47684 100644 --- a/src/main.rs +++ b/src/main.rs @@ -64,9 +64,11 @@ struct Stat { fn solve_ffo(name: &str, index: &mut usize, solve_obj: &mut SolveObj, workers: &[String]) -> Vec { let file = File::open(name).unwrap(); let reader = BufReader::new(file); + let mut total_nodes = 0; println!("|No.|empties|result|answer|move|nodes|time|NPS|"); println!("|---:|---:|---:|---:|---:|---:|:--:|---:|"); let mut stats = Vec::new(); + let global_start = Instant::now(); for line in reader.lines() { let line_str = line.unwrap(); let desired: i8 = line_str[71..].split(';').next().unwrap().parse().unwrap(); @@ -105,10 +107,18 @@ fn solve_ffo(name: &str, index: &mut usize, solve_obj: &mut SolveObj, workers: & correct: res == desired, }); *index += 1; + total_nodes += stat.node_count; } Err(_) => println!("Parse error"), } } + let micro_seconds = global_start.elapsed().as_micros(); + let nps = (total_nodes * 1000000) as u128 / micro_seconds; + println!( + "[Total] elapsed: {}us, node count: {}, NPS: {}nodes/sec", + micro_seconds, total_nodes, nps, + ); + stats } diff --git a/src/play.rs b/src/play.rs index 13d89c7..634c9e4 100644 --- a/src/play.rs +++ b/src/play.rs @@ -75,8 +75,8 @@ pub fn play(matches: &ArgMatches) -> Board { match hand { Hand::Pass => board = board.pass_unchecked(), Hand::Play(hand) => match board.play(hand) { - Ok(next) => board = next, - Err(_) => println!("Invalid move"), + Some(next) => board = next, + None => println!("Invalid move"), }, } } @@ -131,8 +131,8 @@ pub fn self_play(matches: &ArgMatches) -> Board { match hand { Hand::Pass => board = board.pass_unchecked(), Hand::Play(hand) => match board.play(hand) { - Ok(next) => board = next, - Err(_) => println!("Invalid move"), + Some(next) => board = next, + None => println!("Invalid move"), }, } } @@ -149,11 +149,11 @@ fn self_play_worker(solve_obj: SolveObj, sub_solver: Arc, initial_rec match hand { Hand::Pass => board = board.pass_unchecked(), Hand::Play(pos) => match board.play(*pos) { - Ok(next) => { + Some(next) => { write!(&mut record_str, "{}", hand).unwrap(); board = next; } - Err(_) => panic!(), + None => panic!(), }, } } @@ -185,11 +185,11 @@ fn self_play_worker(solve_obj: SolveObj, sub_solver: Arc, initial_rec match hand { Hand::Pass => board = board.pass_unchecked(), Hand::Play(pos) => match board.play(pos) { - Ok(next) => { + Some(next) => { write!(&mut record_str, "{}", hand).unwrap(); board = next; } - Err(_) => panic!(), + None => panic!(), }, } } diff --git a/src/record.rs b/src/record.rs index f65af06..90a0328 100644 --- a/src/record.rs +++ b/src/record.rs @@ -28,7 +28,7 @@ impl Record { for i in 0..(l / 2) { let h = Hand::from_str(&record_str[(2 * i)..(2 * i + 2)])?; hands.push(h); - board = board.play_hand(h)?; + board = board.play_hand(h).ok_or(UnmovableError{})?; } let score = if let Some(score) = splitted.get(1) { score.parse().unwrap() @@ -52,7 +52,7 @@ impl Record { }; for &h in &self.hands { res.push((board, h, score)); - board = board.play_hand(h)?; + board = board.play_hand(h).ok_or(UnmovableError{})?; score = -score; } res.push((board, Hand::Pass, score)); diff --git a/src/train.rs b/src/train.rs index 0708735..900636e 100644 --- a/src/train.rs +++ b/src/train.rs @@ -52,15 +52,12 @@ pub fn parse_record(line: &str) -> Vec { pub fn step_by_pos(board: &Board, pos: usize) -> Option { match board.play(pos) { - Ok(next) => Some(next), - Err(_) => { + Some(next) => Some(next), + None => { if !board.mobility().is_empty() { None } else { - match board.pass_unchecked().play(pos) { - Ok(next) => Some(next), - Err(_) => None, - } + board.pass_unchecked().play(pos) } } }