darkfi/zkas/
decoder.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use darkfi_serial::{deserialize_limited_partial, deserialize_partial, VarInt};
20
21use super::{
22    compiler::MAGIC_BYTES,
23    constants::{
24        MAX_ARGS_PER_OPCODE, MAX_BIN_SIZE, MAX_CONSTANTS, MAX_HEAP_SIZE, MAX_K, MAX_LITERALS,
25        MAX_NS_LEN, MAX_OPCODES, MAX_STRING_LEN, MAX_WITNESSES, MIN_BIN_SIZE, SECTION_CIRCUIT,
26        SECTION_CONSTANT, SECTION_DEBUG, SECTION_LITERAL, SECTION_WITNESS,
27    },
28    types::HeapType,
29    LitType, Opcode, VarType,
30};
31use crate::{Error::ZkasDecoderError as ZkasErr, Result};
32
33/// A ZkBinary decoded from compiled zkas code.
34/// This is used by the zkvm.
35///
36/// The binary format consists of:
37/// - Header: magic bytes, version, k param, namespace
38/// - `.constant` section: constant types and names
39/// - `.literal` section: literal types and values
40/// - `.witness` section: witness types
41/// - `.circuit` section: opcoddes and their arguments
42/// - `.debug` section (optional): debug informatioon
43#[derive(Clone, Debug)]
44// ANCHOR: zkbinary-struct
45pub struct ZkBinary {
46    pub namespace: String,
47    pub k: u32,
48    pub constants: Vec<(VarType, String)>,
49    pub literals: Vec<(LitType, String)>,
50    pub witnesses: Vec<VarType>,
51    pub opcodes: Vec<(Opcode, Vec<(HeapType, usize)>)>,
52    pub debug_info: Option<DebugInfo>,
53}
54// ANCHOR_END: zkbinary-struct
55
56/// Debug information decoded from the optional .debug section
57/// Contains source mappings to help debug circuit failures.
58#[derive(Clone, Debug, Default)]
59pub struct DebugInfo {
60    /// Source locations (line, col) for each opcode
61    pub opcode_locations: Vec<(usize, usize)>,
62    /// Variable names for each heap entry (constants, witnesses, assigned vars in order)
63    pub heap_names: Vec<String>,
64    /// Literal values as strings
65    pub literal_names: Vec<String>,
66}
67
68// https://stackoverflow.com/questions/35901547/how-can-i-find-a-subsequence-in-a-u8-slice
69fn find_subslice(haystack: &[u8], needle: &[u8]) -> Option<usize> {
70    haystack.windows(needle.len()).position(|window| window == needle)
71}
72
73fn find_section(bytes: &[u8], section: &[u8]) -> Result<usize> {
74    find_subslice(bytes, section).ok_or_else(|| {
75        ZkasErr(format!("Could not find {} section", String::from_utf8_lossy(section)))
76    })
77}
78
79/// Validate that a count is within limits and reasonable for the remaining bytes
80fn validate_count(
81    count: u64,
82    max: usize,
83    remaining_bytes: usize,
84    item_name: &str,
85) -> Result<usize> {
86    let count = count as usize;
87
88    if count > max {
89        return Err(ZkasErr(format!(
90            "{} count {} exceeds maximum allowed {}",
91            item_name, count, max
92        )));
93    }
94
95    // Sanity check: each item needs at least 1 byte
96    if count > remaining_bytes {
97        return Err(ZkasErr(format!(
98            "{} count {} exceeds remaining bytes {}",
99            item_name, count, remaining_bytes
100        )));
101    }
102
103    Ok(count)
104}
105
106struct SectionOffsets {
107    constant: usize,
108    literal: usize,
109    witness: usize,
110    circuit: usize,
111    debug: usize,
112}
113
114impl SectionOffsets {
115    /// Find all section offsets in the binary and validate their order
116    fn find(bytes: &[u8]) -> Result<Self> {
117        let constant = find_section(bytes, SECTION_CONSTANT)?;
118        let literal = find_section(bytes, SECTION_LITERAL)?;
119        let witness = find_section(bytes, SECTION_WITNESS)?;
120        let circuit = find_section(bytes, SECTION_CIRCUIT)?;
121        // Debug section is optional, so use end of bytes if not present
122        let debug = find_subslice(bytes, SECTION_DEBUG).unwrap_or(bytes.len());
123
124        // Validate section order
125        let sections = [
126            (constant, ".constant"),
127            (literal, ".literal"),
128            (witness, ".witness"),
129            (circuit, ".circuit"),
130            (debug, "debug/EOF"),
131        ];
132
133        for i in 0..sections.len() - 1 {
134            if sections[i].0 > sections[i + 1].0 {
135                return Err(ZkasErr(format!(
136                    "{} section appeared before {}",
137                    sections[i + 1].1,
138                    sections[i].1
139                )));
140            }
141        }
142
143        Ok(Self { constant, literal, witness, circuit, debug })
144    }
145
146    /// Extract the bytes for the constant section
147    fn constant_bytes<'a>(&self, bytes: &'a [u8]) -> &'a [u8] {
148        &bytes[self.constant + SECTION_CONSTANT.len()..self.literal]
149    }
150
151    /// Extract the bytes for the literal section
152    fn literal_bytes<'a>(&self, bytes: &'a [u8]) -> &'a [u8] {
153        &bytes[self.literal + SECTION_LITERAL.len()..self.witness]
154    }
155
156    /// Extract the bytes for the witness section
157    fn witness_bytes<'a>(&self, bytes: &'a [u8]) -> &'a [u8] {
158        &bytes[self.witness + SECTION_WITNESS.len()..self.circuit]
159    }
160
161    /// Extract the bytes for the circuit section
162    fn circuit_bytes<'a>(&self, bytes: &'a [u8]) -> &'a [u8] {
163        &bytes[self.circuit + SECTION_CIRCUIT.len()..self.debug]
164    }
165
166    /// Extract the bytes for the debug section if present
167    fn debug_bytes<'a>(&self, bytes: &'a [u8]) -> Option<&'a [u8]> {
168        if self.debug < bytes.len() {
169            Some(&bytes[self.debug + SECTION_DEBUG.len()..])
170        } else {
171            None
172        }
173    }
174}
175
176impl ZkBinary {
177    /// Decode a ZkBinary from compiled bytes
178    pub fn decode(bytes: &[u8], decode_debug_symbols: bool) -> Result<Self> {
179        // Ensure that bytes is a certain minimum length. Otherwise the code
180        // below will panic due to an index out of bounds error.
181        if bytes.len() < MIN_BIN_SIZE {
182            return Err(ZkasErr("Not enough bytes".to_string()))
183        }
184
185        // Check max size to prevent decoding maliciously large binaries
186        if bytes.len() > MAX_BIN_SIZE {
187            return Err(ZkasErr(format!(
188                "Binary size {} exceeds maximum allowed {}",
189                bytes.len(),
190                MAX_BIN_SIZE
191            )))
192        }
193
194        let magic_bytes = &bytes[0..4];
195        if magic_bytes != MAGIC_BYTES {
196            return Err(ZkasErr("Magic bytes are incorrect".to_string()))
197        }
198
199        let _binary_version = &bytes[4];
200
201        // Deserialize the k param
202        let (k, _): (u32, _) = deserialize_partial(&bytes[5..9])?;
203
204        // For now, we'll limit k.
205        if k > MAX_K {
206            return Err(ZkasErr(format!("k param is too high, max allowed is {MAX_K}")))
207        }
208
209        // After the binary version and k, we're supposed to have the witness namespace
210        let (namespace, _) = deserialize_limited_partial::<String>(&bytes[9..], MAX_NS_LEN)?;
211
212        // ===============
213        // Section parsing
214        // ===============
215        let offsets = SectionOffsets::find(bytes)?;
216
217        let constants = Self::parse_constants(offsets.constant_bytes(bytes))?;
218        let literals = Self::parse_literals(offsets.literal_bytes(bytes))?;
219        let witnesses = Self::parse_witnesses(offsets.witness_bytes(bytes))?;
220        let opcodes = Self::parse_circuit(offsets.circuit_bytes(bytes))?;
221
222        let mut debug_info = None;
223        if decode_debug_symbols {
224            debug_info = match offsets.debug_bytes(bytes) {
225                Some(debug_bytes) => Some(Self::parse_debug(debug_bytes)?),
226                None => None,
227            };
228        }
229
230        let binary = Self { namespace, k, constants, literals, witnesses, opcodes, debug_info };
231
232        // Validate cross-references between sections
233        binary.validate()?;
234
235        Ok(binary)
236    }
237
238    /// Validate cross-references and consistency between sections.
239    /// This catches malicious binaries that pass individual section
240    /// parsing but have invalid references.
241    fn validate(&self) -> Result<()> {
242        // Calculate actual heap size: constants + witnesses + assigned vars
243        // Each opcode that produces a result adds one entry to the heap
244        let num_assignments = self
245            .opcodes
246            .iter()
247            .filter(|(op, _)| {
248                let (ret_types, _) = op.arg_types();
249                !ret_types.is_empty()
250            })
251            .count();
252
253        let heap_size = self.constants.len() + self.witnesses.len() + num_assignments;
254
255        // Validate all heap references in opcodes
256        for (op_idx, (opcode, args)) in self.opcodes.iter().enumerate() {
257            // Calculate heap size at this point in execution
258            // (constants + witnesses + results from previous opcodes)
259            let prev_assignments = self.opcodes[..op_idx]
260                .iter()
261                .filter(|(op, _)| {
262                    let (ret_types, _) = op.arg_types();
263                    !ret_types.is_empty()
264                })
265                .count();
266            let available_heap = self.constants.len() + self.witnesses.len() + prev_assignments;
267
268            for (heap_type, heap_idx) in args {
269                match heap_type {
270                    HeapType::Var => {
271                        if *heap_idx >= available_heap {
272                            return Err(ZkasErr(format!(
273                                "Opcode {} references heap idx {} but only {} entries available",
274                                opcode.name(),
275                                heap_idx,
276                                available_heap
277                            )));
278                        }
279                    }
280                    HeapType::Lit => {
281                        if *heap_idx >= self.literals.len() {
282                            return Err(ZkasErr(format!(
283                                "Opcode {} references literal idx {} but only {} literals exist",
284                                opcode.name(),
285                                heap_idx,
286                                self.literals.len()
287                            )));
288                        }
289                    }
290                }
291            }
292        }
293        // Validate debug info consistency if present
294        if let Some(ref debug) = self.debug_info {
295            if debug.opcode_locations.len() != self.opcodes.len() {
296                return Err(ZkasErr(format!(
297                    "Debug info has {} opcode locations but circuit has {} opcodes",
298                    debug.opcode_locations.len(),
299                    self.opcodes.len()
300                )));
301            }
302
303            if debug.heap_names.len() != heap_size {
304                return Err(ZkasErr(format!(
305                    "Debug info has {} heap names but heap has {} entries",
306                    debug.heap_names.len(),
307                    heap_size
308                )));
309            }
310
311            if debug.literal_names.len() != self.literals.len() {
312                return Err(ZkasErr(format!(
313                    "Debug info has {} literal names but {} literals exist",
314                    debug.literal_names.len(),
315                    self.literals.len()
316                )));
317            }
318        }
319
320        Ok(())
321    }
322
323    fn parse_constants(bytes: &[u8]) -> Result<Vec<(VarType, String)>> {
324        let mut constants = vec![];
325        let mut offset = 0;
326
327        while offset < bytes.len() {
328            // Check we haven't exceeded the limit
329            if constants.len() >= MAX_CONSTANTS {
330                return Err(ZkasErr(format!(
331                    "Too many constants, maximum allowed is {MAX_CONSTANTS}"
332                )))
333            }
334
335            let c_type = VarType::from_repr(bytes[offset]).ok_or_else(|| {
336                ZkasErr(format!("Could not decode constant VarType from {}", bytes[offset]))
337            })?;
338            offset += 1;
339
340            let (name, len) =
341                deserialize_limited_partial::<String>(&bytes[offset..], MAX_STRING_LEN)?;
342            offset += len;
343
344            constants.push((c_type, name));
345        }
346
347        Ok(constants)
348    }
349
350    fn parse_literals(bytes: &[u8]) -> Result<Vec<(LitType, String)>> {
351        let mut literals = vec![];
352        let mut offset = 0;
353
354        while offset < bytes.len() {
355            // Check we haven't exceeded the limit
356            if literals.len() >= MAX_LITERALS {
357                return Err(ZkasErr(format!(
358                    "Too many literals, maximum allowed is {MAX_LITERALS}"
359                )));
360            }
361
362            let l_type = LitType::from_repr(bytes[offset]).ok_or_else(|| {
363                ZkasErr(format!("Could not decode literal LitType from {}", bytes[offset]))
364            })?;
365            offset += 1;
366
367            let (name, len) =
368                deserialize_limited_partial::<String>(&bytes[offset..], MAX_STRING_LEN)?;
369            offset += len;
370
371            literals.push((l_type, name));
372        }
373
374        Ok(literals)
375    }
376
377    fn parse_witnesses(bytes: &[u8]) -> Result<Vec<VarType>> {
378        // Check vount before allocating
379        if bytes.len() > MAX_WITNESSES {
380            return Err(ZkasErr(format!(
381                "Too many witnesses ({}), maximum allowed is {}",
382                bytes.len(),
383                MAX_WITNESSES
384            )));
385        }
386
387        let mut witnesses = Vec::with_capacity(bytes.len());
388
389        for &byte in bytes {
390            let w_type = VarType::from_repr(byte).ok_or_else(|| {
391                ZkasErr(format!("Could not decode witness VarType from {}", byte))
392            })?;
393
394            witnesses.push(w_type);
395        }
396
397        Ok(witnesses)
398    }
399
400    #[allow(clippy::type_complexity)]
401    fn parse_circuit(bytes: &[u8]) -> Result<Vec<(Opcode, Vec<(HeapType, usize)>)>> {
402        let mut opcodes = vec![];
403        let mut offset = 0;
404
405        while offset < bytes.len() {
406            // Check opcode count limit
407            if opcodes.len() >= MAX_OPCODES {
408                return Err(ZkasErr(format!("Too many opcodes, maximum allowed is {MAX_OPCODES}")))
409            }
410
411            let opcode = Opcode::from_repr(bytes[offset]).ok_or_else(|| {
412                ZkasErr(format!("Could not decode Opcode from {}", bytes[offset]))
413            })?;
414            offset += 1;
415
416            // TODO: Check that the types and arg number are correct
417
418            // Parse argument count
419            let (arg_count, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
420            offset += len;
421
422            // Validate argument count
423            let arg_count =
424                validate_count(arg_count.0, MAX_ARGS_PER_OPCODE, bytes.len() - offset, "Argument")?;
425
426            // Parse arguments
427            let mut args = Vec::with_capacity(arg_count);
428            for _ in 0..arg_count {
429                // Check bounds to prevent panics
430                if offset >= bytes.len() {
431                    return Err(ZkasErr(format!(
432                        "Bad offset for circuit: offset {} is >= circuit len {}",
433                        offset,
434                        bytes.len()
435                    )));
436                }
437
438                let heap_type_byte = bytes[offset];
439                offset += 1;
440
441                if offset >= bytes.len() {
442                    return Err(ZkasErr(format!(
443                        "Bad offset for circuit: offset {} is >= circuit len {}",
444                        offset,
445                        bytes.len()
446                    )));
447                }
448
449                let (heap_index, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
450                offset += len;
451
452                let heap_type = HeapType::from_repr(heap_type_byte).ok_or_else(|| {
453                    ZkasErr(format!("Could not decode HeapType from {}", heap_type_byte))
454                })?;
455
456                // Validate heap index is reasonable
457                let heap_idx = heap_index.0 as usize;
458                if heap_idx > MAX_HEAP_SIZE {
459                    return Err(ZkasErr(format!(
460                        "Heap index {} exceeds maximum allowed {}",
461                        heap_idx, MAX_HEAP_SIZE
462                    )));
463                }
464
465                args.push((heap_type, heap_index.0 as usize));
466            }
467
468            opcodes.push((opcode, args));
469        }
470
471        Ok(opcodes)
472    }
473
474    fn parse_debug(bytes: &[u8]) -> Result<DebugInfo> {
475        let mut offset = 0;
476
477        // Parse opcode source locations
478        let (num_opcodes, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
479        offset += len;
480
481        let num_opcodes =
482            validate_count(num_opcodes.0, MAX_OPCODES, bytes.len() - offset, "Debug opcode")?;
483
484        let mut opcode_locations = Vec::with_capacity(num_opcodes);
485        for _ in 0..num_opcodes {
486            let (line, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
487            offset += len;
488            let (column, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
489            offset += len;
490            opcode_locations.push((line.0 as usize, column.0 as usize));
491        }
492
493        // Parse heap var names
494        let (heap_size, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
495        offset += len;
496
497        let heap_size =
498            validate_count(heap_size.0, MAX_HEAP_SIZE, bytes.len() - offset, "Debug heap")?;
499
500        let mut heap_names = Vec::with_capacity(heap_size);
501        for _ in 0..heap_size {
502            let (name, len) =
503                deserialize_limited_partial::<String>(&bytes[offset..], MAX_STRING_LEN)?;
504            offset += len;
505            heap_names.push(name);
506        }
507
508        // Parse literal names
509        let (num_literals, len) = deserialize_partial::<VarInt>(&bytes[offset..])?;
510        offset += len;
511
512        let num_literals =
513            validate_count(num_literals.0, MAX_LITERALS, bytes.len() - offset, "Debug literal")?;
514
515        let mut literal_names = Vec::with_capacity(num_literals);
516        for _ in 0..num_literals {
517            let (name, len) =
518                deserialize_limited_partial::<String>(&bytes[offset..], MAX_STRING_LEN)?;
519            offset += len;
520            literal_names.push(name);
521        }
522
523        Ok(DebugInfo { opcode_locations, heap_names, literal_names })
524    }
525
526    /// Get the source location (line, column) for a given opcode index.
527    /// Returns `None` if debug info is not present or index is OOB.
528    pub fn opcode_location(&self, opcode_idx: usize) -> Option<(usize, usize)> {
529        self.debug_info.as_ref()?.opcode_locations.get(opcode_idx).copied()
530    }
531
532    /// Get the variable name for a given heap index.
533    /// Returns `None` if debug info is not present or index is OOB.
534    pub fn heap_name(&self, heap_idx: usize) -> Option<&str> {
535        self.debug_info.as_ref()?.heap_names.get(heap_idx).map(|s| s.as_str())
536    }
537
538    /// Get the literal name/value for a given literal index.
539    /// Returns `None` if debug info is not present or index is OOB.
540    pub fn literal_name(&self, literal_idx: usize) -> Option<&str> {
541        self.debug_info.as_ref()?.literal_names.get(literal_idx).map(|s| s.as_str())
542    }
543
544    /// Check if debug info is present
545    pub fn has_debug_info(&self) -> bool {
546        self.debug_info.is_some()
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use crate::zkas::ZkBinary;
553
554    #[test]
555    fn panic_regression_001() {
556        // Out-of-memory panic from string deserialization.
557        // Read `doc/src/zkas/bincode.md` to understand the input.
558        let data = vec![11u8, 1, 177, 53, 1, 0, 0, 0, 0, 255, 0, 204, 200, 72, 72, 72, 72, 1];
559        let _dec = ZkBinary::decode(&data, true);
560    }
561
562    #[test]
563    fn panic_regression_002() {
564        // Index out of bounds panic in parse_circuit().
565        // Read `doc/src/zkas/bincode.md` to understand the input.
566        let data = vec![
567            11u8, 1, 177, 53, 2, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 83, 105,
568            109, 112, 108, 101, 46, 99, 111, 110, 115, 116, 97, 110, 116, 3, 18, 86, 65, 76, 85,
569            69, 95, 67, 79, 77, 77, 73, 84, 95, 86, 65, 76, 85, 69, 2, 19, 86, 65, 76, 85, 69, 95,
570            67, 79, 77, 77, 73, 84, 95, 82, 65, 77, 68, 79, 77, 46, 108, 105, 116, 101, 114, 97,
571            108, 46, 119, 105, 116, 110, 101, 115, 115, 16, 18, 46, 99, 105, 114, 99, 117, 105,
572            116, 4, 2, 0, 2, 0, 0, 2, 2, 0, 3, 0, 1, 8, 2, 0, 4, 0, 5, 8, 1, 0, 6, 9, 1, 0, 6, 240,
573            1, 0, 7, 240, 41, 0, 0, 0, 1, 0, 8,
574        ];
575        let _dec = ZkBinary::decode(&data, true);
576    }
577}