Skip to content

Commit 24f2562

Browse files
committed
Sync llama : refactor llama_model_quantize_params to expose a pure C interface (#20346)
Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent 7036ac3 commit 24f2562

1 file changed

Lines changed: 65 additions & 26 deletions

File tree

llama_cpp/llama_cpp.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,8 @@ class llama_model_kv_override(ctypes.Structure):
677677
key: bytes
678678
value: Union[int, float, bool, bytes]
679679

680+
llama_model_kv_override_p = ctypes.POINTER(llama_model_kv_override)
681+
680682
# struct llama_model_tensor_buft_override {
681683
# const char * pattern;
682684
# ggml_backend_buffer_type_t buft;
@@ -975,22 +977,59 @@ class llama_context_params(ctypes.Structure):
975977
llama_context_params_p = ctypes.POINTER(llama_context_params)
976978

977979

980+
# struct llama_model_tensor_override {
981+
# const char * pattern;
982+
# enum ggml_type type;
983+
# };
984+
class llama_model_tensor_override(ctypes.Structure):
985+
_fields_ = [
986+
("pattern", ctypes.c_char_p),
987+
("type", ctypes.c_int),
988+
]
989+
990+
if TYPE_CHECKING:
991+
pattern: ctypes.c_char_p
992+
type: ctypes.c_int
993+
994+
llama_model_tensor_override_p = ctypes.POINTER(llama_model_tensor_override)
995+
996+
997+
# struct llama_model_imatrix_data {
998+
# const char * name;
999+
# const float * data;
1000+
# size_t size;
1001+
# };
1002+
class llama_model_imatrix_data(ctypes.Structure):
1003+
_fields_ = [
1004+
("name", ctypes.c_char_p),
1005+
("data", ctypes.POINTER(ctypes.c_float)),
1006+
("size", ctypes.c_size_t),
1007+
]
1008+
1009+
if TYPE_CHECKING:
1010+
name: ctypes.c_char_p
1011+
data: ctypes.POINTER(ctypes.c_float)
1012+
size: ctypes.c_size_t
1013+
1014+
llama_model_imatrix_data_p = ctypes.POINTER(llama_model_imatrix_data)
1015+
1016+
9781017
# // model quantization parameters
9791018
# typedef struct llama_model_quantize_params {
980-
# int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
981-
# enum llama_ftype ftype; // quantize to this llama_ftype
982-
# enum ggml_type output_tensor_type; // output tensor type
983-
# enum ggml_type token_embedding_type; // token embeddings tensor type
984-
# bool allow_requantize; // allow quantizing non-f32/f16 tensors
985-
# bool quantize_output_tensor; // quantize output.weight
986-
# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
987-
# bool pure; // quantize all tensors to the default type
988-
# bool keep_split; // quantize to the same number of shards
989-
# bool dry_run; // calculate and show the final quantization size without performing quantization
990-
# void * imatrix; // pointer to importance matrix data
991-
# void * kv_overrides; // pointer to vector containing overrides
992-
# void * tensor_types; // pointer to vector containing tensor types
993-
# void * prune_layers; // pointer to vector containing layer indices to prune
1019+
# int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
1020+
# enum llama_ftype ftype; // quantize to this llama_ftype
1021+
# enum ggml_type output_tensor_type; // output tensor type
1022+
# enum ggml_type token_embedding_type; // token embeddings tensor type
1023+
# bool allow_requantize; // allow quantizing non-f32/f16 tensors
1024+
# bool quantize_output_tensor; // quantize output.weight
1025+
# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
1026+
# bool pure; // quantize all tensors to the default type
1027+
# bool keep_split; // quantize to the same number of shards
1028+
# bool dry_run; // calculate and show the final quantization size without performing quantization
1029+
# const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data
1030+
# const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides
1031+
# const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides
1032+
# const int32_t * prune_layers; // pointer to layer indices to prune
9941033
# } llama_model_quantize_params;
9951034
class llama_model_quantize_params(ctypes.Structure):
9961035
"""Parameters for llama_model_quantize
@@ -1006,10 +1045,10 @@ class llama_model_quantize_params(ctypes.Structure):
10061045
pure (bool): quantize all tensors to the default type
10071046
keep_split (bool): quantize to the same number of shards
10081047
dry_run (bool): calculate and show the final quantization size without performing quantization
1009-
imatrix (ctypes.c_void_p): pointer to importance matrix data
1010-
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
1011-
tensor_types (ctypes.c_void_p): pointer to vector containing tensor types
1012-
prune_layers (ctypes.c_void_p): pointer to vector containing layer indices to prune
1048+
imatrix (POINTER(llama_model_imatrix_data)): Pointer to importance matrix data.
1049+
kv_overrides (POINTER(llama_model_kv_override)): Pointer to KV overrides.
1050+
tt_overrides (POINTER(llama_model_tensor_override)): Pointer to tensor overrides.
1051+
prune_layers (POINTER(c_int32)): Pointer to layer indices to prune.
10131052
"""
10141053

10151054
if TYPE_CHECKING:
@@ -1023,10 +1062,10 @@ class llama_model_quantize_params(ctypes.Structure):
10231062
pure: bool
10241063
keep_split: bool
10251064
dry_run: bool
1026-
imatrix: ctypes.c_void_p
1027-
kv_overrides: ctypes.c_void_p
1028-
tensor_types: ctypes.c_void_p
1029-
prune_layers: ctypes.c_void_p
1065+
imatrix: ctypes.POINTER(llama_model_imatrix_data)
1066+
kv_overrides: ctypes.POINTER(llama_model_kv_override)
1067+
tensor_types: ctypes.POINTER(llama_model_tensor_override)
1068+
prune_layers: ctypes.POINTER(ctypes.c_int32)
10301069

10311070
_fields_ = [
10321071
("nthread", ctypes.c_int32),
@@ -1039,10 +1078,10 @@ class llama_model_quantize_params(ctypes.Structure):
10391078
("pure", ctypes.c_bool),
10401079
("keep_split", ctypes.c_bool),
10411080
("dry_run", ctypes.c_bool),
1042-
("imatrix", ctypes.c_void_p),
1043-
("kv_overrides", ctypes.c_void_p),
1044-
("tensor_types", ctypes.c_void_p),
1045-
("prune_layers", ctypes.c_void_p),
1081+
("imatrix", ctypes.POINTER(llama_model_imatrix_data)),
1082+
("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
1083+
("tt_overrides", ctypes.POINTER(llama_model_tensor_override)),
1084+
("prune_layers", ctypes.POINTER(ctypes.c_int32)),
10461085
]
10471086

10481087

0 commit comments

Comments
 (0)