darkfi/zk/gadget/
less_than.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19//! Less-Than Gadget
20//!
21//! Given two values:
22//!     - `a`, a NUM_OF_BITS-length value and
23//!     - `b`, an arbitrary field element,
24//! this gadget constrains them in the following way:
25//!     - in `strict` mode, `a` is constrained to be strictly less than `b`;
26//!     - else, `a` is constrained to be less than or equal to `b`.
27
28use halo2_proofs::{
29    arithmetic::Field,
30    circuit::{AssignedCell, Chip, Layouter, Region, Value},
31    pasta::pallas,
32    plonk::{Advice, Column, ConstraintSystem, Error, Expression, Selector, TableColumn},
33    poly::Rotation,
34};
35
36use super::native_range_check::{NativeRangeCheckChip, NativeRangeCheckConfig};
37
38#[derive(Clone, Debug)]
39pub struct LessThanConfig<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> {
40    pub s_lt: Selector,
41    pub s_leq: Selector,
42    pub a: Column<Advice>,
43    pub b: Column<Advice>,
44    pub a_offset: Column<Advice>,
45    pub range_a_config: NativeRangeCheckConfig<WINDOW_SIZE, NUM_OF_BITS>,
46    pub range_a_offset_config: NativeRangeCheckConfig<WINDOW_SIZE, NUM_OF_BITS>,
47    pub k_values_table: TableColumn,
48}
49
50#[derive(Clone, Debug)]
51pub struct LessThanChip<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> {
52    config: LessThanConfig<WINDOW_SIZE, NUM_OF_BITS>,
53}
54
55impl<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> Chip<pallas::Base>
56    for LessThanChip<WINDOW_SIZE, NUM_OF_BITS>
57{
58    type Config = LessThanConfig<WINDOW_SIZE, NUM_OF_BITS>;
59    type Loaded = ();
60
61    fn config(&self) -> &Self::Config {
62        &self.config
63    }
64
65    fn loaded(&self) -> &Self::Loaded {
66        &()
67    }
68}
69
70impl<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> LessThanChip<WINDOW_SIZE, NUM_OF_BITS> {
71    pub fn construct(config: LessThanConfig<WINDOW_SIZE, NUM_OF_BITS>) -> Self {
72        Self { config }
73    }
74
75    pub fn configure(
76        meta: &mut ConstraintSystem<pallas::Base>,
77        a: Column<Advice>,
78        b: Column<Advice>,
79        a_offset: Column<Advice>,
80        z1: Column<Advice>,
81        z2: Column<Advice>,
82        k_values_table: TableColumn,
83    ) -> LessThanConfig<WINDOW_SIZE, NUM_OF_BITS> {
84        let s_lt = meta.selector();
85        let s_leq = meta.selector();
86
87        meta.enable_equality(a);
88        meta.enable_equality(b);
89        meta.enable_equality(a_offset);
90        meta.enable_equality(z1);
91        meta.enable_equality(z2);
92
93        // configure range check for `a` and `offset`
94        let range_a_config =
95            NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::configure(meta, z1, k_values_table);
96
97        let range_a_offset_config =
98            NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::configure(meta, z2, k_values_table);
99
100        let config = LessThanConfig {
101            s_lt,
102            s_leq,
103            a,
104            b,
105            a_offset,
106            range_a_config,
107            range_a_offset_config,
108            k_values_table,
109        };
110
111        meta.create_gate("a_offset", |meta| {
112            let s_lt = meta.query_selector(config.s_lt);
113            let s_leq = meta.query_selector(config.s_leq);
114            let a = meta.query_advice(config.a, Rotation::cur());
115            let b = meta.query_advice(config.b, Rotation::cur());
116            let a_offset = meta.query_advice(config.a_offset, Rotation::cur());
117            let two_pow_m =
118                Expression::Constant(pallas::Base::from(2).pow([NUM_OF_BITS as u64, 0, 0, 0]));
119
120            // If strict, a_offset = a + 2^m - b
121            let strict_check =
122                s_lt * (a_offset.clone() - two_pow_m.clone() + b.clone() - a.clone());
123            // If leq, a_offset = a + 2^m - b - 1
124            let leq_check =
125                s_leq * (a_offset - two_pow_m + b - a + Expression::Constant(pallas::Base::one()));
126
127            vec![strict_check, leq_check]
128        });
129
130        config
131    }
132
133    pub fn witness_less_than(
134        &self,
135        mut layouter: impl Layouter<pallas::Base>,
136        a: Value<pallas::Base>,
137        b: Value<pallas::Base>,
138        offset: usize,
139        strict: bool,
140    ) -> Result<(), Error> {
141        let (a, _, a_offset) = layouter.assign_region(
142            || "a less than b",
143            |mut region: Region<'_, pallas::Base>| {
144                let a = region.assign_advice(|| "a", self.config.a, offset, || a)?;
145                let b = region.assign_advice(|| "b", self.config.b, offset, || b)?;
146                let a_offset = self.less_than(region, a.clone(), b.clone(), offset, strict)?;
147                Ok((a, b, a_offset))
148            },
149        )?;
150
151        self.less_than_range_check(layouter, a, a_offset)?;
152
153        Ok(())
154    }
155
156    /*
157    pub fn witness_less_than2(
158        &self,
159        mut layouter: impl Layouter<pallas::Base>,
160        a: Value<pallas::Base>,
161        b: Value<pallas::Base>,
162        offset: usize,
163        strict: bool,
164    ) -> Result<AssignedCell<pallas::Base, pallas::Base>, Error> {
165        let (a, _, a_offset) = layouter.assign_region(
166            || "a less than b",
167            |mut region: Region<'_, pallas::Base>| {
168                let a = region.assign_advice(|| "a", self.config.a, offset, || a)?;
169                let b = region.assign_advice(|| "b", self.config.b, offset, || b)?;
170                let a_offset = self.less_than(region, a.clone(), b.clone(), offset)?;
171                Ok((a, b, a_offset))
172            },
173        )?;
174
175        self.less_than_range_check(layouter, a, a_offset.clone(), strict)?;
176
177        Ok(a_offset)
178    }
179    */
180
181    pub fn copy_less_than(
182        &self,
183        mut layouter: impl Layouter<pallas::Base>,
184        a: AssignedCell<pallas::Base, pallas::Base>,
185        b: AssignedCell<pallas::Base, pallas::Base>,
186        offset: usize,
187        strict: bool,
188    ) -> Result<(), Error> {
189        let (a, _, a_offset) = layouter.assign_region(
190            || "a less than b",
191            |mut region: Region<'_, pallas::Base>| {
192                let a = a.copy_advice(|| "a", &mut region, self.config.a, offset)?;
193                let b = b.copy_advice(|| "b", &mut region, self.config.b, offset)?;
194                let a_offset = self.less_than(region, a.clone(), b.clone(), offset, strict)?;
195                Ok((a, b, a_offset))
196            },
197        )?;
198
199        self.less_than_range_check(layouter, a, a_offset)?;
200
201        Ok(())
202    }
203
204    pub fn less_than_range_check(
205        &self,
206        mut layouter: impl Layouter<pallas::Base>,
207        a: AssignedCell<pallas::Base, pallas::Base>,
208        a_offset: AssignedCell<pallas::Base, pallas::Base>,
209    ) -> Result<(), Error> {
210        let range_a_chip = NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::construct(
211            self.config.range_a_config.clone(),
212        );
213        let range_a_offset_chip = NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::construct(
214            self.config.range_a_offset_config.clone(),
215        );
216
217        range_a_chip.copy_range_check(layouter.namespace(|| "a copy_range_check"), a)?;
218
219        range_a_offset_chip
220            .copy_range_check(layouter.namespace(|| "a_offset copy_range_check"), a_offset)?;
221
222        Ok(())
223    }
224
225    pub fn less_than(
226        &self,
227        mut region: Region<'_, pallas::Base>,
228        a: AssignedCell<pallas::Base, pallas::Base>,
229        b: AssignedCell<pallas::Base, pallas::Base>,
230        offset: usize,
231        strict: bool,
232    ) -> Result<AssignedCell<pallas::Base, pallas::Base>, Error> {
233        if strict {
234            // enable `less_than` selector
235            self.config.s_lt.enable(&mut region, offset)?;
236        } else {
237            self.config.s_leq.enable(&mut region, offset)?;
238        }
239
240        let two_pow_m = pallas::Base::from(2).pow([NUM_OF_BITS as u64, 0, 0, 0]);
241        let a_offset = if strict {
242            a.value().zip(b.value()).map(|(a, b)| *a + (two_pow_m - b))
243        } else {
244            a.value().zip(b.value()).map(|(a, b)| *a + (two_pow_m - b) - pallas::Base::one())
245        };
246        let a_offset =
247            region.assign_advice(|| "a_offset", self.config.a_offset, offset, || a_offset)?;
248
249        Ok(a_offset)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use darkfi_sdk::crypto::pasta_prelude::PrimeField;
257    use halo2_proofs::{
258        circuit::floor_planner,
259        dev::{CircuitLayout, MockProver},
260        plonk::Circuit,
261    };
262
263    macro_rules! test_circuit {
264        ($k: expr, $strict:expr, $window_size:expr, $num_bits:expr, $valid_pairs:expr, $invalid_pairs:expr) => {
265            #[derive(Default)]
266            struct LessThanCircuit {
267                a: Value<pallas::Base>,
268                b: Value<pallas::Base>,
269            }
270
271            impl Circuit<pallas::Base> for LessThanCircuit {
272                type Config = (LessThanConfig<$window_size, $num_bits>, Column<Advice>);
273                type FloorPlanner = floor_planner::V1;
274                type Params = ();
275
276                fn without_witnesses(&self) -> Self {
277                    Self { a: Value::unknown(), b: Value::unknown() }
278                }
279
280                fn configure(meta: &mut ConstraintSystem<pallas::Base>) -> Self::Config {
281                    let w = meta.advice_column();
282                    meta.enable_equality(w);
283
284                    let a = meta.advice_column();
285                    let b = meta.advice_column();
286                    let a_offset = meta.advice_column();
287                    let z1 = meta.advice_column();
288                    let z2 = meta.advice_column();
289
290                    let k_values_table = meta.lookup_table_column();
291
292                    let constants = meta.fixed_column();
293                    meta.enable_constant(constants);
294
295                    (
296                        LessThanChip::<$window_size, $num_bits>::configure(
297                            meta,
298                            a,
299                            b,
300                            a_offset,
301                            z1,
302                            z2,
303                            k_values_table,
304                        ),
305                        w,
306                    )
307                }
308
309                fn synthesize(
310                    &self,
311                    config: Self::Config,
312                    mut layouter: impl Layouter<pallas::Base>,
313                ) -> Result<(), Error> {
314                    let less_than_chip =
315                        LessThanChip::<$window_size, $num_bits>::construct(config.0.clone());
316
317                    NativeRangeCheckChip::<$window_size, $num_bits>::load_k_table(
318                        &mut layouter,
319                        config.0.k_values_table,
320                    )?;
321
322                    less_than_chip.witness_less_than(
323                        layouter.namespace(|| "a < b"),
324                        self.a,
325                        self.b,
326                        0,
327                        $strict,
328                    )?;
329
330                    Ok(())
331                }
332            }
333
334            use plotters::prelude::*;
335            let circuit = LessThanCircuit {
336                a: Value::known(pallas::Base::zero()),
337                b: Value::known(pallas::Base::one()),
338            };
339            let file_name = format!("target/lessthan_check_{:?}_circuit_layout.png", $num_bits);
340            let root = BitMapBackend::new(file_name.as_str(), (3840, 2160)).into_drawing_area();
341            CircuitLayout::default().render($k, &circuit, &root).unwrap();
342
343            let check = if $strict { "<" } else { "<=" };
344            for (a, b) in $valid_pairs {
345                println!("{:?} bit (valid) {:?} {} {:?} check", $num_bits, a, check, b);
346                let circuit = LessThanCircuit { a: Value::known(a), b: Value::known(b) };
347                let prover = MockProver::run($k, &circuit, vec![]).unwrap();
348                prover.assert_satisfied();
349            }
350
351            for (a, b) in $invalid_pairs {
352                println!("{:?} bit (invalid) {:?} {} {:?} check", $num_bits, a, check, b);
353                let circuit = LessThanCircuit { a: Value::known(a), b: Value::known(b) };
354                let prover = MockProver::run($k, &circuit, vec![]).unwrap();
355                assert!(prover.verify().is_err())
356            }
357        };
358    }
359
360    #[test]
361    fn leq_64() {
362        let k = 5;
363        const WINDOW_SIZE: usize = 3;
364        const NUM_OF_BITS: usize = 64;
365
366        let valid_pairs = [
367            (pallas::Base::ZERO, pallas::Base::ZERO),
368            (pallas::Base::ONE, pallas::Base::ONE),
369            (pallas::Base::from(13), pallas::Base::from(15)),
370            (pallas::Base::ZERO, pallas::Base::from(u64::MAX)),
371            (pallas::Base::ONE, pallas::Base::from(rand::random::<u64>())),
372            (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
373            (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX)),
374        ];
375
376        let invalid_pairs = [
377            (pallas::Base::from(14), pallas::Base::from(11)),
378            (pallas::Base::from(u64::MAX), pallas::Base::ZERO),
379            (pallas::Base::ONE, pallas::Base::ZERO),
380        ];
381        test_circuit!(k, false, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
382    }
383
384    #[test]
385    fn less_than_64() {
386        let k = 5;
387        const WINDOW_SIZE: usize = 3;
388        const NUM_OF_BITS: usize = 64;
389
390        let valid_pairs = [
391            (pallas::Base::from(13), pallas::Base::from(15)),
392            (pallas::Base::ZERO, pallas::Base::from(u64::MAX)),
393            (pallas::Base::ONE, pallas::Base::from(rand::random::<u64>())),
394            (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
395        ];
396
397        let invalid_pairs = [
398            (pallas::Base::from(14), pallas::Base::from(11)),
399            (pallas::Base::from(u64::MAX), pallas::Base::ZERO),
400            (pallas::Base::ZERO, pallas::Base::ZERO),
401            (pallas::Base::ONE, pallas::Base::ONE),
402            (pallas::Base::ONE, pallas::Base::ZERO),
403            (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX)),
404        ];
405        test_circuit!(k, true, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
406    }
407
408    #[test]
409    fn leq_253() {
410        let k = 7;
411        const WINDOW_SIZE: usize = 3;
412        const NUM_OF_BITS: usize = 253;
413
414        const P_MINUS_1: pallas::Base = pallas::Base::from_raw([
415            0x992d30ed00000000,
416            0x224698fc094cf91b,
417            0x0000000000000000,
418            0x4000000000000000,
419        ]);
420
421        // 2^253 - 1. This is the maximum we can check.
422        const MAX_253: pallas::Base = pallas::Base::from_raw([
423            0xFFFFFFFFFFFFFFFF,
424            0xFFFFFFFFFFFFFFFF,
425            0xFFFFFFFFFFFFFFFF,
426            0x1FFFFFFFFFFFFFFF,
427        ]);
428
429        let valid_pairs = [
430            (pallas::Base::ZERO, pallas::Base::ZERO),
431            (pallas::Base::ZERO, pallas::Base::ONE),
432            (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
433            (
434                pallas::Base::from_u128(u128::MAX),
435                pallas::Base::from_u128(u128::MAX) + pallas::Base::ONE,
436            ),
437            (MAX_253, MAX_253),
438            (MAX_253 - pallas::Base::from(2), MAX_253 - pallas::Base::ONE),
439            (MAX_253 - pallas::Base::ONE, MAX_253),
440            (MAX_253, MAX_253 + pallas::Base::ONE),
441        ];
442
443        let invalid_pairs = [
444            (pallas::Base::ONE, pallas::Base::ZERO),
445            (P_MINUS_1 - pallas::Base::ONE, P_MINUS_1),
446            (P_MINUS_1, pallas::Base::ZERO),
447            (P_MINUS_1, P_MINUS_1),
448            (MAX_253, pallas::Base::ZERO),
449            (MAX_253, pallas::Base::ONE),
450            (MAX_253 + pallas::Base::ONE, pallas::Base::ZERO),
451            (MAX_253 + pallas::Base::ONE, pallas::Base::ONE),
452            (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::ONE),
453            (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::from(2)),
454        ];
455
456        test_circuit!(k, false, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
457    }
458
459    #[test]
460    fn less_than_253() {
461        let k = 7;
462        const WINDOW_SIZE: usize = 3;
463        const NUM_OF_BITS: usize = 253;
464
465        const P_MINUS_1: pallas::Base = pallas::Base::from_raw([
466            0x992d30ed00000000,
467            0x224698fc094cf91b,
468            0x0000000000000000,
469            0x4000000000000000,
470        ]);
471
472        // 2^253 - 1. This is the maximum we can check.
473        const MAX_253: pallas::Base = pallas::Base::from_raw([
474            0xFFFFFFFFFFFFFFFF,
475            0xFFFFFFFFFFFFFFFF,
476            0xFFFFFFFFFFFFFFFF,
477            0x1FFFFFFFFFFFFFFF,
478        ]);
479
480        let valid_pairs = [
481            (pallas::Base::ZERO, pallas::Base::ONE),
482            (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
483            (
484                pallas::Base::from_u128(u128::MAX),
485                pallas::Base::from_u128(u128::MAX) + pallas::Base::ONE,
486            ),
487            (MAX_253 - pallas::Base::from(2), MAX_253 - pallas::Base::ONE),
488            (MAX_253 - pallas::Base::ONE, MAX_253),
489            (MAX_253, MAX_253 + pallas::Base::ONE),
490        ];
491
492        let invalid_pairs = [
493            (pallas::Base::ZERO, pallas::Base::ZERO),
494            (pallas::Base::ONE, pallas::Base::ZERO),
495            (P_MINUS_1 - pallas::Base::ONE, P_MINUS_1),
496            (P_MINUS_1, P_MINUS_1),
497            (P_MINUS_1, pallas::Base::ZERO),
498            (MAX_253, MAX_253),
499            (MAX_253, pallas::Base::ZERO),
500            (MAX_253, pallas::Base::ONE),
501            (MAX_253 + pallas::Base::ONE, pallas::Base::ZERO),
502            (MAX_253 + pallas::Base::ONE, pallas::Base::ONE),
503            (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::ONE),
504            (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::from(2)),
505        ];
506
507        test_circuit!(k, true, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
508    }
509}