Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
461 changes: 240 additions & 221 deletions be/src/exprs/function/function_regexp.cpp

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions be/test/exprs/function/function_like_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "core/column/column_string.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type_array.h"
#include "core/data_type/data_type_nullable.h"
#include "core/data_type/data_type_number.h"
#include "core/data_type/data_type_string.h"
Expand Down Expand Up @@ -248,6 +249,156 @@ TEST(FunctionLikeTest, regexp_extract_all) {
}
}

TEST(FunctionLikeTest, regexp_extract_all_array) {
std::string func_name = "regexp_extract_all_array";
auto str_type = std::make_shared<DataTypeString>();
auto return_type = make_nullable(
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeString>())));

auto run_case = [&](const std::string& str, const std::string& pattern,
const std::string& expected, bool expect_null = false) {
auto col_str = ColumnString::create();
col_str->insert_data(str.data(), str.size());
auto col_pattern = ColumnString::create();
col_pattern->insert_data(pattern.data(), pattern.size());

Block block;
block.insert({std::move(col_str), str_type, "str"});
block.insert({ColumnConst::create(std::move(col_pattern), 1), str_type, "pattern"});
block.insert({nullptr, return_type, "result"});

ColumnsWithTypeAndName arg_cols = {block.get_by_position(0), block.get_by_position(1)};
auto func =
SimpleFunctionFactory::instance().get_function(func_name, arg_cols, return_type);
ASSERT_TRUE(func != nullptr);

std::vector<DataTypePtr> arg_types = {str_type, str_type};
FunctionUtils fn_utils({}, arg_types, false);
auto* fn_ctx = fn_utils.get_fn_ctx();
fn_ctx->set_constant_cols(
{nullptr, std::make_shared<ColumnPtrWrapper>(block.get_by_position(1).column)});

ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::THREAD_LOCAL));
ASSERT_EQ(Status::OK(), func->execute(fn_ctx, block, {0, 1}, 2, 1));

auto result_col = block.get_by_position(2).column;
ASSERT_TRUE(result_col.get() != nullptr);
if (expect_null) {
EXPECT_TRUE(result_col->is_null_at(0));
} else {
ASSERT_FALSE(result_col->is_null_at(0));
auto result_str = return_type->to_string(*result_col, 0);
EXPECT_EQ(expected, result_str)
<< "input: '" << str << "', pattern: '" << pattern << "'";
}

static_cast<void>(func->close(fn_ctx, FunctionContext::THREAD_LOCAL));
static_cast<void>(func->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
};

run_case("x=a3&x=18abc&x=2&y=3&x=4&x=17bcd", "x=([0-9]+)([a-z]+)", "[\"18\", \"17\"]");
run_case("x=a3&x=18abc&x=2&y=3&x=4", "^x=([a-z]+)([0-9]+)", "[\"a\"]");
run_case("http://a.m.baidu.com/i41915173660.htm", "i([0-9]+)", "[\"41915173660\"]");
run_case("http://a.m.baidu.com/i41915i73660.htm", "i([0-9]+)", "[\"41915\", \"73660\"]");
run_case("hitdecisiondlist", "(i)(.*?)(e)", "[\"i\"]");
run_case("no_match_here", "x=([0-9]+)", "[]");
run_case("abc", "([a-z]+)", "[\"abc\"]");

// Helper for testing null input propagation
auto nullable_str_type = make_nullable(str_type);
auto run_null_case = [&](bool null_str, bool null_pattern) {
ColumnPtr col_str;
DataTypePtr str_col_type;
if (null_str) {
auto col = ColumnNullable::create(ColumnString::create(), ColumnUInt8::create());
col->insert_default();
col_str = std::move(col);
str_col_type = nullable_str_type;
} else {
auto col = ColumnString::create();
col->insert_data("abc", 3);
col_str = std::move(col);
str_col_type = str_type;
}

ColumnPtr col_pattern;
DataTypePtr pattern_col_type;
if (null_pattern) {
auto col = ColumnNullable::create(ColumnString::create(), ColumnUInt8::create());
col->insert_default();
col_pattern = ColumnConst::create(std::move(col), 1);
pattern_col_type = nullable_str_type;
} else {
auto col = ColumnString::create();
col->insert_data("([a-z]+)", 8);
col_pattern = ColumnConst::create(std::move(col), 1);
pattern_col_type = str_type;
}

Block block;
block.insert({col_str, str_col_type, "str"});
block.insert({col_pattern, pattern_col_type, "pattern"});
block.insert({nullptr, return_type, "result"});

ColumnsWithTypeAndName arg_cols = {block.get_by_position(0), block.get_by_position(1)};
auto func =
SimpleFunctionFactory::instance().get_function(func_name, arg_cols, return_type);
ASSERT_TRUE(func != nullptr);

std::vector<DataTypePtr> arg_types = {str_col_type, pattern_col_type};
FunctionUtils fn_utils({}, arg_types, false);
auto* fn_ctx = fn_utils.get_fn_ctx();
fn_ctx->set_constant_cols(
{nullptr, std::make_shared<ColumnPtrWrapper>(block.get_by_position(1).column)});

ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::THREAD_LOCAL));
ASSERT_EQ(Status::OK(), func->execute(fn_ctx, block, {0, 1}, 2, 1));

EXPECT_TRUE(block.get_by_position(2).column->is_null_at(0))
<< "Expected null for null_str=" << null_str << " null_pattern=" << null_pattern;

static_cast<void>(func->close(fn_ctx, FunctionContext::THREAD_LOCAL));
static_cast<void>(func->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
};

// NULL input string → null result
run_null_case(true, false);
// NULL pattern → null result
run_null_case(false, true);

// Invalid const pattern → open() should fail
{
auto col_str = ColumnString::create();
col_str->insert_data("abc", 3);
auto col_pattern = ColumnString::create();
col_pattern->insert_data("(", 1);
Block block;
block.insert({std::move(col_str), str_type, "str"});
block.insert({ColumnConst::create(std::move(col_pattern), 1), str_type, "pattern"});
block.insert({nullptr, return_type, "result"});

ColumnsWithTypeAndName arg_cols = {block.get_by_position(0), block.get_by_position(1)};
auto func =
SimpleFunctionFactory::instance().get_function(func_name, arg_cols, return_type);
ASSERT_TRUE(func != nullptr);

std::vector<DataTypePtr> arg_types = {str_type, str_type};
FunctionUtils fn_utils({}, arg_types, false);
auto* fn_ctx = fn_utils.get_fn_ctx();
fn_ctx->set_constant_cols(
{nullptr, std::make_shared<ColumnPtrWrapper>(block.get_by_position(1).column)});

ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
// Invalid pattern should cause open() to fail for THREAD_LOCAL scope
EXPECT_NE(Status::OK(), func->open(fn_ctx, FunctionContext::THREAD_LOCAL));

static_cast<void>(func->close(fn_ctx, FunctionContext::THREAD_LOCAL));
static_cast<void>(func->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
}
}
Comment on lines +252 to +400
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new regexp_extract_all_array unit test only covers non-null inputs with a constant pattern. Since the function is marked always-nullable and should propagate NULLs / handle invalid patterns, add a couple of cases that assert (1) NULL input string, (2) NULL pattern, and ideally (3) invalid pattern behavior (error or NULL result depending on constness) so the null/exception semantics are covered at the unit-test level too.

Copilot uses AI. Check for mistakes.

TEST(FunctionLikeTest, regexp_replace) {
std::string func_name = "regexp_replace";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpCount;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAllArray;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractOrNull;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplace;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplaceOne;
Expand Down Expand Up @@ -986,6 +987,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(RegexpCount.class, "regexp_count"),
scalar(RegexpExtract.class, "regexp_extract"),
scalar(RegexpExtractAll.class, "regexp_extract_all"),
scalar(RegexpExtractAllArray.class, "regexp_extract_all_array"),
scalar(RegexpExtractOrNull.class, "regexp_extract_or_null"),
scalar(RegexpReplace.class, "regexp_replace"),
scalar(RegexpReplaceOne.class, "regexp_replace_one"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'regexp_extract_all_array'.
* Returns all matches of a regex pattern as an Array&lt;String&gt; instead of a string-formatted array.
*/
public class RegexpExtractAllArray extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable, PropagateNullLiteral {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT))
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(ArrayType.of(StringType.INSTANCE))
.args(StringType.INSTANCE, StringType.INSTANCE)
);

/**
* constructor with 2 arguments.
*/
public RegexpExtractAllArray(Expression arg0, Expression arg1) {
super("regexp_extract_all_array", arg0, arg1);
}

/** constructor for withChildren and reuse signature */
private RegexpExtractAllArray(ScalarFunctionParams functionParams) {
super(functionParams);
}

/**
* withChildren.
*/
@Override
public RegexpExtractAllArray withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new RegexpExtractAllArray(getFunctionParams(children));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitRegexpExtractAllArray(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpCount;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAllArray;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractOrNull;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplace;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplaceOne;
Expand Down Expand Up @@ -2160,6 +2161,10 @@ default R visitRegexpExtractAll(RegexpExtractAll regexpExtractAll, C context) {
return visitScalarFunction(regexpExtractAll, context);
}

default R visitRegexpExtractAllArray(RegexpExtractAllArray regexpExtractAllArray, C context) {
return visitScalarFunction(regexpExtractAllArray, context);
}

default R visitRegexpExtractOrNull(RegexpExtractOrNull regexpExtractOrNull, C context) {
return visitScalarFunction(regexpExtractOrNull, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,62 @@ aXb
-- !sql_regexp_extract_all_10 --
['aXb','cXd']

-- !regexp_extract_all_array_1 --
["18", "17"]

-- !regexp_extract_all_array_2 --
["41915", "73660"]

-- !regexp_extract_all_array_3 --
["abc", "def", "ghi"]

-- !regexp_extract_all_array_4 --
[]

-- !regexp_extract_all_array_5 --
\N

-- !regexp_extract_all_array_6 --
\N

-- !regexp_extract_all_array_7 --
["ab", "c", "c", "c"]

-- !regexp_extract_all_array_8 --
\N
[]
["Emmy", "eillish"]
["It", "s", "ok"]
["It", "s", "true"]
["billie", "eillish"]
["billie", "eillish"]

-- !regexp_extract_all_array_9 --
\N
[]
["mmy", "eillish"]
["t", "s", "ok"]
["t", "s", "true"]
["billie", "eillish"]
["billie", "eillish"]

-- !regexp_extract_all_array_10 --
\N 5 \N
6 []
Emmy eillish 3 ["Emmy", "eillish"]
It's ok 2 ["It", "s", "ok"]
It's true 4 ["It", "s", "true"]
billie eillish \N ["billie", "eillish"]
billie eillish 1 ["billie", "eillish"]

-- !regexp_extract_all_array_11 --
[]
[]
[]
[]
[]
[]

-- !sql --
a-b-c

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ suite("test_string_function_regexp") {
qt_sql_regexp_extract_all_9 "SELECT REGEXP_EXTRACT_ALL(concat('aXb', char(10), 'cXd'), '(?-s)(\\\\w.\\\\w)');"
qt_sql_regexp_extract_all_10 "SELECT REGEXP_EXTRACT_ALL(concat('aXb', char(10), 'cXd'), '(\\\\w.\\\\w)');"

qt_regexp_extract_all_array_1 "SELECT regexp_extract_all_array('x=a3&x=18abc&x=2&y=3&x=4&x=17bcd', 'x=([0-9]+)([a-z]+)');"
qt_regexp_extract_all_array_2 "SELECT regexp_extract_all_array('http://a.m.baidu.com/i41915i73660.htm', 'i([0-9]+)');"
qt_regexp_extract_all_array_3 "SELECT regexp_extract_all_array('abc=111, def=222, ghi=333', '(\"[^\"]+\"|\\\\w+)=(\"[^\"]+\"|\\\\w+)');"
qt_regexp_extract_all_array_4 "select regexp_extract_all_array('xxfs','f');"
qt_regexp_extract_all_array_5 "select regexp_extract_all_array(NULL, 'pattern');"
qt_regexp_extract_all_array_6 "select regexp_extract_all_array('text', NULL);"
qt_regexp_extract_all_array_7 "select regexp_extract_all_array('abcdfesscca', '(ab|c|)');"
qt_regexp_extract_all_array_8 "SELECT regexp_extract_all_array(k, '(\\\\w+)') from test_string_function_regexp ORDER BY k;"
qt_regexp_extract_all_array_9 "SELECT regexp_extract_all_array(k, '([a-z]+)') from test_string_function_regexp ORDER BY k;"
qt_regexp_extract_all_array_10 "SELECT k, v, regexp_extract_all_array(k, '(\\\\w+)') from test_string_function_regexp ORDER BY k;"
qt_regexp_extract_all_array_11 "SELECT regexp_extract_all_array(k, concat('^', k)) from test_string_function_regexp WHERE k IS NOT NULL ORDER BY k;"

qt_sql "SELECT regexp_replace('a b c', \" \", \"-\");"
qt_sql "SELECT regexp_replace('a b c','(b)','<\\\\1>');"

Expand Down
Loading