File ‹~~/src/Tools/eqsubst.ML›
signature EQSUBST =
sig
type match =
((indexname * (sort * typ)) list
* (indexname * (typ * term)) list)
* (string * typ) list
* (string * typ) list
* term
type searchinfo =
Proof.context
* int
* Zipper.T
datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
val eqsubst_asm_tac': Proof.context ->
(searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
val eqsubst_tac: Proof.context ->
int list ->
thm list -> int -> tactic
val eqsubst_tac': Proof.context ->
(searchinfo -> term -> match Seq.seq)
-> thm
-> int
-> thm
-> thm Seq.seq
val valid_match_start: Zipper.T -> bool
val search_lr_all: Zipper.T -> Zipper.T Seq.seq
val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
end;
structure EqSubst: EQSUBST =
struct
fun prep_meta_eq ctxt =
Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
fun unfix_frees frees =
fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
type match =
((indexname * (sort * typ)) list
* (indexname * (typ * term)) list)
* (string * typ) list
* (string * typ) list
* term;
type searchinfo =
Proof.context
* int
* Zipper.T;
datatype 'a skipseq =
SkipMore of int |
SkipSeq of 'a Seq.seq Seq.seq;
fun skipto_skipseq m s =
let
fun skip_occs n sq =
(case Seq.pull sq of
NONE => SkipMore n
| SOME (h, t) =>
(case Seq.pull h of
NONE => skip_occs n t
| SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
in skip_occs m s end;
fun mk_foo_match mkuptermfunc Ts t =
let
val ty = Term.type_of t
val bigtype = rev (map snd Ts) ---> ty
fun mk_foo 0 t = t
| mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
val num_of_bnds = length Ts
val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
fun mk_fake_bound_name n = ":b_" ^ n;
fun fakefree_badbounds Ts t =
let val (FakeTs, Ts, newnames) =
fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
let
val newname = singleton (Name.variant_list usednames) n
in
((mk_fake_bound_name newname, ty) :: FakeTs,
(newname, ty) :: Ts,
newname :: usednames)
end) Ts ([], [], [])
in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
fun prep_zipper_match z =
let
val t = Zipper.trm z
val c = Zipper.ctxt z
val Ts = Zipper.C.nty_ctxt c
val (FakeTs', Ts', t') = fakefree_badbounds Ts t
val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
in
(t', (FakeTs', Ts', absterm))
end;
fun clean_unify ctxt ix (a as (pat, tgt)) =
let
val pat_ty = Term.type_of pat;
val tgt_ty = Term.type_of tgt;
val typs_unify =
SOME (Sign.typ_unify (Proof_Context.theory_of ctxt) (pat_ty, tgt_ty) (Vartab.empty, ix))
handle Type.TUNIFY => NONE;
in
(case typs_unify of
SOME (typinsttab, ix2) =>
let
fun mk_insts env =
(Vartab.dest (Envir.type_env env),
Vartab.dest (Envir.term_env env));
val initenv =
Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
val useq = Unify.smash_unifiers (Context.Proof ctxt) [a] initenv
handle ListPair.UnequalLengths => Seq.empty
| Term.TERM _ => Seq.empty;
fun clean_unify' useq () =
(case (Seq.pull useq) of
NONE => NONE
| SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
handle ListPair.UnequalLengths => NONE
| Term.TERM _ => NONE;
in
(Seq.make (clean_unify' useq))
end
| NONE => Seq.empty)
end;
fun clean_unify_z ctxt maxidx pat z =
let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
(clean_unify ctxt maxidx (t, pat))
end;
fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
| bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
| bot_left_leaf_of x = x;
fun valid_match_start z =
(case bot_left_leaf_of (Zipper.trm z) of
Var _ => false
| _ => true);
val search_lr_all = ZipperSearch.all_bl_ur;
fun search_lr_valid validf =
let
fun sf_valid_td_lr z =
let val here = if validf z then [Zipper.Here z] else [] in
(case Zipper.trm z of
_ $ _ =>
[Zipper.LookIn (Zipper.move_down_left z)] @ here @
[Zipper.LookIn (Zipper.move_down_right z)]
| Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
| _ => here)
end;
in Zipper.lzy_search sf_valid_td_lr end;
fun search_bt_valid validf =
let
fun sf_valid_td_lr z =
let val here = if validf z then [Zipper.Here z] else [] in
(case Zipper.trm z of
_ $ _ =>
[Zipper.LookIn (Zipper.move_down_left z),
Zipper.LookIn (Zipper.move_down_right z)] @ here
| Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
| _ => here)
end;
in Zipper.lzy_search sf_valid_td_lr end;
fun searchf_unify_gen f (ctxt, maxidx, z) lhs =
Seq.map (clean_unify_z ctxt maxidx lhs) (Zipper.limit_apply f z);
val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
RW_Inst.rw ctxt m rule conclthm
|> unfix_frees cfvs
|> Conv.fconv_rule Drule.beta_eta_conversion
|> (fn r => resolve_tac ctxt [r] i st);
fun prep_concl_subst ctxt i gth =
let
val th = Thm.incr_indexes 1 gth;
val tgt_term = Thm.prop_of th;
val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
val conclterm = Logic.strip_imp_concl fixedbody;
val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
val maxidx = Thm.maxidx_of th;
val ft =
(Zipper.move_down_right
o Zipper.move_down_left
o Zipper.mktop
o Thm.prop_of) conclthm
in
((cfvs, conclthm), (ctxt, maxidx, ft))
end;
fun eqsubst_tac' ctxt searchf instepthm i st =
let
val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
fun rewrite_with_thm r =
let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
searchf searchinfo lhs
|> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
end;
in stepthms |> Seq.maps rewrite_with_thm end;
fun skip_first_occs_search occ srchf sinfo lhs =
(case skipto_skipseq occ (srchf sinfo lhs) of
SkipMore _ => Seq.empty
| SkipSeq ss => Seq.flat ss);
fun eqsubst_tac ctxt occs thms =
SELECT_GOAL
let
val thmseq = Seq.of_list thms;
fun apply_occ_tac occ st =
thmseq |> Seq.maps (fn r =>
eqsubst_tac' ctxt
(skip_first_occs_search occ searchf_lr_unify_valid) r
(Thm.nprems_of st) st);
val sorted_occs = Library.sort (rev_order o int_ord) occs;
in Seq.EVERY (map apply_occ_tac sorted_occs) #> Seq.maps distinct_subgoals_tac end;
fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
let
val st2 = Thm.rotate_rule (j - 1) i st;
val preelimrule =
RW_Inst.rw ctxt m rule pth
|> (Seq.hd o prune_params_tac ctxt)
|> Thm.permute_prems 0 ~1
|> unfix_frees cfvs
|> Conv.fconv_rule Drule.beta_eta_conversion;
in
Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
(dresolve_tac ctxt [preelimrule] i st2)
end;
fun prep_subst_in_asm ctxt i gth j =
let
val th = Thm.incr_indexes 1 gth;
val tgt_term = Thm.prop_of th;
val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
val asm_nprems = length (Logic.strip_imp_prems asmt);
val pth = Thm.trivial ((Thm.cterm_of ctxt) asmt);
val maxidx = Thm.maxidx_of th;
val ft =
(Zipper.move_down_right
o Zipper.mktop
o Thm.prop_of) pth
in ((cfvs, j, asm_nprems, pth), (ctxt, maxidx, ft)) end;
fun prep_subst_in_asms ctxt i gth =
map (prep_subst_in_asm ctxt i gth)
((fn l => Library.upto (1, length l))
(Logic.prems_of_goal (Thm.prop_of gth) i));
fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
let
val asmpreps = prep_subst_in_asms ctxt i st;
val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
fun rewrite_with_thm r =
let
val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
fun occ_search occ [] = Seq.empty
| occ_search occ ((asminfo, searchinfo)::moreasms) =
(case searchf searchinfo occ lhs of
SkipMore i => occ_search i moreasms
| SkipSeq ss =>
Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
(occ_search 1 moreasms))
in
occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
end;
in stepthms |> Seq.maps rewrite_with_thm end;
fun skip_first_asm_occs_search searchf sinfo occ lhs =
skipto_skipseq occ (searchf sinfo lhs);
fun eqsubst_asm_tac ctxt occs thms =
SELECT_GOAL
let
val thmseq = Seq.of_list thms;
fun apply_occ_tac occ st =
thmseq |> Seq.maps (fn r =>
eqsubst_asm_tac' ctxt
(skip_first_asm_occs_search searchf_lr_unify_valid) occ r
(Thm.nprems_of st) st);
val sorted_occs = Library.sort (rev_order o int_ord) occs;
in Seq.EVERY (map apply_occ_tac sorted_occs) #> Seq.maps distinct_subgoals_tac end;
val _ =
Theory.setup
(Method.setup \<^binding>‹subst›
(Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
"single-step substitution");
end;