1use crate::{os, util, Protection, QueryIter, Region, Result};
2
3#[inline]
54pub unsafe fn protect<T>(address: *const T, size: usize, protection: Protection) -> Result<()> {
55 let (address, size) = util::round_to_page_boundaries(address, size)?;
56 os::protect(address.cast(), size, protection)
57}
58
59#[allow(clippy::missing_inline_in_public_items)]
102pub unsafe fn protect_with_handle<T>(
103 address: *const T,
104 size: usize,
105 protection: Protection,
106) -> Result<ProtectGuard> {
107 let (address, size) = util::round_to_page_boundaries(address, size)?;
108
109 let mut regions = QueryIter::new(address, size)?.collect::<Result<Vec<_>>>()?;
111
112 protect(address, size, protection)?;
114
115 if let Some(region) = regions.first_mut() {
116 region.base = address.cast();
118 region.size -= address as usize - region.as_range().start;
119 }
120
121 if let Some(region) = regions.last_mut() {
122 let protect_end = address as usize + size;
124 region.size -= region.as_range().end - protect_end;
125 }
126
127 Ok(ProtectGuard::new(regions))
128}
129
130#[must_use]
135pub struct ProtectGuard {
136 regions: Vec<Region>,
137}
138
139impl ProtectGuard {
140 #[inline(always)]
141 fn new(regions: Vec<Region>) -> Self {
142 Self { regions }
143 }
144}
145
146impl Drop for ProtectGuard {
147 #[inline]
148 fn drop(&mut self) {
149 let result = self
150 .regions
151 .iter()
152 .try_for_each(|region| unsafe { protect(region.base, region.size, region.protection) });
153 debug_assert!(result.is_ok(), "restoring region protection: {:?}", result);
154 }
155}
156
157unsafe impl Send for ProtectGuard {}
158unsafe impl Sync for ProtectGuard {}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::tests::util::alloc_pages;
164 use crate::{page, query, query_range};
165
166 #[test]
167 fn protect_null_fails() {
168 assert!(unsafe { protect(std::ptr::null::<()>(), 0, Protection::NONE) }.is_err());
169 }
170
171 #[test]
172 #[cfg(not(any(
173 target_os = "openbsd",
174 target_os = "netbsd",
175 all(target_vendor = "apple", target_arch = "aarch64")
176 )))]
177 fn protect_can_alter_text_segments() {
178 #[allow(clippy::ptr_as_ptr)]
179 let address = &mut protect_can_alter_text_segments as *mut _ as *mut u8;
180 unsafe {
181 protect(address, 1, Protection::READ_WRITE_EXECUTE).unwrap();
182 *address = 0x90;
183 }
184 }
185
186 #[test]
187 fn protect_updates_both_pages_for_straddling_range() -> Result<()> {
188 let pz = page::size();
189
190 let map = alloc_pages(&[
193 Protection::READ,
194 Protection::READ_EXECUTE,
195 Protection::READ_WRITE,
196 Protection::READ,
197 ]);
198
199 let exec_page = unsafe { map.as_ptr().add(pz) };
200 let exec_page_end = unsafe { exec_page.add(pz - 1) };
201
202 unsafe {
204 protect(exec_page_end, 2, Protection::NONE)?;
205 }
206
207 let result = query_range(exec_page, pz * 2)?.collect::<Result<Vec<_>>>()?;
209
210 assert!(matches!(result.len(), 1 | 2));
212 assert_eq!(result.iter().map(Region::len).sum::<usize>(), pz * 2);
213 assert_eq!(result[0].protection(), Protection::NONE);
214 Ok(())
215 }
216
217 #[test]
218 fn protect_has_inclusive_lower_and_exclusive_upper_bound() -> Result<()> {
219 let map = alloc_pages(&[
220 Protection::READ_WRITE,
221 Protection::READ,
222 Protection::READ_WRITE,
223 Protection::READ,
224 ]);
225
226 let second_page = unsafe { map.as_ptr().add(page::size()) };
228 unsafe {
229 let second_page_end = second_page.offset(page::size() as isize - 1);
230 protect(second_page_end, 1, Protection::NONE)?;
231 }
232
233 let regions = query_range(map.as_ptr(), page::size() * 3)?.collect::<Result<Vec<_>>>()?;
234 assert_eq!(regions.len(), 3);
235 assert_eq!(regions[0].protection(), Protection::READ_WRITE);
236 assert_eq!(regions[1].protection(), Protection::NONE);
237 assert_eq!(regions[2].protection(), Protection::READ_WRITE);
238
239 unsafe {
241 protect(second_page, page::size() + 1, Protection::READ_EXECUTE)?;
242 }
243
244 let regions = query_range(map.as_ptr(), page::size() * 3)?.collect::<Result<Vec<_>>>()?;
245 assert!(regions.len() >= 2);
246 assert_eq!(regions[0].protection(), Protection::READ_WRITE);
247 assert_eq!(regions[1].protection(), Protection::READ_EXECUTE);
248 assert!(regions[1].len() >= page::size());
249
250 Ok(())
251 }
252
253 #[test]
254 fn protect_with_handle_resets_protection() -> Result<()> {
255 let map = alloc_pages(&[Protection::READ]);
256
257 unsafe {
258 let _handle = protect_with_handle(map.as_ptr(), page::size(), Protection::READ_WRITE)?;
259 assert_eq!(query(map.as_ptr())?.protection(), Protection::READ_WRITE);
260 };
261
262 assert_eq!(query(map.as_ptr())?.protection(), Protection::READ);
263 Ok(())
264 }
265
266 #[test]
267 fn protect_with_handle_only_alters_protection_of_affected_pages() -> Result<()> {
268 let pages = [
269 Protection::READ_WRITE,
270 Protection::READ,
271 Protection::READ_WRITE,
272 Protection::READ_EXECUTE,
273 Protection::NONE,
274 ];
275 let map = alloc_pages(&pages);
276
277 let second_page = unsafe { map.as_ptr().add(page::size()) };
278 let region_size = page::size() * 3;
279
280 unsafe {
281 let _handle = protect_with_handle(second_page, region_size, Protection::NONE)?;
282 let region = query(second_page)?;
283
284 assert_eq!(region.protection(), Protection::NONE);
285 assert_eq!(region.as_ptr(), second_page);
286 }
287
288 let regions =
289 query_range(map.as_ptr(), page::size() * pages.len())?.collect::<Result<Vec<_>>>()?;
290
291 assert_eq!(regions.len(), 5);
292 assert_eq!(regions[0].as_ptr(), map.as_ptr());
293 for i in 0..pages.len() {
294 assert_eq!(regions[i].protection(), pages[i]);
295 }
296
297 Ok(())
298 }
299}