Files
new/card_model/test_inference.py
e2hang ed2fadb625 What
2026-04-20 20:25:35 +08:00

70 lines
2.5 KiB
Python

import torch
import numpy as np
from card_model.model import CardModel
from card_model.config import VOCAB_SIZE, EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS
import os
def parse_cards(card_str):
"""'As Ks' 等字符串转换为 0-51 的 ID 列表"""
if not card_str.strip():
return []
cards = []
for c in card_str.strip().split():
rank = "23456789TJQKA".index(c[0].upper())
suit = "cdhs".index(c[1].lower())
cards.append(rank * 4 + suit)
return cards
def print_histogram(hist):
"""用 ASCII 打印简单的直方图"""
print("\n胜率直方图分布 (0% -> 100%):")
bars = [" "] * 50
for i, p in enumerate(hist):
length = int(p * 50) # 缩放以便显示
print(f"[{i*2:02d}%-{(i+1)*2:02d}%]: {'' * length} ({p*100:.1f}%)")
def main():
# 1. 加载你刚刚训练好的最优模型
model_path = os.path.join(os.path.dirname(__file__), "data", "best_card_model.pt")
model = CardModel(VOCAB_SIZE, EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
print("模型加载成功!你可以输入牌型进行测试 (例如: As Ks)。输入 'q' 退出。")
while True:
hole_str = input("\n请输入你的底牌 (2张, 例如 'As Ks'): ")
if hole_str.lower() == 'q': break
board_str = input("请输入公共牌 (0到5张, 例如 'Qs Js 2d' 或直接回车留空): ")
try:
hole_cards = parse_cards(hole_str)
board_cards = parse_cards(board_str)
if len(hole_cards) != 2:
print("底牌必须是 2 张!")
continue
# 补齐公共牌到 5 张
padded_board = board_cards + [52] * (5 - len(board_cards))
x_hole = torch.tensor([hole_cards], dtype=torch.int64)
x_board = torch.tensor([padded_board], dtype=torch.int64)
with torch.no_grad():
pred_equity, pred_hist = model(x_hole, x_board)
eq = pred_equity.item()
hist = pred_hist[0].numpy()
print("-" * 40)
print(f"你的底牌: {hole_str} | 公共牌: {board_str if board_str else '(翻牌前)'}")
print(f"=> 预测绝对胜率 (Equity): {eq * 100:.2f}%")
print_histogram(hist)
print("-" * 40)
except Exception as e:
print(f"输入格式错误: {e}")
if __name__ == "__main__":
main()