1use hashbrown::{HashMap, HashSet};
21use sled_overlay::{sled::Tree, SledDbOverlay};
22
23use super::{
24 bits::{merge_owned_and_bits, Bits, BitsOwned},
25 node::{Node, Unit},
26 utils::{get_sorted_indices, slice_to_hash},
27 Hash, Proof, HASH_LEN, ROOT_KEY,
28};
29use crate::{ContractError, GenericResult};
30
31#[derive(Clone, Debug)]
32pub(crate) struct MemCache {
33 pub(crate) set: HashSet<Hash>,
34 pub(crate) map: HashMap<Hash, Vec<u8>>,
35}
36
37#[allow(dead_code)]
38impl MemCache {
39 pub(crate) fn new() -> Self {
40 Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
41 }
42
43 pub(crate) fn clear(&mut self) {
44 self.set.clear();
45 self.map.clear();
46 }
47
48 pub(crate) fn contains(&self, key: &[u8]) -> bool {
49 !self.set.contains(key) && self.map.contains_key(key)
50 }
51
52 pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
53 self.map.get(key).cloned()
54 }
55
56 pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
57 self.map.insert(slice_to_hash(key), value);
58 self.set.remove(key);
59 }
60
61 pub(crate) fn del(&mut self, key: &[u8]) {
62 self.set.insert(slice_to_hash(key));
63 }
64}
65
66pub trait MonotreeStorageAdapter {
68 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()>;
70 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>>;
72 fn del(&mut self, key: &Hash) -> GenericResult<()>;
74 fn init_batch(&mut self) -> GenericResult<()>;
76 fn finish_batch(&mut self) -> GenericResult<()>;
78}
79
80#[derive(Clone, Debug)]
82pub struct MemoryDb {
83 db: HashMap<Hash, Vec<u8>>,
84 batch: MemCache,
85 batch_on: bool,
86}
87
88#[allow(clippy::new_without_default)]
89impl MemoryDb {
90 pub fn new() -> Self {
91 Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
92 }
93}
94
95impl MonotreeStorageAdapter for MemoryDb {
96 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
97 if self.batch_on {
98 self.batch.put(key, value);
99 } else {
100 self.db.insert(slice_to_hash(key), value);
101 }
102
103 Ok(())
104 }
105
106 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
107 if self.batch_on && self.batch.contains(key) {
108 return Ok(self.batch.get(key));
109 }
110
111 match self.db.get(key) {
112 Some(v) => Ok(Some(v.to_owned())),
113 None => Ok(None),
114 }
115 }
116
117 fn del(&mut self, key: &Hash) -> GenericResult<()> {
118 if self.batch_on {
119 self.batch.del(key);
120 } else {
121 self.db.remove(key);
122 }
123
124 Ok(())
125 }
126
127 fn init_batch(&mut self) -> GenericResult<()> {
128 if !self.batch_on {
129 self.batch.clear();
130 self.batch_on = true;
131 }
132
133 Ok(())
134 }
135
136 fn finish_batch(&mut self) -> GenericResult<()> {
137 if self.batch_on {
138 for (key, value) in self.batch.map.drain() {
139 self.db.insert(key, value);
140 }
141 for key in self.batch.set.drain() {
142 self.db.remove(&key);
143 }
144 self.batch_on = false;
145 }
146
147 Ok(())
148 }
149}
150
151#[derive(Clone)]
153pub struct SledTreeDb {
154 tree: Tree,
155 batch: MemCache,
156 batch_on: bool,
157}
158
159impl SledTreeDb {
160 pub fn new(tree: &Tree) -> Self {
161 Self { tree: tree.clone(), batch: MemCache::new(), batch_on: false }
162 }
163}
164
165impl MonotreeStorageAdapter for SledTreeDb {
166 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
167 if self.batch_on {
168 self.batch.put(key, value);
169 } else if let Err(e) = self.tree.insert(slice_to_hash(key), value) {
170 return Err(ContractError::IoError(e.to_string()))
171 }
172
173 Ok(())
174 }
175
176 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
177 if self.batch_on && self.batch.contains(key) {
178 return Ok(self.batch.get(key));
179 }
180
181 match self.tree.get(key) {
182 Ok(Some(v)) => Ok(Some(v.to_vec())),
183 Ok(None) => Ok(None),
184 Err(e) => Err(ContractError::IoError(e.to_string())),
185 }
186 }
187
188 fn del(&mut self, key: &Hash) -> GenericResult<()> {
189 if self.batch_on {
190 self.batch.del(key);
191 } else if let Err(e) = self.tree.remove(key) {
192 return Err(ContractError::IoError(e.to_string()));
193 }
194
195 Ok(())
196 }
197
198 fn init_batch(&mut self) -> GenericResult<()> {
199 if !self.batch_on {
200 self.batch.clear();
201 self.batch_on = true;
202 }
203
204 Ok(())
205 }
206
207 fn finish_batch(&mut self) -> GenericResult<()> {
208 if self.batch_on {
209 for (key, value) in self.batch.map.drain() {
210 if let Err(e) = self.tree.insert(key, value) {
211 return Err(ContractError::IoError(e.to_string()))
212 }
213 }
214 for key in self.batch.set.drain() {
215 if let Err(e) = self.tree.remove(key) {
216 return Err(ContractError::IoError(e.to_string()))
217 }
218 }
219 self.batch_on = false;
220 }
221
222 Ok(())
223 }
224}
225
226pub struct SledOverlayDb<'a> {
228 overlay: &'a mut SledDbOverlay,
229 tree: [u8; 32],
230 batch: MemCache,
231 batch_on: bool,
232}
233
234impl<'a> SledOverlayDb<'a> {
235 pub fn new(
236 overlay: &'a mut SledDbOverlay,
237 tree: &[u8; 32],
238 ) -> GenericResult<SledOverlayDb<'a>> {
239 if let Err(e) = overlay.open_tree(tree, false) {
240 return Err(ContractError::IoError(e.to_string()))
241 };
242 Ok(Self { overlay, tree: *tree, batch: MemCache::new(), batch_on: false })
243 }
244}
245
246impl MonotreeStorageAdapter for SledOverlayDb<'_> {
247 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
248 if self.batch_on {
249 self.batch.put(key, value);
250 } else if let Err(e) = self.overlay.insert(&self.tree, &slice_to_hash(key), &value) {
251 return Err(ContractError::IoError(e.to_string()))
252 }
253
254 Ok(())
255 }
256
257 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
258 if self.batch_on && self.batch.contains(key) {
259 return Ok(self.batch.get(key));
260 }
261
262 match self.overlay.get(&self.tree, key) {
263 Ok(Some(v)) => Ok(Some(v.to_vec())),
264 Ok(None) => Ok(None),
265 Err(e) => Err(ContractError::IoError(e.to_string())),
266 }
267 }
268
269 fn del(&mut self, key: &Hash) -> GenericResult<()> {
270 if self.batch_on {
271 self.batch.del(key);
272 } else if let Err(e) = self.overlay.remove(&self.tree, key) {
273 return Err(ContractError::IoError(e.to_string()));
274 }
275
276 Ok(())
277 }
278
279 fn init_batch(&mut self) -> GenericResult<()> {
280 if !self.batch_on {
281 self.batch.clear();
282 self.batch_on = true;
283 }
284
285 Ok(())
286 }
287
288 fn finish_batch(&mut self) -> GenericResult<()> {
289 if self.batch_on {
290 for (key, value) in self.batch.map.drain() {
291 if let Err(e) = self.overlay.insert(&self.tree, &key, &value) {
292 return Err(ContractError::IoError(e.to_string()))
293 }
294 }
295 for key in self.batch.set.drain() {
296 if let Err(e) = self.overlay.remove(&self.tree, &key) {
297 return Err(ContractError::IoError(e.to_string()))
298 }
299 }
300 self.batch_on = false;
301 }
302
303 Ok(())
304 }
305}
306
307#[derive(Clone, Debug)]
312pub struct Monotree<D: MonotreeStorageAdapter> {
313 db: D,
314}
315
316impl<D: MonotreeStorageAdapter> Monotree<D> {
317 pub fn new(db: D) -> Self {
318 Self { db }
319 }
320
321 fn hash_digest(bytes: &[u8]) -> Hash {
322 let mut hasher = blake3::Hasher::new();
323 hasher.update(bytes);
324 let hash = hasher.finalize();
325 slice_to_hash(hash.as_bytes())
326 }
327
328 pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
330 let headroot = self.db.get(ROOT_KEY)?;
331 match headroot {
332 Some(root) => Ok(Some(slice_to_hash(&root))),
333 None => Ok(None),
334 }
335 }
336
337 pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
339 if let Some(root) = headroot {
340 self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
341 }
342 }
343
344 pub fn prepare(&mut self) {
345 self.db.init_batch().expect("prepare(): failed to initialize batch");
346 }
347
348 pub fn commit(&mut self) {
349 self.db.finish_batch().expect("commit(): failed to initialize batch");
350 }
351
352 pub fn insert(
354 &mut self,
355 root: Option<&Hash>,
356 key: &Hash,
357 leaf: &Hash,
358 ) -> GenericResult<Option<Hash>> {
359 match root {
360 None => {
361 let (hash, bits) = (leaf, Bits::new(key));
362 self.put_node(Node::new(Some(Unit { hash, bits }), None))
363 }
364 Some(root) => self.put(root, Bits::new(key), leaf),
365 }
366 }
367
368 fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
369 let bytes = node.to_bytes()?;
370 let hash = Self::hash_digest(&bytes);
371 self.db.put(&hash, bytes)?;
372 Ok(Some(hash))
373 }
374
375 fn put_soft_node_owned(
377 &mut self,
378 target_hash: &[u8],
379 bits: &BitsOwned,
380 ) -> GenericResult<Option<Hash>> {
381 let bits_bytes = bits.to_bytes()?;
382 let node_bytes = [target_hash, &bits_bytes[..], &[0x00u8]].concat();
383 let node_hash = Self::hash_digest(&node_bytes);
384 self.db.put(&node_hash, node_bytes)?;
385 Ok(Some(node_hash))
386 }
387
388 fn put_hard_node_mixed(
390 &mut self,
391 left_hash: &[u8],
392 left_bits: &BitsOwned,
393 right: &Unit,
394 ) -> GenericResult<Option<Hash>> {
395 let lb_bytes = left_bits.to_bytes()?;
396 let rb_bytes = right.bits.to_bytes()?;
397
398 let (lh, lb, rh, rb) = if right.bits.first() {
399 (left_hash, &lb_bytes[..], right.hash, &rb_bytes[..])
400 } else {
401 (right.hash, &rb_bytes[..], left_hash, &lb_bytes[..])
402 };
403
404 let node_bytes = [lh, lb, rb, rh, &[0x01u8]].concat();
405 let node_hash = Self::hash_digest(&node_bytes);
406 self.db.put(&node_hash, node_bytes)?;
407 Ok(Some(node_hash))
408 }
409
410 fn collapse_to_target(
413 &mut self,
414 hash: &[u8],
415 prefix: BitsOwned,
416 ) -> GenericResult<(Hash, BitsOwned)> {
417 let Some(bytes) = self.db.get(hash)? else {
418 return Ok((slice_to_hash(hash), prefix))
420 };
421
422 let node = Node::from_bytes(&bytes)?;
423 match node {
424 Node::Soft(Some(child)) => {
425 let merged = merge_owned_and_bits(&prefix, &child.bits);
426 self.collapse_to_target(child.hash, merged)
427 }
428 Node::Hard(_, _) => Ok((slice_to_hash(hash), prefix)),
429 _ => unreachable!("unexpected node type in collapse_to_target"),
430 }
431 }
432
433 fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
459 let bytes = self.db.get(root)?.expect("put(): bytes");
460 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
461 let unit = left.as_ref().expect("put(): left-unit");
462 let n = Bits::len_common_bits(&unit.bits, &bits);
463
464 match n {
465 0 => self.put_node(Node::new(left, Some(Unit { hash: leaf, bits }))),
466 n if n == bits.len() => {
467 self.put_node(Node::new(Some(Unit { hash: leaf, bits }), right))
468 }
469 n if n == unit.bits.len() => {
470 let hash =
471 &self.put(unit.hash, bits.drop(n), leaf)?.expect("put(): consume & pass-over");
472
473 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.to_owned() }), right))
474 }
475 _ => {
476 let hash = &self
477 .put_node(Node::new(
478 Some(Unit { hash: unit.hash, bits: unit.bits.drop(n) }),
479 Some(Unit { hash: leaf, bits: bits.drop(n) }),
480 ))?
481 .expect("put(): split-node");
482
483 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.take(n) }), right))
484 }
485 }
486 }
487
488 pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
490 match root {
491 None => Ok(None),
492 Some(root) => self.find_key(root, Bits::new(key)),
493 }
494 }
495
496 fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
497 let bytes = self.db.get(root)?.expect("find_key(): bytes");
498 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
499 let unit = cell.as_ref().expect("find_key(): left-unit");
500 let n = Bits::len_common_bits(&unit.bits, &bits);
501 match n {
502 n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
503 n if n == unit.bits.len() => self.find_key(unit.hash, bits.drop(n)),
504 _ => Ok(None),
505 }
506 }
507
508 pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
510 match root {
511 None => Ok(None),
512 Some(root) => self.delete_key(root, Bits::new(key)),
513 }
514 }
515
516 fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
517 let bytes = self.db.get(root)?.expect("delete_key(): bytes");
518 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
519 let unit = left.as_ref().expect("delete_key(): left-unit");
520 let n = Bits::len_common_bits(&unit.bits, &bits);
521
522 match n {
523 n if n == bits.len() => {
525 match right {
526 Some(ref sibling) => {
527 let prefix = sibling.bits.to_bits_owned();
529 let (target, merged_bits) =
530 self.collapse_to_target(sibling.hash, prefix)?;
531 self.put_soft_node_owned(&target, &merged_bits)
532 }
533 None => Ok(None),
534 }
535 }
536 n if n == unit.bits.len() => {
538 let hash = self.delete_key(unit.hash, bits.drop(n))?;
539 match (hash, &right) {
540 (None, None) => Ok(None),
541
542 (None, Some(sibling)) => {
543 let prefix = sibling.bits.to_bits_owned();
545 let (target, merged_bits) =
546 self.collapse_to_target(sibling.hash, prefix)?;
547 self.put_soft_node_owned(&target, &merged_bits)
548 }
549
550 (Some(ref new_child), None) => {
551 let prefix = unit.bits.to_bits_owned();
553 let (target, merged_bits) = self.collapse_to_target(new_child, prefix)?;
554 self.put_soft_node_owned(&target, &merged_bits)
555 }
556
557 (Some(ref new_child), Some(sibling)) => {
558 match self.db.get(new_child)? {
560 Some(child_bytes) => {
561 match Node::from_bytes(&child_bytes)? {
562 Node::Soft(Some(inner)) => {
563 let merged = Bits::merge(&unit.bits, &inner.bits);
565 self.put_hard_node_mixed(inner.hash, &merged, sibling)
566 }
567 Node::Hard(_, _) => {
568 let parent_bits = unit.bits.to_bits_owned();
570 self.put_hard_node_mixed(new_child, &parent_bits, sibling)
571 }
572 _ => unreachable!(),
573 }
574 }
575 None => {
576 let parent_bits = unit.bits.to_bits_owned();
578 self.put_hard_node_mixed(new_child, &parent_bits, sibling)
579 }
580 }
581 }
582 }
583 }
584 _ => Ok(None),
585 }
586 }
587
588 pub fn inserts(
591 &mut self,
592 root: Option<&Hash>,
593 keys: &[Hash],
594 leaves: &[Hash],
595 ) -> GenericResult<Option<Hash>> {
596 let indices = get_sorted_indices(keys, false);
597 self.prepare();
598
599 let mut root = root.cloned();
600 for i in indices.iter() {
601 root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
602 }
603
604 self.commit();
605 Ok(root)
606 }
607
608 pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
610 let mut leaves: Vec<Option<Hash>> = vec![];
611 for key in keys.iter() {
612 leaves.push(self.get(root, key)?);
613 }
614 Ok(leaves)
615 }
616
617 pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
620 let indices = get_sorted_indices(keys, false);
621 let mut root = root.cloned();
622 self.prepare();
623
624 for i in indices.iter() {
625 root = self.remove(root.as_ref(), &keys[*i])?;
626 }
627
628 self.commit();
629 Ok(root)
630 }
631
632 pub fn get_merkle_proof(
634 &mut self,
635 root: Option<&Hash>,
636 key: &[u8],
637 ) -> GenericResult<Option<Proof>> {
638 let mut proof: Proof = vec![];
639 match root {
640 None => Ok(None),
641 Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
642 }
643 }
644
645 fn gen_proof(
646 &mut self,
647 root: &[u8],
648 bits: Bits,
649 proof: &mut Proof,
650 ) -> GenericResult<Option<Proof>> {
651 let bytes = self.db.get(root)?.expect("gen_proof(): bytes");
652 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
653 let unit = cell.as_ref().expect("gen_proof(): left-unit");
654 let n = Bits::len_common_bits(&unit.bits, &bits);
655
656 match n {
657 n if n == bits.len() => {
658 proof.push(self.encode_proof(&bytes, bits.first())?);
659 Ok(Some(proof.to_owned()))
660 }
661 n if n == unit.bits.len() => {
662 proof.push(self.encode_proof(&bytes, bits.first())?);
663 self.gen_proof(unit.hash, bits.drop(n), proof)
664 }
665 _ => Ok(None),
666 }
667 }
668
669 fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
670 match Node::from_bytes(bytes)? {
671 Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
672 Node::Hard(_, _) => {
673 if right {
674 Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
675 } else {
676 Ok((false, bytes[HASH_LEN..].to_vec()))
677 }
678 }
679 }
680 }
681}
682
683pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
687 match proof {
688 None => false,
689 Some(proof) => {
690 let mut hash = leaf.to_owned();
691 proof.iter().rev().for_each(|(right, cut)| {
692 if *right {
693 let l = cut.len();
694 let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
695 hash = Monotree::<MemoryDb>::hash_digest(&o);
696 } else {
697 let o = [&hash[..], &cut[..]].concat();
698 hash = Monotree::<MemoryDb>::hash_digest(&o);
699 }
700 });
701 root.expect("verify_proof(): root") == &hash
702 }
703 }
704}