fix(core/rust): fix UB due to unaligned access in protobuf codec

This commit is contained in:
matejcik
2025-10-18 14:29:26 +02:00
committed by matejcik
parent fe21ddc727
commit 5a0f3a6220
5 changed files with 180 additions and 116 deletions

View File

@@ -91,30 +91,36 @@ if not PROTOC:
PROTOC_PREFIX = Path(PROTOC).resolve().parent.parent
ENUM_ENTRY = c.PrefixedArray(c.Byte, c.Int16ul)
ENUM_ENTRY = c.PrefixedArray(c.Int16ul, c.Int16ul)
FIELD_STRUCT = c.Struct(
"tag" / c.Byte,
"flags_and_type"
/ c.BitStruct(
"is_required" / c.Flag,
"is_repeated" / c.Flag,
"is_experimental" / c.Flag,
c.Padding(1),
"type" / c.BitsInteger(4),
FIELD_STRUCT = c.Aligned(
2,
c.Struct(
"tag" / c.Byte,
"flags_and_type"
/ c.BitStruct(
"is_required" / c.Flag,
"is_repeated" / c.Flag,
"is_experimental" / c.Flag,
c.Padding(1),
"type" / c.BitsInteger(4),
),
"enum_or_msg_offset" / c.Int16ul,
"name" / c.Int16ul,
),
"enum_or_msg_offset" / c.Int16ul,
"name" / c.Int16ul,
)
MSG_ENTRY = c.Struct(
"fields_count" / c.Rebuild(c.Byte, c.len_(c.this.fields)),
"defaults_size" / c.Rebuild(c.Byte, c.len_(c.this.defaults)),
# highest bit = is_experimental
# the rest = wire_id, 0x7FFF iff unset
"flags_and_wire_type" / c.Int16ul,
"fields" / c.Array(c.this.fields_count, FIELD_STRUCT),
"defaults" / c.Bytes(c.this.defaults_size),
MSG_ENTRY = c.Aligned(
2,
c.Struct(
"fields_count" / c.Rebuild(c.Byte, c.len_(c.this.fields)),
"defaults_size" / c.Rebuild(c.Byte, c.len_(c.this.defaults)),
# highest bit = is_experimental
# the rest = wire_id, 0x7FFF iff unset
"flags_and_wire_type" / c.Int16ul,
"fields" / c.Array(c.this.fields_count, FIELD_STRUCT),
"defaults" / c.Bytes(c.this.defaults_size),
),
)
DEFAULT_VARINT_ENTRY = c.Sequence(c.Byte, c.VarInt)

View File

@@ -0,0 +1,21 @@
// taken from: https://users.rust-lang.org/t/can-i-conveniently-compile-bytes-into-a-rust-program-with-a-specific-alignment/24049/2
#[repr(C)]
pub struct AlignedTo<Align, Bytes: ?Sized> {
pub _align: [Align; 0],
pub data: Bytes,
}
macro_rules! include_aligned {
($align:ty, $filename:expr) => {{
use $crate::align::AlignedTo;
static ALIGNED: &AlignedTo<$align, [u8]> = &AlignedTo {
_align: [],
data: *include_bytes!($filename),
};
&ALIGNED.data
}};
}
pub(crate) use include_aligned;

View File

@@ -18,6 +18,7 @@ extern crate num_derive;
#[macro_use]
mod macros;
mod align;
#[cfg(feature = "debug")]
mod coverage;
#[cfg(feature = "crypto")]

View File

@@ -1,4 +1,17 @@
use core::{mem, slice};
use core::mem;
use crate::align::include_aligned;
macro_rules! proto_def_path {
($filename:expr) => {
concat!(env!("BUILD_DIR"), "/rust/", $filename)
};
}
static ENUM_DEFS: &[u8] = include_aligned!(u16, proto_def_path!("proto_enums.data"));
static MSG_DEFS: &[u8] = include_aligned!(u16, proto_def_path!("proto_msgs.data"));
static NAME_DEFS: &[u8] = include_aligned!(NameDef, proto_def_path!("proto_names.data"));
static WIRE_DEFS: &[u8] = include_aligned!(WireDef, proto_def_path!("proto_wire.data"));
pub struct MsgDef {
pub fields: &'static [FieldDef],
@@ -10,19 +23,11 @@ pub struct MsgDef {
impl MsgDef {
pub fn for_name(msg_name: u16) -> Option<Self> {
find_msg_offset_by_name(msg_name).map(|msg_offset| unsafe {
// SAFETY: We are taking the offset right out of the definitions so we can be
// sure it's to be trusted.
get_msg(msg_offset)
})
find_msg_offset_by_name(msg_name).map(get_msg)
}
pub fn for_wire_id(enum_name: u16, wire_id: u16) -> Option<Self> {
find_msg_offset_by_wire(enum_name, wire_id).map(|msg_offset| unsafe {
// SAFETY: We are taking the offset right out of the definitions so we can be
// sure it's to be trusted.
get_msg(msg_offset)
})
find_msg_offset_by_wire(enum_name, wire_id).map(get_msg)
}
pub fn field(&self, tag: u8) -> Option<&FieldDef> {
@@ -30,7 +35,7 @@ impl MsgDef {
}
}
#[repr(C, packed)]
#[repr(C)]
pub struct FieldDef {
pub tag: u8,
flags_and_type: u8,
@@ -38,6 +43,19 @@ pub struct FieldDef {
pub name: u16,
}
const STATIC_ASSERT_FIELD_DEF_ALIGNMENT: () = {
// alignment must be that of u16 aka the largest element
debug_assert!(mem::align_of::<FieldDef>() == mem::align_of::<u16>());
// the total size must be the same as the sum of the sizes of all the elements
debug_assert!(
mem::size_of::<FieldDef>()
== mem::size_of::<u8>()
+ mem::size_of::<u8>()
+ mem::size_of::<u16>()
+ mem::size_of::<u16>()
);
};
impl FieldDef {
pub fn get_type(&self) -> FieldType {
match self.ftype() {
@@ -46,8 +64,8 @@ impl FieldDef {
2 => FieldType::Bool,
3 => FieldType::Bytes,
4 => FieldType::String,
5 => FieldType::Enum(unsafe { get_enum(self.enum_or_msg_offset) }),
6 => FieldType::Msg(unsafe { get_msg(self.enum_or_msg_offset) }),
5 => FieldType::Enum(get_enum(self.enum_or_msg_offset)),
6 => FieldType::Msg(get_msg(self.enum_or_msg_offset)),
_ => unreachable!(),
}
}
@@ -103,43 +121,66 @@ pub struct EnumDef {
pub values: &'static [u16],
}
#[repr(C, packed)]
#[repr(C)]
struct NameDef {
msg_name: u16,
msg_offset: u16,
}
macro_rules! proto_def_path {
($filename:expr) => {
concat!(env!("BUILD_DIR"), "/rust/", $filename)
};
const STATIC_ASSERT_NAME_DEF_ALIGNMENT: () = {
// alignment must be that of u16 aka the largest element
debug_assert!(mem::align_of::<NameDef>() == mem::align_of::<u16>());
// the total size must be the same as two u16s
debug_assert!(mem::size_of::<NameDef>() == 2 * mem::size_of::<u16>());
};
impl NameDef {
pub fn defs() -> &'static [NameDef] {
// SAFETY: NameDef is a packed struct of ints, so all bit patterns are valid.
let (_pre, name_defs, _post) = unsafe { NAME_DEFS.align_to::<NameDef>() };
// per `include_aligned!` macro, NAME_DEFS is aligned to NameDef, so `name_defs`
// array should be cleanly aligned.
debug_assert!(_pre.is_empty());
debug_assert!(_post.is_empty());
name_defs
}
}
static ENUM_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_enums.data"));
static MSG_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_msgs.data"));
static NAME_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_names.data"));
static WIRE_DEFS: &[u8] = include_bytes!(proto_def_path!("proto_wire.data"));
#[repr(C)]
struct WireDef {
enum_name: u16,
wire_id: u16,
msg_offset: u16,
}
const STATIC_ASSERT_WIRE_DEF_ALIGNMENT: () = {
// alignment must be that of u16 aka the largest element
debug_assert!(mem::align_of::<WireDef>() == mem::align_of::<u16>());
// the total size must be the same as three u16s
debug_assert!(mem::size_of::<WireDef>() == 3 * mem::size_of::<u16>());
};
impl WireDef {
pub fn defs() -> &'static [WireDef] {
// SAFETY: WireDef is a packed struct of ints, so all bit patterns are valid.
let (_pre, wire_defs, _post) = unsafe { WIRE_DEFS.align_to::<WireDef>() };
// per `include_aligned!` macro, WIRE_DEFS is aligned to WireDef, so `wire_defs`
// array should be cleanly aligned.
debug_assert!(_pre.is_empty());
debug_assert!(_post.is_empty());
wire_defs
}
}
pub fn find_name_by_msg_offset(msg_offset: u16) -> Option<u16> {
let name_defs: &[NameDef] = unsafe {
slice::from_raw_parts(
NAME_DEFS.as_ptr().cast(),
NAME_DEFS.len() / mem::size_of::<NameDef>(),
)
};
name_defs
NameDef::defs()
.iter()
.find(|def| def.msg_offset == msg_offset)
.map(|def| def.msg_name)
}
fn find_msg_offset_by_name(msg_name: u16) -> Option<u16> {
let name_defs: &[NameDef] = unsafe {
slice::from_raw_parts(
NAME_DEFS.as_ptr().cast(),
NAME_DEFS.len() / mem::size_of::<NameDef>(),
)
};
let name_defs = NameDef::defs();
name_defs
.binary_search_by_key(&msg_name, |def| def.msg_name)
.map(|i| name_defs[i].msg_offset)
@@ -147,34 +188,14 @@ fn find_msg_offset_by_name(msg_name: u16) -> Option<u16> {
}
fn find_msg_offset_by_wire(enum_name: u16, wire_id: u16) -> Option<u16> {
#[repr(C, packed)]
struct WireDef {
enum_name: u16,
wire_id: u16,
msg_offset: u16,
}
let wire_defs: &[WireDef] = unsafe {
slice::from_raw_parts(
WIRE_DEFS.as_ptr().cast(),
WIRE_DEFS.len() / mem::size_of::<WireDef>(),
)
};
let wire_defs = WireDef::defs();
wire_defs
.binary_search_by(|def| {
// need to make a copy to avoid taking a reference of unaligned value from
// packed struct, see rust error E0793
let def_enum_name = def.enum_name;
def_enum_name.cmp(&enum_name).then_with(|| {
let def_wire_id = def.wire_id;
def_wire_id.cmp(&wire_id)
})
})
.binary_search_by_key(&(enum_name, wire_id), |def| (def.enum_name, def.wire_id))
.map(|i| wire_defs[i].msg_offset)
.ok()
}
pub unsafe fn get_msg(msg_offset: u16) -> MsgDef {
pub fn get_msg(msg_offset: u16) -> MsgDef {
// #[repr(C, packed)]
// struct MsgDef {
// fields_count: u8,
@@ -183,55 +204,70 @@ pub unsafe fn get_msg(msg_offset: u16) -> MsgDef {
// fields: [Field],
// defaults: [u8],
// }
let msg_def_start = &MSG_DEFS[msg_offset as usize..];
let fields_count = msg_def_start[0] as usize;
let defaults_size = msg_def_start[1] as usize;
// SAFETY: `msg_offset` has to point to a beginning of a valid message
// definition inside `MSG_DEFS`.
unsafe {
let ptr = MSG_DEFS.as_ptr().add(msg_offset as usize);
let fields_count = ptr.offset(0).read() as usize;
let defaults_size = ptr.offset(1).read() as usize;
let flags_and_wire_id = u16::from_le_bytes([msg_def_start[2], msg_def_start[3]]);
let is_experimental = flags_and_wire_id & 0x8000 != 0;
let wire_id = match flags_and_wire_id & 0x7FFF {
0x7FFF => None,
some_wire_id => Some(some_wire_id),
};
let flags_and_wire_id_lo = ptr.offset(2).read();
let flags_and_wire_id_hi = ptr.offset(3).read();
let flags_and_wire_id = u16::from_le_bytes([flags_and_wire_id_lo, flags_and_wire_id_hi]);
let fields_size_in_bytes = fields_count * mem::size_of::<FieldDef>();
let fields_start = 4;
let fields_end = fields_start + fields_size_in_bytes;
let fields_byteslice = &msg_def_start[fields_start..fields_end];
let is_experimental = flags_and_wire_id & 0x8000 != 0;
let wire_id = match flags_and_wire_id & 0x7FFF {
0x7FFF => None,
some_wire_id => Some(some_wire_id),
};
// PREREQUISITES:
// * MSG_DEFS is aligned to u16 (per `include_aligned!` macro)
// * FieldDef has the same alignment
debug_assert!(mem::align_of::<FieldDef>() == mem::align_of::<u16>());
// * both msg_offset and fields_start added together keep the alignment:
debug_assert!(fields_byteslice.as_ptr().addr() % mem::align_of::<FieldDef>() == 0);
let fields_size = fields_count * mem::size_of::<FieldDef>();
let fields_ptr = ptr.offset(4);
let defaults_ptr = ptr.offset(4).add(fields_size);
// SAFETY: FieldDef is a packed struct of ints, so all bit patterns are valid.
let (_pre, fields, _post) = unsafe { fields_byteslice.align_to::<FieldDef>() };
// Given the prerequisites, `fields` array must be cleanly aligned.
debug_assert!(_pre.is_empty());
debug_assert!(_post.is_empty());
MsgDef {
fields: slice::from_raw_parts(fields_ptr.cast(), fields_count),
defaults: slice::from_raw_parts(defaults_ptr.cast(), defaults_size),
is_experimental,
wire_id,
offset: msg_offset,
}
let defaults_start = fields_end;
let defaults_end = defaults_start + defaults_size;
// `defaults` is a byteslice so its alignment is always ok
let defaults = &msg_def_start[defaults_start..defaults_end];
MsgDef {
fields,
defaults,
is_experimental,
wire_id,
offset: msg_offset,
}
}
unsafe fn get_enum(enum_offset: u16) -> EnumDef {
// #[repr(C, packed)]
fn get_enum(enum_offset: u16) -> EnumDef {
// #[repr(C)]
// struct EnumDef {
// count: u8,
// count: u16,
// vals: [u16],
// }
const SIZE: u16 = mem::size_of::<u16>() as u16;
// SAFETY: `enum_offset` has to point to a beginning of a valid enum
// definition inside `ENUM_DEFS`.
unsafe {
let ptr = ENUM_DEFS.as_ptr().add(enum_offset as usize);
let count = ptr.offset(0).read() as usize;
let vals = ptr.offset(1);
// SAFETY: enum_defs is an array of u16, so all bit patterns are valid.
let (_pre, enum_defs, _post) = unsafe { ENUM_DEFS.align_to::<u16>() };
// ENUM_DEFS is aligned to u16 per `include_aligned!` macro
debug_assert!(_pre.is_empty());
debug_assert!(_post.is_empty());
EnumDef {
values: slice::from_raw_parts(vals.cast(), count),
}
// enum_offset is a raw byte offset, we check that it is also a valid index of
// an u16
assert!(enum_offset % SIZE == 0);
let offset: usize = (enum_offset / SIZE).into();
let count: usize = enum_defs[offset].into();
EnumDef {
values: &enum_defs[offset + 1..offset + 1 + count],
}
}

View File

@@ -49,7 +49,7 @@ impl MsgObj {
}
pub fn def(&self) -> MsgDef {
unsafe { get_msg(self.msg_offset) }
get_msg(self.msg_offset)
}
fn obj_type() -> &'static Type {