From ddc418d1829e17a403316173c5c5e417b27f3d35 Mon Sep 17 00:00:00 2001 From: Geethapranay1 Date: Fri, 17 Apr 2026 16:58:03 +0530 Subject: [PATCH 1/2] perf(rust): pre-reserve buffer capacity for struct primitive fields --- rust/fory-core/src/buffer.rs | 242 +++++++++++++++++++++++++++ rust/fory-derive/src/object/util.rs | 105 ++++++++++++ rust/fory-derive/src/object/write.rs | 97 ++++++++--- 3 files changed, 425 insertions(+), 19 deletions(-) diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 4cab51acf1..b071fc6f1e 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -499,6 +499,248 @@ impl<'a> Writer<'a> { self.write_u64(combined); } } + + /// # Safety + #[inline(always)] + pub unsafe fn prepare_write(&mut self, max_bytes: usize) -> (*mut u8, usize) { + self.reserve(max_bytes); + let len = self.bf.len(); + (self.bf.as_mut_ptr().add(len), len) + } + + /// # Safety + #[inline(always)] + pub unsafe fn finish_write(&mut self, new_len: usize) { + debug_assert!(new_len <= self.bf.capacity()); + self.bf.set_len(new_len); + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_bool_at(ptr: *mut u8, value: bool) -> usize { + *ptr = value as u8; + 1 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_i8_at(ptr: *mut u8, value: i8) -> usize { + *ptr = value as u8; + 1 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_u8_at(ptr: *mut u8, value: u8) -> usize { + *ptr = value; + 1 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_i16_at(ptr: *mut u8, value: i16) -> usize { + Self::put_u16_at(ptr, value as u16) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_u16_at(ptr: *mut u8, value: u16) -> usize { + std::ptr::copy_nonoverlapping(&value.to_le() as *const u16 as *const u8, ptr, 2); + 2 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_i32_at(ptr: *mut u8, value: i32) -> usize { + Self::put_u32_at(ptr, value as u32) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_u32_at(ptr: *mut u8, value: u32) -> usize { + std::ptr::copy_nonoverlapping(&value.to_le() as *const u32 as *const u8, ptr, 4); + 4 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_i64_at(ptr: *mut u8, value: i64) -> usize { + Self::put_u64_at(ptr, value as u64) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_u64_at(ptr: *mut u8, value: u64) -> usize { + std::ptr::copy_nonoverlapping(&value.to_le() as *const u64 as *const u8, ptr, 8); + 8 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_f16_at(ptr: *mut u8, value: float16) -> usize { + Self::put_u16_at(ptr, value.to_bits()) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_f32_at(ptr: *mut u8, value: f32) -> usize { + std::ptr::copy_nonoverlapping(&value as *const f32 as *const u8, ptr, 4); + 4 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_f64_at(ptr: *mut u8, value: f64) -> usize { + std::ptr::copy_nonoverlapping(&value as *const f64 as *const u8, ptr, 8); + 8 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_i128_at(ptr: *mut u8, value: i128) -> usize { + Self::put_u128_at(ptr, value as u128) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_u128_at(ptr: *mut u8, value: u128) -> usize { + std::ptr::copy_nonoverlapping(&value.to_le() as *const u128 as *const u8, ptr, 16); + 16 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_varint32_at(ptr: *mut u8, value: i32) -> usize { + let zigzag = ((value as i64) << 1) ^ ((value as i64) >> 31); + Self::put_var_uint32_at(ptr, zigzag as u32) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_var_uint32_at(ptr: *mut u8, value: u32) -> usize { + if value < 0x80 { + *ptr = value as u8; + 1 + } else if value < 0x4000 { + *ptr = ((value as u8) & 0x7F) | 0x80; + *ptr.add(1) = (value >> 7) as u8; + 2 + } else if value < 0x200000 { + *ptr = ((value as u8) & 0x7F) | 0x80; + *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; + *ptr.add(2) = (value >> 14) as u8; + 3 + } else if value < 0x10000000 { + *ptr = ((value as u8) & 0x7F) | 0x80; + *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; + *ptr.add(2) = (((value >> 14) as u8) & 0x7F) | 0x80; + *ptr.add(3) = (value >> 21) as u8; + 4 + } else { + *ptr = ((value as u8) & 0x7F) | 0x80; + *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; + *ptr.add(2) = (((value >> 14) as u8) & 0x7F) | 0x80; + *ptr.add(3) = (((value >> 21) as u8) & 0x7F) | 0x80; + *ptr.add(4) = (value >> 28) as u8; + 5 + } + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_varint64_at(ptr: *mut u8, value: i64) -> usize { + let zigzag = ((value << 1) ^ (value >> 63)) as u64; + Self::put_var_uint64_at(ptr, zigzag) + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_var_uint64_at(ptr: *mut u8, value: u64) -> usize { + if value < 0x80 { + *ptr = value as u8; + return 1; + } + *ptr = ((value as u8) & 0x7F) | 0x80; + if value < 0x4000 { + *ptr.add(1) = (value >> 7) as u8; + return 2; + } + *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; + if value < 0x200000 { + *ptr.add(2) = (value >> 14) as u8; + return 3; + } + *ptr.add(2) = (((value >> 14) as u8) & 0x7F) | 0x80; + if value < 0x10000000 { + *ptr.add(3) = (value >> 21) as u8; + return 4; + } + *ptr.add(3) = (((value >> 21) as u8) & 0x7F) | 0x80; + if value < 0x800000000 { + *ptr.add(4) = (value >> 28) as u8; + return 5; + } + *ptr.add(4) = (((value >> 28) as u8) & 0x7F) | 0x80; + if value < 0x40000000000 { + *ptr.add(5) = (value >> 35) as u8; + return 6; + } + *ptr.add(5) = (((value >> 35) as u8) & 0x7F) | 0x80; + if value < 0x2000000000000 { + *ptr.add(6) = (value >> 42) as u8; + return 7; + } + *ptr.add(6) = (((value >> 42) as u8) & 0x7F) | 0x80; + if value < 0x100000000000000 { + *ptr.add(7) = (value >> 49) as u8; + return 8; + } + *ptr.add(7) = (((value >> 49) as u8) & 0x7F) | 0x80; + *ptr.add(8) = (value >> 56) as u8; + 9 + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_tagged_i64_at(ptr: *mut u8, value: i64) -> usize { + const HALF_MIN: i64 = i32::MIN as i64 / 2; + const HALF_MAX: i64 = i32::MAX as i64 / 2; + if (HALF_MIN..=HALF_MAX).contains(&value) { + let v = (value as i32) << 1; + std::ptr::copy_nonoverlapping(&v.to_le() as *const i32 as *const u8, ptr, 4); + 4 + } else { + *ptr = 0b1; + std::ptr::copy_nonoverlapping(&value.to_le() as *const i64 as *const u8, ptr.add(1), 8); + 9 + } + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_tagged_u64_at(ptr: *mut u8, value: u64) -> usize { + if value <= i32::MAX as u64 { + let v = (value as u32) << 1; + std::ptr::copy_nonoverlapping(&v.to_le() as *const u32 as *const u8, ptr, 4); + 4 + } else { + *ptr = 0b1; + std::ptr::copy_nonoverlapping(&value.to_le() as *const u64 as *const u8, ptr.add(1), 8); + 9 + } + } + + /// # Safety + #[inline(always)] + pub unsafe fn put_usize_at(ptr: *mut u8, value: usize) -> usize { + const SIZE: usize = std::mem::size_of::(); + match SIZE { + 2 => Self::put_u16_at(ptr, value as u16), + 4 => Self::put_var_uint32_at(ptr, value as u32), + 8 => Self::put_var_uint64_at(ptr, value as u64), + _ => unreachable!(), + } + } } #[derive(Default)] diff --git a/rust/fory-derive/src/object/util.rs b/rust/fory-derive/src/object/util.rs index 07c40b4959..5c998ab362 100644 --- a/rust/fory-derive/src/object/util.rs +++ b/rust/fory-derive/src/object/util.rs @@ -835,6 +835,111 @@ pub(super) fn get_primitive_writer_method_with_encoding( get_primitive_writer_method(type_name) } +pub(super) fn get_max_primitive_bytes( + type_name: &str, + meta: &super::field_meta::ForyFieldMeta, +) -> usize { + use fory_core::types::TypeId; + + if type_name == "i32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 { + return 4; + } + } + return 5; + } + + if type_name == "u32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT32 as i16 { + return 4; + } + } + return 5; + } + + if type_name == "u64" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT64 as i16 { + return 8; + } else if type_id == TypeId::TAGGED_UINT64 as i16 { + return 9; + } + } + return 9; + } + + if type_name == "i64" { + return 9; + } + + match type_name { + "bool" | "i8" | "u8" => 1, + "i16" | "u16" | "float16" => 2, + "f32" => 4, + "f64" => 8, + "i128" | "u128" => 16, + "isize" | "usize" => 9, + _ => 0, + } +} + +pub(super) fn get_put_at_method_with_encoding( + type_name: &str, + meta: &super::field_meta::ForyFieldMeta, +) -> &'static str { + use fory_core::types::TypeId; + + if type_name == "i32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 { + return "put_i32_at"; + } + } + return "put_varint32_at"; + } + + if type_name == "u32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT32 as i16 { + return "put_u32_at"; + } + } + return "put_var_uint32_at"; + } + + if type_name == "u64" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT64 as i16 { + return "put_u64_at"; + } else if type_id == TypeId::TAGGED_UINT64 as i16 { + return "put_tagged_u64_at"; + } + } + return "put_var_uint64_at"; + } + + if type_name == "i64" { + return "put_varint64_at"; + } + + match type_name { + "bool" => "put_bool_at", + "i8" => "put_i8_at", + "u8" => "put_u8_at", + "i16" => "put_i16_at", + "u16" => "put_u16_at", + "f32" => "put_f32_at", + "f64" => "put_f64_at", + "float16" => "put_f16_at", + "i128" => "put_i128_at", + "u128" => "put_u128_at", + "usize" => "put_usize_at", + _ => panic!("unsupported primitive type for put_at: {type_name}"), + } +} + /// Get the reader method name for a primitive numeric type /// Panics if type_name is not a primitive type pub(super) fn get_primitive_reader_method(type_name: &str) -> &'static str { diff --git a/rust/fory-derive/src/object/write.rs b/rust/fory-derive/src/object/write.rs index 8300e8784b..d1c29654a1 100644 --- a/rust/fory-derive/src/object/write.rs +++ b/rust/fory-derive/src/object/write.rs @@ -19,10 +19,10 @@ use super::field_meta::parse_field_meta; use super::util::{ classify_trait_object_field, create_wrapper_types_arc, create_wrapper_types_rc, determine_field_ref_mode, extract_type_name, gen_struct_version_hash_ts, get_field_accessor, - get_field_name, get_filtered_source_fields_iter, get_option_inner_primitive_name, - get_primitive_writer_method_with_encoding, get_struct_name, get_type_id_by_type_ast, - is_debug_enabled, is_direct_primitive_type, is_option_encoding_primitive, FieldRefMode, - StructField, + get_field_name, get_filtered_source_fields_iter, get_max_primitive_bytes, + get_option_inner_primitive_name, get_primitive_writer_method_with_encoding, + get_put_at_method_with_encoding, get_struct_name, get_type_id_by_type_ast, is_debug_enabled, + is_direct_primitive_type, is_option_encoding_primitive, FieldRefMode, StructField, }; use crate::util::SourceField; use fory_core::types::TypeId; @@ -280,15 +280,10 @@ fn gen_write_field_impl( <#ty as fory_core::Serializer>::fory_write_data(&#value_ts, context)?; } } else { - // Numeric primitives: use direct buffer methods - // For u32/u64, consider encoding attributes let writer_method = get_primitive_writer_method_with_encoding(&type_name, &meta); let writer_ident = syn::Ident::new(writer_method, proc_macro2::Span::call_site()); - // For primitives: - // - use_self=true: #value_ts is `self.field`, which is T (copy happens automatically) - // - use_self=false: #value_ts is `field` from pattern match on &self, which is &T let value_expr = if use_self { quote! { #value_ts } } else { @@ -359,18 +354,82 @@ fn gen_write_field_impl( pub fn gen_write_data(source_fields: &[SourceField<'_>]) -> TokenStream { let fields: Vec<&Field> = source_fields.iter().map(|sf| sf.field).collect(); - let write_fields_ts: Vec<_> = get_filtered_source_fields_iter(source_fields) - .map(|sf| gen_write_field_with_index(sf.field, sf.original_index, true)) - .collect(); - + let filtered: Vec<_> = get_filtered_source_fields_iter(source_fields).collect(); let version_hash_ts = gen_struct_version_hash_ts(&fields); - quote! { - if context.is_check_struct_version() { - let version_hash: i32 = #version_hash_ts; - context.writer.write_i32(version_hash); + + let fast_count = if !is_debug_enabled() { + filtered + .iter() + .take_while(|sf| { + let ref_mode = determine_field_ref_mode(sf.field); + ref_mode == FieldRefMode::None + && is_direct_primitive_type(&sf.field.ty) + && extract_type_name(&sf.field.ty) != "String" + }) + .count() + } else { + 0 + }; + + if fast_count > 0 { + let (fast, rest) = filtered.split_at(fast_count); + + let max_bytes: usize = fast + .iter() + .map(|sf| { + let tn = extract_type_name(&sf.field.ty); + let meta = parse_field_meta(sf.field).unwrap_or_default(); + get_max_primitive_bytes(&tn, &meta) + }) + .sum(); + + let put_stmts: Vec = fast + .iter() + .map(|sf| { + let value_ts = get_field_accessor(sf.field, sf.original_index, true); + let tn = extract_type_name(&sf.field.ty); + let meta = parse_field_meta(sf.field).unwrap_or_default(); + let method = get_put_at_method_with_encoding(&tn, &meta); + let method_ident = syn::Ident::new(method, proc_macro2::Span::call_site()); + quote! { + offset += fory_core::buffer::Writer::#method_ident(ptr.add(offset), #value_ts); + } + }) + .collect(); + + let remaining_ts: Vec<_> = rest + .iter() + .map(|sf| gen_write_field_with_index(sf.field, sf.original_index, true)) + .collect(); + + quote! { + if context.is_check_struct_version() { + let version_hash: i32 = #version_hash_ts; + context.writer.write_i32(version_hash); + } + unsafe { + let (ptr, base_len) = context.writer.prepare_write(#max_bytes); + let mut offset = 0usize; + #(#put_stmts)* + context.writer.finish_write(base_len + offset); + } + #(#remaining_ts)* + Ok(()) + } + } else { + let write_fields_ts: Vec<_> = filtered + .iter() + .map(|sf| gen_write_field_with_index(sf.field, sf.original_index, true)) + .collect(); + + quote! { + if context.is_check_struct_version() { + let version_hash: i32 = #version_hash_ts; + context.writer.write_i32(version_hash); + } + #(#write_fields_ts)* + Ok(()) } - #(#write_fields_ts)* - Ok(()) } } From b175a09635b57777f074622f35ee3933e54afe4f Mon Sep 17 00:00:00 2001 From: Geethapranay1 Date: Sat, 18 Apr 2026 16:13:18 +0530 Subject: [PATCH 2/2] refactor(rust): move put_at methods from buffer.rs to unsafe_util.rs --- rust/fory-core/src/buffer.rs | 227 --------------------- rust/fory-core/src/lib.rs | 1 + rust/fory-core/src/unsafe_util.rs | 293 +++++++++++++++++++++++++++ rust/fory-derive/src/object/write.rs | 2 +- 4 files changed, 295 insertions(+), 228 deletions(-) create mode 100644 rust/fory-core/src/unsafe_util.rs diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index b071fc6f1e..5e223c12e7 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -514,233 +514,6 @@ impl<'a> Writer<'a> { debug_assert!(new_len <= self.bf.capacity()); self.bf.set_len(new_len); } - - /// # Safety - #[inline(always)] - pub unsafe fn put_bool_at(ptr: *mut u8, value: bool) -> usize { - *ptr = value as u8; - 1 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_i8_at(ptr: *mut u8, value: i8) -> usize { - *ptr = value as u8; - 1 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_u8_at(ptr: *mut u8, value: u8) -> usize { - *ptr = value; - 1 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_i16_at(ptr: *mut u8, value: i16) -> usize { - Self::put_u16_at(ptr, value as u16) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_u16_at(ptr: *mut u8, value: u16) -> usize { - std::ptr::copy_nonoverlapping(&value.to_le() as *const u16 as *const u8, ptr, 2); - 2 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_i32_at(ptr: *mut u8, value: i32) -> usize { - Self::put_u32_at(ptr, value as u32) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_u32_at(ptr: *mut u8, value: u32) -> usize { - std::ptr::copy_nonoverlapping(&value.to_le() as *const u32 as *const u8, ptr, 4); - 4 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_i64_at(ptr: *mut u8, value: i64) -> usize { - Self::put_u64_at(ptr, value as u64) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_u64_at(ptr: *mut u8, value: u64) -> usize { - std::ptr::copy_nonoverlapping(&value.to_le() as *const u64 as *const u8, ptr, 8); - 8 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_f16_at(ptr: *mut u8, value: float16) -> usize { - Self::put_u16_at(ptr, value.to_bits()) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_f32_at(ptr: *mut u8, value: f32) -> usize { - std::ptr::copy_nonoverlapping(&value as *const f32 as *const u8, ptr, 4); - 4 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_f64_at(ptr: *mut u8, value: f64) -> usize { - std::ptr::copy_nonoverlapping(&value as *const f64 as *const u8, ptr, 8); - 8 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_i128_at(ptr: *mut u8, value: i128) -> usize { - Self::put_u128_at(ptr, value as u128) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_u128_at(ptr: *mut u8, value: u128) -> usize { - std::ptr::copy_nonoverlapping(&value.to_le() as *const u128 as *const u8, ptr, 16); - 16 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_varint32_at(ptr: *mut u8, value: i32) -> usize { - let zigzag = ((value as i64) << 1) ^ ((value as i64) >> 31); - Self::put_var_uint32_at(ptr, zigzag as u32) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_var_uint32_at(ptr: *mut u8, value: u32) -> usize { - if value < 0x80 { - *ptr = value as u8; - 1 - } else if value < 0x4000 { - *ptr = ((value as u8) & 0x7F) | 0x80; - *ptr.add(1) = (value >> 7) as u8; - 2 - } else if value < 0x200000 { - *ptr = ((value as u8) & 0x7F) | 0x80; - *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; - *ptr.add(2) = (value >> 14) as u8; - 3 - } else if value < 0x10000000 { - *ptr = ((value as u8) & 0x7F) | 0x80; - *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; - *ptr.add(2) = (((value >> 14) as u8) & 0x7F) | 0x80; - *ptr.add(3) = (value >> 21) as u8; - 4 - } else { - *ptr = ((value as u8) & 0x7F) | 0x80; - *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; - *ptr.add(2) = (((value >> 14) as u8) & 0x7F) | 0x80; - *ptr.add(3) = (((value >> 21) as u8) & 0x7F) | 0x80; - *ptr.add(4) = (value >> 28) as u8; - 5 - } - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_varint64_at(ptr: *mut u8, value: i64) -> usize { - let zigzag = ((value << 1) ^ (value >> 63)) as u64; - Self::put_var_uint64_at(ptr, zigzag) - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_var_uint64_at(ptr: *mut u8, value: u64) -> usize { - if value < 0x80 { - *ptr = value as u8; - return 1; - } - *ptr = ((value as u8) & 0x7F) | 0x80; - if value < 0x4000 { - *ptr.add(1) = (value >> 7) as u8; - return 2; - } - *ptr.add(1) = (((value >> 7) as u8) & 0x7F) | 0x80; - if value < 0x200000 { - *ptr.add(2) = (value >> 14) as u8; - return 3; - } - *ptr.add(2) = (((value >> 14) as u8) & 0x7F) | 0x80; - if value < 0x10000000 { - *ptr.add(3) = (value >> 21) as u8; - return 4; - } - *ptr.add(3) = (((value >> 21) as u8) & 0x7F) | 0x80; - if value < 0x800000000 { - *ptr.add(4) = (value >> 28) as u8; - return 5; - } - *ptr.add(4) = (((value >> 28) as u8) & 0x7F) | 0x80; - if value < 0x40000000000 { - *ptr.add(5) = (value >> 35) as u8; - return 6; - } - *ptr.add(5) = (((value >> 35) as u8) & 0x7F) | 0x80; - if value < 0x2000000000000 { - *ptr.add(6) = (value >> 42) as u8; - return 7; - } - *ptr.add(6) = (((value >> 42) as u8) & 0x7F) | 0x80; - if value < 0x100000000000000 { - *ptr.add(7) = (value >> 49) as u8; - return 8; - } - *ptr.add(7) = (((value >> 49) as u8) & 0x7F) | 0x80; - *ptr.add(8) = (value >> 56) as u8; - 9 - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_tagged_i64_at(ptr: *mut u8, value: i64) -> usize { - const HALF_MIN: i64 = i32::MIN as i64 / 2; - const HALF_MAX: i64 = i32::MAX as i64 / 2; - if (HALF_MIN..=HALF_MAX).contains(&value) { - let v = (value as i32) << 1; - std::ptr::copy_nonoverlapping(&v.to_le() as *const i32 as *const u8, ptr, 4); - 4 - } else { - *ptr = 0b1; - std::ptr::copy_nonoverlapping(&value.to_le() as *const i64 as *const u8, ptr.add(1), 8); - 9 - } - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_tagged_u64_at(ptr: *mut u8, value: u64) -> usize { - if value <= i32::MAX as u64 { - let v = (value as u32) << 1; - std::ptr::copy_nonoverlapping(&v.to_le() as *const u32 as *const u8, ptr, 4); - 4 - } else { - *ptr = 0b1; - std::ptr::copy_nonoverlapping(&value.to_le() as *const u64 as *const u8, ptr.add(1), 8); - 9 - } - } - - /// # Safety - #[inline(always)] - pub unsafe fn put_usize_at(ptr: *mut u8, value: usize) -> usize { - const SIZE: usize = std::mem::size_of::(); - match SIZE { - 2 => Self::put_u16_at(ptr, value as u16), - 4 => Self::put_var_uint32_at(ptr, value as u32), - 8 => Self::put_var_uint64_at(ptr, value as u64), - _ => unreachable!(), - } - } } #[derive(Default)] diff --git a/rust/fory-core/src/lib.rs b/rust/fory-core/src/lib.rs index 976a760af6..14128c5a69 100644 --- a/rust/fory-core/src/lib.rs +++ b/rust/fory-core/src/lib.rs @@ -187,6 +187,7 @@ pub mod row; pub mod serializer; pub mod types; pub use float16::float16 as Float16; +pub mod unsafe_util; pub mod util; // Re-export paste for use in macros diff --git a/rust/fory-core/src/unsafe_util.rs b/rust/fory-core/src/unsafe_util.rs new file mode 100644 index 0000000000..5ac089089a --- /dev/null +++ b/rust/fory-core/src/unsafe_util.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::float16::float16; + +/// # Safety +#[inline(always)] +pub unsafe fn put_bool_at(ptr: *mut u8, value: bool) -> usize { + *ptr = value as u8; + 1 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i8_at(ptr: *mut u8, value: i8) -> usize { + *ptr = value as u8; + 1 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u8_at(ptr: *mut u8, value: u8) -> usize { + *ptr = value; + 1 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i16_at(ptr: *mut u8, value: i16) -> usize { + put_u16_at(ptr, value as u16) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u16_at(ptr: *mut u8, value: u16) -> usize { + (ptr as *mut u16).write_unaligned(value.to_le()); + 2 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i32_at(ptr: *mut u8, value: i32) -> usize { + put_u32_at(ptr, value as u32) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u32_at(ptr: *mut u8, value: u32) -> usize { + (ptr as *mut u32).write_unaligned(value.to_le()); + 4 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i64_at(ptr: *mut u8, value: i64) -> usize { + put_u64_at(ptr, value as u64) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u64_at(ptr: *mut u8, value: u64) -> usize { + (ptr as *mut u64).write_unaligned(value.to_le()); + 8 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_f16_at(ptr: *mut u8, value: float16) -> usize { + put_u16_at(ptr, value.to_bits()) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_f32_at(ptr: *mut u8, value: f32) -> usize { + (ptr as *mut f32).write_unaligned(value); + 4 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_f64_at(ptr: *mut u8, value: f64) -> usize { + (ptr as *mut f64).write_unaligned(value); + 8 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i128_at(ptr: *mut u8, value: i128) -> usize { + put_u128_at(ptr, value as u128) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u128_at(ptr: *mut u8, value: u128) -> usize { + (ptr as *mut u128).write_unaligned(value.to_le()); + 16 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_varint32_at(ptr: *mut u8, value: i32) -> usize { + let zigzag = ((value as i64) << 1) ^ ((value as i64) >> 31); + put_var_uint32_at(ptr, zigzag as u32) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_var_uint32_at(ptr: *mut u8, value: u32) -> usize { + if value < 0x80 { + *ptr = value as u8; + 1 + } else if value < 0x4000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (value >> 7) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + 2 + } else if value < 0x200000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (value >> 14) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + *ptr.add(2) = u3; + 3 + } else if value < 0x10000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (value >> 21) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + 4 + } else { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (value >> 28) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + *ptr.add(4) = u5; + 5 + } +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_varint64_at(ptr: *mut u8, value: i64) -> usize { + let zigzag = ((value << 1) ^ (value >> 63)) as u64; + put_var_uint64_at(ptr, zigzag) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_var_uint64_at(ptr: *mut u8, value: u64) -> usize { + if value < 0x80 { + *ptr = value as u8; + return 1; + } + if value < 0x4000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (value >> 7) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + return 2; + } + if value < 0x200000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (value >> 14) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + *ptr.add(2) = u3; + return 3; + } + if value < 0x10000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (value >> 21) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + return 4; + } + if value < 0x800000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (value >> 28) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + *ptr.add(4) = u5; + return 5; + } + if value < 0x40000000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (value >> 35) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + (ptr.add(4) as *mut u16).write_unaligned(u16::from_ne_bytes([u5, u6])); + return 6; + } + if value < 0x2000000000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (((value >> 35) as u8) & 0x7F) | 0x80; + let u7 = (value >> 42) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + (ptr.add(4) as *mut u16).write_unaligned(u16::from_ne_bytes([u5, u6])); + *ptr.add(6) = u7; + return 7; + } + if value < 0x100000000000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (((value >> 35) as u8) & 0x7F) | 0x80; + let u7 = (((value >> 42) as u8) & 0x7F) | 0x80; + let u8v = (value >> 49) as u8; + (ptr as *mut u64).write_unaligned(u64::from_ne_bytes([u1, u2, u3, u4, u5, u6, u7, u8v])); + return 8; + } + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (((value >> 35) as u8) & 0x7F) | 0x80; + let u7 = (((value >> 42) as u8) & 0x7F) | 0x80; + let u8v = (((value >> 49) as u8) & 0x7F) | 0x80; + let u9 = (value >> 56) as u8; + (ptr as *mut u64).write_unaligned(u64::from_ne_bytes([u1, u2, u3, u4, u5, u6, u7, u8v])); + *ptr.add(8) = u9; + 9 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_tagged_i64_at(ptr: *mut u8, value: i64) -> usize { + const HALF_MIN: i64 = i32::MIN as i64 / 2; + const HALF_MAX: i64 = i32::MAX as i64 / 2; + if (HALF_MIN..=HALF_MAX).contains(&value) { + let v = (value as i32) << 1; + (ptr as *mut i32).write_unaligned(v.to_le()); + 4 + } else { + *ptr = 0b1; + (ptr.add(1) as *mut i64).write_unaligned(value.to_le()); + 9 + } +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_tagged_u64_at(ptr: *mut u8, value: u64) -> usize { + if value <= i32::MAX as u64 { + let v = (value as u32) << 1; + (ptr as *mut u32).write_unaligned(v.to_le()); + 4 + } else { + *ptr = 0b1; + (ptr.add(1) as *mut u64).write_unaligned(value.to_le()); + 9 + } +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_usize_at(ptr: *mut u8, value: usize) -> usize { + const SIZE: usize = std::mem::size_of::(); + match SIZE { + 2 => put_u16_at(ptr, value as u16), + 4 => put_var_uint32_at(ptr, value as u32), + 8 => put_var_uint64_at(ptr, value as u64), + _ => unreachable!(), + } +} diff --git a/rust/fory-derive/src/object/write.rs b/rust/fory-derive/src/object/write.rs index d1c29654a1..1a4031c33e 100644 --- a/rust/fory-derive/src/object/write.rs +++ b/rust/fory-derive/src/object/write.rs @@ -392,7 +392,7 @@ pub fn gen_write_data(source_fields: &[SourceField<'_>]) -> TokenStream { let method = get_put_at_method_with_encoding(&tn, &meta); let method_ident = syn::Ident::new(method, proc_macro2::Span::call_site()); quote! { - offset += fory_core::buffer::Writer::#method_ident(ptr.add(offset), #value_ts); + offset += fory_core::unsafe_util::#method_ident(ptr.add(offset), #value_ts); } }) .collect();