Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/main/java/com/uid2/shared/encryption/AesGcm.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,27 @@
import com.uid2.shared.model.EncryptionKey;
import com.uid2.shared.model.KeyIdentifier;
import com.uid2.shared.model.KeysetKey;
import io.vertx.core.buffer.Buffer;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;

public class AesGcm {
private static final String cipherScheme = "AES/GCM/NoPadding";
private static final String CIPHER_SCHEME = "AES/GCM/NoPadding";
public static final int GCM_AUTHTAG_LENGTH = 16;
public static final int GCM_IV_LENGTH = 12;

private static final ThreadLocal<Cipher> CIPHER = ThreadLocal.withInitial(() -> {
try {
return Cipher.getInstance(CIPHER_SCHEME);
} catch (GeneralSecurityException e) {
throw new RuntimeException("Unable to create cipher", e);
}
});

public static EncryptedPayload encrypt(byte[] b, KeysetKey key) {
return encrypt(b, key.getKeyBytes(), key.getKeyIdentifier());
}
Expand All @@ -32,11 +40,16 @@ private static EncryptedPayload encrypt(byte[] b, byte[] secretBytes, KeyIdentif
public static byte[] encrypt(byte[] b, byte[] secretBytes) {
try {
final SecretKey k = new SecretKeySpec(secretBytes, "AES");
final Cipher c = Cipher.getInstance(cipherScheme);
final Cipher c = CIPHER.get();
final byte[] ivBytes = Random.getBytes(GCM_IV_LENGTH);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_AUTHTAG_LENGTH * 8, ivBytes);
final GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_AUTHTAG_LENGTH * 8, ivBytes);
c.init(Cipher.ENCRYPT_MODE, k, gcmParameterSpec);
return Buffer.buffer().appendBytes(ivBytes).appendBytes(c.doFinal(b)).getBytes();

// Pre-allocate output: IV + ciphertext + auth tag
final byte[] result = new byte[GCM_IV_LENGTH + c.getOutputSize(b.length)];
System.arraycopy(ivBytes, 0, result, 0, GCM_IV_LENGTH);
c.doFinal(b, 0, b.length, result, GCM_IV_LENGTH);
return result;
} catch (Exception e) {
throw new RuntimeException("Unable to Encrypt", e);
}
Expand All @@ -50,9 +63,10 @@ public static byte[] decrypt(byte[] encryptedBytes, int offset, byte[] secretByt
try {
final SecretKey key = new SecretKeySpec(secretBytes, "AES");
final GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_AUTHTAG_LENGTH * 8, encryptedBytes, offset, GCM_IV_LENGTH);
final Cipher c = Cipher.getInstance(cipherScheme);
final Cipher c = CIPHER.get();
c.init(Cipher.DECRYPT_MODE, key, gcmParameterSpec);
return c.doFinal(encryptedBytes, offset + GCM_IV_LENGTH, encryptedBytes.length - offset - GCM_IV_LENGTH);
final int dataOffset = offset + GCM_IV_LENGTH;
return c.doFinal(encryptedBytes, dataOffset, encryptedBytes.length - dataOffset);
} catch (Exception e) {
throw new RuntimeException("Unable to Decrypt", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import java.util.Collection;
import java.util.Comparator;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

/*
1. metadata.json format
Expand Down Expand Up @@ -42,6 +44,10 @@
public class RotatingClientKeyProvider implements IClientKeyProvider, StoreReader<Collection<ClientKey>> {
private final ScopedStoreReader<Collection<ClientKey>> reader;
private final AuthorizableStore<ClientKey> authorizableStore;
private final ConcurrentHashMap<Integer, VersionedValue> oldestClientKeyBySiteIdCache = new ConcurrentHashMap<>();
private volatile long snapshotVersion = 0;

private record VersionedValue(long version, Optional<ClientKey> value) {}

public RotatingClientKeyProvider(DownloadCloudStorage fileStreamProvider, StoreScope scope) {
this.reader = new ScopedStoreReader<>(fileStreamProvider, scope, new ClientParser(), "auth keys");
Expand All @@ -64,9 +70,13 @@ public long getVersion(JsonObject metadata) {
}

@Override
public long loadContent(JsonObject metadata) throws Exception {
public long loadContent(JsonObject metadata) throws Exception {
long version = reader.loadContent(metadata, "client_keys");
authorizableStore.refresh(getAll());

// Versioning to prevent race conditions when reading the oldest client key
oldestClientKeyBySiteIdCache.clear();
snapshotVersion = getVersion(metadata);
return version;
}

Expand Down Expand Up @@ -102,10 +112,18 @@ public IAuthorizable get(String key) {

@Override
public ClientKey getOldestClientKey(int siteId) {
return this.reader.getSnapshot().stream()
.filter(k -> k.getSiteId() == siteId) // filter by site id
.sorted(Comparator.comparing(ClientKey::getCreated)) // sort by key creation timestamp ascending
.findFirst() // return the oldest key
.orElse(null);
long currentVersion = snapshotVersion;
VersionedValue cached = oldestClientKeyBySiteIdCache.get(siteId);

if (cached != null && cached.version() == currentVersion) {
return cached.value().orElse(null);
}

Optional<ClientKey> computed = this.reader.getSnapshot().stream()
.filter(k -> k.getSiteId() == siteId)
.min(Comparator.comparingLong(ClientKey::getCreated));

oldestClientKeyBySiteIdCache.put(siteId, new VersionedValue(currentVersion, computed));
return computed.orElse(null);
}
}
Loading