chectus/chectus_net/fen2vec.py
2024-07-28 14:48:02 +08:00

57 lines
1.5 KiB
Python

import torch
from cpe2wdl import cpe2wdl
# Mapping of FEN piece characters to indices in the tensor
piece_to_index = {
'p': 5, 'n': 4, 'b': 3, 'r': 2, 'q': 1, 'k': 0
}
def parse_fen(fen, cpe):
board_tensor = torch.zeros((8, 8, 8), dtype=torch.float16)
parts = fen.split()
piece_placement = parts[0]
active_color = parts[1]
castling_rights = parts[2]
en_passant_target = parts[3]
current_side = 1 if active_color == 'w' else -1
rows = piece_placement.split('/')
for row_idx, row in enumerate(rows):
col_idx = 0
for char in row:
if char.isdigit():
col_idx += int(char)
else:
piece_side = current_side if char.isupper() else -current_side
board_tensor[row_idx, col_idx, piece_to_index[char.lower()]] = piece_side
col_idx += 1
if en_passant_target != '-':
file = ord(en_passant_target[0]) - ord('a')
rank = 8 - int(en_passant_target[1])
board_tensor[rank, file, 6] = 1
for char in castling_rights:
if char == 'K':
board_tensor[7, 7, 7] = 1
elif char == 'Q':
board_tensor[7, 0, 7] = 1
elif char == 'k':
board_tensor[0, 7, 7] = 1
elif char == 'q':
board_tensor[0, 0, 7] = 1
wdl = cpe2wdl(torch.tensor(cpe if current_side == 1 else -cpe, dtype=torch.float16))
if current_side == -1:
board_tensor = torch.flip(board_tensor, [0])
return board_tensor, wdl
if __name__ == '__main__':
fen = "r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6"
# fen = "r2qk2r/3n2p1/1pp1pP2/3p2b1/P2P1nBp/1NB4P/1PP2P2/R3QR1K b kq -"
tensor = parse_fen(fen, 200)
print(tensor)