Skip to content
Merged
97 changes: 55 additions & 42 deletions src/FSharp.Control.AsyncSeq/AsyncSeq.fs
Original file line number Diff line number Diff line change
Expand Up @@ -591,49 +591,62 @@ module AsyncSeq =
| HaveInnerEnumerator of IAsyncEnumerator<'T> * IAsyncEnumerator<'U>
| Finished

// Optimized collect implementation using direct field access instead of ref cells
type OptimizedCollectEnumerator<'T, 'U>(f: 'T -> AsyncSeq<'U>, inp: AsyncSeq<'T>) =
// Mutable fields instead of ref cells to reduce allocations
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
let mutable innerEnumerator: IAsyncEnumerator<'U> option = None
let mutable disposed = false

// Tail-recursive optimization to avoid deep continuation chains
let rec moveNextLoop () : Async<'U option> = async {
if disposed then return None
else
match innerEnumerator with
| Some inner ->
let! result = inner.MoveNext()
match result with
| Some value -> return Some value
| None ->
inner.Dispose()
innerEnumerator <- None
return! moveNextLoop ()
| None ->
match inputEnumerator with
| Some outer ->
let! result = outer.MoveNext()
match result with
| Some value ->
let newInner = (f value).GetEnumerator()
innerEnumerator <- Some newInner
return! moveNextLoop ()
| None ->
outer.Dispose()
inputEnumerator <- None
disposed <- true
return None
| None ->
let newOuter = inp.GetEnumerator()
inputEnumerator <- Some newOuter
return! moveNextLoop ()
}

interface IAsyncEnumerator<'U> with
member _.MoveNext() = moveNextLoop ()
member _.Dispose() =
if not disposed then
disposed <- true
match innerEnumerator with
| Some inner -> inner.Dispose(); innerEnumerator <- None
| None -> ()
match inputEnumerator with
| Some outer -> outer.Dispose(); inputEnumerator <- None
| None -> ()

let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
{ new IAsyncEnumerable<'U> with
member x.GetEnumerator() =
let state = ref (CollectState.NotStarted inp)
{ new IAsyncEnumerator<'U> with
member x.MoveNext() =
async { match !state with
| CollectState.NotStarted inp ->
return!
(let e1 = inp.GetEnumerator()
state := CollectState.HaveInputEnumerator e1
x.MoveNext())
| CollectState.HaveInputEnumerator e1 ->
let! res1 = e1.MoveNext()
return!
(match res1 with
| Some v1 ->
let e2 = (f v1).GetEnumerator()
state := CollectState.HaveInnerEnumerator (e1, e2)
| None ->
x.Dispose()
x.MoveNext())
| CollectState.HaveInnerEnumerator (e1, e2) ->
let! res2 = e2.MoveNext()
match res2 with
| None ->
state := CollectState.HaveInputEnumerator e1
dispose e2
return! x.MoveNext()
| Some _ ->
return res2
| _ ->
return None }
member x.Dispose() =
match !state with
| CollectState.HaveInputEnumerator e1 ->
state := CollectState.Finished
dispose e1
| CollectState.HaveInnerEnumerator (e1, e2) ->
state := CollectState.Finished
dispose e2
dispose e1
| _ -> () } }
{ new IAsyncEnumerable<'U> with
member _.GetEnumerator() =
new OptimizedCollectEnumerator<'T, 'U>(f, inp) :> IAsyncEnumerator<'U> }

// let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
// AsyncGenerator.collect f inp
Expand Down