Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions python/infinilm/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,66 @@ def infini_to_ctype_dtype(infini_dtype):

if infini_dtype == infinicore.int32:
return ctypes.c_int32
elif infini_dtype == infinicore.int64:
return ctypes.c_int64
elif infini_dtype == infinicore.float32:
return ctypes.c_float
elif infini_dtype == infinicore.bfloat16:
# bfloat16 uses uint16 to read raw bytes
return ctypes.c_uint16
else:
raise ValueError(f"Unsupported py_dtype: {infini_dtype}")


def infini_to_numpy(infini_tensor: infinicore.Tensor):
# Ensure data is on CPU
if infini_tensor.device.type != "cpu":
infini_tensor_cpu = infini_tensor.to(infinicore.device("cpu", 0))
# Sync to ensure copy is complete
infinicore.sync_stream()
else:
infini_tensor_cpu = infini_tensor

# 获取数据指针和形状信息
# Get data pointer and shape information
data_ptr = infini_tensor_cpu.data_ptr()
num_elements = infini_tensor_cpu.numel()
original_shape = infini_tensor_cpu.shape

# 创建1D NumPy数组(共享内存)
ArrayType = infini_to_ctype_dtype(infini_tensor_cpu.dtype) * num_elements
array = ArrayType.from_address(data_ptr)
np_flat = np.ctypeslib.as_array(array)

# 重塑为原始形状
np_array = np_flat.reshape(original_shape)

return np.copy(np_array)
# Special handling for bfloat16
if infini_tensor_cpu.dtype == infinicore.bfloat16:
# bfloat16 is 16-bit, read as uint16
import ctypes
# Use safer approach: allocate memory first, then copy
buffer = (ctypes.c_uint16 * num_elements)()
ctypes.memmove(ctypes.addressof(buffer), data_ptr, num_elements * 2) # 2 bytes per uint16
np_uint16 = np.array(buffer, dtype=np.uint16, copy=True)

# Convert uint16 to float32
# bfloat16 memory layout: shift uint16 left by 16 bits, then read as float32
np_uint32 = np_uint16.astype(np.uint32) << 16
np_array = np_uint32.view(np.float32).reshape(original_shape)
else:
# Determine element size and numpy dtype based on dtype
dtype_info_map = {
infinicore.int32: (4, np.int32),
infinicore.int64: (8, np.int64),
infinicore.float32: (4, np.float32),
}
element_size, np_dtype = dtype_info_map.get(infini_tensor_cpu.dtype, (4, np.float32))

# Use safer approach: allocate memory first, then copy
import ctypes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到最上面去

ctype = infini_to_ctype_dtype(infini_tensor_cpu.dtype)
buffer = (ctype * num_elements)()
ctypes.memmove(ctypes.addressof(buffer), data_ptr, num_elements * element_size)

# Convert to numpy array (using np.array instead of frombuffer, safer)
np_flat = np.array(buffer, dtype=np_dtype, copy=True)

# Reshape to original shape
np_array = np_flat.reshape(original_shape)

return np_array


infinicore.Tensor.to_numpy = infini_to_numpy
Expand Down Expand Up @@ -197,6 +231,8 @@ def _sample(
# -------------------------------------------------------------------------- #
start_time = time.time()
logits = self(**model_inputs)
# Ensure computation is complete - sync stream before reading logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个同步可以去掉

infinicore.sync_stream()

# -------------------------------------------------------------------------- #
# 处理输出
Expand Down Expand Up @@ -225,7 +261,7 @@ def _sample(
out=out,
)

infinicore.sync_stream() # 计算结束前需要同步
infinicore.sync_stream() # Sync before computation ends

end_time = time.time()
time_list.append((end_time - start_time) * 1000)
Expand All @@ -245,11 +281,14 @@ def _sample(
break

print("\n</s>")
print(
f"\n\n\n Time per step: prefill {round(time_list[0], 2)} token/ms\n",
)
print(
f" Time per step: decoder {round(sum(time_list[1:]) / (len(time_list) - 1), 2)} token/ms \n",
)

if len(time_list) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if len(time_list) > 1

print(
f"\n\n\n Time per step: prefill {round(time_list[0], 2)} token/ms\n",
)
if len(time_list) > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删了

print(
f" Time per step: decoder {round(sum(time_list[1:]) / (len(time_list) - 1), 2)} token/ms \n",
)

return output_tokens_list, output_content