diff --git a/src/lifetimes/exercise.md b/src/lifetimes/exercise.md index c742abf9..4c48bc62 100644 --- a/src/lifetimes/exercise.md +++ b/src/lifetimes/exercise.md @@ -42,7 +42,7 @@ a message into a series of calls to those callbacks. What remains for you is to implement the `parse_field` function and the `ProtoMessage` trait for `Person` and `PhoneNumber`. - + ```rust,editable,compile_fail {{#include exercise.rs:preliminaries }} @@ -62,3 +62,13 @@ What remains for you is to implement the `parse_field` function and the {{#include exercise.rs:main }} ``` + +
+ +- In this exercise there are various cases where protobuf parsing might fail, + e.g. if you try to parse an `i32` when there are fewer than 4 bytes left in + the data buffer. In normal Rust code we'd handle this with the `Result` enum, + but for simplicity in this exercise we panic if any errors are encountered. On + day 4 we'll cover error handling in Rust in more detail. + +
diff --git a/src/lifetimes/exercise.rs b/src/lifetimes/exercise.rs index ab73d9ce..7ab8e891 100644 --- a/src/lifetimes/exercise.rs +++ b/src/lifetimes/exercise.rs @@ -14,25 +14,6 @@ // ANCHOR: solution // ANCHOR: preliminaries -use std::convert::TryFrom; -use thiserror::Error; - -#[derive(Debug, Error)] -enum Error { - #[error("Invalid varint")] - InvalidVarint, - #[error("Invalid wire-type")] - InvalidWireType, - #[error("Unexpected EOF")] - UnexpectedEOF, - #[error("Invalid length")] - InvalidSize(#[from] std::num::TryFromIntError), - #[error("Unexpected wire-type)")] - UnexpectedWireType, - #[error("Invalid string (not UTF-8)")] - InvalidString, -} - /// A wire type as seen on the wire. enum WireType { /// The Varint WireType indicates the value is a single VARINT. @@ -63,51 +44,57 @@ struct Field<'a> { } trait ProtoMessage<'a>: Default + 'a { - fn add_field(&mut self, field: Field<'a>) -> Result<(), Error>; + fn add_field(&mut self, field: Field<'a>); } -impl TryFrom for WireType { - type Error = Error; - - fn try_from(value: u64) -> Result { - Ok(match value { +impl From for WireType { + fn from(value: u64) -> Self { + match value { 0 => WireType::Varint, //1 => WireType::I64, -- not needed for this exercise 2 => WireType::Len, 5 => WireType::I32, - _ => return Err(Error::InvalidWireType), - }) + _ => panic!("Invalid wire type: {value}"), + } } } impl<'a> FieldValue<'a> { - fn as_string(&self) -> Result<&'a str, Error> { + fn as_string(&self) -> &'a str { let FieldValue::Len(data) = self else { - return Err(Error::UnexpectedWireType); + panic!("Expected string to be a `Len` field"); }; - std::str::from_utf8(data).map_err(|_| Error::InvalidString) + std::str::from_utf8(data).expect("Invalid string") } - fn as_bytes(&self) -> Result<&'a [u8], Error> { + fn as_bytes(&self) -> &'a [u8] { let FieldValue::Len(data) = self else { - return Err(Error::UnexpectedWireType); + panic!("Expected bytes to be a `Len` field"); }; - Ok(data) + data } - fn as_u64(&self) -> Result { + fn as_u64(&self) -> u64 { let FieldValue::Varint(value) = self else { - return Err(Error::UnexpectedWireType); + panic!("Expected `u64` to be a `Varint` field"); }; - Ok(*value) + *value + } + + #[allow(dead_code)] + fn as_i32(&self) -> i32 { + let FieldValue::I32(value) = self else { + panic!("Expected `i32` to be an `I32` field"); + }; + *value } } /// Parse a VARINT, returning the parsed value and the remaining bytes. -fn parse_varint(data: &[u8]) -> Result<(u64, &[u8]), Error> { +fn parse_varint(data: &[u8]) -> (u64, &[u8]) { for i in 0..7 { let Some(b) = data.get(i) else { - return Err(Error::InvalidVarint); + panic!("Not enough bytes for varint"); }; if b & 0x80 == 0 { // This is the last byte of the VARINT, so convert it to @@ -116,45 +103,45 @@ fn parse_varint(data: &[u8]) -> Result<(u64, &[u8]), Error> { for b in data[..=i].iter().rev() { value = (value << 7) | (b & 0x7f) as u64; } - return Ok((value, &data[i + 1..])); + return (value, &data[i + 1..]); } } // More than 7 bytes is invalid. - Err(Error::InvalidVarint) + panic!("Too many bytes for varint"); } /// Convert a tag into a field number and a WireType. -fn unpack_tag(tag: u64) -> Result<(u64, WireType), Error> { +fn unpack_tag(tag: u64) -> (u64, WireType) { let field_num = tag >> 3; - let wire_type = WireType::try_from(tag & 0x7)?; - Ok((field_num, wire_type)) + let wire_type = WireType::from(tag & 0x7); + (field_num, wire_type) } // ANCHOR_END: preliminaries // ANCHOR: parse_field /// Parse a field, returning the remaining bytes -fn parse_field(data: &[u8]) -> Result<(Field, &[u8]), Error> { - let (tag, remainder) = parse_varint(data)?; - let (field_num, wire_type) = unpack_tag(tag)?; +fn parse_field(data: &[u8]) -> (Field, &[u8]) { + let (tag, remainder) = parse_varint(data); + let (field_num, wire_type) = unpack_tag(tag); let (fieldvalue, remainder) = match wire_type { // ANCHOR_END: parse_field WireType::Varint => { - let (value, remainder) = parse_varint(remainder)?; + let (value, remainder) = parse_varint(remainder); (FieldValue::Varint(value), remainder) } WireType::Len => { - let (len, remainder) = parse_varint(remainder)?; - let len: usize = len.try_into()?; + let (len, remainder) = parse_varint(remainder); + let len: usize = len.try_into().expect("len not a valid `usize`"); if remainder.len() < len { - return Err(Error::UnexpectedEOF); + panic!("Unexpected EOF"); } let (value, remainder) = remainder.split_at(len); (FieldValue::Len(value), remainder) } WireType::I32 => { if remainder.len() < 4 { - return Err(Error::UnexpectedEOF); + panic!("Unexpected EOF"); } let (value, remainder) = remainder.split_at(4); // Unwrap error because `value` is definitely 4 bytes long. @@ -162,7 +149,7 @@ fn parse_field(data: &[u8]) -> Result<(Field, &[u8]), Error> { (FieldValue::I32(value), remainder) } }; - Ok((Field { field_num, value: fieldvalue }, remainder)) + (Field { field_num, value: fieldvalue }, remainder) } // ANCHOR: parse_message @@ -170,14 +157,14 @@ fn parse_field(data: &[u8]) -> Result<(Field, &[u8]), Error> { /// the message. /// /// The entire input is consumed. -fn parse_message<'a, T: ProtoMessage<'a>>(mut data: &'a [u8]) -> Result { +fn parse_message<'a, T: ProtoMessage<'a>>(mut data: &'a [u8]) -> T { let mut result = T::default(); while !data.is_empty() { - let parsed = parse_field(data)?; - result.add_field(parsed.0)?; + let parsed = parse_field(data); + result.add_field(parsed.0); data = parsed.1; } - Ok(result) + result } // ANCHOR_END: parse_message @@ -197,25 +184,23 @@ struct Person<'a> { // ANCHOR_END: message_types impl<'a> ProtoMessage<'a> for Person<'a> { - fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> { + fn add_field(&mut self, field: Field<'a>) { match field.field_num { - 1 => self.name = field.value.as_string()?, - 2 => self.id = field.value.as_u64()?, - 3 => self.phone.push(parse_message(field.value.as_bytes()?)?), + 1 => self.name = field.value.as_string(), + 2 => self.id = field.value.as_u64(), + 3 => self.phone.push(parse_message(field.value.as_bytes())), _ => {} // skip everything else } - Ok(()) } } impl<'a> ProtoMessage<'a> for PhoneNumber<'a> { - fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> { + fn add_field(&mut self, field: Field<'a>) { match field.field_num { - 1 => self.number = field.value.as_string()?, - 2 => self.type_ = field.value.as_string()?, + 1 => self.number = field.value.as_string(), + 2 => self.type_ = field.value.as_string(), _ => {} // skip everything else } - Ok(()) } } @@ -228,36 +213,7 @@ fn main() { 0x18, 0x0a, 0x0e, 0x2b, 0x31, 0x38, 0x30, 0x30, 0x2d, 0x38, 0x36, 0x37, 0x2d, 0x35, 0x33, 0x30, 0x38, 0x12, 0x06, 0x6d, 0x6f, 0x62, 0x69, 0x6c, 0x65, - ]) - .unwrap(); + ]); println!("{:#?}", person); } // ANCHOR_END: main - -// ANCHOR: tests -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn as_string() { - assert!(FieldValue::Varint(10).as_string().is_err()); - assert!(FieldValue::I32(10).as_string().is_err()); - assert_eq!(FieldValue::Len(b"hello").as_string().unwrap(), "hello"); - } - - #[test] - fn as_bytes() { - assert!(FieldValue::Varint(10).as_bytes().is_err()); - assert!(FieldValue::I32(10).as_bytes().is_err()); - assert_eq!(FieldValue::Len(b"hello").as_bytes().unwrap(), b"hello"); - } - - #[test] - fn as_u64() { - assert_eq!(FieldValue::Varint(10).as_u64().unwrap(), 10u64); - assert!(FieldValue::I32(10).as_u64().is_err()); - assert!(FieldValue::Len(b"hello").as_u64().is_err()); - } -} -// ANCHOR_END: tests diff --git a/src/lifetimes/solution.md b/src/lifetimes/solution.md index 2dc2bc58..b4a4c92c 100644 --- a/src/lifetimes/solution.md +++ b/src/lifetimes/solution.md @@ -1,7 +1,5 @@ # Solution - - -```rust,editable,compile_fail +```rust,editable {{#include exercise.rs:solution}} ```