Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/cosmos/azure-cosmos/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ markers =
cosmosEmulator: marks tests as depending in Cosmos DB Emulator.
cosmosLong: marks tests to be run on a Cosmos DB live account.
cosmosQuery: marks tests running queries on Cosmos DB live account.
cosmosAADLong: marks AAD tests for the standard live-account lane.
cosmosAADSplit: marks AAD tests for partition split scenarios.
cosmosAADMultiRegion: marks AAD tests for multi-region scenarios.
cosmosAADCircuitBreaker: marks AAD tests for circuit-breaker scenarios.
cosmosAADQuery: marks AAD tests for query-focused scenarios.
cosmosAADPerPartitionAutomaticFailover: marks AAD tests for per-partition automatic failover scenarios.
cosmosSplit: marks test where there are partition splits on CosmosDB live account.
cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled.
cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled.
Expand Down
26 changes: 15 additions & 11 deletions sdk/cosmos/azure-cosmos/tests/test_aad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

import azure.cosmos.cosmos_client as cosmos_client
import test_config
from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions
from azure.cosmos import DatabaseProxy, ContainerProxy
from azure.core.exceptions import HttpResponseError



def _remove_padding(encoded_string):
while encoded_string.endswith("="):
encoded_string = encoded_string[0:len(encoded_string) - 1]
Expand All @@ -35,7 +37,7 @@ def get_test_item(num):

class CosmosEmulatorCredential(object):
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
# type: (*str, **object) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.

This method is called automatically by Azure SDK clients.
Expand Down Expand Up @@ -93,10 +95,16 @@ class TestAAD(unittest.TestCase):
configs = test_config.TestConfig
host = configs.host
masterKey = configs.masterKey
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential
# Emulator-only credential used by this class.
credential = CosmosEmulatorCredential()
_skip_scope_tests_on_non_emulator = pytest.mark.skipif(
not configs.is_emulator,
reason="Scope capture tests are emulator-specific (localhost audience)."
)

@classmethod
def setUpClass(cls):
# Emulator-only path: always use the emulator credential.
cls.client = cosmos_client.CosmosClient(cls.host, cls.credential)
cls.database = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID)
cls.container = cls.database.get_container_client(cls.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
Expand All @@ -110,14 +118,6 @@ def test_aad_credentials(self):
print("Query result: " + str(query_results[0]))
self.container.delete_item(item='Item_0', partition_key='pk')

# Attempting to do management operations will return a 403 Forbidden exception
try:
self.client.delete_database(self.configs.TEST_DATABASE_ID)
except exceptions.CosmosHttpResponseError as e:
assert e.status_code == 403
print("403 error assertion success")


def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs):
scopes_captured = []
original_get_token = credential_cls.get_token
Expand All @@ -133,6 +133,7 @@ def capturing_get_token(self, *scopes, **kwargs):
credential_cls.get_token = original_get_token
return scopes_captured, result

@_skip_scope_tests_on_non_emulator
def test_override_scope_no_fallback(self):
"""When override scope is provided, only that scope is used and no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand All @@ -156,6 +157,7 @@ def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
def test_override_scope_auth_error_no_fallback(self):
"""When override scope is provided and auth fails, no fallback to other scopes occurs."""
override_scope = "https://my.custom.scope/.default"
Expand All @@ -180,6 +182,7 @@ def action(scopes_captured):
finally:
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]

@_skip_scope_tests_on_non_emulator
def test_account_scope_only(self):
"""When account scope is provided, only that scope is used."""
account_scope = "https://localhost/.default"
Expand All @@ -203,6 +206,7 @@ def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
def test_account_scope_fallback_on_error(self):
"""When account scope is provided and auth fails, fallback to default scope occurs."""
account_scope = "https://localhost/.default"
Expand Down
32 changes: 12 additions & 20 deletions sdk/cosmos/azure-cosmos/tests/test_aad_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from azure.core.credentials import AccessToken

import test_config
from azure.cosmos import exceptions
from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy
from azure.core.exceptions import HttpResponseError



def _remove_padding(encoded_string):
while encoded_string.endswith("="):
encoded_string = encoded_string[0:len(encoded_string) - 1]
Expand All @@ -35,7 +36,7 @@ def get_test_item(num):

class CosmosEmulatorCredential(object):
async def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
# type: (*str, **object) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.

This method is called automatically by Azure SDK clients.
Expand Down Expand Up @@ -93,16 +94,11 @@ class TestAADAsync(unittest.IsolatedAsyncioTestCase):
configs = test_config.TestConfig
host = configs.host
masterKey = configs.masterKey
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential_async

@classmethod
def setUpClass(cls):
if (cls.credential == '[YOUR_KEY_HERE]' or
cls.host == '[YOUR_ENDPOINT_HERE]'):
raise Exception(
"You must specify your Azure Cosmos account values for "
"'masterKey' and 'host' at the top of this class to run the "
"tests.")
credential = CosmosEmulatorCredential()
_skip_scope_tests_on_non_emulator = pytest.mark.skipif(
not configs.is_emulator,
reason="Scope capture tests are emulator-specific (localhost audience)."
)

async def asyncSetUp(self):
self.client = CosmosClient(self.host, self.credential)
Expand All @@ -113,8 +109,6 @@ async def asyncTearDown(self):
await self.client.close()

async def test_aad_credentials_async(self):
# Do any R/W data operations with your authorized AAD client

print("Container info: " + str(await self.container.read()))
await self.container.create_item(get_test_item(0))
print("Point read result: " + str(await self.container.read_item(item='Item_0', partition_key='pk')))
Expand All @@ -123,12 +117,6 @@ async def test_aad_credentials_async(self):
print("Query result: " + str(query_results[0]))
await self.container.delete_item(item='Item_0', partition_key='pk')

# Attempting to do management operations will return a 403 Forbidden exception
try:
await self.client.delete_database(self.configs.TEST_DATABASE_ID)
except exceptions.CosmosHttpResponseError as e:
assert e.status_code == 403
print("403 error assertion success")

async def _run_with_scope_capture_async(self, credential_cls, action):
scopes_captured = []
Expand All @@ -146,6 +134,7 @@ async def capturing_get_token(self, *scopes, **kwargs):
finally:
credential_cls.get_token = orig_get_token

@_skip_scope_tests_on_non_emulator
async def test_override_scope_no_fallback_async(self):
"""When override scope is provided, only that scope is used and no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand All @@ -172,6 +161,7 @@ async def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
async def test_override_scope_no_fallback_on_error_async(self):
"""When override scope is provided and auth fails, no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand Down Expand Up @@ -205,6 +195,7 @@ async def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
async def test_account_scope_only_async(self):
"""When account scope is provided, only that scope is used."""
account_scope = "https://localhost/.default"
Expand All @@ -230,6 +221,7 @@ async def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
async def test_account_scope_fallback_on_error_async(self):
"""When account scope is provided and auth fails, fallback to default scope occurs."""
account_scope = "https://localhost/.default"
Expand Down
81 changes: 75 additions & 6 deletions sdk/cosmos/azure-cosmos/tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,57 @@
from azure.cosmos.partition_key import PartitionKey


print("[AAD DEBUG] imported test_aggregate.py", flush=True)


class _config:
is_aad_mode = test_config.TestConfig.data_auth_mode == "aad"
host = test_config.TestConfig.host
master_key = test_config.TestConfig.masterKey
connection_policy = test_config.TestConfig.connectionPolicy
PARTITION_KEY = 'key'
UNIQUE_PARTITION_KEY = 'uniquePartitionKey'
FIELD = 'field'
DOCUMENTS_COUNT = 400
DOCS_WITH_SAME_PARTITION_KEY = 200
# Keep key-auth query coverage unchanged; trim only AAD runs to stay under CI timeout.
DOCUMENTS_COUNT = 120 if is_aad_mode else 400
DOCS_WITH_SAME_PARTITION_KEY = 60 if is_aad_mode else 200
docs_with_numeric_id = 0
sum = 0


@pytest.mark.cosmosQuery
class TestAggregateQuery(unittest.TestCase):
client: cosmos_client.CosmosClient = None
key_client: cosmos_client.CosmosClient = None

@classmethod
def setUpClass(cls):
print("[AAD DEBUG] TestAggregateQuery.setUpClass start", flush=True)
cls._all_tests = []
cls._setup()
cls._generate_test_configs()
print("[AAD DEBUG] TestAggregateQuery.setUpClass end", flush=True)

@classmethod
def tearDownClass(cls) -> None:
try:
cls.created_db.delete_container(cls.created_collection.id)
cls.key_db.delete_container(cls.created_collection.id)
except CosmosHttpResponseError:
pass

@classmethod
def _setup(cls):
print("[AAD DEBUG] TestAggregateQuery._setup start", flush=True)
if not _config.master_key or not _config.host:
raise Exception(
"You must specify your Azure Cosmos account values for "
"'masterKey' and 'host' at the top of this class to run the "
"tests.")

cls.client = cosmos_client.CosmosClient(_config.host, _config.master_key)
cls.created_db = cls.client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)
cls.created_collection = cls._create_collection(cls.created_db)
cls.key_client, cls.key_db, cls.client, cls.created_db = (
test_config.TestConfig.create_test_clients(test_config.TestConfig.TEST_DATABASE_ID))
created_collection_ref = cls._create_collection(cls.key_db)
cls.created_collection = cls.created_db.get_container_client(created_collection_ref.id)

# test documents
document_definitions = []
Expand All @@ -81,6 +91,7 @@ def _setup(cls):
* (_config.docs_with_numeric_id + 1) / 2.0

cls._insert_doc(cls.created_collection, document_definitions)
print("[AAD DEBUG] TestAggregateQuery._setup end", flush=True)

@classmethod
def _generate_test_configs(cls):
Expand Down Expand Up @@ -129,6 +140,7 @@ def _generate_test_configs(cls):
Exception()])

def test_run_all(self):
print("[AAD DEBUG] TestAggregateQuery.test_run_all start", flush=True)
for test_name, query, expected_result in self._all_tests:
test_name = "test_%s" % test_name
try:
Expand All @@ -138,6 +150,63 @@ def test_run_all(self):
print(test_name + ': ' + query + " FAILED")
raise e

# AAD-only smoke subset.
#
# Why this exists: the CI AAD lane runs on Linux and the shared
# ``azpysdk.main whl --isolate`` bootstrap on that pool already eats
# ~90 minutes of the 120-minute job ceiling. Running the full
# ``test_run_all`` matrix (24 aggregate variants) under AAD on top of
# that bootstrap pushes the lane over the ceiling. The full matrix
# still runs under the ``cosmosQuery`` lane (key auth) -- this method
# is *additional* AAD-only coverage focused on Contoso's exact bug
# shape: cross-partition aggregate query under bearer auth, including
# the ORDER BY pagination case where token refresh mid-stream is most
# likely to surface.
#
# Three queries: cross-partition COUNT (fan-out), cross-partition SUM
# with ORDER BY (fan-out + paginated reduce -> token-refresh window),
# single-partition AVG (pinned-PK path).
@pytest.mark.cosmosAADLong
@pytest.mark.skipif(
test_config.TestConfig.data_auth_mode != "aad",
reason="AAD-only smoke subset; full coverage runs under cosmosQuery (key auth).",
)
def test_aad_aggregate_subset(self):
print("[AAD DEBUG] TestAggregateQuery.test_aad_aggregate_subset start", flush=True)
same_partition_avg = (
_config.DOCS_WITH_SAME_PARTITION_KEY * (_config.DOCS_WITH_SAME_PARTITION_KEY + 1) / 2.0
) / _config.DOCS_WITH_SAME_PARTITION_KEY
subset = [
(
"test_aad_xp_count",
"SELECT VALUE COUNT(r.{}) FROM r WHERE true".format(_config.PARTITION_KEY),
_config.DOCUMENTS_COUNT,
),
(
"test_aad_xp_sum_orderby",
"SELECT VALUE SUM(r.{f}) FROM r WHERE IS_NUMBER(r.{pk}) ORDER BY r.{pk}".format(
f=_config.PARTITION_KEY, pk=_config.PARTITION_KEY
),
_config.sum,
),
(
"test_aad_sp_avg",
"SELECT VALUE AVG(r.{f}) FROM r WHERE r.{pk} = '{val}'".format(
f=_config.FIELD,
pk=_config.PARTITION_KEY,
val=_config.UNIQUE_PARTITION_KEY,
),
same_partition_avg,
),
]
for test_name, query, expected in subset:
try:
self._run_one(query, expected)
print(test_name + ': ' + query + " PASSED", flush=True)
except Exception as e:
print(test_name + ': ' + query + " FAILED", flush=True)
raise e

def _run_one(self, query, expected_result):
self._execute_query_and_validate_results(self.created_collection, query, expected_result)

Expand Down
Loading
Loading