darkfi/zkas/analyzer.rs
1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program. If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{
20 io::{stdin, stdout, Read, Result, Write},
21 str::Chars,
22};
23
24use super::{
25 ast::{Arg, Constant, Literal, Statement, StatementType, Var, Variable, Witness},
26 constants::MAX_RECURSION_DEPTH,
27 error::ErrorEmitter,
28 Opcode, VarType,
29};
30
31pub struct Analyzer {
32 pub constants: Vec<Constant>,
33 pub witnesses: Vec<Witness>,
34 pub statements: Vec<Statement>,
35 pub literals: Vec<Literal>,
36 pub heap: Vec<Variable>,
37 error: ErrorEmitter,
38}
39
40impl Analyzer {
41 pub fn new(
42 filename: &str,
43 source: Chars,
44 constants: Vec<Constant>,
45 witnesses: Vec<Witness>,
46 statements: Vec<Statement>,
47 ) -> Self {
48 // For nice error reporting, we'll load everything into a string
49 // vector so we have references to lines.
50 let lines: Vec<String> = source.as_str().lines().map(|x| x.to_string()).collect();
51 let error = ErrorEmitter::new("Semantic", filename, lines);
52
53 Self { constants, witnesses, statements, literals: vec![], heap: vec![], error }
54 }
55
56 pub fn analyze_types(&mut self) -> Result<()> {
57 // To work around the pedantic safety, we'll make new vectors and then
58 // replace the `statements` and `heap` vectors from the `Analyzer`
59 // object when we are done.
60 let mut statements = vec![];
61 let mut heap = vec![];
62
63 let input_statements = self.statements.clone();
64 for statement in &input_statements {
65 //println!("{statement:?}");
66 let mut stmt = statement.clone();
67
68 let (return_types, arg_types) = statement.opcode.arg_types();
69 let mut rhs = vec![];
70
71 // This handling is kinda limiting, but it'll do for now.
72 if !(arg_types[0] == VarType::BaseArray || arg_types[0] == VarType::ScalarArray) {
73 // Check that number of args is correct
74 if statement.rhs.len() != arg_types.len() {
75 return Err(self.error.abort(
76 &format!(
77 "Incorrect number of arguments for statement. Expected {}, got {}.",
78 arg_types.len(),
79 statement.rhs.len()
80 ),
81 statement.line,
82 1,
83 ))
84 }
85 } else {
86 // In case of arrays, check there's at least one element.
87 if statement.rhs.is_empty() {
88 return Err(self.error.abort(
89 "Expected at least one element for statement using arrays.",
90 statement.line,
91 1,
92 ))
93 }
94 }
95
96 // Edge-cases for some opcodes
97 #[allow(clippy::single_match)]
98 match &statement.opcode {
99 Opcode::RangeCheck => {
100 if let Arg::Lit(arg0) = &statement.rhs[0] {
101 if &arg0.name != "64" && &arg0.name != "253" {
102 return Err(self.error.abort(
103 "Supported range checks are only 64 and 253 bits.",
104 arg0.line,
105 arg0.column,
106 ))
107 }
108 } else {
109 return Err(self.error.abort(
110 "Invalid argument for range_check opcode.",
111 statement.line,
112 0,
113 ))
114 }
115 }
116
117 _ => {}
118 }
119
120 for (idx, arg) in statement.rhs.iter().enumerate() {
121 // In case an argument is a function call, we will first
122 // convert it to another statement that will get executed
123 // before this one. An important assumption is that this
124 // opcode has a return value. When executed we will push
125 // this value onto the heap and use it as a reference to
126 // the actual statement we're parsing at this moment.
127 // This uses a recursive algorithm to handle arbitrarily
128 // nested functions up to MAX_RECURSION_DEPTH.
129 if let Arg::Func(func) = arg {
130 let (result_var, nested_statements) = self.process_nested_func(
131 func, &arg_types, idx, &mut heap, 1, // Start at depth 1
132 )?;
133
134 // Add all nested statements to our statement list
135 statements.extend(nested_statements);
136
137 // Replace the statement function call with the variable
138 // from the innermost statement we created.
139 stmt.rhs[idx] = Arg::Var(result_var.clone());
140 rhs.push(Arg::Var(result_var));
141
142 continue
143 } // <-- Arg::Func
144
145 // The literals get pushed on their own "heap", and
146 // then the compiler will reference them by their own
147 // index when it comes to running the statement that
148 // requires the literal type.
149 if let Arg::Lit(v) = arg {
150 // Match this literal type to a VarType for
151 // type checking.
152
153 let var_type = v.typ.to_vartype();
154 // Validation for Array types
155 if arg_types[0] == VarType::BaseArray {
156 if var_type != VarType::Base {
157 return Err(self.error.abort(
158 &format!(
159 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
160 VarType::Base
161 ),
162 v.line,
163 v.column,
164 ))
165 }
166 } else if arg_types[0] == VarType::ScalarArray && var_type != VarType::Scalar {
167 return Err(self.error.abort(
168 &format!(
169 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
170 VarType::Scalar
171 ),
172 v.line,
173 v.column,
174 ))
175 }
176 // Validation for non-Array types
177 if var_type != arg_types[idx] {
178 return Err(self.error.abort(
179 &format!(
180 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
181 arg_types[idx]
182 ),
183 v.line,
184 v.column,
185 ))
186 }
187
188 self.literals.push(v.clone());
189 rhs.push(Arg::Lit(v.clone()));
190 continue
191 }
192
193 if let Arg::Var(v) = arg {
194 // Look up variable and check if type is correct.
195 if let Some(s_var) = self.lookup_var(&v.name) {
196 let (var_type, _ln, _col) = match s_var {
197 Var::Constant(c) => (c.typ, c.line, c.column),
198 Var::Witness(c) => (c.typ, c.line, c.column),
199 Var::Variable(c) => (c.typ, c.line, c.column),
200 };
201
202 if arg_types[0] == VarType::BaseArray {
203 if var_type != VarType::Base {
204 return Err(self.error.abort(
205 &format!(
206 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
207 VarType::Base
208 ),
209 v.line,
210 v.column,
211 ))
212 }
213 } else if arg_types[0] == VarType::ScalarArray {
214 if var_type != VarType::Scalar {
215 return Err(self.error.abort(
216 &format!(
217 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
218 VarType::Scalar
219 ),
220 v.line,
221 v.column,
222 ))
223 }
224 } else if var_type != arg_types[idx] && arg_types[idx] != VarType::Any {
225 return Err(self.error.abort(
226 &format!(
227 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
228 arg_types[idx]
229 ),
230 v.line,
231 v.column,
232 ))
233 }
234
235 // Replace Dummy type with correct type.
236 let mut v_new = v.clone();
237 v_new.typ = var_type;
238 rhs.push(Arg::Var(v_new));
239 continue
240 }
241
242 return Err(self.error.abort(
243 &format!("Unknown variable reference `{}`.", v.name),
244 v.line,
245 v.column,
246 ))
247 }
248 } // <-- statement.rhs.iter().enumerate()
249
250 // We now type-checked and assigned types to the statement rhs,
251 // so now we apply it to the statement.
252 stmt.rhs = rhs;
253
254 // In case this statement is an assignment, we will push its
255 // result on the heap.
256 if statement.typ == StatementType::Assign {
257 let mut var = statement.lhs.clone().unwrap();
258 // Since we are doing an assignment, ensure that there is a return type.
259 if return_types.is_empty() {
260 return Err(self.error.abort(
261 "Cannot perform assignment without a return type",
262 var.line,
263 var.column,
264 ))
265 }
266 var.typ = return_types[0];
267 stmt.lhs = Some(var.clone());
268 heap.push(var.clone());
269 self.heap.clone_from(&heap);
270 }
271
272 //println!("{stmt:#?}");
273 statements.push(stmt);
274 } // <-- for statement in &self.statements
275
276 // Here we replace the self.statements and self.heap with what we
277 // built so far. These can be used later on by the compiler after
278 // this function is finished.
279 self.statements = statements;
280 self.heap = heap;
281
282 //println!("=================STATEMENTS===============\n{:#?}", self.statements);
283 //println!("====================HEAP==================\n{:#?}", self.heap);
284 //println!("==================LITERALS================\n{:#?}", self.literals);
285
286 Ok(())
287 }
288
289 /// Recursively process a nested function call.
290 /// Returns the result Variable and a Vec of Statements that need to be executed.
291 fn process_nested_func(
292 &mut self,
293 func: &Statement,
294 parent_arg_types: &[VarType],
295 parent_arg_idx: usize,
296 heap: &mut Vec<Variable>,
297 depth: usize,
298 ) -> Result<(Variable, Vec<Statement>)> {
299 if depth > MAX_RECURSION_DEPTH {
300 return Err(self.error.abort(
301 &format!(
302 "Maximum recursion depth of {} exceeded for nested function calls.",
303 MAX_RECURSION_DEPTH
304 ),
305 func.line,
306 0,
307 ))
308 }
309
310 let (f_return_types, f_arg_types) = func.opcode.arg_types();
311
312 if f_return_types.is_empty() {
313 return Err(self.error.abort(
314 &format!(
315 "Used a function argument which doesn't have a return value: {:?}",
316 func.opcode
317 ),
318 func.line,
319 1,
320 ))
321 }
322
323 // Create the result variable for this function call
324 let result_var = Variable {
325 name: func.lhs.clone().unwrap().name,
326 typ: f_return_types[0],
327 line: func.lhs.clone().unwrap().line,
328 column: func.lhs.clone().unwrap().column,
329 };
330
331 // Validate return type against parent's expected type
332 if parent_arg_types[0] == VarType::BaseArray {
333 if f_return_types[0] != VarType::Base {
334 return Err(self.error.abort(
335 &format!(
336 "Function passed as argument returns wrong type. Expected `{:?}`, got `{:?}`.",
337 VarType::Base,
338 f_return_types[0],
339 ),
340 result_var.line,
341 result_var.column,
342 ))
343 }
344 } else if parent_arg_types[0] == VarType::ScalarArray {
345 if f_return_types[0] != VarType::Scalar {
346 return Err(self.error.abort(
347 &format!(
348 "Function passed as argument returns wrong type. Expected `{:?}`, got `{:?}`.",
349 VarType::Scalar,
350 f_return_types[0],
351 ),
352 result_var.line,
353 result_var.column,
354 ))
355 }
356 } else if f_return_types[0] != parent_arg_types[parent_arg_idx] {
357 return Err(self.error.abort(
358 &format!(
359 "Function passed as argument returns wrong type. Expected `{:?}`, got `{:?}`.",
360 parent_arg_types[parent_arg_idx], f_return_types[0],
361 ),
362 result_var.line,
363 result_var.column,
364 ))
365 }
366
367 // Collect all statements that need to be generated
368 let mut nested_statements = vec![];
369 let mut rhs_inner = vec![];
370
371 // Process each argument of this nested function
372 for (inner_idx, arg) in func.rhs.iter().enumerate() {
373 match arg {
374 Arg::Var(v) => {
375 if let Some(var_ref) = self.lookup_var(&v.name) {
376 let (var_type, ln, col) = match var_ref {
377 Var::Constant(c) => (c.typ, c.line, c.column),
378 Var::Witness(c) => (c.typ, c.line, c.column),
379 Var::Variable(c) => (c.typ, c.line, c.column),
380 };
381
382 // Type checking for array types
383 if f_arg_types[0] == VarType::BaseArray {
384 if var_type != VarType::Base {
385 return Err(self.error.abort(
386 &format!(
387 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
388 VarType::Base
389 ),
390 ln,
391 col,
392 ))
393 }
394 } else if f_arg_types[0] == VarType::ScalarArray {
395 if var_type != VarType::Scalar {
396 return Err(self.error.abort(
397 &format!(
398 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
399 VarType::Scalar
400 ),
401 ln,
402 col,
403 ))
404 }
405 } else if var_type != f_arg_types[inner_idx] &&
406 f_arg_types[inner_idx] != VarType::Any
407 {
408 return Err(self.error.abort(
409 &format!(
410 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
411 f_arg_types[inner_idx]
412 ),
413 ln,
414 col,
415 ))
416 }
417
418 // Apply the proper type
419 let mut v_new = v.clone();
420 v_new.typ = var_type;
421 rhs_inner.push(Arg::Var(v_new));
422 } else {
423 return Err(self.error.abort(
424 &format!("Unknown variable reference `{}`.", v.name),
425 v.line,
426 v.column,
427 ))
428 }
429 }
430
431 Arg::Lit(lit) => {
432 let var_type = lit.typ.to_vartype();
433
434 // Type checking for array types
435 if f_arg_types[0] == VarType::BaseArray {
436 if var_type != VarType::Base {
437 return Err(self.error.abort(
438 &format!(
439 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
440 VarType::Base
441 ),
442 lit.line,
443 lit.column,
444 ))
445 }
446 } else if f_arg_types[0] == VarType::ScalarArray {
447 if var_type != VarType::Scalar {
448 return Err(self.error.abort(
449 &format!(
450 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
451 VarType::Scalar
452 ),
453 lit.line,
454 lit.column,
455 ))
456 }
457 } else if var_type != f_arg_types[inner_idx] {
458 return Err(self.error.abort(
459 &format!(
460 "Incorrect argument type. Expected `{:?}`, got `{var_type:?}`.",
461 f_arg_types[inner_idx]
462 ),
463 lit.line,
464 lit.column,
465 ))
466 }
467
468 self.literals.push(lit.clone());
469 rhs_inner.push(Arg::Lit(lit.clone()));
470 }
471
472 Arg::Func(inner_func) => {
473 // Recursively process the inner function
474 let (inner_result_var, inner_statements) = self.process_nested_func(
475 inner_func,
476 &f_arg_types,
477 inner_idx,
478 heap,
479 depth + 1,
480 )?;
481
482 // Add inner statements first (they need to execute before this one)
483 nested_statements.extend(inner_statements);
484
485 // Use the result variable as an argument
486 rhs_inner.push(Arg::Var(inner_result_var));
487 }
488 }
489 }
490
491 // Create the statement for this function call
492 let stmt = Statement {
493 typ: func.typ,
494 opcode: func.opcode,
495 lhs: Some(result_var.clone()),
496 rhs: rhs_inner,
497 line: func.line,
498 };
499
500 // Add this statement to the list
501 nested_statements.push(stmt);
502
503 // Push the result variable onto the heap
504 heap.push(result_var.clone());
505 self.heap.clone_from(heap);
506
507 Ok((result_var, nested_statements))
508 }
509
510 fn lookup_var(&self, name: &str) -> Option<Var> {
511 if let Some(r) = self.lookup_constant(name) {
512 return Some(Var::Constant(r))
513 }
514
515 if let Some(r) = self.lookup_witness(name) {
516 return Some(Var::Witness(r))
517 }
518
519 if let Some(r) = self.lookup_heap(name) {
520 return Some(Var::Variable(r))
521 }
522
523 None
524 }
525
526 fn lookup_constant(&self, name: &str) -> Option<Constant> {
527 for i in &self.constants {
528 if i.name == name {
529 return Some(i.clone())
530 }
531 }
532
533 None
534 }
535
536 fn lookup_witness(&self, name: &str) -> Option<Witness> {
537 for i in &self.witnesses {
538 if i.name == name {
539 return Some(i.clone())
540 }
541 }
542
543 None
544 }
545
546 fn lookup_heap(&self, name: &str) -> Option<Variable> {
547 for i in &self.heap {
548 if i.name == name {
549 return Some(i.clone())
550 }
551 }
552
553 None
554 }
555
556 pub fn analyze_semantic(&mut self) -> Result<()> {
557 let mut heap = vec![];
558
559 println!("Loading constants...\n-----");
560 for i in &self.constants {
561 println!("Adding `{}` to heap", i.name);
562 heap.push(&i.name);
563 Analyzer::pause();
564 }
565 println!("Heap:\n{heap:#?}\n-----");
566 println!("Loading witnesses...\n-----");
567 for i in &self.witnesses {
568 println!("Adding `{}` to heap", i.name);
569 heap.push(&i.name);
570 Analyzer::pause();
571 }
572 println!("Heap:\n{heap:#?}\n-----");
573 println!("Loading circuit...");
574 for i in &self.statements {
575 let mut argnames = vec![];
576 for arg in &i.rhs {
577 if let Arg::Var(arg) = arg {
578 argnames.push(arg.name.clone());
579 } else if let Arg::Lit(lit) = arg {
580 argnames.push(lit.name.clone());
581 } else {
582 unreachable!()
583 }
584 }
585 println!("Executing: {:?}({argnames:?})", i.opcode);
586
587 Analyzer::pause();
588
589 for arg in &i.rhs {
590 if let Arg::Var(arg) = arg {
591 print!("Looking up `{}` on the heap... ", arg.name);
592 if let Some(index) = heap.iter().position(|&r| r == &arg.name) {
593 println!("Found at heap index {index}");
594 } else {
595 return Err(self.error.abort(
596 &format!("Could not find `{}` on the heap", arg.name),
597 arg.line,
598 arg.column,
599 ))
600 }
601 } else if let Arg::Lit(lit) = arg {
602 println!("Using literal `{}`", lit.name);
603 } else {
604 println!("{arg:#?}");
605 unreachable!();
606 }
607
608 Analyzer::pause();
609 }
610 match i.typ {
611 StatementType::Assign => {
612 println!("Pushing result as `{}` to heap", &i.lhs.as_ref().unwrap().name);
613 heap.push(&i.lhs.as_ref().unwrap().name);
614 println!("Heap:\n{heap:#?}\n-----");
615 }
616 StatementType::Call => {
617 println!("-----");
618 }
619 _ => unreachable!(),
620 }
621 }
622
623 Ok(())
624 }
625
626 fn pause() {
627 let msg = b"[Press Enter to continue]\r";
628 let mut stdout = stdout();
629 let _ = stdout.write(msg).unwrap();
630 stdout.flush().unwrap();
631 let _ = stdin().read(&mut [0]).unwrap();
632 write!(stdout, "\x1b[1A\r\x1b[K\r").unwrap();
633 }
634}