Skip to content

Commit 1542226

Browse files
committed
Add a Threshold<T> type
We have various enums in the codebase that include a `Thresh` variant, we have to explicitly check that invariants are maintained all over the place because these enums are public (eg, `policy::Concrete`). Add a `Threshold<T>` type that abstracts over a threshold and maintains the following invariants: - v.len() > 0 - k > 0 - k <= v.len()
1 parent 4734ed4 commit 1542226

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub mod miniscript;
126126
pub mod plan;
127127
pub mod policy;
128128
pub mod psbt;
129+
pub mod threshold;
129130

130131
#[cfg(test)]
131132
mod test_utils;

src/threshold.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// SPDX-License-Identifier: CC0-1.0
2+
3+
//! A generic (k,n)-threshold type.
4+
5+
use core::fmt;
6+
7+
use crate::prelude::Vec;
8+
9+
/// A (k, n)-threshold.
10+
///
11+
/// This type maintains the following invariants:
12+
/// - n > 0
13+
/// - k > 0
14+
/// - k <= n
15+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
16+
pub struct Threshold<T> {
17+
k: usize,
18+
v: Vec<T>,
19+
}
20+
21+
impl<T> Threshold<T> {
22+
/// Creates a `Theshold<T>` after checking that invariants hold.
23+
pub fn new(k: usize, v: Vec<T>) -> Result<Threshold<T>, Error> {
24+
if v.len() == 0 {
25+
Err(Error::ZeroN)
26+
} else if k == 0 {
27+
Err(Error::ZeroK)
28+
} else if k > v.len() {
29+
Err(Error::BigK)
30+
} else {
31+
Ok(Threshold { k, v })
32+
}
33+
}
34+
35+
/// Creates a `Theshold<T>` without checking that invariants hold.
36+
#[cfg(test)]
37+
pub fn new_unchecked(k: usize, v: Vec<T>) -> Threshold<T> { Threshold { k, v } }
38+
39+
/// Returns `k`, the threshold value.
40+
pub fn k(&self) -> usize { self.k }
41+
42+
/// Returns `n`, the total number of elements in the threshold.
43+
pub fn n(&self) -> usize { self.v.len() }
44+
45+
/// Returns a read-only iterator over the threshold elements.
46+
pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() }
47+
48+
/// Returns the threshold elements, consuming self.
49+
pub fn into_elements(self) -> Vec<T> { self.v }
50+
51+
/// Maps `Thresh<T>` to `Thresh<U>`.
52+
///
53+
/// Typically you would call this function after collecting a vector that explicitly contains
54+
/// the correct number of elements e.g.,
55+
///
56+
/// `thresh.map((0..thresh.n()).map(|element| some_function(element)).collect())`
57+
///
58+
/// # Panics
59+
///
60+
/// Panics if the new vector is not the same length as the
61+
/// original i.e., `new.len() != self.n()`.
62+
pub fn map<U>(&self, new: Vec<U>) -> Threshold<U> {
63+
if self.n() != new.len() {
64+
panic!("cannot map to a different length vector")
65+
}
66+
Threshold { k: self.k(), v: new }
67+
}
68+
}
69+
70+
/// An error attempting to construct a `Threshold<T>`.
71+
#[derive(Debug, Clone, PartialEq, Eq)]
72+
#[non_exhaustive]
73+
pub enum Error {
74+
/// Threshold `n` value must be non-zero.
75+
ZeroN,
76+
/// Threshold `k` value must be non-zero.
77+
ZeroK,
78+
/// Threshold `k` value must be <= `n`.
79+
BigK,
80+
}
81+
82+
impl fmt::Display for Error {
83+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84+
use Error::*;
85+
86+
match *self {
87+
ZeroN => f.write_str("threshold `n` value must be non-zero"),
88+
ZeroK => f.write_str("threshold `k` value must be non-zero"),
89+
BigK => f.write_str("threshold `k` value must be <= `n`"),
90+
}
91+
}
92+
}
93+
94+
#[cfg(feature = "std")]
95+
impl std::error::Error for Error {
96+
fn cause(&self) -> Option<&dyn std::error::Error> {
97+
use Error::*;
98+
99+
match *self {
100+
ZeroN | ZeroK | BigK => None,
101+
}
102+
}
103+
}
104+
105+
#[cfg(test)]
106+
mod tests {
107+
use super::*;
108+
109+
#[test]
110+
fn threshold_constructor_valid() {
111+
let v = vec![1, 2, 3];
112+
let n = 3;
113+
114+
for k in 1..=3 {
115+
let thresh = Threshold::new(k, v.clone()).expect("failed to create threshold");
116+
assert_eq!(thresh.k(), k);
117+
assert_eq!(thresh.n(), n);
118+
}
119+
}
120+
121+
#[test]
122+
fn threshold_constructor_invalid() {
123+
let v = vec![1, 2, 3];
124+
assert!(Threshold::new(0, v.clone()).is_err());
125+
assert!(Threshold::new(4, v.clone()).is_err());
126+
}
127+
}

0 commit comments

Comments
 (0)