Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// </summary>
public class CosmosDateTimeMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator
{
private static readonly Dictionary<MethodInfo, string> MethodInfoDatePartMapping = new()
{
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddYears), [typeof(int)])!, "yyyy" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMonths), [typeof(int)])!, "mm" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddDays), [typeof(double)])!, "dd" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddHours), [typeof(double)])!, "hh" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMinutes), [typeof(double)])!, "mi" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddSeconds), [typeof(double)])!, "ss" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMilliseconds), [typeof(double)])!, "ms" },
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMicroseconds), [typeof(double)])!, "mcs" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddYears), [typeof(int)])!, "yyyy" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMonths), [typeof(int)])!, "mm" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddDays), [typeof(double)])!, "dd" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddHours), [typeof(double)])!, "hh" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMinutes), [typeof(double)])!, "mi" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), [typeof(double)])!, "ss" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMilliseconds), [typeof(double)])!, "ms" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMicroseconds), [typeof(double)])!, "mcs" }
};

private static readonly Dictionary<MethodInfo, string> MethodInfoDateDiffMapping = new()
{
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.ToUnixTimeSeconds), Type.EmptyTypes)!, "second" },
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.ToUnixTimeMilliseconds), Type.EmptyTypes)!, "millisecond" }
};

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand All @@ -56,16 +30,30 @@ public class CosmosDateTimeMethodTranslator(ISqlExpressionFactory sqlExpressionF
return null;
}

if (MethodInfoDatePartMapping.TryGetValue(method, out var datePart)
&& instance != null)
if (instance is null || arguments is not [var arg])
{
return sqlExpressionFactory.Function(
"DateTimeAdd",
arguments: [sqlExpressionFactory.Constant(datePart), arguments[0], instance],
instance.Type,
instance.TypeMapping);
return null;
}

return null;
var datePart = method.Name switch
{
nameof(DateTime.AddYears) => "yyyy",
nameof(DateTime.AddMonths) => "mm",
nameof(DateTime.AddDays) => "dd",
nameof(DateTime.AddHours) => "hh",
nameof(DateTime.AddMinutes) => "mi",
nameof(DateTime.AddSeconds) => "ss",
nameof(DateTime.AddMilliseconds) => "ms",
nameof(DateTime.AddMicroseconds) => "mcs",
_ => (string?)null
};

return datePart is not null
? sqlExpressionFactory.Function(
"DateTimeAdd",
arguments: [sqlExpressionFactory.Constant(datePart), arg, instance],
instance.Type,
instance.TypeMapping)
: null;
}
}
146 changes: 75 additions & 71 deletions src/EFCore.Cosmos/Query/Internal/Translators/CosmosMathTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,66 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// </summary>
public class CosmosMathTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator
{
private static readonly Dictionary<MethodInfo, string> SupportedMethodTranslations = new()
{
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(decimal)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(double)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(float)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(int)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(long)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(sbyte)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(short)])!, "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(decimal)])!, "CEILING" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(double)])!, "CEILING" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(decimal)])!, "FLOOR" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(double)])!, "FLOOR" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "POWER" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Exp), [typeof(double)])!, "EXP" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log10), [typeof(double)])!, "LOG10" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double)])!, "LOG" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double), typeof(double)])!, "LOG" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), [typeof(double)])!, "SQRT" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Acos), [typeof(double)])!, "ACOS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Asin), [typeof(double)])!, "ASIN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan), [typeof(double)])!, "ATAN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "ATN2" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Cos), [typeof(double)])!, "COS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sin), [typeof(double)])!, "SIN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Tan), [typeof(double)])!, "TAN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(decimal)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(double)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(float)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(int)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(long)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(sbyte)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(short)])!, "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(decimal)])!, "TRUNC" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(double)])!, "TRUNC" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal)])!, "ROUND" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(double)])!, "ROUND" },
{ typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), [typeof(double)])!, "RADIANS" },
{ typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), [typeof(double)])!, "DEGREES" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), [typeof(float)])!, "CEILING" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), [typeof(float)])!, "FLOOR" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "POWER" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), [typeof(float)])!, "EXP" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), [typeof(float)])!, "LOG10" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float)])!, "LOG" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float), typeof(float)])!, "LOG" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), [typeof(float)])!, "SQRT" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), [typeof(float)])!, "ACOS" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), [typeof(float)])!, "ASIN" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), [typeof(float)])!, "ATAN" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "ATN2" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), [typeof(float)])!, "COS" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), [typeof(float)])!, "SIN" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), [typeof(float)])!, "TAN" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Truncate), [typeof(float)])!, "TRUNC" },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), [typeof(float)])!, "ROUND" },
{ typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), [typeof(float)])!, "RADIANS" },
{ typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "DEGREES" },
};

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand All @@ -85,21 +25,85 @@ public class CosmosMathTranslator(ISqlExpressionFactory sqlExpressionFactory) :
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (SupportedMethodTranslations.TryGetValue(method, out var sqlFunctionName))
if (method.DeclaringType != typeof(Math)
&& method.DeclaringType != typeof(MathF)
&& method.DeclaringType != typeof(double)
&& method.DeclaringType != typeof(float))
{
return null;
}

var sqlFunctionName = method.Name switch
{
var typeMapping = arguments.Count == 1
? ExpressionExtensions.InferTypeMapping(arguments[0])
: ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]);
nameof(Math.Abs) when arguments is [var arg]
&& arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float)
|| t == typeof(int) || t == typeof(long) || t == typeof(sbyte) || t == typeof(short))
=> "ABS",

var newArguments = arguments.Select(e => sqlExpressionFactory.ApplyTypeMapping(e, typeMapping!));
nameof(Math.Ceiling) when arguments is [var arg]
&& arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float))
=> "CEILING",
nameof(Math.Floor) when arguments is [var arg]
&& arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float))
=> "FLOOR",
nameof(Math.Round) when arguments is [var arg]
&& arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float))
=> "ROUND",
nameof(Math.Truncate) when arguments is [var arg]
&& arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float))
=> "TRUNC",
nameof(Math.Sign) when arguments is [var arg]
&& arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float)
|| t == typeof(int) || t == typeof(long) || t == typeof(sbyte) || t == typeof(short))
=> "SIGN",

return sqlExpressionFactory.Function(
sqlFunctionName,
newArguments,
method.ReturnType,
typeMapping);
nameof(Math.Pow) when arguments is [_, _]
=> "POWER",
nameof(Math.Exp) when arguments is [_]
=> "EXP",
nameof(Math.Log10) when arguments is [_]
=> "LOG10",
nameof(Math.Log) when arguments is [_] or [_, _]
=> "LOG",
nameof(Math.Sqrt) when arguments is [_]
=> "SQRT",
nameof(Math.Acos) when arguments is [_]
=> "ACOS",
nameof(Math.Asin) when arguments is [_]
=> "ASIN",
nameof(Math.Atan) when arguments is [_]
=> "ATAN",
nameof(Math.Atan2) when arguments is [_, _]
=> "ATN2",
nameof(Math.Cos) when arguments is [_]
=> "COS",
nameof(Math.Sin) when arguments is [_]
=> "SIN",
nameof(Math.Tan) when arguments is [_]
=> "TAN",
nameof(double.DegreesToRadians) when arguments is [_]
=> "RADIANS",
nameof(double.RadiansToDegrees) when arguments is [_]
=> "DEGREES",

_ => null
};

if (sqlFunctionName is null)
{
return null;
}

return null;
var typeMapping = arguments.Count == 1
? ExpressionExtensions.InferTypeMapping(arguments[0])
: ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]);

var newArguments = arguments.Select(e => sqlExpressionFactory.ApplyTypeMapping(e, typeMapping!));

return sqlExpressionFactory.Function(
sqlFunctionName,
newArguments,
method.ReturnType,
typeMapping);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// </summary>
public class CosmosRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator
{
private static readonly MethodInfo MethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod(
nameof(DbFunctionsExtensions.Random), [typeof(DbFunctions)])!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand All @@ -27,10 +24,11 @@ public class CosmosRandomTranslator(ISqlExpressionFactory sqlExpressionFactory)
MethodInfo method,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
=> MethodInfo.Equals(method)
? sqlExpressionFactory.Function(
"RAND",
[],
method.ReturnType)
: null;
=> method.DeclaringType == typeof(DbFunctionsExtensions)
&& method.Name == nameof(DbFunctionsExtensions.Random)
? sqlExpressionFactory.Function(
"RAND",
[],
method.ReturnType)
: null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
public class CosmosRegexTranslator(ISqlExpressionFactory sqlExpressionFactory)
: IMethodCallTranslator
{
private static readonly MethodInfo IsMatch =
typeof(Regex).GetRuntimeMethod(nameof(Regex.IsMatch), [typeof(string), typeof(string)])!;

private static readonly MethodInfo IsMatchWithRegexOptions =
typeof(Regex).GetRuntimeMethod(nameof(Regex.IsMatch), [typeof(string), typeof(string), typeof(RegexOptions)])!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand All @@ -33,7 +27,13 @@ public class CosmosRegexTranslator(ISqlExpressionFactory sqlExpressionFactory)
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method != IsMatch && method != IsMatchWithRegexOptions)
if (method.DeclaringType != typeof(Regex)
|| method.Name != nameof(Regex.IsMatch))
{
return null;
}

if (arguments is not ([_, _] or [_, _, _]))
{
return null;
}
Expand All @@ -44,7 +44,7 @@ public class CosmosRegexTranslator(ISqlExpressionFactory sqlExpressionFactory)
sqlExpressionFactory.ApplyTypeMapping(input, typeMapping),
sqlExpressionFactory.ApplyTypeMapping(pattern, typeMapping));

if (method == IsMatch || arguments[2] is SqlConstantExpression { Value: RegexOptions.None })
if (arguments.Count == 2 || arguments[2] is SqlConstantExpression { Value: RegexOptions.None })
{
return sqlExpressionFactory.Function("RegexMatch", [input, pattern], typeof(bool));
}
Expand Down
Loading
Loading