darkfi/zkas/
analyzer.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{
20    io::{stdin, stdout, Read, Result, Write},
21    str::Chars,
22};
23
24use super::{
25    ast::{Arg, Constant, Literal, Statement, StatementType, Var, Variable, Witness},
26    constants::MAX_RECURSION_DEPTH,
27    error::ErrorEmitter,
28    Opcode, VarType,
29};
30
31pub struct Analyzer {
32    pub constants: Vec<Constant>,
33    pub witnesses: Vec<Witness>,
34    pub statements: Vec<Statement>,
35    pub literals: Vec<Literal>,
36    pub heap: Vec<Variable>,
37    error: ErrorEmitter,
38}
39
40impl Analyzer {
41    pub fn new(
42        filename: &str,
43        source: Chars,
44        constants: Vec<Constant>,
45        witnesses: Vec<Witness>,
46        statements: Vec<Statement>,
47    ) -> Self {
48        // For nice error reporting, we'll load everything into a string
49        // vector so we have references to lines.
50        let lines: Vec<String> = source.as_str().lines().map(|x| x.to_string()).collect();
51        let error = ErrorEmitter::new("Semantic", filename, lines);
52
53        Self { constants, witnesses, statements, literals: vec![], heap: vec![], error }
54    }
55
56    pub fn analyze_types(&mut self) -> Result<()> {
57        // To work around the pedantic safety, we'll make new vectors and then
58        // replace the `statements` and `heap` vectors from the `Analyzer`
59        // object when we are done.
60        let mut statements = vec![];
61        let mut heap = vec![];
62
63        let input_statements = self.statements.clone();
64        for statement in &input_statements {
65            //println!("{statement:?}");
66            let mut stmt = statement.clone();
67
68            let (return_types, arg_types) = statement.opcode.arg_types();
69            let mut rhs = vec![];
70
71            // This handling is kinda limiting, but it'll do for now.
72            if !(arg_types[0] == VarType::BaseArray || arg_types[0] == VarType::ScalarArray) {
73                // Check that number of args is correct
74                if statement.rhs.len() != arg_types.len() {
75                    return Err(self.error.abort(
76                        &format!(
77                            "Incorrect number of arguments for statement. Expected {}, got {}.",
78                            arg_types.len(),
79                            statement.rhs.len()
80                        ),
81                        statement.line,
82                        1,
83                    ))
84                }
85            } else {
86                // In case of arrays, check there's at least one element.
87                if statement.rhs.is_empty() {
88                    return Err(self.error.abort(
89                        "Expected at least one element for statement using arrays.",
90                        statement.line,
91                        1,
92                    ))
93                }
94            }
95
96            // Edge-cases for some opcodes
97            #[allow(clippy::single_match)]
98            match &statement.opcode {
99                Opcode::RangeCheck => {
100                    if let Arg::Lit(arg0) = &statement.rhs[0] {
101                        if &arg0.name != "64" && &arg0.name != "253" {
102                            return Err(self.error.abort(
103                                "Supported range checks are only 64 and 253 bits.",
104                                arg0.line,
105                                arg0.column,
106                            ))
107                        }
108                    } else {
109                        return Err(self.error.abort(
110                            "Invalid argument for range_check opcode.",
111                            statement.line,
112                            0,
113                        ))
114                    }
115                }
116
117                _ => {}
118            }
119
120            for (idx, arg) in statement.rhs.iter().enumerate() {
121                // In case an argument is a function call, we will first
122                // convert it to another statement that will get executed
123                // before this one. An important assumption is that this
124                // opcode has a return value. When executed we will push
125                // this value onto the heap and use it as a reference to
126                // the actual statement we're parsing at this moment.
127                // This uses a recursive algorithm to handle arbitrarily
128                // nested functions up to MAX_RECURSION_DEPTH.
129                if let Arg::Func(func) = arg {
130                    let (result_var, nested_statements) = self.process_nested_func(
131                        func, &arg_types, idx, &mut heap, 1, // Start at depth 1
132                    )?;
133
134                    // Add all nested statements to our statement list
135                    statements.extend(nested_statements);
136
137                    // Replace the statement function call with the variable
138                    // from the innermost statement we created.
139                    stmt.rhs[idx] = Arg::Var(result_var.clone());
140                    rhs.push(Arg::Var(result_var));
141
142                    continue
143                } // <-- Arg::Func
144
145                // The literals get pushed on their own "heap", and
146                // then the compiler will reference them by their own
147                // index when it comes to running the statement that
148                // requires the literal type.
149                if let Arg::Lit(v) = arg {
150                    // Match this literal type to a VarType for
151                    // type checking.
152
153                    let var_type = v.typ.to_vartype();
154                    // Validation for Array types
155                    if arg_types[0] == VarType::BaseArray {
156                        if var_type != VarType::Base {
157                            return Err(self.error.abort(
158                                &format!(
159                                    "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
160                                    VarType::Base
161                                ),
162                                v.line,
163                                v.column,
164                            ))
165                        }
166                    } else if arg_types[0] == VarType::ScalarArray && var_type != VarType::Scalar {
167                        return Err(self.error.abort(
168                            &format!(
169                                "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
170                                VarType::Scalar
171                            ),
172                            v.line,
173                            v.column,
174                        ))
175                    }
176                    // Validation for non-Array types
177                    if var_type != arg_types[idx] {
178                        return Err(self.error.abort(
179                            &format!(
180                                "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
181                                arg_types[idx]
182                            ),
183                            v.line,
184                            v.column,
185                        ))
186                    }
187
188                    self.literals.push(v.clone());
189                    rhs.push(Arg::Lit(v.clone()));
190                    continue
191                }
192
193                if let Arg::Var(v) = arg {
194                    // Look up variable and check if type is correct.
195                    if let Some(s_var) = self.lookup_var(&v.name) {
196                        let (var_type, _ln, _col) = match s_var {
197                            Var::Constant(c) => (c.typ, c.line, c.column),
198                            Var::Witness(c) => (c.typ, c.line, c.column),
199                            Var::Variable(c) => (c.typ, c.line, c.column),
200                        };
201
202                        if arg_types[0] == VarType::BaseArray {
203                            if var_type != VarType::Base {
204                                return Err(self.error.abort(
205                                    &format!(
206                                        "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
207                                        VarType::Base
208                                    ),
209                                    v.line,
210                                    v.column,
211                                ))
212                            }
213                        } else if arg_types[0] == VarType::ScalarArray {
214                            if var_type != VarType::Scalar {
215                                return Err(self.error.abort(
216                                    &format!(
217                                        "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
218                                        VarType::Scalar
219                                    ),
220                                    v.line,
221                                    v.column,
222                                ))
223                            }
224                        } else if var_type != arg_types[idx] && arg_types[idx] != VarType::Any {
225                            return Err(self.error.abort(
226                                &format!(
227                                    "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
228                                    arg_types[idx]
229                                ),
230                                v.line,
231                                v.column,
232                            ))
233                        }
234
235                        // Replace Dummy type with correct type.
236                        let mut v_new = v.clone();
237                        v_new.typ = var_type;
238                        rhs.push(Arg::Var(v_new));
239                        continue
240                    }
241
242                    return Err(self.error.abort(
243                        &format!("Unknown variable reference `{}`.", v.name),
244                        v.line,
245                        v.column,
246                    ))
247                }
248            } // <-- statement.rhs.iter().enumerate()
249
250            // We now type-checked and assigned types to the statement rhs,
251            // so now we apply it to the statement.
252            stmt.rhs = rhs;
253
254            // In case this statement is an assignment, we will push its
255            // result on the heap.
256            if statement.typ == StatementType::Assign {
257                let mut var = statement.lhs.clone().unwrap();
258                // Since we are doing an assignment, ensure that there is a return type.
259                if return_types.is_empty() {
260                    return Err(self.error.abort(
261                        "Cannot perform assignment without a return type",
262                        var.line,
263                        var.column,
264                    ))
265                }
266                var.typ = return_types[0];
267                stmt.lhs = Some(var.clone());
268                heap.push(var.clone());
269                self.heap.clone_from(&heap);
270            }
271
272            //println!("{stmt:#?}");
273            statements.push(stmt);
274        } // <-- for statement in &self.statements
275
276        // Here we replace the self.statements and self.heap with what we
277        // built so far. These can be used later on by the compiler after
278        // this function is finished.
279        self.statements = statements;
280        self.heap = heap;
281
282        //println!("=================STATEMENTS===============\n{:#?}", self.statements);
283        //println!("====================HEAP==================\n{:#?}", self.heap);
284        //println!("==================LITERALS================\n{:#?}", self.literals);
285
286        Ok(())
287    }
288
289    /// Recursively process a nested function call.
290    /// Returns the result Variable and a Vec of Statements that need to be executed.
291    fn process_nested_func(
292        &mut self,
293        func: &Statement,
294        parent_arg_types: &[VarType],
295        parent_arg_idx: usize,
296        heap: &mut Vec<Variable>,
297        depth: usize,
298    ) -> Result<(Variable, Vec<Statement>)> {
299        if depth > MAX_RECURSION_DEPTH {
300            return Err(self.error.abort(
301                &format!(
302                    "Maximum recursion depth of {} exceeded for nested function calls.",
303                    MAX_RECURSION_DEPTH
304                ),
305                func.line,
306                0,
307            ))
308        }
309
310        let (f_return_types, f_arg_types) = func.opcode.arg_types();
311
312        if f_return_types.is_empty() {
313            return Err(self.error.abort(
314                &format!(
315                    "Used a function argument which doesn't have a return value: {:?}",
316                    func.opcode
317                ),
318                func.line,
319                1,
320            ))
321        }
322
323        // Create the result variable for this function call
324        let result_var = Variable {
325            name: func.lhs.clone().unwrap().name,
326            typ: f_return_types[0],
327            line: func.lhs.clone().unwrap().line,
328            column: func.lhs.clone().unwrap().column,
329        };
330
331        // Validate return type against parent's expected type
332        if parent_arg_types[0] == VarType::BaseArray {
333            if f_return_types[0] != VarType::Base {
334                return Err(self.error.abort(
335                    &format!(
336                        "Function passed as argument returns wrong type. Expected `{:?}`, got `{:?}`.",
337                        VarType::Base,
338                        f_return_types[0],
339                    ),
340                    result_var.line,
341                    result_var.column,
342                ))
343            }
344        } else if parent_arg_types[0] == VarType::ScalarArray {
345            if f_return_types[0] != VarType::Scalar {
346                return Err(self.error.abort(
347                    &format!(
348                        "Function passed as argument returns wrong type. Expected `{:?}`, got `{:?}`.",
349                        VarType::Scalar,
350                        f_return_types[0],
351                    ),
352                    result_var.line,
353                    result_var.column,
354                ))
355            }
356        } else if f_return_types[0] != parent_arg_types[parent_arg_idx] {
357            return Err(self.error.abort(
358                &format!(
359                    "Function passed as argument returns wrong type. Expected `{:?}`, got `{:?}`.",
360                    parent_arg_types[parent_arg_idx], f_return_types[0],
361                ),
362                result_var.line,
363                result_var.column,
364            ))
365        }
366
367        // Collect all statements that need to be generated
368        let mut nested_statements = vec![];
369        let mut rhs_inner = vec![];
370
371        // Process each argument of this nested function
372        for (inner_idx, arg) in func.rhs.iter().enumerate() {
373            match arg {
374                Arg::Var(v) => {
375                    if let Some(var_ref) = self.lookup_var(&v.name) {
376                        let (var_type, ln, col) = match var_ref {
377                            Var::Constant(c) => (c.typ, c.line, c.column),
378                            Var::Witness(c) => (c.typ, c.line, c.column),
379                            Var::Variable(c) => (c.typ, c.line, c.column),
380                        };
381
382                        // Type checking for array types
383                        if f_arg_types[0] == VarType::BaseArray {
384                            if var_type != VarType::Base {
385                                return Err(self.error.abort(
386                                    &format!(
387                                        "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
388                                        VarType::Base
389                                    ),
390                                    ln,
391                                    col,
392                                ))
393                            }
394                        } else if f_arg_types[0] == VarType::ScalarArray {
395                            if var_type != VarType::Scalar {
396                                return Err(self.error.abort(
397                                    &format!(
398                                        "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
399                                        VarType::Scalar
400                                    ),
401                                    ln,
402                                    col,
403                                ))
404                            }
405                        } else if var_type != f_arg_types[inner_idx] &&
406                            f_arg_types[inner_idx] != VarType::Any
407                        {
408                            return Err(self.error.abort(
409                                &format!(
410                                    "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
411                                    f_arg_types[inner_idx]
412                                ),
413                                ln,
414                                col,
415                            ))
416                        }
417
418                        // Apply the proper type
419                        let mut v_new = v.clone();
420                        v_new.typ = var_type;
421                        rhs_inner.push(Arg::Var(v_new));
422                    } else {
423                        return Err(self.error.abort(
424                            &format!("Unknown variable reference `{}`.", v.name),
425                            v.line,
426                            v.column,
427                        ))
428                    }
429                }
430
431                Arg::Lit(lit) => {
432                    let var_type = lit.typ.to_vartype();
433
434                    // Type checking for array types
435                    if f_arg_types[0] == VarType::BaseArray {
436                        if var_type != VarType::Base {
437                            return Err(self.error.abort(
438                                &format!(
439                                    "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
440                                    VarType::Base
441                                ),
442                                lit.line,
443                                lit.column,
444                            ))
445                        }
446                    } else if f_arg_types[0] == VarType::ScalarArray {
447                        if var_type != VarType::Scalar {
448                            return Err(self.error.abort(
449                                &format!(
450                                    "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
451                                    VarType::Scalar
452                                ),
453                                lit.line,
454                                lit.column,
455                            ))
456                        }
457                    } else if var_type != f_arg_types[inner_idx] {
458                        return Err(self.error.abort(
459                            &format!(
460                                "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
461                                f_arg_types[inner_idx]
462                            ),
463                            lit.line,
464                            lit.column,
465                        ))
466                    }
467
468                    self.literals.push(lit.clone());
469                    rhs_inner.push(Arg::Lit(lit.clone()));
470                }
471
472                Arg::Func(inner_func) => {
473                    // Recursively process the inner function
474                    let (inner_result_var, inner_statements) = self.process_nested_func(
475                        inner_func,
476                        &f_arg_types,
477                        inner_idx,
478                        heap,
479                        depth + 1,
480                    )?;
481
482                    // Add inner statements first (they need to execute before this one)
483                    nested_statements.extend(inner_statements);
484
485                    // Use the result variable as an argument
486                    rhs_inner.push(Arg::Var(inner_result_var));
487                }
488            }
489        }
490
491        // Create the statement for this function call
492        let stmt = Statement {
493            typ: func.typ,
494            opcode: func.opcode,
495            lhs: Some(result_var.clone()),
496            rhs: rhs_inner,
497            line: func.line,
498        };
499
500        // Add this statement to the list
501        nested_statements.push(stmt);
502
503        // Push the result variable onto the heap
504        heap.push(result_var.clone());
505        self.heap.clone_from(heap);
506
507        Ok((result_var, nested_statements))
508    }
509
510    fn lookup_var(&self, name: &str) -> Option<Var> {
511        if let Some(r) = self.lookup_constant(name) {
512            return Some(Var::Constant(r))
513        }
514
515        if let Some(r) = self.lookup_witness(name) {
516            return Some(Var::Witness(r))
517        }
518
519        if let Some(r) = self.lookup_heap(name) {
520            return Some(Var::Variable(r))
521        }
522
523        None
524    }
525
526    fn lookup_constant(&self, name: &str) -> Option<Constant> {
527        for i in &self.constants {
528            if i.name == name {
529                return Some(i.clone())
530            }
531        }
532
533        None
534    }
535
536    fn lookup_witness(&self, name: &str) -> Option<Witness> {
537        for i in &self.witnesses {
538            if i.name == name {
539                return Some(i.clone())
540            }
541        }
542
543        None
544    }
545
546    fn lookup_heap(&self, name: &str) -> Option<Variable> {
547        for i in &self.heap {
548            if i.name == name {
549                return Some(i.clone())
550            }
551        }
552
553        None
554    }
555
556    pub fn analyze_semantic(&mut self) -> Result<()> {
557        let mut heap = vec![];
558
559        println!("Loading constants...\n-----");
560        for i in &self.constants {
561            println!("Adding `{}` to heap", i.name);
562            heap.push(&i.name);
563            Analyzer::pause();
564        }
565        println!("Heap:\n{heap:#?}\n-----");
566        println!("Loading witnesses...\n-----");
567        for i in &self.witnesses {
568            println!("Adding `{}` to heap", i.name);
569            heap.push(&i.name);
570            Analyzer::pause();
571        }
572        println!("Heap:\n{heap:#?}\n-----");
573        println!("Loading circuit...");
574        for i in &self.statements {
575            let mut argnames = vec![];
576            for arg in &i.rhs {
577                if let Arg::Var(arg) = arg {
578                    argnames.push(arg.name.clone());
579                } else if let Arg::Lit(lit) = arg {
580                    argnames.push(lit.name.clone());
581                } else {
582                    unreachable!()
583                }
584            }
585            println!("Executing: {:?}({argnames:?})", i.opcode);
586
587            Analyzer::pause();
588
589            for arg in &i.rhs {
590                if let Arg::Var(arg) = arg {
591                    print!("Looking up `{}` on the heap... ", arg.name);
592                    if let Some(index) = heap.iter().position(|&r| r == &arg.name) {
593                        println!("Found at heap index {index}");
594                    } else {
595                        return Err(self.error.abort(
596                            &format!("Could not find `{}` on the heap", arg.name),
597                            arg.line,
598                            arg.column,
599                        ))
600                    }
601                } else if let Arg::Lit(lit) = arg {
602                    println!("Using literal `{}`", lit.name);
603                } else {
604                    println!("{arg:#?}");
605                    unreachable!();
606                }
607
608                Analyzer::pause();
609            }
610            match i.typ {
611                StatementType::Assign => {
612                    println!("Pushing result as `{}` to heap", &i.lhs.as_ref().unwrap().name);
613                    heap.push(&i.lhs.as_ref().unwrap().name);
614                    println!("Heap:\n{heap:#?}\n-----");
615                }
616                StatementType::Call => {
617                    println!("-----");
618                }
619                _ => unreachable!(),
620            }
621        }
622
623        Ok(())
624    }
625
626    fn pause() {
627        let msg = b"[Press Enter to continue]\r";
628        let mut stdout = stdout();
629        let _ = stdout.write(msg).unwrap();
630        stdout.flush().unwrap();
631        let _ = stdin().read(&mut [0]).unwrap();
632        write!(stdout, "\x1b[1A\r\x1b[K\r").unwrap();
633    }
634}