Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
package org.apache.iotdb.ainode.it;

import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
import org.apache.iotdb.itbase.category.AIClusterIT;
import org.apache.iotdb.itbase.env.BaseEnv;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;

import java.sql.Connection;
import java.sql.ResultSet;
Expand All @@ -41,9 +45,13 @@
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeInstanceManagementIT {

private static final Set<String> TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1"));
private static final String TARGET_DEVICES_STR = "0,1";
private static final Set<String> TARGET_DEVICES =
new HashSet<>(Arrays.asList(TARGET_DEVICES_STR.split(",")));

@BeforeClass
public static void setUp() throws Exception {
Expand Down Expand Up @@ -76,53 +84,57 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter
// Ensure resources
try (ResultSet resultSet = statement.executeQuery("SHOW AI_DEVICES")) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "DeviceID");
checkHeader(resultSetMetaData, "DeviceId,DeviceType");
final Set<String> resultDevices = new HashSet<>();
while (resultSet.next()) {
resultDevices.add(resultSet.getString("DeviceID"));
resultDevices.add(resultSet.getString("DeviceId"));
}
Assert.assertEquals(TARGET_DEVICES, resultDevices);
Set<String> expected = new HashSet<>(TARGET_DEVICES);
expected.add("cpu");
Assert.assertEquals(expected, resultDevices);
}

// Load sundial to each device
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES));
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR));
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
// Unload sundial from each device
statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES));
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR));
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);

// Load timer_xl to each device
statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES));
checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES_STR));
checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR);
// Unload timer_xl from each device
statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES));
checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES_STR));
checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR);
}

private static final int LOOP_CNT = 10;

@Test
// @Test
public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < LOOP_CNT; i++) {
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR));
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
statement.execute(
String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR));
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
}
}
}

@Test
// @Test
public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < LOOP_CNT; i++) {
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR));
statement.execute(
String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR));
}
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
}
}

Expand All @@ -145,23 +157,23 @@ public void failTestInTableModel() throws SQLException {
private void failTest(Statement statement) {
errorTest(
statement,
"LOAD MODEL unknown TO DEVICES \"cpu,0,1\"",
"1505: Cannot load model [unknown], because it is neither a built-in nor a fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.");
"LOAD MODEL unknown TO DEVICES 'cpu,0,1'",
"1504: Model [unknown] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.");
errorTest(
statement,
"LOAD MODEL sundial TO DEVICES \"unknown\"",
"1507: Device ID [unknown] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
"LOAD MODEL sundial TO DEVICES '999'",
"1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
errorTest(
statement,
"UNLOAD MODEL sundial FROM DEVICES \"unknown\"",
"1507: Device ID [unknown] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
"UNLOAD MODEL sundial FROM DEVICES '999'",
"1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
errorTest(
statement,
"LOAD MODEL sundial TO DEVICES \"0,0\"",
"LOAD MODEL sundial TO DEVICES '0,0'",
"1509: Device ID list contains duplicate entries.");
errorTest(
statement,
"UNLOAD MODEL sundial FROM DEVICES \"0,0\"",
"UNLOAD MODEL sundial FROM DEVICES '0,0'",
"1510: Device ID list contains duplicate entries.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,39 +71,58 @@ public static void tearDown() throws Exception {
public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
registerUserDefinedModel(statement);
callInferenceTest(
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
dropUserDefinedModel(statement);
// Test transformers model (chronos2) in tree.
AINodeTestUtils.FakeModelInfo modelInfo =
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
callInferenceTest(statement, modelInfo);
dropUserDefinedModel(statement, modelInfo.getModelId());
errorTest(
statement,
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
"1505: 't5' is already used by a Transformers config, pick another name.");
statement.execute("drop model origin_chronos");

// Test PytorchModelHubMixin model (mantis) in tree.
modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active");
registerUserDefinedModel(statement, modelInfo, "file:///data/mantis");
dropUserDefinedModel(statement, modelInfo.getModelId());
}
}

@Test
public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
registerUserDefinedModel(statement);
forecastTableFunctionTest(
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
dropUserDefinedModel(statement);
// Test transformers model (chronos2) in table.
AINodeTestUtils.FakeModelInfo modelInfo =
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
forecastTableFunctionTest(statement, modelInfo);
dropUserDefinedModel(statement, modelInfo.getModelId());
errorTest(
statement,
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
"1505: 't5' is already used by a Transformers config, pick another name.");
statement.execute("drop model origin_chronos");

// Test PytorchModelHubMixin model (mantis) in table.
modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active");
registerUserDefinedModel(statement, modelInfo, "file:///data/mantis");
dropUserDefinedModel(statement, modelInfo.getModelId());
}
}

private void registerUserDefinedModel(Statement statement)
public static void registerUserDefinedModel(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo, String uri)
throws SQLException, InterruptedException {
String modelId = modelInfo.getModelId();
String modelType = modelInfo.getModelType();
String category = modelInfo.getCategory();
final String CREATE_MODEL_TEMPLATE = "create model %s using uri \"%s\"";
final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'";
final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\"";
final String showSql = "SHOW MODELS user_chronos";
final String registerSql = String.format(CREATE_MODEL_TEMPLATE, modelId, uri);
final String showSql = String.format("SHOW MODELS %s", modelId);
statement.execute(alterConfigSQL);
statement.execute(registerSql);
boolean loading = true;
Expand All @@ -112,13 +131,13 @@ private void registerUserDefinedModel(Statement statement)
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
while (resultSet.next()) {
String modelId = resultSet.getString(1);
String modelType = resultSet.getString(2);
String category = resultSet.getString(3);
String resultModelId = resultSet.getString(1);
String resultModelType = resultSet.getString(2);
String resultCategory = resultSet.getString(3);
String state = resultSet.getString(4);
assertEquals("user_chronos", modelId);
assertEquals("custom_t5", modelType);
assertEquals("user_defined", category);
assertEquals(modelId, resultModelId);
assertEquals(modelType, resultModelType);
assertEquals(category, resultCategory);
if (state.equals("active")) {
loading = false;
} else if (state.equals("loading")) {
Expand All @@ -136,9 +155,9 @@ private void registerUserDefinedModel(Statement statement)
assertFalse(loading);
}

private void dropUserDefinedModel(Statement statement) throws SQLException {
final String showSql = "SHOW MODELS user_chronos";
final String dropSql = "DROP MODEL user_chronos";
public static void dropUserDefinedModel(Statement statement, String modelId) throws SQLException {
final String showSql = String.format("SHOW MODELS %s", modelId);
final String dropSql = String.format("DROP MODEL %s", modelId);
statement.execute(dropSql);
try (ResultSet resultSet = statement.executeQuery(showSql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ public class AINodeTestUtils {

public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP =
Stream.of(
new AbstractMap.SimpleEntry<>(
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
Expand Down Expand Up @@ -171,7 +171,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model
LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
foundDevices.add(deviceId);
LOGGER.info("Model {} is loaded to device {}", modelId, device);
LOGGER.info("Model {} is loaded to device {}", modelId, deviceId);
}
}
if (foundDevices.containsAll(targetDevices)) {
Expand Down Expand Up @@ -252,6 +252,32 @@ public static void prepareDataInTable() throws SQLException {
}
}

/** Prepare db.AI2(s0 FLOAT,...) with 2880 rows of data in table. */
public static void prepareDataInTable2() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
statement.execute("CREATE DATABASE db");
statement.execute(
"CREATE TABLE db.AI2 (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD, s4 FLOAT FIELD, s5 DOUBLE FIELD, s6 INT32 FIELD, s7 INT64 FIELD, s8 FLOAT FIELD, s9 DOUBLE FIELD)");
for (int i = 0; i < 2880; i++) {
statement.execute(
String.format(
"INSERT INTO db.AI2(time,s0,s1,s2,s3,s4,s5,s6,s7,s8,s9) VALUES(%d,%f,%f,%d,%d,%f,%f,%d,%d,%f,%f)",
i,
(float) i,
(double) i,
i,
i,
(float) (i * 2),
(double) (i * 2),
i * 2,
i * 2,
(float) (i * 3),
(double) (i * 3)));
}
}
}

public static class FakeModelInfo {

private final String modelId;
Expand Down
6 changes: 0 additions & 6 deletions iotdb-core/ainode/iotdb/ainode/core/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
# specific language governing permissions and limitations
# under the License.
#
import re

from iotdb.ainode.core.model.model_constants import (
MODEL_CONFIG_FILE_IN_YAML,
MODEL_WEIGHTS_FILE_IN_PT,
)


class _BaseException(Exception):
Expand Down
Loading
Loading