Skip to content
Open
332 changes: 250 additions & 82 deletions conditioner.hpp

Large diffs are not rendered by default.

70 changes: 52 additions & 18 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,19 @@ namespace Flux {

public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true)
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true,
bool diffusers_style = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
if (diffusers_style) {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new SplitLinear(dim, {dim, dim, dim}, qkv_bias));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
}

std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
Expand Down Expand Up @@ -261,15 +266,16 @@ namespace Flux {
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = static_cast<int64_t>(hidden_size * mlp_ratio);

if (!prune_mod && !share_modulation) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) {
Expand All @@ -282,7 +288,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) {
Expand Down Expand Up @@ -424,6 +430,7 @@ namespace Flux {
bool use_yak_mlp;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
bool diffusers_style = false;

public:
SingleStreamBlock(int64_t hidden_size,
Expand All @@ -435,7 +442,8 @@ namespace Flux {
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
Expand All @@ -447,8 +455,11 @@ namespace Flux {
if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2;
}

blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
if (diffusers_style) {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias));
} else {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
}
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
Expand Down Expand Up @@ -776,6 +787,7 @@ namespace Flux {
bool use_yak_mlp = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
bool diffusers_style = false;
ChromaRadianceParams chroma_radiance_params;
};

Expand Down Expand Up @@ -822,7 +834,8 @@ namespace Flux {
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

for (int i = 0; i < params.depth_single_blocks; i++) {
Expand All @@ -835,7 +848,8 @@ namespace Flux {
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

if (params.version == VERSION_CHROMA_RADIANCE) {
Expand Down Expand Up @@ -882,6 +896,11 @@ namespace Flux {
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
if (params.patch_size == 1) {
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
return x;
}
int64_t p = params.patch_size;
int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size;
Expand Down Expand Up @@ -916,6 +935,12 @@ namespace Flux {
int64_t W = w * params.patch_size;
int64_t p = params.patch_size;

if (params.patch_size == 1) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
return x;
}

GGML_ASSERT(C * p * p == x->ne[0]);

x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
Expand Down Expand Up @@ -1302,6 +1327,9 @@ namespace Flux {
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
}
int64_t head_dim = 0;
for (auto pair : tensor_storage_map) {
Expand All @@ -1311,6 +1339,9 @@ namespace Flux {
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
flux_params.guidance_embed = true;
}
if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) {
flux_params.diffusers_style = true;
}
if (tensor_name.find("__x0__") != std::string::npos) {
LOG_DEBUG("using x0 prediction");
flux_params.chroma_radiance_params.use_x0 = true;
Expand Down Expand Up @@ -1366,6 +1397,10 @@ namespace Flux {
LOG_INFO("Using pruned modulation (Chroma)");
}

if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style attention blocks");
}

flux = Flux(flux_params);
flux.init(params_ctx, tensor_storage_map, prefix);
}
Expand Down Expand Up @@ -1477,7 +1512,6 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2};
}

pe_vec = Rope::gen_flux_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
flux_params.patch_size,
Expand All @@ -1490,10 +1524,10 @@ namespace Flux {
flux_params.theta,
circular_y_enabled,
circular_x_enabled,
flux_params.axes_dim);
flux_params.axes_dim,
sd_version_is_longcat(version));
int pos_len = static_cast<int>(pe_vec.size() / flux_params.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = nullptr;
Expand Down
79 changes: 78 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,7 @@ class Linear : public UnaryBlock {
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
float scale = 1.f / 128.f)
: in_features(in_features),
out_features(out_features),
bias(bias),
Expand All @@ -2209,6 +2209,83 @@ class Linear : public UnaryBlock {
}
};

class SplitLinear : public Linear {
protected:
int64_t in_features;
std::vector<int64_t> out_features_vec;
bool bias;
bool force_f32;
bool force_prec_f32;
float scale;
std::string prefix;

void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
// most likely same type as the first weight
params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]);
}
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]);
}
}
}

public:
SplitLinear(int64_t in_features,
std::vector<int64_t> out_features_vec,
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
: Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale),
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
if (ctx->weight_adapter) {
// concat all weights and biases together so it runs in one linear layer
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0);
}

return out;
}
};

__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
Expand Down
1 change: 1 addition & 0 deletions llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ namespace LLM {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);

input_ids = to_backend(input_ids);
attention_mask = to_backend(attention_mask);

for (auto& image_embed : image_embeds) {
image_embed.second = to_backend(image_embed.second);
Expand Down
29 changes: 18 additions & 11 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight;

bool has_multiple_encoders = false;
bool is_unet = false;
Expand All @@ -1044,7 +1044,7 @@ SDVersion ModelLoader::get_sd_version() {

for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
Expand Down Expand Up @@ -1117,6 +1117,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
}
if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") {
context_ebedding_weight = tensor_storage;
}
}
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
Expand Down Expand Up @@ -1144,16 +1147,20 @@ SDVersion ModelLoader::get_sd_version() {
}

if (is_flux && !is_flux2) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
if (context_ebedding_weight.ne[0] == 3584) {
return VERSION_LONGCAT;
} else {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}
return VERSION_FLUX;
}

if (is_flux2) {
Expand Down
11 changes: 10 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum SDVersion {
VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_LONGCAT,
VERSION_COUNT,
};

Expand Down Expand Up @@ -128,6 +129,13 @@ static inline bool sd_version_is_z_image(SDVersion version) {
return false;
}

static inline bool sd_version_is_longcat(SDVersion version) {
if (version == VERSION_LONGCAT) {
return true;
}
return false;
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||
Expand All @@ -145,7 +153,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
sd_version_is_z_image(version) ||
sd_version_is_longcat(version)) {
return true;
}
return false;
Expand Down
Loading