Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 58 additions & 52 deletions Cslib/Foundations/Control/Monad/Free/Effects.lean
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ The canonical interpreter `toStateM` derived from `liftM` agrees with the hand-w
recursive interpreter `run` for `FreeState`.
-/
@[simp]
theorem run_toStateM {α : Type u} (comp : FreeState σ α) :
(toStateM comp).run = run comp := by
ext s₀ : 1
theorem run_toStateM {α : Type u} (comp : FreeState σ α) (s₀ : σ) :
(toStateM comp).run s₀ = pure (run comp s₀) := by
induction comp generalizing s₀ with
| pure a => rfl
| liftBind op cont ih =>
Expand All @@ -124,10 +123,9 @@ lemma run_set (s' : σ) (k : PUnit → FreeState σ α) (s₀ : σ) :
def run' (c : FreeState σ α) (s₀ : σ) : α := (run c s₀).1

@[simp]
theorem run'_toStateM {α : Type u} (comp : FreeState σ α) :
(toStateM comp).run' = run' comp := by
ext s₀ : 1
rw [run', ← run_toStateM]
theorem run'_toStateM {α : Type u} (comp : FreeState σ α) (s₀ : σ) :
(toStateM comp).run' s₀ = pure (run' comp s₀) := by
rw [run', StateT.run'_eq, run_toStateM]
rfl

@[simp]
Expand Down Expand Up @@ -176,16 +174,6 @@ def toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) : WriterT ω I
theorem toWriterT_unique {α : Type u} [Monoid ω] (g : FreeWriter ω α → WriterT ω Id α)
(h : Interprets writerInterp g) : g = toWriterT := h.eq

/--
Writes a log entry. This creates an effectful node in the computation tree.
-/
abbrev tell (w : ω) : FreeWriter ω PUnit :=
lift (.tell w)

@[simp]
lemma tell_def (w : ω) :
tell w = .lift (.tell w) := rfl

/--
Interprets a `FreeWriter` computation by recursively traversing the tree, accumulating
log entries with the monoid operation, and returns the final value paired with the accumulated log.
Expand All @@ -204,30 +192,59 @@ lemma run_pure [Monoid ω] (a : α) :
lemma run_liftBind_tell [Monoid ω] (w : ω) (k : PUnit → FreeWriter ω α) :
run (liftBind (.tell w) k) = (let (a, w') := run (k .unit); (a, w * w')) := rfl


-- https://github.com/leanprover-community/mathlib4/pull/36497
section missing_from_mathlib

@[simp]
theorem _root_.WriterT.run_pure [Monoid ω] [Monad M] (a : α) :
WriterT.run (pure a : WriterT ω M α) = pure (a, 1) := rfl

@[simp]
theorem _root_.WriterT.run_bind [Monoid ω] [Monad M] (x : WriterT ω M α) (f : α → WriterT ω M β) :
WriterT.run (x >>= f) = x.run >>= fun (a, w₁) => (fun (b, w₂) => (b, w₁ * w₂)) <$> (f a).run :=
rfl

@[simp]
theorem _root_.WriterT.run_tell [Monad M] (w : ω) :
WriterT.run (MonadWriter.tell w : WriterT ω M PUnit) = pure (.unit, w) := rfl

end missing_from_mathlib

/--
The canonical interpreter `toWriterT` derived from `liftM` agrees with the hand-written
recursive interpreter `run` for `FreeWriter`.
-/
@[simp]
theorem run_toWriterT {α : Type u} [Monoid ω] :
∀ comp : FreeWriter ω α, (toWriterT comp).run = run comp
| .pure _ => by simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run]
| liftBind (.tell w) cont => by
simp only [toWriterT, liftM_liftBind, run_liftBind_tell] at *
rw [← run_toWriterT]
congr
theorem run_toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) :
(toWriterT comp).run = pure (run comp) := by
ext : 1
induction comp with
| pure _ => simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run]
| liftBind op cont ih =>
cases op
simp only [toWriterT, liftM_liftBind, run_liftBind_tell, Id.run_pure] at *
rw [ ← ih]
simp [WriterT.run_bind, writerInterp]

/--
`listen` captures the log produced by a subcomputation incrementally. It traverses the computation,
emitting log entries as encountered, and returns the accumulated log as a result.
-/
def listen [Monoid ω] : FreeWriter ω α → FreeWriter ω (α × ω)
/-- Implementation of `MonadWriter.listen`. -/
protected def listen [Monoid ω] : FreeWriter ω α → FreeWriter ω (α × ω)
| .pure a => .pure (a, 1)
| .liftBind (.tell w) k =>
liftBind (.tell w) fun _ =>
listen (k .unit) >>= fun (a, w') =>
FreeWriter.listen (k .unit) >>= fun (a, w') =>
pure (a, w * w')

/-- Implementation of `MonadWriter.pass`. -/
protected def pass [Monoid ω] (m : FreeWriter ω (α × (ω → ω))) : FreeWriter ω α :=
let ((a, f), w) := run m
liftBind (.tell (f w)) (fun _ => .pure a)

instance [Monoid ω] : MonadWriter ω (FreeWriter ω) where
tell w := lift (.tell w)
listen := FreeWriter.listen
pass := FreeWriter.pass

@[simp]
lemma listen_pure [Monoid ω] (a : α) :
listen (.pure a : FreeWriter ω α) = .pure (a, 1) := rfl
Expand All @@ -241,24 +258,14 @@ lemma listen_liftBind_tell [Monoid ω] (w : ω)
pure (a, w * w')) := by
rfl

/--
`pass` allows a subcomputation to modify its own log. After traversing the computation and
accumulating its log, the resulting function is applied to rewrite the accumulated log
before re-emission.
-/
def pass [Monoid ω] (m : FreeWriter ω (α × (ω → ω))) : FreeWriter ω α :=
let ((a, f), w) := run m
liftBind (.tell (f w)) (fun _ => .pure a)
@[simp]
lemma tell_def [Monoid ω] (w : ω) :
(tell w : FreeWriter ω _) = .lift (.tell w) := rfl

@[simp]
lemma pass_def [Monoid ω] (m : FreeWriter ω (α × (ω → ω))) :
pass m = let ((a, f), w) := run m; liftBind (.tell (f w)) fun _ => .pure a := rfl

instance [Monoid ω] : MonadWriter ω (FreeWriter ω) where
tell := tell
listen := listen
pass := pass

end FreeWriter

/-! ### Continuation Monad via `FreeM` -/
Expand Down Expand Up @@ -301,9 +308,8 @@ The canonical interpreter `toContT` derived from `liftM` agrees with the hand-wr
recursive interpreter `run` for `FreeCont`.
-/
@[simp]
theorem run_toContT {α : Type u} (comp : FreeCont r α) :
(toContT comp).run = run comp := by
ext k
theorem run_toContT {α : Type u} (comp : FreeCont r α) (k : α → r) :
(toContT comp).run k = pure (run comp k) := by
induction comp with
| pure a => rfl
| liftBind op cont ih =>
Expand All @@ -322,7 +328,7 @@ lemma run_liftBind_callCC (g : (α → r) → r)
(cont : α → FreeCont r β) (k : β → r) :
run (liftBind (.callCC g) cont) k = g (fun a => run (cont a) k) := rfl

/-- Call with current continuation for the Free continuation monad. -/
/-- Universe-generic version of `MonadCont.callCC` -/
def callCC (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) :
FreeCont r α :=
liftBind (.callCC fun k => run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k) pure
Expand All @@ -333,15 +339,15 @@ lemma callCC_def (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) :
liftBind (.callCC fun k => run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k) pure :=
rfl

instance : MonadCont (FreeCont r) where
callCC := .callCC

/-- `run` of a `callCC` node simplifies to running the handler with the current continuation. -/
@[simp]
lemma run_callCC (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) (k : α → r) :
run (callCC f) k = run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k := by
simp [callCC, run_liftBind_callCC]

instance : MonadCont (FreeCont r) where
callCC := .callCC

end FreeCont

/-- Type constructor for reader operations. -/
Expand Down Expand Up @@ -369,7 +375,7 @@ def readInterp {α : Type u} : ReaderF σ α → ReaderM σ α

/-- Convert a `FreeReader` computation into a `ReaderM` computation. This is the canonical
interpreter derived from `liftM`. -/
def toReaderM {α : Type u} (comp : FreeReader σ α) : ReaderM σ α :=
abbrev toReaderM {α : Type u} (comp : FreeReader σ α) : ReaderM σ α :=
comp.liftM readInterp

/-- `toReaderM` is the unique interpreter extending `readInterp`. -/
Expand All @@ -387,7 +393,7 @@ The canonical interpreter `toReaderM` derived from `liftM` agrees with the hand-
recursive interpreter `run` for `FreeReader` -/
@[simp]
theorem run_toReaderM {α : Type u} (comp : FreeReader σ α) (s : σ) :
(toReaderM comp).run s = run comp s := by
(toReaderM comp).run s = pure (run comp s) := by
induction comp generalizing s with
| pure a => rfl
| liftBind op cont ih =>
Expand Down
Loading