|
3 | 3 | import json |
4 | 4 | from typing import Annotated, Any, Literal |
5 | 5 |
|
6 | | -from pydantic import BeforeValidator, Field, model_validator |
| 6 | +from pydantic import BeforeValidator, Field, field_validator, model_validator |
7 | 7 | from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict |
8 | 8 |
|
9 | 9 | from opencode_a2a import __version__ |
| 10 | +from opencode_a2a.protocol_versions import ( |
| 11 | + normalize_protocol_version, |
| 12 | + normalize_protocol_versions, |
| 13 | +) |
10 | 14 | from opencode_a2a.sandbox_policy import SandboxPolicy |
11 | 15 |
|
12 | 16 | SandboxMode = Literal[ |
@@ -97,7 +101,11 @@ class Settings(BaseSettings): |
97 | 101 | a2a_title: str = Field(default="OpenCode A2A", alias="A2A_TITLE") |
98 | 102 | a2a_description: str = Field(default="OpenCode A2A runtime", alias="A2A_DESCRIPTION") |
99 | 103 | a2a_version: str = Field(default=__version__, alias="A2A_VERSION") |
100 | | - a2a_protocol_version: str = Field(default="0.3.0", alias="A2A_PROTOCOL_VERSION") |
| 104 | + a2a_protocol_version: str = Field(default="0.3", alias="A2A_PROTOCOL_VERSION") |
| 105 | + a2a_supported_protocol_versions: DeclaredStringList = Field( |
| 106 | + default=("0.3", "1.0"), |
| 107 | + alias="A2A_SUPPORTED_PROTOCOL_VERSIONS", |
| 108 | + ) |
101 | 109 | a2a_log_level: str = Field(default="WARNING", alias="A2A_LOG_LEVEL") |
102 | 110 | a2a_log_payloads: bool = Field(default=False, alias="A2A_LOG_PAYLOADS") |
103 | 111 | a2a_log_body_limit: int = Field(default=0, alias="A2A_LOG_BODY_LIMIT") |
@@ -180,6 +188,10 @@ class Settings(BaseSettings): |
180 | 188 | ) |
181 | 189 | a2a_client_bearer_token: str | None = Field(default=None, alias="A2A_CLIENT_BEARER_TOKEN") |
182 | 190 | a2a_client_basic_auth: str | None = Field(default=None, alias="A2A_CLIENT_BASIC_AUTH") |
| 191 | + a2a_client_protocol_version: str | None = Field( |
| 192 | + default=None, |
| 193 | + alias="A2A_CLIENT_PROTOCOL_VERSION", |
| 194 | + ) |
183 | 195 | a2a_client_cache_ttl_seconds: float = Field( |
184 | 196 | default=900.0, |
185 | 197 | ge=0.0, |
@@ -212,4 +224,37 @@ def _validate_sandbox_policy(self) -> Settings: |
212 | 224 | raise ValueError( |
213 | 225 | "A2A_TASK_STORE_DATABASE_URL is required when A2A_TASK_STORE_BACKEND=database" |
214 | 226 | ) |
| 227 | + if self.a2a_protocol_version not in self.a2a_supported_protocol_versions: |
| 228 | + supported_display = ", ".join(self.a2a_supported_protocol_versions) |
| 229 | + raise ValueError( |
| 230 | + "A2A_PROTOCOL_VERSION must be present in A2A_SUPPORTED_PROTOCOL_VERSIONS. " |
| 231 | + f"Declared supported versions: {supported_display}" |
| 232 | + ) |
215 | 233 | return self |
| 234 | + |
| 235 | + @field_validator("a2a_protocol_version", mode="before") |
| 236 | + @classmethod |
| 237 | + def _normalize_a2a_protocol_version(cls, value: Any) -> str: |
| 238 | + if not isinstance(value, str): |
| 239 | + raise TypeError("A2A_PROTOCOL_VERSION must be a string.") |
| 240 | + return normalize_protocol_version(value) |
| 241 | + |
| 242 | + @field_validator("a2a_client_protocol_version", mode="before") |
| 243 | + @classmethod |
| 244 | + def _normalize_a2a_client_protocol_version(cls, value: Any) -> str | None: |
| 245 | + if value is None: |
| 246 | + return None |
| 247 | + if not isinstance(value, str): |
| 248 | + raise TypeError("A2A_CLIENT_PROTOCOL_VERSION must be a string.") |
| 249 | + normalized = value.strip() |
| 250 | + if not normalized: |
| 251 | + return None |
| 252 | + return normalize_protocol_version(normalized) |
| 253 | + |
| 254 | + @field_validator("a2a_supported_protocol_versions") |
| 255 | + @classmethod |
| 256 | + def _normalize_supported_protocol_versions( |
| 257 | + cls, |
| 258 | + value: tuple[str, ...], |
| 259 | + ) -> tuple[str, ...]: |
| 260 | + return normalize_protocol_versions(value) |
0 commit comments