2929
3030from executorch .extension .flat_tensor .serialize .serialize import (
3131 _deserialize_to_flat_tensor ,
32+ _FLATBUFFER_ALIGNMENT ,
3233 FlatTensorConfig ,
3334 FlatTensorHeader ,
3435 FlatTensorSerializer ,
@@ -109,8 +110,7 @@ def _check_named_data_entries(
109110 f"Named data record { key } .{ field .name } does not match." ,
110111 )
111112
112- def test_serialize (self ) -> None :
113- config = FlatTensorConfig ()
113+ def _serialize_with_alignment (self , config : FlatTensorConfig ) -> None :
114114 serializer : DataSerializer = FlatTensorSerializer (config )
115115 serialized_data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
116116
@@ -120,15 +120,15 @@ def test_serialize(self) -> None:
120120 )
121121 self .assertTrue (header .is_valid ())
122122
123- # Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
123+ # Flatbuffer is non-empty.
124+ self .assertTrue (header .flatbuffer_size > 0 )
125+
126+ # Align the flatbuffer to _FLATBUFFER_ALIGNMENT.
124127 self .assertEqual (
125128 header .flatbuffer_offset ,
126- aligned_size (FlatTensorHeader .EXPECTED_LENGTH , config . segment_alignment ),
129+ aligned_size (FlatTensorHeader .EXPECTED_LENGTH , _FLATBUFFER_ALIGNMENT ),
127130 )
128131
129- # Flatbuffer is non-empty.
130- self .assertTrue (header .flatbuffer_size > 0 )
131-
132132 # Segment base offset is aligned to config.segment_alignment.
133133 expected_segment_base_offset = aligned_size (
134134 header .flatbuffer_offset + header .flatbuffer_size , config .segment_alignment
@@ -180,12 +180,12 @@ def test_serialize(self) -> None:
180180 segments = flat_tensor .segments
181181 self .assertEqual (len (segments ), 3 )
182182
183- # Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.tensor_alignment .
183+ # Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.segment_alignment .
184184 self .assertEqual (segments [0 ].offset , 0 )
185185 self .assertEqual (segments [0 ].size , len (TEST_BUFFER [0 ]))
186186
187- # Segment 1 contains fqn3; 32 bytes, aligned to config.tensor_alignment .
188- self .assertEqual (segments [1 ].offset , config .tensor_alignment )
187+ # Segment 1 contains fqn3; 32 bytes, aligned to config.segment_alignment .
188+ self .assertEqual (segments [1 ].offset , config .segment_alignment )
189189 self .assertEqual (segments [1 ].size , len (TEST_BUFFER [1 ]))
190190
191191 # Segment 2 contains key0; 17 bytes, aligned to 64.
@@ -194,7 +194,7 @@ def test_serialize(self) -> None:
194194 )
195195 self .assertEqual (
196196 segments [2 ].offset ,
197- aligned_size (config .tensor_alignment * 3 , custom_alignment ),
197+ aligned_size (config .segment_alignment * 2 , custom_alignment ),
198198 )
199199 self .assertEqual (segments [2 ].size , len (TEST_BUFFER [2 ]))
200200
@@ -245,6 +245,18 @@ def test_serialize(self) -> None:
245245
246246 self .assertEqual (segments [2 ].offset + segments [2 ].size , len (segment_data ))
247247
248+ def test_serialize_default_alignment (self ) -> None :
249+ config = FlatTensorConfig ()
250+ self ._serialize_with_alignment (config )
251+
252+ def test_serialize_align_4096 (self ) -> None :
253+ config = FlatTensorConfig (segment_alignment = 4096 )
254+ self ._serialize_with_alignment (config )
255+
256+ def test_serialize_align_1024 (self ) -> None :
257+ config = FlatTensorConfig (segment_alignment = 1024 )
258+ self ._serialize_with_alignment (config )
259+
248260 def test_round_trip (self ) -> None :
249261 # Serialize and then deserialize the test payload. Make sure it's reconstructed
250262 # properly.
0 commit comments