import torch import numpy as np from card_model.model import CardModel from card_model.config import VOCAB_SIZE, EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS import os def parse_cards(card_str): """将 'As Ks' 等字符串转换为 0-51 的 ID 列表""" if not card_str.strip(): return [] cards = [] for c in card_str.strip().split(): rank = "23456789TJQKA".index(c[0].upper()) suit = "cdhs".index(c[1].lower()) cards.append(rank * 4 + suit) return cards def print_histogram(hist): """用 ASCII 打印简单的直方图""" print("\n胜率直方图分布 (0% -> 100%):") bars = [" "] * 50 for i, p in enumerate(hist): length = int(p * 50) # 缩放以便显示 print(f"[{i*2:02d}%-{(i+1)*2:02d}%]: {'█' * length} ({p*100:.1f}%)") def main(): # 1. 加载你刚刚训练好的最优模型 model_path = os.path.join(os.path.dirname(__file__), "data", "best_card_model.pt") model = CardModel(VOCAB_SIZE, EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() print("模型加载成功!你可以输入牌型进行测试 (例如: As Ks)。输入 'q' 退出。") while True: hole_str = input("\n请输入你的底牌 (2张, 例如 'As Ks'): ") if hole_str.lower() == 'q': break board_str = input("请输入公共牌 (0到5张, 例如 'Qs Js 2d' 或直接回车留空): ") try: hole_cards = parse_cards(hole_str) board_cards = parse_cards(board_str) if len(hole_cards) != 2: print("底牌必须是 2 张!") continue # 补齐公共牌到 5 张 padded_board = board_cards + [52] * (5 - len(board_cards)) x_hole = torch.tensor([hole_cards], dtype=torch.int64) x_board = torch.tensor([padded_board], dtype=torch.int64) with torch.no_grad(): pred_equity, pred_hist = model(x_hole, x_board) eq = pred_equity.item() hist = pred_hist[0].numpy() print("-" * 40) print(f"你的底牌: {hole_str} | 公共牌: {board_str if board_str else '(翻牌前)'}") print(f"=> 预测绝对胜率 (Equity): {eq * 100:.2f}%") print_histogram(hist) print("-" * 40) except Exception as e: print(f"输入格式错误: {e}") if __name__ == "__main__": main()