Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parse record #53

Merged
merged 2 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 16 additions & 40 deletions src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::engine::hand::*;
use anyhow::Result;
use std::fmt::*;
use std::fs::File;
use std::io::{BufRead, BufReader, Read};
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::str::FromStr;
use thiserror::Error;
Expand All @@ -23,8 +23,8 @@ pub struct ScoreIsNotRegistered {}

#[derive(Error, Debug)]
pub enum ParseRecordError {
#[error("Failed to parse hand")]
FailedToParseHand,
#[error("Failed to parse hand :{0}")]
FailedToParseHand(String),
#[error("invalid hand")]
InvalidHand,
}
Expand All @@ -46,11 +46,7 @@ impl Record {
let mut board = self.initial_board;
let mut res = Vec::new();
let final_score = self.final_score.ok_or(ScoreIsNotRegistered {})?;
let mut score = if self.hands.len() % 2 == 0 {
final_score
} else {
-final_score
};
Comment on lines -49 to -53
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final_score は手番から見たスコアではなく黒番から見たスコアだった

let mut score = final_score;
for &h in &self.hands {
res.push((board, h, score));
board = board.play_hand(h).ok_or(UnmovableError {})?;
Expand Down Expand Up @@ -82,9 +78,10 @@ impl FromStr for Record {
let splitted = record_str.split_ascii_whitespace().collect::<Vec<_>>();
let l = splitted[0].len();
for i in 0..(l / 2) {
let h = splitted[0][(2 * i)..(2 * i + 2)]
let hand_s = &splitted[0][(2 * i)..(2 * i + 2)];
let h = hand_s
.parse::<Hand>()
.or(Err(ParseRecordError::FailedToParseHand))?;
.or(Err(ParseRecordError::FailedToParseHand(hand_s.to_string())))?;
board = match board.play_hand(h) {
Some(next) => next,
None => {
Expand All @@ -103,43 +100,22 @@ impl FromStr for Record {
let score = if let Some(score) = splitted.get(1) {
score.parse().ok()
} else if board.is_gameover() {
Some(board.score() as i16)
let absolute_score = if l % 2 == 0 {
board.score()
} else {
-board.score()
};
Some(absolute_score as i16)
} else {
None
};
Ok(Record::new(Board::initial_state(), &hands, score))
}
}

pub struct LoadRecords<R: Read> {
reader: BufReader<R>,
buffer: String,
remain: usize,
}

impl<R: Read> Iterator for LoadRecords<R> {
type Item = Result<Record>;
fn next(&mut self) -> Option<Self::Item> {
if self.remain > 0 {
self.remain -= 1;
self.reader.read_line(&mut self.buffer).ok()?;
return Some(self.buffer.parse::<Record>().map_err(|e| e.into()));
}
None
}
}

pub fn load_records(path: &Path) -> Result<LoadRecords<File>> {
pub fn load_records(path: &Path) -> Result<impl Iterator<Item = Result<Record, ParseRecordError>>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut buffer = String::new();

reader.read_line(&mut buffer)?;
let remain = buffer.trim().parse()?;
let reader = BufReader::new(file);

Ok(LoadRecords {
reader,
buffer,
remain,
})
Ok(reader.lines().map(|line| line.unwrap().parse::<Record>()))
}
38 changes: 26 additions & 12 deletions src/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;
use std::str;
use std::sync::Arc;
use rand::prelude::*;

pub fn clean_record(matches: &ArgMatches) {
let input_path = matches.get_one::<String>("INPUT").unwrap();
let output_path = matches.get_one::<String>("OUTPUT").unwrap();

let mut result = Vec::new();
for record in load_records(Path::new(input_path)).unwrap() {
if let Ok(record) = record {
if let Ok(_timeline) = record.timeline() {
result.push(record);
}
let Ok(record) = record else { continue; };
if let Ok(_timeline) = record.timeline() {
result.push(record);
}
}

Expand All @@ -47,16 +47,20 @@ pub fn gen_dataset(matches: &ArgMatches) {
.unwrap()
.parse::<usize>()
.unwrap();
let mut rng = rand::thread_rng();

eprintln!("Parse input...");
let mut boards_with_results = Vec::new();
for record in load_records(Path::new(input_path)).unwrap() {
let mut timeline = record.unwrap().timeline().unwrap();
let record = record.unwrap();
let mut timeline = record.timeline().unwrap();
boards_with_results.append(&mut timeline);
}

eprintln!("Total board count = {}", boards_with_results.len());

boards_with_results.shuffle(&mut rng);

eprintln!("Writing to file...");
let out_f = File::create(output_path).unwrap();
let mut writer = BufWriter::new(out_f);
Expand All @@ -71,13 +75,23 @@ pub fn gen_dataset(matches: &ArgMatches) {
if idx >= max_output {
break;
}
if let Hand::Play(pos) = hand {
writeln!(
&mut writer,
"{:016x} {:016x} {} {}",
board.player, board.opponent, score, pos,
)
.unwrap();
match hand {
Hand::Play(pos) => {
writeln!(
&mut writer,
"{:016x} {:016x} {} {}",
board.player, board.opponent, score, pos,
)
.unwrap();
}
Hand::Pass => {
writeln!(
&mut writer,
"{:016x} {:016x} {} ps",
board.player, board.opponent, score,
)
.unwrap();
}
Comment on lines +87 to +94
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

パス時がデータセットに入らないバグを修正

}
}
eprintln!("Finished!");
Expand Down
Loading