commonlibsse_ng\skse/
trampoline.rs

1//! This module allows you to call another function before an existing function or to replace it with another function.
2use dashmap::DashMap;
3use retour::RawDetour;
4use std::sync::OnceLock;
5
6/// - Key: 'static pointer to original function
7/// - value: trampoline pair fn ptr.
8pub fn get_trampoline() -> &'static DashMap<usize, RawDetour> {
9    static HOOKS: OnceLock<DashMap<usize, RawDetour>> = OnceLock::new();
10    HOOKS.get_or_init(DashMap::new)
11}
12
13/// Sets up a trampoline to replace the original function with a new one.
14///
15/// # Safety
16/// This is an unsafe function as it manipulates raw pointers and enables function detouring.
17/// - Both the `original` and `replacement` pointers must point to valid functions with matching signatures.
18/// - Undefined behavior may occur if the functions have different calling conventions or incompatible arguments.
19///
20/// # Errors
21/// This function returns a `retour::Error` if:
22/// - The detour cannot be created due to invalid function pointers.
23/// - Enabling the hook fails.
24///
25/// # Example
26/// ```
27/// use retour::Error;
28/// use commonlibsse_ng::skse::trampoline::add_hook;
29///
30/// fn add5(val: i32) -> i32 {
31///     val + 5
32/// }
33///
34/// fn add10(val: i32) -> i32 {
35///     val + 10
36/// }
37///
38/// let original = add5 as *const ();
39/// let replacement = add10 as *const ();
40///
41/// // Verify the original behavior
42/// assert_eq!(add5(5), 10);
43///
44/// // Replace the original function with the new one
45/// unsafe { add_hook(original, replacement) }.unwrap();
46/// assert_eq!(add5(5), 15);
47/// ```
48pub unsafe fn add_hook(original: *const (), replacement: *const ()) -> Result<(), retour::Error> {
49    let detour = unsafe { RawDetour::new(original, replacement)? };
50    unsafe { detour.enable() }?;
51    get_trampoline().insert(original.addr(), detour);
52    Ok(())
53}
54
55/// Removes a previously added hook by disabling the trampoline and restoring the original function.
56///
57/// # Errors
58/// This function returns a `retour::Error` if:
59/// - Disabling the hook fails.
60///
61/// # Example
62/// ```
63/// use commonlibsse_ng::skse::trampoline::{add_hook, remove_hook};
64/// use retour::Error;
65///
66/// fn add5(val: i32) -> i32 {
67///     val + 5
68/// }
69///
70/// fn add10(val: i32) -> i32 {
71///     val + 10
72/// }
73///
74/// let original = add5 as *const ();
75/// let replacement = add10 as *const ();
76///
77/// // Verify the original behavior
78/// assert_eq!(add5(5), 10);
79///
80/// // Replace the original function with the new one
81/// unsafe { add_hook(original, replacement) }.unwrap();
82/// assert_eq!(add5(5), 15); // Initially, the behavior is to add5 to add10
83///
84/// // Remove the hook (if added previously)
85/// remove_hook(original).unwrap();
86/// assert_eq!(add5(5), 10);
87/// ```
88pub fn remove_hook(original: *const ()) -> Result<(), retour::Error> {
89    if let Some((_, detour)) = get_trampoline().remove(&original.addr()) {
90        unsafe { detour.disable()? };
91    }
92
93    Ok(())
94}
95
96/// Enables a previously added hook, replacing the original function with the new one.
97///
98/// # Errors
99/// This function returns a `retour::Error` if:
100/// - Enabling the hook fails.
101///
102/// # Example
103/// ```
104/// use commonlibsse_ng::skse::trampoline::{add_hook, remove_hook, enable_hook};
105/// use retour::Error;
106///
107/// fn add5(val: i32) -> i32 {
108///     val + 5
109/// }
110///
111/// fn add10(val: i32) -> i32 {
112///     val + 10
113/// }
114///
115/// let original = add5 as *const ();
116/// let replacement = add10 as *const ();
117///
118/// // Initially, the behavior is to add 5
119/// assert_eq!(add5(5), 10);
120///
121/// // Add a hook to replace `add5` with `add10`
122/// unsafe { add_hook(original, replacement) }.unwrap();
123/// assert_eq!(add5(5), 15);
124///
125/// assert!(enable_hook(original).is_ok());
126/// assert_eq!(add5(5), 15);
127/// ```
128pub fn enable_hook(original: *const ()) -> Result<(), retour::Error> {
129    if let Some(detour) = get_trampoline().get(&original.addr()) {
130        unsafe { detour.enable() }?;
131    }
132    Ok(())
133}
134
135/// # Description
136/// Disables a previously added hook, restoring the original function's behavior.
137///
138/// # Errors
139/// This function returns a `retour::Error` if:
140/// - Disabling the hook fails.
141///
142/// # Example
143/// ```
144/// use commonlibsse_ng::skse::trampoline::{add_hook, remove_hook, disable_hook};
145/// use retour::Error;
146///
147/// fn add5(val: i32) -> i32 {
148///     val + 5
149/// }
150///
151/// let original = add5 as *const ();
152///
153/// // Initially, the behavior is to add 5
154/// assert_eq!(add5(5), 10);
155///
156/// assert!(disable_hook(original).is_ok());
157/// assert_eq!(add5(5), 10);
158/// ```
159pub fn disable_hook(original: *const ()) -> Result<(), retour::Error> {
160    if let Some(detour) = get_trampoline().get(&original.addr()) {
161        unsafe { detour.disable() }?;
162    }
163    Ok(())
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use std::mem;
170
171    #[test]
172    #[allow(clippy::fn_to_numeric_cast_any)]
173    #[allow(clippy::missing_const_for_fn)]
174    fn test_raw_detour_with_dashmap() -> Result<(), retour::Error> {
175        #[allow(clippy::missing_const_for_fn)]
176        fn add5(val: i32) -> i32 {
177            val + 5
178        }
179
180        #[allow(clippy::missing_const_for_fn)]
181        fn add10(val: i32) -> i32 {
182            val + 10
183        }
184
185        let original = add5 as *const ();
186        let replacement = add10 as *const ();
187
188        assert_eq!(add5(5), 10);
189
190        unsafe { add_hook(original, replacement) }?;
191        assert_eq!(add5(5), 15);
192
193        {
194            let original_fn = get_trampoline().get(&original.addr()).unwrap();
195
196            // Get `add5` fn
197            let original_fn: fn(i32) -> i32 = unsafe { mem::transmute(original_fn.trampoline()) };
198            assert_eq!(original_fn(5), 10);
199        };
200
201        remove_hook(original)?;
202        assert_eq!(add5(5), 10);
203
204        Ok(())
205    }
206}