Skip to content

Commit 755a43c

Browse files
committed
resolved a bunch of todo's
1 parent 88e5674 commit 755a43c

9 files changed

Lines changed: 108 additions & 634 deletions

File tree

β€Žcrypto/core-interface/tests/key_material_tests.rsβ€Ž

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,13 +439,16 @@ mod test_key_material {
439439
}
440440

441441
#[test]
442-
fn from_keymaterial() {
443-
let key1 = KeyMaterial256::from_bytes(&DUMMY_KEY[..32]).unwrap();
444-
442+
fn from_keym() {
443+
let key1 = KeyMaterial256::from_bytes_as_type(&DUMMY_KEY[..32], KeyType::MACKey).unwrap();
444+
assert_eq!(key1.key_type(), KeyType::MACKey);
445+
assert_eq!(key1.security_strength(), SecurityStrength::_256bit);
446+
445447
// success case: same size using default From impl; only works if the sizes are the same (ie the compiler knows that they are the same type.
446448
let key2 = KeyMaterial256::from(key1.clone());
447449
assert_eq!(key1.key_len(), key2.key_len());
448450
assert_eq!(key1.key_type(), key2.key_type());
451+
assert_eq!(key1.security_strength(), key2.security_strength());
449452
assert_eq!(key1, key2);
450453

451454
// success case: same size

β€Žcrypto/mldsa/src/aux_functions.rsβ€Ž

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@ pub(crate) fn coeff_from_three_bytes(b: &[u8; 3]) -> Result<i32, ()> {
3535
/// Output: An integer between βˆ’πœ‚ and πœ‚, or βŠ₯.
3636
#[inline(always)]
3737
pub(crate) fn coeff_from_half_byte<const ETA: usize>(b: u8) -> Result<i32, ()> {
38-
// todo: there's no way this is constant time:
39-
// todo: the if statement might not be so bad because the alternative is rejection,
40-
// todo: but that % is a problem.
41-
// todo: what does openssl or rust crypto do?
4238
if ETA == 2 && b < 15 {
43-
Ok(2 - (b % 5) as i32) // todo: is constant-time?
39+
// Original code is bad because '%' is not constant-time.
40+
// Ok(2 - (b % 5) as i32)
41+
// I'm still not convinced this is constant-time, but maybe it's closer? And I can't come up with anything better.
42+
let b = match b {
43+
b if b < 5 => b,
44+
b if b < 10 => b - 5,
45+
_ => b - 10,
46+
};
47+
Ok(2 - b as i32)
4448
} else {
4549
if ETA == 4 && b < 9 { Ok(4 - b as i32) } else { Err(()) }
4650
}
@@ -412,7 +416,6 @@ pub(crate) fn sig_encode<
412416
pos += LAMBDA_over_4;
413417

414418
for i in 0..l {
415-
// todo -- remove this copy by having bitpack_gamma1 take an output slice
416419
output[pos..pos + POLY_Z_PACKED_LEN]
417420
.copy_from_slice(&bitpack_gamma1::<POLY_Z_PACKED_LEN, GAMMA1>(&z.vec[i]));
418421
pos += POLY_Z_PACKED_LEN;
@@ -504,7 +507,6 @@ pub(crate) fn sig_decode<
504507
}
505508

506509
// β–· read any leftover bytes in the first πœ” bytes of 𝑦 for malformed (nonzero) bytes
507-
508510
for j in idx..OMEGA as usize {
509511
if sig[pos + j] != 0 {
510512
return Err(());
@@ -790,7 +792,6 @@ fn test_power_2_round() {
790792
/// Decomposes π‘Ÿ into (π‘Ÿ1, π‘Ÿ0) such that π‘Ÿ ≑ π‘Ÿ1(2𝛾2) + π‘Ÿ0 mod π‘ž.
791793
/// Input: π‘Ÿ ∈ β„€π‘ž.
792794
/// Output: Integers (π‘Ÿ1, π‘Ÿ0).
793-
794795
// the hope here is that the compiler will aggressively inline this function,
795796
// and optimize away the branching.
796797
#[inline(always)]
@@ -1029,7 +1030,6 @@ pub(crate) fn ntt(w: &Polynomial) -> Polynomial {
10291030
/// Output: Polynomial 𝑀(𝑋) = βˆ‘255
10301031
/// 𝑗=0 𝑀𝑗𝑋𝑗 ∈ π‘…π‘ž
10311032
pub(crate) fn inv_ntt(w_hat: &Polynomial) -> Polynomial {
1032-
// todo: optimize to do this in-place? Might actually bench worse.
10331033
let mut w = w_hat.clone();
10341034

10351035
let mut m: usize = N;
Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
use bouncycastle_core_interface::traits::Hash;
2+
use crate::MLDSA44PublicKey;
13

24
/// Note that the PH expected here *is not the same* as the `mu` computed by [compute_mu] ... blah blah explain.
3-
struct DummyForTheCommentToCompile {}
5+
struct DummyForTheCommentToCompile {}
6+
7+
/// Note: yes, a HashMLDSAPublicKey is just a re-branded MLDSAPublicKey because they are structurally
8+
/// the same, but we are giving them different types to indicate that they _should_ not be used interchangeably
9+
/// in order to prevent certain cross-protocol attacks where a verifier is tricked into accepting an
10+
/// ML-DSA signature that was actually generated by the HashML-DSA algorithm, and vice-versa.
11+
/// IE a given key should be used ONLY for ML-DSA or HashML-DSA and should not be used sometimes for one and sometimes for the other.
12+
/// We won't prevent you from doing so; for example you can convert between them by encoding the key to bytes and
13+
/// back into the other type, but we want to make you aware that you're opening yourself up to cross-protocol attacks.
14+
pub type HashMLDSA44PublicKey = MLDSA44PublicKey;
15+
16+
/// An instance of the HashML-DSA algorithm.
17+
///
18+
/// We are exposing the HashMLDSA struct this way so that alternative hash functions can be used
19+
/// without requiring modification of this source code; you can add your own hash function
20+
/// by specifying the hash function to use (in the verifier), and specifying the bytes of the OID to
21+
/// to use as its domain separator in constructing the message representative M'.
22+
// todo: put an example of this in the unit tests and then copy that example into these ducs
23+
// todo: figure out how to do a const str param
24+
pub struct HashMLDSA<HASH: Hash + Default, const oid_name: [u8]> {
25+
_phantom: std::marker::PhantomData<HASH>,
26+
}

β€Žcrypto/mldsa/src/lib.rsβ€Ž

Lines changed: 4 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#![feature(generic_const_exprs)]
1010
#![feature(int_roundings)]
1111
#![feature(inherent_associated_types)]
12+
#![feature(adt_const_params)]
1213
// These are because I'm matching variable names exactly against FIPS 204, for example both 'K' and 'k',
1314
// or 'A' and 'a' are used and have specific meanings.
1415
// But need to tell the rust linter to not care.
@@ -19,6 +20,9 @@
1920
// MLDSA implentation, but I don't want accessed from outside, such FIPS-internal functions.
2021
#![allow(private_bounds)]
2122

23+
// Used in HashMLDSA
24+
#![feature(unsized_const_params)]
25+
2226
mod mldsa;
2327
mod hashmldsa;
2428
mod mldsa_keys;
@@ -56,104 +60,4 @@ pub use mldsa::{MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_SIG_LEN};
5660

5761

5862

59-
/*** Param traits ***/
60-
61-
// todo -- delete
62-
// /// Private trait on purpose so that only the NIST-approved params can be used.
63-
// /// Values taken directly from FIPS 204 Table 1 and Table 2
64-
// #[allow(private_bounds)]
65-
// trait MLDSAParams {
66-
// // from FIPS 204 Table 1
67-
// // q, zeta, d defined as global constants since they do not vary by parameter set
68-
// const TAU: i32;
69-
// const GAMMA1: i32;
70-
// const GAMMA2: i32;
71-
// const k: usize;
72-
// const l: usize;
73-
// const ETA: i32;
74-
// const BETA: i32; // tau * eta
75-
// const OMEGA: i32;
76-
//
77-
// // useful derived values
78-
// const C_TILDE: usize;
79-
// const POLY_VEC_H_PACKED_LEN: usize;
80-
// const POLY_Z_PACKED_LEN: usize;
81-
// const POLY_W1_PACKED_LEN: usize;
82-
// const POLY_ETA_PACKED_LEN: usize;
83-
// const GAMMA1_MASK_LEN: usize;
84-
// const LAMBDA_over_4: usize;
85-
// }
86-
87-
// pub struct MLDSA44Params;
88-
//
89-
// impl MLDSAParams for MLDSA44Params {
90-
// const TAU: i32 = 39;
91-
// const GAMMA1: i32 = 1 << 17;
92-
// const GAMMA2: i32 = (q - 1) / 88;
93-
// const k: usize = 4;
94-
// const l: usize = 4;
95-
// const ETA: i32 = 2;
96-
// const BETA: i32 = 78;
97-
// const OMEGA: i32 = 80;
98-
//
99-
// // const ALG: MldsaAlg = MldsaAlg::MlDsa44;
100-
// const C_TILDE: usize = 32;
101-
// const POLY_VEC_H_PACKED_LEN: usize = 0; // todo -- compute
102-
// const POLY_Z_PACKED_LEN: usize = 576;
103-
// const POLY_W1_PACKED_LEN: usize = 192;
104-
// const POLY_ETA_PACKED_LEN: usize = 96;
105-
//
106-
// // Alg 32
107-
// // 1: 𝑐 ← 1 + bitlen (𝛾1 βˆ’ 1)
108-
// const GAMMA1_MASK_LEN: usize = 576; // 32*(1 + bitlen (𝛾1 βˆ’ 1) )
109-
// const LAMBDA_over_4: usize = 128/4;
110-
// // todo -- bc-java does it as compute: 576usize.div_ceil(symmetric.stream_256_block_bytes) -- which should be 5
111-
// // todo -- might need to debug this against bc-java
112-
// // todo -- debug this against bc-java; or look in other implementations. I feel like this should be 32*17=544 or 32*19=608
113-
// // todo -- I'm not sure why they're adding an extra 32
114-
// // todo -- corresponds to aux_functions::expand_mask()
115-
// }
116-
117-
// pub struct MLDSA65Params;
118-
//
119-
// impl MLDSAParams for MLDSA65Params {
120-
// const TAU: i32 = 49;
121-
// const GAMMA1: i32 = 1 << 19;
122-
// const GAMMA2: i32 = (q - 1) / 32;
123-
// const k: usize = 6;
124-
// const l: usize = 5;
125-
// const ETA: i32 = 4;
126-
// const BETA: i32 = 196;
127-
// const OMEGA: i32 = 55;
128-
//
129-
// const C_TILDE: usize = 48;
130-
// const POLY_VEC_H_PACKED_LEN: usize = 0; // todo -- compute
131-
// const POLY_Z_PACKED_LEN: usize = 640;
132-
// const POLY_W1_PACKED_LEN: usize = 128;
133-
// const POLY_ETA_PACKED_LEN: usize = 128;
134-
// const GAMMA1_MASK_LEN: usize = 640; // todo -- compute: 640usize.div_ceil(symmetric.stream_256_block_bytes)
135-
// const LAMBDA_over_4: usize = 192/4;
136-
// }
137-
138-
// pub struct MLDSA87Params;
139-
//
140-
// impl MLDSAParams for MLDSA87Params {
141-
// const TAU: i32 = 60;
142-
// const GAMMA1: i32 = 1 << 19;
143-
// const GAMMA2: i32 = (q - 1) / 32;
144-
// const k: usize = 8;
145-
// const l: usize = 7;
146-
// const ETA: i32 = 2;
147-
// const BETA: i32 = 120;
148-
// const OMEGA: i32 = 75;
149-
//
150-
// const C_TILDE: usize = 64;
151-
// const POLY_VEC_H_PACKED_LEN: usize = 0; // todo -- compute
152-
// const POLY_Z_PACKED_LEN: usize = 640;
153-
// const POLY_W1_PACKED_LEN: usize = 128;
154-
// const POLY_ETA_PACKED_LEN: usize = 96;
155-
// const GAMMA1_MASK_LEN: usize = 640; // todo -- compute: 640usize.div_ceil(symmetric.stream_256_block_bytes)
156-
// const LAMBDA_over_4: usize = 256/4;
157-
// }
158-
15963
// todo -- impl bouncycastle_core_interface::traits::Algorithm with the security strengths from Table 1

β€Žcrypto/mldsa/src/matrix.rsβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ impl<const LEN: usize> Vector<LEN>
188188
// 4: end for
189189
for i in 0..LEN {
190190
w1_tilde[i*POLY_W1_PACKED_LEN .. (i+1)*POLY_W1_PACKED_LEN].copy_from_slice(
191-
// todo -- optimize this to take a slice and write directly to it?
192191
&self.vec[i].w1_encode::<POLY_W1_PACKED_LEN>()
193192
)
194193
}

β€Žcrypto/mldsa/src/mldsa.rsβ€Ž

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub(crate) const ROOT_OF_UNITY: i32 = 1753;
2727
pub const SEED_LEN: usize = 32;
2828
pub const RND_LEN: usize = 32;
2929
pub const TR_LEN: usize = 64;
30+
pub const MU_LEN: usize = 64;
3031
pub(crate) const POLY_T1PACKED_LEN: usize = 320;
3132
pub(crate) const POLY_T0PACKED_LEN: usize = 416;
3233

@@ -287,7 +288,7 @@ impl<
287288
/// specifically takes a 32-byte [KeyMaterial256] and checks that it has [KeyType::Seed] and
288289
/// [SecurityStrength::_256bit].
289290
/// If you happen to have your seed in a larger KeyMaterial, you'll have to copy it using
290-
/// [KeyMaterial::from_key] -- todo: make sure this works and copies key type and security strength correctly.
291+
/// [KeyMaterial::from_key]
291292
fn keygen_internal(
292293
seed: &KeyMaterial256,
293294
) -> Result<
@@ -312,15 +313,11 @@ impl<
312313
let mut rho_prime: [u8; 64] = [0u8; 64];
313314
let mut K: [u8; 32] = [0u8; 32];
314315

315-
// TODO: optimization: re-use variables rather than allocating new ones?
316-
// TODO: do with benches because it might not actually be faster. Rust seems to like local vars.
317-
318316
let mut h = H::default();
319317
h.absorb(seed.ref_to_bytes());
320318
h.absorb(&(k as u8).to_le_bytes());
321319
h.absorb(&(l as u8).to_le_bytes());
322320
let bytes_written = h.squeeze_out(&mut rho);
323-
debug_assert_eq!(bytes_written, 32); // todo: remove these asserts once we have unit tests that pass?
324321
let bytes_written = h.squeeze_out(&mut rho_prime);
325322
debug_assert_eq!(bytes_written, 64);
326323
let bytes_written = h.squeeze_out(&mut K);
@@ -421,14 +418,17 @@ impl<
421418
/// (in which case a keygen_from_seed is run and then the pk's compared).
422419
///
423420
/// Returns either `()` or [SignatureError::ConsistencyCheckFailed].
424-
///
425-
/// TODO -- sync with openssl implementation
426-
/// TODO -- https://github.com/openssl/openssl/blob/master/crypto/ml_dsa/ml_dsa_key.c#L385
427421
pub fn keypair_consistency_check(
428422
pk: &PK,
429423
sk: &SK,
430424
) -> Result<(), SignatureError> {
431-
todo!()
425+
// This is maybe a computationally heavy way to compare them, but it works
426+
let derived_pk = sk.derive_public_key();
427+
if derived_pk.compute_tr() == pk.compute_tr() {
428+
Ok(())
429+
} else {
430+
Err(SignatureError::ConsistencyCheckFailed())
431+
}
432432
}
433433

434434
/// This provides the first half of the "External Mu" interface to ML-DSA which is described
@@ -465,8 +465,8 @@ impl<
465465
pub fn compute_mu_from_tr(
466466
msg: &[u8],
467467
ctx: Option<&[u8]>,
468-
tr: &[u8; 64],
469-
) -> Result<[u8; 64], SignatureError> {
468+
tr: &[u8; TR_LEN],
469+
) -> Result<[u8; TR_LEN], SignatureError> {
470470
MuBuilder::compute_mu(msg, ctx, tr)
471471
}
472472

@@ -475,7 +475,7 @@ impl<
475475
msg: &[u8],
476476
ctx: Option<&[u8]>,
477477
pk: &PK,
478-
) -> Result<[u8; 64], SignatureError> {
478+
) -> Result<[u8; MU_LEN], SignatureError> {
479479
MuBuilder::compute_mu(msg, ctx, &pk.compute_tr())
480480
}
481481

@@ -494,7 +494,7 @@ impl<
494494
/// This mode uses randomized signing (called "hedged mode" in FIPS 204) using an internal RNG.
495495
fn sign_mu(
496496
sk: &SK,
497-
mu: &[u8; 64],
497+
mu: &[u8; MU_LEN],
498498
) -> Result<[u8; SIG_LEN], SignatureError> {
499499
let mut out: [u8; SIG_LEN] = [0u8; SIG_LEN];
500500
Self::sign_mu_out(sk, mu, &mut out)?;
@@ -509,10 +509,10 @@ impl<
509509
/// Returns the number of bytes written to the output buffer. Can be called with an oversized buffer.
510510
fn sign_mu_out(
511511
sk: &SK,
512-
mu: &[u8; 64],
512+
mu: &[u8; MU_LEN],
513513
output: &mut [u8; SIG_LEN],
514514
) -> Result<usize, SignatureError> {
515-
let mut rnd: [u8; 32] = [0u8; 32];
515+
let mut rnd: [u8; RND_LEN] = [0u8; RND_LEN];
516516
HashDRBG_SHA512::new_from_os().next_bytes_out(&mut rnd)?;
517517

518518
Self::sign_mu_deterministic_out(sk, mu, rnd, output)
@@ -553,8 +553,8 @@ impl<
553553
/// Returns the number of bytes written to the output buffer. Can be called with an oversized buffer.
554554
pub(crate) fn sign_mu_deterministic_out(
555555
sk: &SK,
556-
mu: &[u8; 64],
557-
rnd: [u8; 32],
556+
mu: &[u8; MU_LEN],
557+
rnd: [u8; RND_LEN],
558558
output: &mut [u8; SIG_LEN],
559559
) -> Result<usize, SignatureError> {
560560
// 1: (𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0) ← skDecode(π‘ π‘˜)
@@ -608,9 +608,6 @@ impl<
608608
// 11: 𝐲 ∈ 𝑅^β„“ ← ExpandMask(πœŒβ€³, πœ…)
609609
let mut y = expand_mask::<l, GAMMA1, GAMMA1_MASK_LEN>(&rho_p_p, kappa);
610610

611-
// last use of rho_p_p, so zeroizing it
612-
rho_p_p.fill(0u8);
613-
614611
// 12: 𝐰 ← NTTβˆ’1(𝐀_hat * NTT(𝐲))
615612
let mut y_hat = y.clone();
616613
y_hat.ntt();
@@ -696,12 +693,16 @@ impl<
696693
};
697694

698695
// "In addition, there is an alternative way of implementing the validity checks on 𝐳 and the computation of
699-
// 𝐑, which is described in Section 5.1 of. This method may also be used in implementations of ML-DSA."
700-
// todo -- check this out
696+
// 𝐑, which is described in Section 5.1 of [6] (dilithium-specification-round3-20210208.pdf).
697+
// This method may also be used in implementations of ML-DSA."
698+
// todo -- I believe this code is already using this optimization, but it could use a deeper look to see if more optimization is possible.
701699

702700
break;
703701
}
704702

703+
// zeroize rho_p_p before returning it to the OS
704+
rho_p_p.fill(0u8);
705+
705706
// 33: 𝜎 ← sigEncode(𝑐, 𝐳̃ modΒ±π‘ž, 𝐑)
706707
let bytes_written = sig_encode::<GAMMA1, k, l, LAMBDA_over_4, OMEGA, POLY_Z_PACKED_LEN, SIG_LEN>
707708
(&sig_val_c_tilde, &sig_val_z, &sig_val_h, output);
@@ -721,7 +722,7 @@ impl<
721722
/// Input: Signature 𝜎 ∈ π”Ήπœ†/4+β„“β‹…32β‹…(1+bitlen (𝛾1βˆ’1))+πœ”+π‘˜.
722723
fn verify_mu_internal(
723724
pk: &PK,
724-
mu: &[u8; 64],
725+
mu: &[u8; MU_LEN],
725726
sig: &[u8; SIG_LEN],
726727
) -> bool {
727728
// 1: (𝜌, 𝐭1) ← pkDecode(π‘π‘˜)

β€Žcrypto/mldsa/src/mldsa_keys.rsβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, cons
384384
seed: seed.clone(),
385385
}
386386
}
387-
387+
388388
fn rho(&self) -> &[u8; 32] { &self.rho }
389389

390390
fn K(&self) -> &[u8; 32] { &self.K }

0 commit comments

Comments
Β (0)