Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,54 @@ pub fn array_model_store_path() -> [Symbol; 3] {
]
}

pub fn seq_model_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_model"),
]
}

pub fn seq_empty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_empty"),
]
}

pub fn seq_singleton_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_singleton"),
]
}

pub fn seq_len_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_len"),
]
}

pub fn seq_push_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_push"),
]
}

pub fn seq_concat_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_concat"),
]
}

pub fn exists_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down
83 changes: 82 additions & 1 deletion src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,25 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
FormulaOrTerm::Formula(fn_ty.postcondition_formula(&param_args, result))
}

fn node_arg_type_at(&self, hir_id: HirId, idx: usize) -> rty::Type<rty::Closed> {
let generic_args = self.typeck.node_args(hir_id);
let generic_args =
mir_ty::EarlyBinder::bind(generic_args).instantiate(self.tcx, self.generic_args);
let elem_ty = generic_args.type_at(idx);
self.type_builder.build(elem_ty)
}

fn adt_arg_type_at(
&self,
expr: &'tcx rustc_hir::Expr<'tcx>,
idx: usize,
) -> rty::Type<rty::Closed> {
let mir_ty::TyKind::Adt(_, args) = self.expr_ty(expr).kind() else {
panic!("expected ADT");
};
self.type_builder.build(args.type_at(idx))
}

fn variant_ctor_term(
&self,
ctor_did: rustc_span::def_id::DefId,
Expand Down Expand Up @@ -623,9 +642,18 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
FormulaOrTerm::Term(term.tuple_proj(index))
}
ExprKind::Index(array, index, _) => {
let array_ty = self.expr_ty(array);
let array_term = self.to_term(array);
let index_term = self.to_term(index);
FormulaOrTerm::Term(array_term.select(index_term))
let is_seq = array_ty
.ty_adt_def()
.is_some_and(|adt| Some(adt.did()) == self.def_ids.seq_model());
let array_inner = if is_seq {
array_term.tuple_proj(0)
} else {
array_term
};
FormulaOrTerm::Term(array_inner.select(index_term))
}
ExprKind::MethodCall(method, receiver, args, _) => {
if let Some(def_id) = self.typeck.type_dependent_def_id(hir.hir_id) {
Expand All @@ -644,6 +672,40 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
let t = self.to_term(receiver);
return FormulaOrTerm::Term(t);
}
if Some(def_id) == self.def_ids.seq_len() {
assert!(args.is_empty(), "Seq::len does not take any arguments");
let t = self.to_term(receiver);
return FormulaOrTerm::Term(t.tuple_proj(1));
}
if Some(def_id) == self.def_ids.seq_push() {
assert_eq!(args.len(), 1, "Seq::push takes exactly 1 argument");
let t = self.to_term(receiver);
let v = self.to_term(&args[0]);
let arr = t.clone().tuple_proj(0);
let len = t.tuple_proj(1);
let new_arr = arr.store(len.clone(), v);
let new_len = len.add(chc::Term::int(1));
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
if Some(def_id) == self.def_ids.seq_concat() {
assert_eq!(args.len(), 1, "Seq::concat takes exactly 1 argument");
let elem_sort = self.adt_arg_type_at(receiver, 0).to_sort();
let t = self.to_term(receiver);
let other = self.to_term(&args[0]);
let a_arr = t.clone().tuple_proj(0);
let a_len = t.tuple_proj(1);
let b_arr = other.clone().tuple_proj(0);
let b_len = other.tuple_proj(1);
let new_arr = chc::Term::array_concat(
elem_sort,
a_arr,
a_len.clone(),
b_arr,
b_len.clone(),
);
let new_len = a_len.add(b_len);
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
}
unimplemented!("unsupported method call in formula: {:?}", method)
}
Expand Down Expand Up @@ -719,6 +781,25 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
let t = self.to_term(&args[0]);
return FormulaOrTerm::Term(chc::Term::box_(t));
}
if Some(def_id) == self.def_ids.seq_empty() {
assert!(args.is_empty(), "Seq::empty does not take any arguments");
let elem_sort = self.node_arg_type_at(func_expr.hir_id, 0).to_sort();
return FormulaOrTerm::Term(chc::Term::tuple(vec![
chc::Term::array_empty(chc::Sort::int(), elem_sort),
chc::Term::int(0),
]));
}
if Some(def_id) == self.def_ids.seq_singleton() {
assert_eq!(args.len(), 1, "Seq::singleton takes exactly 1 argument");
let v = self.to_term(&args[0]);
let elem_sort = self.node_arg_type_at(func_expr.hir_id, 0).to_sort();
let new_arr = chc::Term::array_empty(chc::Sort::int(), elem_sort)
.store(chc::Term::int(0), v);
return FormulaOrTerm::Term(chc::Term::tuple(vec![
new_arr,
chc::Term::int(1),
]));
}
if let rustc_hir::def::DefKind::Ctor(ctor_of, _) = def_kind {
let terms = args.iter().map(|e| self.to_term(e)).collect();
match ctor_of {
Expand Down
49 changes: 49 additions & 0 deletions src/analyze/did_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ struct DefIds {
box_model_new: OnceCell<Option<DefId>>,
array_model_store: OnceCell<Option<DefId>>,

seq_model: OnceCell<Option<DefId>>,
seq_empty: OnceCell<Option<DefId>>,
seq_singleton: OnceCell<Option<DefId>>,
seq_len: OnceCell<Option<DefId>>,
seq_push: OnceCell<Option<DefId>>,
seq_concat: OnceCell<Option<DefId>>,

exists: OnceCell<Option<DefId>>,
forall: OnceCell<Option<DefId>>,
implies: OnceCell<Option<DefId>>,
Expand Down Expand Up @@ -179,6 +186,48 @@ impl<'tcx> DefIdCache<'tcx> {
.get_or_init(|| self.annotated_def(&crate::analyze::annot::array_model_store_path()))
}

pub fn seq_model(&self) -> Option<DefId> {
*self
.def_ids
.seq_model
.get_or_init(|| self.annotated_def(&crate::analyze::annot::seq_model_path()))
}

pub fn seq_empty(&self) -> Option<DefId> {
*self
.def_ids
.seq_empty
.get_or_init(|| self.annotated_def(&crate::analyze::annot::seq_empty_path()))
}

pub fn seq_singleton(&self) -> Option<DefId> {
*self
.def_ids
.seq_singleton
.get_or_init(|| self.annotated_def(&crate::analyze::annot::seq_singleton_path()))
}

pub fn seq_len(&self) -> Option<DefId> {
*self
.def_ids
.seq_len
.get_or_init(|| self.annotated_def(&crate::analyze::annot::seq_len_path()))
}

pub fn seq_push(&self) -> Option<DefId> {
*self
.def_ids
.seq_push
.get_or_init(|| self.annotated_def(&crate::analyze::annot::seq_push_path()))
}

pub fn seq_concat(&self) -> Option<DefId> {
*self
.def_ids
.seq_concat
.get_or_init(|| self.annotated_def(&crate::analyze::annot::seq_concat_path()))
}

pub fn exists(&self) -> Option<DefId> {
*self
.def_ids
Expand Down
8 changes: 4 additions & 4 deletions src/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub struct AnnotPathSegment {

/// A trait for resolving variables in annotations to their logical representation and their sorts.
pub trait Resolver {
type Output;
type Output: Clone;
fn resolve(&self, ident: Ident) -> Option<(Self::Output, chc::Sort)>;
}

Expand Down Expand Up @@ -1222,7 +1222,7 @@ struct RefinementResolver<'a, T> {
self_: Option<(Ident, chc::Sort)>,
}

impl<'a, T> Resolver for RefinementResolver<'a, T> {
impl<'a, T: Clone> Resolver for RefinementResolver<'a, T> {
type Output = rty::RefinedTypeVar<T>;
fn resolve(&self, ident: Ident) -> Option<(Self::Output, chc::Sort)> {
if let Some((self_ident, sort)) = &self.self_ {
Expand Down Expand Up @@ -1256,7 +1256,7 @@ pub struct MappedResolver<'a, T, F> {
map: F,
}

impl<'a, T, F, U> Resolver for MappedResolver<'a, T, F>
impl<'a, T: Clone, F, U: Clone> Resolver for MappedResolver<'a, T, F>
where
F: Fn(T) -> U,
{
Expand Down Expand Up @@ -1290,7 +1290,7 @@ impl<'a, T> Default for StackedResolver<'a, T> {
}
}

impl<'a, T> Resolver for StackedResolver<'a, T> {
impl<'a, T: Clone> Resolver for StackedResolver<'a, T> {
type Output = T;
fn resolve(&self, ident: Ident) -> Option<(Self::Output, chc::Sort)> {
for resolver in &self.resolvers {
Expand Down
Loading