Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,30 @@ public void serialize_Deserialize() throws Exception {
.isEqualTo(altLink);
assertThat(deserializedDocumentCollection.getWrappedItem().toJson()).isEqualTo(collection.toJson());
}

@Test(groups = { "unit" })
public void deserializeWithInvalidClassType_shouldFail() throws Exception {
// Create a malicious payload with a different class type instead of ObjectNode
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream objectOutputStream = new ObjectOutputStream(baos);

// Write a malicious object instead of ObjectNode
objectOutputStream.writeObject("MaliciousString");
objectOutputStream.flush();
objectOutputStream.close();

// Attempt to deserialize - should fail with InvalidClassException
byte[] bytes = baos.toByteArray();
try {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
SerializableDocumentCollection deserializedDocumentCollection = (SerializableDocumentCollection) ois.readObject();

// Should not reach here
org.testng.Assert.fail("Expected InvalidClassException to be thrown");
} catch (java.io.InvalidClassException e) {
// Expected - the malicious class type was rejected
assertThat(e.getMessage()).contains("Expected ObjectNode");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.
package com.azure.cosmos.implementation.caches;

import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot;
import com.azure.cosmos.implementation.DocumentCollection;
import com.azure.cosmos.implementation.apachecommons.lang.RandomStringUtils;
import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils;
Expand Down Expand Up @@ -116,6 +117,77 @@ private DocumentCollection generateDocumentCollectionDefinition() {

return collection;
}

@Test(groups = { "unit" })
public void deserializeWithInvalidClassType_shouldFail() throws Exception {
// Create a malicious payload with a different class type
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);

// Write a valid SerializableAsyncCollectionCache structure but with a malicious value
oos.writeInt(1); // size = 1
oos.writeUTF("testKey"); // key

// Write a malicious object instead of SerializableDocumentCollection
oos.writeObject("MaliciousString");

// Write the equality comparer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are failing as well please fix them c.a.c.i.c.SerializableAsyncCacheTest.deserializeWithInvalidClassType_shouldFail
c.a.c.SerializableDocumentCollectionTests.deserializeWithInvalidClassType_shouldFail
c.a.c.i.c.SerializableAsyncCacheTest.deserializeWithInvalidClassType_shouldFail
c.a.c.i.c.SerializableAsyncCacheTest.deserializeWithInvalidClassType_shouldFail
c.a.c.i.c.SerializableAsyncCacheTest.safeObjectInputStream_rejectsUnauthorizedClasses
c.a.c.i.c.SerializableAsyncCacheTest.safeObjectInputStream_rejectsUnauthorizedClasses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit ad3bed2. The tests were failing because the allowlist was incomplete - it only included the top-level class but ObjectInputStream.resolveClass() is called for every class in the deserialization chain. The allowlist now includes all transitively deserialized classes.

oos.writeObject((IEqualityComparer<DocumentCollection>) (v1, v2) -> v1 == v2);
oos.flush();

byte[] maliciousBytes = baos.toByteArray();

// Attempt to deserialize - should fail with InvalidClassException
try {
ByteArrayInputStream bais = new ByteArrayInputStream(maliciousBytes);
ObjectInputStream ois = new ObjectInputStream(bais);
SerializableAsyncCollectionCache cache = (SerializableAsyncCollectionCache) ois.readObject();

// Should not reach here
org.testng.Assert.fail("Expected InvalidClassException to be thrown");
} catch (java.io.InvalidClassException e) {
// Expected - the malicious class type was rejected
assertThat(e.getMessage()).contains("Expected SerializableDocumentCollection");
}
}

@Test(groups = { "unit" })
public void safeObjectInputStream_rejectsUnauthorizedClasses() throws Exception {
// Create a malicious payload with an unauthorized class
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);

// Write a malicious object (String instead of SerializableAsyncCollectionCache)
oos.writeObject("MaliciousPayload");
oos.flush();

byte[] maliciousBytes = baos.toByteArray();

// Create a CosmosClientMetadataCachesSnapshot with the malicious payload
CosmosClientMetadataCachesSnapshot snapshot = new CosmosClientMetadataCachesSnapshot();
snapshot.collectionInfoByNameCache = maliciousBytes;

// Attempt to deserialize - should fail with InvalidClassException
try {
AsyncCache<String, DocumentCollection> cache = snapshot.getCollectionInfoByNameCache();

// Should not reach here
org.testng.Assert.fail("Expected exception to be thrown for unauthorized class");
} catch (Exception e) {
// Expected - the unauthorized class was rejected
// The exception could be wrapped in a CosmosException, so check the cause chain
Throwable cause = e;
boolean foundInvalidClassException = false;
while (cause != null && !foundInvalidClassException) {
if (cause instanceof java.io.InvalidClassException) {
foundInvalidClassException = true;
assertThat(cause.getMessage()).contains("Unauthorized deserialization attempt");
}
cause = cause.getCause();
}
assertThat(foundInvalidClassException).as("Expected InvalidClassException in cause chain").isTrue();
}
}
}


3 changes: 2 additions & 1 deletion sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
#### Bugs Fixed
* Fixed an issue where `query plan` failed with `400` or query return empty result when `CosmosQueryRequestOptions` has partition key filter and partition key value contains non-ascii character. See [PR 47881](https://github.com/Azure/azure-sdk-for-java/pull/47881)
* Fixed an issue where operation failed with `400` when configured with pre-trigger or post-trigger with non-ascii character. Only impact for gateway mode. See [PR 47881](https://github.com/Azure/azure-sdk-for-java/pull/47881)
* Fixed Remote Code Execution (RCE) vulnerability via unsafe Java deserialization in `CosmosClientMetadataCachesSnapshot`, `AsyncCache`, and `DocumentCollection`. Added `SafeObjectInputStream` with class allowlisting to prevent deserialization of unauthorized classes. See [PR 47971](https://github.com/Azure/azure-sdk-for-java/pull/47971)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Fixed Remote Code Execution (RCE) vulnerability via unsafe Java deserialization in `CosmosClientMetadataCachesSnapshot`, `AsyncCache`, and `DocumentCollection`. Added `SafeObjectInputStream` with class allowlisting to prevent deserialization of unauthorized classes. See [PR 47971](https://github.com/Azure/azure-sdk-for-java/pull/47971)
* Fixed Remote Code Execution (RCE) vulnerability via unsafe Java deserialization in `CosmosClientMetadataCachesSnapshot`, `AsyncCache`, and `DocumentCollection`. Added `SafeObjectInputStream` with class allowlisting to prevent deserialization of unauthorized classes. - [PR 47971](https://github.com/Azure/azure-sdk-for-java/pull/47971)


#### Other Changes
* Added `x-ms-hub-region-processing-only` header to allow hub-region stickiness when 404 `READ SESSION NOT AVAIALBLE` is hit for Single-Writer accounts. - [PR 47631](https://github.com/Azure/azure-sdk-for-java/pull/47631)
* Added `x-ms-hub-region-processing-only` header to allow hub-region stickiness when 404 `READ SESSION NOT AVAILABLE` is hit for Single-Writer accounts. - [PR 47631](https://github.com/Azure/azure-sdk-for-java/pull/47631)

### 4.77.0 (2026-01-26)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosBridgeInternal;
import com.azure.cosmos.implementation.caches.AsyncCache;
import com.azure.cosmos.implementation.caches.SafeObjectInputStream;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;

public class CosmosClientMetadataCachesSnapshot implements Serializable {
private static final long serialVersionUID = 1l;
Expand Down Expand Up @@ -53,8 +55,32 @@ private byte[] serializeAsyncCollectionCache(AsyncCache<String, DocumentCollecti

public AsyncCache<String, DocumentCollection> getCollectionInfoByNameCache() {
try {
// Create allowlist for all classes that may be deserialized
SafeObjectInputStream ois = new SafeObjectInputStream(
new ByteArrayInputStream(collectionInfoByNameCache),
new HashSet<>(Arrays.asList(
// Top-level serialized cache class
AsyncCache.SerializableAsyncCache.SerializableAsyncCollectionCache.class.getName(),
// Nested classes deserialized by SerializableAsyncCollectionCache
DocumentCollection.SerializableDocumentCollection.class.getName(),
// Jackson classes used by SerializableDocumentCollection
"com.fasterxml.jackson.databind.node.ObjectNode",
"com.fasterxml.jackson.databind.node.TextNode",
// Internal Jackson classes that may be involved
"com.fasterxml.jackson.databind.node.BaseJsonNode",
"com.fasterxml.jackson.databind.node.ContainerNode",
"com.fasterxml.jackson.databind.node.ValueNode",
"com.fasterxml.jackson.databind.JsonNode",
// Equality comparer - we skip deserialization but still need to allow reading it
"com.azure.cosmos.implementation.caches.RxCollectionCache$CollectionRidComparer",
// Java collections and concurrent classes used internally
"java.util.concurrent.ConcurrentHashMap",
"java.util.HashMap",
"java.util.LinkedHashMap"
))
);
return ((AsyncCache.SerializableAsyncCache.SerializableAsyncCollectionCache)
new ObjectInputStream(new ByteArrayInputStream(collectionInfoByNameCache)).readObject())
ois.readObject())
.toAsyncCache();
} catch (IOException | ClassNotFoundException e) {
throw CosmosBridgeInternal.cosmosException(ERROR_CODE, e);
Expand All @@ -63,8 +89,32 @@ public AsyncCache<String, DocumentCollection> getCollectionInfoByNameCache() {

public AsyncCache<String, DocumentCollection> getCollectionInfoByIdCache() {
try {
// Create allowlist for all classes that may be deserialized
SafeObjectInputStream ois = new SafeObjectInputStream(
new ByteArrayInputStream(collectionInfoByIdCache),
new HashSet<>(Arrays.asList(
// Top-level serialized cache class
AsyncCache.SerializableAsyncCache.SerializableAsyncCollectionCache.class.getName(),
// Nested classes deserialized by SerializableAsyncCollectionCache
DocumentCollection.SerializableDocumentCollection.class.getName(),
// Jackson classes used by SerializableDocumentCollection
"com.fasterxml.jackson.databind.node.ObjectNode",
"com.fasterxml.jackson.databind.node.TextNode",
// Internal Jackson classes that may be involved
"com.fasterxml.jackson.databind.node.BaseJsonNode",
"com.fasterxml.jackson.databind.node.ContainerNode",
"com.fasterxml.jackson.databind.node.ValueNode",
"com.fasterxml.jackson.databind.JsonNode",
// Equality comparer - we skip deserialization but still need to allow reading it
"com.azure.cosmos.implementation.caches.RxCollectionCache$CollectionRidComparer",
// Java collections and concurrent classes used internally
"java.util.concurrent.ConcurrentHashMap",
"java.util.HashMap",
"java.util.LinkedHashMap"
))
);
return ((AsyncCache.SerializableAsyncCache.SerializableAsyncCollectionCache)
new ObjectInputStream(new ByteArrayInputStream(collectionInfoByIdCache)).readObject())
ois.readObject())
.toAsyncCache();
} catch (IOException | ClassNotFoundException e) {
throw CosmosBridgeInternal.cosmosException(ERROR_CODE, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.fasterxml.jackson.databind.node.TextNode;

import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collection;
Expand Down Expand Up @@ -537,7 +538,17 @@ private void writeObject(ObjectOutputStream objectOutputStream) throws IOExcepti
}

private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
ObjectNode objectNode = (ObjectNode) objectInputStream.readObject();
Object obj = objectInputStream.readObject();
// Security fix: Validate that the deserialized object is the expected type before casting.
// Without this check, an attacker could provide a malicious object that would be blindly cast to ObjectNode,
// potentially leading to code execution vulnerabilities.
if (!(obj instanceof ObjectNode)) {
throw new InvalidClassException(
"Expected ObjectNode but got " +
(obj == null ? "null" : obj.getClass().getName())
);
}
ObjectNode objectNode = (ObjectNode) obj;
ObjectNode collectionNode = (ObjectNode)objectNode.get(COLLECTIONS_ROOT_PROPERTY_NAME);
String altLink = objectNode.get(ALT_LINK_PROPERTY_NAME).asText();
this.documentCollection = new DocumentCollection(collectionNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import reactor.core.publisher.Mono;

import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
Expand Down Expand Up @@ -175,7 +176,15 @@ protected String deserializeKey(ObjectInputStream ois) throws IOException {
@Override
protected DocumentCollection deserializeValue(ObjectInputStream ois) throws IOException,
ClassNotFoundException {
return ((DocumentCollection.SerializableDocumentCollection) ois.readObject()).getWrappedItem();
Object obj = ois.readObject();
// Security fix: Validate that the deserialized object is the expected type
if (!(obj instanceof DocumentCollection.SerializableDocumentCollection)) {
throw new InvalidClassException(
"Expected SerializableDocumentCollection but got " +
(obj == null ? "null" : obj.getClass().getName())
);
}
return ((DocumentCollection.SerializableDocumentCollection) obj).getWrappedItem();
}
}

Expand Down Expand Up @@ -237,8 +246,27 @@ private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IO
pairs.put(key, new AsyncLazy<>(value));
}

// Security fix: Don't deserialize the IEqualityComparer as it could be a malicious object
// (e.g., a crafted lambda that executes arbitrary code).
// Instead, skip it and use the default equality comparer.
// This is safe because:
// 1. Most production code uses the default equality comparer (via no-arg AsyncCache constructor).
// RxCollectionCache uses CollectionRidComparer, but we restore the correct comparer by
// always using the default on deserialization. This is acceptable because the comparer
// only affects cache staleness checks, not correctness.
// 2. The serialization format remains unchanged (we still write the comparer for backward compatibility)
// 3. Future format changes should increment the serialVersionUID to handle compatibility explicitly
Object unusedComparer = ois.readObject(); // Read and discard the serialized comparer to maintain format compatibility

// Use the default equality comparer (same as AsyncCache constructor)
@SuppressWarnings("unchecked")
IEqualityComparer<TValue> equalityComparer = (IEqualityComparer<TValue>) ois.readObject();
IEqualityComparer<TValue> equalityComparer = (value1, value2) -> {
if (value1 == value2)
return true;
if (value1 == null || value2 == null)
return false;
return value1.equals(value2);
};
this.cache = new AsyncCache<>(equalityComparer, pairs);
}
}
Expand Down
Loading
Loading