70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
import torch
|
|
import numpy as np
|
|
from card_model.model import CardModel
|
|
from card_model.config import VOCAB_SIZE, EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS
|
|
import os
|
|
|
|
|
|
|
|
def parse_cards(card_str):
|
|
"""将 'As Ks' 等字符串转换为 0-51 的 ID 列表"""
|
|
if not card_str.strip():
|
|
return []
|
|
cards = []
|
|
for c in card_str.strip().split():
|
|
rank = "23456789TJQKA".index(c[0].upper())
|
|
suit = "cdhs".index(c[1].lower())
|
|
cards.append(rank * 4 + suit)
|
|
return cards
|
|
|
|
def print_histogram(hist):
|
|
"""用 ASCII 打印简单的直方图"""
|
|
print("\n胜率直方图分布 (0% -> 100%):")
|
|
bars = [" "] * 50
|
|
for i, p in enumerate(hist):
|
|
length = int(p * 50) # 缩放以便显示
|
|
print(f"[{i*2:02d}%-{(i+1)*2:02d}%]: {'█' * length} ({p*100:.1f}%)")
|
|
|
|
def main():
|
|
# 1. 加载你刚刚训练好的最优模型
|
|
model_path = os.path.join(os.path.dirname(__file__), "data", "best_card_model.pt")
|
|
model = CardModel(VOCAB_SIZE, EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS)
|
|
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
|
model.eval()
|
|
print("模型加载成功!你可以输入牌型进行测试 (例如: As Ks)。输入 'q' 退出。")
|
|
|
|
while True:
|
|
hole_str = input("\n请输入你的底牌 (2张, 例如 'As Ks'): ")
|
|
if hole_str.lower() == 'q': break
|
|
board_str = input("请输入公共牌 (0到5张, 例如 'Qs Js 2d' 或直接回车留空): ")
|
|
|
|
try:
|
|
hole_cards = parse_cards(hole_str)
|
|
board_cards = parse_cards(board_str)
|
|
if len(hole_cards) != 2:
|
|
print("底牌必须是 2 张!")
|
|
continue
|
|
|
|
# 补齐公共牌到 5 张
|
|
padded_board = board_cards + [52] * (5 - len(board_cards))
|
|
|
|
x_hole = torch.tensor([hole_cards], dtype=torch.int64)
|
|
x_board = torch.tensor([padded_board], dtype=torch.int64)
|
|
|
|
with torch.no_grad():
|
|
pred_equity, pred_hist = model(x_hole, x_board)
|
|
|
|
eq = pred_equity.item()
|
|
hist = pred_hist[0].numpy()
|
|
|
|
print("-" * 40)
|
|
print(f"你的底牌: {hole_str} | 公共牌: {board_str if board_str else '(翻牌前)'}")
|
|
print(f"=> 预测绝对胜率 (Equity): {eq * 100:.2f}%")
|
|
print_histogram(hist)
|
|
print("-" * 40)
|
|
|
|
except Exception as e:
|
|
print(f"输入格式错误: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |