Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -998,4 +998,47 @@ drop function foo(), bar$0();
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_function_call() {
assert_snapshot!(goto("
create function foo() returns int as $$ select 1 $$ language sql;
select foo$0();
"), @r"
╭▸
2 │ create function foo() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
3 │ select foo();
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_function_call_with_schema() {
assert_snapshot!(goto("
create function public.foo() returns int as $$ select 1 $$ language sql;
select public.foo$0();
"), @r"
╭▸
2 │ create function public.foo() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
3 │ select public.foo();
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_function_call_with_search_path() {
assert_snapshot!(goto("
set search_path to myschema;
create function foo() returns int as $$ select 1 $$ language sql;
select myschema.foo$0();
"), @r"
╭▸
3 │ create function foo() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
4 │ select myschema.foo();
╰╴ ─ 1. source
");
}
}
71 changes: 71 additions & 0 deletions crates/squawk_ide/src/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option<String> {
if is_function_ref(&name_ref) {
return hover_function(file, &name_ref, &binder);
}

if is_select_function_call(&name_ref) {
return hover_function(file, &name_ref, &binder);
}
}

if let Some(name) = ast::Name::cast(parent) {
Expand Down Expand Up @@ -299,6 +303,20 @@ fn is_function_ref(name_ref: &ast::NameRef) -> bool {
false
}

fn is_select_function_call(name_ref: &ast::NameRef) -> bool {
let mut in_call_expr = false;

for ancestor in name_ref.syntax().ancestors() {
if ast::CallExpr::can_cast(ancestor.kind()) {
in_call_expr = true;
}
if ast::Select::can_cast(ancestor.kind()) && in_call_expr {
return true;
}
}
false
}

fn hover_function(
file: &ast::SourceFile,
name_ref: &ast::NameRef,
Expand Down Expand Up @@ -856,4 +874,57 @@ drop function foo$0();
╰╴ ─ hover
");
}

#[test]
fn hover_on_select_function_call() {
assert_snapshot!(check_hover("
create function foo() returns int as $$ select 1 $$ language sql;
select foo$0();
"), @r"
hover: function public.foo() returns int
╭▸
3 │ select foo();
╰╴ ─ hover
");
}

#[test]
fn hover_on_select_function_call_with_schema() {
assert_snapshot!(check_hover("
create function public.foo() returns int as $$ select 1 $$ language sql;
select public.foo$0();
"), @r"
hover: function public.foo() returns int
╭▸
3 │ select public.foo();
╰╴ ─ hover
");
}

#[test]
fn hover_on_select_function_call_with_search_path() {
assert_snapshot!(check_hover(r#"
set search_path to myschema;
create function foo() returns int as $$ select 1 $$ language sql;
select foo$0();
"#), @r"
hover: function myschema.foo() returns int
╭▸
4 │ select foo();
╰╴ ─ hover
");
}

#[test]
fn hover_on_select_function_call_with_params() {
assert_snapshot!(check_hover("
create function add(a int, b int) returns int as $$ select a + b $$ language sql;
select add$0(1, 2);
"), @r"
hover: function public.add(a int, b int) returns int
╭▸
3 │ select add(1, 2);
╰╴ ─ hover
");
}
}
23 changes: 23 additions & 0 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ enum NameRefContext {
DropFunction,
CreateIndex,
CreateIndexColumn,
SelectFunctionCall,
}

pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
Expand Down Expand Up @@ -44,12 +45,28 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
let position = name_ref.syntax().text_range().start();
resolve_function(binder, &function_name, &schema, position)
}
NameRefContext::SelectFunctionCall => {
let schema = if let Some(parent_node) = name_ref.syntax().parent()
&& let Some(field_expr) = ast::FieldExpr::cast(parent_node)
{
let base = field_expr.base()?;
let schema_name_ref = ast::NameRef::cast(base.syntax().clone())?;
let schema_text = schema_name_ref.syntax().text().to_string();
Some(Schema(Name::new(schema_text)))
} else {
None
};
let function_name = Name::new(name_ref.syntax().text().to_string());
let position = name_ref.syntax().text_range().start();
resolve_function(binder, &function_name, &schema, position)
}
NameRefContext::CreateIndexColumn => resolve_create_index_column(binder, name_ref),
}
}

fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext> {
let mut in_partition_item = false;
let mut in_call_expr = false;

for ancestor in name_ref.syntax().ancestors() {
if ast::DropTable::can_cast(ancestor.kind()) {
Expand All @@ -73,6 +90,12 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
}
return Some(NameRefContext::CreateIndex);
}
if ast::CallExpr::can_cast(ancestor.kind()) {
in_call_expr = true;
}
if ast::Select::can_cast(ancestor.kind()) && in_call_expr {
return Some(NameRefContext::SelectFunctionCall);
}
}

None
Expand Down
Loading