Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,27 @@ end
assert_eq!(b, b_expected);
}

#[test]
fn test_feature_const_local_alias_chain_does_not_inherit_flow() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
local ret --- @type string | nil

local is_string = type(ret) == "string"
local ok = is_string
if ok then
a = ret
end
"#,
);

let a = ws.expr_ty("a");
let a_expected = ws.ty("string?");
assert_eq!(a, a_expected);
}

#[test]
fn test_feature_generic_type_guard() {
let mut ws = VirtualWorkspace::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,99 @@ mod test {

assert_eq!(ws.expr_ty("result"), ws.ty("never"));
}

#[test]
fn test_rawget_guard_narrows_matching_index_expr() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@class T
---@field x? integer

---@type T
local t = {}

if rawget(t, "x") then
result = t.x
end
"#,
);

assert_eq!(ws.expr_ty("result"), LuaType::Integer);
}

#[test]
fn test_type_guard_call_narrows_matching_index_expr() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@generic T
---@param inst any
---@param type `T`
---@return TypeGuard<T>
local function instance_of(inst, type)
return true
end

---@class T
---@field x? string|integer

---@type T
local t = {}

if instance_of(t.x, "string") then
result = t.x
end
"#,
);

assert_eq!(ws.expr_ty("result"), LuaType::String);
}

#[test]
fn test_alias_predicate_guard_narrows_matching_index_expr() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class T
---@field x? integer

---@type T
local t = {}

local ok = t.x ~= nil
if ok then
result = t.x
end
"#,
);

assert_eq!(ws.expr_ty("result"), LuaType::Integer);
}

#[test]
fn test_alias_chain_predicate_guard_keeps_matching_index_expr_wide() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class T
---@field x? integer

---@type T
local t = {}

local has_x = t.x ~= nil
local ok = has_x
if ok then
result = t.x
end
"#,
);

assert_eq!(ws.expr_ty("result"), ws.ty("integer?"));
}
}
45 changes: 36 additions & 9 deletions crates/emmylua_code_analysis/src/semantic/cache/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
mod cache_options;

pub use cache_options::{CacheOptions, LuaAnalysisPhase};
use emmylua_parser::LuaSyntaxId;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use emmylua_parser::{LuaExpr, LuaSyntaxId, LuaVarExpr};
use hashbrown::{HashMap, HashSet};
use std::{rc::Rc, sync::Arc};

use crate::{
FileId, FlowId, LuaFunctionType,
Expand All @@ -19,16 +17,35 @@ pub enum CacheEntry<T> {
Cache(T),
}

#[derive(Debug, Clone)]
pub(in crate::semantic) struct FlowConditionInfo {
pub expr: LuaExpr,
pub index_var_ref_id: Option<VarRefId>,
pub index_prefix_var_ref_id: Option<VarRefId>,
}

#[derive(Debug, Clone)]
pub(in crate::semantic) struct FlowAssignmentInfo {
pub vars: Vec<LuaVarExpr>,
pub exprs: Vec<LuaExpr>,
pub var_ref_ids: Vec<Option<VarRefId>>,
}

#[derive(Debug)]
pub struct LuaInferCache {
file_id: FileId,
config: CacheOptions,
pub expr_cache: HashMap<LuaSyntaxId, CacheEntry<LuaType>>,
pub call_cache:
HashMap<(LuaSyntaxId, Option<usize>, LuaType), CacheEntry<Arc<LuaFunctionType>>>,
pub(crate) flow_node_cache: HashMap<(VarRefId, FlowId, bool), CacheEntry<LuaType>>,
pub(in crate::semantic) flow_cache_var_ref_ids: HashMap<VarRefId, u32>,
pub(in crate::semantic) next_flow_cache_var_ref_id: u32,
pub(crate) flow_node_cache: Vec<HashMap<u32, [Option<CacheEntry<LuaType>>; 2]>>,
pub(in crate::semantic) flow_branch_antecedent_cache: Vec<Option<Rc<Vec<FlowId>>>>,
pub(in crate::semantic) flow_condition_info_cache: Vec<Option<Rc<FlowConditionInfo>>>,
pub(in crate::semantic) flow_assignment_info_cache: Vec<Option<Rc<FlowAssignmentInfo>>>,
pub(in crate::semantic) condition_flow_cache:
HashMap<(VarRefId, FlowId, bool), CacheEntry<ConditionFlowAction>>,
Vec<HashMap<u32, [Option<CacheEntry<ConditionFlowAction>>; 2]>>,
pub index_ref_origin_type_cache: HashMap<VarRefId, CacheEntry<LuaType>>,
pub expr_var_ref_id_cache: HashMap<LuaSyntaxId, VarRefId>,
pub narrow_by_literal_stop_position_cache: HashSet<LuaSyntaxId>,
Expand All @@ -41,8 +58,13 @@ impl LuaInferCache {
config,
expr_cache: HashMap::new(),
call_cache: HashMap::new(),
flow_node_cache: HashMap::new(),
condition_flow_cache: HashMap::new(),
flow_cache_var_ref_ids: HashMap::new(),
next_flow_cache_var_ref_id: 0,
flow_node_cache: Vec::new(),
flow_branch_antecedent_cache: Vec::new(),
flow_condition_info_cache: Vec::new(),
flow_assignment_info_cache: Vec::new(),
condition_flow_cache: Vec::new(),
index_ref_origin_type_cache: HashMap::new(),
expr_var_ref_id_cache: HashMap::new(),
narrow_by_literal_stop_position_cache: HashSet::new(),
Expand All @@ -64,7 +86,12 @@ impl LuaInferCache {
pub fn clear(&mut self) {
self.expr_cache.clear();
self.call_cache.clear();
self.flow_cache_var_ref_ids.clear();
self.next_flow_cache_var_ref_id = 0;
self.flow_node_cache.clear();
self.flow_branch_antecedent_cache.clear();
self.flow_condition_info_cache.clear();
self.flow_assignment_info_cache.clear();
self.condition_flow_cache.clear();
self.index_ref_origin_type_cache.clear();
self.expr_var_ref_id_cache.clear();
Expand Down
137 changes: 131 additions & 6 deletions crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
mod infer_array;

use std::collections::HashSet;

use emmylua_parser::{
LuaExpr, LuaIndexExpr, LuaIndexKey, LuaIndexMemberExpr, NumberResult, PathTrait,
LuaAstNode, LuaChunk, LuaExpr, LuaIndexExpr, LuaIndexKey, LuaIndexMemberExpr, NumberResult,
PathTrait,
};
use hashbrown::HashSet;
use internment::ArcIntern;
use rowan::TextRange;
use smol_str::SmolStr;

use crate::{
CacheEntry, GenericTpl, InFiled, InferGuardRef, LuaAliasCallKind, LuaDeclOrMemberId,
LuaInferCache, LuaInstanceType, LuaMemberOwner, LuaOperatorOwner, TypeOps,
CacheEntry, FlowAntecedent, FlowId, FlowNode, FlowNodeKind, FlowTree, GenericTpl, InFiled,
InferGuardRef, LuaAliasCallKind, LuaDeclOrMemberId, LuaInferCache, LuaInstanceType,
LuaMemberOwner, LuaOperatorOwner, TypeOps,
db_index::{
DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType,
LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType,
Expand All @@ -24,7 +25,11 @@ use crate::{
VarRefId,
infer_index::infer_array::{check_iter_var_range, infer_array_member},
infer_name::get_name_expr_var_ref_id,
narrow::infer_expr_narrow_type,
narrow::{
ConditionFlowAction, InferConditionFlow, get_condition_flow_action,
get_flow_assignment_info, get_flow_cache_var_ref_id, get_flow_condition_info,
get_var_expr_var_ref_id, infer_expr_narrow_type,
},
},
member::get_buildin_type_map_type_id,
member::intersect_member_types,
Expand Down Expand Up @@ -107,13 +112,133 @@ fn infer_member_type_pass_flow(
cache
.index_ref_origin_type_cache
.insert(var_ref_id.clone(), CacheEntry::Cache(member_type.clone()));
let file_id = cache.get_file_id();
if let Some(flow_tree) = db.get_flow_index().get_flow_tree(&file_id)
&& let Some(flow_id) = flow_tree.get_flow_id(index_expr.get_syntax_id())
&& let Some(root) = LuaChunk::cast(index_expr.get_root())
&& matches!(
has_direct_index_flow_effect(db, cache, flow_tree, &root, &var_ref_id, flow_id),
Ok(false)
)
{
return Ok(member_type);
}

let result = infer_expr_narrow_type(db, cache, LuaExpr::IndexExpr(index_expr), var_ref_id);
match &result {
Err(InferFailReason::None) => Ok(member_type.clone()),
_ => result,
}
}

fn has_direct_index_flow_effect(
db: &DbIndex,
cache: &mut LuaInferCache,
tree: &FlowTree,
root: &LuaChunk,
var_ref_id: &VarRefId,
start_flow_id: FlowId,
) -> Result<bool, InferFailReason> {
let mut pending = vec![start_flow_id];
let mut visited_labels = HashSet::new();
let var_ref_cache_id = get_flow_cache_var_ref_id(cache, var_ref_id);

while let Some(flow_id) = pending.pop() {
let flow_node = tree.get_flow_node(flow_id).ok_or(InferFailReason::None)?;
match &flow_node.kind {
FlowNodeKind::Start | FlowNodeKind::Unreachable => {}
FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => {
if visited_labels.insert(flow_id) {
extend_flow_antecedents(tree, flow_node, &mut pending)?;
}
}
FlowNodeKind::Assignment(assign_ptr) => {
let assignment_info =
get_flow_assignment_info(db, cache, root, flow_node.id, assign_ptr)?;
if assignment_info
.var_ref_ids
.iter()
.flatten()
.any(|assignment_var_ref_id| assignment_var_ref_id == var_ref_id)
{
return Ok(true);
}
extend_flow_antecedents(tree, flow_node, &mut pending)?;
}
FlowNodeKind::TrueCondition(condition_ptr)
| FlowNodeKind::FalseCondition(condition_ptr) => {
let condition_info =
get_flow_condition_info(db, cache, root, flow_node.id, condition_ptr)?;
let condition_flow = if matches!(&flow_node.kind, FlowNodeKind::TrueCondition(_)) {
InferConditionFlow::TrueCondition
} else {
InferConditionFlow::FalseCondition
};
let condition_action = get_condition_flow_action(
db,
tree,
cache,
root,
var_ref_id,
var_ref_cache_id,
flow_node,
&condition_info,
condition_flow,
)?;
if !matches!(condition_action, ConditionFlowAction::Continue) {
return Ok(true);
}
extend_flow_antecedents(tree, flow_node, &mut pending)?;
}
FlowNodeKind::ImplFunc(func_ptr) => {
let func_stat = func_ptr.to_node(root).ok_or(InferFailReason::None)?;
if let Some(func_name) = func_stat.get_func_name()
&& get_var_expr_var_ref_id(db, cache, func_name.to_expr()).as_ref()
== Some(var_ref_id)
{
return Ok(true);
}
extend_flow_antecedents(tree, flow_node, &mut pending)?;
}
FlowNodeKind::TagCast(tag_cast_ptr) => {
let tag_cast = tag_cast_ptr.to_node(root).ok_or(InferFailReason::None)?;
if let Some(key_expr) = tag_cast.get_key_expr()
&& get_var_expr_var_ref_id(db, cache, key_expr).as_ref() == Some(var_ref_id)
{
return Ok(true);
}
extend_flow_antecedents(tree, flow_node, &mut pending)?;
}
FlowNodeKind::LoopLabel
| FlowNodeKind::DeclPosition(_)
| FlowNodeKind::ForIStat(_)
| FlowNodeKind::Break
| FlowNodeKind::Return => {
extend_flow_antecedents(tree, flow_node, &mut pending)?;
}
}
}

Ok(false)
}

fn extend_flow_antecedents(
tree: &FlowTree,
flow_node: &FlowNode,
pending: &mut Vec<FlowId>,
) -> Result<(), InferFailReason> {
match flow_node.antecedent.as_ref() {
Some(FlowAntecedent::Single(flow_id)) => pending.push(*flow_id),
Some(FlowAntecedent::Multiple(multi_id)) => pending.extend(
tree.get_multi_antecedents(*multi_id)
.ok_or(InferFailReason::None)?,
),
None => {}
}

Ok(())
}

pub fn get_index_expr_var_ref_id(
db: &DbIndex,
cache: &mut LuaInferCache,
Expand Down
Loading
Loading