1use proc_macro2::{Ident, Span, TokenStream};
21use quote::quote;
22use syn::{
23 Fields, FieldsNamed, FieldsUnnamed, Index, ItemEnum, ItemStruct, WhereClause, WherePredicate,
24};
25
26use super::{contains_initialize_with, contains_skip, discriminant_map, VariantParts};
27
28fn named_fields(
29 cratename: &Ident,
30 enum_ident: &Ident,
31 variant_ident: &Ident,
32 discriminant_value: &TokenStream,
33 fields: &FieldsNamed,
34) -> syn::Result<VariantParts> {
35 let mut where_predicates: Vec<WherePredicate> = vec![];
36 let mut variant_header = TokenStream::new();
37 let mut variant_body = TokenStream::new();
38
39 for field in &fields.named {
40 if !contains_skip(&field.attrs) {
41 let field_ident = field.ident.clone().unwrap();
42
43 variant_header.extend(quote! { #field_ident, });
44
45 let field_type = &field.ty;
46 where_predicates.push(
47 syn::parse2(quote! {
48 #field_type: #cratename::Encodable
49 })
50 .unwrap(),
51 );
52
53 variant_body.extend(quote! {
54 len += #field_ident.encode(s)?;
55 })
56 }
57 }
58
59 variant_header = quote! { { #variant_header .. }};
61 let variant_idx_body = quote!(
62 #enum_ident::#variant_ident { .. } => #discriminant_value,
63 );
64
65 Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body })
66}
67
68fn unnamed_fields(
69 cratename: &Ident,
70 enum_ident: &Ident,
71 variant_ident: &Ident,
72 discriminant_value: &TokenStream,
73 fields: &FieldsUnnamed,
74) -> syn::Result<VariantParts> {
75 let mut where_predicates: Vec<WherePredicate> = vec![];
76 let mut variant_header = TokenStream::new();
77 let mut variant_body = TokenStream::new();
78
79 for (field_idx, field) in fields.unnamed.iter().enumerate() {
80 let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported");
81 if contains_skip(&field.attrs) {
82 let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::mixed_site());
83 variant_header.extend(quote! { #field_ident, });
84 } else {
85 let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::mixed_site());
86 variant_header.extend(quote! { #field_ident, });
87
88 let field_type = &field.ty;
89 where_predicates.push(
90 syn::parse2(quote! {
91 #field_type: #cratename::Encodable
92 })
93 .unwrap(),
94 );
95
96 variant_body.extend(quote! {
97 len += #field_ident.encode(s)?;
98 })
99 }
100 }
101
102 variant_header = quote! { ( #variant_header )};
103 let variant_idx_body = quote!(
104 #enum_ident::#variant_ident(..) => #discriminant_value,
105 );
106
107 Ok(VariantParts { where_predicates, variant_header, variant_body, variant_idx_body })
108}
109
110pub fn enum_ser(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream> {
111 let enum_ident = &input.ident;
112 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
113 let mut where_clause = where_clause.map_or_else(
114 || WhereClause { where_token: Default::default(), predicates: Default::default() },
115 Clone::clone,
116 );
117 let mut all_variants_idx_body = TokenStream::new();
118 let mut fields_body = TokenStream::new();
119 let discriminants = discriminant_map(&input.variants);
120
121 for variant in input.variants.iter() {
122 let variant_ident = &variant.ident;
123 let discriminant_value = discriminants.get(variant_ident).unwrap();
124 let VariantParts { where_predicates, variant_header, variant_body, variant_idx_body } =
125 match &variant.fields {
126 Fields::Named(fields) => {
127 named_fields(&cratename, enum_ident, variant_ident, discriminant_value, fields)?
128 }
129 Fields::Unnamed(fields) => unnamed_fields(
130 &cratename,
131 enum_ident,
132 variant_ident,
133 discriminant_value,
134 fields,
135 )?,
136 Fields::Unit => {
137 let variant_idx_body = quote!(
138 #enum_ident::#variant_ident => #discriminant_value,
139 );
140 VariantParts {
141 where_predicates: vec![],
142 variant_header: TokenStream::new(),
143 variant_body: TokenStream::new(),
144 variant_idx_body,
145 }
146 }
147 };
148 where_predicates.into_iter().for_each(|predicate| where_clause.predicates.push(predicate));
149 all_variants_idx_body.extend(variant_idx_body);
150 fields_body.extend(quote!(
151 #enum_ident::#variant_ident #variant_header => {
152 #variant_body
153 }
154 ))
155 }
156
157 Ok(quote! {
158 impl #impl_generics #cratename::Encodable for #enum_ident #ty_generics #where_clause {
159 fn encode<S: std::io::Write>(&self, s: &mut S) -> ::core::result::Result<usize, std::io::Error> {
160 let variant_idx: u8 = match self {
161 #all_variants_idx_body
162 };
163
164 let mut len = 0;
165 let bytes = variant_idx.to_le_bytes();
166 s.write_all(&bytes)?;
167 len += bytes.len();
168
169 match self {
170 #fields_body
171 }
172
173 Ok(len)
174 }
175 }
176 })
177}
178
179pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream> {
180 let name = &input.ident;
181 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
182 let mut where_clause = where_clause.map_or_else(
183 || WhereClause { where_token: Default::default(), predicates: Default::default() },
184 Clone::clone,
185 );
186
187 let init_method = contains_initialize_with(&input.attrs);
188 let mut variant_arms = TokenStream::new();
189 let discriminants = discriminant_map(&input.variants);
190
191 for variant in input.variants.iter() {
192 let variant_ident = &variant.ident;
193 let discriminant = discriminants.get(variant_ident).unwrap();
194 let mut variant_header = TokenStream::new();
195 match &variant.fields {
196 Fields::Named(fields) => {
197 for field in &fields.named {
198 let field_name = field.ident.as_ref().unwrap();
199 if contains_skip(&field.attrs) {
200 variant_header.extend(quote! {
201 #field_name: Default::default(),
202 });
203 } else {
204 let field_type = &field.ty;
205 where_clause.predicates.push(
206 syn::parse2(quote! {
207 #field_type: #cratename::Decodable
208 })
209 .unwrap(),
210 );
211
212 variant_header.extend(quote! {
213 #field_name: #cratename::Decodable::decode(d)?,
214 });
215 }
216 }
217 variant_header = quote! { { #variant_header }};
218 }
219 Fields::Unnamed(fields) => {
220 for field in fields.unnamed.iter() {
221 if contains_skip(&field.attrs) {
222 variant_header.extend(quote! { Default::default(), });
223 } else {
224 let field_type = &field.ty;
225 where_clause.predicates.push(
226 syn::parse2(quote! {
227 #field_type: #cratename::Decodable
228 })
229 .unwrap(),
230 );
231
232 variant_header.extend(quote! {
233 #cratename::Decodable::decode(d)?,
234 });
235 }
236 }
237 variant_header = quote! { ( #variant_header )};
238 }
239 Fields::Unit => {}
240 }
241 variant_arms.extend(quote! {
242 if variant_tag == #discriminant { #name::#variant_ident #variant_header } else
243 });
244 }
245
246 let init = if let Some(method_ident) = init_method {
247 quote! {
248 return_value.#method_ident();
249 }
250 } else {
251 quote! {}
252 };
253
254 Ok(quote! {
255 impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause {
256 fn decode<D: std::io::Read>(d: &mut D) -> ::core::result::Result<Self, std::io::Error> {
257 let variant_tag: u8 = #cratename::Decodable::decode(d)?;
258
259 let mut return_value =
260 #variant_arms {
261 return Err(std::io::Error::new(
262 std::io::ErrorKind::InvalidData,
263 format!("Unexpected variant tag: {:?}", variant_tag),
264 ))
265 };
266 #init
267 Ok(return_value)
268 }
269 }
270 })
271}
272
273pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream> {
274 let name = &input.ident;
275 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
276 let mut where_clause = where_clause.map_or_else(
277 || WhereClause { where_token: Default::default(), predicates: Default::default() },
278 Clone::clone,
279 );
280
281 let mut body = TokenStream::new();
282
283 match &input.fields {
284 Fields::Named(fields) => {
285 for field in &fields.named {
286 if contains_skip(&field.attrs) {
287 continue
288 }
289
290 let field_name = field.ident.as_ref().unwrap();
291 let delta = quote! {
292 len += self.#field_name.encode(s)?;
293 };
294 body.extend(delta);
295
296 let field_type = &field.ty;
297 where_clause.predicates.push(
298 syn::parse2(quote! {
299 #field_type: #cratename::Encodable
300 })
301 .unwrap(),
302 );
303 }
304 }
305 Fields::Unnamed(fields) => {
306 for field_idx in 0..fields.unnamed.len() {
307 let field_idx = Index {
308 index: u32::try_from(field_idx).expect("up to 2^32 fields are supported"),
309 span: Span::call_site(),
310 };
311 let delta = quote! {
312 len += self.#field_idx.encode(s)?;
313 };
314 body.extend(delta);
315 }
316 }
317 Fields::Unit => {}
318 }
319
320 Ok(quote! {
321 impl #impl_generics #cratename::Encodable for #name #ty_generics #where_clause {
322 fn encode<S: std::io::Write>(&self, s: &mut S) -> ::core::result::Result<usize, std::io::Error> {
323 let mut len = 0;
324 #body
325 Ok(len)
326 }
327 }
328 })
329}
330
331pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream> {
332 let name = &input.ident;
333 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
334 let mut where_clause = where_clause.map_or_else(
335 || WhereClause { where_token: Default::default(), predicates: Default::default() },
336 Clone::clone,
337 );
338
339 let init_method = contains_initialize_with(&input.attrs);
340 let return_value = match &input.fields {
341 Fields::Named(fields) => {
342 let mut body = TokenStream::new();
343 for field in &fields.named {
344 let field_name = field.ident.as_ref().unwrap();
345
346 let delta = if contains_skip(&field.attrs) {
347 quote! {
348 #field_name: Default::default(),
349 }
350 } else {
351 let field_type = &field.ty;
352 where_clause.predicates.push(
353 syn::parse2(quote! {
354 #field_type: #cratename::Decodable
355 })
356 .unwrap(),
357 );
358
359 quote! {
360 #field_name: #cratename::Decodable::decode(d)?,
361 }
362 };
363 body.extend(delta);
364 }
365 quote! {
366 Self { #body }
367 }
368 }
369 Fields::Unnamed(fields) => {
370 let mut body = TokenStream::new();
371 for _ in 0..fields.unnamed.len() {
372 let delta = quote! {
373 #cratename::Decodable::decode(d)?,
374 };
375 body.extend(delta);
376 }
377 quote! {
378 Self( #body )
379 }
380 }
381 Fields::Unit => {
382 quote! {
383 Self {}
384 }
385 }
386 };
387
388 if let Some(method_ident) = init_method {
389 Ok(quote! {
390 impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause {
391 fn decode<D: std::io::Read>(d: &mut D) -> ::core::result::Result<Self, std::io::Error> {
392 let mut return_value = #return_value;
393 return_value.#method_ident();
394 Ok(return_value)
395 }
396 }
397 })
398 } else {
399 Ok(quote! {
400 impl #impl_generics #cratename::Decodable for #name #ty_generics #where_clause {
401 fn decode<D: std::io::Read>(d: &mut D) -> ::core::result::Result<Self, std::io::Error> {
402 Ok(#return_value)
403 }
404 }
405 })
406 }
407}