Skip to content

Commit 26a9b02

Browse files
authored
Update ggml_extend.hpp
1 parent 713e721 commit 26a9b02

File tree

1 file changed

+32
-324
lines changed

1 file changed

+32
-324
lines changed

ggml_extend.hpp

Lines changed: 32 additions & 324 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,11 +1277,38 @@ class Linear : public UnaryBlock {
12771277
enum ggml_type wtype = tensor_types[prefix + "A"];
12781278

12791279
uint64_t rank = 64;
1280-
auto entryName = prefix.substr(23, 100);
1281-
if (sizes.find(entryName) != sizes.end()) {
1282-
rank = sizes[entryName];
1283-
} else {
1284-
LOG_INFO("not found: %s", entryName);
1280+
if (prefix.find( "proj") != std::string::npos) {
1281+
rank = 1024;
1282+
} else if (prefix.find(prefix + "qkv") != std::string::npos) {
1283+
rank = 1280;
1284+
} else if (prefix.find(prefix + "mlp.2") != std::string::npos) {
1285+
rank = 1408;
1286+
} else if (prefix.find(prefix + "mlp.0") != std::string::npos) {
1287+
rank = 1344;
1288+
} else if (prefix.find(prefix + "modulation.lin") != std::string::npos) {
1289+
rank = 1280;
1290+
} else if (prefix.find(prefix + "_mod.lin.") != std::string::npos) {
1291+
rank = 1344;
1292+
} else if (prefix.find(prefix + "linear1") != std::string::npos) {
1293+
rank = 1600;
1294+
} else if (prefix.find(prefix + "linear2") != std::string::npos) {
1295+
rank = 1600;
1296+
} else if (prefix.find(prefix + "time_in.in_layer") != std::string::npos) {
1297+
rank = 128;
1298+
} else if (prefix.find(prefix + "time_in.out_layer") != std::string::npos) {
1299+
rank = 768;
1300+
} else if (prefix.find(prefix + "txt_in") != std::string::npos) {
1301+
rank = 960;
1302+
} else if (prefix.find(prefix + "vector_in.in_layer") != std::string::npos) {
1303+
rank = 256;
1304+
} else if (prefix.find(prefix + "vector_in.in_layer") != std::string::npos) {
1305+
rank = 576;
1306+
} else if (prefix.find(prefix + "final_layer.adaLN_modulation") != std::string::npos) {
1307+
rank = 1088;
1308+
} else if (prefix.find(prefix + "guidance_in.in_layer") != std::string::npos) {
1309+
rank = 192;
1310+
} else if (prefix.find(prefix + "guidance_in.out_layer") != std::string::npos) {
1311+
rank = 1280;
12851312
}
12861313

12871314
params["A"] = ggml_new_tensor_2d(ctx, wtype, in_features, rank);
@@ -1300,8 +1327,6 @@ class Linear : public UnaryBlock {
13001327
}
13011328
}
13021329

1303-
static std::map<std::string, uint64_t> sizes;
1304-
13051330
public:
13061331
Linear(int64_t in_features,
13071332
int64_t out_features,
@@ -1332,323 +1357,6 @@ class Linear : public UnaryBlock {
13321357
}
13331358
};
13341359

1335-
std::map<std::string, uint64_t> Linear::sizes = {
1336-
{"double_blocks.0.img_attn.proj.", 960},
1337-
{"double_blocks.0.img_attn.qkv.", 1280},
1338-
{"double_blocks.0.img_mlp.0.", 1472},
1339-
{"double_blocks.0.img_mlp.2.", 1216},
1340-
{"double_blocks.0.img_mod.lin.", 1344},
1341-
{"double_blocks.0.txt_attn.proj.", 960},
1342-
{"double_blocks.0.txt_attn.qkv.", 1088},
1343-
{"double_blocks.0.txt_mlp.0.", 1216},
1344-
{"double_blocks.0.txt_mlp.2.", 1472},
1345-
{"double_blocks.0.txt_mod.lin.", 1344},
1346-
{"double_blocks.1.img_attn.proj.", 896},
1347-
{"double_blocks.1.img_attn.qkv.", 1152},
1348-
{"double_blocks.1.img_mlp.0.", 1408},
1349-
{"double_blocks.1.img_mlp.2.", 1216},
1350-
{"double_blocks.1.img_mod.lin.", 1344},
1351-
{"double_blocks.1.txt_attn.proj.", 896},
1352-
{"double_blocks.1.txt_attn.qkv.", 1152},
1353-
{"double_blocks.1.txt_mlp.0.", 1280},
1354-
{"double_blocks.1.txt_mlp.2.", 1472},
1355-
{"double_blocks.1.txt_mod.lin.", 1280},
1356-
{"double_blocks.10.img_attn.proj.", 1024},
1357-
{"double_blocks.10.img_attn.qkv.", 1280},
1358-
{"double_blocks.10.img_mlp.0.", 1344},
1359-
{"double_blocks.10.img_mlp.2.", 1472},
1360-
{"double_blocks.10.img_mod.lin.", 1344},
1361-
{"double_blocks.10.txt_attn.proj.", 960},
1362-
{"double_blocks.10.txt_attn.qkv.", 1280},
1363-
{"double_blocks.10.txt_mlp.0.", 1344},
1364-
{"double_blocks.10.txt_mlp.2.", 1472},
1365-
{"double_blocks.10.txt_mod.lin.", 1344},
1366-
{"double_blocks.11.img_attn.proj.", 1024},
1367-
{"double_blocks.11.img_attn.qkv.", 1280},
1368-
{"double_blocks.11.img_mlp.0.", 1408},
1369-
{"double_blocks.11.img_mlp.2.", 1472},
1370-
{"double_blocks.11.img_mod.lin.", 1344},
1371-
{"double_blocks.11.txt_attn.proj.", 960},
1372-
{"double_blocks.11.txt_attn.qkv.", 1280},
1373-
{"double_blocks.11.txt_mlp.0.", 1344},
1374-
{"double_blocks.11.txt_mlp.2.", 1472},
1375-
{"double_blocks.11.txt_mod.lin.", 1344},
1376-
{"double_blocks.12.img_attn.proj.", 1024},
1377-
{"double_blocks.12.img_attn.qkv.", 1344},
1378-
{"double_blocks.12.img_mlp.0.", 1408},
1379-
{"double_blocks.12.img_mlp.2.", 1472},
1380-
{"double_blocks.12.img_mod.lin.", 1344},
1381-
{"double_blocks.12.txt_attn.proj.", 1024},
1382-
{"double_blocks.12.txt_attn.qkv.", 1344},
1383-
{"double_blocks.12.txt_mlp.0.", 1472},
1384-
{"double_blocks.12.txt_mlp.2.", 1600},
1385-
{"double_blocks.12.txt_mod.lin.", 1344},
1386-
{"double_blocks.13.img_attn.proj.", 1024},
1387-
{"double_blocks.13.img_attn.qkv.", 1280},
1388-
{"double_blocks.13.img_mlp.0.", 1408},
1389-
{"double_blocks.13.img_mlp.2.", 1472},
1390-
{"double_blocks.13.img_mod.lin.", 1344},
1391-
{"double_blocks.13.txt_attn.proj.", 1024},
1392-
{"double_blocks.13.txt_attn.qkv.", 1280},
1393-
{"double_blocks.13.txt_mlp.0.", 1408},
1394-
{"double_blocks.13.txt_mlp.2.", 1536},
1395-
{"double_blocks.13.txt_mod.lin.", 1344},
1396-
{"double_blocks.14.img_attn.proj.", 1024},
1397-
{"double_blocks.14.img_attn.qkv.", 1344},
1398-
{"double_blocks.14.img_mlp.0.", 1408},
1399-
{"double_blocks.14.img_mlp.2.", 1472},
1400-
{"double_blocks.14.img_mod.lin.", 1344},
1401-
{"double_blocks.14.txt_attn.proj.", 1024},
1402-
{"double_blocks.14.txt_attn.qkv.", 1280},
1403-
{"double_blocks.14.txt_mlp.0.", 1472},
1404-
{"double_blocks.14.txt_mlp.2.", 1536},
1405-
{"double_blocks.14.txt_mod.lin.", 1344},
1406-
{"double_blocks.15.img_attn.proj.", 1024},
1407-
{"double_blocks.15.img_attn.qkv.", 1344},
1408-
{"double_blocks.15.img_mlp.0.", 1408},
1409-
{"double_blocks.15.img_mlp.2.", 1472},
1410-
{"double_blocks.15.img_mod.lin.", 1344},
1411-
{"double_blocks.15.txt_attn.proj.", 1024},
1412-
{"double_blocks.15.txt_attn.qkv.", 1280},
1413-
{"double_blocks.15.txt_mlp.0.", 1408},
1414-
{"double_blocks.15.txt_mlp.2.", 1472},
1415-
{"double_blocks.15.txt_mod.lin.", 1344},
1416-
{"double_blocks.16.img_attn.proj.", 1088},
1417-
{"double_blocks.16.img_attn.qkv.", 1344},
1418-
{"double_blocks.16.img_mlp.0.", 1408},
1419-
{"double_blocks.16.img_mlp.2.", 1472},
1420-
{"double_blocks.16.img_mod.lin.", 1344},
1421-
{"double_blocks.16.txt_attn.proj.", 1024},
1422-
{"double_blocks.16.txt_attn.qkv.", 1344},
1423-
{"double_blocks.16.txt_mlp.0.", 1344},
1424-
{"double_blocks.16.txt_mlp.2.", 1472},
1425-
{"double_blocks.16.txt_mod.lin.", 1344},
1426-
{"double_blocks.17.img_attn.proj.", 1088},
1427-
{"double_blocks.17.img_attn.qkv.", 1344},
1428-
{"double_blocks.17.img_mlp.0.", 1408},
1429-
{"double_blocks.17.img_mlp.2.", 1472},
1430-
{"double_blocks.17.img_mod.lin.", 1408},
1431-
{"double_blocks.17.txt_attn.proj.", 1024},
1432-
{"double_blocks.17.txt_attn.qkv.", 1344},
1433-
{"double_blocks.17.txt_mlp.0.", 1344},
1434-
{"double_blocks.17.txt_mlp.2.", 1472},
1435-
{"double_blocks.17.txt_mod.lin.", 1344},
1436-
{"double_blocks.18.img_attn.proj.", 1088},
1437-
{"double_blocks.18.img_attn.qkv.", 1344},
1438-
{"double_blocks.18.img_mlp.0.", 1472},
1439-
{"double_blocks.18.img_mlp.2.", 1536},
1440-
{"double_blocks.18.img_mod.lin.", 1408},
1441-
{"double_blocks.18.txt_attn.proj.", 1024},
1442-
{"double_blocks.18.txt_attn.qkv.", 1344},
1443-
{"double_blocks.18.txt_mlp.0.", 1344},
1444-
{"double_blocks.18.txt_mlp.2.", 1472},
1445-
{"double_blocks.18.txt_mod.lin.", 1344},
1446-
{"double_blocks.2.img_attn.proj.", 832},
1447-
{"double_blocks.2.img_attn.qkv.", 1088},
1448-
{"double_blocks.2.img_mlp.0.", 1216},
1449-
{"double_blocks.2.img_mlp.2.", 1216},
1450-
{"double_blocks.2.img_mod.lin.", 1344},
1451-
{"double_blocks.2.txt_attn.proj.", 896},
1452-
{"double_blocks.2.txt_attn.qkv.", 1152},
1453-
{"double_blocks.2.txt_mlp.0.", 1280},
1454-
{"double_blocks.2.txt_mlp.2.", 1408},
1455-
{"double_blocks.2.txt_mod.lin.", 1344},
1456-
{"double_blocks.3.img_attn.proj.", 896},
1457-
{"double_blocks.3.img_attn.qkv.", 1088},
1458-
{"double_blocks.3.img_mlp.0.", 1152},
1459-
{"double_blocks.3.img_mlp.2.", 1088},
1460-
{"double_blocks.3.img_mod.lin.", 1344},
1461-
{"double_blocks.3.txt_attn.proj.", 896},
1462-
{"double_blocks.3.txt_attn.qkv.", 1152},
1463-
{"double_blocks.3.txt_mlp.0.", 1280},
1464-
{"double_blocks.3.txt_mlp.2.", 1408},
1465-
{"double_blocks.3.txt_mod.lin.", 1344},
1466-
{"double_blocks.4.img_attn.proj.", 896},
1467-
{"double_blocks.4.img_attn.qkv.", 1088},
1468-
{"double_blocks.4.img_mlp.0.", 1152},
1469-
{"double_blocks.4.img_mlp.2.", 1472},
1470-
{"double_blocks.4.img_mod.lin.", 1344},
1471-
{"double_blocks.4.txt_attn.proj.", 896},
1472-
{"double_blocks.4.txt_attn.qkv.", 1216},
1473-
{"double_blocks.4.txt_mlp.0.", 1280},
1474-
{"double_blocks.4.txt_mlp.2.", 1408},
1475-
{"double_blocks.4.txt_mod.lin.", 1344},
1476-
{"double_blocks.5.img_attn.proj.", 960},
1477-
{"double_blocks.5.img_attn.qkv.", 1152},
1478-
{"double_blocks.5.img_mlp.0.", 1216},
1479-
{"double_blocks.5.img_mlp.2.", 1472},
1480-
{"double_blocks.5.img_mod.lin.", 1344},
1481-
{"double_blocks.5.txt_attn.proj.", 896},
1482-
{"double_blocks.5.txt_attn.qkv.", 1216},
1483-
{"double_blocks.5.txt_mlp.0.", 1280},
1484-
{"double_blocks.5.txt_mlp.2.", 1408},
1485-
{"double_blocks.5.txt_mod.lin.", 1344},
1486-
{"double_blocks.6.img_attn.proj.", 960},
1487-
{"double_blocks.6.img_attn.qkv.", 1216},
1488-
{"double_blocks.6.img_mlp.0.", 1280},
1489-
{"double_blocks.6.img_mlp.2.", 1472},
1490-
{"double_blocks.6.img_mod.lin.", 1344},
1491-
{"double_blocks.6.txt_attn.proj.", 960},
1492-
{"double_blocks.6.txt_attn.qkv.", 1216},
1493-
{"double_blocks.6.txt_mlp.0.", 1280},
1494-
{"double_blocks.6.txt_mlp.2.", 1408},
1495-
{"double_blocks.6.txt_mod.lin.", 1344},
1496-
{"double_blocks.7.img_attn.proj.", 960},
1497-
{"double_blocks.7.img_attn.qkv.", 1216},
1498-
{"double_blocks.7.img_mlp.0.", 1280},
1499-
{"double_blocks.7.img_mlp.2.", 1408},
1500-
{"double_blocks.7.img_mod.lin.", 1344},
1501-
{"double_blocks.7.txt_attn.proj.", 960},
1502-
{"double_blocks.7.txt_attn.qkv.", 1280},
1503-
{"double_blocks.7.txt_mlp.0.", 1280},
1504-
{"double_blocks.7.txt_mlp.2.", 1408},
1505-
{"double_blocks.7.txt_mod.lin.", 1344},
1506-
{"double_blocks.8.img_attn.proj.", 1024},
1507-
{"double_blocks.8.img_attn.qkv.", 1280},
1508-
{"double_blocks.8.img_mlp.0.", 1280},
1509-
{"double_blocks.8.img_mlp.2.", 1472},
1510-
{"double_blocks.8.img_mod.lin.", 1344},
1511-
{"double_blocks.8.txt_attn.proj.", 1024},
1512-
{"double_blocks.8.txt_attn.qkv.", 1280},
1513-
{"double_blocks.8.txt_mlp.0.", 1280},
1514-
{"double_blocks.8.txt_mlp.2.", 1408},
1515-
{"double_blocks.8.txt_mod.lin.", 1344},
1516-
{"double_blocks.9.img_attn.proj.", 1024},
1517-
{"double_blocks.9.img_attn.qkv.", 1216},
1518-
{"double_blocks.9.img_mlp.0.", 1344},
1519-
{"double_blocks.9.img_mlp.2.", 1472},
1520-
{"double_blocks.9.img_mod.lin.", 1344},
1521-
{"double_blocks.9.txt_attn.proj.", 960},
1522-
{"double_blocks.9.txt_attn.qkv.", 1280},
1523-
{"double_blocks.9.txt_mlp.0.", 1344},
1524-
{"double_blocks.9.txt_mlp.2.", 1408},
1525-
{"double_blocks.9.txt_mod.lin.", 1344},
1526-
{"final_layer.adaLN_modulation.1.", 1088},
1527-
{"final_layer.linear.", 64},
1528-
{"guidance_in.in_layer.", 192},
1529-
{"guidance_in.out_layer.", 1280},
1530-
{"img_in.", 64},
1531-
{"single_blocks.0.linear1.", 1600},
1532-
{"single_blocks.0.linear2.", 1600},
1533-
{"single_blocks.0.modulation.lin.", 1280},
1534-
{"single_blocks.1.linear1.", 1600},
1535-
{"single_blocks.1.linear2.", 1600},
1536-
{"single_blocks.1.modulation.lin.", 1280},
1537-
{"single_blocks.10.linear1.", 1600},
1538-
{"single_blocks.10.linear2.", 1600},
1539-
{"single_blocks.10.modulation.lin.", 1280},
1540-
{"single_blocks.11.linear1.", 1600},
1541-
{"single_blocks.11.linear2.", 1600},
1542-
{"single_blocks.11.modulation.lin.", 1280},
1543-
{"single_blocks.12.linear1.", 1664},
1544-
{"single_blocks.12.linear2.", 1536},
1545-
{"single_blocks.12.modulation.lin.", 1280},
1546-
{"single_blocks.13.linear1.", 1664},
1547-
{"single_blocks.13.linear2.", 1600},
1548-
{"single_blocks.13.modulation.lin.", 1280},
1549-
{"single_blocks.14.linear1.", 1664},
1550-
{"single_blocks.14.linear2.", 1600},
1551-
{"single_blocks.14.modulation.lin.", 1280},
1552-
{"single_blocks.15.linear1.", 1664},
1553-
{"single_blocks.15.linear2.", 1600},
1554-
{"single_blocks.15.modulation.lin.", 1280},
1555-
{"single_blocks.16.linear1.", 1600},
1556-
{"single_blocks.16.linear2.", 1600},
1557-
{"single_blocks.16.modulation.lin.", 1280},
1558-
{"single_blocks.17.linear1.", 1664},
1559-
{"single_blocks.17.linear2.", 1600},
1560-
{"single_blocks.17.modulation.lin.", 1280},
1561-
{"single_blocks.18.linear1.", 1600},
1562-
{"single_blocks.18.linear2.", 1600},
1563-
{"single_blocks.18.modulation.lin.", 1280},
1564-
{"single_blocks.19.linear1.", 1600},
1565-
{"single_blocks.19.linear2.", 1600},
1566-
{"single_blocks.19.modulation.lin.", 1280},
1567-
{"single_blocks.2.linear1.", 1600},
1568-
{"single_blocks.2.linear2.", 1600},
1569-
{"single_blocks.2.modulation.lin.", 1280},
1570-
{"single_blocks.20.linear1.", 1600},
1571-
{"single_blocks.20.linear2.", 1600},
1572-
{"single_blocks.20.modulation.lin.", 1280},
1573-
{"single_blocks.21.linear1.", 1600},
1574-
{"single_blocks.21.linear2.", 1600},
1575-
{"single_blocks.21.modulation.lin.", 1280},
1576-
{"single_blocks.22.linear1.", 1600},
1577-
{"single_blocks.22.linear2.", 1600},
1578-
{"single_blocks.22.modulation.lin.", 1280},
1579-
{"single_blocks.23.linear1.", 1600},
1580-
{"single_blocks.23.linear2.", 1600},
1581-
{"single_blocks.23.modulation.lin.", 1280},
1582-
{"single_blocks.24.linear1.", 1600},
1583-
{"single_blocks.24.linear2.", 1536},
1584-
{"single_blocks.24.modulation.lin.", 1280},
1585-
{"single_blocks.25.linear1.", 1664},
1586-
{"single_blocks.25.linear2.", 1600},
1587-
{"single_blocks.25.modulation.lin.", 1280},
1588-
{"single_blocks.26.linear1.", 1664},
1589-
{"single_blocks.26.linear2.", 1600},
1590-
{"single_blocks.26.modulation.lin.", 1216},
1591-
{"single_blocks.27.linear1.", 1600},
1592-
{"single_blocks.27.linear2.", 1600},
1593-
{"single_blocks.27.modulation.lin.", 1216},
1594-
{"single_blocks.28.linear1.", 1600},
1595-
{"single_blocks.28.linear2.", 1536},
1596-
{"single_blocks.28.modulation.lin.", 1216},
1597-
{"single_blocks.29.linear1.", 1600},
1598-
{"single_blocks.29.linear2.", 1536},
1599-
{"single_blocks.29.modulation.lin.", 1216},
1600-
{"single_blocks.3.linear1.", 1600},
1601-
{"single_blocks.3.linear2.", 1600},
1602-
{"single_blocks.3.modulation.lin.", 1280},
1603-
{"single_blocks.30.linear1.", 1600},
1604-
{"single_blocks.30.linear2.", 1536},
1605-
{"single_blocks.30.modulation.lin.", 1216},
1606-
{"single_blocks.31.linear1.", 1600},
1607-
{"single_blocks.31.linear2.", 1472},
1608-
{"single_blocks.31.modulation.lin.", 1152},
1609-
{"single_blocks.32.linear1.", 1600},
1610-
{"single_blocks.32.linear2.", 1472},
1611-
{"single_blocks.32.modulation.lin.", 1152},
1612-
{"single_blocks.33.linear1.", 1600},
1613-
{"single_blocks.33.linear2.", 1408},
1614-
{"single_blocks.33.modulation.lin.", 1152},
1615-
{"single_blocks.34.linear1.", 1536},
1616-
{"single_blocks.34.linear2.", 1280},
1617-
{"single_blocks.34.modulation.lin.", 1152},
1618-
{"single_blocks.35.linear1.", 1536},
1619-
{"single_blocks.35.linear2.", 1088},
1620-
{"single_blocks.35.modulation.lin.", 1088},
1621-
{"single_blocks.36.linear1.", 1536},
1622-
{"single_blocks.36.linear2.", 1024},
1623-
{"single_blocks.36.modulation.lin.", 1152},
1624-
{"single_blocks.37.linear1.", 1344},
1625-
{"single_blocks.37.linear2.", 896},
1626-
{"single_blocks.37.modulation.lin.", 1280},
1627-
{"single_blocks.4.linear1.", 1600},
1628-
{"single_blocks.4.linear2.", 1600},
1629-
{"single_blocks.4.modulation.lin.", 1280},
1630-
{"single_blocks.5.linear1.", 1600},
1631-
{"single_blocks.5.linear2.", 1600},
1632-
{"single_blocks.5.modulation.lin.", 1280},
1633-
{"single_blocks.6.linear1.", 1600},
1634-
{"single_blocks.6.linear2.", 1600},
1635-
{"single_blocks.6.modulation.lin.", 1280},
1636-
{"single_blocks.7.linear1.", 1600},
1637-
{"single_blocks.7.linear2.", 1600},
1638-
{"single_blocks.7.modulation.lin.", 1280},
1639-
{"single_blocks.8.linear1.", 1600},
1640-
{"single_blocks.8.linear2.", 1600},
1641-
{"single_blocks.8.modulation.lin.", 1280},
1642-
{"single_blocks.9.linear1.", 1664},
1643-
{"single_blocks.9.linear2.", 1600},
1644-
{"single_blocks.9.modulation.lin.", 1280},
1645-
{"time_in.in_layer.", 128},
1646-
{"time_in.out_layer.", 768},
1647-
{"txt_in.", 960},
1648-
{"vector_in.in_layer.", 256},
1649-
{"vector_in.out_layer.", 576},
1650-
};
1651-
16521360
class Embedding : public UnaryBlock {
16531361
protected:
16541362
int64_t embedding_dim;

0 commit comments

Comments
 (0)