diff --git a/cmd/chat_template/chat_template.py b/cmd/chat_template/chat_template.py index 548ae9825..41c1fd45d 100644 --- a/cmd/chat_template/chat_template.py +++ b/cmd/chat_template/chat_template.py @@ -7,6 +7,7 @@ # "fastapi", # "uvicorn", # "pydantic", +# "requests", # ] # /// """ @@ -15,9 +16,11 @@ Chat Template Testing Tool Test HuggingFace chat templates against Ollama renderers. Usage: - # Run predefined test cases against a model + # Run predefined test cases against a HuggingFace model uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 - uv run cmd/chat_template/chat_template.py --model allenai/Olmo-3-7B-Think + + # Compare HuggingFace output with Ollama renderer + uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3 # Start server for manual curl testing uv run cmd/chat_template/chat_template.py --serve @@ -213,6 +216,7 @@ TEST_CASES = [ {"role": "user", "content": "Think about this silently."}, { "role": "assistant", + "content": "", # HuggingFace requires content field "thinking": "I'm thinking about this but won't respond with visible content.", }, {"role": "user", "content": "What did you think?"}, @@ -256,11 +260,90 @@ def apply_template( ) +def get_ollama_prompt( + ollama_model: str, + messages: list[dict], + tools: list[dict] | None = None, + ollama_host: str = "http://localhost:11434", +) -> str | None: + """Get rendered prompt from Ollama using debug_render_only.""" + import requests + + # Convert messages to Ollama format + ollama_messages = [] + for msg in messages: + ollama_msg = {"role": msg["role"]} + if "content" in msg: + ollama_msg["content"] = msg["content"] + if "thinking" in msg: + ollama_msg["thinking"] = msg["thinking"] + if "tool_calls" in msg: + # Convert tool_calls to Ollama format + tool_calls = [] + for tc in msg["tool_calls"]: + tool_call = { + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + } + } + if "id" in tc: + tool_call["id"] = tc["id"] + tool_calls.append(tool_call) + ollama_msg["tool_calls"] = tool_calls + if "tool_call_id" in msg: + ollama_msg["tool_call_id"] = msg["tool_call_id"] + ollama_messages.append(ollama_msg) + + payload = { + "model": ollama_model, + "messages": ollama_messages, + "stream": False, + "_debug_render_only": True, + } + + if tools: + payload["tools"] = tools + + try: + resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + # Field name is _debug_info with underscore prefix + if "_debug_info" in data and "rendered_template" in data["_debug_info"]: + return data["_debug_info"]["rendered_template"] + return None + except requests.exceptions.ConnectionError: + print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr) + return None + except Exception as e: + print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr) + return None + + +def compute_diff(hf_prompt: str, ollama_prompt: str) -> str: + """Compute a unified diff between HuggingFace and Ollama prompts.""" + import difflib + + hf_lines = hf_prompt.splitlines(keepends=True) + ollama_lines = ollama_prompt.splitlines(keepends=True) + + diff = difflib.unified_diff( + ollama_lines, + hf_lines, + fromfile="Ollama", + tofile="HuggingFace", + lineterm="", + ) + return "".join(diff) + + def print_test_output( name: str, messages: list[dict], tools: list[dict] | None, - prompt: str, + hf_prompt: str, + ollama_prompt: str | None = None, as_repr: bool = False, ): """Print test output in a format suitable for Go test creation and LLM diffing.""" @@ -272,17 +355,61 @@ def print_test_output( if tools: print("\n--- Tools ---") print(json.dumps(tools, indent=2)) - print("\n--- Output Prompt ---") - if as_repr: - print(repr(prompt)) + + if ollama_prompt is not None: + # Comparison mode + if hf_prompt == ollama_prompt: + print("\n--- Result: MATCH ---") + print("\n--- Prompt (both identical) ---") + if as_repr: + print(repr(hf_prompt)) + else: + print(hf_prompt) + else: + print("\n--- Result: MISMATCH ---") + print("\n--- HuggingFace Prompt ---") + if as_repr: + print(repr(hf_prompt)) + else: + print(hf_prompt) + print("\n--- Ollama Prompt ---") + if as_repr: + print(repr(ollama_prompt)) + else: + print(ollama_prompt) + print("\n--- Diff (Ollama -> HuggingFace) ---") + diff = compute_diff(hf_prompt, ollama_prompt) + if diff: + print(diff) + else: + print("(no line-level diff, check whitespace)") else: - print(prompt) + # HuggingFace only mode + print("\n--- HuggingFace Prompt ---") + if as_repr: + print(repr(hf_prompt)) + else: + print(hf_prompt) + print("=" * 60) -def run_tests(model: str, as_repr: bool = False, test_filter: str | None = None): +def run_tests( + model: str, + as_repr: bool = False, + test_filter: str | None = None, + ollama_model: str | None = None, + ollama_host: str = "http://localhost:11434", +): """Run all predefined test cases against a model.""" - print(f"\nRunning tests against: {model}\n") + if ollama_model: + print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n") + else: + print(f"\nRunning tests against: {model}\n") + + matches = 0 + mismatches = 0 + errors = 0 for test_case in TEST_CASES: name = test_case["name"] @@ -294,9 +421,25 @@ def run_tests(model: str, as_repr: bool = False, test_filter: str | None = None) continue try: - prompt = apply_template(model, messages, tools) - print_test_output(name, messages, tools, prompt, as_repr=as_repr) + hf_prompt = apply_template(model, messages, tools) + + ollama_prompt = None + if ollama_model: + ollama_prompt = get_ollama_prompt( + ollama_model, messages, tools, ollama_host + ) + if ollama_prompt is None: + errors += 1 + elif hf_prompt == ollama_prompt: + matches += 1 + else: + mismatches += 1 + + print_test_output( + name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr + ) except Exception as e: + errors += 1 print(f"\n{'='*60}") print(f"Test: {name} - FAILED") print(f"--- Input Messages ---") @@ -308,6 +451,18 @@ def run_tests(model: str, as_repr: bool = False, test_filter: str | None = None) print(f"{e}") print("=" * 60) + # Print summary if comparing + if ollama_model: + total = matches + mismatches + errors + print(f"\n{'='*60}") + print("SUMMARY") + print("=" * 60) + print(f" Total: {total}") + print(f" Matches: {matches}") + print(f" Mismatches: {mismatches}") + print(f" Errors: {errors}") + print("=" * 60) + def show_template(model: str): """Show the chat template for a model.""" @@ -397,6 +552,18 @@ def main(): type=str, help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)", ) + parser.add_argument( + "--ollama-model", + "-o", + type=str, + help="Ollama model name to compare against (e.g., qwen3-coder)", + ) + parser.add_argument( + "--ollama-host", + type=str, + default="http://localhost:11434", + help="Ollama server URL (default: http://localhost:11434)", + ) parser.add_argument( "--serve", "-s", @@ -437,12 +604,18 @@ def main(): if args.show_template: show_template(args.model) else: - run_tests(args.model, as_repr=args.repr, test_filter=args.filter) + run_tests( + args.model, + as_repr=args.repr, + test_filter=args.filter, + ollama_model=args.ollama_model, + ollama_host=args.ollama_host, + ) else: parser.print_help() print("\nExample usage:") print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3") - print(" uv run cmd/chat_template/chat_template.py --model allenai/Olmo-3-7B-Think") + print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder") print(" uv run cmd/chat_template/chat_template.py --serve") sys.exit(1)