Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
)
request.url = URL(url)
request.headers["host"] = request.url.host
headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)
Expand Down
41 changes: 33 additions & 8 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@ class Client:
def __init__(
self,
aws_region: typing.Optional[str] = None,
mode: Mode = Mode.SAGEMAKER,
):
"""
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self.mode = mode
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
self.mode = Mode.SAGEMAKER

if self.mode == Mode.SAGEMAKER:
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
elif self.mode == Mode.BEDROCK:
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
self._sess = None
self._endpoint_name = None

def _require_sagemaker(self) -> None:
if self.mode != Mode.SAGEMAKER:
raise CohereError("This method is only supported in SageMaker mode.")

def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
Expand All @@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
Raises:
CohereError: Connection to the endpoint failed.
"""
self._require_sagemaker()
if not self._does_endpoint_exist(endpoint_name):
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
self._endpoint_name = endpoint_name
Expand Down Expand Up @@ -137,6 +148,7 @@ def create_endpoint(
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
"""
self._require_sagemaker()
# First, check if endpoint already exists
if self._does_endpoint_exist(endpoint_name):
if recreate:
Expand Down Expand Up @@ -550,11 +562,15 @@ def embed(
variant: Optional[str] = None,
input_type: Optional[str] = None,
model_id: Optional[str] = None,
) -> Embeddings:
output_dimension: Optional[int] = None,
embedding_types: Optional[List[str]] = None,
) -> Union[Embeddings, Dict[str, List]]:
json_params = {
'texts': texts,
'truncate': truncate,
"input_type": input_type
"input_type": input_type,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
}
for key, value in list(json_params.items()):
if value is None:
Expand Down Expand Up @@ -591,7 +607,10 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))

return Embeddings(response['embeddings'])
embeddings = response['embeddings']
if isinstance(embeddings, dict):
return embeddings
return Embeddings(embeddings)

def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
if not model_id:
Expand All @@ -612,7 +631,10 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))

return Embeddings(response['embeddings'])
embeddings = response['embeddings']
if isinstance(embeddings, dict):
return embeddings
return Embeddings(embeddings)


def rerank(self,
Expand Down Expand Up @@ -805,6 +827,7 @@ def export_finetune(
This should work when one uses the client inside SageMaker. If this errors out,
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
"""
self._require_sagemaker()
if name == "model":
raise ValueError("name cannot be 'model'")

Expand Down Expand Up @@ -948,6 +971,7 @@ def summarize(
additional_command: Optional[str] = "",
variant: Optional[str] = None
) -> Summary:
self._require_sagemaker()

if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
Expand Down Expand Up @@ -989,6 +1013,7 @@ def summarize(


def delete_endpoint(self) -> None:
self._require_sagemaker()
if self._endpoint_name is None:
raise CohereError("No endpoint connected.")
try:
Expand Down
252 changes: 252 additions & 0 deletions tests/test_aws_client_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"""
Unit tests (mocked, no AWS credentials needed) for AWS client fixes.

Covers:
- Fix 1: SigV4 signing uses the correct host header after URL rewrite
- Fix 2: cohere_aws.Client conditionally initializes based on mode
- Fix 3: embed() accepts and passes output_dimension and embedding_types
"""

import inspect
import json
import os
import unittest
from unittest.mock import MagicMock, patch

import httpx

from cohere.manually_maintained.cohere_aws.mode import Mode


class TestSigV4HostHeader(unittest.TestCase):
"""Fix 1: The headers dict passed to AWSRequest for SigV4 signing must
contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com."""

def test_sigv4_signs_with_correct_host(self) -> None:
captured_aws_request_kwargs: dict = {}

mock_aws_request_cls = MagicMock()

def capture_aws_request(**kwargs): # type: ignore
captured_aws_request_kwargs.update(kwargs)
mock_req = MagicMock()
mock_req.prepare.return_value = MagicMock(
headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"}
)
return mock_req

mock_aws_request_cls.side_effect = capture_aws_request

mock_botocore = MagicMock()
mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls
mock_botocore.auth.SigV4Auth.return_value = MagicMock()

mock_boto3 = MagicMock()
mock_session = MagicMock()
mock_session.region_name = "us-east-1"
mock_session.get_credentials.return_value = MagicMock()
mock_boto3.Session.return_value = mock_session

with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3):

from cohere.aws_client import map_request_to_bedrock

hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1")

request = httpx.Request(
method="POST",
url="https://api.cohere.com/v1/chat",
headers={"connection": "keep-alive"},
json={"model": "cohere.command-r-plus-v1:0", "message": "hello"},
)

self.assertEqual(request.url.host, "api.cohere.com")

hook(request)

self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url))

signed_headers = captured_aws_request_kwargs["headers"]
self.assertEqual(
signed_headers["host"],
"bedrock-runtime.us-east-1.amazonaws.com",
)


class TestModeConditionalInit(unittest.TestCase):
"""Fix 2: cohere_aws.Client should initialize different boto3 clients
depending on mode, and default to SAGEMAKER for backwards compat."""

def test_sagemaker_mode_creates_sagemaker_clients(self) -> None:
mock_boto3 = MagicMock()
mock_sagemaker = MagicMock()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1")

self.assertEqual(client.mode, Mode.SAGEMAKER)

service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
self.assertIn("sagemaker-runtime", service_names)
self.assertIn("sagemaker", service_names)
self.assertNotIn("bedrock-runtime", service_names)
self.assertNotIn("bedrock", service_names)

mock_sagemaker.Session.assert_called_once()

def test_bedrock_mode_creates_bedrock_clients(self) -> None:
mock_boto3 = MagicMock()
mock_sagemaker = MagicMock()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-west-2", mode=Mode.BEDROCK)

self.assertEqual(client.mode, Mode.BEDROCK)

service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
self.assertIn("bedrock-runtime", service_names)
self.assertIn("bedrock", service_names)
self.assertNotIn("sagemaker-runtime", service_names)
self.assertNotIn("sagemaker", service_names)

mock_sagemaker.Session.assert_not_called()

def test_default_mode_is_sagemaker(self) -> None:
from cohere.manually_maintained.cohere_aws.client import Client

sig = inspect.signature(Client.__init__)
self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER)


class TestEmbedV4Params(unittest.TestCase):
"""Fix 3: embed() should accept output_dimension and embedding_types,
pass them through to the request body, and strip them when None."""

@staticmethod
def _make_bedrock_client(): # type: ignore
mock_boto3 = MagicMock()
mock_botocore = MagicMock()
captured_body: dict = {}

def fake_invoke_model(**kwargs): # type: ignore
captured_body.update(json.loads(kwargs["body"]))
mock_body = MagicMock()
mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode()
return {"body": mock_body}

mock_bedrock_client = MagicMock()
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model

def fake_boto3_client(service_name, **kwargs): # type: ignore
if service_name == "bedrock-runtime":
return mock_bedrock_client
return MagicMock()

mock_boto3.client.side_effect = fake_boto3_client
return mock_boto3, mock_botocore, captured_body

def test_embed_accepts_new_params(self) -> None:
from cohere.manually_maintained.cohere_aws.client import Client

sig = inspect.signature(Client.embed)
self.assertIn("output_dimension", sig.parameters)
self.assertIn("embedding_types", sig.parameters)
self.assertIsNone(sig.parameters["output_dimension"].default)
self.assertIsNone(sig.parameters["embedding_types"].default)

def test_embed_passes_params_to_bedrock(self) -> None:
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
output_dimension=256,
embedding_types=["float", "int8"],
)

self.assertEqual(captured_body["output_dimension"], 256)
self.assertEqual(captured_body["embedding_types"], ["float", "int8"])

def test_embed_omits_none_params(self) -> None:
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
)

self.assertNotIn("output_dimension", captured_body)
self.assertNotIn("embedding_types", captured_body)

def test_embed_with_embedding_types_returns_dict(self) -> None:
"""When embedding_types is specified, the API returns embeddings as a dict.
The client should return that dict rather than wrapping it in Embeddings."""
mock_boto3 = MagicMock()
mock_botocore = MagicMock()

by_type_embeddings = {"float": [[0.1, 0.2]], "int8": [[1, 2]]}

def fake_invoke_model(**kwargs): # type: ignore
mock_body = MagicMock()
mock_body.read.return_value = json.dumps({
"embeddings": by_type_embeddings,
"response_type": "embeddings_by_type",
}).encode()
return {"body": mock_body}

mock_bedrock_client = MagicMock()
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model

def fake_boto3_client(service_name, **kwargs): # type: ignore
if service_name == "bedrock-runtime":
return mock_bedrock_client
return MagicMock()

mock_boto3.client.side_effect = fake_boto3_client

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
result = client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
embedding_types=["float", "int8"],
)

self.assertIsInstance(result, dict)
self.assertEqual(result, by_type_embeddings)
Loading