463463LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1
464464
465465
466+ # enum llama_flash_attn_type {
467+ # LLAMA_FLASH_ATTN_TYPE_AUTO = -1,
468+ # LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
469+ # LLAMA_FLASH_ATTN_TYPE_ENABLED = 1,
470+ # };
471+ LLAMA_FLASH_ATTN_TYPE_AUTO = - 1
472+ LLAMA_FLASH_ATTN_TYPE_DISABLED = 0
473+ LLAMA_FLASH_ATTN_TYPE_ENABLED = 1
474+
475+
466476# enum llama_split_mode {
467477# LLAMA_SPLIT_MODE_NONE = 0, // single GPU
468478# LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -761,6 +771,7 @@ class llama_model_params(ctypes.Structure):
761771# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
762772# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
763773# enum llama_attention_type attention_type; // attention type to use for embeddings
774+ # enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
764775
765776# // ref: https://github.com/ggml-org/llama.cpp/pull/2054
766777# float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -770,7 +781,7 @@ class llama_model_params(ctypes.Structure):
770781# float yarn_beta_fast; // YaRN low correction dim
771782# float yarn_beta_slow; // YaRN high correction dim
772783# uint32_t yarn_orig_ctx; // YaRN original context size
773- # float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
784+ # float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
774785
775786# ggml_backend_sched_eval_callback cb_eval;
776787# void * cb_eval_user_data;
@@ -787,15 +798,14 @@ class llama_model_params(ctypes.Structure):
787798# // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
788799# bool embeddings; // if true, extract embeddings (together with logits)
789800# bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
790- # bool flash_attn; // use flash attention [EXPERIMENTAL]
791801# bool no_perf; // measure performance timings
792802# bool op_offload; // offload host tensor operations to device
793- # bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
794- # // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
795- # // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
803+ # bool swa_full; // use full-size SWA cache
796804# bool kv_unified; // use a unified buffer across the input sequences when computing the attention
797- # // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
798- # // ref: https://github.com/ggml-org/llama.cpp/pull/14363
805+
806+ # // [EXPERIMENTAL]
807+ # struct llama_sampler_seq_config * samplers;
808+ # size_t n_samplers;
799809# };
800810class llama_context_params (ctypes .Structure ):
801811 """Parameters for llama_context
@@ -810,6 +820,7 @@ class llama_context_params(ctypes.Structure):
810820 rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
811821 pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
812822 attention_type (int): attention type to use for embeddings
823+ flash_attn_type (int): when to enable Flash Attention, from `enum llama_flash_attn_type`
813824 rope_freq_base (float): RoPE base frequency, 0 = from model
814825 rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
815826 yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
@@ -826,11 +837,12 @@ class llama_context_params(ctypes.Structure):
826837 abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
827838 embeddings (bool): if true, extract embeddings (together with logits)
828839 offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
829- flash_attn (bool): whether to use flash attention
830840 no_perf (bool): whether to measure performance timings
831841 op_offload (bool): offload host tensor operations to device
832842 swa_full (bool): use full-size SWA cache
833843 kv_unified (bool): use a unified buffer across the input sequences when computing the attention
844+ samplers (ctypes.c_void_p): backend sampler chain configuration [EXPERIMENTAL]
845+ n_samplers (ctypes.c_size_t): number of backend sampler chains
834846 """
835847
836848 if TYPE_CHECKING :
@@ -843,6 +855,7 @@ class llama_context_params(ctypes.Structure):
843855 rope_scaling_type : int
844856 pooling_type : int
845857 attention_type : int
858+ flash_attn_type : int
846859 rope_freq_base : float
847860 rope_freq_scale : float
848861 yarn_ext_factor : float
@@ -859,11 +872,12 @@ class llama_context_params(ctypes.Structure):
859872 abort_callback_data : ctypes .c_void_p
860873 embeddings : bool
861874 offload_kqv : bool
862- flash_attn : bool
863875 no_perf : bool
864876 op_offload : bool
865877 swa_full : bool
866878 kv_unified : bool
879+ samplers : ctypes .c_void_p
880+ n_samplers : ctypes .c_size_t
867881
868882 _fields_ = [
869883 ("n_ctx" , ctypes .c_uint32 ),
@@ -875,6 +889,7 @@ class llama_context_params(ctypes.Structure):
875889 ("rope_scaling_type" , ctypes .c_int ),
876890 ("pooling_type" , ctypes .c_int ),
877891 ("attention_type" , ctypes .c_int ),
892+ ("flash_attn_type" , ctypes .c_int ),
878893 ("rope_freq_base" , ctypes .c_float ),
879894 ("rope_freq_scale" , ctypes .c_float ),
880895 ("yarn_ext_factor" , ctypes .c_float ),
@@ -891,11 +906,12 @@ class llama_context_params(ctypes.Structure):
891906 ("abort_callback_data" , ctypes .c_void_p ),
892907 ("embeddings" , ctypes .c_bool ),
893908 ("offload_kqv" , ctypes .c_bool ),
894- ("flash_attn" , ctypes .c_bool ),
895909 ("no_perf" , ctypes .c_bool ),
896910 ("op_offload" , ctypes .c_bool ),
897911 ("swa_full" , ctypes .c_bool ),
898912 ("kv_unified" , ctypes .c_bool ),
913+ ("samplers" , ctypes .c_void_p ),
914+ ("n_samplers" , ctypes .c_size_t ),
899915 ]
900916
901917
0 commit comments