Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public interface IText2SqlHook : IHookBase
{
// Get database type
string GetDatabaseType(RoleDialogModel message);
string GetConnectionString(RoleDialogModel message);
string? GetConnectionString(RoleDialogModel message);
Task SqlGenerated(RoleDialogModel message);
Task SqlExecuting(RoleDialogModel message);
Task SqlExecuted(RoleDialogModel message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ public async Task<IActionResult> ExecuteSqlQuery([FromRoute] string conversation

var fn = _services.GetRequiredService<IRoutingService>();
var conv = _services.GetRequiredService<IConversationService>();
conv.SetConversationId(conversationId, [new MessageState("database_type", sqlQueryRequest.DbType)]);
conv.SetConversationId(conversationId,
[
new MessageState("database_type", sqlQueryRequest.DbType),
new MessageState("data_source_name", sqlQueryRequest.DataSource),
]);

var msg = new RoleDialogModel(AgentRole.User, sqlQueryRequest.SqlStatement)
{
Expand All @@ -38,6 +42,7 @@ public async Task<IActionResult> ExecuteSqlQuery([FromRoute] string conversation
msg.FunctionArgs = JsonSerializer.Serialize(new ExecuteQueryArgs
{
DbType = sqlQueryRequest.DbType,
DataSource = sqlQueryRequest.DataSource,
SqlStatements = [sqlQueryRequest.SqlStatement],
ResultFormat = sqlQueryRequest.ResultFormat
});
Expand Down Expand Up @@ -79,4 +84,20 @@ public async Task<IActionResult> AddQueryExecutionResult([FromRoute] string conv

return Ok(dialog);
}

[HttpGet]
[Route("/sql-driver/connections")]
public IActionResult GetConnectionSettings()
{
var settings = _services.GetRequiredService<SqlDriverSetting>();

var connections = settings.Connections.Select(x => new DataSourceSetting
{
DbType = x.DbType,
Name = x.Name,
ConnectionString = "**********"
}).ToArray();

return Ok(connections);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ public class SqlQueryRequest
{
public string AgentId { get; set; } = null!;
public string DbType { get; set; } = null!;
/// <summary>
/// Data source name
/// </summary>
public string DataSource { get; set; } = null!;
public string SqlStatement { get; set; } = null!;
public string ResultFormat { get; set; } = "markdown";
public bool IsEphemeral { get; set; } = false;
Expand Down
19 changes: 10 additions & 9 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ public ExecuteQueryFn(IServiceProvider services, SqlDriverSetting setting, ILogg

public async Task<bool> Execute(RoleDialogModel message)
{
var args = JsonSerializer.Deserialize<ExecuteQueryArgs>(message.FunctionArgs);
var args = JsonSerializer.Deserialize<ExecuteQueryArgs>(message.FunctionArgs) ?? new();
//var refinedArgs = await RefineSqlStatement(message, args);
var dbHook = _services.GetRequiredService<IText2SqlHook>();
var dbType = dbHook.GetDatabaseType(message);
var dbConnectionString = dbHook.GetConnectionString(message);
var connectionString = _setting.Connections.FirstOrDefault(x => x.Name.Equals(args.DataSource, StringComparison.OrdinalIgnoreCase))?.ConnectionString;
var dbConnectionString = dbHook.GetConnectionString(message) ?? connectionString ?? throw new Exception("database connection is not found");

// Print all the SQL statements for debugging
_logger.LogInformation("Executing SQL Statements: {SqlStatements}", string.Join("\r\n", args.SqlStatements));
Expand All @@ -39,9 +40,9 @@ public async Task<bool> Execute(RoleDialogModel message)
{
results = dbType.ToLower() switch
{
"mysql" => RunQueryInMySql(args.SqlStatements),
"mysql" => RunQueryInMySql(dbConnectionString, args.SqlStatements),
"sqlserver" or "mssql" => RunQueryInSqlServer(dbConnectionString, args.SqlStatements),
"redshift" => RunQueryInRedshift(args.SqlStatements),
"redshift" => RunQueryInRedshift(dbConnectionString, args.SqlStatements),
"sqlite" => RunQueryInSqlite(dbConnectionString, args.SqlStatements),
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
};
Expand Down Expand Up @@ -214,24 +215,24 @@ private string EscapeMarkdownField(string field)
return field.Replace("|", "\\|");
}

private IEnumerable<dynamic> RunQueryInMySql(string[] sqlTexts)
private IEnumerable<dynamic> RunQueryInMySql(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new MySqlConnection(settings.MySqlExecutionConnectionString ?? settings.MySqlConnectionString);
using var connection = new MySqlConnection(connectionString);
return connection.Query(string.Join(";\r\n", sqlTexts));
}

private IEnumerable<dynamic> RunQueryInSqlServer(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new SqlConnection(settings.SqlServerExecutionConnectionString ?? settings.SqlServerConnectionString ?? connectionString);
using var connection = new SqlConnection(connectionString);
return connection.Query(string.Join("\r\n", sqlTexts));
}

private IEnumerable<dynamic> RunQueryInRedshift(string[] sqlTexts)
private IEnumerable<dynamic> RunQueryInRedshift(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new NpgsqlConnection(settings.RedshiftConnectionString);
using var connection = new NpgsqlConnection(connectionString);
return connection.Query(string.Join("\r\n", sqlTexts));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Data.SqlClient;
using MongoDB.Driver.Core.Configuration;
using MySqlConnector;
using Npgsql;

Expand Down Expand Up @@ -44,7 +45,9 @@ private List<string> GetDdlFromMySql(string[] tables)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
var tableDdls = new List<string>();
using var connection = new MySqlConnection(settings.MySqlMetaConnectionString ?? settings.MySqlConnectionString);

var connectionString = settings.Connections.FirstOrDefault(x => x.DbType == "mysql")?.ConnectionString;
using var connection = new MySqlConnection(connectionString);
connection.Open();

foreach (var table in tables)
Expand Down Expand Up @@ -79,7 +82,8 @@ private List<string> GetDdlFromSqlServer(string[] tables)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
var tableDdls = new List<string>();
using var connection = new SqlConnection(settings.SqlServerExecutionConnectionString ?? settings.SqlServerConnectionString);
var connectionString = settings.Connections.FirstOrDefault(x => x.DbType == "mssql")?.ConnectionString;
using var connection = new SqlConnection(connectionString);
connection.Open();

foreach (var table in tables)
Expand Down Expand Up @@ -132,7 +136,8 @@ private List<string> GetDdlFromRedshift(string[] tables)
var settings = _services.GetRequiredService<SqlDriverSetting>();
var tableDdls = new List<string>();
var schemas = "'onebi_hour','onebi_day'";
using var connection = new NpgsqlConnection(settings.RedshiftConnectionString);
var connectionString = settings.Connections.FirstOrDefault(x => x.DbType == "redshift")?.ConnectionString;
using var connection = new NpgsqlConnection(connectionString);
connection.Open();

foreach (var table in tables)
Expand Down
35 changes: 18 additions & 17 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlSelect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ public class SqlSelect : IFunctionCallback
{
public string Name => "sql_select";
private readonly IServiceProvider _services;
private readonly SqlDriverSetting _settings;

public SqlSelect(IServiceProvider services)
public SqlSelect(IServiceProvider services, SqlDriverSetting settings)
{
_settings = settings;
_services = services;
}

Expand All @@ -30,13 +32,16 @@ public async Task<bool> Execute(RoleDialogModel message)
// check if need to instantely
var dbHook = _services.GetRequiredService<IText2SqlHook>();
var dbType = dbHook.GetDatabaseType(message);
var dbConnectionString = dbHook.GetConnectionString(message) ??
_settings.Connections.FirstOrDefault(c => c.DbType == dbType)?.ConnectionString ??
throw new Exception("database connectdion is not found");

var result = dbType switch
{
"mysql" => RunQueryInMySql(args),
"sqlserver" or "mssql" => RunQueryInSqlServer(args),
"redshift" => RunQueryInRedshift(args),
"mongodb" => RunQueryInMongoDb(args),
"mysql" => RunQueryInMySql(dbConnectionString, args),
"sqlserver" or "mssql" => RunQueryInSqlServer(dbConnectionString, args),
"redshift" => RunQueryInRedshift(dbConnectionString, args),
"mongodb" => RunQueryInMongoDb(dbConnectionString, args),
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
};

Expand All @@ -54,10 +59,9 @@ public async Task<bool> Execute(RoleDialogModel message)
return true;
}

private IEnumerable<dynamic> RunQueryInMySql(SqlStatement args)
private IEnumerable<dynamic> RunQueryInMySql(string connectionString, SqlStatement args)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new MySqlConnection(settings.MySqlExecutionConnectionString);
using var connection = new MySqlConnection(connectionString);
var dictionary = new Dictionary<string, object>();
foreach (var p in args.Parameters)
{
Expand All @@ -66,10 +70,9 @@ private IEnumerable<dynamic> RunQueryInMySql(SqlStatement args)
return connection.Query(args.Statement, dictionary);
}

private IEnumerable<dynamic> RunQueryInSqlServer(SqlStatement args)
private IEnumerable<dynamic> RunQueryInSqlServer(string connectionString, SqlStatement args)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new SqlConnection(settings.SqlServerExecutionConnectionString ?? settings.SqlServerConnectionString);
using var connection = new SqlConnection(connectionString);
var dictionary = new Dictionary<string, object>();
foreach (var p in args.Parameters)
{
Expand All @@ -78,10 +81,9 @@ private IEnumerable<dynamic> RunQueryInSqlServer(SqlStatement args)
return connection.Query(args.Statement, dictionary);
}

private IEnumerable<dynamic> RunQueryInRedshift(SqlStatement args)
private IEnumerable<dynamic> RunQueryInRedshift(string connectionString, SqlStatement args)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new NpgsqlConnection(settings.RedshiftConnectionString);
using var connection = new NpgsqlConnection(connectionString);
var dictionary = new Dictionary<string, object>();
foreach (var p in args.Parameters)
{
Expand All @@ -90,10 +92,9 @@ private IEnumerable<dynamic> RunQueryInRedshift(SqlStatement args)
return connection.Query(args.Statement, dictionary);
}

private IEnumerable<dynamic> RunQueryInMongoDb(SqlStatement args)
private IEnumerable<dynamic> RunQueryInMongoDb(string connectionString, SqlStatement args)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
var client = new MongoClient(settings.MongoDbConnectionString);
var client = new MongoClient(connectionString);

// Normalize multi-line query to single line
var statement = Regex.Replace(args.Statement.Trim(), @"\s+", " ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ public async Task<bool> Execute(RoleDialogModel message)
IEnumerable<dynamic>? result = null;
if (!string.IsNullOrWhiteSpace(args.SqlStatement))
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new MySqlConnection(settings.MySqlExecutionConnectionString);
var sqlHook = _services.GetRequiredService<IText2SqlHook>();
var connectionString = sqlHook.GetConnectionString(message);
using var connection = new MySqlConnection(connectionString);
result = connection.Query(args.SqlStatement);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ public async Task OnPlanningCompleted(string planner, RoleDialogModel msg)
public async Task<string> GetSummaryAdditionalRequirements(string planner, RoleDialogModel message)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
var sqlHooks = _services.GetServices<IText2SqlHook>();
var sqlHook = _services.GetRequiredService<IText2SqlHook>();
var agentService = _services.GetRequiredService<IAgentService>();

var dbType = !sqlHooks.IsNullOrEmpty() ? sqlHooks.First().GetDatabaseType(message) : settings.DatabaseType;
var dbType = sqlHook.GetDatabaseType(message);
var agent = await agentService.LoadAgent(BuiltInAgentId.SqlDriver);

return agent.Templates.FirstOrDefault(x => x.Name == $"database.summarize.{dbType}")?.Content ?? string.Empty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ public class ExecuteQueryArgs

public string DbType { get; set; } = null!;

public string DataSource { get;set; } = null!;

public string? ConnectionString { get; set; }

/// <summary>
/// Beautifying query result
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ public async Task<bool> Import(ImportDbKnowledgeRequest request)
var collectionName = request.KnowledgebaseCollection;

var tables = new HashSet<string>();
using var connection = new MySqlConnection(sqlDriverSettings.MySqlConnectionString);

using var connection = new MySqlConnection(string.Empty);

var sql = $"select table_name from information_schema.tables where table_schema = @tableSchema";
var results = connection.Query(sql, new
Expand Down Expand Up @@ -98,7 +99,7 @@ private string GetTableStructure(string table)
var escapedTableName = MySqlHelper.EscapeString(table);
var sql = $"SHOW CREATE TABLE `{escapedTableName}`";

using var connection = new MySqlConnection(settings.MySqlConnectionString);
using var connection = new MySqlConnection(string.Empty);
connection.Open();
using var command = new MySqlCommand(sql, connection);
using var reader = command.ExecuteReader();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace BotSharp.Plugin.SqlDriver.Settings;


public class DataSourceSetting
{
public string Name { get; set; } = "default";
public string DbType { get; set; } = "mysql";
public string ConnectionString { get; set; } = "localhost";
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@ namespace BotSharp.Plugin.SqlDriver.Settings;

public class SqlDriverSetting
{
public string DatabaseType { get; set; } = "mysql";
public string MySqlConnectionString { get; set; } = null!;
public string MySqlExecutionConnectionString { get; set; } = null!;
public string MySqlMetaConnectionString { get; set; } = null!;
public string SqlServerConnectionString { get; set; } = null!;
public string SqlServerExecutionConnectionString { get; set; } = null!;
public string RedshiftConnectionString { get; set; } = null!;
public string MongoDbConnectionString { get; set; } = null!;
public DataSourceSetting[] Connections { get; set; } = [];
public bool ExecuteSqlSelectAutonomous { get; set; } = false;
public bool FormattingResult { get; set; } = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private List<string> GetDdlFromMySql(string[] tables)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
var tableDdls = new List<string>();
using var connection = new MySqlConnection(settings.MySqlMetaConnectionString ?? settings.MySqlConnectionString);
using var connection = new MySqlConnection(string.Empty);
connection.Open();

foreach (var table in tables)
Expand Down Expand Up @@ -79,7 +79,7 @@ private List<string> GetDdlFromSqlServer(string[] tables)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
var tableDdls = new List<string>();
using var connection = new SqlConnection(settings.SqlServerExecutionConnectionString ?? settings.SqlServerConnectionString);
using var connection = new SqlConnection(string.Empty);
connection.Open();

foreach (var table in tables)
Expand Down Expand Up @@ -132,7 +132,7 @@ private List<string> GetDdlFromRedshift(string[] tables, string schema)
var settings = _services.GetRequiredService<SqlDriverSetting>();
var tableDdls = new List<string>();
var schemas = "'onebi_hour','onebi_day'";
using var connection = new NpgsqlConnection(settings.RedshiftConnectionString);
using var connection = new NpgsqlConnection(string.Empty);
connection.Open();

foreach (var table in tables)
Expand Down
Loading
Loading