Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use const array instead of lazy_static #26

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.2.0"
authors = ["prime <prime@kmc.gr.jp>"]

[dependencies]
lazy_static = "1.4"
rand = { version = "0.8", features = ["small_rng"] }
rand_xoshiro = "0.6"
futures = { version = "0.3", features = ["std"] }
Expand Down
33 changes: 17 additions & 16 deletions src/engine/bits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use core::arch::x86_64::{_pdep_u64, _pext_u64};
use lazy_static::lazy_static;

pub fn popcnt(x: u64) -> i8 {
x.count_ones() as i8
Expand Down Expand Up @@ -53,23 +52,25 @@ pub fn pdep(x: u64, mask: u64) -> u64 {
unsafe { _pdep_u64(x, mask) }
}

lazy_static! {
pub static ref BASE3: [usize; 256] = {
let mut res = [0usize; 256];
for x in 0..256 {
let mut pow3 = 1;
let mut sum = 0;
for i in 0..8 {
if ((x >> i) & 1) == 1 {
sum += pow3;
}
pow3 *= 3;
pub const BASE3: [usize; 256] = {
let mut res = [0usize; 256];
let mut x = 0;
while x < 256 {
let mut pow3 = 1;
let mut sum = 0;
let mut i = 0;
while i < 8 {
if ((x >> i) & 1) == 1 {
sum += pow3;
}
res[x] = sum;
pow3 *= 3;
i += 1;
}
res
};
}
res[x] = sum;
x += 1;
}
res
};

#[cfg(test)]
mod tests {
Expand Down
135 changes: 100 additions & 35 deletions src/engine/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::engine::hand::*;
use anyhow::Result;
use clap::ArgMatches;
use core::arch::x86_64::*;
use lazy_static::lazy_static;
use std::cmp::min;
use std::fmt;
use std::io::{BufWriter, Write};
Expand Down Expand Up @@ -33,7 +32,7 @@ pub struct PlayIterator {

pub const BOARD_SIZE: usize = 64;

#[cfg(all(target_feature = "avx512cd", target_feature="avx512vl"))]
#[cfg(all(target_feature = "avx512cd", target_feature = "avx512vl"))]
unsafe fn smart_upper_bit(x: __m256i) -> __m256i {
let y = _mm256_lzcnt_epi64(x);
_mm256_srlv_epi64(_mm256_set1_epi64x(0x8000_0000_0000_0000u64 as i64), y)
Expand All @@ -48,6 +47,14 @@ unsafe fn smart_upper_bit(mut x: __m256i) -> __m256i {
_mm256_andnot_si256(lowers, x)
}

const fn smart_upper_bit_scalar(mut x: u64, lane: usize) -> u64 {
x |= x >> [8, 1, 7, 9][lane];
x |= x >> [16, 2, 14, 18][lane];
x |= x >> [32, 4, 28, 36][lane];
let lowers = x >> [8, 1, 7, 9][lane];
!lowers & x
}

#[allow(dead_code)]
unsafe fn upper_bit(mut x: __m256i) -> __m256i {
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 1));
Expand Down Expand Up @@ -125,10 +132,51 @@ impl Board {
reduce_or(flipped)
}

#[cfg(all(target_feature = "avx2"))]
pub fn flip_unchecked(&self, pos: usize) -> u64 {
unsafe { self.flip_simd(pos) }
}

pub const fn flip_naive(&self, pos: usize) -> u64 {
let o_mask = 0x7E7E_7E7E_7E7E_7E7Eu64;
let om = [
self.opponent,
self.opponent & o_mask,
self.opponent & o_mask,
self.opponent & o_mask,
];
let mask1 = [
0x0080808080808080u64,
0x7f00000000000000u64,
0x0102040810204000u64,
0x0040201008040201u64,
];
let mask2 = [
0x0101010101010100u64,
0x00000000000000feu64,
0x0002040810204080u64,
0x8040201008040200u64,
];
let mut flipped = 0;
let mut i = 0;
while i < 4 {
let mask = mask1[i] >> (63 - pos);
let outflank = smart_upper_bit_scalar(!om[i] & mask, i) & self.player;
flipped |= (outflank.wrapping_neg() << 1) & mask;
let mask = mask2[i] << pos;
let outflank = !((!om[i] & mask).wrapping_sub(1)) & mask & self.player;
flipped |= !((if outflank == 0 {
0xFFFF_FFFF_FFFF_FFFFu64
} else {
0
})
.wrapping_sub(outflank))
& mask;
i += 1;
}
flipped
}

pub fn flip(&self, pos: usize) -> u64 {
if ((self.empty() >> pos) & 1) == 0 {
0
Expand All @@ -137,6 +185,14 @@ impl Board {
}
}

pub const fn flip_const(&self, pos: usize) -> u64 {
if ((self.empty() >> pos) & 1) == 0 {
0
} else {
self.flip_naive(pos)
}
}

pub fn is_movable(&self, pos: usize) -> bool {
if pos >= BOARD_SIZE {
return false;
Expand Down Expand Up @@ -171,7 +227,7 @@ impl Board {
}
}

pub fn pass_unchecked(&self) -> Board {
pub const fn pass_unchecked(&self) -> Board {
Board {
player: self.opponent,
opponent: self.player,
Expand All @@ -189,7 +245,7 @@ impl Board {
}
}

pub fn empty(&self) -> u64 {
pub const fn empty(&self) -> u64 {
!(self.player | self.opponent)
}

Expand Down Expand Up @@ -605,17 +661,19 @@ pub fn weighted_mobility(board: &Board) -> i8 {
popcnt(b) + popcnt(b & corner)
}

fn stable_bits_8(board: Board, passed: bool, memo: &mut [Option<u64>]) -> u64 {
const fn stable_bits_8(board: Board, passed: bool, memo: &[Option<u64>]) -> u64 {
let index = BASE3[board.player as usize] + 2 * BASE3[board.opponent as usize];
if let Some(res) = memo[index] {
return res;
}
let mut res = 0xFF;
for pos in 0..8 {
let mut pos = 0;
while pos < 8 {
if ((board.empty() >> pos) & 1) != 1 {
pos += 1;
continue;
}
let flip = board.flip(pos);
let flip = board.flip_const(pos);
let pos_bit = 1 << pos;
let next = Board {
player: board.opponent ^ flip,
Expand All @@ -624,11 +682,12 @@ fn stable_bits_8(board: Board, passed: bool, memo: &mut [Option<u64>]) -> u64 {
res &= !flip;
res &= !pos_bit;
res &= stable_bits_8(next, false, memo);
pos += 1;
}
if !passed {
let next = board.pass_unchecked();
res &= stable_bits_8(next, true, memo);
memo[index] = Some(res);
//memo[index] = Some(res);
}
res
}
Expand All @@ -643,32 +702,38 @@ pub fn parse_board(matches: &ArgMatches) {
println!("{}", board);
}

lazy_static! {
static ref STABLE: [u64; 6561] = {
let mut memo = [None; 6561];
for i in 0..6561 {
let mut me = 0;
let mut op = 0;
let mut tmp = i;
for j in 0..8 {
let state = tmp % 3;
match state {
1 => me |= 1 << j,
2 => op |= 1 << j,
_ => (),
}
tmp /= 3;
const STABLE: [u64; 6561] = {
let mut memo = [None; 6561];
let mut ri = 0;
while ri < 6561 {
let i = 6561 - ri - 1;
let mut me = 0;
let mut op = 0;
let mut tmp = i;
let mut j = 0;
while j < 8 {
let state = tmp % 3;
match state {
1 => me |= 1 << j,
2 => op |= 1 << j,
_ => (),
}
let board = Board {
player: me,
opponent: op,
};
stable_bits_8(board, false, &mut memo);
}
let mut res = [0; 6561];
for i in 0..6561 {
res[i] = memo[i].unwrap() & 0xFF;
tmp /= 3;
j += 1;
}
res
};
}
let board = Board {
player: me,
opponent: op,
};
let res = stable_bits_8(board, false, &memo);
memo[i] = Some(res);
ri += 1;
}
let mut res = [0; 6561];
let mut i = 0;
while i < 6561 {
res[i] = memo[i].unwrap() & 0xFF;
i += 1;
}
res
};
6 changes: 2 additions & 4 deletions src/engine/board/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ impl From<NaiveBoard> for Board {
_ => (),
}
}
Board {
player,
opponent,
}
Board { player, opponent }
}
}

Expand Down Expand Up @@ -202,6 +199,7 @@ fn test_ops() {
assert_eq!(board, Board::from(naive_board.clone()));
for i in 0..BOARD_SIZE {
assert_eq!(board.flip(i), naive_board.flip(i));
assert_eq!(board.flip_const(i), naive_board.flip(i));
assert_eq!(board.is_movable(i), naive_board.is_movable(i));
if board.is_movable(i) {
assert_eq!(
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(const_option)]
#![feature(test)]
mod book;
mod engine;
Expand Down
Loading