1use hashbrown::{HashMap, HashSet};
21use sled_overlay::{sled::Tree, SledDbOverlay};
22
23use super::{
24 bits::{merge_owned_and_bits, Bits, BitsOwned},
25 node::{Node, Unit},
26 utils::{get_sorted_indices, slice_to_hash},
27 Hash, Proof, HASH_LEN, ROOT_KEY,
28};
29use crate::{ContractError, GenericResult};
30
31#[derive(Clone, Debug)]
32pub(crate) struct MemCache {
33 pub(crate) set: HashSet<Hash>,
34 pub(crate) map: HashMap<Hash, Vec<u8>>,
35}
36
37#[allow(dead_code)]
38impl MemCache {
39 pub(crate) fn new() -> Self {
40 Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
41 }
42
43 pub(crate) fn clear(&mut self) {
44 self.set.clear();
45 self.map.clear();
46 }
47
48 pub(crate) fn contains(&self, key: &[u8]) -> bool {
49 !self.set.contains(key) && self.map.contains_key(key)
50 }
51
52 pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
53 self.map.get(key).cloned()
54 }
55
56 pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
57 self.map.insert(slice_to_hash(key), value);
58 self.set.remove(key);
59 }
60
61 pub(crate) fn del(&mut self, key: &[u8]) {
62 self.set.insert(slice_to_hash(key));
63 }
64}
65
66pub trait MonotreeStorageAdapter {
68 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()>;
70 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>>;
72 fn del(&mut self, key: &Hash) -> GenericResult<()>;
74 fn init_batch(&mut self) -> GenericResult<()>;
76 fn finish_batch(&mut self) -> GenericResult<()>;
78}
79
80#[derive(Clone, Debug)]
82pub struct MemoryDb {
83 db: HashMap<Hash, Vec<u8>>,
84 batch: MemCache,
85 batch_on: bool,
86}
87
88#[allow(clippy::new_without_default)]
89impl MemoryDb {
90 pub fn new() -> Self {
91 Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
92 }
93}
94
95impl MonotreeStorageAdapter for MemoryDb {
96 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
97 if self.batch_on {
98 self.batch.put(key, value);
99 } else {
100 self.db.insert(slice_to_hash(key), value);
101 }
102
103 Ok(())
104 }
105
106 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
107 if self.batch_on && self.batch.contains(key) {
108 return Ok(self.batch.get(key));
109 }
110
111 match self.db.get(key) {
112 Some(v) => Ok(Some(v.to_owned())),
113 None => Ok(None),
114 }
115 }
116
117 fn del(&mut self, key: &Hash) -> GenericResult<()> {
118 if self.batch_on {
119 self.batch.del(key);
120 } else {
121 self.db.remove(key);
122 }
123
124 Ok(())
125 }
126
127 fn init_batch(&mut self) -> GenericResult<()> {
128 if !self.batch_on {
129 self.batch.clear();
130 self.batch_on = true;
131 }
132
133 Ok(())
134 }
135
136 fn finish_batch(&mut self) -> GenericResult<()> {
137 if self.batch_on {
138 for (key, value) in self.batch.map.drain() {
139 self.db.insert(key, value);
140 }
141 for key in self.batch.set.drain() {
142 self.db.remove(&key);
143 }
144 self.batch_on = false;
145 }
146
147 Ok(())
148 }
149}
150
151#[derive(Clone)]
153pub struct SledTreeDb {
154 tree: Tree,
155 batch: MemCache,
156 batch_on: bool,
157}
158
159impl SledTreeDb {
160 pub fn new(tree: &Tree) -> Self {
161 Self { tree: tree.clone(), batch: MemCache::new(), batch_on: false }
162 }
163}
164
165impl MonotreeStorageAdapter for SledTreeDb {
166 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
167 if self.batch_on {
168 self.batch.put(key, value);
169 } else if let Err(e) = self.tree.insert(slice_to_hash(key), value) {
170 return Err(ContractError::IoError(e.to_string()))
171 }
172
173 Ok(())
174 }
175
176 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
177 if self.batch_on && self.batch.contains(key) {
178 return Ok(self.batch.get(key));
179 }
180
181 match self.tree.get(key) {
182 Ok(Some(v)) => Ok(Some(v.to_vec())),
183 Ok(None) => Ok(None),
184 Err(e) => Err(ContractError::IoError(e.to_string())),
185 }
186 }
187
188 fn del(&mut self, key: &Hash) -> GenericResult<()> {
189 if self.batch_on {
190 self.batch.del(key);
191 } else if let Err(e) = self.tree.remove(key) {
192 return Err(ContractError::IoError(e.to_string()));
193 }
194
195 Ok(())
196 }
197
198 fn init_batch(&mut self) -> GenericResult<()> {
199 if !self.batch_on {
200 self.batch.clear();
201 self.batch_on = true;
202 }
203
204 Ok(())
205 }
206
207 fn finish_batch(&mut self) -> GenericResult<()> {
208 if self.batch_on {
209 for (key, value) in self.batch.map.drain() {
210 if let Err(e) = self.tree.insert(key, value) {
211 return Err(ContractError::IoError(e.to_string()))
212 }
213 }
214 for key in self.batch.set.drain() {
215 if let Err(e) = self.tree.remove(key) {
216 return Err(ContractError::IoError(e.to_string()))
217 }
218 }
219 self.batch_on = false;
220 }
221
222 Ok(())
223 }
224}
225
226pub struct SledOverlayDb<'a> {
228 overlay: &'a mut SledDbOverlay,
229 tree: [u8; 32],
230 batch: MemCache,
231 batch_on: bool,
232}
233
234impl<'a> SledOverlayDb<'a> {
235 pub fn new(
236 overlay: &'a mut SledDbOverlay,
237 tree: &[u8; 32],
238 ) -> GenericResult<SledOverlayDb<'a>> {
239 if let Err(e) = overlay.open_tree(tree, false) {
240 return Err(ContractError::IoError(e.to_string()))
241 };
242 Ok(Self { overlay, tree: *tree, batch: MemCache::new(), batch_on: false })
243 }
244}
245
246impl MonotreeStorageAdapter for SledOverlayDb<'_> {
247 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
248 if self.batch_on {
249 self.batch.put(key, value);
250 } else if let Err(e) = self.overlay.insert(&self.tree, &slice_to_hash(key), &value) {
251 return Err(ContractError::IoError(e.to_string()))
252 }
253
254 Ok(())
255 }
256
257 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
258 if self.batch_on && self.batch.contains(key) {
259 return Ok(self.batch.get(key));
260 }
261
262 match self.overlay.get(&self.tree, key) {
263 Ok(Some(v)) => Ok(Some(v.to_vec())),
264 Ok(None) => Ok(None),
265 Err(e) => Err(ContractError::IoError(e.to_string())),
266 }
267 }
268
269 fn del(&mut self, key: &Hash) -> GenericResult<()> {
270 if self.batch_on {
271 self.batch.del(key);
272 } else if let Err(e) = self.overlay.remove(&self.tree, key) {
273 return Err(ContractError::IoError(e.to_string()));
274 }
275
276 Ok(())
277 }
278
279 fn init_batch(&mut self) -> GenericResult<()> {
280 if !self.batch_on {
281 self.batch.clear();
282 self.batch_on = true;
283 }
284
285 Ok(())
286 }
287
288 fn finish_batch(&mut self) -> GenericResult<()> {
289 if self.batch_on {
290 for (key, value) in self.batch.map.drain() {
291 if let Err(e) = self.overlay.insert(&self.tree, &key, &value) {
292 return Err(ContractError::IoError(e.to_string()))
293 }
294 }
295 for key in self.batch.set.drain() {
296 if let Err(e) = self.overlay.remove(&self.tree, &key) {
297 return Err(ContractError::IoError(e.to_string()))
298 }
299 }
300 self.batch_on = false;
301 }
302
303 Ok(())
304 }
305}
306
307#[derive(Clone, Debug)]
312pub struct Monotree<D: MonotreeStorageAdapter> {
313 db: D,
314}
315
316impl<D: MonotreeStorageAdapter> Monotree<D> {
317 pub fn new(db: D) -> Self {
318 Self { db }
319 }
320
321 fn hash_digest(bytes: &[u8]) -> Hash {
322 let mut hasher = blake3::Hasher::new();
323 hasher.update(bytes);
324 let hash = hasher.finalize();
325 slice_to_hash(hash.as_bytes())
326 }
327
328 pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
330 let headroot = self.db.get(ROOT_KEY)?;
331 match headroot {
332 Some(root) => Ok(Some(slice_to_hash(&root))),
333 None => Ok(None),
334 }
335 }
336
337 pub fn set_headroot(&mut self, headroot: Option<&Hash>) -> GenericResult<()> {
339 if let Some(root) = headroot {
340 self.db.put(ROOT_KEY, root.to_vec())?;
341 }
342
343 Ok(())
344 }
345
346 pub fn prepare(&mut self) -> GenericResult<()> {
347 self.db.init_batch()
348 }
349
350 pub fn commit(&mut self) -> GenericResult<()> {
351 self.db.finish_batch()
352 }
353
354 pub fn insert(
356 &mut self,
357 root: Option<&Hash>,
358 key: &Hash,
359 leaf: &Hash,
360 ) -> GenericResult<Option<Hash>> {
361 match root {
362 None => {
363 let (hash, bits) = (leaf, Bits::new(key));
364 self.put_node(Node::new(Some(Unit { hash, bits }), None))
365 }
366 Some(root) => self.put(root, Bits::new(key), leaf),
367 }
368 }
369
370 fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
371 let bytes = node.to_bytes()?;
372 let hash = Self::hash_digest(&bytes);
373 self.db.put(&hash, bytes)?;
374 Ok(Some(hash))
375 }
376
377 fn put_soft_node_owned(
379 &mut self,
380 target_hash: &[u8],
381 bits: &BitsOwned,
382 ) -> GenericResult<Option<Hash>> {
383 let bits_bytes = bits.to_bytes()?;
384 let node_bytes = [target_hash, &bits_bytes[..], &[0x00u8]].concat();
385 let node_hash = Self::hash_digest(&node_bytes);
386 self.db.put(&node_hash, node_bytes)?;
387 Ok(Some(node_hash))
388 }
389
390 fn put_hard_node_mixed(
392 &mut self,
393 left_hash: &[u8],
394 left_bits: &BitsOwned,
395 right: &Unit,
396 ) -> GenericResult<Option<Hash>> {
397 let lb_bytes = left_bits.to_bytes()?;
398 let rb_bytes = right.bits.to_bytes()?;
399
400 let (lh, lb, rh, rb) = if right.bits.first() {
401 (left_hash, &lb_bytes[..], right.hash, &rb_bytes[..])
402 } else {
403 (right.hash, &rb_bytes[..], left_hash, &lb_bytes[..])
404 };
405
406 let node_bytes = [lh, lb, rb, rh, &[0x01u8]].concat();
407 let node_hash = Self::hash_digest(&node_bytes);
408 self.db.put(&node_hash, node_bytes)?;
409 Ok(Some(node_hash))
410 }
411
412 fn collapse_to_target(
415 &mut self,
416 hash: &[u8],
417 prefix: BitsOwned,
418 ) -> GenericResult<(Hash, BitsOwned)> {
419 let Some(bytes) = self.db.get(hash)? else {
420 return Ok((slice_to_hash(hash), prefix))
422 };
423
424 let node = Node::from_bytes(&bytes)?;
425 match node {
426 Node::Soft(Some(child)) => {
427 let merged = merge_owned_and_bits(&prefix, &child.bits);
428 self.collapse_to_target(child.hash, merged)
429 }
430 Node::Hard(_, _) => Ok((slice_to_hash(hash), prefix)),
431 _ => unreachable!("unexpected node type in collapse_to_target"),
432 }
433 }
434
435 fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
461 let bytes =
462 self.db.get(root)?.ok_or(ContractError::MonotreeError("put(): bytes".to_string()))?;
463 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
464 let unit =
465 left.as_ref().ok_or(ContractError::MonotreeError("put(): left-unit".to_string()))?;
466 let n = Bits::len_common_bits(&unit.bits, &bits)?;
467
468 match n {
469 0 => self.put_node(Node::new(left, Some(Unit { hash: leaf, bits }))),
470 n if n == bits.len() => {
471 self.put_node(Node::new(Some(Unit { hash: leaf, bits }), right))
472 }
473 n if n == unit.bits.len() => {
474 let hash = &self.put(unit.hash, bits.drop(n), leaf)?.ok_or(
475 ContractError::MonotreeError("put(): consume & pass-over".to_string()),
476 )?;
477
478 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.to_owned() }), right))
479 }
480 _ => {
481 let hash = &self
482 .put_node(Node::new(
483 Some(Unit { hash: unit.hash, bits: unit.bits.drop(n) }),
484 Some(Unit { hash: leaf, bits: bits.drop(n) }),
485 ))?
486 .ok_or(ContractError::MonotreeError("put(): split-node".to_string()))?;
487
488 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.take(n)? }), right))
489 }
490 }
491 }
492
493 pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
495 match root {
496 None => Ok(None),
497 Some(root) => self.find_key(root, Bits::new(key)),
498 }
499 }
500
501 fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
502 let bytes = self
503 .db
504 .get(root)?
505 .ok_or(ContractError::MonotreeError("find_key(): bytes".to_string()))?;
506 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
507 let unit = cell
508 .as_ref()
509 .ok_or(ContractError::MonotreeError("find_key(): left-unit".to_string()))?;
510 let n = Bits::len_common_bits(&unit.bits, &bits)?;
511 match n {
512 n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
513 n if n == unit.bits.len() => self.find_key(unit.hash, bits.drop(n)),
514 _ => Ok(None),
515 }
516 }
517
518 pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
520 match root {
521 None => Ok(None),
522 Some(root) => self.delete_key(root, Bits::new(key)),
523 }
524 }
525
526 fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
527 let bytes = self
528 .db
529 .get(root)?
530 .ok_or(ContractError::MonotreeError("delete_key(): bytes".to_string()))?;
531 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
532 let unit = left
533 .as_ref()
534 .ok_or(ContractError::MonotreeError("delete_key(): left-unit".to_string()))?;
535 let n = Bits::len_common_bits(&unit.bits, &bits)?;
536
537 match n {
538 n if n == bits.len() => {
540 match right {
541 Some(ref sibling) => {
542 let prefix = sibling.bits.to_bits_owned();
544 let (target, merged_bits) =
545 self.collapse_to_target(sibling.hash, prefix)?;
546 self.put_soft_node_owned(&target, &merged_bits)
547 }
548 None => Ok(None),
549 }
550 }
551 n if n == unit.bits.len() => {
553 let hash = self.delete_key(unit.hash, bits.drop(n))?;
554 match (hash, &right) {
555 (None, None) => Ok(None),
556
557 (None, Some(sibling)) => {
558 let prefix = sibling.bits.to_bits_owned();
560 let (target, merged_bits) =
561 self.collapse_to_target(sibling.hash, prefix)?;
562 self.put_soft_node_owned(&target, &merged_bits)
563 }
564
565 (Some(ref new_child), None) => {
566 let prefix = unit.bits.to_bits_owned();
568 let (target, merged_bits) = self.collapse_to_target(new_child, prefix)?;
569 self.put_soft_node_owned(&target, &merged_bits)
570 }
571
572 (Some(ref new_child), Some(sibling)) => {
573 match self.db.get(new_child)? {
575 Some(child_bytes) => {
576 match Node::from_bytes(&child_bytes)? {
577 Node::Soft(Some(inner)) => {
578 let merged = Bits::merge(&unit.bits, &inner.bits);
580 self.put_hard_node_mixed(inner.hash, &merged, sibling)
581 }
582 Node::Hard(_, _) => {
583 let parent_bits = unit.bits.to_bits_owned();
585 self.put_hard_node_mixed(new_child, &parent_bits, sibling)
586 }
587 _ => unreachable!(),
588 }
589 }
590 None => {
591 let parent_bits = unit.bits.to_bits_owned();
593 self.put_hard_node_mixed(new_child, &parent_bits, sibling)
594 }
595 }
596 }
597 }
598 }
599 _ => Ok(None),
600 }
601 }
602
603 pub fn inserts(
606 &mut self,
607 root: Option<&Hash>,
608 keys: &[Hash],
609 leaves: &[Hash],
610 ) -> GenericResult<Option<Hash>> {
611 let indices = get_sorted_indices(keys, false);
612 self.prepare()?;
613
614 let mut root = root.cloned();
615 for i in indices.iter() {
616 root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
617 }
618
619 self.commit()?;
620 Ok(root)
621 }
622
623 pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
625 let mut leaves: Vec<Option<Hash>> = vec![];
626 for key in keys.iter() {
627 leaves.push(self.get(root, key)?);
628 }
629 Ok(leaves)
630 }
631
632 pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
635 let indices = get_sorted_indices(keys, false);
636 let mut root = root.cloned();
637 self.prepare()?;
638
639 for i in indices.iter() {
640 root = self.remove(root.as_ref(), &keys[*i])?;
641 }
642
643 self.commit()?;
644 Ok(root)
645 }
646
647 pub fn get_merkle_proof(
649 &mut self,
650 root: Option<&Hash>,
651 key: &[u8],
652 ) -> GenericResult<Option<Proof>> {
653 let mut proof: Proof = vec![];
654 match root {
655 None => Ok(None),
656 Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
657 }
658 }
659
660 fn gen_proof(
661 &mut self,
662 root: &[u8],
663 bits: Bits,
664 proof: &mut Proof,
665 ) -> GenericResult<Option<Proof>> {
666 let bytes = self
667 .db
668 .get(root)?
669 .ok_or(ContractError::MonotreeError("gen_proof(): bytes".to_string()))?;
670 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
671 let unit = cell
672 .as_ref()
673 .ok_or(ContractError::MonotreeError("gen_proof(): left-unit".to_string()))?;
674 let n = Bits::len_common_bits(&unit.bits, &bits)?;
675
676 match n {
677 n if n == bits.len() => {
678 proof.push(self.encode_proof(&bytes, bits.first())?);
679 Ok(Some(proof.to_owned()))
680 }
681 n if n == unit.bits.len() => {
682 proof.push(self.encode_proof(&bytes, bits.first())?);
683 self.gen_proof(unit.hash, bits.drop(n), proof)
684 }
685 _ => Ok(None),
686 }
687 }
688
689 fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
690 match Node::from_bytes(bytes)? {
691 Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
692 Node::Hard(_, _) => {
693 if right {
694 Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
695 } else {
696 Ok((false, bytes[HASH_LEN..].to_vec()))
697 }
698 }
699 }
700 }
701}
702
703pub fn verify_proof(
707 root: Option<&Hash>,
708 leaf: &Hash,
709 proof: Option<&Proof>,
710) -> GenericResult<bool> {
711 match proof {
712 None => Ok(false),
713 Some(proof) => {
714 let mut hash = leaf.to_owned();
715 proof.iter().rev().for_each(|(right, cut)| {
716 if *right {
717 let l = cut.len();
718 let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
719 hash = Monotree::<MemoryDb>::hash_digest(&o);
720 } else {
721 let o = [&hash[..], &cut[..]].concat();
722 hash = Monotree::<MemoryDb>::hash_digest(&o);
723 }
724 });
725
726 Ok(root.ok_or(ContractError::MonotreeError("verify_proof(): root".to_string()))? ==
727 &hash)
728 }
729 }
730}