diff --git a/src/llm/providers/astraflow.py b/src/llm/providers/astraflow.py index ba23517f..1f7cceab 100644 --- a/src/llm/providers/astraflow.py +++ b/src/llm/providers/astraflow.py @@ -15,6 +15,7 @@ from llm.core.interface import ( RateLimitError, ) from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall +from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1" ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1" @@ -55,7 +56,7 @@ class _AstraflowBaseProvider(LLMProvider): env_model = os.environ.get(self.model_env) fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL - self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url, _enforce_credentials=False) self._models = [ ModelInfo( name=self.default_model, @@ -79,6 +80,8 @@ class _AstraflowBaseProvider(LLMProvider): params["tools"] = [tool.to_openai_tool() for tool in llm_input.tools] response = self.client.chat.completions.create(**params) + if not response.choices or response.choices[0].message is None: + raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR) choice = response.choices[0] tool_calls = None diff --git a/src/llm/providers/constants.py b/src/llm/providers/constants.py new file mode 100644 index 00000000..092cac3a --- /dev/null +++ b/src/llm/providers/constants.py @@ -0,0 +1,3 @@ +"""Shared provider constants.""" + +EMPTY_FILTERED_RESPONSE_ERROR = "LLM returned empty or filtered response" diff --git a/src/llm/providers/openai.py b/src/llm/providers/openai.py index e4e7f895..7461a8f1 100644 --- a/src/llm/providers/openai.py +++ b/src/llm/providers/openai.py @@ -1,114 +1,125 @@ -"""OpenAI provider adapter.""" - -from __future__ import annotations - -import json -import os -from typing import Any - -from openai import OpenAI - -from llm.core.interface import ( - AuthenticationError, - ContextLengthError, - LLMProvider, - RateLimitError, -) -from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall - - -class OpenAIProvider(LLMProvider): - provider_type = ProviderType.OPENAI - - def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: - self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"), base_url=base_url) - self._models = [ - ModelInfo( - name="gpt-4o", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=True, - max_tokens=4096, - context_window=128000, - ), - ModelInfo( - name="gpt-4o-mini", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=True, - max_tokens=4096, - context_window=128000, - ), - ModelInfo( - name="gpt-4-turbo", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=True, - max_tokens=4096, - context_window=128000, - ), - ModelInfo( - name="gpt-3.5-turbo", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=False, - max_tokens=4096, - context_window=16385, - ), - ] - - def generate(self, input: LLMInput) -> LLMOutput: - try: - params: dict[str, Any] = { - "model": input.model or "gpt-4o-mini", - "messages": [msg.to_dict() for msg in input.messages], - "temperature": input.temperature, - } - if input.max_tokens: - params["max_tokens"] = input.max_tokens - if input.tools: +"""OpenAI provider adapter.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +from openai import OpenAI + +from llm.core.interface import ( + AuthenticationError, + ContextLengthError, + LLMProvider, + RateLimitError, +) +from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall +from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR + + +class OpenAIProvider(LLMProvider): + provider_type = ProviderType.OPENAI + + def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: + self.client = OpenAI( + api_key=api_key or os.environ.get("OPENAI_API_KEY"), + base_url=base_url, + _enforce_credentials=False, + ) + self._models = [ + ModelInfo( + name="gpt-4o", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=True, + max_tokens=4096, + context_window=128000, + ), + ModelInfo( + name="gpt-4o-mini", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=True, + max_tokens=4096, + context_window=128000, + ), + ModelInfo( + name="gpt-4-turbo", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=True, + max_tokens=4096, + context_window=128000, + ), + ModelInfo( + name="gpt-3.5-turbo", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=False, + max_tokens=4096, + context_window=16385, + ), + ] + + def generate(self, input: LLMInput) -> LLMOutput: + try: + params: dict[str, Any] = { + "model": input.model or "gpt-4o-mini", + "messages": [msg.to_dict() for msg in input.messages], + "temperature": input.temperature, + } + if input.max_tokens: + params["max_tokens"] = input.max_tokens + if input.tools: params["tools"] = [tool.to_openai_tool() for tool in input.tools] - - response = self.client.chat.completions.create(**params) - choice = response.choices[0] - - tool_calls = None - if choice.message.tool_calls: - tool_calls = [ - ToolCall( - id=tc.id or "", - name=tc.function.name, - arguments={} if not tc.function.arguments else json.loads(tc.function.arguments), - ) - for tc in choice.message.tool_calls - ] - - return LLMOutput( - content=choice.message.content or "", - tool_calls=tool_calls, - model=response.model, - usage={ - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - }, - stop_reason=choice.finish_reason, - ) - except Exception as e: - msg = str(e) - if "401" in msg or "authentication" in msg.lower(): - raise AuthenticationError(msg, provider=ProviderType.OPENAI) from e - if "429" in msg or "rate_limit" in msg.lower(): - raise RateLimitError(msg, provider=ProviderType.OPENAI) from e - if "context" in msg.lower() and "length" in msg.lower(): - raise ContextLengthError(msg, provider=ProviderType.OPENAI) from e - raise - - def list_models(self) -> list[ModelInfo]: - return self._models.copy() - - def validate_config(self) -> bool: - return bool(self.client.api_key) - - def get_default_model(self) -> str: - return "gpt-4o-mini" + + response = self.client.chat.completions.create(**params) + if not response.choices or response.choices[0].message is None: + raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR) + choice = response.choices[0] + + tool_calls = None + if choice.message.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id or "", + name=tc.function.name, + arguments={} if not tc.function.arguments else json.loads(tc.function.arguments), + ) + for tc in choice.message.tool_calls + ] + + usage = None + if response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + + return LLMOutput( + content=choice.message.content or "", + tool_calls=tool_calls, + model=response.model, + usage=usage, + stop_reason=choice.finish_reason, + ) + except Exception as e: + msg = str(e) + if "401" in msg or "authentication" in msg.lower(): + raise AuthenticationError(msg, provider=ProviderType.OPENAI) from e + if "429" in msg or "rate_limit" in msg.lower(): + raise RateLimitError(msg, provider=ProviderType.OPENAI) from e + if "context" in msg.lower() and "length" in msg.lower(): + raise ContextLengthError(msg, provider=ProviderType.OPENAI) from e + raise + + def list_models(self) -> list[ModelInfo]: + return self._models.copy() + + def validate_config(self) -> bool: + return bool(self.client.api_key) + + def get_default_model(self) -> str: + return "gpt-4o-mini" diff --git a/tests/test_provider_tools.py b/tests/test_provider_tools.py index e12f7aa6..4c9c76f9 100644 --- a/tests/test_provider_tools.py +++ b/tests/test_provider_tools.py @@ -1,7 +1,10 @@ from types import SimpleNamespace +import pytest + from llm.core.types import LLMInput, Message, Role, ToolDefinition from llm.providers.claude import ClaudeProvider +from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR from llm.providers.openai import OpenAIProvider @@ -14,21 +17,20 @@ def _tool() -> ToolDefinition: class _OpenAICompletions: - def __init__(self) -> None: + def __init__(self, response: SimpleNamespace | None = None) -> None: self.params = None + self.response = response def create(self, **params): self.params = params - return SimpleNamespace( - choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")], - model=params["model"], - usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2), - ) + if self.response: + return self.response + return _openai_response(model=params["model"]) class _OpenAIClient: - def __init__(self) -> None: - self.completions = _OpenAICompletions() + def __init__(self, response: SimpleNamespace | None = None) -> None: + self.completions = _OpenAICompletions(response=response) self.chat = SimpleNamespace(completions=self.completions) @@ -52,6 +54,16 @@ class _AnthropicClient: self.api_key = "test" +def _openai_response(**overrides) -> SimpleNamespace: + defaults = { + "choices": [SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")], + "model": "gpt-4o-mini", + "usage": SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2), + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + def test_openai_provider_serializes_tools_for_chat_completions(): provider = OpenAIProvider(api_key="test") client = _OpenAIClient() @@ -72,6 +84,36 @@ def test_openai_provider_serializes_tools_for_chat_completions(): ] +def test_openai_provider_can_be_constructed_without_credentials(monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + provider = OpenAIProvider() + + assert provider.validate_config() is False + + +def test_openai_provider_rejects_empty_or_filtered_responses(): + provider = OpenAIProvider(api_key="test") + + for response in [ + _openai_response(choices=[]), + _openai_response(choices=[SimpleNamespace(message=None, finish_reason="content_filter")]), + ]: + provider.client = _OpenAIClient(response=response) + with pytest.raises(ValueError, match=EMPTY_FILTERED_RESPONSE_ERROR): + provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")])) + + +def test_openai_provider_allows_missing_usage(): + provider = OpenAIProvider(api_key="test") + provider.client = _OpenAIClient(response=_openai_response(usage=None)) + + output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")])) + + assert output.content == "ok" + assert output.usage is None + + def test_claude_provider_serializes_tools_for_messages_api(): provider = ClaudeProvider(api_key="test") client = _AnthropicClient()