Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension.ScalarFunctionVariant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

/**
* Custom mapping for the Calcite {@code POSITION} function to the Substrait {@code strpos}
* function. Calcite also represents the SQL {@code STRPOS} function as {@code POSITION}.
*
* <p>Calcite {@code POSITION} has <em>substring</em> followed by <em>input</em> parameters, while
* Substrait {@code strpos} has <em>input</em> followed by <em>substring</em>. When mapping between
* Calcite and Substrait, the parameters need to be reversed
*
* <p>{@code POSITION(substring IN input)} maps to {@code strpos(input, substring)}.
*/
final class PositionFunctionMapper implements ScalarFunctionMapper {
private static final String strposFunctionName = "strpos";
private final List<ScalarFunctionVariant> strposFunctions;

public PositionFunctionMapper(List<ScalarFunctionVariant> functions) {
strposFunctions =
functions.stream()
.filter(f -> strposFunctionName.equals(f.name()))
.collect(Collectors.toUnmodifiableList());
}

@Override
public Optional<SubstraitFunctionMapping> toSubstrait(final RexCall call) {
if (!SqlStdOperatorTable.POSITION.equals(call.op)) {
return Optional.empty();
}

List<RexNode> operands = new ArrayList<>(call.getOperands());
Collections.swap(operands, 0, 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a bad idea to put a comment here just because we have some naked numbers used to reference parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the class Javadoc not sufficient?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably fine, just my opinion if we wanted to make it extra clear. Feel free to ignore if you disagree.


return Optional.of(new SubstraitFunctionMapping(strposFunctionName, operands, strposFunctions));
}

@Override
public Optional<List<FunctionArg>> getExpressionArguments(
final Expression.ScalarFunctionInvocation expression) {
if (!strposFunctionName.equals(expression.declaration().name())) {
return Optional.empty();
}

List<FunctionArg> args = new ArrayList<>(expression.arguments());
Collections.swap(args, 0, 1);
return Optional.of(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public ScalarFunctionConverter(
List.of(
new TrimFunctionMapper(functions),
new SqrtFunctionMapper(functions),
new ExtractDateFunctionMapper(functions));
new ExtractDateFunctionMapper(functions),
new PositionFunctionMapper(functions));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@
import org.apache.calcite.rex.RexCall;

/**
* Provides custom conversion for a Calcite call to corresponding Substrait functions and arguments.
* Provides custom conversion between a Calcite call and corresponding Substrait functions and
* arguments.
*/
interface ScalarFunctionMapper {

/**
* If the supplied call is applicable to this mapper, get the custom mapping to the corresponding
* Substrait function.
* If the supplied Calcite call is applicable to this mapper, get the custom mapping to the
* corresponding Substrait function.
*
* @param call a Calcite call.
* @return a custom function mapping, or an empty Optional if no mapping exists.
*/
Optional<SubstraitFunctionMapping> toSubstrait(RexCall call);

/**
* If the supplied expression is applicable to this mapper, get the function arguments that should
* be used for the Substrait function call.
* If the supplied Substrait expression is applicable to this mapper, get the function arguments
* that should be used when mapping to the corresponding Calcite function.
*
* @param expression an expression.
* @return a list of function arguments, or an empty Optional if no mapping exists.
Expand Down
81 changes: 44 additions & 37 deletions isthmus/src/test/java/io/substrait/isthmus/StringFunctionTest.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FixedCharLiteral;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import io.substrait.plan.Plan;
import io.substrait.relation.Project;
import java.util.List;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
Expand Down Expand Up @@ -144,9 +152,7 @@ private void assertSqlRoundTrip(String sql) throws SqlParseException {
@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testStarts_With(String left, String right) throws Exception {

String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right);

assertSqlRoundTrip(query);
}

Expand All @@ -162,9 +168,7 @@ void testStarts_WithLiteral(String left, String right) throws Exception {
@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testStartsWith(String left, String right) throws Exception {

String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right);

assertSqlRoundTrip(query);
}

Expand All @@ -180,9 +184,7 @@ void testStartsWithLiteral(String left, String right) throws Exception {
@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testEnds_With(String left, String right) throws Exception {

String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right);

assertSqlRoundTrip(query);
}

Expand All @@ -198,9 +200,7 @@ void testEnds_WithLiteral(String left, String right) throws Exception {
@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testEndsWith(String left, String right) throws Exception {

String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right);

assertSqlRoundTrip(query);
}

Expand All @@ -216,9 +216,7 @@ void testEndsWithLiteral(String left, String right) throws Exception {
@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testContains(String left, String right) throws Exception {

String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right);

assertSqlRoundTrip(query);
}

Expand All @@ -227,87 +225,96 @@ void testContains(String left, String right) throws Exception {
value = {"'start', vc", "vc, 'end'"},
quoteCharacter = '`')
void testContainsWithLiteral(String left, String right) throws Exception {

String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right);

assertSqlRoundTrip(query);
}

@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testPosition(String left, String right) throws Exception {

String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right);

void testPosition(String substring, String input) throws Exception {
String query = String.format("SELECT POSITION(%s IN %s) FROM strings", substring, input);
assertSqlRoundTrip(query);
}

@ParameterizedTest
@CsvSource(
value = {"'start', vc", "vc, 'end'"},
value = {"'substring', vc", "vc, 'string'"},
quoteCharacter = '`')
void testPositionWithLiteral(String left, String right) throws Exception {

String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right);

void testPositionWithLiteral(String substring, String input) throws Exception {
String query = String.format("SELECT POSITION(%s IN %s) FROM strings", substring, input);
assertSqlRoundTrip(query);
}

@ParameterizedTest
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
void testStrpos(String left, String right) throws Exception {

String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right);

void testStrpos(String input, String substring) throws Exception {
String query = String.format("SELECT STRPOS(%s, %s) FROM strings", input, substring);
assertSqlRoundTrip(query);
}

@ParameterizedTest
@CsvSource(
value = {"'start', vc", "vc, 'end'"},
value = {"vc, 'substring'", "'string', vc"},
quoteCharacter = '`')
void testStrposWithLiteral(String left, String right) throws Exception {
void testStrposWithLiteral(String input, String substring) throws Exception {
String query = String.format("SELECT STRPOS(%s, %s) FROM strings", input, substring);
assertSqlRoundTrip(query);
}

String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right);
@Test
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@Test
@Test
// Verifies POSITION→strpos parameter swap (#662).
// SQL POSITION(substring IN input) must become Substrait strpos(input, substring).

Just so someone has a reference for the future :)

// Calcite POSITION(substring in input) maps to Substrait strpos(input, substring).
// Calcite represents STRPOS as POSITION so this test covers both functions.
void testPositionParameterOrdering() throws Exception {
String input = "input";
String substring = "substring";
String sql = String.format("SELECT POSITION('%s' in '%s') FROM strings", substring, input);
CalciteCatalogReader catalog =
SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES);

assertSqlRoundTrip(query);
Plan plan = new SqlToSubstrait().convert(sql, catalog);

List<String> expected = List.of(input, substring);

Plan.Root root = plan.getRoots().stream().findFirst().orElseThrow();
Project project = (Project) root.getInput();
Expression.ScalarFunctionInvocation strpos =
(Expression.ScalarFunctionInvocation) project.getExpressions().get(0);
List<String> actual =
strpos.arguments().stream()
.map(arg -> (FixedCharLiteral) arg)
.map(FixedCharLiteral::value)
.toList();

assertEquals(expected, actual);
}

@ParameterizedTest
@CsvSource({"vc32, i32", "vc, i32"})
void testLeft(String left, String right) throws Exception {

String query = String.format("SELECT LEFT(%s, %s) FROM int_num_strings", left, right);

assertFullRoundTrip(query, CHAR_INT_CREATES);
}

@ParameterizedTest
@CsvSource({"vc32, i32", "vc, i32"})
void testRight(String left, String right) throws Exception {

String query = String.format("SELECT RIGHT(%s, %s) FROM int_num_strings", left, right);

assertFullRoundTrip(query, CHAR_INT_CREATES);
}

@ParameterizedTest
@CsvSource({"vc32, i32, vc32", "vc, i32, vc"})
void testRpad(String left, String center, String right) throws Exception {

String query =
String.format("SELECT RPAD(%s, %s, %s) FROM int_num_strings", left, center, right);

assertFullRoundTrip(query, CHAR_INT_CREATES);
}

@ParameterizedTest
@CsvSource({"vc32, i32, vc32", "vc, i32, vc"})
void testLpad(String left, String center, String right) throws Exception {

String query =
String.format("SELECT LPAD(%s, %s, %s) FROM int_num_strings", left, center, right);

assertFullRoundTrip(query, CHAR_INT_CREATES);
}
}