Monad Transformers
In the previous sections you learned about some handy monads Option, IO, Reader, State and Except, and you now know how to make your function use one of these, but what you do not yet know is how to make your function use multiple monads at once.
For example, suppose you need a function that wants to access some Reader context and optionally throw
an exception? This would require composition of two monads ReaderM
and Except
and this is what
monad transformers are for.
A monad transformer is fundamentally a wrapper type. It is generally parameterized by another
monadic type. You can then run actions from the inner monad, while adding your own customized
behavior for combining actions in this new monad. The common transformers add T
to the end of an
existing monad name. You will find OptionT
, ExceptT
, ReaderT
, StateT
but there is no transformer
for IO
. So generally if you need IO
it becomes the innermost wrapped monad.
In the following example we use ReaderT
to provide some read only context to a function
and this ReaderT
transformer will wrap an Except
monad. If all goes well the
requiredArgument
returns the value of a required argument and optionalSwitch
returns true if the optional argument is present.
abbrevArguments :=Arguments: TypeListList: Type → TypeString defString: TypeindexOf? [indexOf?: {α : Type u_1} → [inst : BEq α] → List α → α → optParam Nat 0 → Option NatBEqBEq: Type u_1 → Type u_1α] (α: Type u_1xs :xs: List αListList: Type u_1 → Type u_1α) (α: Type u_1s :s: αα) (α: Type u_1start :=start: optParam Nat 00):0: NatOptionOption: Type → TypeNat := matchNat: Typexs with | [] =>xs: List αnone |none: {α : Type} → Option αa ::a: αtail => iftail: List αa ==a: αs thens: αsomesome: {α : Type} → α → Option αstart elsestart: optParam Nat 0indexOf?indexOf?: {α : Type u_1} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Nattailtail: List αs (s: αstart+start: optParam Nat 01) def1: NatrequiredArgument (requiredArgument: String → ReaderT Arguments (Except String) Stringname :name: StringString) :String: TypeReaderTReaderT: Type → (Type → Type) → Type → TypeArguments (Arguments: TypeExceptExcept: Type → Type → TypeString)String: TypeString := do letString: Typeargs ←args: Argumentsread letread: {ρ : outParam Type} → {m : Type → Type} → [self : MonadReader ρ m] → m ρvalue := matchvalue: StringindexOf?indexOf?: {α : Type} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Natargsargs: Argumentsname with |name: Stringsomesome: {α : Type ?u.998} → α → Option αi => ifi: Nati +i: Nat1 <1: Natargs.args: Argumentslength thenlength: {α : Type} → List α → Natargs[args: Argumentsi+i: Nat1]! else1: Nat"" |"": Stringnone =>none: {α : Type ?u.1187} → Option α"" if"": Stringvalue ==value: String"" then"": Stringthrowthrow: {ε : outParam Type} → {m : Type → Type} → [self : MonadExcept ε m] → {α : Type} → ε → m αthrow s!"Command line argument {name} missing": ReaderT Arguments (Except String) Strings!throw s!"Command line argument {name} missing": ReaderT Arguments (Except String) String"Command line argument {throw s!"Command line argument {name} missing": ReaderT Arguments (Except String) Stringnamename: String} missing" returnthrow s!"Command line argument {name} missing": ReaderT Arguments (Except String) Stringvalue defvalue: StringoptionalSwitch (optionalSwitch: String → ReaderT Arguments (Except String) Boolname :name: StringString) :String: TypeReaderTReaderT: Type → (Type → Type) → Type → TypeArguments (Arguments: TypeExceptExcept: Type → Type → TypeString)String: TypeBool := do letBool: Typeargs ←args: Argumentsread return match (read: {ρ : outParam Type} → {m : Type → Type} → [self : MonadReader ρ m] → m ρindexOf?indexOf?: {α : Type} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Natargsargs: Argumentsname) with |name: Stringsome _ =>some: {α : Type ?u.2490} → α → Option αtrue |true: Boolnone =>none: {α : Type ?u.2506} → Option αfalsefalse: BoolrequiredArgumentrequiredArgument: String → ReaderT Arguments (Except String) String"--input" |>."--input": Stringrun [run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α"--input","--input": String"foo"] -- Except.ok "foo""foo": StringrequiredArgumentrequiredArgument: String → ReaderT Arguments (Except String) String"--input" |>."--input": Stringrun [run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α"foo","foo": String"bar"] -- Except.error "Command line argument --input missing""bar": StringoptionalSwitchoptionalSwitch: String → ReaderT Arguments (Except String) Bool"--help" |>."--help": Stringrun [run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α"--help"] -- Except.ok true"--help": StringoptionalSwitchoptionalSwitch: String → ReaderT Arguments (Except String) Bool"--help" |>."--help": Stringrunrun: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α[] -- Except.ok false[]: List String
Notice that throw
was available from the inner Except
monad. The cool thing is you can switch
this around and get the exact same result using ExceptT
as the outer monad transformer and
ReaderM
as the wrapped monad. Try changing requiredArgument to ExceptT String (ReaderM Arguments) Bool
.
Note: the |>.
notation is described in Readers.
Adding more layers
Here's the best part about monad transformers. Since the result of a monad transformer is itself a monad, you can wrap it inside another transformer! Suppose you need to pass in some read only context like the command line arguments, update some read-write state (like program Config) and optionally throw an exception, then you could write this:
structureConfig whereConfig: Bool → Bool → String → Confighelp :help: Config → BoolBool :=Bool: Typefalsefalse: Boolverbose :verbose: Config → BoolBool :=Bool: Typefalsefalse: Boolinput :input: Config → StringString :=String: Type"" deriving"": StringRepr abbrevRepr: Type u → Type uCliConfigM :=CliConfigM: Type → TypeStateTStateT: Type → (Type → Type) → Type → TypeConfig (Config: TypeReaderTReaderT: Type → (Type → Type) → Type → TypeArguments (Arguments: TypeExceptExcept: Type → Type → TypeString)) defString: TypeparseArguments :parseArguments: CliConfigM BoolCliConfigMCliConfigM: Type → TypeBool := do let mutBool: Typeconfig ←config: Configget ifget: {σ : outParam Type} → {m : Type → Type} → [self : MonadState σ m] → m σ(←(← optionalSwitch "--help"): BooloptionalSwitchoptionalSwitch: String → ReaderT Arguments (Except String) Bool(← optionalSwitch "--help"): Bool"--help""--help": String) then(← optionalSwitch "--help"): Boolthrowthrow: {ε : outParam Type} → {m : Type → Type} → [self : MonadExcept ε m] → {α : Type} → ε → m α"Usage: example [--help] [--verbose] [--input <input file>]""Usage: example [--help] [--verbose] [--input <input file>]": Stringconfig := {config: Configconfig with verbose :=config: Config(←(← optionalSwitch "--verbose"): BooloptionalSwitchoptionalSwitch: String → ReaderT Arguments (Except String) Bool(← optionalSwitch "--verbose"): Bool"--verbose""--verbose": String), input :=(← optionalSwitch "--verbose"): Bool(←(← requiredArgument "--input"): StringrequiredArgumentrequiredArgument: String → ReaderT Arguments (Except String) String(← requiredArgument "--input"): String"--input""--input": String) }(← requiredArgument "--input"): Stringsetset: {σ : semiOutParam Type} → {m : Type → Type} → [self : MonadStateOf σ m] → σ → m PUnitconfig returnconfig: Configtrue deftrue: Boolmain (main: List String → IO Unitargs :args: List StringListList: Type → TypeString) :String: TypeIOIO: Type → TypeUnit := do letUnit: Typeconfig :config: ConfigConfig := { input :=Config: Type"default"} match"default": StringparseArguments |>.parseArguments: CliConfigM Boolrunrun: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)config |>.config: Configrunrun: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m αargs with |args: List StringExcept.ok (_,Except.ok: {ε : Type ?u.19090} → {α : Type ?u.19089} → α → Except ε αc) => doc: ConfigIO.println s!"Processing input '{IO.println: {α : Type} → [inst : ToString α] → α → IO Unitc.c: Configinput}' with verbose={input: Config → Stringc.c: Configverbose}" |verbose: Config → BoolExcept.errorExcept.error: {ε : Type ?u.19326} → {α : Type ?u.19325} → ε → Except ε αs =>s: StringIO.printlnIO.println: {α : Type} → [inst : ToString α] → α → IO Unitss: Stringmain [main: List String → IO Unit"--help"] -- Usage: example [--help] [--verbose] [--input <input file>]"--help": Stringmain [main: List String → IO Unit"--input","--input": String"foo"] -- Processing input file 'foo' with verbose=false"foo": Stringmain [main: List String → IO Unit"--verbose","--verbose": String"--input","--input": String"bar"] -- Processing input 'bar' with verbose=true"bar": String
In this example parseArguments
is actually three stacked monads, StateM
, ReaderM
, Except
. Notice
the convention of abbreviating long monadic types with an alias like CliConfigM
.
Monad Lifting
Lean makes it easy to compose functions that use different monads using a concept of automatic monad
lifting. You already used lifting in the above code, because you were able to compose
optionalSwitch
which has type ReaderT Arguments (Except String) Bool
and call it from
parseArguments
which has a bigger type StateT Config (ReaderT Arguments (Except String))
.
This "just worked" because Lean did some magic with monad lifting.
To give you a simpler example of this, suppose you have the following function:
defdivide (divide: Float → Float → ExceptT String Id Floatx :x: FloatFloat ) (Float: Typey :y: FloatFloat):Float: TypeExceptTExceptT: Type → (Type → Type) → Type → TypeStringString: TypeIdId: Type → TypeFloat := ifFloat: Typey ==y: Float0 then0: Floatthrowthrow: {ε : outParam Type} → {m : Type → Type} → [self : MonadExcept ε m] → {α : Type} → ε → m α"can't divide by zero" else"can't divide by zero": Stringpure (pure: {f : Type → Type} → [self : Pure f] → {α : Type} → α → f αx /x: Floaty)y: Floatdividedivide: Float → Float → ExceptT String Id Float66: Float3 -- Except.ok 2.0000003: Floatdividedivide: Float → Float → ExceptT String Id Float11: Float0 -- Except.error "can't divide by zero"0: Float
Notice here we used the ExceptT
transformer, but we composed it with the Id
identity monad.
This is then the same as writing Except String Float
since the identity monad does nothing.
Now suppose you want to count the number of times divide is called and store the result in some global state:
defdivideCounter (divideCounter: Float → Float → StateT Nat (ExceptT String Id) Floatx :x: FloatFloat) (Float: Typey :y: FloatFloat) :Float: TypeStateTStateT: Type → (Type → Type) → Type → TypeNat (Nat: TypeExceptTExceptT: Type → (Type → Type) → Type → TypeStringString: TypeId)Id: Type → TypeFloat := doFloat: Typemodify funmodify: {σ : Type} → {m : Type → Type} → [inst : MonadState σ m] → (σ → σ) → m PUnits =>s: Nats +s: Nat11: Natdividedivide: Float → Float → ExceptT String Id Floatxx: Floatyy: FloatdivideCounterdivideCounter: Float → Float → StateT Nat (ExceptT String Id) Float66: Float3 |>.3: Floatrunrun: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)0 -- Except.ok (2.000000, 1)0: NatdivideCounterdivideCounter: Float → Float → StateT Nat (ExceptT String Id) Float11: Float0 |>.0: Floatrunrun: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)0 -- Except.error "can't divide by zero"0: Nat
The modify
function is a helper which makes it easier to use modifyGet
from the StateM
monad.
But something interesting is happening here, divideCounter
is returning the value of
divide
, but the types don't match, yet it works? This is monad lifting in action.
You can see this more clearly with the following test:
defliftTest (liftTest: Except String Float → StateT Nat (Except String) Floatx :x: Except String FloatExceptExcept: Type → Type → TypeStringString: TypeFloat) :Float: TypeStateTStateT: Type → (Type → Type) → Type → TypeNat (Nat: TypeExceptExcept: Type → Type → TypeString)String: TypeFloat :=Float: Typexx: Except String FloatliftTest (liftTest: Except String Float → StateT Nat (Except String) Floatdividedivide: Float → Float → ExceptT String Id Float55: Float1) |>.1: Floatrunrun: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)3 -- Except.ok (5.000000, 3)3: Nat
Notice that liftTest
returned x
without doing anything to it, yet that matched the return type
StateT Nat (Except String) Float
. Monad lifting is provided by monad transformers. if you
#print liftTest
you will see that Lean is implementing this using a call to a function named
monadLift
from the MonadLift
type class:
class MonadLift (m : Type u → Type v) (n : Type u → Type w) where
monadLift : {α : Type u} → m α → n α
So monadLift
is a function for lifting a computation from an inner Monad m α
to an outer Monad n α
.
You could replace x
in liftTest
with monadLift x
if you want to be explicit about it.
The StateT monad transformer defines an instance of MonadLift
like this:
@[inline] protected def lift {α : Type u} (t : m α) : StateT σ m α :=
fun s => do let a ← t; pure (a, s)
instance : MonadLift m (StateT σ m) := ⟨StateT.lift⟩
This means that any monad m
can be wrapped in a StateT
monad by using the function
fun s => do let a ← t; pure (a, s)
that takes state s
, runs the inner monad action t
, and
returns the result and the new state in a pair (a, s)
without making any changes to s
.
Because MonadLift
is a type class, Lean can automatically find the required monadLift
instances in order to make your code compile and in this way it was able to find the StateT.lift
function and use it to wrap the result of divide
so that the correct type is returned from
divideCounter
.
If you have an instance MonadLift m n
that means there is a way to turn a computation that happens
inside of m
into one that happens inside of n
and (this is the key part) usually without the
instance itself creating any additional data that feeds into the computation. This means you can in
principle declare lifting instances from any monad to any other monad, it does not, however, mean
that you should do this in all cases. You can get a very nice report on how all this was done by
adding the line set_option trace.Meta.synthInstance true in
before divideCounter
and moving you
cursor to the end of the first line after do
.
This was a lot of detail, but it is very important to understand how monad lifting works because it is used heavily in Lean programs.
Transitive lifting
There is also a transitive version of MonadLift
called MonadLiftT
which can lift multiple
monad layers at once. In the following example we added another monad layer with
ReaderT String ...
and notice that x
is also automatically lifted to match.
defliftTest2 (liftTest2: Except String Float → ReaderT String (StateT Nat (Except String)) Floatx :x: Except String FloatExceptExcept: Type → Type → TypeStringString: TypeFloat) :Float: TypeReaderTReaderT: Type → (Type → Type) → Type → TypeString (String: TypeStateTStateT: Type → (Type → Type) → Type → TypeNat (Nat: TypeExceptExcept: Type → Type → TypeString))String: TypeFloat :=Float: Typexx: Except String FloatliftTest2 (liftTest2: Except String Float → ReaderT String (StateT Nat (Except String)) Floatdividedivide: Float → Float → ExceptT String Id Float55: Float1) |>.1: Floatrunrun: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α"" |>."": Stringrunrun: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)3 -- Except.ok (5.000000, 3)3: Nat
The ReaderT monadLift is even simpler than the one for StateT:
instance : MonadLift m (ReaderT ρ m) where
monadLift x := fun _ => x
This lift operation creates a function that defines the required ReaderT
input
argument, but the inner monad doesn't know or care about ReaderT
so the
monadLift function throws it away with the _
then calls the inner monad action x
.
This is a perfectly legal implementation of the ReaderM
monad.
Add your own Custom MonadLift
This does not compile:
defmain2 :main2: IO UnitIOIO: Type → TypeUnit := do try letUnit: Typeret ←ret: ?m.45218IO.println (IO.println: {α : Type} → [inst : ToString α] → α → IO UnittoStringtoString: {α : Type} → [self : ToString α] → α → Stringret) catchret: ?m.45218e =>e: IO.ErrorIO.printlnIO.println: {α : Type} → [inst : ToString α] → α → IO Unitee: IO.Error
saying:
typeclass instance problem is stuck, it is often due to metavariables
ToString ?m.4786
The reason is divideCounter
returns the big StateT Nat (ExceptT String Id) Float
and that type
cannot be automatically lifted into the main
return type of IO Unit
unless you give it some
help.
The following custom MonadLift
solves this problem:
defliftIO (liftIO: {α : Type} → ExceptT String Id α → IO αt :t: ExceptT String Id αExceptTExceptT: Type → (Type → Type) → Type → TypeStringString: TypeIdId: Type → Typeα) :α: TypeIOIO: Type → Typeα := do matchα: Typet with |t: ExceptT String Id α.ok.ok: {ε α : Type} → α → Except ε αr =>r: αEStateM.Result.okEStateM.Result.ok: {ε σ α : Type} → α → σ → EStateM.Result ε σ αr |r: α.error.error: {ε α : Type} → ε → Except ε αs =>s: StringEStateM.Result.errorEStateM.Result.error: {ε σ α : Type} → ε → σ → EStateM.Result ε σ αss: Stringinstance :instance: MonadLift (ExceptT String Id) IOMonadLift (MonadLift: semiOutParam (Type → Type) → (Type → Type) → Type 1ExceptTExceptT: Type → (Type → Type) → Type → TypeStringString: TypeId)Id: Type → TypeIO where monadLift :=IO: Type → TypeliftIO defliftIO: {α : Type} → ExceptT String Id α → IO αmain3 :main3: IO UnitIOIO: Type → TypeUnit := do try letUnit: Typeret ←ret: Float × NatdivideCounterdivideCounter: Float → Float → StateT Nat (ExceptT String Id) Float55: Float2 |>.2: Floatrunrun: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)00: NatIO.println (IO.println: {α : Type} → [inst : ToString α] → α → IO UnittoStringtoString: {α : Type} → [self : ToString α] → α → Stringret) catchret: Float × Nate =>e: IO.ErrorIO.printlnIO.println: {α : Type} → [inst : ToString α] → α → IO Unitee: IO.Errormain3 -- (2.500000, 1)main3: IO Unit
It turns out that the IO
monad you see in your main
function is based on the EStateM.Result
type
which is similar to the Except
type but it has an additional return value. The liftIO
function
converts any Except String α
into IO α
by simply mapping the ok case of the Except
to the
Result.ok
and the error case to the Result.error
.
Lifting ExceptT
In the previous Except section you saw functions that throw
Except
values. When you get all the way back up to your main
function which has type IO Unit
you have
the same problem you had above, because Except String Float
doesn't match even if you use a
try/catch
.
defmain4 :main4: IO UnitIOIO: Type → TypeUnit := do try letUnit: Typeret ←ret: Floatdividedivide: Float → Float → ExceptT String Id Float55: Float00: FloatIO.println (IO.println: {α : Type} → [inst : ToString α] → α → IO UnittoStringtoString: {α : Type} → [self : ToString α] → α → Stringret) -- lifting happens here. catchret: Floate =>e: IO.ErrorIO.println s!"Unhandled exception: {IO.println: {α : Type} → [inst : ToString α] → α → IO Unite}"e: IO.Errormain4 -- Unhandled exception: can't divide by zeromain4: IO Unit
Without the liftIO
the (toString ret)
expression would not compile with a similar error:
typeclass instance problem is stuck, it is often due to metavariables
ToString ?m.6007
So the general lesson is that if you see an error like this when using monads, check for
a missing MonadLift
.
Summary
Now that you know how to combine your monads together, you're almost done with understanding the key concepts of monads! You could probably go out now and start writing some pretty nice code! But to truly master monads, you should know how to make your own, and there's one final concept that you should understand for that. This is the idea of type "laws". Each of the structures you've learned so far has a series of laws associated with it. And for your instances of these classes to make sense, they should follow the laws! Check out Monad Laws.