diff --git a/examples/taproot.rs b/examples/taproot.rs index 3c8c6b66e..f0abd7885 100644 --- a/examples/taproot.rs +++ b/examples/taproot.rs @@ -101,7 +101,7 @@ fn main() { let real_desc = desc.translate_pk(&mut t).unwrap(); // Max satisfaction weight for compilation, corresponding to the script-path spend - // `multi_a(2,PUBKEY_1,PUBKEY_2) at taptree depth 1, having: + // `multi_a(2,PUBKEY_1,PUBKEY_2) at tap tree depth 1, having: // // max_witness_size = varint(control_block_size) + control_block size + // varint(script_size) + script_size + max_satisfaction_size diff --git a/src/descriptor/tr.rs b/src/descriptor/tr.rs index af8aeedab..23c1742d0 100644 --- a/src/descriptor/tr.rs +++ b/src/descriptor/tr.rs @@ -1,8 +1,7 @@ // SPDX-License-Identifier: CC0-1.0 -use core::cmp::{self, max}; use core::str::FromStr; -use core::{fmt, hash}; +use core::{cmp, fmt, hash}; use bitcoin::taproot::{ LeafVersion, TaprootBuilder, TaprootSpendInfo, TAPROOT_CONTROL_BASE_SIZE, @@ -29,7 +28,14 @@ use crate::{ #[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub enum TapTree { /// A taproot tree structure - Tree(Arc>, Arc>), + Tree { + /// Left tree branch. + left: Arc>, + /// Right tree branch. + right: Arc>, + /// Tree height, defined as `1 + max(left_height, right_height)`. + height: usize, + }, /// A taproot leaf denoting a spending condition // A new leaf version would require a new Context, therefore there is no point // in adding a LeafVersion with Leaf type here. All Miniscripts right now @@ -108,14 +114,24 @@ impl hash::Hash for Tr { } impl TapTree { - // Helper function to compute height - // TODO: Instead of computing this every time we add a new leaf, we should - // add height as a separate field in taptree - fn taptree_height(&self) -> usize { + /// Creates a `TapTree` by combining `left` and `right` tree nodes. + pub(crate) fn combine(left: TapTree, right: TapTree) -> Self { + let height = 1 + cmp::max(left.height(), right.height()); + TapTree::Tree { + left: Arc::new(left), + right: Arc::new(right), + height, + } + } + + /// Returns the height of this tree. + fn height(&self) -> usize { match *self { - TapTree::Tree(ref left_tree, ref right_tree) => { - 1 + max(left_tree.taptree_height(), right_tree.taptree_height()) - } + TapTree::Tree { + left: _, + right: _, + height, + } => height, TapTree::Leaf(..) => 0, } } @@ -134,12 +150,17 @@ impl TapTree { T: Translator, Q: MiniscriptKey, { - let frag = match self { - TapTree::Tree(l, r) => TapTree::Tree( - Arc::new(l.translate_helper(t)?), - Arc::new(r.translate_helper(t)?), - ), - TapTree::Leaf(ms) => TapTree::Leaf(Arc::new(ms.translate_pk(t)?)), + let frag = match *self { + TapTree::Tree { + ref left, + ref right, + ref height, + } => TapTree::Tree { + left: Arc::new(left.translate_helper(t)?), + right: Arc::new(right.translate_helper(t)?), + height: *height, + }, + TapTree::Leaf(ref ms) => TapTree::Leaf(Arc::new(ms.translate_pk(t)?)), }; Ok(frag) } @@ -148,7 +169,11 @@ impl TapTree { impl fmt::Display for TapTree { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TapTree::Tree(ref left, ref right) => write!(f, "{{{},{}}}", *left, *right), + TapTree::Tree { + ref left, + ref right, + height: _, + } => write!(f, "{{{},{}}}", *left, *right), TapTree::Leaf(ref script) => write!(f, "{}", *script), } } @@ -157,7 +182,11 @@ impl fmt::Display for TapTree { impl fmt::Debug for TapTree { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TapTree::Tree(ref left, ref right) => write!(f, "{{{:?},{:?}}}", *left, *right), + TapTree::Tree { + ref left, + ref right, + height: _, + } => write!(f, "{{{:?},{:?}}}", *left, *right), TapTree::Leaf(ref script) => write!(f, "{:?}", *script), } } @@ -167,7 +196,7 @@ impl Tr { /// Create a new [`Tr`] descriptor from internal key and [`TapTree`] pub fn new(internal_key: Pk, tree: Option>) -> Result { Tap::check_pk(&internal_key)?; - let nodes = tree.as_ref().map(|t| t.taptree_height()).unwrap_or(0); + let nodes = tree.as_ref().map(|t| t.height()).unwrap_or(0); if nodes <= TAPROOT_CONTROL_MAX_NODE_COUNT { Ok(Self { @@ -186,10 +215,16 @@ impl Tr { } /// Obtain the [`TapTree`] of the [`Tr`] descriptor - pub fn taptree(&self) -> &Option> { + pub fn tap_tree(&self) -> &Option> { &self.tree } + /// Obtain the [`TapTree`] of the [`Tr`] descriptor + #[deprecated(since = "11.0.0", note = "use tap_tree instead")] + pub fn taptree(&self) -> &Option> { + self.tap_tree() + } + /// Iterate over all scripts in merkle tree. If there is no script path, the iterator /// yields [`None`] pub fn iter_scripts(&self) -> TapTreeIter { @@ -258,7 +293,7 @@ impl Tr { /// # Errors /// When the descriptor is impossible to safisfy (ex: sh(OP_FALSE)). pub fn max_weight_to_satisfy(&self) -> Result { - let tree = match self.taptree() { + let tree = match self.tap_tree() { None => { // key spend path // item: varint(sig+sigHash) + @@ -309,7 +344,7 @@ impl Tr { /// When the descriptor is impossible to safisfy (ex: sh(OP_FALSE)). #[deprecated(note = "use max_weight_to_satisfy instead")] pub fn max_satisfaction_weight(&self) -> Result { - let tree = match self.taptree() { + let tree = match self.tap_tree() { // key spend path: // scriptSigLen(4) + stackLen(1) + stack[Sig]Len(1) + stack[Sig](65) None => return Ok(4 + 1 + 1 + 65), @@ -407,9 +442,13 @@ where fn next(&mut self) -> Option { while let Some((depth, last)) = self.stack.pop() { match *last { - TapTree::Tree(ref l, ref r) => { - self.stack.push((depth + 1, r)); - self.stack.push((depth + 1, l)); + TapTree::Tree { + ref left, + ref right, + height: _, + } => { + self.stack.push((depth + 1, right)); + self.stack.push((depth + 1, left)); } TapTree::Leaf(ref ms) => return Some((depth, ms)), } @@ -431,7 +470,7 @@ impl_block_str!( expression::Tree { name, args } if name.is_empty() && args.len() == 2 => { let left = Self::parse_tr_script_spend(&args[0])?; let right = Self::parse_tr_script_spend(&args[1])?; - Ok(TapTree::Tree(Arc::new(left), Arc::new(right))) + Ok(TapTree::combine(left, right)) } _ => Err(Error::Unexpected( "unknown format for script spending paths while parsing taproot descriptor" @@ -589,10 +628,15 @@ fn split_once(inp: &str, delim: char) -> Option<(&str, &str)> { impl Liftable for TapTree { fn lift(&self) -> Result, Error> { fn lift_helper(s: &TapTree) -> Result, Error> { - match s { - TapTree::Tree(ref l, ref r) => { - Ok(Policy::Threshold(1, vec![lift_helper(l)?, lift_helper(r)?])) - } + match *s { + TapTree::Tree { + ref left, + ref right, + height: _, + } => Ok(Policy::Threshold( + 1, + vec![lift_helper(left)?, lift_helper(right)?], + )), TapTree::Leaf(ref leaf) => leaf.lift(), } } @@ -713,10 +757,8 @@ where #[cfg(test)] mod tests { use super::*; - use crate::ForEachKey; - #[test] - fn test_for_each() { + fn descriptor() -> String { let desc = "tr(acc0, { multi_a(3, acc10, acc11, acc12), { and_v( @@ -729,9 +771,21 @@ mod tests { ) } })"; - let desc = desc.replace(&[' ', '\n'][..], ""); + desc.replace(&[' ', '\n'][..], "") + } + + #[test] + fn for_each() { + let desc = descriptor(); let tr = Tr::::from_str(&desc).unwrap(); // Note the last ac12 only has ac and fails the predicate assert!(!tr.for_each_key(|k| k.starts_with("acc"))); } + + #[test] + fn height() { + let desc = descriptor(); + let tr = Tr::::from_str(&desc).unwrap(); + assert_eq!(tr.tap_tree().as_ref().unwrap().height(), 2); + } } diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 0b9045ee2..d030a99d2 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -403,8 +403,8 @@ impl Policy { compilation.sanity_check()?; leaf_compilations.push((OrdF64(prob), compilation)); } - let taptree = with_huffman_tree::(leaf_compilations)?; - Some(taptree) + let tap_tree = with_huffman_tree::(leaf_compilations)?; + Some(tap_tree) } }, )?; @@ -462,8 +462,8 @@ impl Policy { ) }) .collect(); - let taptree = with_huffman_tree::(leaf_compilations).unwrap(); - Some(taptree) + let tap_tree = with_huffman_tree::(leaf_compilations).unwrap(); + Some(tap_tree) } }, )?; @@ -1202,10 +1202,7 @@ fn with_huffman_tree( let (p2, s2) = node_weights.pop().expect("len must atleast be two"); let p = (p1.0).0 + (p2.0).0; - node_weights.push(( - Reverse(OrdF64(p)), - TapTree::Tree(Arc::from(s1), Arc::from(s2)), - )); + node_weights.push((Reverse(OrdF64(p)), TapTree::combine(s1, s2))); } debug_assert!(node_weights.len() == 1); diff --git a/src/policy/mod.rs b/src/policy/mod.rs index 8e80dd489..27165c2e2 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -387,9 +387,11 @@ mod tests { Arc::new(ms_str!("and_v(v:pk(C),pk(D))")); let right_ms_compilation: Arc> = Arc::new(ms_str!("and_v(v:pk(A),pk(B))")); - let left_node: Arc> = Arc::from(TapTree::Leaf(left_ms_compilation)); - let right_node: Arc> = Arc::from(TapTree::Leaf(right_ms_compilation)); - let tree: TapTree = TapTree::Tree(left_node, right_node); + + let left = TapTree::Leaf(left_ms_compilation); + let right = TapTree::Leaf(right_ms_compilation); + let tree = TapTree::combine(left, right); + let expected_descriptor = Descriptor::new_tr(unspendable_key.clone(), Some(tree)).unwrap(); assert_eq!(descriptor, expected_descriptor); @@ -457,21 +459,18 @@ mod tests { .collect::>(); // Arrange leaf compilations (acc. to probabilities) using huffman encoding into a TapTree - let tree = TapTree::Tree( - Arc::from(TapTree::Tree( - Arc::from(node_compilations[4].clone()), - Arc::from(node_compilations[5].clone()), - )), - Arc::from(TapTree::Tree( - Arc::from(TapTree::Tree( - Arc::from(TapTree::Tree( - Arc::from(node_compilations[0].clone()), - Arc::from(node_compilations[1].clone()), - )), - Arc::from(node_compilations[3].clone()), - )), - Arc::from(node_compilations[6].clone()), - )), + let tree = TapTree::combine( + TapTree::combine(node_compilations[4].clone(), node_compilations[5].clone()), + TapTree::combine( + TapTree::combine( + TapTree::combine( + node_compilations[0].clone(), + node_compilations[1].clone(), + ), + node_compilations[3].clone(), + ), + node_compilations[6].clone(), + ), ); let expected_descriptor = Descriptor::new_tr("E".to_string(), Some(tree)).unwrap(); diff --git a/src/psbt/mod.rs b/src/psbt/mod.rs index d9f0ed047..1668f2f66 100644 --- a/src/psbt/mod.rs +++ b/src/psbt/mod.rs @@ -1222,7 +1222,7 @@ fn update_item_with_descriptor_helper( match item.tap_tree() { // Only set the tap_tree if the item supports it (it's an output) and the descriptor actually // contains one, otherwise it'll just be empty - Some(tap_tree) if tr_derived.taptree().is_some() => { + Some(tap_tree) if tr_derived.tap_tree().is_some() => { *tap_tree = Some( taproot::TapTree::try_from(builder) .expect("The tree should always be valid"),