1use proc_macro2::{Ident, Span, TokenStream};
21use quote::quote;
22use syn::{
23 Fields, FieldsNamed, FieldsUnnamed, Index, ItemEnum, ItemStruct, WhereClause, WherePredicate,
24};
25
26use super::{contains_initialize_with, contains_skip, discriminant_map, VariantParts};
27
28fn named_fields(
29 cratename: &Ident,
30 enum_ident: &Ident,
31 variant_ident: &Ident,
32 discriminant_value: &TokenStream,
33 fields: &FieldsNamed,
34) -> syn::Result<VariantParts> {
35 let mut where_predicates: Vec<WherePredicate> = vec![];
36 let mut variant_header = TokenStream::new();
37 let mut variant_body = TokenStream::new();
38
39 for field in &fields.named {
40 if !contains_skip(&field.attrs) {
41 let field_ident = field.ident.clone().unwrap();
42
43 variant_header.extend(quote! { #field_ident, });
44
45 let field_type = &field.ty;
46 where_predicates.push(
47 syn::parse2(quote! {
48 #field_type: #cratename::AsyncEncodable
49 })
50 .unwrap(),
51 );
52
53 variant_body.extend(quote! {
54 len += #field_ident.encode_async(s).await?;
55 })
56 }
57 }
58
59 variant_header = quote! { { #variant_header .. }};
61 let variant_idx_body = quote!(
62 #enum_ident::#variant_ident { .. } => #discriminant_value,
63 );
64
65 Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body })
66}
67
68fn unnamed_fields(
69 cratename: &Ident,
70 enum_ident: &Ident,
71 variant_ident: &Ident,
72 discriminant_value: &TokenStream,
73 fields: &FieldsUnnamed,
74) -> syn::Result<VariantParts> {
75 let mut where_predicates: Vec<WherePredicate> = vec![];
76 let mut variant_header = TokenStream::new();
77 let mut variant_body = TokenStream::new();
78
79 for (field_idx, field) in fields.unnamed.iter().enumerate() {
80 let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported");
81 if contains_skip(&field.attrs) {
82 let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::mixed_site());
83 variant_header.extend(quote! { #field_ident, });
84 } else {
85 let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::mixed_site());
86 variant_header.extend(quote! { #field_ident, });
87
88 let field_type = &field.ty;
89 where_predicates.push(
90 syn::parse2(quote! {
91 #field_type: #cratename::AsyncEncodable
92 })
93 .unwrap(),
94 );
95
96 variant_body.extend(quote! {
97 len += #field_ident.encode_async(s).await?;
98 })
99 }
100 }
101
102 variant_header = quote! { ( #variant_header )};
103 let variant_idx_body = quote!(
104 #enum_ident::#variant_ident(..) => #discriminant_value,
105 );
106
107 Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body })
108}
109
110pub fn async_enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream> {
111 let enum_ident = &input.ident;
112 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
113 let mut where_clause = where_clause.map_or_else(
114 || WhereClause { where_token: Default::default(), predicates: Default::default() },
115 Clone::clone,
116 );
117 let mut all_variants_idx_body = TokenStream::new();
118 let mut fields_body = TokenStream::new();
119 let discriminants = discriminant_map(&input.variants);
120
121 for variant in input.variants.iter() {
122 let variant_ident = &variant.ident;
123 let discriminant_value = discriminants.get(variant_ident).unwrap();
124 let VariantParts { where_predicates, variant_header, variant_body, variant_idx_body } =
125 match &variant.fields {
126 Fields::Named(fields) => {
127 named_fields(&cratename, enum_ident, variant_ident, discriminant_value, fields)?
128 }
129 Fields::Unnamed(fields) => unnamed_fields(
130 &cratename,
131 enum_ident,
132 variant_ident,
133 discriminant_value,
134 fields,
135 )?,
136 Fields::Unit => {
137 let variant_idx_body = quote!(
138 #enum_ident::#variant_ident => #discriminant_value,
139 );
140 VariantParts {
141 where_predicates: vec![],
142 variant_header: TokenStream::new(),
143 variant_body: TokenStream::new(),
144 variant_idx_body,
145 }
146 }
147 };
148 where_predicates.into_iter().for_each(|predicate| where_clause.predicates.push(predicate));
149 all_variants_idx_body.extend(variant_idx_body);
150 fields_body.extend(quote!(
151 #enum_ident::#variant_ident #variant_header => {
152 #variant_body
153 }
154 ))
155 }
156
157 Ok(quote! {
158 #[async_trait]
159 impl #impl_generics #cratename::AsyncEncodable for #enum_ident #ty_generics #where_clause {
160 async fn encode_async<S: #cratename::AsyncWrite + Unpin + Send>(&self, s: &mut S) -> ::std::io::Result<usize> {
161 let variant_idx: u8 = match self {
162 #all_variants_idx_body
163 };
164
165 let mut len = 0;
166 let bytes = variant_idx.to_le_bytes();
167 s.write_all(&bytes).await?;
168 len += bytes.len();
169
170 match self {
171 #fields_body
172 }
173
174 Ok(len)
175 }
176 }
177 })
178}
179
180pub fn async_enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream> {
181 let name = &input.ident;
182 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
183 let mut where_clause = where_clause.map_or_else(
184 || WhereClause { where_token: Default::default(), predicates: Default::default() },
185 Clone::clone,
186 );
187
188 let init_method = contains_initialize_with(&input.attrs);
189 let mut variant_arms = TokenStream::new();
190 let discriminants = discriminant_map(&input.variants);
191
192 for variant in input.variants.iter() {
193 let variant_ident = &variant.ident;
194 let discriminant = discriminants.get(variant_ident).unwrap();
195 let mut variant_header = TokenStream::new();
196 match &variant.fields {
197 Fields::Named(fields) => {
198 for field in &fields.named {
199 let field_name = field.ident.as_ref().unwrap();
200 if contains_skip(&field.attrs) {
201 variant_header.extend(quote! {
202 #field_name: Default::default(),
203 });
204 } else {
205 let field_type = &field.ty;
206 where_clause.predicates.push(
207 syn::parse2(quote! {
208 #field_type: #cratename::AsyncDecodable
209 })
210 .unwrap(),
211 );
212
213 variant_header.extend(quote! {
214 #field_name: #cratename::AsyncDecodable::decode_async(d).await?,
215 });
216 }
217 }
218 variant_header = quote! { { #variant_header }};
219 }
220 Fields::Unnamed(fields) => {
221 for field in fields.unnamed.iter() {
222 if contains_skip(&field.attrs) {
223 variant_header.extend(quote! { Default::default(), });
224 } else {
225 let field_type = &field.ty;
226 where_clause.predicates.push(
227 syn::parse2(quote! {
228 #field_type: #cratename::AsyncDecodable
229 })
230 .unwrap(),
231 );
232
233 variant_header.extend(quote! {
234 #cratename::AsyncDecodable::decode_async(d).await?,
235 });
236 }
237 }
238 variant_header = quote! { ( #variant_header )};
239 }
240 Fields::Unit => {}
241 }
242 variant_arms.extend(quote! {
243 if variant_tag == #discriminant { #name::#variant_ident #variant_header } else
244 });
245 }
246
247 let init = if let Some(method_ident) = init_method {
248 quote! {
249 return_value.#method_ident();
250 }
251 } else {
252 quote! {}
253 };
254
255 Ok(quote! {
256 #[async_trait]
257 impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause {
258 async fn decode_async<D: #cratename::AsyncRead + Unpin + Send>(d: &mut D) -> ::std::io::Result<Self> {
259 let variant_tag: u8 = #cratename::AsyncDecodable::decode_async(d).await?;
260
261 let mut return_value =
262 #variant_arms {
263 return Err(std::io::Error::new(
264 std::io::ErrorKind::InvalidData,
265 format!("Unexpected variant tag: {:?}", variant_tag),
266 ))
267 };
268 #init
269 Ok(return_value)
270 }
271 }
272 })
273}
274
275pub fn async_struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream> {
276 let name = &input.ident;
277 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
278 let mut where_clause = where_clause.map_or_else(
279 || WhereClause { where_token: Default::default(), predicates: Default::default() },
280 Clone::clone,
281 );
282
283 let mut body = TokenStream::new();
284
285 match &input.fields {
286 Fields::Named(fields) => {
287 for field in &fields.named {
288 if contains_skip(&field.attrs) {
289 continue
290 }
291
292 let field_name = field.ident.as_ref().unwrap();
293 let delta = quote! {
294 len += self.#field_name.encode_async(s).await?;
295 };
296 body.extend(delta);
297
298 let field_type = &field.ty;
299 where_clause.predicates.push(
300 syn::parse2(quote! {
301 #field_type: #cratename::AsyncEncodable
302 })
303 .unwrap(),
304 );
305 }
306 }
307 Fields::Unnamed(fields) => {
308 for field_idx in 0..fields.unnamed.len() {
309 let field_idx = Index {
310 index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"),
311 span: Span::call_site(),
312 };
313 let delta = quote! {
314 len += self.#field_idx.encode_async(s).await?;
315 };
316 body.extend(delta);
317 }
318 }
319 Fields::Unit => {}
320 }
321
322 Ok(quote! {
323 #[async_trait]
324 impl #impl_generics #cratename::AsyncEncodable for #name #ty_generics #where_clause {
325 async fn encode_async<S: #cratename::AsyncWrite + Unpin + Send>(&self, s: &mut S) -> ::std::io::Result<usize> {
326 let mut len = 0;
327 #body
328 Ok(len)
329 }
330 }
331 })
332}
333
334pub fn async_struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream> {
335 let name = &input.ident;
336 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
337 let mut where_clause = where_clause.map_or_else(
338 || WhereClause { where_token: Default::default(), predicates: Default::default() },
339 Clone::clone,
340 );
341
342 let init_method = contains_initialize_with(&input.attrs);
343 let return_value = match &input.fields {
344 Fields::Named(fields) => {
345 let mut body = TokenStream::new();
346 for field in &fields.named {
347 let field_name = field.ident.as_ref().unwrap();
348
349 let delta = if contains_skip(&field.attrs) {
350 quote! {
351 #field_name: Default::default(),
352 }
353 } else {
354 let field_type = &field.ty;
355 where_clause.predicates.push(
356 syn::parse2(quote! {
357 #field_type: #cratename::AsyncDecodable
358 })
359 .unwrap(),
360 );
361
362 quote! {
363 #field_name: #cratename::AsyncDecodable::decode_async(d).await?,
364 }
365 };
366 body.extend(delta);
367 }
368 quote! {
369 Self { #body }
370 }
371 }
372 Fields::Unnamed(fields) => {
373 let mut body = TokenStream::new();
374 for _ in 0..fields.unnamed.len() {
375 let delta = quote! {
376 #cratename::AsyncDecodable::decode_async(d).await?,
377 };
378 body.extend(delta);
379 }
380 quote! {
381 Self( #body )
382 }
383 }
384 Fields::Unit => {
385 quote! {
386 Self {}
387 }
388 }
389 };
390
391 if let Some(method_ident) = init_method {
392 Ok(quote! {
393 #[async_trait]
394 impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause {
395 async fn decode_async<D: #cratename::AsyncRead + Unpin + Send>(d: &mut D) -> ::std::io::Result<Self> {
396 let mut return_value = #return_value;
397 return_value.#method_ident();
398 Ok(return_value)
399 }
400 }
401 })
402 } else {
403 Ok(quote! {
404 #[async_trait]
405 impl #impl_generics #cratename::AsyncDecodable for #name #ty_generics #where_clause {
406 async fn decode_async<D: #cratename::AsyncRead + Unpin + Send>(d: &mut D) -> ::std::io::Result<Self> {
407 Ok(#return_value)
408 }
409 }
410 })
411 }
412}