1
0
mirror of https://github.com/google/comprehensive-rust.git synced 2025-06-17 14:47:35 +02:00

Use safe-mmio crate in PL011 UART driver example and RTC exercise solution (#2752)

This commit is contained in:
Andrew Walbran
2025-06-02 15:28:06 +01:00
committed by GitHub
parent f3e369274a
commit 64ef712d7d
15 changed files with 122 additions and 113 deletions

View File

@ -11,5 +11,7 @@ with bitflags.
- The `bitflags!` macro creates a newtype something like `Flags(u16)`, along - The `bitflags!` macro creates a newtype something like `Flags(u16)`, along
with a bunch of method implementations to get and set flags. with a bunch of method implementations to get and set flags.
- We need to derive `FromBytes` and `IntoBytes` for use with `safe-mmio`, which
we'll see on the next page.
</details> </details>

View File

@ -8,7 +8,16 @@ Now let's use the new `Registers` struct in our driver.
<details> <details>
- Note the use of `&raw const` / `&raw mut` to get pointers to individual fields - `UniqueMmioPointer` is a wrapper around a raw pointer to an MMIO device or
without creating an intermediate reference, which would be unsound. register. The caller of `UniqueMmioPointer::new` promises that it is valid and
unique for the given lifetime, so it can provide safe methods to read and
write fields.
- These MMIO accesses are generally a wrapper around `read_volatile` and
`write_volatile`, though on aarch64 they are instead implemented in assembly
to work around a bug where the compiler can emit instructions that prevent
MMIO virtualisation.
- The `field!` and `field_shared!` macros internally use `&raw mut` and
`&raw const` to get pointers to individual fields without creating an
intermediate reference, which would be unsound.
</details> </details>

View File

@ -1,6 +1,8 @@
# Multiple registers # Multiple registers
We can use a struct to represent the memory layout of the UART's registers. We can use a struct to represent the memory layout of the UART's registers,
using types from the `safe-mmio` crate to wrap ones which can be read or written
safely.
<!-- mdbook-xgettext: skip --> <!-- mdbook-xgettext: skip -->
@ -15,5 +17,12 @@ We can use a struct to represent the memory layout of the UART's registers.
rules as C. This is necessary for our struct to have a predictable layout, as rules as C. This is necessary for our struct to have a predictable layout, as
default Rust representation allows the compiler to (among other things) default Rust representation allows the compiler to (among other things)
reorder fields however it sees fit. reorder fields however it sees fit.
- There are a number of different crates providing safe abstractions around MMIO
operations; we recommend the `safe-mmio` crate.
- The difference between `ReadPure` or `ReadOnly` (and likewise between
`ReadPureWrite` and `ReadWrite`) is whether reading a register can have
side-effects which change the state of the device. E.g. reading the data
register pops a byte from the receive FIFO. `ReadPure` means that reads have
no side-effects, they are purely reading data.
</details> </details>

View File

@ -31,8 +31,10 @@ dependencies = [
"arm-pl011-uart", "arm-pl011-uart",
"bitflags", "bitflags",
"log", "log",
"safe-mmio",
"smccc", "smccc",
"spin", "spin",
"zerocopy",
] ]
[[package]] [[package]]

View File

@ -12,8 +12,10 @@ aarch64-rt = "0.1.3"
arm-pl011-uart = "0.3.1" arm-pl011-uart = "0.3.1"
bitflags = "2.9.1" bitflags = "2.9.1"
log = "0.4.27" log = "0.4.27"
safe-mmio = "0.2.5"
smccc = "0.2.0" smccc = "0.2.0"
spin = "0.10.0" spin = "0.10.0"
zerocopy = "0.8.25"
[[bin]] [[bin]]
name = "improved" name = "improved"

View File

@ -21,7 +21,7 @@ use spin::mutex::SpinMutex;
static LOGGER: Logger = Logger { uart: SpinMutex::new(None) }; static LOGGER: Logger = Logger { uart: SpinMutex::new(None) };
struct Logger { struct Logger {
uart: SpinMutex<Option<Uart>>, uart: SpinMutex<Option<Uart<'static>>>,
} }
impl Log for Logger { impl Log for Logger {
@ -43,7 +43,10 @@ impl Log for Logger {
} }
/// Initialises UART logger. /// Initialises UART logger.
pub fn init(uart: Uart, max_level: LevelFilter) -> Result<(), SetLoggerError> { pub fn init(
uart: Uart<'static>,
max_level: LevelFilter,
) -> Result<(), SetLoggerError> {
LOGGER.uart.lock().replace(uart); LOGGER.uart.lock().replace(uart);
log::set_logger(&LOGGER)?; log::set_logger(&LOGGER)?;

View File

@ -22,19 +22,22 @@ mod pl011;
use crate::pl011::Uart; use crate::pl011::Uart;
use core::fmt::Write; use core::fmt::Write;
use core::panic::PanicInfo; use core::panic::PanicInfo;
use core::ptr::NonNull;
use log::error; use log::error;
use safe_mmio::UniqueMmioPointer;
use smccc::Hvc; use smccc::Hvc;
use smccc::psci::system_off; use smccc::psci::system_off;
/// Base address of the primary PL011 UART. /// Base address of the primary PL011 UART.
const PL011_BASE_ADDRESS: *mut u32 = 0x900_0000 as _; const PL011_BASE_ADDRESS: NonNull<pl011::Registers> =
NonNull::new(0x900_0000 as _).unwrap();
// SAFETY: There is no other global function of this name. // SAFETY: There is no other global function of this name.
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) { extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) {
// SAFETY: `PL011_BASE_ADDRESS` is the base address of a PL011 device, and // SAFETY: `PL011_BASE_ADDRESS` is the base address of a PL011 device, and
// nothing else accesses that address range. // nothing else accesses that address range.
let mut uart = unsafe { Uart::new(PL011_BASE_ADDRESS) }; let mut uart = unsafe { Uart::new(UniqueMmioPointer::new(PL011_BASE_ADDRESS)) };
writeln!(uart, "main({x0:#x}, {x1:#x}, {x2:#x}, {x3:#x})").unwrap(); writeln!(uart, "main({x0:#x}, {x1:#x}, {x2:#x}, {x3:#x})").unwrap();

View File

@ -22,19 +22,22 @@ mod pl011;
use crate::pl011::Uart; use crate::pl011::Uart;
use core::panic::PanicInfo; use core::panic::PanicInfo;
use core::ptr::NonNull;
use log::{LevelFilter, error, info}; use log::{LevelFilter, error, info};
use safe_mmio::UniqueMmioPointer;
use smccc::Hvc; use smccc::Hvc;
use smccc::psci::system_off; use smccc::psci::system_off;
/// Base address of the primary PL011 UART. /// Base address of the primary PL011 UART.
const PL011_BASE_ADDRESS: *mut u32 = 0x900_0000 as _; const PL011_BASE_ADDRESS: NonNull<pl011::Registers> =
NonNull::new(0x900_0000 as _).unwrap();
// SAFETY: There is no other global function of this name. // SAFETY: There is no other global function of this name.
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) { extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) {
// SAFETY: `PL011_BASE_ADDRESS` is the base address of a PL011 device, and // SAFETY: `PL011_BASE_ADDRESS` is the base address of a PL011 device, and
// nothing else accesses that address range. // nothing else accesses that address range.
let uart = unsafe { Uart::new(PL011_BASE_ADDRESS) }; let uart = unsafe { Uart::new(UniqueMmioPointer::new(PL011_BASE_ADDRESS)) };
logger::init(uart, LevelFilter::Trace).unwrap(); logger::init(uart, LevelFilter::Trace).unwrap();
info!("main({x0:#x}, {x1:#x}, {x2:#x}, {x3:#x})"); info!("main({x0:#x}, {x1:#x}, {x2:#x}, {x3:#x})");

View File

@ -17,12 +17,15 @@ use core::fmt::{self, Write};
// ANCHOR: Flags // ANCHOR: Flags
use bitflags::bitflags; use bitflags::bitflags;
use zerocopy::{FromBytes, IntoBytes};
bitflags! {
/// Flags from the UART flag register. /// Flags from the UART flag register.
#[repr(transparent)] #[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Eq, FromBytes, IntoBytes, PartialEq)]
struct Flags: u16 { struct Flags(u16);
bitflags! {
impl Flags: u16 {
/// Clear to send. /// Clear to send.
const CTS = 1 << 0; const CTS = 1 << 0;
/// Data set ready. /// Data set ready.
@ -45,11 +48,13 @@ bitflags! {
} }
// ANCHOR_END: Flags // ANCHOR_END: Flags
bitflags! {
/// Flags from the UART Receive Status Register / Error Clear Register. /// Flags from the UART Receive Status Register / Error Clear Register.
#[repr(transparent)] #[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Eq, FromBytes, IntoBytes, PartialEq)]
struct ReceiveStatus: u16 { struct ReceiveStatus(u16);
bitflags! {
impl ReceiveStatus: u16 {
/// Framing error. /// Framing error.
const FE = 1 << 0; const FE = 1 << 0;
/// Parity error. /// Parity error.
@ -62,70 +67,64 @@ bitflags! {
} }
// ANCHOR: Registers // ANCHOR: Registers
use safe_mmio::fields::{ReadPure, ReadPureWrite, ReadWrite, WriteOnly};
#[repr(C, align(4))] #[repr(C, align(4))]
struct Registers { pub struct Registers {
dr: u16, dr: ReadWrite<u16>,
_reserved0: [u8; 2], _reserved0: [u8; 2],
rsr: ReceiveStatus, rsr: ReadPure<ReceiveStatus>,
_reserved1: [u8; 19], _reserved1: [u8; 19],
fr: Flags, fr: ReadPure<Flags>,
_reserved2: [u8; 6], _reserved2: [u8; 6],
ilpr: u8, ilpr: ReadPureWrite<u8>,
_reserved3: [u8; 3], _reserved3: [u8; 3],
ibrd: u16, ibrd: ReadPureWrite<u16>,
_reserved4: [u8; 2], _reserved4: [u8; 2],
fbrd: u8, fbrd: ReadPureWrite<u8>,
_reserved5: [u8; 3], _reserved5: [u8; 3],
lcr_h: u8, lcr_h: ReadPureWrite<u8>,
_reserved6: [u8; 3], _reserved6: [u8; 3],
cr: u16, cr: ReadPureWrite<u16>,
_reserved7: [u8; 3], _reserved7: [u8; 3],
ifls: u8, ifls: ReadPureWrite<u8>,
_reserved8: [u8; 3], _reserved8: [u8; 3],
imsc: u16, imsc: ReadPureWrite<u16>,
_reserved9: [u8; 2], _reserved9: [u8; 2],
ris: u16, ris: ReadPure<u16>,
_reserved10: [u8; 2], _reserved10: [u8; 2],
mis: u16, mis: ReadPure<u16>,
_reserved11: [u8; 2], _reserved11: [u8; 2],
icr: u16, icr: WriteOnly<u16>,
_reserved12: [u8; 2], _reserved12: [u8; 2],
dmacr: u8, dmacr: ReadPureWrite<u8>,
_reserved13: [u8; 3], _reserved13: [u8; 3],
} }
// ANCHOR_END: Registers // ANCHOR_END: Registers
// ANCHOR: Uart // ANCHOR: Uart
use safe_mmio::{UniqueMmioPointer, field, field_shared};
/// Driver for a PL011 UART. /// Driver for a PL011 UART.
#[derive(Debug)] #[derive(Debug)]
pub struct Uart { pub struct Uart<'a> {
registers: *mut Registers, registers: UniqueMmioPointer<'a, Registers>,
} }
impl Uart { impl<'a> Uart<'a> {
/// Constructs a new instance of the UART driver for a PL011 device at the /// Constructs a new instance of the UART driver for a PL011 device with the
/// given base address. /// given set of registers.
/// pub fn new(registers: UniqueMmioPointer<'a, Registers>) -> Self {
/// # Safety Self { registers }
///
/// The given base address must point to the 8 MMIO control registers of a
/// PL011 device, which must be mapped into the address space of the process
/// as device memory and not have any other aliases.
pub unsafe fn new(base_address: *mut u32) -> Self {
Self { registers: base_address as *mut Registers }
} }
/// Writes a single byte to the UART. /// Writes a single byte to the UART.
pub fn write_byte(&self, byte: u8) { pub fn write_byte(&mut self, byte: u8) {
// Wait until there is room in the TX buffer. // Wait until there is room in the TX buffer.
while self.read_flag_register().contains(Flags::TXFF) {} while self.read_flag_register().contains(Flags::TXFF) {}
// SAFETY: We know that self.registers points to the control registers
// of a PL011 device which is appropriately mapped.
unsafe {
// Write to the TX buffer. // Write to the TX buffer.
(&raw mut (*self.registers).dr).write_volatile(byte.into()); field!(self.registers, dr).write(byte.into());
}
// Wait until the UART is no longer busy. // Wait until the UART is no longer busy.
while self.read_flag_register().contains(Flags::BUSY) {} while self.read_flag_register().contains(Flags::BUSY) {}
@ -133,27 +132,23 @@ impl Uart {
/// Reads and returns a pending byte, or `None` if nothing has been /// Reads and returns a pending byte, or `None` if nothing has been
/// received. /// received.
pub fn read_byte(&self) -> Option<u8> { pub fn read_byte(&mut self) -> Option<u8> {
if self.read_flag_register().contains(Flags::RXFE) { if self.read_flag_register().contains(Flags::RXFE) {
None None
} else { } else {
// SAFETY: We know that self.registers points to the control let data = field!(self.registers, dr).read();
// registers of a PL011 device which is appropriately mapped.
let data = unsafe { (&raw const (*self.registers).dr).read_volatile() };
// TODO: Check for error conditions in bits 8-11. // TODO: Check for error conditions in bits 8-11.
Some(data as u8) Some(data as u8)
} }
} }
fn read_flag_register(&self) -> Flags { fn read_flag_register(&self) -> Flags {
// SAFETY: We know that self.registers points to the control registers field_shared!(self.registers, fr).read()
// of a PL011 device which is appropriately mapped.
unsafe { (&raw const (*self.registers).fr).read_volatile() }
} }
} }
// ANCHOR_END: Uart // ANCHOR_END: Uart
impl Write for Uart { impl Write for Uart<'_> {
fn write_str(&mut self, s: &str) -> fmt::Result { fn write_str(&mut self, s: &str) -> fmt::Result {
for c in s.as_bytes() { for c in s.as_bytes() {
self.write_byte(*c); self.write_byte(*c);
@ -161,7 +156,3 @@ impl Write for Uart {
Ok(()) Ok(())
} }
} }
// Safe because it just contains a pointer to device memory, which can be
// accessed from any context.
unsafe impl Send for Uart {}

View File

@ -157,8 +157,10 @@ dependencies = [
"bitflags", "bitflags",
"chrono", "chrono",
"log", "log",
"safe-mmio",
"smccc", "smccc",
"spin", "spin",
"zerocopy",
] ]
[[package]] [[package]]

View File

@ -14,5 +14,7 @@ arm-pl011-uart = "0.3.1"
bitflags = "2.9.1" bitflags = "2.9.1"
chrono = { version = "0.4.41", default-features = false } chrono = { version = "0.4.41", default-features = false }
log = "0.4.27" log = "0.4.27"
safe-mmio = "0.2.5"
smccc = "0.2.1" smccc = "0.2.1"
spin = "0.10.0" spin = "0.10.0"
zerocopy = "0.8.25"

View File

@ -72,7 +72,8 @@ initial_pagetable!({
// ANCHOR_END: imports // ANCHOR_END: imports
/// Base address of the PL031 RTC. /// Base address of the PL031 RTC.
const PL031_BASE_ADDRESS: *mut u32 = 0x901_0000 as _; const PL031_BASE_ADDRESS: NonNull<pl031::Registers> =
NonNull::new(0x901_0000 as _).unwrap();
/// The IRQ used by the PL031 RTC. /// The IRQ used by the PL031 RTC.
const PL031_IRQ: IntId = IntId::spi(2); const PL031_IRQ: IntId = IntId::spi(2);
@ -96,7 +97,7 @@ fn main(x0: u64, x1: u64, x2: u64, x3: u64) -> ! {
// SAFETY: `PL031_BASE_ADDRESS` is the base address of a PL031 device, and // SAFETY: `PL031_BASE_ADDRESS` is the base address of a PL031 device, and
// nothing else accesses that address range. // nothing else accesses that address range.
let mut rtc = unsafe { Rtc::new(PL031_BASE_ADDRESS) }; let mut rtc = unsafe { Rtc::new(UniqueMmioPointer::new(PL031_BASE_ADDRESS)) };
let timestamp = rtc.read(); let timestamp = rtc.read();
let time = Utc.timestamp_opt(timestamp.into(), 0).unwrap(); let time = Utc.timestamp_opt(timestamp.into(), 0).unwrap();
info!("RTC: {time}"); info!("RTC: {time}");

View File

@ -12,72 +12,63 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use safe_mmio::fields::{ReadPure, ReadPureWrite, WriteOnly};
use safe_mmio::{UniqueMmioPointer, field, field_shared};
// ANCHOR: solution // ANCHOR: solution
#[repr(C, align(4))] #[repr(C, align(4))]
struct Registers { pub struct Registers {
/// Data register /// Data register
dr: u32, dr: ReadPure<u32>,
/// Match register /// Match register
mr: u32, mr: ReadPureWrite<u32>,
/// Load register /// Load register
lr: u32, lr: ReadPureWrite<u32>,
/// Control register /// Control register
cr: u8, cr: ReadPureWrite<u8>,
_reserved0: [u8; 3], _reserved0: [u8; 3],
/// Interrupt Mask Set or Clear register /// Interrupt Mask Set or Clear register
imsc: u8, imsc: ReadPureWrite<u8>,
_reserved1: [u8; 3], _reserved1: [u8; 3],
/// Raw Interrupt Status /// Raw Interrupt Status
ris: u8, ris: ReadPure<u8>,
_reserved2: [u8; 3], _reserved2: [u8; 3],
/// Masked Interrupt Status /// Masked Interrupt Status
mis: u8, mis: ReadPure<u8>,
_reserved3: [u8; 3], _reserved3: [u8; 3],
/// Interrupt Clear Register /// Interrupt Clear Register
icr: u8, icr: WriteOnly<u8>,
_reserved4: [u8; 3], _reserved4: [u8; 3],
} }
/// Driver for a PL031 real-time clock. /// Driver for a PL031 real-time clock.
#[derive(Debug)] #[derive(Debug)]
pub struct Rtc { pub struct Rtc<'a> {
registers: *mut Registers, registers: UniqueMmioPointer<'a, Registers>,
} }
impl Rtc { impl<'a> Rtc<'a> {
/// Constructs a new instance of the RTC driver for a PL031 device at the /// Constructs a new instance of the RTC driver for a PL031 device with the
/// given base address. /// given set of registers.
/// pub fn new(registers: UniqueMmioPointer<'a, Registers>) -> Self {
/// # Safety Self { registers }
///
/// The given base address must point to the MMIO control registers of a
/// PL031 device, which must be mapped into the address space of the process
/// as device memory and not have any other aliases.
pub unsafe fn new(base_address: *mut u32) -> Self {
Self { registers: base_address as *mut Registers }
} }
/// Reads the current RTC value. /// Reads the current RTC value.
pub fn read(&self) -> u32 { pub fn read(&self) -> u32 {
// SAFETY: We know that self.registers points to the control registers field_shared!(self.registers, dr).read()
// of a PL031 device which is appropriately mapped.
unsafe { (&raw const (*self.registers).dr).read_volatile() }
} }
/// Writes a match value. When the RTC value matches this then an interrupt /// Writes a match value. When the RTC value matches this then an interrupt
/// will be generated (if it is enabled). /// will be generated (if it is enabled).
pub fn set_match(&mut self, value: u32) { pub fn set_match(&mut self, value: u32) {
// SAFETY: We know that self.registers points to the control registers field!(self.registers, mr).write(value);
// of a PL031 device which is appropriately mapped.
unsafe { (&raw mut (*self.registers).mr).write_volatile(value) }
} }
/// Returns whether the match register matches the RTC value, whether or not /// Returns whether the match register matches the RTC value, whether or not
/// the interrupt is enabled. /// the interrupt is enabled.
pub fn matched(&self) -> bool { pub fn matched(&self) -> bool {
// SAFETY: We know that self.registers points to the control registers let ris = field_shared!(self.registers, ris).read();
// of a PL031 device which is appropriately mapped.
let ris = unsafe { (&raw const (*self.registers).ris).read_volatile() };
(ris & 0x01) != 0 (ris & 0x01) != 0
} }
@ -86,10 +77,8 @@ impl Rtc {
/// This should be true if and only if `matched` returns true and the /// This should be true if and only if `matched` returns true and the
/// interrupt is masked. /// interrupt is masked.
pub fn interrupt_pending(&self) -> bool { pub fn interrupt_pending(&self) -> bool {
// SAFETY: We know that self.registers points to the control registers let mis = field_shared!(self.registers, mis).read();
// of a PL031 device which is appropriately mapped. (mis & 0x01) != 0
let ris = unsafe { (&raw const (*self.registers).mis).read_volatile() };
(ris & 0x01) != 0
} }
/// Sets or clears the interrupt mask. /// Sets or clears the interrupt mask.
@ -98,19 +87,11 @@ impl Rtc {
/// interrupt is disabled. /// interrupt is disabled.
pub fn enable_interrupt(&mut self, mask: bool) { pub fn enable_interrupt(&mut self, mask: bool) {
let imsc = if mask { 0x01 } else { 0x00 }; let imsc = if mask { 0x01 } else { 0x00 };
// SAFETY: We know that self.registers points to the control registers field!(self.registers, imsc).write(imsc);
// of a PL031 device which is appropriately mapped.
unsafe { (&raw mut (*self.registers).imsc).write_volatile(imsc) }
} }
/// Clears a pending interrupt, if any. /// Clears a pending interrupt, if any.
pub fn clear_interrupt(&mut self) { pub fn clear_interrupt(&mut self) {
// SAFETY: We know that self.registers points to the control registers field!(self.registers, icr).write(0x01);
// of a PL031 device which is appropriately mapped.
unsafe { (&raw mut (*self.registers).icr).write_volatile(0x01) }
} }
} }
// SAFETY: `Rtc` just contains a pointer to device memory, which can be
// accessed from any context.
unsafe impl Send for Rtc {}

View File

@ -12,6 +12,6 @@ _main.rs_:
_pl031.rs_: _pl031.rs_:
```rust ```rust,compile_fail
{{#include rtc/src/pl031.rs:solution}} {{#include rtc/src/pl031.rs:solution}}
``` ```

View File

@ -18,7 +18,6 @@ export const size_exemptions = [
]; ];
export const playground_size_exemptions = [ export const playground_size_exemptions = [
"bare-metal/aps/better-uart/driver.html",
"bare-metal/microcontrollers/type-state.html", "bare-metal/microcontrollers/type-state.html",
"concurrency/async-pitfalls/cancellation.html", "concurrency/async-pitfalls/cancellation.html",
"iterators/intoiterator.html", "iterators/intoiterator.html",