snafu_derive/
shared.rs

1use std::collections::BTreeSet;
2
3pub(crate) use self::context_module::ContextModule;
4pub(crate) use self::context_selector::ContextSelector;
5pub(crate) use self::display::{Display, DisplayMatchArm};
6pub(crate) use self::error::{Error, ErrorProvideMatchArm, ErrorSourceMatchArm};
7pub(crate) use self::error_compat::{ErrorCompat, ErrorCompatBacktraceMatchArm};
8
9pub(crate) struct StaticIdent(&'static str);
10
11impl quote::ToTokens for StaticIdent {
12    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
13        proc_macro2::Ident::new(self.0, proc_macro2::Span::call_site()).to_tokens(tokens)
14    }
15}
16
17struct AllFieldNames<'a>(&'a crate::FieldContainer);
18
19impl<'a> AllFieldNames<'a> {
20    fn field_names(&self) -> BTreeSet<&'a proc_macro2::Ident> {
21        let user_fields = self.0.selector_kind.user_fields();
22        let backtrace_field = self.0.backtrace_field.as_ref();
23        let implicit_fields = &self.0.implicit_fields;
24        let message_field = self.0.selector_kind.message_field();
25        let source_field = self.0.selector_kind.source_field();
26
27        user_fields
28            .iter()
29            .chain(backtrace_field)
30            .chain(implicit_fields)
31            .chain(message_field)
32            .map(crate::Field::name)
33            .chain(source_field.map(crate::SourceField::name))
34            .collect()
35    }
36}
37
38pub mod context_module {
39    use crate::ModuleName;
40    use heck::ToSnakeCase;
41    use proc_macro2::TokenStream;
42    use quote::{quote, ToTokens};
43    use syn::Ident;
44
45    #[derive(Copy, Clone)]
46    pub(crate) struct ContextModule<'a, T> {
47        pub container_name: &'a Ident,
48        pub module_name: &'a ModuleName,
49        pub visibility: Option<&'a dyn ToTokens>,
50        pub body: &'a T,
51    }
52
53    impl<'a, T> ToTokens for ContextModule<'a, T>
54    where
55        T: ToTokens,
56    {
57        fn to_tokens(&self, stream: &mut TokenStream) {
58            let module_name = match self.module_name {
59                ModuleName::Default => {
60                    let name_str = self.container_name.to_string().to_snake_case();
61                    syn::Ident::new(&name_str, self.container_name.span())
62                }
63                ModuleName::Custom(name) => name.clone(),
64            };
65
66            let visibility = self.visibility;
67            let body = self.body;
68
69            let module_tokens = quote! {
70                #visibility mod #module_name {
71                    use super::*;
72
73                    #body
74                }
75            };
76
77            stream.extend(module_tokens);
78        }
79    }
80}
81
82pub mod context_selector {
83    use crate::{ContextSelectorKind, Field, SuffixKind};
84    use proc_macro2::TokenStream;
85    use quote::{format_ident, quote, IdentFragment, ToTokens};
86
87    const DEFAULT_SUFFIX: &str = "Snafu";
88
89    #[derive(Copy, Clone)]
90    pub(crate) struct ContextSelector<'a> {
91        pub backtrace_field: Option<&'a Field>,
92        pub implicit_fields: &'a [Field],
93        pub crate_root: &'a dyn ToTokens,
94        pub error_constructor_name: &'a dyn ToTokens,
95        pub original_generics_without_defaults: &'a [TokenStream],
96        pub parameterized_error_name: &'a dyn ToTokens,
97        pub selector_doc_string: &'a str,
98        pub selector_kind: &'a ContextSelectorKind,
99        pub selector_name: &'a proc_macro2::Ident,
100        pub user_fields: &'a [Field],
101        pub visibility: Option<&'a dyn ToTokens>,
102        pub where_clauses: &'a [TokenStream],
103        pub default_suffix: &'a SuffixKind,
104    }
105
106    impl ToTokens for ContextSelector<'_> {
107        fn to_tokens(&self, stream: &mut TokenStream) {
108            use self::ContextSelectorKind::*;
109
110            let context_selector = match self.selector_kind {
111                Context { source_field, .. } => {
112                    let context_selector_type = self.generate_type();
113                    let context_selector_impl = match source_field {
114                        Some(_) => None,
115                        None => Some(self.generate_leaf()),
116                    };
117                    let context_selector_into_error_impl =
118                        self.generate_into_error(source_field.as_ref());
119
120                    quote! {
121                        #context_selector_type
122                        #context_selector_impl
123                        #context_selector_into_error_impl
124                    }
125                }
126                Whatever {
127                    source_field,
128                    message_field,
129                } => self.generate_whatever(source_field.as_ref(), message_field),
130                NoContext { source_field } => self.generate_from_source(source_field),
131            };
132
133            stream.extend(context_selector)
134        }
135    }
136
137    impl ContextSelector<'_> {
138        fn user_field_generics(&self) -> Vec<proc_macro2::Ident> {
139            (0..self.user_fields.len())
140                .map(|i| format_ident!("__T{}", i))
141                .collect()
142        }
143
144        fn user_field_names(&self) -> Vec<&syn::Ident> {
145            self.user_fields
146                .iter()
147                .map(|Field { name, .. }| name)
148                .collect()
149        }
150
151        fn parameterized_selector_name(&self) -> TokenStream {
152            let selector_name = self.selector_name.to_string();
153            let selector_name = selector_name.trim_end_matches("Error");
154            let suffix: &dyn IdentFragment = match self.selector_kind {
155                ContextSelectorKind::Context { suffix, .. } => {
156                    match suffix.resolve_with_default(self.default_suffix) {
157                        SuffixKind::Some(s) => s,
158                        SuffixKind::None => &"",
159                        SuffixKind::Default => &DEFAULT_SUFFIX,
160                    }
161                }
162                _ => &DEFAULT_SUFFIX,
163            };
164            let selector_name = format_ident!(
165                "{}{}",
166                selector_name,
167                suffix,
168                span = self.selector_name.span()
169            );
170            let user_generics = self.user_field_generics();
171
172            quote! { #selector_name<#(#user_generics,)*> }
173        }
174
175        fn extended_where_clauses(&self) -> Vec<TokenStream> {
176            let user_fields = self.user_fields;
177            let user_field_generics = self.user_field_generics();
178            let where_clauses = self.where_clauses;
179
180            let target_types = user_fields
181                .iter()
182                .map(|Field { ty, .. }| quote! { ::core::convert::Into<#ty>});
183
184            user_field_generics
185                .into_iter()
186                .zip(target_types)
187                .map(|(gen, bound)| quote! { #gen: #bound })
188                .chain(where_clauses.iter().cloned())
189                .collect()
190        }
191
192        fn transfer_user_fields(&self) -> Vec<TokenStream> {
193            self.user_field_names()
194                .into_iter()
195                .map(|name| {
196                    quote! { #name: ::core::convert::Into::into(self.#name) }
197                })
198                .collect()
199        }
200
201        fn construct_implicit_fields(&self) -> TokenStream {
202            let crate_root = self.crate_root;
203            let expression = quote! {
204                #crate_root::GenerateImplicitData::generate()
205            };
206
207            self.construct_implicit_fields_with_expression(expression)
208        }
209
210        fn construct_implicit_fields_with_source(&self) -> TokenStream {
211            let crate_root = self.crate_root;
212            let expression = quote! { {
213                use #crate_root::AsErrorSource;
214                let error = error.as_error_source();
215                #crate_root::GenerateImplicitData::generate_with_source(error)
216            } };
217
218            self.construct_implicit_fields_with_expression(expression)
219        }
220
221        fn construct_implicit_fields_with_expression(
222            &self,
223            expression: TokenStream,
224        ) -> TokenStream {
225            self.implicit_fields
226                .iter()
227                .chain(self.backtrace_field)
228                .map(|field| {
229                    let name = &field.name;
230                    quote! { #name: #expression, }
231                })
232                .collect()
233        }
234
235        fn generate_type(self) -> TokenStream {
236            let visibility = self.visibility;
237            let parameterized_selector_name = self.parameterized_selector_name();
238            let user_field_generics = self.user_field_generics();
239            let user_field_names = self.user_field_names();
240            let selector_doc_string = self.selector_doc_string;
241
242            let body = if user_field_names.is_empty() {
243                quote! { ; }
244            } else {
245                quote! {
246                    {
247                        #(
248                            #[allow(missing_docs)]
249                            #visibility #user_field_names: #user_field_generics
250                        ),*
251                    }
252                }
253            };
254
255            quote! {
256                #[derive(Debug, Copy, Clone)]
257                #[doc = #selector_doc_string]
258                #visibility struct #parameterized_selector_name #body
259            }
260        }
261
262        fn generate_leaf(self) -> TokenStream {
263            let error_constructor_name = self.error_constructor_name;
264            let original_generics_without_defaults = self.original_generics_without_defaults;
265            let parameterized_error_name = self.parameterized_error_name;
266            let parameterized_selector_name = self.parameterized_selector_name();
267            let user_field_generics = self.user_field_generics();
268            let visibility = self.visibility;
269            let extended_where_clauses = self.extended_where_clauses();
270            let transfer_user_fields = self.transfer_user_fields();
271            let construct_implicit_fields = self.construct_implicit_fields();
272
273            quote! {
274                impl<#(#user_field_generics,)*> #parameterized_selector_name {
275                    #[doc = "Consume the selector and return the associated error"]
276                    #[must_use]
277                    #[track_caller]
278                    #visibility fn build<#(#original_generics_without_defaults,)*>(self) -> #parameterized_error_name
279                    where
280                        #(#extended_where_clauses),*
281                    {
282                        #error_constructor_name {
283                            #construct_implicit_fields
284                            #(#transfer_user_fields,)*
285                        }
286                    }
287
288                    #[doc = "Consume the selector and return a `Result` with the associated error"]
289                    #[track_caller]
290                    #visibility fn fail<#(#original_generics_without_defaults,)* __T>(self) -> ::core::result::Result<__T, #parameterized_error_name>
291                    where
292                        #(#extended_where_clauses),*
293                    {
294                        ::core::result::Result::Err(self.build())
295                    }
296                }
297            }
298        }
299
300        fn generate_into_error(self, source_field: Option<&crate::SourceField>) -> TokenStream {
301            let crate_root = self.crate_root;
302            let error_constructor_name = self.error_constructor_name;
303            let original_generics_without_defaults = self.original_generics_without_defaults;
304            let parameterized_error_name = self.parameterized_error_name;
305            let parameterized_selector_name = self.parameterized_selector_name();
306            let user_field_generics = self.user_field_generics();
307            let extended_where_clauses = self.extended_where_clauses();
308            let transfer_user_fields = self.transfer_user_fields();
309            let construct_implicit_fields = if source_field.is_some() {
310                self.construct_implicit_fields_with_source()
311            } else {
312                self.construct_implicit_fields()
313            };
314
315            let (source_ty, transform_source, transfer_source_field) = match source_field {
316                Some(source_field) => {
317                    let SourceInfo {
318                        source_field_type,
319                        transform_source,
320                        transfer_source_field,
321                    } = build_source_info(source_field);
322                    (
323                        quote! { #source_field_type },
324                        Some(transform_source),
325                        Some(transfer_source_field),
326                    )
327                }
328                None => (quote! { #crate_root::NoneError }, None, None),
329            };
330
331            quote! {
332                impl<#(#original_generics_without_defaults,)* #(#user_field_generics,)*> #crate_root::IntoError<#parameterized_error_name> for #parameterized_selector_name
333                where
334                    #parameterized_error_name: #crate_root::Error + #crate_root::ErrorCompat,
335                    #(#extended_where_clauses),*
336                {
337                    type Source = #source_ty;
338
339                    #[track_caller]
340                    fn into_error(self, error: Self::Source) -> #parameterized_error_name {
341                        #transform_source;
342                        #error_constructor_name {
343                            #construct_implicit_fields
344                            #transfer_source_field
345                            #(#transfer_user_fields),*
346                        }
347                    }
348                }
349            }
350        }
351
352        fn generate_whatever(
353            self,
354            source_field: Option<&crate::SourceField>,
355            message_field: &crate::Field,
356        ) -> TokenStream {
357            let crate_root = self.crate_root;
358            let parameterized_error_name = self.parameterized_error_name;
359            let error_constructor_name = self.error_constructor_name;
360            let construct_implicit_fields = self.construct_implicit_fields();
361            let original_generics_without_defaults = self.original_generics_without_defaults;
362            let construct_implicit_fields_with_source =
363                self.construct_implicit_fields_with_source();
364            let extended_where_clauses = self.extended_where_clauses();
365
366            // testme: transform
367
368            let (source_ty, transfer_source_field, empty_source_field) = match source_field {
369                Some(f) => {
370                    let source_field_type = f.transformation.source_ty();
371                    let source_field_name = &f.name;
372                    let source_transformation = f.transformation.transformation();
373
374                    (
375                        quote! { #source_field_type },
376                        Some(quote! { #source_field_name: (#source_transformation)(error), }),
377                        Some(quote! { #source_field_name: core::option::Option::None, }),
378                    )
379                }
380                None => (quote! { #crate_root::NoneError }, None, None),
381            };
382
383            let message_field_name = &message_field.name;
384
385            quote! {
386                impl<#(#original_generics_without_defaults,)*> #crate_root::FromString for #parameterized_error_name
387                where
388                    #(#extended_where_clauses),*
389                {
390                    type Source = #source_ty;
391
392                    #[track_caller]
393                    fn without_source(message: String) -> Self {
394                        #error_constructor_name {
395                            #construct_implicit_fields
396                            #empty_source_field
397                            #message_field_name: message,
398                        }
399                    }
400
401                    #[track_caller]
402                    fn with_source(error: Self::Source, message: String) -> Self {
403                        #error_constructor_name {
404                            #construct_implicit_fields_with_source
405                            #transfer_source_field
406                            #message_field_name: message,
407                        }
408                    }
409                }
410            }
411        }
412
413        fn generate_from_source(self, source_field: &crate::SourceField) -> TokenStream {
414            let parameterized_error_name = self.parameterized_error_name;
415            let error_constructor_name = self.error_constructor_name;
416            let construct_implicit_fields_with_source =
417                self.construct_implicit_fields_with_source();
418            let original_generics_without_defaults = self.original_generics_without_defaults;
419            let user_field_generics = self.user_field_generics();
420            let where_clauses = self.where_clauses;
421
422            let SourceInfo {
423                source_field_type,
424                transform_source,
425                transfer_source_field,
426            } = build_source_info(source_field);
427
428            quote! {
429                impl<#(#original_generics_without_defaults,)* #(#user_field_generics,)*> ::core::convert::From<#source_field_type> for #parameterized_error_name
430                where
431                    #(#where_clauses),*
432                {
433                    #[track_caller]
434                    fn from(error: #source_field_type) -> Self {
435                        #transform_source;
436                        #error_constructor_name {
437                            #construct_implicit_fields_with_source
438                            #transfer_source_field
439                        }
440                    }
441                }
442            }
443        }
444    }
445
446    struct SourceInfo<'a> {
447        source_field_type: &'a syn::Type,
448        transform_source: TokenStream,
449        transfer_source_field: TokenStream,
450    }
451
452    // Assumes that the error is in a variable called "error"
453    fn build_source_info(source_field: &crate::SourceField) -> SourceInfo<'_> {
454        let source_field_name = source_field.name();
455        let source_field_type = source_field.transformation.source_ty();
456        let target_field_type = source_field.transformation.target_ty();
457        let source_transformation = source_field.transformation.transformation();
458
459        let transform_source =
460            quote! { let error: #target_field_type = (#source_transformation)(error) };
461        let transfer_source_field = quote! { #source_field_name: error, };
462
463        SourceInfo {
464            source_field_type,
465            transform_source,
466            transfer_source_field,
467        }
468    }
469}
470
471pub mod display {
472    use super::StaticIdent;
473    use proc_macro2::TokenStream;
474    use quote::{quote, ToTokens};
475    use std::collections::BTreeSet;
476
477    const FORMATTER_ARG: StaticIdent = StaticIdent("__snafu_display_formatter");
478
479    pub(crate) struct Display<'a> {
480        pub(crate) arms: &'a [TokenStream],
481        pub(crate) original_generics: &'a [TokenStream],
482        pub(crate) parameterized_error_name: &'a dyn ToTokens,
483        pub(crate) where_clauses: &'a [TokenStream],
484    }
485
486    impl ToTokens for Display<'_> {
487        fn to_tokens(&self, stream: &mut TokenStream) {
488            let Self {
489                arms,
490                original_generics,
491                parameterized_error_name,
492                where_clauses,
493            } = *self;
494
495            let display_impl = quote! {
496                #[allow(single_use_lifetimes)]
497                impl<#(#original_generics),*> ::core::fmt::Display for #parameterized_error_name
498                where
499                    #(#where_clauses),*
500                {
501                    fn fmt(&self, #FORMATTER_ARG: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
502                        #[allow(unused_variables)]
503                        match *self {
504                            #(#arms),*
505                        }
506                    }
507                }
508            };
509
510            stream.extend(display_impl);
511        }
512    }
513
514    pub(crate) struct DisplayMatchArm<'a> {
515        pub(crate) field_container: &'a crate::FieldContainer,
516        pub(crate) default_name: &'a dyn ToTokens,
517        pub(crate) display_format: Option<&'a crate::Display>,
518        pub(crate) doc_comment: Option<&'a crate::DocComment>,
519        pub(crate) pattern_ident: &'a dyn ToTokens,
520        pub(crate) selector_kind: &'a crate::ContextSelectorKind,
521    }
522
523    impl ToTokens for DisplayMatchArm<'_> {
524        fn to_tokens(&self, stream: &mut TokenStream) {
525            let Self {
526                field_container,
527                default_name,
528                display_format,
529                doc_comment,
530                pattern_ident,
531                selector_kind,
532            } = *self;
533
534            let source_field = selector_kind.source_field();
535
536            if field_container.is_transparent {
537                // transparent errors always have a source field
538                let source_field_name = source_field.unwrap().name();
539
540                let match_arm = quote! {
541                    #pattern_ident { ref #source_field_name, .. } => {
542                        ::core::fmt::Display::fmt(#source_field_name, #FORMATTER_ARG)
543                    }
544                };
545
546                stream.extend(match_arm);
547                return;
548            }
549
550            let mut shorthand_names = &BTreeSet::new();
551            let mut assigned_names = &BTreeSet::new();
552
553            let format = match (display_format, doc_comment) {
554                (Some(v), _) => {
555                    let exprs = &v.exprs;
556                    shorthand_names = &v.shorthand_names;
557                    assigned_names = &v.assigned_names;
558                    quote! { #(#exprs),* }
559                }
560                (_, Some(d)) => {
561                    let content = &d.content;
562                    shorthand_names = &d.shorthand_names;
563                    quote! { #content }
564                }
565                _ => quote! { stringify!(#default_name) },
566            };
567
568            let field_names = super::AllFieldNames(field_container).field_names();
569
570            let shorthand_names = shorthand_names.iter().collect::<BTreeSet<_>>();
571            let assigned_names = assigned_names.iter().collect::<BTreeSet<_>>();
572
573            let shorthand_fields = &shorthand_names & &field_names;
574            let shorthand_fields = &shorthand_fields - &assigned_names;
575
576            let shorthand_assignments = quote! { #( #shorthand_fields = #shorthand_fields ),* };
577
578            let match_arm = quote! {
579                #pattern_ident { #(ref #field_names),* } => {
580                    write!(#FORMATTER_ARG, #format, #shorthand_assignments)
581                }
582            };
583
584            stream.extend(match_arm);
585        }
586    }
587}
588
589pub mod error {
590    use super::StaticIdent;
591    use crate::{FieldContainer, Provide, SourceField};
592    use proc_macro2::TokenStream;
593    use quote::{format_ident, quote, ToTokens};
594
595    pub(crate) const PROVIDE_ARG: StaticIdent = StaticIdent("__snafu_provide_demand");
596
597    pub(crate) struct Error<'a> {
598        pub(crate) crate_root: &'a dyn ToTokens,
599        pub(crate) description_arms: &'a [TokenStream],
600        pub(crate) original_generics: &'a [TokenStream],
601        pub(crate) parameterized_error_name: &'a dyn ToTokens,
602        pub(crate) provide_arms: &'a [TokenStream],
603        pub(crate) source_arms: &'a [TokenStream],
604        pub(crate) where_clauses: &'a [TokenStream],
605    }
606
607    impl ToTokens for Error<'_> {
608        fn to_tokens(&self, stream: &mut TokenStream) {
609            let Self {
610                crate_root,
611                description_arms,
612                original_generics,
613                parameterized_error_name,
614                provide_arms,
615                source_arms,
616                where_clauses,
617            } = *self;
618
619            let description_fn = quote! {
620                fn description(&self) -> &str {
621                    match *self {
622                        #(#description_arms)*
623                    }
624                }
625            };
626
627            let source_body = quote! {
628                use #crate_root::AsErrorSource;
629                match *self {
630                    #(#source_arms)*
631                }
632            };
633
634            let cause_fn = quote! {
635                fn cause(&self) -> ::core::option::Option<&dyn #crate_root::Error> {
636                    #source_body
637                }
638            };
639
640            let source_fn = quote! {
641                fn source(&self) -> ::core::option::Option<&(dyn #crate_root::Error + 'static)> {
642                    #source_body
643                }
644            };
645
646            let provide_fn = if cfg!(feature = "unstable-provider-api") {
647                Some(quote! {
648                    fn provide<'a>(&'a self, #PROVIDE_ARG: &mut #crate_root::error::Request<'a>) {
649                        match *self {
650                            #(#provide_arms,)*
651                        };
652                    }
653                })
654            } else {
655                None
656            };
657
658            let error = quote! {
659                #[allow(single_use_lifetimes)]
660                impl<#(#original_generics),*> #crate_root::Error for #parameterized_error_name
661                where
662                    Self: ::core::fmt::Debug + ::core::fmt::Display,
663                    #(#where_clauses),*
664                {
665                    #description_fn
666                    #cause_fn
667                    #source_fn
668                    #provide_fn
669                }
670            };
671
672            stream.extend(error);
673        }
674    }
675
676    pub(crate) struct ErrorSourceMatchArm<'a> {
677        pub(crate) field_container: &'a FieldContainer,
678        pub(crate) pattern_ident: &'a dyn ToTokens,
679    }
680
681    impl ToTokens for ErrorSourceMatchArm<'_> {
682        fn to_tokens(&self, stream: &mut TokenStream) {
683            let Self {
684                field_container:
685                    FieldContainer {
686                        selector_kind,
687                        is_transparent,
688                        ..
689                    },
690                pattern_ident,
691            } = *self;
692
693            let source_field = selector_kind.source_field();
694
695            let arm = match source_field {
696                Some(source_field) => {
697                    let SourceField {
698                        name: field_name, ..
699                    } = source_field;
700
701                    let convert_to_error_source = if selector_kind.is_whatever() {
702                        quote! {
703                            #field_name.as_ref().map(|e| e.as_error_source())
704                        }
705                    } else if *is_transparent {
706                        quote! {
707                            #field_name.as_error_source().source()
708                        }
709                    } else {
710                        quote! {
711                            ::core::option::Option::Some(#field_name.as_error_source())
712                        }
713                    };
714
715                    quote! {
716                        #pattern_ident { ref #field_name, .. } => {
717                            #convert_to_error_source
718                        }
719                    }
720                }
721                None => {
722                    quote! {
723                        #pattern_ident { .. } => { ::core::option::Option::None }
724                    }
725                }
726            };
727
728            stream.extend(arm);
729        }
730    }
731
732    pub(crate) struct ProvidePlus<'a> {
733        provide: &'a Provide,
734        cached_name: proc_macro2::Ident,
735    }
736
737    pub(crate) struct ErrorProvideMatchArm<'a> {
738        pub(crate) crate_root: &'a dyn ToTokens,
739        pub(crate) field_container: &'a FieldContainer,
740        pub(crate) pattern_ident: &'a dyn ToTokens,
741    }
742
743    impl<'a> ToTokens for ErrorProvideMatchArm<'a> {
744        fn to_tokens(&self, stream: &mut TokenStream) {
745            let Self {
746                crate_root,
747                field_container,
748                pattern_ident,
749            } = *self;
750
751            let user_fields = field_container.user_fields();
752            let provides = enhance_provider_list(field_container.provides());
753            let field_names = super::AllFieldNames(field_container).field_names();
754
755            let (hi_explicit_calls, lo_explicit_calls) = build_explicit_provide_calls(&provides);
756
757            let cached_expressions = quote_cached_expressions(&provides);
758
759            let provide_refs = user_fields
760                .iter()
761                .chain(&field_container.implicit_fields)
762                .chain(field_container.selector_kind.message_field())
763                .flat_map(|f| {
764                    if f.provide {
765                        Some((&f.ty, f.name()))
766                    } else {
767                        None
768                    }
769                });
770
771            let provided_source = field_container
772                .selector_kind
773                .source_field()
774                .filter(|f| f.provide);
775
776            let source_provide_ref =
777                provided_source.map(|f| (f.transformation.source_ty(), f.name()));
778
779            let provide_refs = provide_refs.chain(source_provide_ref);
780
781            let source_chain = provided_source.map(|f| {
782                let name = f.name();
783                quote! {
784                    #name.provide(#PROVIDE_ARG);
785                }
786            });
787
788            let user_chained = quote_chained(crate_root, &provides);
789
790            let shorthand_calls = provide_refs.map(|(ty, name)| {
791                quote! { #PROVIDE_ARG.provide_ref::<#ty>(#name) }
792            });
793
794            let provided_backtrace = field_container
795                .backtrace_field
796                .as_ref()
797                .filter(|f| f.provide);
798
799            let provide_backtrace = provided_backtrace.map(|f| {
800                let name = f.name();
801                quote! {
802                    if #PROVIDE_ARG.would_be_satisfied_by_ref_of::<#crate_root::Backtrace>() {
803                        if let ::core::option::Option::Some(bt) = #crate_root::AsBacktrace::as_backtrace(#name) {
804                            #PROVIDE_ARG.provide_ref::<#crate_root::Backtrace>(bt);
805                        }
806                    }
807                }
808            });
809
810            let arm = quote! {
811                #pattern_ident { #(ref #field_names,)* .. } => {
812                    #(#cached_expressions;)*
813                    #(#hi_explicit_calls;)*
814                    #source_chain;
815                    #(#user_chained;)*
816                    #provide_backtrace;
817                    #(#shorthand_calls;)*
818                    #(#lo_explicit_calls;)*
819                }
820            };
821
822            stream.extend(arm);
823        }
824    }
825
826    pub(crate) fn enhance_provider_list(provides: &[Provide]) -> Vec<ProvidePlus<'_>> {
827        provides
828            .iter()
829            .enumerate()
830            .map(|(i, provide)| {
831                let cached_name = format_ident!("__snafu_cached_expr_{}", i);
832                ProvidePlus {
833                    provide,
834                    cached_name,
835                }
836            })
837            .collect()
838    }
839
840    pub(crate) fn quote_cached_expressions<'a>(
841        provides: &'a [ProvidePlus<'a>],
842    ) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
843        provides.iter().filter(|pp| pp.provide.is_chain).map(|pp| {
844            let cached_name = &pp.cached_name;
845            let expr = &pp.provide.expr;
846
847            quote! {
848                let #cached_name = #expr;
849            }
850        })
851    }
852
853    pub(crate) fn quote_chained<'a>(
854        crate_root: &'a dyn ToTokens,
855        provides: &'a [ProvidePlus<'a>],
856    ) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
857        provides
858            .iter()
859            .filter(|pp| pp.provide.is_chain)
860            .map(move |pp| {
861                let arm = if pp.provide.is_opt {
862                    quote! { ::core::option::Option::Some(chained_item) }
863                } else {
864                    quote! { chained_item }
865                };
866                let cached_name = &pp.cached_name;
867
868                quote! {
869                    if let #arm = #cached_name {
870                        #crate_root::Error::provide(chained_item, #PROVIDE_ARG);
871                    }
872                }
873            })
874    }
875
876    fn quote_provides<'a, I>(provides: I) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a
877    where
878        I: IntoIterator<Item = &'a ProvidePlus<'a>>,
879        I::IntoIter: 'a,
880    {
881        provides.into_iter().map(|pp| {
882            let ProvidePlus {
883                provide:
884                    Provide {
885                        is_chain,
886                        is_opt,
887                        is_priority: _,
888                        is_ref,
889                        ty,
890                        expr,
891                    },
892                cached_name,
893            } = pp;
894
895            let effective_expr = if *is_chain {
896                quote! { #cached_name }
897            } else {
898                quote! { #expr }
899            };
900
901            match (is_opt, is_ref) {
902                (true, true) => {
903                    quote! {
904                        if #PROVIDE_ARG.would_be_satisfied_by_ref_of::<#ty>() {
905                            if let ::core::option::Option::Some(v) = #effective_expr {
906                                #PROVIDE_ARG.provide_ref::<#ty>(v);
907                            }
908                        }
909                    }
910                }
911                (true, false) => {
912                    quote! {
913                        if #PROVIDE_ARG.would_be_satisfied_by_value_of::<#ty>() {
914                            if let ::core::option::Option::Some(v) = #effective_expr {
915                                #PROVIDE_ARG.provide_value::<#ty>(v);
916                            }
917                        }
918                    }
919                }
920                (false, true) => {
921                    quote! { #PROVIDE_ARG.provide_ref_with::<#ty>(|| #effective_expr) }
922                }
923                (false, false) => {
924                    quote! { #PROVIDE_ARG.provide_value_with::<#ty>(|| #effective_expr) }
925                }
926            }
927        })
928    }
929
930    pub(crate) fn build_explicit_provide_calls<'a>(
931        provides: &'a [ProvidePlus<'a>],
932    ) -> (
933        impl Iterator<Item = TokenStream> + 'a,
934        impl Iterator<Item = TokenStream> + 'a,
935    ) {
936        let (high_priority, low_priority): (Vec<_>, Vec<_>) =
937            provides.iter().partition(|pp| pp.provide.is_priority);
938
939        let hi_explicit_calls = quote_provides(high_priority);
940        let lo_explicit_calls = quote_provides(low_priority);
941
942        (hi_explicit_calls, lo_explicit_calls)
943    }
944}
945
946pub mod error_compat {
947    use crate::{Field, FieldContainer, SourceField};
948    use proc_macro2::TokenStream;
949    use quote::{quote, ToTokens};
950
951    pub(crate) struct ErrorCompat<'a> {
952        pub(crate) crate_root: &'a dyn ToTokens,
953        pub(crate) parameterized_error_name: &'a dyn ToTokens,
954        pub(crate) backtrace_arms: &'a [TokenStream],
955        pub(crate) original_generics: &'a [TokenStream],
956        pub(crate) where_clauses: &'a [TokenStream],
957    }
958
959    impl ToTokens for ErrorCompat<'_> {
960        fn to_tokens(&self, stream: &mut TokenStream) {
961            let Self {
962                crate_root,
963                parameterized_error_name,
964                backtrace_arms,
965                original_generics,
966                where_clauses,
967            } = *self;
968
969            let backtrace_fn = quote! {
970                fn backtrace(&self) -> ::core::option::Option<&#crate_root::Backtrace> {
971                    match *self {
972                        #(#backtrace_arms),*
973                    }
974                }
975            };
976
977            let error_compat_impl = quote! {
978                #[allow(single_use_lifetimes)]
979                impl<#(#original_generics),*> #crate_root::ErrorCompat for #parameterized_error_name
980                where
981                    #(#where_clauses),*
982                {
983                    #backtrace_fn
984                }
985            };
986
987            stream.extend(error_compat_impl);
988        }
989    }
990
991    pub(crate) struct ErrorCompatBacktraceMatchArm<'a> {
992        pub(crate) crate_root: &'a dyn ToTokens,
993        pub(crate) field_container: &'a FieldContainer,
994        pub(crate) pattern_ident: &'a dyn ToTokens,
995    }
996
997    impl ToTokens for ErrorCompatBacktraceMatchArm<'_> {
998        fn to_tokens(&self, stream: &mut TokenStream) {
999            let Self {
1000                crate_root,
1001                field_container:
1002                    FieldContainer {
1003                        backtrace_field,
1004                        selector_kind,
1005                        ..
1006                    },
1007                pattern_ident,
1008            } = *self;
1009
1010            let match_arm = match (selector_kind.source_field(), backtrace_field) {
1011                (Some(source_field), _) if source_field.backtrace_delegate => {
1012                    let SourceField {
1013                        name: field_name, ..
1014                    } = source_field;
1015                    quote! {
1016                        #pattern_ident { ref #field_name, .. } => { #crate_root::ErrorCompat::backtrace(#field_name) }
1017                    }
1018                }
1019                (_, Some(backtrace_field)) => {
1020                    let Field {
1021                        name: field_name, ..
1022                    } = backtrace_field;
1023                    quote! {
1024                        #pattern_ident { ref #field_name, .. } => { #crate_root::AsBacktrace::as_backtrace(#field_name) }
1025                    }
1026                }
1027                _ => {
1028                    quote! {
1029                        #pattern_ident { .. } => { ::core::option::Option::None }
1030                    }
1031                }
1032            };
1033
1034            stream.extend(match_arm);
1035        }
1036    }
1037}