diff --git a/python/README.md b/python/README.md index ee5215f0ec..970ef69a94 100644 --- a/python/README.md +++ b/python/README.md @@ -490,6 +490,34 @@ fory.register(Person.class, "example.Person"); Person person = (Person) fory.deserialize(binaryData); ``` +### BFloat16 Support + +`pyfory` supports `bfloat16` scalar values and `bfloat16` arrays in xlang mode: + +- Scalar type: `pyfory.BFloat16` (type id `18`) +- Array type: `pyfory.BFloat16Array` (type id `54`) + +```python +import pyfory +from pyfory import BFloat16, BFloat16Array + +fory = pyfory.Fory(xlang=True, ref=False, strict=True) + +# Scalar bfloat16 +v = BFloat16(3.1415926) +data = fory.serialize(v) +out = fory.deserialize(data) +print(float(out)) + +# bfloat16 array +arr = BFloat16Array([1.0, 2.5, -3.25]) +data = fory.serialize(arr) +out = fory.deserialize(data) +print(out) +``` + +`BFloat16Array` stores values in a packed `array('H')` representation and writes bytes in little-endian order for cross-language compatibility. + ## 📊 Row Format - Zero-Copy Processing Apache Fury™ provides a random-access row format that enables reading nested fields from binary data without full deserialization. This drastically reduces overhead when working with large objects where only partial data access is needed. The format also supports memory-mapped files for ultra-low memory footprint. diff --git a/python/pyfory/__init__.py b/python/pyfory/__init__.py index 638ee8bd02..637857f217 100644 --- a/python/pyfory/__init__.py +++ b/python/pyfory/__init__.py @@ -52,6 +52,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -90,6 +91,8 @@ tagged_uint64, float32, float64, + bfloat16 as bfloat16_type, + bfloat16_array, int8_array, uint8_array, int16_array, @@ -119,6 +122,13 @@ ) from pyfory.policy import DeserializationPolicy # noqa: F401 # pylint: disable=unused-import +# BFloat16 support +from pyfory.bfloat16 import bfloat16 # noqa: F401 +from pyfory.bfloat16_array import BFloat16Array # noqa: F401 + +# Keep compatibility with existing API naming. +BFloat16 = bfloat16 + __version__ = "0.16.0.dev0" __all__ = [ @@ -152,6 +162,10 @@ "tagged_uint64", "float32", "float64", + "BFloat16", + "BFloat16Array", + "bfloat16", + "bfloat16_array", "int8_array", "uint8_array", "int16_array", @@ -193,6 +207,7 @@ "TaggedUint64Serializer", "Float32Serializer", "Float64Serializer", + "BFloat16Serializer", "StringSerializer", "DateSerializer", "TimestampSerializer", diff --git a/python/pyfory/bfloat16.pxd b/python/pyfory/bfloat16.pxd new file mode 100644 index 0000000000..54f7318970 --- /dev/null +++ b/python/pyfory/bfloat16.pxd @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from libc.stdint cimport uint16_t + +cdef class bfloat16: + cdef uint16_t _bits + + +cdef bfloat16 bfloat16_from_bits(uint16_t bits) \ No newline at end of file diff --git a/python/pyfory/bfloat16.pyx b/python/pyfory/bfloat16.pyx new file mode 100644 index 0000000000..2779611efd --- /dev/null +++ b/python/pyfory/bfloat16.pyx @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport uint16_t, uint32_t +from libc.string cimport memcpy + +cdef inline uint16_t float32_to_bfloat16_bits(float value) nogil: + cdef uint32_t f32_bits + cdef uint16_t bf16_bits + cdef uint16_t truncated + memcpy(&f32_bits, &value, 4) + # Preserve NaN payloads and force the quiet-NaN bit so signaling NaNs do + # not collapse into infinity when the lower 16 payload bits are truncated. + if (f32_bits & 0x7FFFFFFF) > 0x7F800000: + return ((f32_bits >> 16)) | 0x0040 + bf16_bits = (f32_bits >> 16) + truncated = (f32_bits & 0xFFFF) + if truncated > 0x8000: + bf16_bits += 1 + if (bf16_bits & 0x7F80) == 0x7F80: + bf16_bits = (bf16_bits & 0x8000) | 0x7F80 + elif truncated == 0x8000 and (bf16_bits & 1): + bf16_bits += 1 + if (bf16_bits & 0x7F80) == 0x7F80: + bf16_bits = (bf16_bits & 0x8000) | 0x7F80 + return bf16_bits + +cdef inline float bfloat16_bits_to_float32(uint16_t bits) nogil: + cdef uint32_t f32_bits = bits << 16 + cdef float result + memcpy(&result, &f32_bits, 4) + return result + + +cdef bfloat16 bfloat16_from_bits(uint16_t bits): + cdef bfloat16 value = bfloat16.__new__(bfloat16) + value._bits = bits + return value + + +cdef class bfloat16: + def __init__(self, value): + if isinstance(value, bfloat16): + self._bits = (value)._bits + else: + self._bits = float32_to_bfloat16_bits(float(value)) + + @staticmethod + def from_bits(uint16_t bits): + return bfloat16_from_bits(bits) + + def to_bits(self): + return self._bits + + def to_float32(self): + return bfloat16_bits_to_float32(self._bits) + + def __float__(self): + return float(self.to_float32()) + + def __repr__(self): + return f"bfloat16({self.to_float32()})" + + def __str__(self): + return str(self.to_float32()) + + def __eq__(self, other): + if isinstance(other, bfloat16): + if self.is_nan() or (other).is_nan(): + return False + if self.is_zero() and (other).is_zero(): + return True + return self._bits == (other)._bits + return False + + def __hash__(self): + if self.is_zero(): + return hash(0) + return hash(self._bits) + + def is_nan(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + cdef uint16_t mant = self._bits & 0x7F + return exp == 0xFF and mant != 0 + + def is_inf(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + cdef uint16_t mant = self._bits & 0x7F + return exp == 0xFF and mant == 0 + + def is_zero(self): + return (self._bits & 0x7FFF) == 0 + + def is_finite(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + return exp != 0xFF + + def is_normal(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + return exp != 0 and exp != 0xFF + + def is_subnormal(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + cdef uint16_t mant = self._bits & 0x7F + return exp == 0 and mant != 0 + + def signbit(self): + return (self._bits & 0x8000) != 0 + + +# Backward-compatible alias for existing user code. +BFloat16 = bfloat16 diff --git a/python/pyfory/bfloat16_array.py b/python/pyfory/bfloat16_array.py new file mode 100644 index 0000000000..5287e4be15 --- /dev/null +++ b/python/pyfory/bfloat16_array.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array + +from pyfory.bfloat16 import bfloat16 +from pyfory.utils import is_little_endian + + +class BFloat16Array: + def __init__(self, values=None): + if values is None: + self._data = array.array("H") + elif isinstance(values, BFloat16Array): + self._data = array.array("H", values._data) + elif isinstance(values, array.array) and values.typecode == "H": + self._data = array.array("H", values) + else: + self._data = array.array( + "H", + (v.to_bits() if isinstance(v, bfloat16) else bfloat16(v).to_bits() for v in values), + ) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index): + return bfloat16.from_bits(self._data[index]) + + def __setitem__(self, index, value): + if isinstance(value, bfloat16): + self._data[index] = value.to_bits() + else: + self._data[index] = bfloat16(value).to_bits() + + def __iter__(self): + for bits in self._data: + yield bfloat16.from_bits(bits) + + def __repr__(self): + return f"BFloat16Array([{', '.join(str(bf16) for bf16 in self)}])" + + def __eq__(self, other): + if not isinstance(other, BFloat16Array): + return False + return self._data == other._data + + def append(self, value): + if isinstance(value, bfloat16): + self._data.append(value.to_bits()) + else: + self._data.append(bfloat16(value).to_bits()) + + def extend(self, values): + if isinstance(values, BFloat16Array): + self._data.extend(values._data) + return + for value in values: + self.append(value) + + @property + def itemsize(self): + return 2 + + def tobytes(self): + if is_little_endian: + return self._data.tobytes() + data = array.array("H", self._data) + data.byteswap() + return data.tobytes() + + def to_bits_array(self): + return array.array("H", self._data) + + @classmethod + def from_bits_array(cls, values): + arr = cls() + arr._data = array.array("H", values) + return arr + + @classmethod + def frombytes(cls, data): + if len(data) % 2 != 0: + raise ValueError("bfloat16 byte payload length must be a multiple of 2") + arr = cls() + arr._data = array.array("H") + arr._data.frombytes(data) + if not is_little_endian: + arr._data.byteswap() + return arr diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 9424189bf7..9e8379c984 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -20,6 +20,7 @@ from cython.operator cimport dereference as deref from libcpp.string cimport string as c_string from libc.stdint cimport * from libcpp cimport bool as c_bool +from pyfory.bfloat16 cimport bfloat16, bfloat16_from_bits from pyfory.includes.libutil cimport( CBuffer, COutputStream, allocate_buffer, get_bit as c_get_bit, set_bit as c_set_bit, clear_bit as c_clear_bit, set_bit_to as c_set_bit_to, CError, CErrorCode, CResultVoidError, utf16_has_surrogate_pairs @@ -380,6 +381,14 @@ cdef class Buffer: cpdef inline write_float64(self, double value): self.c_buffer.write_double(value) + cpdef inline write_bfloat16(self, uint16_t value): + self.c_buffer.write_uint16(value) + + cpdef inline bfloat16 read_bfloat16(self): + cdef uint16_t value = self.c_buffer.read_uint16(self._error) + self._raise_if_error() + return bfloat16_from_bits(value) + cpdef put_buffer(self, uint32_t offset, v, int32_t src_index, int32_t length): if length == 0: # access an emtpy buffer may raise out-of-bound exception. return diff --git a/python/pyfory/codegen.py b/python/pyfory/codegen.py index 820484de7c..6b9d72d684 100644 --- a/python/pyfory/codegen.py +++ b/python/pyfory/codegen.py @@ -58,6 +58,7 @@ "write_nullable_pyfloat64", "read_nullable_pyfloat64", ), + "bfloat16": ("write_bfloat16", "read_bfloat16", "write_nullable_bfloat16", "read_nullable_bfloat16"), } @@ -144,6 +145,8 @@ def compile_function( context["read_nullable_pyfloat64"] = serialization.read_nullable_pyfloat64 context["write_nullable_pystr"] = serialization.write_nullable_pystr context["read_nullable_pystr"] = serialization.read_nullable_pystr + context["write_nullable_bfloat16"] = serialization.write_nullable_bfloat16 + context["read_nullable_bfloat16"] = serialization.read_nullable_bfloat16 stmts = [f"{ident(statement)}" for statement in stmts] # Sanitize the function name to ensure it is valid Python syntax sanitized_function_name = _sanitize_function_name(function_name) diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index 35c0a7946a..52a455b915 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -205,7 +205,7 @@ def _read_same_type_no_ref(self, buffer, len_, collection_, serializer): for _ in range(len_): self._add_element( collection_, - self.fory.read_no_ref(buffer, serializer=serializer), + self.fory.read_no_ref(buffer, serializer), ) self.fory.dec_depth() @@ -217,7 +217,7 @@ def _read_same_type_has_null(self, buffer, len_, collection_, serializer): else: self._add_element( collection_, - self.fory.read_no_ref(buffer, serializer=serializer), + self.fory.read_no_ref(buffer, serializer), ) self.fory.dec_depth() @@ -249,7 +249,7 @@ def _read_different_types(self, buffer, len_, collection_, collect_flag): if typeinfo is None: elem = None else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) self._add_element(collection_, elem) else: # When ref tracking is disabled but has nulls, read null flag first @@ -262,7 +262,7 @@ def _read_different_types(self, buffer, len_, collection_, collect_flag): if typeinfo is None: elem = None else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) self._add_element(collection_, elem) self.fory.dec_depth() @@ -587,7 +587,7 @@ def _read_obj(self, serializer, buffer): return serializer.read(buffer) def _read_obj_no_ref(self, serializer, buffer): - return self.fory.read_no_ref(buffer, serializer=serializer) + return self.fory.read_no_ref(buffer, serializer) SubMapSerializer = MapSerializer diff --git a/python/pyfory/format/__init__.py b/python/pyfory/format/__init__.py index 6c9fb205d8..d2732299af 100644 --- a/python/pyfory/format/__init__.py +++ b/python/pyfory/format/__init__.py @@ -36,6 +36,7 @@ int32, int64, float16, + bfloat16, float32, float64, utf8, diff --git a/python/pyfory/format/schema.pxi b/python/pyfory/format/schema.pxi index 84c859cf12..b155b19d17 100644 --- a/python/pyfory/format/schema.pxi +++ b/python/pyfory/format/schema.pxi @@ -42,6 +42,7 @@ from pyfory.includes.libformat cimport ( ) + # Create Python-accessible TypeId enum # The CTypeId enum from libformat.pxd is only accessible from Cython class TypeId: @@ -417,6 +418,12 @@ def float16(): """Create a 16-bit floating point type.""" return DataType.wrap(c_float16()) +def bfloat16(): + """Create a 16-bit brain floating point type.""" + # TODO: Use c_bfloat16() when C++ row format supports bfloat16 + # For now, use float16 as a temporary workaround since C++ doesn't have bfloat16() yet + return DataType.wrap(c_float16()) + def float32(): """Create a 32-bit floating point type.""" return DataType.wrap(c_float32()) diff --git a/python/pyfory/format/schema.py b/python/pyfory/format/schema.py index 18baa20378..68309b15c0 100644 --- a/python/pyfory/format/schema.py +++ b/python/pyfory/format/schema.py @@ -56,6 +56,8 @@ def arrow_type_to_fory_type_id(arrow_type): # Floating point types if pa_types.is_float16(arrow_type): return 17 # FLOAT16 + if hasattr(pa_types, "is_bfloat16") and pa_types.is_bfloat16(arrow_type): + return 18 # BFLOAT16 if pa_types.is_float32(arrow_type): return 19 # FLOAT32 if pa_types.is_float64(arrow_type): @@ -116,6 +118,7 @@ def fory_type_id_to_arrow_type(type_id, precision=None, scale=None, list_type=No 4: pa.int32(), # INT32 6: pa.int64(), # INT64 17: pa.float16(), # FLOAT16 + 18: pa.float16(), # BFLOAT16 (Arrow doesn't have native bfloat16, map to float16) 19: pa.float32(), # FLOAT32 20: pa.float64(), # FLOAT64 21: pa.utf8(), # STRING @@ -204,17 +207,17 @@ def fory_field_list_to_arrow_schema(field_list): nullable = field_spec.get("nullable", True) # Handle nested types - if type_id == 21: # LIST + if type_id == 22: # LIST value_type = field_spec.get("value_type") arrow_type = pa.list_(value_type) - elif type_id == 23: # MAP + elif type_id == 24: # MAP key_type = field_spec.get("key_type") item_type = field_spec.get("item_type") arrow_type = pa.map_(key_type, item_type) - elif type_id == 15: # STRUCT + elif type_id == 27: # STRUCT struct_fields = field_spec.get("struct_fields", []) arrow_type = pa.struct(struct_fields) - elif type_id == 27: # DECIMAL + elif type_id == 40: # DECIMAL precision = field_spec.get("precision", 38) scale = field_spec.get("scale", 18) arrow_type = pa.decimal128(precision, scale) @@ -274,20 +277,20 @@ def reconstruct_arrow_type(spec): """ type_id = spec["type_id"] - if type_id == 21: # LIST + if type_id == 22: # LIST value_type = reconstruct_arrow_type(spec["value_type"]) return pa.list_(value_type) - elif type_id == 23: # MAP + elif type_id == 24: # MAP key_type = reconstruct_arrow_type(spec["key_type"]) item_type = reconstruct_arrow_type(spec["item_type"]) return pa.map_(key_type, item_type) - elif type_id == 15: # STRUCT + elif type_id == 27: # STRUCT fields = [] for field_spec in spec["fields"]: field_type = reconstruct_arrow_type(field_spec["type"]) fields.append(pa.field(field_spec["name"], field_type, nullable=field_spec.get("nullable", True))) return pa.struct(fields) - elif type_id == 27: # DECIMAL + elif type_id == 40: # DECIMAL return pa.decimal128(spec.get("precision", 38), spec.get("scale", 18)) else: return fory_type_id_to_arrow_type(type_id) diff --git a/python/pyfory/includes/libformat.pxd b/python/pyfory/includes/libformat.pxd index 240cb0f44a..d579ac67c7 100755 --- a/python/pyfory/includes/libformat.pxd +++ b/python/pyfory/includes/libformat.pxd @@ -136,6 +136,9 @@ cdef extern from "fory/row/schema.h" namespace "fory::row" nogil: cdef cppclass CFloat16Type" fory::row::Float16Type"(CFixedWidthType): pass + cdef cppclass CBFloat16Type" fory::row::BFloat16Type"(CFixedWidthType): + pass + cdef cppclass CFloat32Type" fory::row::Float32Type"(CFixedWidthType): pass @@ -223,6 +226,8 @@ cdef extern from "fory/row/schema.h" namespace "fory::row" nogil: shared_ptr[CDataType] int32" fory::row::int32"() shared_ptr[CDataType] int64" fory::row::int64"() shared_ptr[CDataType] float16" fory::row::float16"() + # TODO: Uncomment when C++ row format supports bfloat16 + # shared_ptr[CDataType] bfloat16" fory::row::bfloat16"() shared_ptr[CDataType] float32" fory::row::float32"() shared_ptr[CDataType] float64" fory::row::float64"() shared_ptr[CDataType] utf8" fory::row::utf8"() diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index c3d584c7c3..d0ff6b0136 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -35,7 +35,6 @@ Serializer, Numpy1DArraySerializer, NDArraySerializer, - PythonNDArraySerializer, PyArraySerializer, DynamicPyArraySerializer, NoneSerializer, @@ -56,6 +55,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -270,7 +270,7 @@ def _initialize_py(self): register(tuple, serializer=TupleSerializer) register(slice, serializer=SliceSerializer) if np is not None: - register(np.ndarray, serializer=PythonNDArraySerializer) + register(np.ndarray, serializer=NDArraySerializer) register(array.array, serializer=DynamicPyArraySerializer) register(types.MappingProxyType, serializer=MappingProxySerializer) register(pickle.PickleBuffer, serializer=PickleBufferSerializer) @@ -342,6 +342,13 @@ def _initialize_common(self): serializer=Float64Serializer, ) register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer) + from pyfory.bfloat16 import bfloat16 + + register( + bfloat16, + type_id=TypeId.BFLOAT16, + serializer=BFloat16Serializer, + ) register(str, type_id=TypeId.STRING, serializer=StringSerializer) # TODO(chaokunyang) DURATION DECIMAL register(datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer) @@ -353,6 +360,14 @@ def _initialize_common(self): type_id=typeid, serializer=PyArraySerializer(self.fory, ftype, typeid), ) + from pyfory.bfloat16_array import BFloat16Array + from pyfory.serializer import BFloat16ArraySerializer + + register( + BFloat16Array, + type_id=TypeId.BFLOAT16_ARRAY, + serializer=BFloat16ArraySerializer(self.fory, BFloat16Array, TypeId.BFLOAT16_ARRAY), + ) if np: # overwrite pyarray with same type id. # if pyarray are needed, one must annotate that value with XXXArrayType @@ -472,7 +487,8 @@ def _register_type( raise TypeError(f"type name {typename} and id {type_id} should not be set at the same time") if cls in self._types_info: raise TypeError(f"{cls} registered already") - return self._register_xtype( + register_type = self._register_xtype if self.fory.xlang else self._register_pytype + return register_type( cls, type_id=type_id, user_type_id=user_type_id, @@ -540,6 +556,30 @@ def _register_xtype( internal=internal, ) + def _register_pytype( + self, + cls: Union[type, TypeVar], + *, + type_id: int = None, + user_type_id: int = NO_USER_TYPE_ID, + namespace: str = None, + typename: str = None, + serializer: Serializer = None, + internal: bool = False, + ): + # Set default type_id when None, similar to _register_xtype + if type_id is None and typename is not None: + type_id = self._next_type_id() + return self.__register_type( + cls, + type_id=type_id, + user_type_id=user_type_id, + namespace=namespace, + typename=typename, + serializer=serializer, + internal=internal, + ) + def __register_type( self, cls: Union[type, TypeVar], @@ -588,7 +628,7 @@ def __register_type( if user_type_id not in self._user_type_id_to_type_info or not internal: self._user_type_id_to_type_info[user_type_id] = typeinfo self._used_user_type_ids.add(user_type_id) - elif not TypeId.is_namespaced_type(type_id): + elif not self.fory.xlang or not TypeId.is_namespaced_type(type_id): if type_id not in self._type_id_to_type_info or not internal: self._type_id_to_type_info[type_id] = typeinfo self._types_info[cls] = typeinfo @@ -611,6 +651,9 @@ def register_serializer(self, cls: Union[type, TypeVar], serializer): if cls not in self._types_info: raise TypeUnregisteredError(f"{cls} not registered") typeinfo = self._types_info[cls] + if not self.fory.xlang: + typeinfo.serializer = serializer + return prev_type_id = typeinfo.type_id prev_user_type_id = typeinfo.user_type_id if needs_user_type_id(prev_type_id) and prev_user_type_id not in {None, NO_USER_TYPE_ID}: diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 4d30894144..415af11af9 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -64,6 +64,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -99,6 +100,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -147,6 +149,12 @@ def __init__(self, fory): super().__init__(fory, None) self.need_to_write_ref = False + def xwrite(self, buffer, value): + raise NotImplementedError + + def xread(self, buffer): + raise NotImplementedError + def write(self, buffer, value): pass @@ -214,6 +222,12 @@ def read(self, buffer): name = self.fory.read_ref(buffer) return self.type_(start, stop, step, dtype=dtype, name=name) + def xwrite(self, buffer, value): + raise NotImplementedError + + def xread(self, buffer): + raise NotImplementedError + # Use numpy array or python array module. typecode_dict = ( @@ -270,6 +284,7 @@ def read(self, buffer): TypeId.UINT64_ARRAY: "Q", TypeId.FLOAT32_ARRAY: "f", TypeId.FLOAT64_ARRAY: "d", + TypeId.BFLOAT16_ARRAY: "H", # bfloat16 uses 'H' typecode (uint16) } ) @@ -309,7 +324,7 @@ def __init__(self, fory, ftype, type_id: str): self.typecode = typeid_code[type_id] self.itemsize, ftype, self.type_id = typecode_dict[self.typecode] - def write(self, buffer, value): + def xwrite(self, buffer, value): assert value.itemsize == self.itemsize view = memoryview(value) assert view.format == self.typecode @@ -325,7 +340,7 @@ def write(self, buffer, value): swapped.byteswap() buffer.write_buffer(swapped) - def read(self, buffer): + def xread(self, buffer): data = buffer.read_bytes_and_size() arr = array.array(self.typecode, []) arr.frombytes(data) @@ -334,14 +349,37 @@ def read(self, buffer): arr.byteswap() return arr + def write(self, buffer, value: array.array): + nbytes = len(value) * value.itemsize + buffer.write_string(value.typecode) + buffer.write_var_uint32(nbytes) + if is_little_endian or value.itemsize == 1: + buffer.write_buffer(value) + else: + # Swap bytes on big-endian machines for multi-byte types + swapped = array.array(value.typecode, value) + swapped.byteswap() + buffer.write_buffer(swapped) + + def read(self, buffer): + typecode = buffer.read_string() + data = buffer.read_bytes_and_size() + arr = array.array(typecode[0], []) # Take first character + arr.frombytes(data) + if not is_little_endian and arr.itemsize > 1: + # Swap bytes on big-endian machines for multi-byte types + arr.byteswap() + return arr + class DynamicPyArraySerializer(Serializer): """Serializer for dynamic Python arrays that handles any typecode.""" def __init__(self, fory, cls): super().__init__(fory, cls) + self._serializer = ReduceSerializer(fory, cls) - def write(self, buffer, value): + def xwrite(self, buffer, value): itemsize, ftype, type_id = typecode_dict[value.typecode] view = memoryview(value) nbytes = len(value) * itemsize @@ -363,7 +401,7 @@ def write(self, buffer, value): swapped.byteswap() buffer.write_buffer(swapped) - def read(self, buffer): + def xread(self, buffer): type_id = buffer.read_uint8() typecode = typeid_code[type_id] itemsize = typecode_dict[typecode][0] @@ -374,6 +412,50 @@ def read(self, buffer): arr.byteswap() return arr + def write(self, buffer, value): + self._serializer.write(buffer, value) + + def read(self, buffer): + return self._serializer.read(buffer) + + +class BFloat16ArraySerializer(Serializer): + def __init__(self, fory, ftype, type_id: int): + super().__init__(fory, ftype) + self.type_id = type_id + self.itemsize = 2 + + def write(self, buffer, value): + from pyfory.bfloat16_array import BFloat16Array + + if isinstance(value, BFloat16Array): + arr_data = value._data + elif isinstance(value, array.array) and value.typecode == "H": + arr_data = value + else: + arr_data = BFloat16Array(value)._data + nbytes = len(arr_data) * 2 + buffer.write_var_uint32(nbytes) + if nbytes > 0: + if is_little_endian: + buffer.write_buffer(arr_data) + else: + swapped = array.array("H", arr_data) + swapped.byteswap() + buffer.write_buffer(swapped) + + def read(self, buffer): + from pyfory.bfloat16_array import BFloat16Array + + data = buffer.read_bytes_and_size() + arr = array.array("H", []) + arr.frombytes(data) + if not is_little_endian: + arr.byteswap() + bf16_arr = BFloat16Array.__new__(BFloat16Array) + bf16_arr._data = arr + return bf16_arr + if np: _np_dtypes_dict = ( @@ -407,7 +489,6 @@ def read(self, buffer): ) else: _np_dtypes_dict = {} -_np_typeid_to_dtype = {type_id: dtype for dtype, (_, _, _, type_id) in _np_dtypes_dict.items()} class Numpy1DArraySerializer(Serializer): @@ -417,8 +498,9 @@ def __init__(self, fory, ftype, dtype): super().__init__(fory, ftype) self.dtype = dtype self.itemsize, self.typecode, _, self.type_id = _np_dtypes_dict[self.dtype] + self._serializer = ReduceSerializer(fory, np.ndarray) - def write(self, buffer, value): + def xwrite(self, buffer, value): assert value.itemsize == self.itemsize view = memoryview(value) try: @@ -440,7 +522,7 @@ def write(self, buffer, value): # Swap bytes on big-endian machines for multi-byte types buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) - def read(self, buffer): + def xread(self, buffer): data = buffer.read_bytes_and_size() arr = np.frombuffer(data, dtype=self.dtype.newbyteorder("<")) if self.itemsize > 1: @@ -452,53 +534,32 @@ def read(self, buffer): arr = arr.astype(self.dtype) return arr + def write(self, buffer, value): + self._serializer.write(buffer, value) + + def read(self, buffer): + return self._serializer.read(buffer) + class NDArraySerializer(Serializer): - def write(self, buffer, value): - # Write concrete 1D primitive ndarray using type id + bytes payload. - dtype_info = _np_dtypes_dict.get(value.dtype) - if dtype_info is None or value.ndim != 1: - raise NotImplementedError(f"Unsupported ndarray: dtype={value.dtype}, ndim={value.ndim}") - itemsize, _typecode, _ftype, type_id = dtype_info + def xwrite(self, buffer, value): + itemsize, typecode, ftype, type_id = _np_dtypes_dict[value.dtype] view = memoryview(value) nbytes = len(value) * itemsize buffer.write_uint8(type_id) buffer.write_var_uint32(nbytes) if value.dtype == np.dtype("bool") or not view.c_contiguous: - if not is_little_endian and itemsize > 1: - buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) - else: - buffer.write_bytes(value.tobytes()) - elif is_little_endian or itemsize == 1: - buffer.write_buffer(value) + buffer.write_bytes(value.tobytes()) else: - buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) - - def read(self, buffer): - type_id = buffer.read_uint8() - dtype = _np_typeid_to_dtype.get(type_id) - if dtype is None: - raise NotImplementedError(f"Unsupported ndarray type id: {type_id}") - data = buffer.read_bytes_and_size() - arr = np.frombuffer(data, dtype=dtype.newbyteorder("<")) - if dtype.itemsize > 1: - if is_little_endian: - arr = arr.view(dtype) - else: - arr = arr.astype(dtype) - return arr + buffer.write_buffer(value) + def xread(self, buffer): + raise NotImplementedError("Multi-dimensional array not supported currently") -class PythonNDArraySerializer(NDArraySerializer): def write(self, buffer, value): - dtype_info = _np_dtypes_dict.get(value.dtype) - if dtype_info is not None and value.ndim == 1: - super().write(buffer, value) - return - fory = self.fory dtype = value.dtype - buffer.write_string(dtype.str) + fory.write_ref(buffer, dtype) buffer.write_var_uint32(len(value.shape)) for dim in value.shape: buffer.write_var_uint32(dim) @@ -510,22 +571,8 @@ def write(self, buffer, value): fory.write_buffer_object(buffer, NDArrayBufferObject(value)) def read(self, buffer): - reader_index = buffer.get_reader_index() - type_id = buffer.read_uint8() - dtype = _np_typeid_to_dtype.get(type_id) - if dtype is not None: - data = buffer.read_bytes_and_size() - arr = np.frombuffer(data, dtype=dtype.newbyteorder("<")) - if dtype.itemsize > 1: - if is_little_endian: - arr = arr.view(dtype) - else: - arr = arr.astype(dtype) - return arr - - buffer.set_reader_index(reader_index) fory = self.fory - dtype = np.dtype(buffer.read_string()) + dtype = fory.read_ref(buffer) ndim = buffer.read_var_uint32() shape = tuple(buffer.read_var_uint32() for _ in range(ndim)) if dtype.kind == "O": @@ -1200,6 +1247,12 @@ def _deserialize_function(self, buffer): func = result return func + def xwrite(self, buffer, value): + raise NotImplementedError() + + def xread(self, buffer): + raise NotImplementedError() + def write(self, buffer, value): """Serialize a function for Python-only mode.""" self._serialize_function(buffer, value) @@ -1264,6 +1317,12 @@ def read(self, buffer): method = result return method + def xwrite(self, buffer, value): + return self.write(buffer, value) + + def xread(self, buffer): + return self.read(buffer) + class ObjectSerializer(Serializer): """Serializer for regular Python objects. @@ -1307,6 +1366,14 @@ def read(self, buffer): setattr(obj, field_name, field_value) return obj + def xwrite(self, buffer, value): + # for cross-language or minimal framing, reuse the same logic + return self.write(buffer, value) + + def xread(self, buffer): + # symmetric to xwrite + return self.read(buffer) + @dataclasses.dataclass class NonExistEnum: @@ -1324,9 +1391,16 @@ def support_subclass(cls) -> bool: return True def write(self, buffer, value): - buffer.write_var_uint32(value.value) + buffer.write_string(value.name) def read(self, buffer): + name = buffer.read_string() + return NonExistEnum(name=name) + + def xwrite(self, buffer, value): + buffer.write_var_uint32(value.value) + + def xread(self, buffer): value = buffer.read_var_uint32() return NonExistEnum(value=value) @@ -1338,6 +1412,12 @@ def write(self, buffer, value): def read(self, buffer): return self.fory.handle_unsupported_read(buffer) + def xwrite(self, buffer, value): + raise NotImplementedError(f"{self.type_} is not supported for xwrite") + + def xread(self, buffer): + raise NotImplementedError(f"{self.type_} is not supported for xread") + __all__ = [ # Base serializers (imported) @@ -1362,6 +1442,7 @@ def read(self, buffer): "TaggedUint64Serializer", "Float32Serializer", "Float64Serializer", + "BFloat16Serializer", "StringSerializer", "DateSerializer", "TimestampSerializer", diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index 3a842d8d78..dbe3d75018 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -464,7 +464,7 @@ def _write_field_value(self, buffer, serializer, field_value, is_nullable, is_dy serializer.write(buffer, field_value) return if is_tracking_ref: - self.fory.write_ref(buffer, field_value, serializer=None if is_dynamic else serializer) + self.fory.write_ref(buffer, field_value, None if is_dynamic else serializer) return if is_nullable: if field_value is None: @@ -474,7 +474,7 @@ def _write_field_value(self, buffer, serializer, field_value, is_nullable, is_dy if is_dynamic: self.fory.write_no_ref(buffer, field_value) else: - self.fory.write_no_ref(buffer, field_value, serializer=serializer) + self.fory.write_no_ref(buffer, field_value, serializer) def _read_field_value(self, buffer, serializer, is_nullable, is_dynamic, is_basic, is_tracking_ref): if is_nullable and is_basic: @@ -484,12 +484,12 @@ def _read_field_value(self, buffer, serializer, is_nullable, is_dynamic, is_basi if is_basic: return serializer.read(buffer) if is_tracking_ref: - return self.fory.read_ref(buffer, serializer=None if is_dynamic else serializer) + return self.fory.read_ref(buffer, None if is_dynamic else serializer) if is_nullable and buffer.read_int8() == NULL_FLAG: return None if is_dynamic: return self.fory.read_no_ref(buffer) - return self.fory.read_no_ref(buffer, serializer=serializer) + return self.fory.read_no_ref(buffer, serializer) def write(self, buffer: Buffer, value): if not self.fory.compatible: diff --git a/python/pyfory/tests/test_bfloat16.py b/python/pyfory/tests/test_bfloat16.py new file mode 100644 index 0000000000..eae52ea27c --- /dev/null +++ b/python/pyfory/tests/test_bfloat16.py @@ -0,0 +1,405 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array +import math +import struct + +import pytest + +from pyfory import Fory +from pyfory.bfloat16 import bfloat16 +from pyfory.bfloat16_array import BFloat16Array +from pyfory.types import TypeId + + +def ser_de(fory, value): + data = fory.serialize(value) + return fory.deserialize(data) + + +# --------------- scalar construction --------------- + + +def test_bfloat16_basic(): + bf16 = bfloat16(3.14) + assert isinstance(bf16, bfloat16) + assert bf16.to_float32() == pytest.approx(3.14, abs=0.01) + bits = bf16.to_bits() + assert bfloat16.from_bits(bits).to_bits() == bits + + +def test_bfloat16_from_bits_roundtrip(): + for bits in (0x0000, 0x3F80, 0x4049, 0x7F80, 0xFF80, 0x7FC0): + assert bfloat16.from_bits(bits).to_bits() == bits + + +# --------------- special values --------------- + + +def test_bfloat16_special_values(): + assert bfloat16(float("nan")).is_nan() + assert bfloat16(float("inf")).is_inf() + assert bfloat16(float("-inf")).is_inf() + assert bfloat16(0.0).is_zero() + assert bfloat16(1.0).is_finite() + assert not bfloat16(1.0).is_nan() + assert not bfloat16(1.0).is_inf() + + +def test_bfloat16_positive_zero(): + pz = bfloat16(0.0) + assert pz.to_bits() == 0x0000 + assert pz.is_zero() + assert pz.to_float32() == 0.0 + + +def test_bfloat16_negative_zero(): + nz = bfloat16(-0.0) + assert nz.to_bits() == 0x8000 + assert nz.is_zero() + assert math.copysign(1.0, nz.to_float32()) == -1.0 + + +def test_bfloat16_positive_infinity(): + pinf = bfloat16(float("inf")) + assert pinf.to_bits() == 0x7F80 + assert pinf.is_inf() + assert not pinf.is_nan() + assert pinf.to_float32() == float("inf") + + +def test_bfloat16_negative_infinity(): + ninf = bfloat16(float("-inf")) + assert ninf.to_bits() == 0xFF80 + assert ninf.is_inf() + assert ninf.to_float32() == float("-inf") + + +def test_bfloat16_subnormal(): + bf = bfloat16.from_bits(0x0001) + assert bf.is_finite() + assert not bf.is_zero() + assert bf.to_float32() != 0.0 + + +def test_bfloat16_max_normal(): + bf = bfloat16.from_bits(0x7F7F) + assert bf.is_finite() + val = bf.to_float32() + assert val > 3.0e38 + + +def test_bfloat16_min_positive_normal(): + bf = bfloat16.from_bits(0x0080) + assert bf.is_finite() + val = bf.to_float32() + assert 0 < val < 1e-37 + + +# --------------- NaN behaviour --------------- + + +def test_bfloat16_nan_not_equal_to_itself(): + n = bfloat16(float("nan")) + assert n != n + + +def test_bfloat16_signaling_nan_preserved(): + snan_f32_bits = 0x7F800001 + snan_f32 = struct.unpack("f", struct.pack("I", snan_f32_bits))[0] + bf = bfloat16(snan_f32) + assert bf.is_nan(), "signaling NaN must stay NaN, not become infinity" + assert bf.to_bits() != 0x7F80, "must not collapse to infinity" + + +def test_bfloat16_quiet_nan_roundtrip(): + qnan = bfloat16.from_bits(0x7FC0) + assert qnan.is_nan() + assert qnan.to_bits() == 0x7FC0 + + +# --------------- equality / hashing --------------- + + +def test_bfloat16_equality(): + assert bfloat16(1.0) == bfloat16(1.0) + assert bfloat16(0.0) == bfloat16(-0.0) + assert bfloat16(1.0) != bfloat16(2.0) + + +def test_bfloat16_hash_contract(): + pz = bfloat16(0.0) + nz = bfloat16(-0.0) + assert pz == nz + assert hash(pz) == hash(nz), "+0 and -0 are equal so must share the same hash" + + +def test_bfloat16_hash_consistency(): + for v in (1.0, -1.0, 3.14, float("inf")): + a = bfloat16(v) + b = bfloat16(v) + assert hash(a) == hash(b) + + +# --------------- rounding --------------- + + +def test_bfloat16_round_to_nearest_even(): + bf1 = bfloat16(1.0) + bf2 = bfloat16(1.0 + 2**-8) + assert bf1.to_bits() == bf2.to_bits() or abs(bf1.to_bits() - bf2.to_bits()) <= 1 + + +# --------------- conversion --------------- + + +def test_bfloat16_conversion(): + assert bfloat16(0.0).to_float32() == 0.0 + assert bfloat16(1.0).to_float32() == 1.0 + assert bfloat16(-1.0).to_float32() == -1.0 + assert bfloat16(3.14).to_float32() == pytest.approx(3.14, abs=0.01) + assert math.isnan(bfloat16(float("nan")).to_float32()) + assert math.isinf(bfloat16(float("inf")).to_float32()) + assert math.isinf(bfloat16(float("-inf")).to_float32()) + + +def test_bfloat16_float_dunder(): + bf = bfloat16(2.5) + assert float(bf) == pytest.approx(2.5) + + +def test_bfloat16_repr_and_str(): + bf = bfloat16(1.0) + r = repr(bf) + assert "bfloat16" in r + assert "1.0" in str(bf) or "1" in str(bf) + + +# --------------- scalar serialization --------------- + + +def test_bfloat16_serialization(): + fory = Fory(xlang=True) + assert ser_de(fory, bfloat16(0.0)).to_bits() == bfloat16(0.0).to_bits() + assert ser_de(fory, bfloat16(1.0)).to_bits() == bfloat16(1.0).to_bits() + assert ser_de(fory, bfloat16(3.14)).to_bits() == bfloat16(3.14).to_bits() + assert ser_de(fory, bfloat16(float("inf"))).is_inf() + assert ser_de(fory, bfloat16(float("nan"))).is_nan() + + +def test_bfloat16_serialization_special_bits(): + fory = Fory(xlang=True) + for bits in (0x0000, 0x8000, 0x7F80, 0xFF80, 0x7FC0, 0x0001, 0x7F7F): + original = bfloat16.from_bits(bits) + result = ser_de(fory, original) + assert result.to_bits() == bits + + +# --------------- array construction --------------- + + +def test_bfloat16_array_basic(): + arr = BFloat16Array([1.0, 2.0, 3.14]) + assert len(arr) == 3 + assert arr[0].to_float32() == pytest.approx(1.0) + arr[0] = bfloat16(5.0) + assert arr[0].to_float32() == pytest.approx(5.0) + + +def test_bfloat16_array_empty(): + arr = BFloat16Array() + assert len(arr) == 0 + assert arr.tobytes() == b"" + + +def test_bfloat16_array_from_bfloat16_values(): + values = [bfloat16(1.0), bfloat16(2.0)] + arr = BFloat16Array(values) + assert len(arr) == 2 + assert arr[0].to_bits() == bfloat16(1.0).to_bits() + + +def test_bfloat16_array_copy_constructor(): + original = BFloat16Array([1.0, 2.0, 3.0]) + copy = BFloat16Array(original) + assert copy == original + copy[0] = bfloat16(99.0) + assert copy != original + + +# --------------- array bytes round-trip --------------- + + +def test_bfloat16_array_tobytes_frombytes(): + arr = BFloat16Array([1.0, 2.0, 3.0]) + raw = arr.tobytes() + assert len(raw) == 6 + restored = BFloat16Array.frombytes(raw) + assert restored == arr + + +def test_bfloat16_array_frombytes_odd_length(): + with pytest.raises(ValueError): + BFloat16Array.frombytes(b"\x00\x01\x02") + + +# --------------- array append / extend --------------- + + +def test_bfloat16_array_append_extend(): + arr = BFloat16Array() + arr.append(1.0) + arr.append(bfloat16(2.0)) + assert len(arr) == 2 + arr.extend(BFloat16Array([3.0, 4.0])) + assert len(arr) == 4 + arr.extend([5.0]) + assert len(arr) == 5 + + +# --------------- array iteration / equality --------------- + + +def test_bfloat16_array_iteration(): + arr = BFloat16Array([1.0, 2.0, 3.0]) + floats = [float(v) for v in arr] + assert floats == [pytest.approx(1.0), pytest.approx(2.0), pytest.approx(3.0)] + + +def test_bfloat16_array_equality(): + a = BFloat16Array([1.0, 2.0]) + b = BFloat16Array([1.0, 2.0]) + c = BFloat16Array([1.0, 3.0]) + assert a == b + assert a != c + assert a != "not an array" + + +def test_bfloat16_array_repr(): + arr = BFloat16Array([1.0]) + assert "BFloat16Array" in repr(arr) + + +def test_bfloat16_array_itemsize(): + arr = BFloat16Array() + assert arr.itemsize == 2 + + +def test_bfloat16_array_to_bits_array(): + arr = BFloat16Array([1.0, 2.0]) + bits = arr.to_bits_array() + assert isinstance(bits, array.array) + assert bits.typecode == "H" + assert len(bits) == 2 + + +def test_bfloat16_array_from_bits_array(): + bits = array.array("H", [bfloat16(1.0).to_bits(), bfloat16(2.0).to_bits()]) + arr = BFloat16Array.from_bits_array(bits) + assert len(arr) == 2 + assert arr[0].to_float32() == pytest.approx(1.0) + + +# --------------- array special values --------------- + + +def test_bfloat16_array_special_values(): + arr = BFloat16Array([0.0, -0.0, float("inf"), float("-inf"), float("nan")]) + assert arr[0].is_zero() + assert arr[1].is_zero() + assert arr[2].is_inf() + assert arr[3].is_inf() + assert arr[4].is_nan() + + +# --------------- array serialization --------------- + + +def test_bfloat16_array_serialization(): + fory = Fory(xlang=True) + arr = BFloat16Array([1.0, 2.0, 3.14]) + result = ser_de(fory, arr) + assert isinstance(result, BFloat16Array) + assert len(result) == 3 + assert result[0].to_float32() == pytest.approx(1.0) + + +def test_bfloat16_array_serialization_empty(): + fory = Fory(xlang=True) + arr = BFloat16Array() + result = ser_de(fory, arr) + assert isinstance(result, BFloat16Array) + assert len(result) == 0 + + +def test_bfloat16_array_serialization_special(): + fory = Fory(xlang=True) + arr = BFloat16Array([0.0, float("inf"), float("nan")]) + result = ser_de(fory, arr) + assert result[0].is_zero() + assert result[1].is_inf() + assert result[2].is_nan() + + +# --------------- integration --------------- + + +def test_bfloat16_in_dataclass(): + from dataclasses import dataclass + + @dataclass + class TestStruct: + value: bfloat16 + arr: BFloat16Array + + fory = Fory(xlang=True) + fory.register_type(TestStruct) + obj = TestStruct(value=bfloat16(3.14), arr=BFloat16Array([1.0, 2.0])) + result = ser_de(fory, obj) + assert result.value.to_float32() == pytest.approx(3.14, abs=0.01) + assert len(result.arr) == 2 + + +def test_bfloat16_in_list(): + fory = Fory(xlang=True) + values = [bfloat16(1.0), bfloat16(2.0)] + result = ser_de(fory, values) + assert len(result) == 2 + assert result[0].to_float32() == pytest.approx(1.0) + + +def test_bfloat16_in_map(): + fory = Fory(xlang=True) + data = {"a": bfloat16(1.0), "b": bfloat16(2.0)} + result = ser_de(fory, data) + assert result["a"].to_float32() == pytest.approx(1.0) + + +# --------------- type registration --------------- + + +def test_bfloat16_type_registration(): + fory = Fory(xlang=True) + type_info = fory.type_resolver.get_type_info(bfloat16) + assert type_info.type_id == TypeId.BFLOAT16 + + +def test_bfloat16_array_type_registration(): + fory = Fory(xlang=True) + type_info = fory.type_resolver.get_type_info(BFloat16Array) + assert type_info.type_id == TypeId.BFLOAT16_ARRAY diff --git a/python/pyfory/types.py b/python/pyfory/types.py index 7f8f871dd8..db4333a83b 100644 --- a/python/pyfory/types.py +++ b/python/pyfory/types.py @@ -198,6 +198,7 @@ def is_type_share_meta(type_id: int) -> bool: tagged_uint64 = TypeVar("tagged_uint64", bound=int) float32 = TypeVar("float32", bound=float) float64 = TypeVar("float64", bound=float) +bfloat16 = TypeVar("bfloat16", bound=float) class RefMeta: @@ -314,6 +315,7 @@ def get_primitive_type_size(type_id) -> int: uint64_array = TypeVar("uint64_array", bound=array.ArrayType) float32_array = TypeVar("float32_array", bound=array.ArrayType) float64_array = TypeVar("float64_array", bound=array.ArrayType) +bfloat16_array = TypeVar("bfloat16_array", bound=array.ArrayType) BoolNDArrayType = TypeVar("BoolNDArrayType", bound=ndarray) Int8NDArrayType = TypeVar("Int8NDArrayType", bound=ndarray) Uint8NDArrayType = TypeVar("Uint8NDArrayType", bound=ndarray) @@ -351,6 +353,7 @@ def get_primitive_type_size(type_id) -> int: uint64_array, float32_array, float64_array, + bfloat16_array, } _np_array_types = { BoolNDArrayType, @@ -384,6 +387,7 @@ def is_py_array_type(type_) -> bool: TypeId.UINT64_ARRAY, TypeId.FLOAT32_ARRAY, TypeId.FLOAT64_ARRAY, + TypeId.BFLOAT16_ARRAY, }