Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Last cache v2 #17

Merged
merged 5 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ reqwest = { version = "0.11.12", features = ["json"], default-features = false }
crc64 = "2.0.0"
anyhow = "1.0"
thiserror = "1.0"
arrayvec = "0.7.4"
1 change: 1 addition & 0 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod board;
pub mod endgame;
pub mod eval;
pub mod hand;
pub mod last_cache;
pub mod midgame;
pub mod search;
pub mod table;
Expand Down
28 changes: 14 additions & 14 deletions src/engine/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,24 @@ impl Board {
self.flip(pos) != 0
}

pub fn play(&self, pos: usize) -> Result<Board> {
pub fn play(&self, pos: usize) -> Option<Board> {
if pos >= BOARD_SIZE {
return Err(UnmovableError {}.into());
return None;
}
if ((self.player >> pos) & 1) != 0 || ((self.opponent >> pos) & 1) != 0 {
return Err(UnmovableError {}.into());
return None;
}
let flip_bits = self.flip(pos);
if flip_bits == 0 {
return Err(UnmovableError {}.into());
return None;
}
Ok(Board {
Some(Board {
player: self.opponent ^ flip_bits,
opponent: (self.player ^ flip_bits) | (1u64 << pos),
})
}

pub fn play_hand(&self, hand: Hand) -> Result<Board> {
pub fn play_hand(&self, hand: Hand) -> Option<Board> {
match hand {
Hand::Play(pos) => self.play(pos),
Hand::Pass => self.pass(),
Expand All @@ -171,14 +171,14 @@ impl Board {
}
}

pub fn pass(&self) -> Result<Board> {
pub fn pass(&self) -> Option<Board> {
if self.mobility_bits() == 0 {
Ok(Board {
Some(Board {
player: self.opponent,
opponent: self.player,
})
} else {
Err(UnmovableError {}.into())
None
}
}

Expand Down Expand Up @@ -446,9 +446,9 @@ impl BoardWithColor {
}
}

pub fn play(&self, pos: usize) -> Result<BoardWithColor> {
pub fn play(&self, pos: usize) -> Option<BoardWithColor> {
let board = self.board.play(pos)?;
Ok(BoardWithColor {
Some(BoardWithColor {
board,
is_black: !self.is_black,
})
Expand All @@ -461,8 +461,8 @@ impl BoardWithColor {
}
}

pub fn pass(&self) -> Result<BoardWithColor> {
Ok(BoardWithColor {
pub fn pass(&self) -> Option<BoardWithColor> {
Some(BoardWithColor {
board: self.board.pass()?,
is_black: !self.is_black,
})
Expand Down Expand Up @@ -558,7 +558,7 @@ impl Iterator for PlayIterator {
while self.remain != 0 {
let pos = self.remain.trailing_zeros() as usize;
self.remain &= self.remain - 1;
if let Ok(next) = self.board.play(pos) {
if let Some(next) = self.board.play(pos) {
return Some((next, Hand::Play(pos)));
}
}
Expand Down
46 changes: 17 additions & 29 deletions src/engine/endgame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,19 @@ use crate::engine::board::*;
use crate::engine::hand::*;
use crate::engine::search::*;
use crate::engine::table::*;
use arrayvec::ArrayVec;
use bitintr::Tzcnt;
use std::cmp::max;
use std::mem::MaybeUninit;

fn near_leaf(board: Board) -> (i8, SolveStat) {
let bit = board.empty();
let pos = bit.tzcnt() as usize;
match board.play(pos) {
Ok(next) => (-next.score(), SolveStat::one()),
Err(_) => (
match board.pass_unchecked().play(pos) {
Ok(next) => next.score(),
Err(_) => board.score(),
},
SolveStat {
node_count: 2,
st_cut_count: 0,
},
),
}
fn near_leaf(solve_obj: &mut SolveObj, board: Board) -> (i8, SolveStat) {
let (score, node_count) = solve_obj.last_cache.solve_last(board);
(
score,
SolveStat {
node_count,
st_cut_count: 0,
},
)
}

fn naive(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8, i8), passed: bool) -> (i8, SolveStat) {
Expand Down Expand Up @@ -68,7 +61,7 @@ fn static_order(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8,
while remain != 0 {
let pos = remain.tzcnt() as usize;
remain = remain & (remain - 1);
if let Ok(next) = board.play(pos) {
if let Some(next) = board.play(pos) {
pass = false;
let (child_res, child_stat) = solve_inner(solve_obj, next, (-beta, -alpha), false);
res = max(res, -child_res);
Expand Down Expand Up @@ -112,19 +105,14 @@ fn negascout_impl(solve_obj: &mut SolveObj, next: Board, (alpha, beta): (i8, i8)

fn fastest_first(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8, i8), passed: bool) -> (i8, SolveStat) {
const MAX_FFS_NEXT: usize = 20;
let nexts = MaybeUninit::<[(i8, Board); MAX_FFS_NEXT]>::uninit();
let mut nexts = unsafe { nexts.assume_init() };
let mut count = 0;
let mut nexts = ArrayVec::<_, MAX_FFS_NEXT>::new();
for (next, _pos) in board.next_iter() {
nexts[count] = (weighted_mobility(&next), next);
count += 1;
nexts.push((weighted_mobility(&next), next));
}
assert!(count <= MAX_FFS_NEXT);

nexts[0..count].sort_by(|a, b| a.0.cmp(&b.0));
nexts.sort_by(|a, b| a.0.cmp(&b.0));
let mut res = -(BOARD_SIZE as i8);
let mut stat = SolveStat::one();
for (i, &(_, next)) in nexts[0..count].iter().enumerate() {
for (i, &(_, next)) in nexts.iter().enumerate() {
let (child_res, child_stat) = negascout_impl(solve_obj, next, (alpha, beta), i == 0);
res = max(res, -child_res);
stat.merge(child_stat);
Expand All @@ -133,7 +121,7 @@ fn fastest_first(solve_obj: &mut SolveObj, board: Board, (mut alpha, beta): (i8,
return (res, stat);
}
}
if count == 0 {
if nexts.is_empty() {
if passed {
return (board.score(), stat);
} else {
Expand Down Expand Up @@ -190,7 +178,7 @@ pub fn solve_inner(
if rem == 0 {
(board.score(), SolveStat::zero())
} else if rem == 1 {
near_leaf(board)
near_leaf(solve_obj, board)
} else if rem < solve_obj.params.static_ordering_limit {
naive(solve_obj, board, (alpha, beta), passed)
} else if rem < solve_obj.params.ffs_ordering_limit {
Expand Down
113 changes: 113 additions & 0 deletions src/engine/last_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#[cfg(test)]
mod test;
use crate::engine::bits::*;
use crate::engine::board::*;
use bitintr::Tzcnt;

pub struct LastCache {
table: [(i8, i8); 4096],
masks: [(u64, u64, u64); BOARD_SIZE],
indices: [(u8, u8); BOARD_SIZE],
}

impl LastCache {
fn pre_compute(bits: u64, pos: usize) -> (i8, i8) {
let opp_mask = 0xff ^ (1 << pos);
let opp = !bits & opp_mask;
let board_first = Board {
player: bits,
opponent: opp,
};
let board_second = Board {
player: opp,
opponent: bits,
};
(
popcnt(board_first.flip_unchecked(pos)),
popcnt(board_second.flip_unchecked(pos)),
)
}

pub fn new() -> LastCache {
let mut table = [(0, 0); 4096];
for bits in 0..256 {
for pos in 0..8 {
let idx = bits as usize * 8 + pos;
table[idx] = Self::pre_compute(bits, pos);
}
}
let mut masks = [(0, 0, 0); BOARD_SIZE];
let mut indices = [(0, 0); BOARD_SIZE];
for pos in 0..BOARD_SIZE {
let row = pos / 8;
let col = pos % 8;
let col_mask = 0x0101010101010101 << col;
let diag1_mask = if row > col {
0x8040201008040201 << ((row - col) * 8)
} else {
0x8040201008040201 >> ((col - row) * 8)
};
let diag1_idx = if row > col { col } else { row };
let diag2_mask = if row + col > 7 {
0x0102040810204080 << ((row + col - 7) * 8)
} else {
0x0102040810204080 >> ((7 - row - col) * 8)
};
let diag2_idx = if row + col > 7 { 7 - col } else { row };
masks[pos] = (col_mask, diag1_mask, diag2_mask);
indices[pos] = (diag1_idx as u8, diag2_idx as u8);
}
LastCache {
table,
masks,
indices,
}
}

unsafe fn solve_last_impl(&self, board: Board) -> (i8, usize) {
let pos = board.empty().tzcnt() as usize;
let row = pos >> 3;
let col = pos & 0b111;
let row_bits = (board.player >> (row * 8)) & 0xff;
let &(col_mask, diag1_mask, diag2_mask) = self.masks.get_unchecked(pos);
let &(diag1_idx, diag2_idx) = self.indices.get_unchecked(pos);
let &row_score = self.table.get_unchecked((row_bits as usize) * 8 + col);
let col_bits = pext(board.player, col_mask);
let &col_score = self.table.get_unchecked((col_bits as usize) * 8 + row);
let diag1_bits = pext(board.player, diag1_mask);
let &diag1_score = self
.table
.get_unchecked((diag1_bits as usize) * 8 + diag1_idx as usize);
let diag2_bits = pext(board.player, diag2_mask);
let &diag2_score = self
.table
.get_unchecked((diag2_bits as usize) * 8 + diag2_idx as usize);
let pcnt = popcnt(board.player);
let ocnt = 63 - pcnt;
let diff_first = row_score.0 + col_score.0 + diag1_score.0 + diag2_score.0;
if diff_first > 0 {
(pcnt - ocnt + 2 * diff_first + 1, 1)
} else {
let diag1_bits_second = pext(board.opponent, diag1_mask);
let &diag1_score_second = self
.table
.get_unchecked((diag1_bits_second as usize) * 8 + diag1_idx as usize);
let diag2_bits_second = pext(board.opponent, diag2_mask);
let &diag2_score_second = self
.table
.get_unchecked((diag2_bits_second as usize) * 8 + diag2_idx as usize);
let diff_second = row_score.1 + col_score.1 + diag1_score_second.0 + diag2_score_second.0;
if diff_second > 0 {
(pcnt - ocnt - 2 * diff_second - 1, 2)
} else if pcnt > ocnt {
(64 - 2 * ocnt, 0)
} else {
(2 * pcnt - 64, 0)
}
}
}

pub fn solve_last(&self, board: Board) -> (i8, usize) {
unsafe { self.solve_last_impl(board) }
}
}
32 changes: 32 additions & 0 deletions src/engine/last_cache/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use super::*;

use rand::{Rng, SeedableRng};

fn solve_last_naive(board: Board) -> (i8, usize) {
let pos = board.empty().tzcnt() as usize;
match board.play(pos) {
Some(next) => (-next.score(), 1),
None => match board.pass_unchecked().play(pos) {
Some(next) => (next.score(), 2),
None => (board.score(), 0),
},
}
}

#[test]
fn test_last_cache() {
// gen data
let mut rng = rand_xoshiro::Xoshiro256StarStar::seed_from_u64(0xDEADBEAF);
const LENGTH: usize = 256;
let last_cache = LastCache::new();
// last_cache
for _ in 0..LENGTH {
let bit = rng.gen::<u64>();
let pos = rng.gen_range(0..BOARD_SIZE);
let pos_mask = !(1 << pos);
let player = bit & pos_mask;
let opponent = !bit & pos_mask;
let board = Board { player, opponent };
assert_eq!(last_cache.solve_last(board), solve_last_naive(board));
}
}
9 changes: 7 additions & 2 deletions src/engine/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use crate::engine::bits::*;
use crate::engine::board::*;
use crate::engine::eval::*;
use crate::engine::hand::*;
use crate::engine::last_cache::*;
use crate::engine::midgame::*;
use crate::engine::table::*;
use crate::engine::think::*;
use anyhow::Result;
use arrayvec::ArrayVec;
use crc64::Crc64;
use reqwest::Client;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -51,6 +53,7 @@ pub struct SolveObj {
pub res_cache: Arc<ResCacheTable>,
pub eval_cache: Arc<EvalCacheTable>,
pub evaluator: Arc<Evaluator>,
pub last_cache: Arc<LastCache>,
pub params: SearchParams,
}

Expand All @@ -65,6 +68,7 @@ impl SolveObj {
res_cache,
eval_cache,
evaluator,
last_cache: Arc::new(LastCache::new()),
params,
}
}
Expand Down Expand Up @@ -207,15 +211,16 @@ fn calc_max_depth(rem: i8) -> i8 {
}

pub fn move_ordering_impl(solve_obj: &mut SolveObj, board: Board, _old_best: Option<Hand>) -> Vec<(Hand, Board)> {
let mut nexts = Vec::with_capacity(32);
const MAX_NEXT_COUNT: usize = 32;
let mut nexts = ArrayVec::<_, MAX_NEXT_COUNT>::new();
for (next, pos) in board.next_iter() {
nexts.push((0, pos, next));
}

let rem = popcnt(board.empty());
let max_depth = calc_max_depth(rem);
let min_depth = (max_depth - 3).max(0);
let mut tmp = Vec::with_capacity(32);
let mut tmp = ArrayVec::<_, MAX_NEXT_COUNT>::new();
for think_depth in min_depth..=max_depth {
tmp.clear();
for &(_score, pos, next) in nexts.iter() {
Expand Down
Loading
Loading