windows_implement/
lib.rs

1//! Implement COM interfaces for Rust types.
2//!
3//! Take a look at [macro@implement] for an example.
4//!
5//! Learn more about Rust for Windows here: <https://github.com/microsoft/windows-rs>
6
7use quote::{quote, ToTokens};
8
9/// Implements one or more COM interfaces.
10///
11/// # Example
12/// ```rust,no_run
13/// use windows_core::*;
14///
15/// #[interface("094d70d6-5202-44b8-abb8-43860da5aca2")]
16/// unsafe trait IValue: IUnknown {
17///     fn GetValue(&self, value: *mut i32) -> HRESULT;
18/// }
19///
20/// #[implement(IValue)]
21/// struct Value(i32);
22///
23/// impl IValue_Impl for Value_Impl {
24///     unsafe fn GetValue(&self, value: *mut i32) -> HRESULT {
25///         *value = self.0;
26///         HRESULT(0)
27///     }
28/// }
29///
30/// let object: IValue = Value(123).into();
31/// // Call interface methods...
32/// ```
33#[proc_macro_attribute]
34pub fn implement(
35    attributes: proc_macro::TokenStream,
36    original_type: proc_macro::TokenStream,
37) -> proc_macro::TokenStream {
38    let attributes = syn::parse_macro_input!(attributes as ImplementAttributes);
39    let interfaces_len = proc_macro2::Literal::usize_unsuffixed(attributes.implement.len());
40
41    let identity_type = if let Some(first) = attributes.implement.first() {
42        first.to_ident()
43    } else {
44        quote! { ::windows_core::IInspectable }
45    };
46
47    let original_type2 = original_type.clone();
48    let original_type2 = syn::parse_macro_input!(original_type2 as syn::ItemStruct);
49    let vis = &original_type2.vis;
50    let original_ident = &original_type2.ident;
51    let mut constraints = quote! {};
52
53    if let Some(where_clause) = &original_type2.generics.where_clause {
54        where_clause.predicates.to_tokens(&mut constraints);
55    }
56
57    let generics = if original_type2.generics.lt_token.is_some() {
58        let mut params = quote! {};
59        original_type2.generics.params.to_tokens(&mut params);
60        quote! { <#params> }
61    } else {
62        quote! { <> }
63    };
64
65    let impl_ident = quote::format_ident!("{}_Impl", original_ident);
66    let vtbl_idents = attributes
67        .implement
68        .iter()
69        .map(|implement| implement.to_vtbl_ident());
70    let vtbl_idents2 = vtbl_idents.clone();
71
72    let vtable_news = attributes
73        .implement
74        .iter()
75        .enumerate()
76        .map(|(enumerate, implement)| {
77            let vtbl_ident = implement.to_vtbl_ident();
78            let offset = proc_macro2::Literal::isize_unsuffixed(-1 - enumerate as isize);
79            quote! { #vtbl_ident::new::<Self, #offset>() }
80        });
81
82    let offset = attributes
83        .implement
84        .iter()
85        .enumerate()
86        .map(|(offset, _)| proc_macro2::Literal::usize_unsuffixed(offset));
87
88    let queries = attributes
89        .implement
90        .iter()
91        .enumerate()
92        .map(|(count, implement)| {
93            let vtbl_ident = implement.to_vtbl_ident();
94            let offset = proc_macro2::Literal::usize_unsuffixed(count);
95            quote! {
96                else if #vtbl_ident::matches(iid) {
97                    &self.vtables.#offset as *const _ as *mut _
98                }
99            }
100        });
101
102    // Dynamic casting requires that the object not contain non-static lifetimes.
103    let enable_dyn_casting = original_type2.generics.lifetimes().count() == 0;
104    let dynamic_cast_query = if enable_dyn_casting {
105        quote! {
106            else if *iid == ::windows_core::DYNAMIC_CAST_IID {
107                // DYNAMIC_CAST_IID is special. We _do not_ increase the reference count for this pseudo-interface.
108                // Also, instead of returning an interface pointer, we simply write the `&dyn Any` directly to the
109                // 'interface' pointer. Since the size of `&dyn Any` is 2 pointers, not one, the caller must be
110                // prepared for this. This is not a normal QueryInterface call.
111                //
112                // See the `Interface::cast_to_any` method, which is the only caller that should use DYNAMIC_CAST_ID.
113                (interface as *mut *const dyn core::any::Any).write(self as &dyn ::core::any::Any as *const dyn ::core::any::Any);
114                return ::windows_core::HRESULT(0);
115            }
116        }
117    } else {
118        quote!()
119    };
120
121    // The distance from the beginning of the generated type to the 'this' field, in units of pointers (not bytes).
122    let offset_of_this_in_pointers = 1 + attributes.implement.len();
123    let offset_of_this_in_pointers_token =
124        proc_macro2::Literal::usize_unsuffixed(offset_of_this_in_pointers);
125
126    let trust_level = proc_macro2::Literal::usize_unsuffixed(attributes.trust_level);
127
128    let conversions = attributes.implement.iter().enumerate().map(|(enumerate, implement)| {
129        let interface_ident = implement.to_ident();
130        let offset = proc_macro2::Literal::usize_unsuffixed(enumerate);
131        quote! {
132            impl #generics ::core::convert::From<#original_ident::#generics> for #interface_ident where #constraints {
133                #[inline(always)]
134                fn from(this: #original_ident::#generics) -> Self {
135                    let com_object = ::windows_core::ComObject::new(this);
136                    com_object.into_interface()
137                }
138            }
139
140            impl #generics ::windows_core::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints {
141                #[inline(always)]
142                fn as_interface_ref(&self) -> ::windows_core::InterfaceRef<'_, #interface_ident> {
143                    unsafe {
144                        let interface_ptr = &self.vtables.#offset;
145                        ::core::mem::transmute(interface_ptr)
146                    }
147                }
148            }
149
150            impl #generics ::windows_core::AsImpl<#original_ident::#generics> for #interface_ident where #constraints {
151                // SAFETY: the offset is guranteed to be in bounds, and the implementation struct
152                // is guaranteed to live at least as long as `self`.
153                #[inline(always)]
154                unsafe fn as_impl_ptr(&self) -> ::core::ptr::NonNull<#original_ident::#generics> {
155                    let this = ::windows_core::Interface::as_raw(self);
156                    // Subtract away the vtable offset plus 1, for the `identity` field, to get
157                    // to the impl struct which contains that original implementation type.
158                    let this = (this as *mut *mut ::core::ffi::c_void).sub(1 + #offset) as *mut #impl_ident::#generics;
159                    ::core::ptr::NonNull::new_unchecked(::core::ptr::addr_of!((*this).this) as *const #original_ident::#generics as *mut #original_ident::#generics)
160                }
161            }
162        }
163    });
164
165    let tokens = quote! {
166        #[repr(C)]
167        #[allow(non_camel_case_types)]
168        #vis struct #impl_ident #generics where #constraints {
169            identity: &'static ::windows_core::IInspectable_Vtbl,
170            vtables: (#(&'static #vtbl_idents,)*),
171            this: #original_ident::#generics,
172            count: ::windows_core::imp::WeakRefCount,
173        }
174
175        impl #generics #impl_ident::#generics where #constraints {
176            const VTABLES: (#(#vtbl_idents2,)*) = (#(#vtable_news,)*);
177            const IDENTITY: ::windows_core::IInspectable_Vtbl = ::windows_core::IInspectable_Vtbl::new::<Self, #identity_type, 0>();
178        }
179
180        impl #generics #original_ident::#generics where #constraints {
181            /// This converts a partially-constructed COM object (in the sense that it contains
182            /// application state but does not yet have vtable and reference count constructed)
183            /// into a `StaticComObject`. This allows the COM object to be stored in static
184            /// (global) variables.
185            pub const fn into_static(self) -> ::windows_core::StaticComObject<Self> {
186                ::windows_core::StaticComObject::from_outer(self.into_outer())
187            }
188
189            // This constructs an "outer" object. This should only be used by the implementation
190            // of the outer object, never by application code.
191            //
192            // The callers of this function (`into_static` and `into_object`) are both responsible
193            // for maintaining one of our invariants: Application code never has an owned instance
194            // of the outer (implementation) type. into_static() maintains this invariant by
195            // returning a wrapped StaticComObject value, which owns its contents but never gives
196            // application code a way to mutably access its contents. This prevents the refcount
197            // shearing problem.
198            //
199            // TODO: Make it impossible for app code to call this function, by placing it in a
200            // module and marking this as private to the module.
201            #[inline(always)]
202            const fn into_outer(self) -> #impl_ident::#generics {
203                #impl_ident::#generics {
204                    identity: &#impl_ident::#generics::IDENTITY,
205                    vtables: (#(&#impl_ident::#generics::VTABLES.#offset,)*),
206                    this: self,
207                    count: ::windows_core::imp::WeakRefCount::new(),
208                }
209            }
210        }
211
212        impl #generics ::windows_core::ComObjectInner for #original_ident::#generics where #constraints {
213            type Outer = #impl_ident::#generics;
214
215            // IMPORTANT! This function handles assembling the "boxed" type of a COM object.
216            // It immediately moves the box into a heap allocation (box) and returns only a ComObject
217            // reference that points to it. We intentionally _do not_ expose any owned instances of
218            // Foo_Impl to safe Rust code, because doing so would allow unsound behavior in safe Rust
219            // code, due to the adjustments of the reference count that Foo_Impl permits.
220            //
221            // This is why this function returns ComObject<Self> instead of returning #impl_ident.
222
223            fn into_object(self) -> ::windows_core::ComObject<Self> {
224                let boxed = ::windows_core::imp::Box::<#impl_ident::#generics>::new(self.into_outer());
225                unsafe {
226                    let ptr = ::windows_core::imp::Box::into_raw(boxed);
227                    ::windows_core::ComObject::from_raw(
228                        ::core::ptr::NonNull::new_unchecked(ptr)
229                    )
230                }
231            }
232        }
233
234        impl #generics ::windows_core::IUnknownImpl for #impl_ident::#generics where #constraints {
235            type Impl = #original_ident::#generics;
236
237            #[inline(always)]
238            fn get_impl(&self) -> &Self::Impl {
239                &self.this
240            }
241
242            #[inline(always)]
243            fn get_impl_mut(&mut self) -> &mut Self::Impl {
244                &mut self.this
245            }
246
247            #[inline(always)]
248            fn is_reference_count_one(&self) -> bool {
249                self.count.is_one()
250            }
251
252            #[inline(always)]
253            fn into_inner(self) -> Self::Impl {
254                self.this
255            }
256
257            unsafe fn QueryInterface(&self, iid: *const ::windows_core::GUID, interface: *mut *mut ::core::ffi::c_void) -> ::windows_core::HRESULT {
258                if iid.is_null() || interface.is_null() {
259                    return ::windows_core::imp::E_POINTER;
260                }
261
262                let iid = &*iid;
263
264                let interface_ptr: *mut ::core::ffi::c_void = if iid == &<::windows_core::IUnknown as ::windows_core::Interface>::IID
265                    || iid == &<::windows_core::IInspectable as ::windows_core::Interface>::IID
266                    || iid == &<::windows_core::imp::IAgileObject as ::windows_core::Interface>::IID {
267                        &self.identity as *const _ as *mut _
268                }
269                #(#queries)*
270                #dynamic_cast_query
271                else {
272                    ::core::ptr::null_mut()
273                };
274
275                if !interface_ptr.is_null() {
276                    *interface = interface_ptr;
277                    self.count.add_ref();
278                    return ::windows_core::HRESULT(0);
279                }
280
281                let interface_ptr = self.count.query(iid, &self.identity as *const _ as *mut _);
282                *interface = interface_ptr;
283
284                if interface_ptr.is_null() {
285                    ::windows_core::imp::E_NOINTERFACE
286                } else {
287                    ::windows_core::HRESULT(0)
288                }
289            }
290
291            #[inline(always)]
292            fn AddRef(&self) -> u32 {
293                self.count.add_ref()
294            }
295
296            #[inline(always)]
297            unsafe fn Release(self_: *mut Self) -> u32 {
298                let remaining = (*self_).count.release();
299                if remaining == 0 {
300                    _ = ::windows_core::imp::Box::from_raw(self_);
301                }
302                remaining
303            }
304
305            unsafe fn GetTrustLevel(&self, value: *mut i32) -> ::windows_core::HRESULT {
306                if value.is_null() {
307                    return ::windows_core::imp::E_POINTER;
308                }
309                *value = #trust_level;
310                ::windows_core::HRESULT(0)
311            }
312
313            unsafe fn from_inner_ref(inner: &Self::Impl) -> &Self {
314                &*((inner as *const Self::Impl as *const *const ::core::ffi::c_void)
315                    .sub(#offset_of_this_in_pointers_token) as *const Self)
316            }
317
318            fn to_object(&self) -> ::windows_core::ComObject<Self::Impl> {
319                self.count.add_ref();
320                unsafe {
321                    ::windows_core::ComObject::from_raw(
322                        ::core::ptr::NonNull::new_unchecked(self as *const Self as *mut Self)
323                    )
324                }
325            }
326
327            const INNER_OFFSET_IN_POINTERS: usize = #offset_of_this_in_pointers_token;
328        }
329
330        impl #generics #original_ident::#generics where #constraints {
331            /// Try casting as the provided interface
332            ///
333            /// # Safety
334            ///
335            /// This function can only be safely called if `self` has been heap allocated and pinned using
336            /// the mechanisms provided by `implement` macro.
337            #[inline(always)]
338            unsafe fn cast<I: ::windows_core::Interface>(&self) -> ::windows_core::Result<I> {
339                let boxed = (self as *const _ as *const *mut ::core::ffi::c_void).sub(1 + #interfaces_len) as *mut #impl_ident::#generics;
340                let mut result = ::core::ptr::null_mut();
341                _ = <#impl_ident::#generics as ::windows_core::IUnknownImpl>::QueryInterface(&*boxed, &I::IID, &mut result);
342                ::windows_core::Type::from_abi(result)
343            }
344        }
345
346        impl #generics ::core::convert::From<#original_ident::#generics> for ::windows_core::IUnknown where #constraints {
347            #[inline(always)]
348            fn from(this: #original_ident::#generics) -> Self {
349                let com_object = ::windows_core::ComObject::new(this);
350                com_object.into_interface()
351            }
352        }
353
354        impl #generics ::core::convert::From<#original_ident::#generics> for ::windows_core::IInspectable where #constraints {
355            #[inline(always)]
356            fn from(this: #original_ident::#generics) -> Self {
357                let com_object = ::windows_core::ComObject::new(this);
358                com_object.into_interface()
359            }
360        }
361
362        impl #generics ::windows_core::ComObjectInterface<::windows_core::IUnknown> for #impl_ident::#generics where #constraints {
363            #[inline(always)]
364            fn as_interface_ref(&self) -> ::windows_core::InterfaceRef<'_, ::windows_core::IUnknown> {
365                unsafe {
366                    let interface_ptr = &self.identity;
367                    ::core::mem::transmute(interface_ptr)
368                }
369            }
370        }
371
372        impl #generics ::windows_core::ComObjectInterface<::windows_core::IInspectable> for #impl_ident::#generics where #constraints {
373            #[inline(always)]
374            fn as_interface_ref(&self) -> ::windows_core::InterfaceRef<'_, ::windows_core::IInspectable> {
375                unsafe {
376                    let interface_ptr = &self.identity;
377                    ::core::mem::transmute(interface_ptr)
378                }
379            }
380        }
381
382        impl #generics ::windows_core::AsImpl<#original_ident::#generics> for ::windows_core::IUnknown where #constraints {
383            // SAFETY: the offset is guranteed to be in bounds, and the implementation struct
384            // is guaranteed to live at least as long as `self`.
385            #[inline(always)]
386            unsafe fn as_impl_ptr(&self) -> ::core::ptr::NonNull<#original_ident::#generics> {
387                let this = ::windows_core::Interface::as_raw(self);
388                // Subtract away the vtable offset plus 1, for the `identity` field, to get
389                // to the impl struct which contains that original implementation type.
390                let this = (this as *mut *mut ::core::ffi::c_void).sub(1) as *mut #impl_ident::#generics;
391                ::core::ptr::NonNull::new_unchecked(::core::ptr::addr_of!((*this).this) as *const #original_ident::#generics as *mut #original_ident::#generics)
392            }
393        }
394
395        impl #generics ::core::ops::Deref for #impl_ident::#generics where #constraints {
396            type Target = #original_ident::#generics;
397
398            #[inline(always)]
399            fn deref(&self) -> &Self::Target {
400                &self.this
401            }
402        }
403
404        // We intentionally do not provide a DerefMut impl, due to paranoia around soundness.
405
406        #(#conversions)*
407    };
408
409    let mut tokens: proc_macro::TokenStream = tokens.into();
410    tokens.extend(core::iter::once(original_type));
411    tokens
412}
413
414#[derive(Default)]
415struct ImplementType {
416    type_name: String,
417    generics: Vec<ImplementType>,
418}
419
420impl ImplementType {
421    fn to_ident(&self) -> proc_macro2::TokenStream {
422        let type_name = syn::parse_str::<proc_macro2::TokenStream>(&self.type_name)
423            .expect("Invalid token stream");
424        let generics = self.generics.iter().map(|g| g.to_ident());
425        quote! { #type_name<#(#generics,)*> }
426    }
427    fn to_vtbl_ident(&self) -> proc_macro2::TokenStream {
428        let ident = self.to_ident();
429        quote! {
430            <#ident as ::windows_core::Interface>::Vtable
431        }
432    }
433}
434
435#[derive(Default)]
436struct ImplementAttributes {
437    pub implement: Vec<ImplementType>,
438    pub trust_level: usize,
439}
440
441impl syn::parse::Parse for ImplementAttributes {
442    fn parse(cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
443        let mut input = Self::default();
444
445        while !cursor.is_empty() {
446            input.parse_implement(cursor)?;
447        }
448
449        Ok(input)
450    }
451}
452
453impl ImplementAttributes {
454    fn parse_implement(&mut self, cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result<()> {
455        let tree = cursor.parse::<UseTree2>()?;
456        self.walk_implement(&tree, &mut String::new())?;
457
458        if !cursor.is_empty() {
459            cursor.parse::<syn::Token![,]>()?;
460        }
461
462        Ok(())
463    }
464
465    fn walk_implement(
466        &mut self,
467        tree: &UseTree2,
468        namespace: &mut String,
469    ) -> syn::parse::Result<()> {
470        match tree {
471            UseTree2::Path(input) => {
472                if !namespace.is_empty() {
473                    namespace.push_str("::");
474                }
475
476                namespace.push_str(&input.ident.to_string());
477                self.walk_implement(&input.tree, namespace)?;
478            }
479            UseTree2::Name(_) => {
480                self.implement.push(tree.to_element_type(namespace)?);
481            }
482            UseTree2::Group(input) => {
483                for tree in &input.items {
484                    self.walk_implement(tree, namespace)?;
485                }
486            }
487            UseTree2::TrustLevel(input) => self.trust_level = *input,
488        }
489
490        Ok(())
491    }
492}
493
494enum UseTree2 {
495    Path(UsePath2),
496    Name(UseName2),
497    Group(UseGroup2),
498    TrustLevel(usize),
499}
500
501impl UseTree2 {
502    fn to_element_type(&self, namespace: &mut String) -> syn::parse::Result<ImplementType> {
503        match self {
504            UseTree2::Path(input) => {
505                if !namespace.is_empty() {
506                    namespace.push_str("::");
507                }
508
509                namespace.push_str(&input.ident.to_string());
510                input.tree.to_element_type(namespace)
511            }
512            UseTree2::Name(input) => {
513                let mut type_name = input.ident.to_string();
514
515                if !namespace.is_empty() {
516                    type_name = format!("{namespace}::{type_name}");
517                }
518
519                let mut generics = vec![];
520
521                for g in &input.generics {
522                    generics.push(g.to_element_type(&mut String::new())?);
523                }
524
525                Ok(ImplementType {
526                    type_name,
527                    generics,
528                })
529            }
530            UseTree2::Group(input) => Err(syn::parse::Error::new(
531                input.brace_token.span.join(),
532                "Syntax not supported",
533            )),
534            _ => unimplemented!(),
535        }
536    }
537}
538
539struct UsePath2 {
540    pub ident: syn::Ident,
541    pub tree: Box<UseTree2>,
542}
543
544struct UseName2 {
545    pub ident: syn::Ident,
546    pub generics: Vec<UseTree2>,
547}
548
549struct UseGroup2 {
550    pub brace_token: syn::token::Brace,
551    pub items: syn::punctuated::Punctuated<UseTree2, syn::Token![,]>,
552}
553
554impl syn::parse::Parse for UseTree2 {
555    fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<UseTree2> {
556        let lookahead = input.lookahead1();
557        if lookahead.peek(syn::Ident) {
558            use syn::ext::IdentExt;
559            let ident = input.call(syn::Ident::parse_any)?;
560            if input.peek(syn::Token![::]) {
561                input.parse::<syn::Token![::]>()?;
562                Ok(UseTree2::Path(UsePath2 {
563                    ident,
564                    tree: Box::new(input.parse()?),
565                }))
566            } else if input.peek(syn::Token![=]) {
567                if ident != "TrustLevel" {
568                    return Err(syn::parse::Error::new(
569                        ident.span(),
570                        "Unrecognized key-value pair",
571                    ));
572                }
573                input.parse::<syn::Token![=]>()?;
574                let span = input.span();
575                let value = input.call(syn::Ident::parse_any)?;
576                match value.to_string().as_str() {
577                    "Partial" => Ok(UseTree2::TrustLevel(1)),
578                    "Full" => Ok(UseTree2::TrustLevel(2)),
579                    _ => Err(syn::parse::Error::new(
580                        span,
581                        "`TrustLevel` must be `Partial` or `Full`",
582                    )),
583                }
584            } else {
585                let generics = if input.peek(syn::Token![<]) {
586                    input.parse::<syn::Token![<]>()?;
587                    let mut generics = Vec::new();
588                    loop {
589                        generics.push(input.parse::<UseTree2>()?);
590
591                        if input.parse::<syn::Token![,]>().is_err() {
592                            break;
593                        }
594                    }
595                    input.parse::<syn::Token![>]>()?;
596                    generics
597                } else {
598                    Vec::new()
599                };
600
601                Ok(UseTree2::Name(UseName2 { ident, generics }))
602            }
603        } else if lookahead.peek(syn::token::Brace) {
604            let content;
605            let brace_token = syn::braced!(content in input);
606            let items = content.parse_terminated(UseTree2::parse, syn::Token![,])?;
607
608            Ok(UseTree2::Group(UseGroup2 { brace_token, items }))
609        } else {
610            Err(lookahead.error())
611        }
612    }
613}