mirror of
https://github.com/trezor/trezor-firmware.git
synced 2026-02-20 00:33:30 +01:00
fix(core/rust): fix UB due to unaligned access in protobuf codec
This commit is contained in:
@@ -91,30 +91,36 @@ if not PROTOC:
|
||||
PROTOC_PREFIX = Path(PROTOC).resolve().parent.parent
|
||||
|
||||
|
||||
ENUM_ENTRY = c.PrefixedArray(c.Byte, c.Int16ul)
|
||||
ENUM_ENTRY = c.PrefixedArray(c.Int16ul, c.Int16ul)
|
||||
|
||||
FIELD_STRUCT = c.Struct(
|
||||
"tag" / c.Byte,
|
||||
"flags_and_type"
|
||||
/ c.BitStruct(
|
||||
"is_required" / c.Flag,
|
||||
"is_repeated" / c.Flag,
|
||||
"is_experimental" / c.Flag,
|
||||
c.Padding(1),
|
||||
"type" / c.BitsInteger(4),
|
||||
FIELD_STRUCT = c.Aligned(
|
||||
2,
|
||||
c.Struct(
|
||||
"tag" / c.Byte,
|
||||
"flags_and_type"
|
||||
/ c.BitStruct(
|
||||
"is_required" / c.Flag,
|
||||
"is_repeated" / c.Flag,
|
||||
"is_experimental" / c.Flag,
|
||||
c.Padding(1),
|
||||
"type" / c.BitsInteger(4),
|
||||
),
|
||||
"enum_or_msg_offset" / c.Int16ul,
|
||||
"name" / c.Int16ul,
|
||||
),
|
||||
"enum_or_msg_offset" / c.Int16ul,
|
||||
"name" / c.Int16ul,
|
||||
)
|
||||
|
||||
MSG_ENTRY = c.Struct(
|
||||
"fields_count" / c.Rebuild(c.Byte, c.len_(c.this.fields)),
|
||||
"defaults_size" / c.Rebuild(c.Byte, c.len_(c.this.defaults)),
|
||||
# highest bit = is_experimental
|
||||
# the rest = wire_id, 0x7FFF iff unset
|
||||
"flags_and_wire_type" / c.Int16ul,
|
||||
"fields" / c.Array(c.this.fields_count, FIELD_STRUCT),
|
||||
"defaults" / c.Bytes(c.this.defaults_size),
|
||||
MSG_ENTRY = c.Aligned(
|
||||
2,
|
||||
c.Struct(
|
||||
"fields_count" / c.Rebuild(c.Byte, c.len_(c.this.fields)),
|
||||
"defaults_size" / c.Rebuild(c.Byte, c.len_(c.this.defaults)),
|
||||
# highest bit = is_experimental
|
||||
# the rest = wire_id, 0x7FFF iff unset
|
||||
"flags_and_wire_type" / c.Int16ul,
|
||||
"fields" / c.Array(c.this.fields_count, FIELD_STRUCT),
|
||||
"defaults" / c.Bytes(c.this.defaults_size),
|
||||
),
|
||||
)
|
||||
|
||||
DEFAULT_VARINT_ENTRY = c.Sequence(c.Byte, c.VarInt)
|
||||
|
||||
21
core/embed/rust/src/align.rs
Normal file
21
core/embed/rust/src/align.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
// taken from: https://users.rust-lang.org/t/can-i-conveniently-compile-bytes-into-a-rust-program-with-a-specific-alignment/24049/2
|
||||
|
||||
#[repr(C)]
|
||||
pub struct AlignedTo<Align, Bytes: ?Sized> {
|
||||
pub _align: [Align; 0],
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
macro_rules! include_aligned {
|
||||
($align:ty, $filename:expr) => {{
|
||||
use $crate::align::AlignedTo;
|
||||
|
||||
static ALIGNED: &AlignedTo<$align, [u8]> = &AlignedTo {
|
||||
_align: [],
|
||||
data: *include_bytes!($filename),
|
||||
};
|
||||
|
||||
&ALIGNED.data
|
||||
}};
|
||||
}
|
||||
pub(crate) use include_aligned;
|
||||
@@ -18,6 +18,7 @@ extern crate num_derive;
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
|
||||
mod align;
|
||||
#[cfg(feature = "debug")]
|
||||
mod coverage;
|
||||
#[cfg(feature = "crypto")]
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
use core::{mem, slice};
|
||||
use core::mem;
|
||||
|
||||
use crate::align::include_aligned;
|
||||
|
||||
macro_rules! proto_def_path {
|
||||
($filename:expr) => {
|
||||
concat!(env!("BUILD_DIR"), "/rust/", $filename)
|
||||
};
|
||||
}
|
||||
|
||||
static ENUM_DEFS: &[u8] = include_aligned!(u16, proto_def_path!("proto_enums.data"));
|
||||
static MSG_DEFS: &[u8] = include_aligned!(u16, proto_def_path!("proto_msgs.data"));
|
||||
static NAME_DEFS: &[u8] = include_aligned!(NameDef, proto_def_path!("proto_names.data"));
|
||||
static WIRE_DEFS: &[u8] = include_aligned!(WireDef, proto_def_path!("proto_wire.data"));
|
||||
|
||||
pub struct MsgDef {
|
||||
pub fields: &'static [FieldDef],
|
||||
@@ -10,19 +23,11 @@ pub struct MsgDef {
|
||||
|
||||
impl MsgDef {
|
||||
pub fn for_name(msg_name: u16) -> Option<Self> {
|
||||
find_msg_offset_by_name(msg_name).map(|msg_offset| unsafe {
|
||||
// SAFETY: We are taking the offset right out of the definitions so we can be
|
||||
// sure it's to be trusted.
|
||||
get_msg(msg_offset)
|
||||
})
|
||||
find_msg_offset_by_name(msg_name).map(get_msg)
|
||||
}
|
||||
|
||||
pub fn for_wire_id(enum_name: u16, wire_id: u16) -> Option<Self> {
|
||||
find_msg_offset_by_wire(enum_name, wire_id).map(|msg_offset| unsafe {
|
||||
// SAFETY: We are taking the offset right out of the definitions so we can be
|
||||
// sure it's to be trusted.
|
||||
get_msg(msg_offset)
|
||||
})
|
||||
find_msg_offset_by_wire(enum_name, wire_id).map(get_msg)
|
||||
}
|
||||
|
||||
pub fn field(&self, tag: u8) -> Option<&FieldDef> {
|
||||
@@ -30,7 +35,7 @@ impl MsgDef {
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[repr(C)]
|
||||
pub struct FieldDef {
|
||||
pub tag: u8,
|
||||
flags_and_type: u8,
|
||||
@@ -38,6 +43,19 @@ pub struct FieldDef {
|
||||
pub name: u16,
|
||||
}
|
||||
|
||||
const STATIC_ASSERT_FIELD_DEF_ALIGNMENT: () = {
|
||||
// alignment must be that of u16 aka the largest element
|
||||
debug_assert!(mem::align_of::<FieldDef>() == mem::align_of::<u16>());
|
||||
// the total size must be the same as the sum of the sizes of all the elements
|
||||
debug_assert!(
|
||||
mem::size_of::<FieldDef>()
|
||||
== mem::size_of::<u8>()
|
||||
+ mem::size_of::<u8>()
|
||||
+ mem::size_of::<u16>()
|
||||
+ mem::size_of::<u16>()
|
||||
);
|
||||
};
|
||||
|
||||
impl FieldDef {
|
||||
pub fn get_type(&self) -> FieldType {
|
||||
match self.ftype() {
|
||||
@@ -46,8 +64,8 @@ impl FieldDef {
|
||||
2 => FieldType::Bool,
|
||||
3 => FieldType::Bytes,
|
||||
4 => FieldType::String,
|
||||
5 => FieldType::Enum(unsafe { get_enum(self.enum_or_msg_offset) }),
|
||||
6 => FieldType::Msg(unsafe { get_msg(self.enum_or_msg_offset) }),
|
||||
5 => FieldType::Enum(get_enum(self.enum_or_msg_offset)),
|
||||
6 => FieldType::Msg(get_msg(self.enum_or_msg_offset)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -103,43 +121,66 @@ pub struct EnumDef {
|
||||
pub values: &'static [u16],
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[repr(C)]
|
||||
struct NameDef {
|
||||
msg_name: u16,
|
||||
msg_offset: u16,
|
||||
}
|
||||
|
||||
macro_rules! proto_def_path {
|
||||
($filename:expr) => {
|
||||
concat!(env!("BUILD_DIR"), "/rust/", $filename)
|
||||
};
|
||||
const STATIC_ASSERT_NAME_DEF_ALIGNMENT: () = {
|
||||
// alignment must be that of u16 aka the largest element
|
||||
debug_assert!(mem::align_of::<NameDef>() == mem::align_of::<u16>());
|
||||
// the total size must be the same as two u16s
|
||||
debug_assert!(mem::size_of::<NameDef>() == 2 * mem::size_of::<u16>());
|
||||
};
|
||||
|
||||
impl NameDef {
|
||||
pub fn defs() -> &'static [NameDef] {
|
||||
// SAFETY: NameDef is a packed struct of ints, so all bit patterns are valid.
|
||||
let (_pre, name_defs, _post) = unsafe { NAME_DEFS.align_to::<NameDef>() };
|
||||
// per `include_aligned!` macro, NAME_DEFS is aligned to NameDef, so `name_defs`
|
||||
// array should be cleanly aligned.
|
||||
debug_assert!(_pre.is_empty());
|
||||
debug_assert!(_post.is_empty());
|
||||
name_defs
|
||||
}
|
||||
}
|
||||
|
||||
static ENUM_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_enums.data"));
|
||||
static MSG_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_msgs.data"));
|
||||
static NAME_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_names.data"));
|
||||
static WIRE_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_wire.data"));
|
||||
#[repr(C)]
|
||||
struct WireDef {
|
||||
enum_name: u16,
|
||||
wire_id: u16,
|
||||
msg_offset: u16,
|
||||
}
|
||||
|
||||
const STATIC_ASSERT_WIRE_DEF_ALIGNMENT: () = {
|
||||
// alignment must be that of u16 aka the largest element
|
||||
debug_assert!(mem::align_of::<WireDef>() == mem::align_of::<u16>());
|
||||
// the total size must be the same as three u16s
|
||||
debug_assert!(mem::size_of::<WireDef>() == 3 * mem::size_of::<u16>());
|
||||
};
|
||||
|
||||
impl WireDef {
|
||||
pub fn defs() -> &'static [WireDef] {
|
||||
// SAFETY: WireDef is a packed struct of ints, so all bit patterns are valid.
|
||||
let (_pre, wire_defs, _post) = unsafe { WIRE_DEFS.align_to::<WireDef>() };
|
||||
// per `include_aligned!` macro, WIRE_DEFS is aligned to WireDef, so `wire_defs`
|
||||
// array should be cleanly aligned.
|
||||
debug_assert!(_pre.is_empty());
|
||||
debug_assert!(_post.is_empty());
|
||||
wire_defs
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_name_by_msg_offset(msg_offset: u16) -> Option<u16> {
|
||||
let name_defs: &[NameDef] = unsafe {
|
||||
slice::from_raw_parts(
|
||||
NAME_DEFS.as_ptr().cast(),
|
||||
NAME_DEFS.len() / mem::size_of::<NameDef>(),
|
||||
)
|
||||
};
|
||||
name_defs
|
||||
NameDef::defs()
|
||||
.iter()
|
||||
.find(|def| def.msg_offset == msg_offset)
|
||||
.map(|def| def.msg_name)
|
||||
}
|
||||
|
||||
fn find_msg_offset_by_name(msg_name: u16) -> Option<u16> {
|
||||
let name_defs: &[NameDef] = unsafe {
|
||||
slice::from_raw_parts(
|
||||
NAME_DEFS.as_ptr().cast(),
|
||||
NAME_DEFS.len() / mem::size_of::<NameDef>(),
|
||||
)
|
||||
};
|
||||
let name_defs = NameDef::defs();
|
||||
name_defs
|
||||
.binary_search_by_key(&msg_name, |def| def.msg_name)
|
||||
.map(|i| name_defs[i].msg_offset)
|
||||
@@ -147,34 +188,14 @@ fn find_msg_offset_by_name(msg_name: u16) -> Option<u16> {
|
||||
}
|
||||
|
||||
fn find_msg_offset_by_wire(enum_name: u16, wire_id: u16) -> Option<u16> {
|
||||
#[repr(C, packed)]
|
||||
struct WireDef {
|
||||
enum_name: u16,
|
||||
wire_id: u16,
|
||||
msg_offset: u16,
|
||||
}
|
||||
|
||||
let wire_defs: &[WireDef] = unsafe {
|
||||
slice::from_raw_parts(
|
||||
WIRE_DEFS.as_ptr().cast(),
|
||||
WIRE_DEFS.len() / mem::size_of::<WireDef>(),
|
||||
)
|
||||
};
|
||||
let wire_defs = WireDef::defs();
|
||||
wire_defs
|
||||
.binary_search_by(|def| {
|
||||
// need to make a copy to avoid taking a reference of unaligned value from
|
||||
// packed struct, see rust error E0793
|
||||
let def_enum_name = def.enum_name;
|
||||
def_enum_name.cmp(&enum_name).then_with(|| {
|
||||
let def_wire_id = def.wire_id;
|
||||
def_wire_id.cmp(&wire_id)
|
||||
})
|
||||
})
|
||||
.binary_search_by_key(&(enum_name, wire_id), |def| (def.enum_name, def.wire_id))
|
||||
.map(|i| wire_defs[i].msg_offset)
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub unsafe fn get_msg(msg_offset: u16) -> MsgDef {
|
||||
pub fn get_msg(msg_offset: u16) -> MsgDef {
|
||||
// #[repr(C, packed)]
|
||||
// struct MsgDef {
|
||||
// fields_count: u8,
|
||||
@@ -183,55 +204,70 @@ pub unsafe fn get_msg(msg_offset: u16) -> MsgDef {
|
||||
// fields: [Field],
|
||||
// defaults: [u8],
|
||||
// }
|
||||
let msg_def_start = &MSG_DEFS[msg_offset as usize..];
|
||||
let fields_count = msg_def_start[0] as usize;
|
||||
let defaults_size = msg_def_start[1] as usize;
|
||||
|
||||
// SAFETY: `msg_offset` has to point to a beginning of a valid message
|
||||
// definition inside `MSG_DEFS`.
|
||||
unsafe {
|
||||
let ptr = MSG_DEFS.as_ptr().add(msg_offset as usize);
|
||||
let fields_count = ptr.offset(0).read() as usize;
|
||||
let defaults_size = ptr.offset(1).read() as usize;
|
||||
let flags_and_wire_id = u16::from_le_bytes([msg_def_start[2], msg_def_start[3]]);
|
||||
let is_experimental = flags_and_wire_id & 0x8000 != 0;
|
||||
let wire_id = match flags_and_wire_id & 0x7FFF {
|
||||
0x7FFF => None,
|
||||
some_wire_id => Some(some_wire_id),
|
||||
};
|
||||
|
||||
let flags_and_wire_id_lo = ptr.offset(2).read();
|
||||
let flags_and_wire_id_hi = ptr.offset(3).read();
|
||||
let flags_and_wire_id = u16::from_le_bytes([flags_and_wire_id_lo, flags_and_wire_id_hi]);
|
||||
let fields_size_in_bytes = fields_count * mem::size_of::<FieldDef>();
|
||||
let fields_start = 4;
|
||||
let fields_end = fields_start + fields_size_in_bytes;
|
||||
let fields_byteslice = &msg_def_start[fields_start..fields_end];
|
||||
|
||||
let is_experimental = flags_and_wire_id & 0x8000 != 0;
|
||||
let wire_id = match flags_and_wire_id & 0x7FFF {
|
||||
0x7FFF => None,
|
||||
some_wire_id => Some(some_wire_id),
|
||||
};
|
||||
// PREREQUISITES:
|
||||
// * MSG_DEFS is aligned to u16 (per `include_aligned!` macro)
|
||||
// * FieldDef has the same alignment
|
||||
debug_assert!(mem::align_of::<FieldDef>() == mem::align_of::<u16>());
|
||||
// * both msg_offset and fields_start added together keep the alignment:
|
||||
debug_assert!(fields_byteslice.as_ptr().addr() % mem::align_of::<FieldDef>() == 0);
|
||||
|
||||
let fields_size = fields_count * mem::size_of::<FieldDef>();
|
||||
let fields_ptr = ptr.offset(4);
|
||||
let defaults_ptr = ptr.offset(4).add(fields_size);
|
||||
// SAFETY: FieldDef is a packed struct of ints, so all bit patterns are valid.
|
||||
let (_pre, fields, _post) = unsafe { fields_byteslice.align_to::<FieldDef>() };
|
||||
// Given the prerequisites, `fields` array must be cleanly aligned.
|
||||
debug_assert!(_pre.is_empty());
|
||||
debug_assert!(_post.is_empty());
|
||||
|
||||
MsgDef {
|
||||
fields: slice::from_raw_parts(fields_ptr.cast(), fields_count),
|
||||
defaults: slice::from_raw_parts(defaults_ptr.cast(), defaults_size),
|
||||
is_experimental,
|
||||
wire_id,
|
||||
offset: msg_offset,
|
||||
}
|
||||
let defaults_start = fields_end;
|
||||
let defaults_end = defaults_start + defaults_size;
|
||||
// `defaults` is a byteslice so its alignment is always ok
|
||||
let defaults = &msg_def_start[defaults_start..defaults_end];
|
||||
|
||||
MsgDef {
|
||||
fields,
|
||||
defaults,
|
||||
is_experimental,
|
||||
wire_id,
|
||||
offset: msg_offset,
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_enum(enum_offset: u16) -> EnumDef {
|
||||
// #[repr(C, packed)]
|
||||
fn get_enum(enum_offset: u16) -> EnumDef {
|
||||
// #[repr(C)]
|
||||
// struct EnumDef {
|
||||
// count: u8,
|
||||
// count: u16,
|
||||
// vals: [u16],
|
||||
// }
|
||||
const SIZE: u16 = mem::size_of::<u16>() as u16;
|
||||
|
||||
// SAFETY: `enum_offset` has to point to a beginning of a valid enum
|
||||
// definition inside `ENUM_DEFS`.
|
||||
unsafe {
|
||||
let ptr = ENUM_DEFS.as_ptr().add(enum_offset as usize);
|
||||
let count = ptr.offset(0).read() as usize;
|
||||
let vals = ptr.offset(1);
|
||||
// SAFETY: enum_defs is an array of u16, so all bit patterns are valid.
|
||||
let (_pre, enum_defs, _post) = unsafe { ENUM_DEFS.align_to::<u16>() };
|
||||
// ENUM_DEFS is aligned to u16 per `include_aligned!` macro
|
||||
debug_assert!(_pre.is_empty());
|
||||
debug_assert!(_post.is_empty());
|
||||
|
||||
EnumDef {
|
||||
values: slice::from_raw_parts(vals.cast(), count),
|
||||
}
|
||||
// enum_offset is a raw byte offset, we check that it is also a valid index of
|
||||
// an u16
|
||||
assert!(enum_offset % SIZE == 0);
|
||||
let offset: usize = (enum_offset / SIZE).into();
|
||||
let count: usize = enum_defs[offset].into();
|
||||
EnumDef {
|
||||
values: &enum_defs[offset + 1..offset + 1 + count],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ impl MsgObj {
|
||||
}
|
||||
|
||||
pub fn def(&self) -> MsgDef {
|
||||
unsafe { get_msg(self.msg_offset) }
|
||||
get_msg(self.msg_offset)
|
||||
}
|
||||
|
||||
fn obj_type() -> &'static Type {
|
||||
|
||||
Reference in New Issue
Block a user