fix: resolve git conflicts in LLM abstraction layer

- Fix gui() function import in __init__.py (use cli.selector)
- Fix prompt builder system message merging logic
- Add default max_tokens for Anthropic API in claude.py
- Fix openai tool_call arguments parsing with json.loads
- Fix test_builder.py PromptConfig import and assertions
This commit is contained in:
Anish
2026-04-12 07:10:54 +00:00
6 changed files with 501 additions and 495 deletions

View File

@@ -28,15 +28,9 @@ dev = [
"pytest>=8.0", "pytest>=8.0",
"pytest-asyncio>=0.23", "pytest-asyncio>=0.23",
"pytest-cov>=4.1", "pytest-cov>=4.1",
"pytest-mock>=3.12",
"ruff>=0.4", "ruff>=0.4",
"mypy>=1.10", "mypy>=1.10",
"ruff>=0.4",
]
test = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"pytest-cov>=4.1",
"pytest-mock>=3.12",
] ]
[project.urls] [project.urls]

View File

@@ -28,6 +28,6 @@ __all__ = [
def gui() -> None: def gui() -> None:
from llm.gui.selector import main from llm.cli.selector import main
main() main()

View File

@@ -39,6 +39,7 @@ class PromptBuilder:
if messages[0].role == Role.SYSTEM: if messages[0].role == Role.SYSTEM:
system_parts.insert(0, messages[0].content) system_parts.insert(0, messages[0].content)
result.insert(0, Message(role=Role.SYSTEM, content="\n\n".join(system_parts)))
result.extend(messages[1:]) result.extend(messages[1:])
else: else:
if system_parts: if system_parts:

View File

@@ -57,6 +57,8 @@ class ClaudeProvider(LLMProvider):
} }
if input.max_tokens: if input.max_tokens:
params["max_tokens"] = input.max_tokens params["max_tokens"] = input.max_tokens
else:
params["max_tokens"] = 8192 # required by Anthropic API
if input.tools: if input.tools:
params["tools"] = [tool.to_dict() for tool in input.tools] params["tools"] = [tool.to_dict() for tool in input.tools]

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import os import os
from typing import Any from typing import Any
@@ -77,7 +78,7 @@ class OpenAIProvider(LLMProvider):
ToolCall( ToolCall(
id=tc.id or "", id=tc.id or "",
name=tc.function.name, name=tc.function.name,
arguments={} if tc.function.arguments == "" else tc.function.arguments, arguments={} if not tc.function.arguments else json.loads(tc.function.arguments),
) )
for tc in choice.message.tool_calls for tc in choice.message.tool_calls
] ]

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from llm.core.types import LLMInput, Message, Role, ToolDefinition from llm.core.types import LLMInput, Message, Role, ToolDefinition
from llm.prompt import PromptBuilder, adapt_messages_for_provider from llm.prompt import PromptBuilder, adapt_messages_for_provider
from llm.prompt.builder import PromptConfig
class TestPromptBuilder: class TestPromptBuilder:
@@ -31,6 +32,13 @@ class TestPromptBuilder:
assert len(result) == 2 assert len(result) == 2
assert "pirate" in result[0].content assert "pirate" in result[0].content
def test_build_adds_system_from_config(self):
messages = [Message(role=Role.USER, content="Hello")]
builder = PromptBuilder(config=PromptConfig(system_template="You are a pirate."))
result = builder.build(messages)
assert len(result) == 2
assert "pirate" in result[0].content
def test_build_with_tools(self): def test_build_with_tools(self):
messages = [Message(role=Role.USER, content="Search for something")] messages = [Message(role=Role.USER, content="Search for something")]
tools = [ tools = [