Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,26 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
.next()
.is_some()
{
let pred = refine::user_defined_pred(self.tcx, def_id);
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let generic_args = self.typeck.node_args(func_expr.hir_id);
let generic_args = mir_ty::EarlyBinder::bind(generic_args)
.instantiate(self.tcx, self.generic_args);
let instance = mir_ty::Instance::resolve(
self.tcx,
param_env,
def_id,
generic_args,
)
.unwrap();
let pred_def_id = if let Some(instance) = instance {
instance.def_id()
} else {
def_id
};
let pred = refine::user_defined_pred(self.tcx, pred_def_id);
let arg_terms = args.iter().map(|e| self.to_term(e)).collect();
let atom = chc::Atom::new(pred.into(), arg_terms);
return FormulaOrTerm::Formula(chc::Formula::Atom(atom));
Expand Down
95 changes: 57 additions & 38 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
_ty,
) => {
let func_ty = match operand.const_fn_def() {
Some((def_id, args)) => self
.ctx
.def_ty_with_args(def_id, args)
.expect("unknown def")
.ty
.clone(),
Some((def_id, args)) => self.fn_def_ty(def_id, args),
_ => unimplemented!(),
};
PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null())
Expand Down Expand Up @@ -573,44 +568,68 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
});
}

fn resolve_fn_def(
&self,
def_id: DefId,
args: mir_ty::GenericArgsRef<'tcx>,
) -> (DefId, mir_ty::GenericArgsRef<'tcx>) {
if self.ctx.is_fn_trait_method(def_id) {
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
// we simply replace the def_id with the closure's function def_id.
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
// adjusts the arguments accordingly.
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
panic!("expected closure arg for fn trait");
};
tracing::debug!(?closure_def_id, "closure instance");
(*closure_def_id, args)
} else {
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
if let Some(instance) = instance {
(instance.def_id(), instance.args)
} else {
(def_id, args)
}
}
}

fn fn_def_ty(
&mut self,
def_id: DefId,
args: mir_ty::GenericArgsRef<'tcx>,
) -> rty::Type<rty::Closed> {
if let Some(def_ty) = self.ctx.def_ty_with_args(def_id, args) {
return def_ty.ty;
}

let (resolved_def_id, resolved_args) = self.resolve_fn_def(def_id, args);
if resolved_def_id == def_id {
panic!(
"unknown def (and not resolved): {:?}, args: {:?}",
def_id, args
);
}
tracing::info!(?def_id, ?resolved_def_id, ?resolved_args, "resolved");
let Some(def_ty) = self.ctx.def_ty_with_args(resolved_def_id, resolved_args) else {
panic!(
"unknown def (resolved): {:?}, args: {:?}",
resolved_def_id, resolved_args
);
};
def_ty.ty
}

fn type_call<I>(&mut self, func: Operand<'tcx>, args: I, expected_ret: &rty::RefinedType<Var>)
where
I: IntoIterator<Item = Operand<'tcx>>,
{
// TODO: handle const_fn_def on Env side
let func_ty = if let Some((def_id, args)) = func.const_fn_def() {
let (resolved_def_id, resolved_args) = if self.ctx.is_fn_trait_method(def_id) {
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
// we simply replace the def_id with the closure's function def_id.
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
// adjusts the arguments accordingly.
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
panic!("expected closure arg for fn trait");
};
tracing::debug!(?closure_def_id, "closure instance");
(*closure_def_id, args)
} else {
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let instance =
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
if let Some(instance) = instance {
(instance.def_id(), instance.args)
} else {
(def_id, args)
}
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, ?resolved_args, "resolved");
}

self.ctx
.def_ty_with_args(resolved_def_id, resolved_args)
.expect("unknown def")
.ty
.vacuous()
self.fn_def_ty(def_id, args).vacuous()
} else {
self.operand_type(func.clone()).ty
};
Expand Down
57 changes: 42 additions & 15 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {

pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) {
// predicate's name
// TODO: simply use refine::user_defined_pred for all functions
// after we dropped old annotation parser impl
let impl_type = self.impl_type();
let pred_item_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();
let pred = match impl_type {
Some(t) => chc::UserDefinedPred::new(t.to_string() + "_" + &pred_item_name),
None => refine::user_defined_pred(self.tcx, local_def_id.to_def_id()),
};
let pred = refine::user_defined_pred(self.tcx, local_def_id.to_def_id());

// function's body
use rustc_hir::{Block, Expr, ExprKind};
Expand Down Expand Up @@ -276,17 +269,41 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
|| (all_params_annotated && has_ret)
}

pub fn trait_item_id(&self) -> Option<LocalDefId> {
pub fn local_trait_item_id(&self) -> Option<LocalDefId> {
let impl_item_assoc = self
.tcx
.opt_associated_item(self.local_def_id.to_def_id())?;
let trait_item_id = impl_item_assoc
.trait_item_def_id
.and_then(|id| id.as_local())?;

if trait_item_id == self.local_def_id {
return None;
}

Some(trait_item_id)
}

pub fn trait_item_ty(&mut self) -> Option<rty::RefinedType> {
let impl_did = self.tcx.parent(self.local_def_id.to_def_id());

if self.tcx.def_kind(impl_did) != (rustc_hir::def::DefKind::Impl { of_trait: true }) {
return None;
}

let trait_ref = self.tcx.impl_trait_ref(impl_did)?.instantiate_identity();
let trait_item_did = self
.tcx
.associated_item(self.local_def_id.to_def_id())
.trait_item_def_id
.unwrap();
self.ctx.def_ty_with_args(trait_item_did, trait_ref.args)
}

// Note that we do not expect predicate variables to be generated here
// when type params are still present in the type. Callers should ensure either
// - type params are fully instantiated, or
// - the function is fully annotated
pub fn expected_ty(&mut self) -> rty::RefinedType {
let sig = self
.ctx
Expand Down Expand Up @@ -324,7 +341,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.generic_args,
);

if let Some(trait_item_id) = self.trait_item_id() {
if let Some(trait_item_id) = self.local_trait_item_id() {
tracing::info!("trait item found: {:?}", trait_item_id);
let trait_require_annot = self.ctx.extract_require_annot(
trait_item_id,
Expand Down Expand Up @@ -364,6 +381,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
assert!(require_annot.is_none() || param_annots.is_empty());
assert!(ensure_annot.is_none() || ret_annot.is_none());

let trait_item_ty = self.trait_item_ty();
let is_fully_annotated = self.is_fully_annotated();

let mut builder = self.type_builder.for_function_template(&mut self.ctx, sig);
if let Some(AnnotFormula::Formula(require)) = require_annot {
let formula = require.map_var(|idx| {
Expand All @@ -387,11 +407,18 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
builder.ret_rty(ret_rty);
}

// Note that we do not expect predicate variables to be generated here
// when type params are still present in the type. Callers should ensure either
// - type params are fully instantiated, or
// - the function is fully annotated
rty::RefinedType::unrefined(builder.build().into())
if is_fully_annotated {
let expected_ty = builder.build().into();
if let Some(trait_item_ty) = trait_item_ty {
let clauses = rty::relate_sub_closed_type(&expected_ty, &trait_item_ty.ty);
self.ctx.extend_clauses(clauses);
}
rty::RefinedType::unrefined(expected_ty)
} else if let Some(trait_item_ty) = trait_item_ty {
trait_item_ty
} else {
rty::RefinedType::unrefined(builder.build().into())
}
}

/// Extract the target DefId from `#[thrust::extern_spec_fn]` function.
Expand Down
26 changes: 13 additions & 13 deletions std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ mod thrust_models {
}

#[thrust::def::mut_model]
pub struct Mut<T>(PhantomData<T>);
pub struct Mut<T: ?Sized>(PhantomData<T>);

impl<T> Mut<T> {
#[allow(dead_code)]
Expand Down Expand Up @@ -100,7 +100,7 @@ mod thrust_models {
}

#[thrust::def::box_model]
pub struct Box<T>(PhantomData<T>);
pub struct Box<T: ?Sized>(PhantomData<T>);

impl<T> Box<T> {
#[allow(dead_code)]
Expand Down Expand Up @@ -128,7 +128,7 @@ mod thrust_models {
}

#[thrust::def::array_model]
pub struct Array<I, T>(PhantomData<I>, PhantomData<T>);
pub struct Array<I: ?Sized, T: ?Sized>(PhantomData<I>, PhantomData<T>);

impl<I, T, U> PartialEq<U> for Array<I, T> where U: super::Model<Ty = Self> {
#[thrust::ignored]
Expand Down Expand Up @@ -156,9 +156,9 @@ mod thrust_models {
}

#[thrust::def::closure_model]
pub struct Closure<T>(PhantomData<T>);
pub struct Closure<T: ?Sized>(PhantomData<T>);

pub struct Vec<T>(pub Array<Int, T>, pub Int);
pub struct Vec<T: ?Sized>(pub Array<Int, T>, pub Int);

impl<T, U> PartialEq<U> for Vec<T> where U: super::Model<Ty = Self> {
#[thrust::ignored]
Expand Down Expand Up @@ -200,7 +200,7 @@ mod thrust_models {
type Ty = bool;
}

impl<T> Model for model::Closure<T> {
impl<T: ?Sized> Model for model::Closure<T> {
type Ty = model::Closure<T>;
}

Expand All @@ -224,35 +224,35 @@ mod thrust_models {
impl_tuple_model!(T0, T1, T2, T3, T4, T5, T6, T7, T8);
impl_tuple_model!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9);

impl<'a, T> Model for &'a mut T where T: Model {
impl<'a, T: ?Sized> Model for &'a mut T where T: Model {
type Ty = model::Mut<<T as Model>::Ty>;
}

impl<T> Model for model::Mut<T> {
impl<T: ?Sized> Model for model::Mut<T> {
type Ty = model::Mut<T>;
}

impl<'a, T> Model for &'a T where T: Model {
impl<'a, T: ?Sized> Model for &'a T where T: Model {
type Ty = &'a <T as Model>::Ty;
}

impl<T> Model for Box<T> where T: Model {
impl<T: ?Sized> Model for Box<T> where T: Model {
type Ty = model::Box<<T as Model>::Ty>;
}

impl<T> Model for model::Box<T> {
impl<T: ?Sized> Model for model::Box<T> {
type Ty = model::Box<T>;
}

impl<I, T> Model for model::Array<I, T> {
impl<I: ?Sized, T: ?Sized> Model for model::Array<I, T> {
type Ty = model::Array<I, T>;
}

impl<T> Model for Vec<T> where T: Model {
type Ty = model::Vec<<T as Model>::Ty>;
}

impl<T> Model for model::Vec<T> {
impl<T: ?Sized> Model for model::Vec<T> {
type Ty = model::Vec<T>;
}

Expand Down
13 changes: 8 additions & 5 deletions tests/ui/fail/annot_preds_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//@compile-flags: -Adead_code -C debug-assertions=off

// A is represented as Tuple<Int> in SMT-LIB2 format.
#[derive(PartialEq)]
struct A {
x: i64,
}
Expand All @@ -10,25 +11,27 @@ impl thrust_models::Model for A {
type Ty = Self;
}

#[thrust_macros::context]
trait Double {
// Support annotations in trait definitions
#[thrust::predicate]
#[thrust_macros::predicate]
fn is_double(self, doubled: Self) -> bool;

// This annotations are applied to all implementors of the `Double` trait.
#[thrust::requires(true)]
#[thrust::ensures(Self::is_double(*self, ^self))]
#[thrust_macros::requires(true)]
#[thrust_macros::ensures(Self::is_double(*self, !self))]
fn double(&mut self);
}

#[thrust_macros::context]
impl Double for A {
// Write concrete definitions for predicates in `impl` blocks
#[thrust::predicate]
#[thrust_macros::predicate]
fn is_double(self, doubled: Self) -> bool {
// (tuple_proj<Int>.0 self) is equivalent to self.x
// self.x * 3 == doubled.x (this isn't actually doubled!) is written as following:
"(=
(* (tuple_proj<Int>.0 self) 3)
(* (tuple_proj<Int>.0 self_) 3)
(tuple_proj<Int>.0 doubled)
)"; true // This definition does not comply with annotations in trait!
}
Expand Down
Loading