1use derive_deftly::{Deftly, define_derive_deftly};
4use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
5use zeroize::Zeroize;
6
7#[cfg(feature = "memquota-memcost")]
8use tor_memquota_cost::derive_deftly_template_HasMemoryCost;
9
10define_derive_deftly! {
11 export ConstantTimeEq for struct:
18
19 impl<$tgens> ConstantTimeEq for $ttype
20 where $twheres
21 $( $ftype : ConstantTimeEq , )
22 {
23 fn ct_eq(&self, other: &Self) -> subtle::Choice {
24 match (self, other) {
25 $(
26 (${vpat fprefix=self_}, ${vpat fprefix=other_}) => {
27 $(
28 $<self_ $fname>.ct_eq($<other_ $fname>) &
29 )
30 subtle::Choice::from(1)
31 },
32 )
33 }
34 }
35 }
36}
37define_derive_deftly! {
38 export PartialEqFromCtEq:
41
42 impl<$tgens> PartialEq for $ttype
43 where $twheres
44 $ttype : ConstantTimeEq
45 {
46 fn eq(&self, other: &Self) -> bool {
47 self.ct_eq(other).into()
48 }
49 }
50}
51pub(crate) use {derive_deftly_template_ConstantTimeEq, derive_deftly_template_PartialEqFromCtEq};
52
53#[allow(clippy::derived_hash_with_manual_eq)]
64#[derive(Clone, Copy, Debug, Hash, Zeroize, Deftly)]
65#[cfg_attr(feature = "memquota-memcost", derive_deftly(HasMemoryCost))]
66pub struct CtByteArray<const N: usize>([u8; N]);
67
68impl<const N: usize> ConstantTimeEq for CtByteArray<N> {
69 fn ct_eq(&self, other: &Self) -> Choice {
70 self.0.ct_eq(&other.0)
71 }
72}
73
74impl<const N: usize> PartialEq for CtByteArray<N> {
75 fn eq(&self, other: &Self) -> bool {
76 self.ct_eq(other).into()
77 }
78}
79impl<const N: usize> Eq for CtByteArray<N> {}
80
81impl<const N: usize> From<[u8; N]> for CtByteArray<N> {
82 fn from(value: [u8; N]) -> Self {
83 Self(value)
84 }
85}
86
87impl<const N: usize> From<CtByteArray<N>> for [u8; N] {
88 fn from(value: CtByteArray<N>) -> Self {
89 value.0
90 }
91}
92
93impl<const N: usize> Ord for CtByteArray<N> {
94 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
95 let mut first_nonzero_difference = 0_i16;
99
100 for (a, b) in self.0.iter().zip(other.0.iter()) {
101 let difference = i16::from(*a) - i16::from(*b);
102
103 first_nonzero_difference
110 .conditional_assign(&difference, first_nonzero_difference.ct_eq(&0));
111 }
112
113 first_nonzero_difference.cmp(&0)
116 }
117}
118
119impl<const N: usize> PartialOrd for CtByteArray<N> {
120 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
121 Some(self.cmp(other))
122 }
123}
124
125impl<const N: usize> AsRef<[u8; N]> for CtByteArray<N> {
126 fn as_ref(&self) -> &[u8; N] {
127 &self.0
128 }
129}
130
131impl<const N: usize> AsMut<[u8; N]> for CtByteArray<N> {
132 fn as_mut(&mut self) -> &mut [u8; N] {
133 &mut self.0
134 }
135}
136
137pub fn ct_lookup<T, F>(array: &[T], matches: F) -> Option<&T>
153where
154 F: Fn(&T) -> Choice,
155{
156 let mut idx: u64 = 0;
159 let mut found: Choice = 0.into();
160
161 for (i, x) in array.iter().enumerate() {
162 let equal = matches(x);
163 idx.conditional_assign(&(i as u64), equal);
164 found.conditional_assign(&equal, equal);
165 }
166
167 if found.into() {
168 Some(&array[idx as usize])
169 } else {
170 None
171 }
172}
173
174#[cfg(test)]
175mod test {
176 #![allow(clippy::bool_assert_comparison)]
178 #![allow(clippy::clone_on_copy)]
179 #![allow(clippy::dbg_macro)]
180 #![allow(clippy::mixed_attributes_style)]
181 #![allow(clippy::print_stderr)]
182 #![allow(clippy::print_stdout)]
183 #![allow(clippy::single_char_pattern)]
184 #![allow(clippy::unwrap_used)]
185 #![allow(clippy::unchecked_time_subtraction)]
186 #![allow(clippy::useless_vec)]
187 #![allow(clippy::needless_pass_by_value)]
188 use super::*;
191 use rand::Rng;
192 use tor_basic_utils::test_rng;
193
194 #[allow(clippy::nonminimal_bool)]
195 #[test]
196 fn test_comparisons() {
197 let num = 200;
198 let mut rng = test_rng::testing_rng();
199
200 let mut array: Vec<CtByteArray<32>> =
201 (0..num).map(|_| rng.random::<[u8; 32]>().into()).collect();
202 array.sort();
203
204 for i in 0..num {
205 assert_eq!(array[i], array[i]);
206 assert!(!(array[i] < array[i]));
207 assert!(!(array[i] > array[i]));
208
209 for j in (i + 1)..num {
210 assert!(array[i] < array[j]);
214 assert_ne!(array[i], array[j]);
215 assert!(array[j] > array[i]);
216 assert_eq!(
217 array[i].cmp(&array[j]),
218 array[j].as_ref().cmp(array[i].as_ref()).reverse()
219 );
220 }
221 }
222 }
223
224 #[test]
225 fn test_lookup() {
226 use super::ct_lookup as lookup;
227 use subtle::ConstantTimeEq;
228 let items = vec![
229 "One".to_string(),
230 "word".to_string(),
231 "of".to_string(),
232 "every".to_string(),
233 "length".to_string(),
234 ];
235 let of_word = lookup(&items[..], |i| i.len().ct_eq(&2));
236 let every_word = lookup(&items[..], |i| i.len().ct_eq(&5));
237 let no_word = lookup(&items[..], |i| i.len().ct_eq(&99));
238 assert_eq!(of_word.unwrap(), "of");
239 assert_eq!(every_word.unwrap(), "every");
240 assert_eq!(no_word, None);
241 }
242}