Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion PLAN.md
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ FROM (
WHERE total_amount > 1000;
```

### Rule: aggregate free having condition
### Rule: aggregate free `having` condition

```sql
select a from t group by a having a > 10;
Expand All @@ -600,6 +600,21 @@ select a from t group by a having a > 10;
select a from t where a > 10 group by a;
```

### Rule: conflicting function and aggregate definitions

```sql
create function foo(int) returns int as $$
select $1 * 2;
$$ language sql;

create aggregate foo(int) (
sfunc = int4pl,
stype = int,
initcond = '0'
);
-- Query 1 ERROR at Line 1: : ERROR: function "foo" already exists with same argument types
```

### Rule: order direction is redundent

```sql
Expand Down
176 changes: 176 additions & 0 deletions crates/squawk_ide/src/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option<String> {
return hover_function(file, &name_ref, &binder);
}

if is_aggregate_ref(&name_ref) {
return hover_aggregate(file, &name_ref, &binder);
}

if is_select_function_call(&name_ref) {
// Try function first, but fall back to column if no function found
// (handles function-call-style column access like `select a(t)`)
Expand Down Expand Up @@ -85,6 +89,14 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option<String> {
return format_create_function(&create_function, &binder);
}

if let Some(create_aggregate) = name
.syntax()
.ancestors()
.find_map(ast::CreateAggregate::cast)
{
return format_create_aggregate(&create_aggregate, &binder);
}

if let Some(create_schema) = name.syntax().ancestors().find_map(ast::CreateSchema::cast) {
return format_create_schema(&create_schema);
}
Expand Down Expand Up @@ -353,6 +365,15 @@ fn is_function_ref(name_ref: &ast::NameRef) -> bool {
false
}

fn is_aggregate_ref(name_ref: &ast::NameRef) -> bool {
for ancestor in name_ref.syntax().ancestors() {
if ast::DropAggregate::can_cast(ancestor.kind()) {
return true;
}
}
false
}

fn is_select_function_call(name_ref: &ast::NameRef) -> bool {
let mut in_call_expr = false;
let mut in_arg_list = false;
Expand Down Expand Up @@ -498,6 +519,53 @@ fn function_schema(
search_path.first().map(|s| s.to_string())
}

fn hover_aggregate(
file: &ast::SourceFile,
name_ref: &ast::NameRef,
binder: &binder::Binder,
) -> Option<String> {
let aggregate_ptr = resolve::resolve_name_ref(binder, name_ref)?;

let root = file.syntax();
let aggregate_name_node = aggregate_ptr.to_node(root);

let create_aggregate = aggregate_name_node
.ancestors()
.find_map(ast::CreateAggregate::cast)?;

format_create_aggregate(&create_aggregate, binder)
}

fn format_create_aggregate(
create_aggregate: &ast::CreateAggregate,
binder: &binder::Binder,
) -> Option<String> {
let path = create_aggregate.path()?;
let segment = path.segment()?;
let name = segment.name()?;
let aggregate_name = name.syntax().text().to_string();

let schema = if let Some(qualifier) = path.qualifier() {
qualifier.syntax().text().to_string()
} else {
aggregate_schema(create_aggregate, binder)?
};

let param_list = create_aggregate.param_list()?;
let params = param_list.syntax().text().to_string();

Some(format!("aggregate {}.{}{}", schema, aggregate_name, params))
}

fn aggregate_schema(
create_aggregate: &ast::CreateAggregate,
binder: &binder::Binder,
) -> Option<String> {
let position = create_aggregate.syntax().text_range().start();
let search_path = binder.search_path_at(position);
search_path.first().map(|s| s.to_string())
}

#[cfg(test)]
mod test {
use crate::hover::hover;
Expand Down Expand Up @@ -1003,6 +1071,114 @@ drop function foo$0();
");
}

#[test]
fn hover_on_drop_function_overloaded() {
assert_snapshot!(check_hover("
create function add(complex) returns complex as $$ select null $$ language sql;
create function add(bigint) returns bigint as $$ select 1 $$ language sql;
drop function add$0(complex);
"), @r"
hover: function public.add(complex) returns complex
╭▸
4 │ drop function add(complex);
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_function_second_overload() {
assert_snapshot!(check_hover("
create function add(complex) returns complex as $$ select null $$ language sql;
create function add(bigint) returns bigint as $$ select 1 $$ language sql;
drop function add$0(bigint);
"), @r"
hover: function public.add(bigint) returns bigint
╭▸
4 │ drop function add(bigint);
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_aggregate() {
assert_snapshot!(check_hover("
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
drop aggregate myavg$0(int);
"), @r"
hover: aggregate public.myavg(int)
╭▸
3 │ drop aggregate myavg(int);
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_aggregate_with_schema() {
assert_snapshot!(check_hover("
create aggregate myschema.myavg(int) (sfunc = int4_avg_accum, stype = _int8);
drop aggregate myschema.myavg$0(int);
"), @r"
hover: aggregate myschema.myavg(int)
╭▸
3 │ drop aggregate myschema.myavg(int);
╰╴ ─ hover
");
}

#[test]
fn hover_on_create_aggregate_definition() {
assert_snapshot!(check_hover("
create aggregate myavg$0(int) (sfunc = int4_avg_accum, stype = _int8);
"), @r"
hover: aggregate public.myavg(int)
╭▸
2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_aggregate_with_search_path() {
assert_snapshot!(check_hover(r#"
set search_path to myschema;
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
drop aggregate myavg$0(int);
"#), @r"
hover: aggregate myschema.myavg(int)
╭▸
4 │ drop aggregate myavg(int);
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_aggregate_overloaded() {
assert_snapshot!(check_hover("
create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)');
create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0');
drop aggregate sum$0(complex);
"), @r"
hover: aggregate public.sum(complex)
╭▸
4 │ drop aggregate sum(complex);
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_aggregate_second_overload() {
assert_snapshot!(check_hover("
create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)');
create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0');
drop aggregate sum$0(bigint);
"), @r"
hover: aggregate public.sum(bigint)
╭▸
4 │ drop aggregate sum(bigint);
╰╴ ─ hover
");
}

#[test]
fn hover_on_select_function_call() {
assert_snapshot!(check_hover("
Expand Down
Loading