fix: harden openai-compatible provider responses

This commit is contained in:
Affaan Mustafa
2026-05-18 01:04:28 -04:00
parent cc62e89152
commit eb0d893948
4 changed files with 181 additions and 126 deletions

View File

@@ -15,6 +15,7 @@ from llm.core.interface import (
RateLimitError, RateLimitError,
) )
from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1" ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1"
ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1" ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1"
@@ -55,7 +56,7 @@ class _AstraflowBaseProvider(LLMProvider):
env_model = os.environ.get(self.model_env) env_model = os.environ.get(self.model_env)
fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None
self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) self.client = OpenAI(api_key=self.api_key, base_url=self.base_url, _enforce_credentials=False)
self._models = [ self._models = [
ModelInfo( ModelInfo(
name=self.default_model, name=self.default_model,
@@ -80,7 +81,7 @@ class _AstraflowBaseProvider(LLMProvider):
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
if not response.choices or response.choices[0].message is None: if not response.choices or response.choices[0].message is None:
raise ValueError("LLM returned empty or filtered response") raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR)
choice = response.choices[0] choice = response.choices[0]
tool_calls = None tool_calls = None

View File

@@ -0,0 +1,3 @@
"""Shared provider constants."""
EMPTY_FILTERED_RESPONSE_ERROR = "LLM returned empty or filtered response"

View File

@@ -1,116 +1,125 @@
"""OpenAI provider adapter.""" """OpenAI provider adapter."""
from __future__ import annotations from __future__ import annotations
import json import json
import os import os
from typing import Any from typing import Any
from openai import OpenAI from openai import OpenAI
from llm.core.interface import ( from llm.core.interface import (
AuthenticationError, AuthenticationError,
ContextLengthError, ContextLengthError,
LLMProvider, LLMProvider,
RateLimitError, RateLimitError,
) )
from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
class OpenAIProvider(LLMProvider):
provider_type = ProviderType.OPENAI class OpenAIProvider(LLMProvider):
provider_type = ProviderType.OPENAI
def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None:
self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"), base_url=base_url) def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None:
self._models = [ self.client = OpenAI(
ModelInfo( api_key=api_key or os.environ.get("OPENAI_API_KEY"),
name="gpt-4o", base_url=base_url,
provider=ProviderType.OPENAI, _enforce_credentials=False,
supports_tools=True, )
supports_vision=True, self._models = [
max_tokens=4096, ModelInfo(
context_window=128000, name="gpt-4o",
), provider=ProviderType.OPENAI,
ModelInfo( supports_tools=True,
name="gpt-4o-mini", supports_vision=True,
provider=ProviderType.OPENAI, max_tokens=4096,
supports_tools=True, context_window=128000,
supports_vision=True, ),
max_tokens=4096, ModelInfo(
context_window=128000, name="gpt-4o-mini",
), provider=ProviderType.OPENAI,
ModelInfo( supports_tools=True,
name="gpt-4-turbo", supports_vision=True,
provider=ProviderType.OPENAI, max_tokens=4096,
supports_tools=True, context_window=128000,
supports_vision=True, ),
max_tokens=4096, ModelInfo(
context_window=128000, name="gpt-4-turbo",
), provider=ProviderType.OPENAI,
ModelInfo( supports_tools=True,
name="gpt-3.5-turbo", supports_vision=True,
provider=ProviderType.OPENAI, max_tokens=4096,
supports_tools=True, context_window=128000,
supports_vision=False, ),
max_tokens=4096, ModelInfo(
context_window=16385, name="gpt-3.5-turbo",
), provider=ProviderType.OPENAI,
] supports_tools=True,
supports_vision=False,
def generate(self, input: LLMInput) -> LLMOutput: max_tokens=4096,
try: context_window=16385,
params: dict[str, Any] = { ),
"model": input.model or "gpt-4o-mini", ]
"messages": [msg.to_dict() for msg in input.messages],
"temperature": input.temperature, def generate(self, input: LLMInput) -> LLMOutput:
} try:
if input.max_tokens: params: dict[str, Any] = {
params["max_tokens"] = input.max_tokens "model": input.model or "gpt-4o-mini",
if input.tools: "messages": [msg.to_dict() for msg in input.messages],
params["tools"] = [tool.to_openai_tool() for tool in input.tools] "temperature": input.temperature,
}
response = self.client.chat.completions.create(**params) if input.max_tokens:
if not response.choices or response.choices[0].message is None: params["max_tokens"] = input.max_tokens
raise ValueError("LLM returned empty or filtered response") if input.tools:
choice = response.choices[0] params["tools"] = [tool.to_openai_tool() for tool in input.tools]
tool_calls = None response = self.client.chat.completions.create(**params)
if choice.message.tool_calls: if not response.choices or response.choices[0].message is None:
tool_calls = [ raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR)
ToolCall( choice = response.choices[0]
id=tc.id or "",
name=tc.function.name, tool_calls = None
arguments={} if not tc.function.arguments else json.loads(tc.function.arguments), if choice.message.tool_calls:
) tool_calls = [
for tc in choice.message.tool_calls ToolCall(
] id=tc.id or "",
name=tc.function.name,
return LLMOutput( arguments={} if not tc.function.arguments else json.loads(tc.function.arguments),
content=choice.message.content or "", )
tool_calls=tool_calls, for tc in choice.message.tool_calls
model=response.model, ]
usage={
"prompt_tokens": response.usage.prompt_tokens, usage = None
"completion_tokens": response.usage.completion_tokens, if response.usage:
"total_tokens": response.usage.total_tokens, usage = {
}, "prompt_tokens": response.usage.prompt_tokens,
stop_reason=choice.finish_reason, "completion_tokens": response.usage.completion_tokens,
) "total_tokens": response.usage.total_tokens,
except Exception as e: }
msg = str(e)
if "401" in msg or "authentication" in msg.lower(): return LLMOutput(
raise AuthenticationError(msg, provider=ProviderType.OPENAI) from e content=choice.message.content or "",
if "429" in msg or "rate_limit" in msg.lower(): tool_calls=tool_calls,
raise RateLimitError(msg, provider=ProviderType.OPENAI) from e model=response.model,
if "context" in msg.lower() and "length" in msg.lower(): usage=usage,
raise ContextLengthError(msg, provider=ProviderType.OPENAI) from e stop_reason=choice.finish_reason,
raise )
except Exception as e:
def list_models(self) -> list[ModelInfo]: msg = str(e)
return self._models.copy() if "401" in msg or "authentication" in msg.lower():
raise AuthenticationError(msg, provider=ProviderType.OPENAI) from e
def validate_config(self) -> bool: if "429" in msg or "rate_limit" in msg.lower():
return bool(self.client.api_key) raise RateLimitError(msg, provider=ProviderType.OPENAI) from e
if "context" in msg.lower() and "length" in msg.lower():
def get_default_model(self) -> str: raise ContextLengthError(msg, provider=ProviderType.OPENAI) from e
return "gpt-4o-mini" raise
def list_models(self) -> list[ModelInfo]:
return self._models.copy()
def validate_config(self) -> bool:
return bool(self.client.api_key)
def get_default_model(self) -> str:
return "gpt-4o-mini"

View File

@@ -1,7 +1,10 @@
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from llm.core.types import LLMInput, Message, Role, ToolDefinition from llm.core.types import LLMInput, Message, Role, ToolDefinition
from llm.providers.claude import ClaudeProvider from llm.providers.claude import ClaudeProvider
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
from llm.providers.openai import OpenAIProvider from llm.providers.openai import OpenAIProvider
@@ -14,21 +17,20 @@ def _tool() -> ToolDefinition:
class _OpenAICompletions: class _OpenAICompletions:
def __init__(self) -> None: def __init__(self, response: SimpleNamespace | None = None) -> None:
self.params = None self.params = None
self.response = response
def create(self, **params): def create(self, **params):
self.params = params self.params = params
return SimpleNamespace( if self.response:
choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")], return self.response
model=params["model"], return _openai_response(model=params["model"])
usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
)
class _OpenAIClient: class _OpenAIClient:
def __init__(self) -> None: def __init__(self, response: SimpleNamespace | None = None) -> None:
self.completions = _OpenAICompletions() self.completions = _OpenAICompletions(response=response)
self.chat = SimpleNamespace(completions=self.completions) self.chat = SimpleNamespace(completions=self.completions)
@@ -52,6 +54,16 @@ class _AnthropicClient:
self.api_key = "test" self.api_key = "test"
def _openai_response(**overrides) -> SimpleNamespace:
defaults = {
"choices": [SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")],
"model": "gpt-4o-mini",
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_openai_provider_serializes_tools_for_chat_completions(): def test_openai_provider_serializes_tools_for_chat_completions():
provider = OpenAIProvider(api_key="test") provider = OpenAIProvider(api_key="test")
client = _OpenAIClient() client = _OpenAIClient()
@@ -72,6 +84,36 @@ def test_openai_provider_serializes_tools_for_chat_completions():
] ]
def test_openai_provider_can_be_constructed_without_credentials(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
provider = OpenAIProvider()
assert provider.validate_config() is False
def test_openai_provider_rejects_empty_or_filtered_responses():
provider = OpenAIProvider(api_key="test")
for response in [
_openai_response(choices=[]),
_openai_response(choices=[SimpleNamespace(message=None, finish_reason="content_filter")]),
]:
provider.client = _OpenAIClient(response=response)
with pytest.raises(ValueError, match=EMPTY_FILTERED_RESPONSE_ERROR):
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
def test_openai_provider_allows_missing_usage():
provider = OpenAIProvider(api_key="test")
provider.client = _OpenAIClient(response=_openai_response(usage=None))
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.content == "ok"
assert output.usage is None
def test_claude_provider_serializes_tools_for_messages_api(): def test_claude_provider_serializes_tools_for_messages_api():
provider = ClaudeProvider(api_key="test") provider = ClaudeProvider(api_key="test")
client = _AnthropicClient() client = _AnthropicClient()