Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyrefly/lib/lsp/non_wasm/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ pub fn capabilities(
code_action_kinds: Some(vec![
CodeActionKind::QUICKFIX,
CodeActionKind::REFACTOR_EXTRACT,
CodeActionKind::REFACTOR_REWRITE,
CodeActionKind::new("refactor.move"),
CodeActionKind::REFACTOR_INLINE,
]),
Expand Down Expand Up @@ -3189,6 +3190,9 @@ impl Server {
if let Some(refactors) = transaction.introduce_parameter_code_actions(&handle, range) {
push_refactor_actions(refactors);
}
if let Some(refactors) = transaction.convert_star_import_code_actions(&handle, range) {
push_refactor_actions(refactors);
}
if let Some(action) =
convert_module_package_code_actions(&self.initialize_params.capabilities, uri)
{
Expand Down
7 changes: 7 additions & 0 deletions pyrefly/lib/state/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,13 @@ impl<'a> Transaction<'a> {
) -> Option<Vec<LocalRefactorCodeAction>> {
quick_fixes::introduce_parameter::introduce_parameter_code_actions(self, handle, selection)
}
pub fn convert_star_import_code_actions(
&self,
handle: &Handle,
selection: TextRange,
) -> Option<Vec<LocalRefactorCodeAction>> {
quick_fixes::convert_star_import::convert_star_import_code_actions(self, handle, selection)
}

/// Determines whether a module is a third-party package.
///
Expand Down
174 changes: 174 additions & 0 deletions pyrefly/lib/state/lsp/quick_fixes/convert_star_import.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use std::collections::BTreeSet;

use dupe::Dupe;
use lsp_types::CodeActionKind;
use pyrefly_build::handle::Handle;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::short_identifier::ShortIdentifier;
use ruff_python_ast::Expr;
use ruff_python_ast::ExprContext;
use ruff_python_ast::ModModule;
use ruff_python_ast::Stmt;
use ruff_python_ast::StmtImportFrom;
use ruff_python_ast::visitor::Visitor;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
use ruff_text_size::TextSize;

use super::extract_shared::line_indent_and_start;
use super::extract_shared::selection_anchor;
use crate::binding::binding::Key;
use crate::state::ide::IntermediateDefinition;
use crate::state::ide::key_to_intermediate_definition;
use crate::state::lsp::LocalRefactorCodeAction;
use crate::state::lsp::Transaction;

fn rewrite_kind() -> CodeActionKind {
CodeActionKind::new("refactor.rewrite")
}

/// Builds convert-star-import refactor actions for the supplied selection.
pub(crate) fn convert_star_import_code_actions(
transaction: &Transaction<'_>,
handle: &Handle,
selection: TextRange,
) -> Option<Vec<LocalRefactorCodeAction>> {
let module_info = transaction.get_module_info(handle)?;
let ast = transaction.get_ast(handle)?;
let source = module_info.contents();
let selection_point = selection_anchor(source, selection);
let (import_from, star_range) = find_star_import(ast.as_ref(), selection_point)?;
let module_name = resolve_import_module_name(&module_info, import_from)?;
let bindings = transaction.get_bindings(handle)?;

let names = collect_star_imported_names(ast.as_ref(), &bindings, module_name, star_range);
if names.is_empty() {
return None;
}

let (indent, line_start) = line_indent_and_start(source, import_from.range().start())?;
let line_end = line_end_position(source, import_from.range().end());
let line_range = TextRange::new(line_start, line_end);
let line_text =
&source[line_start.to_usize().min(source.len())..line_end.to_usize().min(source.len())];
let comment = trailing_comment(line_text);

let from_module = import_from_module_text(import_from);
let import_list = names.join(", ");
let mut replacement = format!("{indent}from {from_module} import {import_list}");
if let Some(comment) = comment {
replacement.push(' ');
replacement.push_str(comment.trim_start());
}
replacement.push('\n');

Some(vec![LocalRefactorCodeAction {
title: format!(
"Convert to explicit imports from `{}`",
module_name.as_str()
),
edits: vec![(module_info.dupe(), line_range, replacement)],
kind: rewrite_kind(),
}])
}

fn find_star_import<'a>(
ast: &'a ModModule,
selection: TextSize,
) -> Option<(&'a StmtImportFrom, TextRange)> {
ast.body.iter().find_map(|stmt| match stmt {
Stmt::ImportFrom(import_from) if import_from.range().contains(selection) => {
let star = import_from.names.iter().find(|alias| &alias.name == "*")?;
Some((import_from, star.range))
}
_ => None,
})
}

fn resolve_import_module_name(
module_info: &pyrefly_python::module::Module,
import_from: &StmtImportFrom,
) -> Option<ModuleName> {
module_info.name().new_maybe_relative(
module_info.path().is_init(),
import_from.level,
import_from.module.as_ref().map(|module| &module.id),
)
}

fn import_from_module_text(import_from: &StmtImportFrom) -> String {
let mut module_text = ".".repeat(import_from.level as usize);
if let Some(module) = &import_from.module {
module_text.push_str(module.id.as_str());
}
module_text
}

fn collect_star_imported_names(
ast: &ModModule,
bindings: &crate::binding::bindings::Bindings,
module_name: ModuleName,
star_range: TextRange,
) -> Vec<String> {
struct NameCollector<'a> {
bindings: &'a crate::binding::bindings::Bindings,
module_name: ModuleName,
star_range: TextRange,
names: BTreeSet<String>,
}

impl<'a> Visitor<'a> for NameCollector<'a> {
fn visit_expr(&mut self, expr: &'a Expr) {
if let Expr::Name(name) = expr
&& matches!(name.ctx, ExprContext::Load | ExprContext::Del)
{
let key = Key::BoundName(ShortIdentifier::expr_name(name));
if self.bindings.is_valid_key(&key)
&& let Some(intermediate) = key_to_intermediate_definition(self.bindings, &key)
&& let IntermediateDefinition::NamedImport(
import_range,
import_module,
import_name,
_,
) = intermediate
&& import_range == self.star_range
&& import_module == self.module_name
{
self.names.insert(import_name.as_str().to_owned());
}
}
ruff_python_ast::visitor::walk_expr(self, expr);
}
}

let mut collector = NameCollector {
bindings,
module_name,
star_range,
names: BTreeSet::new(),
};
collector.visit_body(&ast.body);
collector.names.into_iter().collect()
}

fn trailing_comment(line: &str) -> Option<&str> {
let trimmed = line.strip_suffix("\n").unwrap_or(line);
let trimmed = trimmed.strip_suffix("\r").unwrap_or(trimmed);
trimmed.find('#').map(|idx| &trimmed[idx..])
}

fn line_end_position(source: &str, position: TextSize) -> TextSize {
let idx = position.to_usize().min(source.len());
if let Some(offset) = source[idx..].find('\n') {
TextSize::try_from(idx + offset + 1).unwrap_or(position)
} else {
TextSize::try_from(source.len()).unwrap_or(position)
}
}
2 changes: 1 addition & 1 deletion pyrefly/lib/state/lsp/quick_fixes/extract_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use super::extract_shared::first_parameter_name;
use super::extract_shared::is_static_or_class_method;
use super::extract_shared::line_indent_and_start;
use super::extract_shared::validate_non_empty_selection;
use super::types::LocalRefactorCodeAction;
use crate::state::lsp::FindPreference;
use crate::state::lsp::LocalRefactorCodeAction;
use crate::state::lsp::Transaction;

const HELPER_INDENT: &str = " ";
Expand Down
1 change: 1 addition & 0 deletions pyrefly/lib/state/lsp/quick_fixes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/

pub(crate) mod convert_star_import;
pub(crate) mod extract_field;
pub(crate) mod extract_function;
mod extract_shared;
Expand Down
139 changes: 139 additions & 0 deletions pyrefly/lib/test/lsp/code_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,29 @@ fn compute_make_top_level_actions(
(module_info, edit_sets, titles)
}

fn compute_convert_star_import_actions(
code_by_module: &[(&'static str, &str)],
module_name: &'static str,
selection: TextRange,
) -> (
ModuleInfo,
Vec<Vec<(Module, TextRange, String)>>,
Vec<String>,
) {
let (handles, state) =
mk_multi_file_state_assert_no_errors(code_by_module, Require::Everything);
let handle = handles.get(module_name).unwrap();
let transaction = state.transaction();
let module_info = transaction.get_module_info(handle).unwrap();
let actions = transaction
.convert_star_import_code_actions(handle, selection)
.unwrap_or_default();
let edit_sets: Vec<Vec<(Module, TextRange, String)>> =
actions.iter().map(|action| action.edits.clone()).collect();
let titles = actions.iter().map(|action| action.title.clone()).collect();
(module_info, edit_sets, titles)
}

fn compute_pull_up_actions(
code: &str,
) -> (
Expand Down Expand Up @@ -1410,6 +1433,122 @@ class C:
assert_eq!(expected.trim(), updated.trim());
}

#[test]
fn convert_star_import_basic() {
let code_main = r#"
# CONVERT-START
from foo import * # noqa: F401
# CONVERT-END
a = A
b = B
"#;
let code_foo = r#"
A = 1
B = 2
C = 3
"#;
let selection = find_marked_range_with(code_main, "# CONVERT-START", "# CONVERT-END");
let (module_info, actions, titles) = compute_convert_star_import_actions(
&[("main", code_main), ("foo", code_foo)],
"main",
selection,
);
assert_eq!(vec!["Convert to explicit imports from `foo`"], titles);
let updated = apply_refactor_edits_for_module(&module_info, &actions[0]);
let expected = r#"
# CONVERT-START
from foo import A, B # noqa: F401
# CONVERT-END
a = A
b = B
"#;
assert_eq!(expected.trim(), updated.trim());
}

#[test]
fn convert_star_import_relative() {
let code_main = r#"
# CONVERT-START
from .foo import *
# CONVERT-END
x = A
"#;
let code_foo = r#"
A = 1
"#;
let selection = find_marked_range_with(code_main, "# CONVERT-START", "# CONVERT-END");
let (module_info, actions, titles) = compute_convert_star_import_actions(
&[("pkg.main", code_main), ("pkg.foo", code_foo)],
"pkg.main",
selection,
);
assert_eq!(vec!["Convert to explicit imports from `pkg.foo`"], titles);
let updated = apply_refactor_edits_for_module(&module_info, &actions[0]);
let expected = r#"
# CONVERT-START
from .foo import A
# CONVERT-END
x = A
"#;
assert_eq!(expected.trim(), updated.trim());
}

#[test]
fn convert_star_import_selects_correct_import() {
let code_main = r#"
# CONVERT-START
from foo import *
# CONVERT-END
from bar import *
a = A
b = B
"#;
let code_foo = r#"
A = 1
"#;
let code_bar = r#"
B = 2
"#;
let selection = find_marked_range_with(code_main, "# CONVERT-START", "# CONVERT-END");
let (module_info, actions, titles) = compute_convert_star_import_actions(
&[("main", code_main), ("foo", code_foo), ("bar", code_bar)],
"main",
selection,
);
assert_eq!(vec!["Convert to explicit imports from `foo`"], titles);
let updated = apply_refactor_edits_for_module(&module_info, &actions[0]);
let expected = r#"
# CONVERT-START
from foo import A
# CONVERT-END
from bar import *
a = A
b = B
"#;
assert_eq!(expected.trim(), updated.trim());
}

#[test]
fn convert_star_import_no_action_when_unused() {
let code_main = r#"
# CONVERT-START
from foo import *
# CONVERT-END
x = 1
"#;
let code_foo = r#"
A = 1
"#;
let selection = find_marked_range_with(code_main, "# CONVERT-START", "# CONVERT-END");
let (_module_info, actions, titles) = compute_convert_star_import_actions(
&[("main", code_main), ("foo", code_foo)],
"main",
selection,
);
assert!(actions.is_empty());
assert!(titles.is_empty());
}

#[test]
fn extract_variable_name_increments_when_taken() {
let code = r#"
Expand Down
2 changes: 1 addition & 1 deletion pyrefly/lib/test/lsp/lsp_interaction/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn test_initialize_basic() {
"definitionProvider": true,
"typeDefinitionProvider": true,
"codeActionProvider": {
"codeActionKinds": ["quickfix", "refactor.extract", "refactor.move", "refactor.inline"]
"codeActionKinds": ["quickfix", "refactor.extract", "refactor.rewrite", "refactor.move", "refactor.inline"]
},
"completionProvider": {
"triggerCharacters": [".", "'", "\""]
Expand Down
Loading