fix: harden openai-compatible provider responses

This commit is contained in:
Affaan Mustafa
2026-05-18 01:04:28 -04:00
parent cc62e89152
commit eb0d893948
4 changed files with 181 additions and 126 deletions

View File

@@ -15,6 +15,7 @@ from llm.core.interface import (
RateLimitError, RateLimitError,
) )
from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1" ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1"
ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1" ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1"
@@ -55,7 +56,7 @@ class _AstraflowBaseProvider(LLMProvider):
env_model = os.environ.get(self.model_env) env_model = os.environ.get(self.model_env)
fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None
self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) self.client = OpenAI(api_key=self.api_key, base_url=self.base_url, _enforce_credentials=False)
self._models = [ self._models = [
ModelInfo( ModelInfo(
name=self.default_model, name=self.default_model,
@@ -80,7 +81,7 @@ class _AstraflowBaseProvider(LLMProvider):
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
if not response.choices or response.choices[0].message is None: if not response.choices or response.choices[0].message is None:
raise ValueError("LLM returned empty or filtered response") raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR)
choice = response.choices[0] choice = response.choices[0]
tool_calls = None tool_calls = None

View File

@@ -0,0 +1,3 @@
"""Shared provider constants."""
EMPTY_FILTERED_RESPONSE_ERROR = "LLM returned empty or filtered response"

View File

@@ -15,13 +15,18 @@ from llm.core.interface import (
RateLimitError, RateLimitError,
) )
from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
class OpenAIProvider(LLMProvider): class OpenAIProvider(LLMProvider):
provider_type = ProviderType.OPENAI provider_type = ProviderType.OPENAI
def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None:
self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"), base_url=base_url) self.client = OpenAI(
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
base_url=base_url,
_enforce_credentials=False,
)
self._models = [ self._models = [
ModelInfo( ModelInfo(
name="gpt-4o", name="gpt-4o",
@@ -71,7 +76,7 @@ class OpenAIProvider(LLMProvider):
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
if not response.choices or response.choices[0].message is None: if not response.choices or response.choices[0].message is None:
raise ValueError("LLM returned empty or filtered response") raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR)
choice = response.choices[0] choice = response.choices[0]
tool_calls = None tool_calls = None
@@ -85,15 +90,19 @@ class OpenAIProvider(LLMProvider):
for tc in choice.message.tool_calls for tc in choice.message.tool_calls
] ]
return LLMOutput( usage = None
content=choice.message.content or "", if response.usage:
tool_calls=tool_calls,
model=response.model,
usage = { usage = {
"prompt_tokens": response.usage.prompt_tokens, "prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens, "completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens,
}, }
return LLMOutput(
content=choice.message.content or "",
tool_calls=tool_calls,
model=response.model,
usage=usage,
stop_reason=choice.finish_reason, stop_reason=choice.finish_reason,
) )
except Exception as e: except Exception as e:

View File

@@ -1,7 +1,10 @@
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from llm.core.types import LLMInput, Message, Role, ToolDefinition from llm.core.types import LLMInput, Message, Role, ToolDefinition
from llm.providers.claude import ClaudeProvider from llm.providers.claude import ClaudeProvider
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
from llm.providers.openai import OpenAIProvider from llm.providers.openai import OpenAIProvider
@@ -14,21 +17,20 @@ def _tool() -> ToolDefinition:
class _OpenAICompletions: class _OpenAICompletions:
def __init__(self) -> None: def __init__(self, response: SimpleNamespace | None = None) -> None:
self.params = None self.params = None
self.response = response
def create(self, **params): def create(self, **params):
self.params = params self.params = params
return SimpleNamespace( if self.response:
choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")], return self.response
model=params["model"], return _openai_response(model=params["model"])
usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
)
class _OpenAIClient: class _OpenAIClient:
def __init__(self) -> None: def __init__(self, response: SimpleNamespace | None = None) -> None:
self.completions = _OpenAICompletions() self.completions = _OpenAICompletions(response=response)
self.chat = SimpleNamespace(completions=self.completions) self.chat = SimpleNamespace(completions=self.completions)
@@ -52,6 +54,16 @@ class _AnthropicClient:
self.api_key = "test" self.api_key = "test"
def _openai_response(**overrides) -> SimpleNamespace:
defaults = {
"choices": [SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")],
"model": "gpt-4o-mini",
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_openai_provider_serializes_tools_for_chat_completions(): def test_openai_provider_serializes_tools_for_chat_completions():
provider = OpenAIProvider(api_key="test") provider = OpenAIProvider(api_key="test")
client = _OpenAIClient() client = _OpenAIClient()
@@ -72,6 +84,36 @@ def test_openai_provider_serializes_tools_for_chat_completions():
] ]
def test_openai_provider_can_be_constructed_without_credentials(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
provider = OpenAIProvider()
assert provider.validate_config() is False
def test_openai_provider_rejects_empty_or_filtered_responses():
provider = OpenAIProvider(api_key="test")
for response in [
_openai_response(choices=[]),
_openai_response(choices=[SimpleNamespace(message=None, finish_reason="content_filter")]),
]:
provider.client = _OpenAIClient(response=response)
with pytest.raises(ValueError, match=EMPTY_FILTERED_RESPONSE_ERROR):
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
def test_openai_provider_allows_missing_usage():
provider = OpenAIProvider(api_key="test")
provider.client = _OpenAIClient(response=_openai_response(usage=None))
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.content == "ok"
assert output.usage is None
def test_claude_provider_serializes_tools_for_messages_api(): def test_claude_provider_serializes_tools_for_messages_api():
provider = ClaudeProvider(api_key="test") provider = ClaudeProvider(api_key="test")
client = _AnthropicClient() client = _AnthropicClient()