darkfi_derive_internal/
async_derive.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19//! Derive (de)serialization for enums and structs, see src/serial/derive
20use proc_macro2::{Ident, Span, TokenStream};
21use quote::quote;
22use syn::{
23    Fields, FieldsNamed, FieldsUnnamed, Index, ItemEnum, ItemStruct, WhereClause, WherePredicate,
24};
25
26use super::{contains_initialize_with, contains_skip, discriminant_map, VariantParts};
27
28fn named_fields(
29    cratename: &Ident,
30    enum_ident: &Ident,
31    variant_ident: &Ident,
32    discriminant_value: &TokenStream,
33    fields: &FieldsNamed,
34) -> syn::Result<VariantParts> {
35    let mut where_predicates: Vec<WherePredicate> = vec![];
36    let mut variant_header = TokenStream::new();
37    let mut variant_body = TokenStream::new();
38
39    for field in &fields.named {
40        if !contains_skip(&field.attrs) {
41            let field_ident = field.ident.clone().unwrap();
42
43            variant_header.extend(quote! { #field_ident, });
44
45            let field_type = &field.ty;
46            where_predicates.push(
47                syn::parse2(quote! {
48                    #field_type: #cratename::AsyncEncodable
49                })
50                .unwrap(),
51            );
52
53            variant_body.extend(quote! {
54                len += #field_ident.encode_async(s).await?;
55            })
56        }
57    }
58
59    // `..` pattern matching works even if all fields were specified
60    variant_header = quote! { { #variant_header .. }};
61    let variant_idx_body = quote!(
62        #enum_ident::#variant_ident { .. } => #discriminant_value,
63    );
64
65    Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body })
66}
67
68fn unnamed_fields(
69    cratename: &Ident,
70    enum_ident: &Ident,
71    variant_ident: &Ident,
72    discriminant_value: &TokenStream,
73    fields: &FieldsUnnamed,
74) -> syn::Result<VariantParts> {
75    let mut where_predicates: Vec<WherePredicate> = vec![];
76    let mut variant_header = TokenStream::new();
77    let mut variant_body = TokenStream::new();
78
79    for (field_idx, field) in fields.unnamed.iter().enumerate() {
80        let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported");
81        if contains_skip(&field.attrs) {
82            let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::mixed_site());
83            variant_header.extend(quote! { #field_ident, });
84        } else {
85            let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::mixed_site());
86            variant_header.extend(quote! { #field_ident, });
87
88            let field_type = &field.ty;
89            where_predicates.push(
90                syn::parse2(quote! {
91                    #field_type: #cratename::AsyncEncodable
92                })
93                .unwrap(),
94            );
95
96            variant_body.extend(quote! {
97                len += #field_ident.encode_async(s).await?;
98            })
99        }
100    }
101
102    variant_header = quote! { ( #variant_header )};
103    let variant_idx_body = quote!(
104        #enum_ident::#variant_ident(..) => #discriminant_value,
105    );
106
107    Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body })
108}
109
110pub fn async_enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream> {
111    let enum_ident = &input.ident;
112    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
113    let mut where_clause = where_clause.map_or_else(
114        || WhereClause { where_token: Default::default(), predicates: Default::default() },
115        Clone::clone,
116    );
117    let mut all_variants_idx_body = TokenStream::new();
118    let mut fields_body = TokenStream::new();
119    let discriminants = discriminant_map(&input.variants);
120
121    for variant in input.variants.iter() {
122        let variant_ident = &variant.ident;
123        let discriminant_value = discriminants.get(variant_ident).unwrap();
124        let VariantParts { where_predicates, variant_header, variant_body, variant_idx_body } =
125            match &variant.fields {
126                Fields::Named(fields) => {
127                    named_fields(&cratename, enum_ident, variant_ident, discriminant_value, fields)?
128                }
129                Fields::Unnamed(fields) => unnamed_fields(
130                    &cratename,
131                    enum_ident,
132                    variant_ident,
133                    discriminant_value,
134                    fields,
135                )?,
136                Fields::Unit => {
137                    let variant_idx_body = quote!(
138                        #enum_ident::#variant_ident => #discriminant_value,
139                    );
140                    VariantParts {
141                        where_predicates: vec![],
142                        variant_header: TokenStream::new(),
143                        variant_body: TokenStream::new(),
144                        variant_idx_body,
145                    }
146                }
147            };
148        where_predicates.into_iter().for_each(|predicate| where_clause.predicates.push(predicate));
149        all_variants_idx_body.extend(variant_idx_body);
150        fields_body.extend(quote!(
151            #enum_ident::#variant_ident #variant_header => {
152                #variant_body
153            }
154        ))
155    }
156
157    Ok(quote! {
158    #[async_trait]
159    impl #impl_generics #cratename::AsyncEncodable for #enum_ident #ty_generics #where_clause {
160        async fn encode_async<S: #cratename::AsyncWrite + Unpin + Send>(&self, s: &mut S) -> ::std::io::Result<usize> {
161            let variant_idx: u8 = match self {
162                #all_variants_idx_body
163            };
164
165            let mut len = 0;
166            let bytes = variant_idx.to_le_bytes();
167            s.write_all(&bytes).await?;
168            len += bytes.len();
169
170            match self {
171                #fields_body
172            }
173
174            Ok(len)
175        }
176    }
177    })
178}
179
180pub fn async_enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream> {
181    let name = &input.ident;
182    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
183    let mut where_clause = where_clause.map_or_else(
184        || WhereClause { where_token: Default::default(), predicates: Default::default() },
185        Clone::clone,
186    );
187
188    let init_method = contains_initialize_with(&input.attrs);
189    let mut variant_arms = TokenStream::new();
190    let discriminants = discriminant_map(&input.variants);
191
192    for variant in input.variants.iter() {
193        let variant_ident = &variant.ident;
194        let discriminant = discriminants.get(variant_ident).unwrap();
195        let mut variant_header = TokenStream::new();
196        match &variant.fields {
197            Fields::Named(fields) => {
198                for field in &fields.named {
199                    let field_name = field.ident.as_ref().unwrap();
200                    if contains_skip(&field.attrs) {
201                        variant_header.extend(quote! {
202                            #field_name: Default::default(),
203                        });
204                    } else {
205                        let field_type = &field.ty;
206                        where_clause.predicates.push(
207                            syn::parse2(quote! {
208                                #field_type: #cratename::AsyncDecodable
209                            })
210                            .unwrap(),
211                        );
212
213                        variant_header.extend(quote! {
214                            #field_name: #cratename::AsyncDecodable::decode_async(d).await?,
215                        });
216                    }
217                }
218                variant_header = quote! { { #variant_header }};
219            }
220            Fields::Unnamed(fields) => {
221                for field in fields.unnamed.iter() {
222                    if contains_skip(&field.attrs) {
223                        variant_header.extend(quote! { Default::default(), });
224                    } else {
225                        let field_type = &field.ty;
226                        where_clause.predicates.push(
227                            syn::parse2(quote! {
228                                #field_type: #cratename::AsyncDecodable
229                            })
230                            .unwrap(),
231                        );
232
233                        variant_header.extend(quote! {
234                            #cratename::AsyncDecodable::decode_async(d).await?,
235                        });
236                    }
237                }
238                variant_header = quote! { ( #variant_header )};
239            }
240            Fields::Unit => {}
241        }
242        variant_arms.extend(quote! {
243            if variant_tag == #discriminant { #name::#variant_ident #variant_header } else
244        });
245    }
246
247    let init = if let Some(method_ident) = init_method {
248        quote! {
249            return_value.#method_ident();
250        }
251    } else {
252        quote! {}
253    };
254
255    Ok(quote! {
256    #[async_trait]
257    impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause {
258        async fn decode_async<D: #cratename::AsyncRead + Unpin + Send>(d: &mut D) -> ::std::io::Result<Self> {
259            let variant_tag: u8 = #cratename::AsyncDecodable::decode_async(d).await?;
260
261            let mut return_value =
262                #variant_arms {
263                    return Err(std::io::Error::new(
264                        std::io::ErrorKind::InvalidData,
265                        format!("Unexpected variant tag: {:?}", variant_tag),
266                    ))
267                };
268                #init
269                Ok(return_value)
270            }
271        }
272    })
273}
274
275pub fn async_struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream> {
276    let name = &input.ident;
277    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
278    let mut where_clause = where_clause.map_or_else(
279        || WhereClause { where_token: Default::default(), predicates: Default::default() },
280        Clone::clone,
281    );
282
283    let mut body = TokenStream::new();
284
285    match &input.fields {
286        Fields::Named(fields) => {
287            for field in &fields.named {
288                if contains_skip(&field.attrs) {
289                    continue
290                }
291
292                let field_name = field.ident.as_ref().unwrap();
293                let delta = quote! {
294                    len += self.#field_name.encode_async(s).await?;
295                };
296                body.extend(delta);
297
298                let field_type = &field.ty;
299                where_clause.predicates.push(
300                    syn::parse2(quote! {
301                        #field_type: #cratename::AsyncEncodable
302                    })
303                    .unwrap(),
304                );
305            }
306        }
307        Fields::Unnamed(fields) => {
308            for field_idx in 0..fields.unnamed.len() {
309                let field_idx = Index {
310                    index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"),
311                    span: Span::call_site(),
312                };
313                let delta = quote! {
314                    len += self.#field_idx.encode_async(s).await?;
315                };
316                body.extend(delta);
317            }
318        }
319        Fields::Unit => {}
320    }
321
322    Ok(quote! {
323    #[async_trait]
324    impl #impl_generics #cratename::AsyncEncodable for #name #ty_generics #where_clause {
325        async fn encode_async<S: #cratename::AsyncWrite + Unpin + Send>(&self, s: &mut S) -> ::std::io::Result<usize> {
326            let mut len = 0;
327            #body
328            Ok(len)
329        }
330    }
331    })
332}
333
334pub fn async_struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream> {
335    let name = &input.ident;
336    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
337    let mut where_clause = where_clause.map_or_else(
338        || WhereClause { where_token: Default::default(), predicates: Default::default() },
339        Clone::clone,
340    );
341
342    let init_method = contains_initialize_with(&input.attrs);
343    let return_value = match &input.fields {
344        Fields::Named(fields) => {
345            let mut body = TokenStream::new();
346            for field in &fields.named {
347                let field_name = field.ident.as_ref().unwrap();
348
349                let delta = if contains_skip(&field.attrs) {
350                    quote! {
351                        #field_name: Default::default(),
352                    }
353                } else {
354                    let field_type = &field.ty;
355                    where_clause.predicates.push(
356                        syn::parse2(quote! {
357                            #field_type: #cratename::AsyncDecodable
358                        })
359                        .unwrap(),
360                    );
361
362                    quote! {
363                        #field_name: #cratename::AsyncDecodable::decode_async(d).await?,
364                    }
365                };
366                body.extend(delta);
367            }
368            quote! {
369                Self { #body }
370            }
371        }
372        Fields::Unnamed(fields) => {
373            let mut body = TokenStream::new();
374            for _ in 0..fields.unnamed.len() {
375                let delta = quote! {
376                    #cratename::AsyncDecodable::decode_async(d).await?,
377                };
378                body.extend(delta);
379            }
380            quote! {
381                Self( #body )
382            }
383        }
384        Fields::Unit => {
385            quote! {
386                Self {}
387            }
388        }
389    };
390
391    if let Some(method_ident) = init_method {
392        Ok(quote! {
393        #[async_trait]
394        impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause {
395            async fn decode_async<D: #cratename::AsyncRead + Unpin + Send>(d: &mut D) -> ::std::io::Result<Self> {
396                let mut return_value = #return_value;
397                return_value.#method_ident();
398                Ok(return_value)
399            }
400        }
401        })
402    } else {
403        Ok(quote! {
404        #[async_trait]
405        impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause {
406            async fn decode_async<D: #cratename::AsyncRead + Unpin + Send>(d: &mut D) -> ::std::io::Result<Self> {
407                Ok(#return_value)
408            }
409        }
410        })
411    }
412}