commonlibsse_ng\skse/trampoline.rs
1//! This module allows you to call another function before an existing function or to replace it with another function.
2use dashmap::DashMap;
3use retour::RawDetour;
4use std::sync::OnceLock;
5
6/// - Key: 'static pointer to original function
7/// - value: trampoline pair fn ptr.
8pub fn get_trampoline() -> &'static DashMap<usize, RawDetour> {
9 static HOOKS: OnceLock<DashMap<usize, RawDetour>> = OnceLock::new();
10 HOOKS.get_or_init(DashMap::new)
11}
12
13/// Sets up a trampoline to replace the original function with a new one.
14///
15/// # Safety
16/// This is an unsafe function as it manipulates raw pointers and enables function detouring.
17/// - Both the `original` and `replacement` pointers must point to valid functions with matching signatures.
18/// - Undefined behavior may occur if the functions have different calling conventions or incompatible arguments.
19///
20/// # Errors
21/// This function returns a `retour::Error` if:
22/// - The detour cannot be created due to invalid function pointers.
23/// - Enabling the hook fails.
24///
25/// # Example
26/// ```
27/// use retour::Error;
28/// use commonlibsse_ng::skse::trampoline::add_hook;
29///
30/// fn add5(val: i32) -> i32 {
31/// val + 5
32/// }
33///
34/// fn add10(val: i32) -> i32 {
35/// val + 10
36/// }
37///
38/// let original = add5 as *const ();
39/// let replacement = add10 as *const ();
40///
41/// // Verify the original behavior
42/// assert_eq!(add5(5), 10);
43///
44/// // Replace the original function with the new one
45/// unsafe { add_hook(original, replacement) }.unwrap();
46/// assert_eq!(add5(5), 15);
47/// ```
48pub unsafe fn add_hook(original: *const (), replacement: *const ()) -> Result<(), retour::Error> {
49 let detour = unsafe { RawDetour::new(original, replacement)? };
50 unsafe { detour.enable() }?;
51 get_trampoline().insert(original.addr(), detour);
52 Ok(())
53}
54
55/// Removes a previously added hook by disabling the trampoline and restoring the original function.
56///
57/// # Errors
58/// This function returns a `retour::Error` if:
59/// - Disabling the hook fails.
60///
61/// # Example
62/// ```
63/// use commonlibsse_ng::skse::trampoline::{add_hook, remove_hook};
64/// use retour::Error;
65///
66/// fn add5(val: i32) -> i32 {
67/// val + 5
68/// }
69///
70/// fn add10(val: i32) -> i32 {
71/// val + 10
72/// }
73///
74/// let original = add5 as *const ();
75/// let replacement = add10 as *const ();
76///
77/// // Verify the original behavior
78/// assert_eq!(add5(5), 10);
79///
80/// // Replace the original function with the new one
81/// unsafe { add_hook(original, replacement) }.unwrap();
82/// assert_eq!(add5(5), 15); // Initially, the behavior is to add5 to add10
83///
84/// // Remove the hook (if added previously)
85/// remove_hook(original).unwrap();
86/// assert_eq!(add5(5), 10);
87/// ```
88pub fn remove_hook(original: *const ()) -> Result<(), retour::Error> {
89 if let Some((_, detour)) = get_trampoline().remove(&original.addr()) {
90 unsafe { detour.disable()? };
91 }
92
93 Ok(())
94}
95
96/// Enables a previously added hook, replacing the original function with the new one.
97///
98/// # Errors
99/// This function returns a `retour::Error` if:
100/// - Enabling the hook fails.
101///
102/// # Example
103/// ```
104/// use commonlibsse_ng::skse::trampoline::{add_hook, remove_hook, enable_hook};
105/// use retour::Error;
106///
107/// fn add5(val: i32) -> i32 {
108/// val + 5
109/// }
110///
111/// fn add10(val: i32) -> i32 {
112/// val + 10
113/// }
114///
115/// let original = add5 as *const ();
116/// let replacement = add10 as *const ();
117///
118/// // Initially, the behavior is to add 5
119/// assert_eq!(add5(5), 10);
120///
121/// // Add a hook to replace `add5` with `add10`
122/// unsafe { add_hook(original, replacement) }.unwrap();
123/// assert_eq!(add5(5), 15);
124///
125/// assert!(enable_hook(original).is_ok());
126/// assert_eq!(add5(5), 15);
127/// ```
128pub fn enable_hook(original: *const ()) -> Result<(), retour::Error> {
129 if let Some(detour) = get_trampoline().get(&original.addr()) {
130 unsafe { detour.enable() }?;
131 }
132 Ok(())
133}
134
135/// # Description
136/// Disables a previously added hook, restoring the original function's behavior.
137///
138/// # Errors
139/// This function returns a `retour::Error` if:
140/// - Disabling the hook fails.
141///
142/// # Example
143/// ```
144/// use commonlibsse_ng::skse::trampoline::{add_hook, remove_hook, disable_hook};
145/// use retour::Error;
146///
147/// fn add5(val: i32) -> i32 {
148/// val + 5
149/// }
150///
151/// let original = add5 as *const ();
152///
153/// // Initially, the behavior is to add 5
154/// assert_eq!(add5(5), 10);
155///
156/// assert!(disable_hook(original).is_ok());
157/// assert_eq!(add5(5), 10);
158/// ```
159pub fn disable_hook(original: *const ()) -> Result<(), retour::Error> {
160 if let Some(detour) = get_trampoline().get(&original.addr()) {
161 unsafe { detour.disable() }?;
162 }
163 Ok(())
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use std::mem;
170
171 #[test]
172 #[allow(clippy::fn_to_numeric_cast_any)]
173 #[allow(clippy::missing_const_for_fn)]
174 fn test_raw_detour_with_dashmap() -> Result<(), retour::Error> {
175 #[allow(clippy::missing_const_for_fn)]
176 fn add5(val: i32) -> i32 {
177 val + 5
178 }
179
180 #[allow(clippy::missing_const_for_fn)]
181 fn add10(val: i32) -> i32 {
182 val + 10
183 }
184
185 let original = add5 as *const ();
186 let replacement = add10 as *const ();
187
188 assert_eq!(add5(5), 10);
189
190 unsafe { add_hook(original, replacement) }?;
191 assert_eq!(add5(5), 15);
192
193 {
194 let original_fn = get_trampoline().get(&original.addr()).unwrap();
195
196 // Get `add5` fn
197 let original_fn: fn(i32) -> i32 = unsafe { mem::transmute(original_fn.trampoline()) };
198 assert_eq!(original_fn(5), 10);
199 };
200
201 remove_hook(original)?;
202 assert_eq!(add5(5), 10);
203
204 Ok(())
205 }
206}