darkfi/zk/gadget/
native_range_check.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use halo2_proofs::{
20    circuit::{AssignedCell, Chip, Layouter, Region, Value},
21    pasta::{
22        group::ff::{Field, PrimeFieldBits},
23        pallas,
24    },
25    plonk,
26    plonk::{Advice, Column, ConstraintSystem, Constraints, Selector, TableColumn},
27    poly::Rotation,
28};
29
30#[derive(Clone, Debug)]
31pub struct NativeRangeCheckConfig<const WINDOW_SIZE: usize, const NUM_BITS: usize> {
32    pub z: Column<Advice>,
33    pub s_rc: Selector,
34    pub s_short: Selector,
35    pub k_values_table: TableColumn,
36}
37
38#[derive(Clone, Debug)]
39pub struct NativeRangeCheckChip<const WINDOW_SIZE: usize, const NUM_BITS: usize> {
40    config: NativeRangeCheckConfig<WINDOW_SIZE, NUM_BITS>,
41}
42
43impl<const WINDOW_SIZE: usize, const NUM_BITS: usize> Chip<pallas::Base>
44    for NativeRangeCheckChip<WINDOW_SIZE, NUM_BITS>
45{
46    type Config = NativeRangeCheckConfig<WINDOW_SIZE, NUM_BITS>;
47    type Loaded = ();
48
49    fn config(&self) -> &Self::Config {
50        &self.config
51    }
52
53    fn loaded(&self) -> &Self::Loaded {
54        &()
55    }
56}
57
58impl<const WINDOW_SIZE: usize, const NUM_BITS: usize> NativeRangeCheckChip<WINDOW_SIZE, NUM_BITS> {
59    pub fn construct(config: NativeRangeCheckConfig<WINDOW_SIZE, NUM_BITS>) -> Self {
60        Self { config }
61    }
62
63    pub fn configure(
64        meta: &mut ConstraintSystem<pallas::Base>,
65        z: Column<Advice>,
66        k_values_table: TableColumn,
67    ) -> NativeRangeCheckConfig<WINDOW_SIZE, NUM_BITS> {
68        // Enable permutation on z column
69        meta.enable_equality(z);
70
71        let s_rc = meta.complex_selector();
72        let s_short = meta.complex_selector();
73
74        // Running sum decomposition
75        meta.lookup(|meta| {
76            let s_rc = meta.query_selector(s_rc);
77            let z_curr = meta.query_advice(z, Rotation::cur());
78            let z_next = meta.query_advice(z, Rotation::next());
79
80            //    z_next = (z_curr - k_i) / 2^K
81            // => k_i = z_curr - (z_next * 2^K)
82            vec![(s_rc * (z_curr - z_next * pallas::Base::from(1 << WINDOW_SIZE)), k_values_table)]
83        });
84
85        // Checks that are enabled if the last chunk is an `s`-bit value
86        // where `s < WINDOW_SIZE`:
87        //
88        //  |s_rc | s_short |                z                |
89        //  ---------------------------------------------------
90        //  |  1  |    0    |            last_chunk           |
91        //  |  0  |    1    |                0                |
92        //  |  0  |    0    | last_chunk << (WINDOW_SIZE - s) |
93
94        // Check that `shifted_last_chunk` is `WINDOW_SIZE` bits,
95        // where shifted_last_chunk = last_chunk << (WINDOW_SIZE - s)
96        //                          = last_chunk * 2^(WINDOW_SIZE - s)
97        meta.lookup(|meta| {
98            let s_short = meta.query_selector(s_short);
99            let shifted_last_chunk = meta.query_advice(z, Rotation::next());
100            vec![(s_short * shifted_last_chunk, k_values_table)]
101        });
102
103        // Check that `shifted_last_chunk = last_chunk << (WINDOW_SIZE - s)`
104        meta.create_gate("Short lookup bitshift", |meta| {
105            let two_pow_window_size = pallas::Base::from(1 << WINDOW_SIZE);
106            let s_short = meta.query_selector(s_short);
107            let last_chunk = meta.query_advice(z, Rotation::prev());
108            // Rotation::cur() is copy-constrained to be zero elsewhere in this gadget.
109            let shifted_last_chunk = meta.query_advice(z, Rotation::next());
110            // inv_two_pow_s = 1 >> s = 2^{-s}
111            let inv_two_pow_s = {
112                let s = NUM_BITS % WINDOW_SIZE;
113                pallas::Base::from(1 << s).invert().unwrap()
114            };
115
116            // shifted_last_chunk = last_chunk << (WINDOW_SIZE - s)
117            //                    = last_chunk * 2^WINDOW_SIZE * 2^{-s}
118            Constraints::with_selector(
119                s_short,
120                Some(last_chunk * two_pow_window_size * inv_two_pow_s - shifted_last_chunk),
121            )
122        });
123
124        NativeRangeCheckConfig { z, s_rc, s_short, k_values_table }
125    }
126
127    /// `k_values_table` should be reused across different chips
128    /// which is why we don't limit it to a specific instance.
129    pub fn load_k_table(
130        layouter: &mut impl Layouter<pallas::Base>,
131        k_values_table: TableColumn,
132    ) -> Result<(), plonk::Error> {
133        layouter.assign_table(
134            || format!("{WINDOW_SIZE} window table"),
135            |mut table| {
136                for index in 0..(1 << WINDOW_SIZE) {
137                    table.assign_cell(
138                        || format!("{WINDOW_SIZE} window assign"),
139                        k_values_table,
140                        index,
141                        || Value::known(pallas::Base::from(index as u64)),
142                    )?;
143                }
144                Ok(())
145            },
146        )
147    }
148
149    fn decompose_value(value: &pallas::Base) -> Vec<[bool; WINDOW_SIZE]> {
150        let bits: Vec<_> = value
151            .to_le_bits()
152            .into_iter()
153            .take(NUM_BITS)
154            .chain(std::iter::repeat_n(false, WINDOW_SIZE - (NUM_BITS % WINDOW_SIZE)))
155            .collect();
156
157        bits.chunks_exact(WINDOW_SIZE)
158            .map(|x| {
159                // Because bits <= WINDOW_SIZE * NUM_BITS, the last window may be
160                // smaller than WINDOW_SIZE.
161                // Additionally we have a slice, so convert them all to a fixed length array.
162                let mut chunks = [false; WINDOW_SIZE];
163                chunks.copy_from_slice(x);
164                chunks
165            })
166            .collect()
167    }
168
169    /// This is the main chip function. Attempts to witness the bits for `z_0` proving
170    /// it is within the allowed range.
171    pub fn decompose(
172        &self,
173        region: &mut Region<'_, pallas::Base>,
174        z_0: AssignedCell<pallas::Base, pallas::Base>,
175        offset: usize,
176    ) -> Result<(), plonk::Error> {
177        let num_windows = NUM_BITS.div_ceil(WINDOW_SIZE);
178
179        // The number of bits in the last chunk.
180        let last_chunk_length = NUM_BITS - (WINDOW_SIZE * (num_windows - 1));
181        assert!(last_chunk_length > 0);
182
183        // Enable selectors for running sum decomposition
184        for index in 0..num_windows {
185            self.config.s_rc.enable(region, index + offset)?;
186        }
187
188        let mut z_values: Vec<AssignedCell<pallas::Base, pallas::Base>> = vec![z_0.clone()];
189        let mut z = z_0;
190        // Convert `z_0` into a `Vec<Value<Fp>>` where each value corresponds to a chunk.
191        let decomposed_chunks = z.value().map(Self::decompose_value).transpose_vec(num_windows);
192
193        let two_pow_k = pallas::Base::from(1 << WINDOW_SIZE as u64);
194        let two_pow_k_inverse = Value::known(two_pow_k.invert().unwrap());
195
196        //   z = 2⁰b₀ + 2¹b₁ + ⋯ + 2ⁿbₙ
197        //     = c₀ + 2ʷc₁ + 2²ʷc₂ + ⋯ + 2ᵐʷcₘ
198        // where cᵢ are the chunks.
199        //
200        // We want to show each cᵢ consists of WINDOW_SIZE bits which we do using
201        // the lookup table.
202        // The algo starts with z₀ = z, then calculates:
203        //   zᵢ = (zᵢ₋₁ - cᵢ₋₁)/2ʷ
204        // Doing this for all chunks, we end up with zₘ = 0 which is done after.
205
206        // Loop over the decomposed chunks...
207        for (i, chunk) in decomposed_chunks.iter().enumerate() {
208            let z_next = {
209                let z_curr = z.value().copied();
210                // Convert the chunk Value<[bool; WINDOW_SIZE]> into Value<pallas::Base>
211                let chunk_value = chunk.map(|c| {
212                    pallas::Base::from(c.iter().rev().fold(0, |acc, c| (acc << 1) + *c as u64))
213                });
214                // Calc z_next = (z_curr - k_i) / 2^K
215                let z_next = (z_curr - chunk_value) * two_pow_k_inverse;
216                // Witness z_next into the running sum decomposition gate
217                region.assign_advice(
218                    || format!("z_{}", i + offset + 1),
219                    self.config.z,
220                    i + offset + 1,
221                    || z_next,
222                )?
223            };
224            z_values.push(z_next.clone());
225            z = z_next.clone();
226        }
227
228        assert!(z_values.len() == num_windows + 1);
229
230        // Constrain the last chunk zₘ = 0
231        region.constrain_constant(z_values.last().unwrap().cell(), pallas::Base::zero())?;
232
233        // If the last chunk is `s` bits where `s < WINDOW_SIZE`,
234        // perform short range check
235        //
236        //  |s_rc | s_short |                z                |
237        //  ---------------------------------------------------
238        //  |  1  |    0    |            last_chunk           |
239        //  |  0  |    1    |                0                |
240        //  |  0  |    0    | last_chunk << (WINDOW_SIZE - s) |
241        //  |  0  |    0    |             1 >> s              |
242
243        if last_chunk_length < WINDOW_SIZE {
244            let s_short_offset = num_windows + offset;
245            self.config.s_short.enable(region, s_short_offset)?;
246
247            // 1 >> s = 2^{-s}
248            let inv_two_pow_s = pallas::Base::from(1 << last_chunk_length).invert().unwrap();
249            region.assign_advice_from_constant(
250                || "inv_two_pow_s",
251                self.config.z,
252                s_short_offset + 2,
253                inv_two_pow_s,
254            )?;
255
256            // shifted_last_chunk = last_chunk * 2^{WINDOW_SIZE-s}
257            //                    = last_chunk * 2^WINDOW_SIZE * inv_two_pow_s
258            let last_chunk = {
259                let chunk = decomposed_chunks.last().unwrap();
260                chunk.map(|c| {
261                    pallas::Base::from(c.iter().rev().fold(0, |acc, c| (acc << 1) + *c as u64))
262                })
263            };
264            let shifted_last_chunk =
265                last_chunk * Value::known(two_pow_k) * Value::known(inv_two_pow_s);
266            region.assign_advice(
267                || "shifted_last_chunk",
268                self.config.z,
269                s_short_offset + 1,
270                || shifted_last_chunk,
271            )?;
272        }
273
274        Ok(())
275    }
276
277    pub fn witness_range_check(
278        &self,
279        mut layouter: impl Layouter<pallas::Base>,
280        value: Value<pallas::Base>,
281    ) -> Result<(), plonk::Error> {
282        layouter.assign_region(
283            || format!("witness {NUM_BITS}-bit native range check"),
284            |mut region: Region<'_, pallas::Base>| {
285                let z_0 = region.assign_advice(|| "z_0", self.config.z, 0, || value)?;
286                self.decompose(&mut region, z_0, 0)?;
287                Ok(())
288            },
289        )
290    }
291
292    pub fn copy_range_check(
293        &self,
294        mut layouter: impl Layouter<pallas::Base>,
295        value: AssignedCell<pallas::Base, pallas::Base>,
296    ) -> Result<(), plonk::Error> {
297        layouter.assign_region(
298            || format!("copy {NUM_BITS}-bit native range check"),
299            |mut region: Region<'_, pallas::Base>| {
300                let z_0 = value.copy_advice(|| "z_0", &mut region, self.config.z, 0)?;
301                self.decompose(&mut region, z_0, 0)?;
302                Ok(())
303            },
304        )
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::zk::assign_free_advice;
312    use halo2_proofs::{
313        circuit::floor_planner,
314        dev::{CircuitLayout, MockProver},
315        pasta::group::ff::PrimeField,
316        plonk::Circuit,
317    };
318
319    macro_rules! test_circuit {
320        ($k: expr, $window_size:expr, $num_bits: expr, $valid_values:expr, $invalid_values:expr) => {
321            #[derive(Default)]
322            struct RangeCheckCircuit {
323                a: Value<pallas::Base>,
324            }
325
326            impl Circuit<pallas::Base> for RangeCheckCircuit {
327                type Config = (NativeRangeCheckConfig<$window_size, $num_bits>, Column<Advice>);
328                type FloorPlanner = floor_planner::V1;
329                type Params = ();
330
331                fn without_witnesses(&self) -> Self {
332                    Self::default()
333                }
334
335                fn configure(meta: &mut ConstraintSystem<pallas::Base>) -> Self::Config {
336                    let w = meta.advice_column();
337                    meta.enable_equality(w);
338                    let z = meta.advice_column();
339                    let table_column = meta.lookup_table_column();
340
341                    let constants = meta.fixed_column();
342                    meta.enable_constant(constants);
343                    (
344                        NativeRangeCheckChip::<$window_size, $num_bits>::configure(
345                            meta,
346                            z,
347                            table_column,
348                        ),
349                        w,
350                    )
351                }
352
353                fn synthesize(
354                    &self,
355                    config: Self::Config,
356                    mut layouter: impl Layouter<pallas::Base>,
357                ) -> Result<(), plonk::Error> {
358                    let rangecheck_chip =
359                        NativeRangeCheckChip::<$window_size, $num_bits>::construct(
360                            config.0.clone(),
361                        );
362                    NativeRangeCheckChip::<$window_size, $num_bits>::load_k_table(
363                        &mut layouter,
364                        config.0.k_values_table,
365                    )?;
366
367                    let a = assign_free_advice(layouter.namespace(|| "load a"), config.1, self.a)?;
368                    rangecheck_chip
369                        .copy_range_check(layouter.namespace(|| "copy a and range check"), a)?;
370
371                    rangecheck_chip.witness_range_check(
372                        layouter.namespace(|| "witness a and range check"),
373                        self.a,
374                    )?;
375
376                    Ok(())
377                }
378            }
379
380            use plotters::prelude::*;
381            let circuit = RangeCheckCircuit { a: Value::known(pallas::Base::one()) };
382            let file_name = format!("target/native_range_check_{:?}_circuit_layout.png", $num_bits);
383            let root = BitMapBackend::new(file_name.as_str(), (3840, 2160)).into_drawing_area();
384            root.fill(&WHITE).unwrap();
385            let root = root
386                .titled(
387                    format!("{:?}-bit Native Range Check Circuit Layout", $num_bits).as_str(),
388                    ("sans-serif", 60),
389                )
390                .unwrap();
391            CircuitLayout::default().render($k, &circuit, &root).unwrap();
392
393            for i in $valid_values {
394                println!("{:?}-bit (valid) range check for {:?}", $num_bits, i);
395                let circuit = RangeCheckCircuit { a: Value::known(i) };
396                let prover = MockProver::run($k, &circuit, vec![]).unwrap();
397                prover.assert_satisfied();
398                println!("Constraints satisfied");
399            }
400
401            for i in $invalid_values {
402                println!("{:?}-bit (invalid) range check for {:?}", $num_bits, i);
403                let circuit = RangeCheckCircuit { a: Value::known(i) };
404                let prover = MockProver::run($k, &circuit, vec![]).unwrap();
405                assert!(prover.verify().is_err());
406            }
407        };
408    }
409
410    // cargo test --release --all-features --lib native_range_check -- --nocapture
411    #[test]
412    fn native_range_check_2() {
413        let k = 6;
414        const WINDOW_SIZE: usize = 5;
415        const NUM_BITS: usize = 2;
416
417        // [0, 1, 2, 3]
418        let valid_values: Vec<_> = (0..(1 << NUM_BITS)).map(pallas::Base::from).collect();
419        // [4, 5, 6, ..., 32]
420        let invalid_values: Vec<_> =
421            ((1 << NUM_BITS)..=(1 << WINDOW_SIZE)).map(pallas::Base::from).collect();
422        test_circuit!(k, WINDOW_SIZE, NUM_BITS, valid_values, invalid_values);
423    }
424
425    #[test]
426    fn native_range_check_64() {
427        let k = 6;
428        const WINDOW_SIZE: usize = 3;
429        const NUM_BITS: usize = 64;
430
431        let valid_values = vec![
432            pallas::Base::zero(),
433            pallas::Base::one(),
434            pallas::Base::from(u64::MAX),
435            pallas::Base::from(rand::random::<u64>()),
436        ];
437
438        let invalid_values = vec![
439            -pallas::Base::one(),
440            pallas::Base::from_u128(u64::MAX as u128 + 1),
441            -pallas::Base::from_u128(u64::MAX as u128 + 1),
442            // The following two are valid
443            // 2 = -28948022309329048855892746252171976963363056481941560715954676764349967630335
444            //-pallas::Base::from_str_vartime(
445            //    "28948022309329048855892746252171976963363056481941560715954676764349967630335",
446            //)
447            //.unwrap(),
448            // 1 = -28948022309329048855892746252171976963363056481941560715954676764349967630336
449            //-pallas::Base::from_str_vartime(
450            //    "28948022309329048855892746252171976963363056481941560715954676764349967630336",
451            //)
452            //.unwrap(),
453        ];
454        test_circuit!(k, WINDOW_SIZE, NUM_BITS, valid_values, invalid_values);
455    }
456
457    #[test]
458    fn native_range_check_128() {
459        let k = 7;
460        const WINDOW_SIZE: usize = 3;
461        const NUM_BITS: usize = 128;
462
463        let valid_values = vec![
464            pallas::Base::zero(),
465            pallas::Base::one(),
466            pallas::Base::from_u128(u128::MAX),
467            pallas::Base::from_u128(rand::random::<u128>()),
468        ];
469
470        let invalid_values = vec![
471            -pallas::Base::one(),
472            pallas::Base::from_u128(u128::MAX) + pallas::Base::one(),
473            -pallas::Base::from_u128(u128::MAX) + pallas::Base::one(),
474            -pallas::Base::from_u128(u128::MAX),
475        ];
476        test_circuit!(k, WINDOW_SIZE, NUM_BITS, valid_values, invalid_values);
477    }
478
479    #[test]
480    fn native_range_check_253() {
481        let k = 8;
482        const WINDOW_SIZE: usize = 3;
483        const NUM_BITS: usize = 253;
484
485        // 2^253 - 1
486        let max_253 = pallas::Base::from_str_vartime(
487            "14474011154664524427946373126085988481658748083205070504932198000989141204991",
488        )
489        .unwrap();
490
491        let valid_values = vec![
492            pallas::Base::zero(),
493            pallas::Base::one(),
494            max_253,
495            // 2^253 / 2
496            pallas::Base::from_str_vartime(
497                "7237005577332262213973186563042994240829374041602535252466099000494570602496",
498            )
499            .unwrap(),
500        ];
501
502        let invalid_values = vec![
503            -pallas::Base::one(),
504            // p - 1
505            pallas::Base::from_str_vartime(
506                "28948022309329048855892746252171976963363056481941560715954676764349967630336",
507            )
508            .unwrap(),
509            max_253 + pallas::Base::one(),
510        ];
511        test_circuit!(k, WINDOW_SIZE, NUM_BITS, valid_values, invalid_values);
512    }
513}