Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,41 @@
![license badge](https://img.shields.io/github/license/cohere-ai/cohere-python)
[![fern shield](https://img.shields.io/badge/%F0%9F%8C%BF-SDK%20generated%20by%20Fern-brightgreen)](https://github.com/fern-api/fern)

---

## ⚠️ Custom Modifications (Internal Fork)

**This is a modified version of the Cohere Python SDK with the following changes:**

### Async Client Migration (httpx → aiohttp)
- **Date:** February 2026
- **Reason:** Resolves `httpx.ConnectError: All connection attempts failed` issues
- **Scope:** Async clients only (`AsyncClient`, `AsyncClientV2`) - sync clients unchanged

### Modified Files:
- `src/cohere/core/http_client.py` - AsyncHttpClient migrated to aiohttp
- `src/cohere/core/client_wrapper.py` - AsyncClientWrapper updated
- `src/cohere/base_client.py` - AsyncBaseCohere initialization
- `src/cohere/core/http_response.py` - AsyncHttpResponse compatibility
- `src/cohere/core/http_sse/_api.py` - SSE streaming with aiohttp
- `src/cohere/core/http_sse/_exceptions.py` - Exception compatibility
- `src/cohere/core/file.py` - FormData support for aiohttp
- `pyproject.toml` - Added aiohttp dependency

### Testing:
- All async operations verified working (see `test_async_client.py`)
- 8/8 test suite passing: chat, streaming, SSE, embed, concurrent requests, error handling

### Important Notes:
- **Fern-generated code modified:** Changes will be overwritten if Fern regenerates
- **Version pinned:** Stay on 5.20.5 base until migration is upstreamed
- **Backward compatible:** Sync clients (`Client`, `ClientV2`) continue using httpx
- **Production ready:** All async functionality tested and working

**To use:** Install with `uv sync` in this directory

---

The Cohere Python SDK allows access to Cohere models across many different platforms: the cohere platform, AWS (Bedrock, Sagemaker), Azure, GCP and Oracle OCI. For a full list of support and snippets, please take a look at the [SDK support docs page](https://docs.cohere.com/docs/cohere-works-everywhere).

## Documentation
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Repository = 'https://github.com/cohere-ai/cohere-python'

[tool.poetry.dependencies]
python = "^3.9"
aiohttp = "^3.9.0"
fastavro = "^1.9.4"
httpx = ">=0.21.2"
pydantic = ">= 1.9.2"
Expand Down
38 changes: 26 additions & 12 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import typing

import aiohttp
import httpx
from .core.api_error import ApiError
from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
Expand Down Expand Up @@ -1592,7 +1593,8 @@ class AsyncBaseCohere:
The timeout to be used, in seconds, for requests. By default the timeout is 300 seconds, unless a custom httpx client is used, in which case this default is not enforced.

follow_redirects : typing.Optional[bool]
Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in.
Whether the async client follows HTTP redirects. Defaults to True. Passed as allow_redirects
on each request; does NOT affect TCP connection reuse (connections are always pooled).

httpx_client : typing.Optional[httpx.AsyncClient]
The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration.
Expand All @@ -1617,24 +1619,32 @@ def __init__(
headers: typing.Optional[typing.Dict[str, str]] = None,
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
aiohttp_session: typing.Optional[aiohttp.ClientSession] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None, # Deprecated, kept for compatibility
):
_defaulted_timeout = (
timeout if timeout is not None else 300 if httpx_client is None else httpx_client.timeout.read
)
_defaulted_timeout = timeout if timeout is not None else 300
if token is None:
raise ApiError(body="The client must be instantiated be either passing in token or setting CO_API_KEY")

# Create aiohttp session if not provided.
# NOTE: force_close is intentionally NOT derived from follow_redirects.
# force_close controls TCP connection reuse (keep-alive pooling); setting it True
# causes every request to open and close a fresh TCP socket, exhausting the ephemeral
# port range (TIME_WAIT) when making thousands of concurrent calls.
# Redirect behaviour is handled per-request via allow_redirects instead.
if aiohttp_session is None:
timeout_config = aiohttp.ClientTimeout(total=_defaulted_timeout)
connector = aiohttp.TCPConnector()
aiohttp_session = aiohttp.ClientSession(timeout=timeout_config, connector=connector)

self._client_wrapper = AsyncClientWrapper(
base_url=_get_base_url(base_url=base_url, environment=environment),
client_name=client_name,
token=token,
headers=headers,
httpx_client=httpx_client
if httpx_client is not None
else httpx.AsyncClient(timeout=_defaulted_timeout, follow_redirects=follow_redirects)
if follow_redirects is not None
else httpx.AsyncClient(timeout=_defaulted_timeout),
aiohttp_session=aiohttp_session,
timeout=_defaulted_timeout,
follow_redirects=follow_redirects if follow_redirects is not None else True,
)
self._raw_client = AsyncRawBaseCohere(client_wrapper=self._client_wrapper)
self._v2: typing.Optional[AsyncV2Client] = None
Expand Down Expand Up @@ -1964,7 +1974,9 @@ async def main() -> None:
request_options=request_options,
) as r:
async for _chunk in r.data:
yield _chunk
# Skip None chunks (e.g., from [DONE] markers in SSE streams)
if _chunk is not None:
yield _chunk

async def chat(
self,
Expand Down Expand Up @@ -2427,7 +2439,9 @@ async def main() -> None:
request_options=request_options,
) as r:
async for _chunk in r.data:
yield _chunk
# Skip None chunks (e.g., from [DONE] markers in SSE streams)
if _chunk is not None:
yield _chunk

async def generate(
self,
Expand Down
7 changes: 5 additions & 2 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tokenizers import Tokenizer # type: ignore
import logging

import aiohttp
import httpx

from cohere.types.detokenize_response import DetokenizeResponse
Expand Down Expand Up @@ -331,7 +332,8 @@ def __init__(
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
aiohttp_session: typing.Optional["aiohttp.ClientSession"] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None, # Deprecated
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
Expand All @@ -349,6 +351,7 @@ def __init__(
client_name=client_name,
token=api_key,
timeout=timeout,
aiohttp_session=aiohttp_session,
httpx_client=httpx_client,
)

Expand All @@ -365,7 +368,7 @@ async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self._client_wrapper.httpx_client.httpx_client.aclose()
await self._client_wrapper.httpx_client.aiohttp_session.close()

wait = async_wait

Expand Down
5 changes: 4 additions & 1 deletion src/cohere/client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import typing
from concurrent.futures import ThreadPoolExecutor

import aiohttp
import httpx
from .client import AsyncClient, Client
from .environment import ClientEnvironment
Expand Down Expand Up @@ -71,7 +72,8 @@ def __init__(
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
aiohttp_session: typing.Optional["aiohttp.ClientSession"] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None, # Deprecated
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
Expand All @@ -82,6 +84,7 @@ def __init__(
environment=environment,
client_name=client_name,
timeout=timeout,
aiohttp_session=aiohttp_session,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
Expand Down
7 changes: 5 additions & 2 deletions src/cohere/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing

import aiohttp
import httpx
from .http_client import AsyncHttpClient, HttpClient

Expand Down Expand Up @@ -85,16 +86,18 @@ def __init__(
base_url: str,
timeout: typing.Optional[float] = None,
async_token: typing.Optional[typing.Callable[[], typing.Awaitable[str]]] = None,
httpx_client: httpx.AsyncClient,
aiohttp_session: aiohttp.ClientSession,
follow_redirects: bool = True,
):
super().__init__(client_name=client_name, token=token, headers=headers, base_url=base_url, timeout=timeout)
self._async_token = async_token
self.httpx_client = AsyncHttpClient(
httpx_client=httpx_client,
aiohttp_session=aiohttp_session,
base_headers=self.get_headers,
base_timeout=self.get_timeout,
base_url=self.get_base_url,
async_base_headers=self.async_get_headers,
follow_redirects=follow_redirects,
)

async def async_get_headers(self) -> typing.Dict[str, str]:
Expand Down
57 changes: 56 additions & 1 deletion src/cohere/core/file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# This file was auto-generated by Fern from our API Definition.

from typing import IO, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import IO, Any, Dict, List, Mapping, Optional, Tuple, Union, cast

try:
import aiohttp
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False

# File typing inspired by the flexibility of types within the httpx library
# https://github.com/encode/httpx/blob/master/httpx/_types.py
Expand Down Expand Up @@ -65,3 +71,52 @@ def with_content_type(*, file: File, default_content_type: str) -> File:
else:
raise ValueError(f"Unexpected tuple length: {len(file)}")
return (None, file, default_content_type)


def build_aiohttp_form_data(
files: Dict[str, Union[File, List[File]]],
data: Optional[Any] = None,
) -> "aiohttp.FormData":
"""
Convert file dict to aiohttp FormData format.
Similar to convert_file_dict_to_httpx_tuples but for aiohttp.
"""
if not HAS_AIOHTTP:
raise ImportError("aiohttp is required for async file uploads")

form = aiohttp.FormData()

# Add regular data fields first
if data is not None and isinstance(data, dict):
for key, value in data.items():
if value is not None:
form.add_field(key, str(value))

# Add file fields
for key, file_like in files.items():
if isinstance(file_like, list):
for file_item in file_like:
_add_file_to_form(form, key, file_item)
else:
_add_file_to_form(form, key, file_like)

return form


def _add_file_to_form(form: "aiohttp.FormData", name: str, file: File) -> None:
"""Helper to add a single file to aiohttp FormData"""
if isinstance(file, tuple):
if len(file) == 2:
filename, content = file
form.add_field(name, content, filename=filename)
elif len(file) == 3:
filename, content, content_type = file
form.add_field(name, content, filename=filename, content_type=content_type)
elif len(file) == 4:
filename, content, content_type, headers = file
# aiohttp FormData doesn't support custom headers per field easily
# Use content_type and filename
form.add_field(name, content, filename=filename, content_type=content_type)
else:
# Simple file content
form.add_field(name, file)
Loading
Loading