import torch from cpe2wdl import cpe2wdl # Mapping of FEN piece characters to indices in the tensor piece_to_index = { 'p': 5, 'n': 4, 'b': 3, 'r': 2, 'q': 1, 'k': 0 } def parse_fen(fen, cpe): board_tensor = torch.zeros((8, 8, 8), dtype=torch.float16) parts = fen.split() piece_placement = parts[0] active_color = parts[1] castling_rights = parts[2] en_passant_target = parts[3] current_side = 1 if active_color == 'w' else -1 rows = piece_placement.split('/') for row_idx, row in enumerate(rows): col_idx = 0 for char in row: if char.isdigit(): col_idx += int(char) else: piece_side = current_side if char.isupper() else -current_side board_tensor[row_idx, col_idx, piece_to_index[char.lower()]] = piece_side col_idx += 1 if en_passant_target != '-': file = ord(en_passant_target[0]) - ord('a') rank = 8 - int(en_passant_target[1]) board_tensor[rank, file, 6] = 1 for char in castling_rights: if char == 'K': board_tensor[7, 7, 7] = 1 elif char == 'Q': board_tensor[7, 0, 7] = 1 elif char == 'k': board_tensor[0, 7, 7] = 1 elif char == 'q': board_tensor[0, 0, 7] = 1 wdl = cpe2wdl(torch.tensor(cpe if current_side == 1 else -cpe, dtype=torch.float16)) if current_side == -1: board_tensor = torch.flip(board_tensor, [0]) return board_tensor, wdl if __name__ == '__main__': fen = "r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6" # fen = "r2qk2r/3n2p1/1pp1pP2/3p2b1/P2P1nBp/1NB4P/1PP2P2/R3QR1K b kq -" tensor = parse_fen(fen, 200) print(tensor)