Skip to content

Commit 9a2f90b

Browse files
committed
Add a Threshold<T> type
We have various enums in the codebase that include a `Thresh` variant, we have to explicitly check that invariants are maintained all over the place because these enums are public (eg, `policy::Concrete`). Add a `Threshold<T>` type that abstracts over a threshold and maintains the following invariants: - v.len() > 0 - k > 0 - k <= v.len()
1 parent b60a702 commit 9a2f90b

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub mod miniscript;
126126
pub mod plan;
127127
pub mod policy;
128128
pub mod psbt;
129+
pub mod threshold;
129130

130131
#[cfg(test)]
131132
mod test_utils;

src/threshold.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// SPDX-License-Identifier: CC0-1.0
2+
3+
//! A generic (k,n)-threshold type.
4+
5+
use core::fmt;
6+
7+
use crate::prelude::Vec;
8+
9+
/// A (k, n)-threshold.
10+
///
11+
/// This type maintains the following invariants:
12+
/// - n > 0
13+
/// - k > 0
14+
/// - k <= n
15+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
16+
pub struct Threshold<T> {
17+
k: usize,
18+
v: Vec<T>,
19+
}
20+
21+
impl<T> Threshold<T> {
22+
/// Creates a `Theshold<T>` after checking that invariants hold.
23+
pub fn new(k: usize, v: Vec<T>) -> Result<Threshold<T>, Error> {
24+
if v.len() == 0 {
25+
Err(Error::ZeroN)
26+
} else if k == 0 {
27+
Err(Error::ZeroK)
28+
} else if k > v.len() {
29+
Err(Error::BigK)
30+
} else {
31+
Ok(Threshold { k, v })
32+
}
33+
}
34+
35+
/// Creates a `Theshold<T>` without checking that invariants hold.
36+
#[cfg(test)]
37+
pub fn new_unchecked(k: usize, v: Vec<T>) -> Threshold<T> { Threshold { k, v } }
38+
39+
/// Returns `k`, the threshold value.
40+
pub fn k(&self) -> usize { self.k }
41+
42+
/// Returns `n`, the total number of elements in the threshold.
43+
pub fn n(&self) -> usize { self.v.len() }
44+
45+
/// Returns a read-only iterator over the threshold elements.
46+
pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() }
47+
48+
/// Returns the threshold elements, consuming self.
49+
// TODO: Find a better name for this functiion.
50+
pub fn into_elements(self) -> Vec<T> { self.v }
51+
52+
/// Maps `Thresh<T>` to `Thresh<U>`.
53+
///
54+
/// Typically you would call this function after collecting a vector that explicitly contains
55+
/// the correct number of elements e.g.,
56+
///
57+
/// `thresh.map((0..thresh.n()).map(|element| some_function(element)).collect())`
58+
///
59+
/// # Panics
60+
///
61+
/// Panics if the new vector is not the same length as the
62+
/// original i.e., `new.len() != self.n()`.
63+
pub fn map<U>(&self, new: Vec<U>) -> Threshold<U> {
64+
if self.n() != new.len() {
65+
panic!("cannot map to a different length vector")
66+
}
67+
Threshold { k: self.k(), v: new }
68+
}
69+
}
70+
71+
/// An error attempting to construct a `Threshold<T>`.
72+
#[derive(Debug, Clone, PartialEq, Eq)]
73+
#[non_exhaustive]
74+
pub enum Error {
75+
/// Threshold `n` value must be non-zero.
76+
ZeroN,
77+
/// Threshold `k` value must be non-zero.
78+
ZeroK,
79+
/// Threshold `k` value must be <= `n`.
80+
BigK,
81+
}
82+
83+
impl fmt::Display for Error {
84+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85+
use Error::*;
86+
87+
match *self {
88+
ZeroN => f.write_str("threshold `n` value must be non-zero"),
89+
ZeroK => f.write_str("threshold `k` value must be non-zero"),
90+
BigK => f.write_str("threshold `k` value must be <= `n`"),
91+
}
92+
}
93+
}
94+
95+
#[cfg(feature = "std")]
96+
impl std::error::Error for Error {
97+
fn cause(&self) -> Option<&dyn std::error::Error> {
98+
use Error::*;
99+
100+
match *self {
101+
ZeroN | ZeroK | BigK => None,
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)