Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,12 @@ private string CreateInitialQuery()
}
else if (!string.IsNullOrEmpty(CatalogName))
{
CatalogName = SqlServerEscapeHelper.EscapeStringAsLiteral(SqlServerEscapeHelper.EscapeIdentifier(CatalogName));
CatalogName = SqlServerEscapeHelper.EscapeIdentifier(CatalogName);
}

string objectName = ADP.BuildMultiPartName(parts);
string escapedObjectName = SqlServerEscapeHelper.EscapeStringAsLiteral(objectName);
string catalogNameStringLiteral = CatalogName is null ? null : SqlServerEscapeHelper.EscapeStringAsLiteral(CatalogName);
// Specify the column names explicitly. This is to ensure that we can map to hidden
// columns (e.g. columns in temporal tables.) If the target table doesn't exist,
// OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent
Expand Down Expand Up @@ -526,6 +527,11 @@ private string CreateInitialQuery()
// we use STRING_AGG in that case and the COALESCE method otherwise.
//
// See: https://learn.microsoft.com/en-us/sql/t-sql/functions/serverproperty-transact-sql
//
// All of this is wrapped in an test against HAS_PERMS_BY_NAME. This test verifies that
// the user possesses the necessary permissions to access sys.all_columns. If they do not
// @Column_Names will remain NULL (and be coalesced to *) and SqlBulkCopy will degrade
// gracefully, silently dropping support for hidden columns and column aliases.
return $"""
SELECT @@TRANCOUNT;

Expand All @@ -535,6 +541,7 @@ private string CreateInitialQuery()
DECLARE @Column_Name_Query_SORT NVARCHAR(MAX);
DECLARE @Column_Name_Query NVARCHAR(MAX);
DECLARE @Column_Names NVARCHAR(MAX) = NULL;
DECLARE @Has_Permissions INT = HAS_PERMS_BY_NAME('{catalogNameStringLiteral}.[sys].[all_columns]', 'OBJECT', 'SELECT');

CREATE TABLE #Column_Aliases
(
Expand All @@ -554,28 +561,35 @@ IF CAST(SERVERPROPERTY('EngineEdition') AS INT) = 6
SET @Column_Name_Query_SORT = N'ORDER BY [column_id] ASC';
END

IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sys.all_columns') AND [name] = 'graph_type')
BEGIN
SET @Column_Name_Query_FILTER = N'WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7)';

EXEC sp_executesql N'
INSERT INTO #Column_Aliases ([Canonical_Column_Name], [Canonical_Column_Id], [Aliased_Column_Name])
SELECT [name], [column_id], ''$to_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 8
UNION ALL
SELECT [name], [column_id], ''$from_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 5
UNION ALL
SELECT [name], [column_id], ''$edge_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 2 AND [name] LIKE ''$edge[_]id[_]%''
UNION ALL
SELECT [name], [column_id], ''$node_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 2 AND [name] LIKE ''$node[_]id[_]%''',
N'@Object_ID INT', @Object_ID = @Object_ID
END
ELSE
IF @Has_Permissions = 1
BEGIN
SET @Column_Name_Query_FILTER = N'WHERE [object_id] = @Object_ID';
IF EXISTS (SELECT TOP 1 * FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{catalogNameStringLiteral}.[sys].[all_columns]') AND [name] = 'graph_type')
BEGIN
SET @Column_Name_Query_FILTER = N'WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7)';

EXEC sp_executesql N'
INSERT INTO #Column_Aliases ([Canonical_Column_Name], [Canonical_Column_Id], [Aliased_Column_Name])
SELECT [name], [column_id], ''$to_id'' FROM {catalogNameStringLiteral}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 8
UNION ALL
SELECT [name], [column_id], ''$from_id'' FROM {catalogNameStringLiteral}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 5
UNION ALL
SELECT [name], [column_id], ''$edge_id'' FROM {catalogNameStringLiteral}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 2 AND [name] LIKE ''$edge[_]id[_]%''
UNION ALL
SELECT [name], [column_id], ''$node_id'' FROM {catalogNameStringLiteral}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 2 AND [name] LIKE ''$node[_]id[_]%''',
N'@Object_ID INT', @Object_ID = @Object_ID
END
ELSE
BEGIN
SET @Column_Name_Query_FILTER = N'WHERE [object_id] = @Object_ID';
END
SET @Column_Name_Query = @Column_Name_Query_SELECT + ' FROM {catalogNameStringLiteral}.[sys].[all_columns] ' + @Column_Name_Query_FILTER + ' ' + @Column_Name_Query_SORT + ';'

EXEC sp_executesql @Column_Name_Query, N'@Object_ID INT, @Column_Names NVARCHAR(MAX) OUTPUT', @Object_ID = @Object_ID, @Column_Names = @Column_Names OUTPUT;

DELETE FROM #Column_Aliases
WHERE [Aliased_Column_Name] IN (SELECT [name] FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID)
END
SET @Column_Name_Query = @Column_Name_Query_SELECT + ' FROM {CatalogName}.[sys].[all_columns] ' + @Column_Name_Query_FILTER + ' ' + @Column_Name_Query_SORT + ';'

EXEC sp_executesql @Column_Name_Query, N'@Object_ID INT, @Column_Names NVARCHAR(MAX) OUTPUT', @Object_ID = @Object_ID, @Column_Names = @Column_Names OUTPUT;
SELECT @Column_Names = COALESCE(@Column_Names, '*');

SET FMTONLY ON;
Expand All @@ -586,7 +600,6 @@ UNION ALL

SELECT [Canonical_Column_Name], [Aliased_Column_Name]
FROM #Column_Aliases
WHERE [Aliased_Column_Name] NOT IN (SELECT [name] FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID)
ORDER BY [Canonical_Column_Id] ASC

DROP TABLE #Column_Aliases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,28 @@ namespace Microsoft.Data.SqlClient.Tests.Common.Fixtures.DatabaseObjects;
/// Base class for a transient database object (such as a table, type or
/// stored procedure.)
/// </summary>
public abstract class DatabaseObject : IDisposable
/// <typeparam name="TState">
/// The type of the internal state accessible to derived types at the point of object creation
/// via the <see cref="State"/> property.
/// </typeparam>
public abstract class DatabaseObject<TState> : IDisposable
{
private readonly bool _shouldDrop;

protected SqlConnection Connection { get; }

protected TState State { get; }

public string Name { get; }

protected DatabaseObject(SqlConnection connection, string name, string definition, bool shouldCreate, bool shouldDrop)
public string UnescapedName => Name.Substring(1, Name.Length - 2).Replace("]]", "]");

protected DatabaseObject(SqlConnection connection, string name, string definition, TState state, bool shouldCreate, bool shouldDrop)
{
_shouldDrop = shouldDrop;

Connection = connection;
State = state;
Name = name;

if (shouldCreate)
Expand Down Expand Up @@ -261,3 +270,15 @@ public void Dispose()
GC.SuppressFinalize(this);
}
}

/// <summary>
/// Base class for a transient database object (such as a table, type or
/// stored procedure.)
/// </summary>
public abstract class DatabaseObject : DatabaseObject<object?>
{
protected DatabaseObject(SqlConnection connection, string name, string definition, bool shouldCreate, bool shouldDrop)
: base(connection, name, definition, state: null, shouldCreate, shouldDrop)
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.Data.SqlClient.Tests.Common.Fixtures.DatabaseObjects;

/// <summary>
/// A transient database user, created at the start of its scope and dropped when disposed.
/// </summary>
/// <remarks>
/// This class assumes that the associated server login already exists.
/// </remarks>
public sealed class DatabaseUser : DatabaseObject<string>
{
public string DatabaseName => State;

/// <summary>
/// Initializes a new instance of the DatabaseUser class using the specified SQL connection
/// and associated server login.
/// </summary>
/// <param name="connection">The SQL connection used to interact with the database.</param>
/// <param name="database">The name of the database where the user will be created.</param>
/// <param name="login">The server login which the database user will be associated with.</param>
public DatabaseUser(SqlConnection connection, string database, ServerLogin login)
: base(connection, login.Name, $"FOR LOGIN {login.Name}", database, shouldCreate: true, shouldDrop: true)
{
}

protected override void CreateObject(string definition)
{
using SqlCommand createCommand = new($"CREATE USER {Name} {definition}", Connection);

ExecuteCommandInDatabase(createCommand);
}

protected override void DropObject()
{
using SqlCommand dropCommand = new($"IF USER_ID('{UnescapedName}') IS NOT NULL DROP USER {Name}", Connection);

ExecuteCommandInDatabase(dropCommand);
}

private void ExecuteCommandInDatabase(SqlCommand command)
{
string? originalDatabase = DatabaseName == command.Connection.Database ? null : command.Connection.Database;

try
{
if (originalDatabase is not null)
{
command.Connection.ChangeDatabase(DatabaseName);
}

command.ExecuteNonQuery();
}
finally
{
if (originalDatabase is not null)
{
command.Connection.ChangeDatabase(originalDatabase);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.Data.SqlClient.Tests.Common.Fixtures.DatabaseObjects;

/// <summary>
/// A transient server login, created at the start of its scope and dropped when disposed.
/// </summary>
public sealed class ServerLogin : DatabaseObject<string>
{
public string Password => State;

/// <summary>
/// Initializes a new instance of the ServerLogin class using the specified SQL connection, login name prefix, and default database.
/// The login will be created with a randomly generated password that meets SQL Server's password complexity requirements.
/// </summary>
/// <param name="connection">The SQL connection used to interact with the database.</param>
/// <param name="namePrefix">The prefix for the login name.</param>
/// <param name="defaultDatabase">The default database for the login. If null, not set.</param>
public ServerLogin(SqlConnection connection, string namePrefix, string? defaultDatabase = null)
: this(connection, GenerateLongName(namePrefix), GeneratePassword(), defaultDatabase)
{
}

private ServerLogin(SqlConnection connection, string namePrefix, string password, string? defaultDatabase)
: base(connection, namePrefix, GenerateDefinition(password, defaultDatabase), password, shouldCreate: true, shouldDrop: true)
{
}

private static string GenerateDefinition(string password, string? defaultDatabase) =>
$"WITH PASSWORD='{password}'" +
(string.IsNullOrEmpty(defaultDatabase) ? string.Empty : $", DEFAULT_DATABASE=[{defaultDatabase}]");

/// <summary>
/// Generates a password which meets the SQL Server password complexity requirements, which are:
/// <list type="number">
/// <item>Minimum length of 8 characters</item>
/// <item>Must contain characters from three of the following four categories:</item>
/// <list type="number">
/// <item>Uppercase letters (A-Z)</item>
/// <item>Lowercase letters (a-z)</item>
/// <item>Digits (0-9)</item>
/// <item>Non-alphanumeric characters (e.g. !, $, #, %)</item>
/// </list>
/// </list>
/// </summary>
/// <returns>A compliant password.</returns>
private static string GeneratePassword()
{
const int PasswordLength = 16;
const char UpperCaseStart = 'A';
const char LowerCaseStart = 'a';
const char DigitsStart = '0';

// First 5 characters are uppercase letters, next 5 are lowercase letters, and the last 6 are digits
Span<char> passwordDigits = stackalloc char[PasswordLength];
Random rnd = new();

for(int i = 0; i < 5; i++)
{
passwordDigits[i] = (char)(UpperCaseStart + rnd.Next(26));
}
for (int i = 5; i < 10; i++)
{
passwordDigits[i] = (char)(LowerCaseStart + rnd.Next(26));
}
for (int i = 10; i < PasswordLength; i++)
{
passwordDigits[i] = (char)(DigitsStart + rnd.Next(10));
}

return passwordDigits.ToString();
}

protected override void CreateObject(string definition)
{
using SqlCommand createCommand = new($"CREATE LOGIN {Name} {definition}", Connection);

createCommand.ExecuteNonQuery();
}

protected override void DropObject()
{
using SqlCommand dropCommand = new($"IF SUSER_ID('{UnescapedName}') IS NOT NULL DROP LOGIN {Name}", Connection);

dropCommand.ExecuteNonQuery();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ public static class DataTestUtility
private static bool? s_isVectorSupported;
private static bool? s_isVectorFloat16Supported;

// Login permissions
private static bool? s_isSysAdmin;
private static bool? s_isSecurityAdmin;

// Azure Synapse EngineEditionId == 6
// More could be read at https://learn.microsoft.com/en-us/sql/t-sql/functions/serverproperty-transact-sql?view=sql-server-ver16#propertyname
public static bool IsAzureSynapse
Expand Down Expand Up @@ -231,6 +235,20 @@ private static bool CheckVectorFloat16Supported()
}
}

public static bool IsSysAdmin =>
s_isSysAdmin ??= IsTCPConnStringSetup() &&
IsServerRoleMember("sysadmin");

public static bool IsSecurityAdmin =>
s_isSecurityAdmin ??= IsTCPConnStringSetup() &&
IsServerRoleMember("securityadmin");

public static bool CanCreateLogins =>
IsSysAdmin || IsSecurityAdmin;

public static bool CanUseSqlAuthentication =>
IsSysAdmin && GetAuthenticationMode() == 2;

static DataTestUtility()
{
Config c = Config.Load();
Expand Down Expand Up @@ -531,6 +549,30 @@ public static bool IsTypePresent(string typeName)
return (int)command.ExecuteScalar() > 0;
}

public static bool IsServerRoleMember(string roleName)
{
using SqlConnection connection = new(TCPConnectionString);
using SqlCommand command = new("SELECT IS_SRVROLEMEMBER(@role)", connection);

connection.Open();
command.Parameters.AddWithValue("@role", roleName);

// IS_SRVROLEMEMBER returns 1 if the caller is a member of the specified server role, 0 if not, and DBNull.Value if the role is not valid.
return command.ExecuteScalar() is int result && result == 1;
}

public static int GetAuthenticationMode()
{
using SqlConnection connection = new(TCPConnectionString);

connection.Open();
using SqlCommand command = new("EXEC xp_instance_regread N'HKEY_LOCAL_MACHINE', N'Software\\Microsoft\\MSSQLServer\\MSSQLServer', N'LoginMode'", connection);
using SqlDataReader reader = command.ExecuteReader();

reader.Read();
return reader.GetInt32(1);
}

public static bool IsAdmin
{
get
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#if NETFRAMEWORK

namespace System.Diagnostics.CodeAnalysis;

#nullable enable

[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)]
internal sealed class MemberNotNullAttribute : Attribute
{
public MemberNotNullAttribute(string member) => Members = [member];

public MemberNotNullAttribute(params string[] members) => Members = members;

public string[] Members { get; }
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<Compile Include="DataCommon\SystemDataResourceManager.cs" />
<Compile Include="DataCommon\UsernamePasswordProvider.cs" />
<Compile Include="DataCommon\XEventScope.cs" />
<Compile Include="Extensions\CodeAnalysis.netfx.cs" />
<Compile Include="Extensions\StreamExtensions.netfx.cs" />
<Compile Include="SQL\Common\AsyncDebugScope.cs" />
<Compile Include="SQL\Common\ConnectionPoolWrapper.cs" />
Expand All @@ -57,6 +58,7 @@
<Compile Include="SQL\Common\SystemDataInternals\FedAuthTokenHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
<Compile Include="SQL\SqlBulkCopyTest\UnprivilegedLogin.cs" />
<Compile Include="XUnitAssemblyAttributes.cs" />

<!-- Content files -->
Expand Down
Loading
Loading