Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions src/FSharp.Control.AsyncSeq/AsyncSeq.fs
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,20 @@ module AsyncSeqOp =
type OptimizedUnfoldEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) =
let mutable currentState = init
let mutable disposed = false

interface IAsyncEnumerator<'T> with
member __.MoveNext () : Async<'T option> =
member __.MoveNext () : Async<'T option> =
if disposed then async.Return None
else async {
let! result = f currentState
match result with
| None ->
| None ->
return None
| Some (value, nextState) ->
currentState <- nextState
return Some value
}
member __.Dispose () =
member __.Dispose () =
disposed <- true

type UnfoldAsyncEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) =
Expand Down Expand Up @@ -606,10 +606,10 @@ module AsyncSeq =
// Optimized collect implementation using direct field access instead of ref cells
type OptimizedCollectEnumerator<'T, 'U>(f: 'T -> AsyncSeq<'U>, inp: AsyncSeq<'T>) =
// Mutable fields instead of ref cells to reduce allocations
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
let mutable innerEnumerator: IAsyncEnumerator<'U> option = None
let mutable disposed = false

// Tail-recursive optimization to avoid deep continuation chains
let rec moveNextLoop () : Async<'U option> = async {
if disposed then return None
Expand Down Expand Up @@ -642,7 +642,7 @@ module AsyncSeq =
inputEnumerator <- Some newOuter
return! moveNextLoop ()
}

interface IAsyncEnumerator<'U> with
member _.MoveNext() = moveNextLoop ()
member _.Dispose() =
Expand All @@ -651,13 +651,13 @@ module AsyncSeq =
match innerEnumerator with
| Some inner -> inner.Dispose(); innerEnumerator <- None
| None -> ()
match inputEnumerator with
match inputEnumerator with
| Some outer -> outer.Dispose(); inputEnumerator <- None
| None -> ()

let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
{ new IAsyncEnumerable<'U> with
member _.GetEnumerator() =
member _.GetEnumerator() =
new OptimizedCollectEnumerator<'T, 'U>(f, inp) :> IAsyncEnumerator<'U> }

// let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
Expand Down Expand Up @@ -749,7 +749,7 @@ module AsyncSeq =
// Optimized iterAsync implementation to reduce allocations
type internal OptimizedIterAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: 'T -> Async<unit>) =
let mutable disposed = false

member _.IterateAsync() =
let rec loop() = async {
let! next = enumerator.MoveNext()
Expand All @@ -760,17 +760,17 @@ module AsyncSeq =
| None -> return ()
}
loop()

interface IDisposable with
member _.Dispose() =
if not disposed then
disposed <- true
enumerator.Dispose()

// Optimized iteriAsync implementation with direct tail recursion
// Optimized iteriAsync implementation with direct tail recursion
type internal OptimizedIteriAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: int -> 'T -> Async<unit>) =
let mutable disposed = false

member _.IterateAsync() =
let rec loop count = async {
let! next = enumerator.MoveNext()
Expand All @@ -781,7 +781,7 @@ module AsyncSeq =
| None -> return ()
}
loop 0

interface IDisposable with
member _.Dispose() =
if not disposed then
Expand All @@ -798,7 +798,7 @@ module AsyncSeq =
let iterAsync (f: 'T -> Async<unit>) (source: AsyncSeq<'T>) =
match source with
| :? AsyncSeqOp<'T> as source -> source.IterAsync f
| _ ->
| _ ->
async {
let enum = source.GetEnumerator()
use optimizer = new OptimizedIterAsyncEnumerator<_>(enum, f)
Expand Down Expand Up @@ -864,7 +864,7 @@ module AsyncSeq =
// Optimized mapAsync enumerator that avoids computation builder overhead
type private OptimizedMapAsyncEnumerator<'T, 'TResult>(source: IAsyncEnumerator<'T>, f: 'T -> Async<'TResult>) =
let mutable disposed = false

interface IAsyncEnumerator<'TResult> with
member _.MoveNext() = async {
let! moveResult = source.MoveNext()
Expand All @@ -874,7 +874,7 @@ module AsyncSeq =
let! mapped = f value
return Some mapped
}

member _.Dispose() =
if not disposed then
disposed <- true
Expand All @@ -885,7 +885,7 @@ module AsyncSeq =
| :? AsyncSeqOp<'T> as source -> source.MapAsync f
| _ ->
{ new IAsyncEnumerable<'TResult> with
member _.GetEnumerator() =
member _.GetEnumerator() =
new OptimizedMapAsyncEnumerator<'T, 'TResult>(source.GetEnumerator(), f) :> IAsyncEnumerator<'TResult> }

let mapiAsync f (source : AsyncSeq<'T>) : AsyncSeq<'TResult> = asyncSeq {
Expand Down Expand Up @@ -1125,6 +1125,22 @@ module AsyncSeq =
let filter f (source : AsyncSeq<'T>) =
filterAsync (f >> async.Return) source

let chunkBySize (chunkSize: int) (source: AsyncSeq<'T>) : AsyncSeq<'T array> =
if chunkSize < 1 then
invalidArg (nameof chunkSize) "must be greater than zero"
asyncSeq {
use enumerator = source.GetEnumerator()
let mutable isFinished = false
while not isFinished do
let chunk = ResizeArray<'T>(chunkSize)
while chunk.Count < chunkSize && not isFinished do
match! enumerator.MoveNext() with
| Some item -> chunk.Add(item)
| None -> isFinished <- true
if chunk.Count > 0 then
yield chunk.ToArray()
}

#if !FABLE_COMPILER
let iterAsyncParallel (f:'a -> Async<unit>) (s:AsyncSeq<'a>) : Async<unit> = async {
use mb = MailboxProcessor.Start (ignore >> async.Return)
Expand Down
5 changes: 4 additions & 1 deletion src/FSharp.Control.AsyncSeq/AsyncSeq.fsi
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ module AsyncSeq =
/// and processes the input element immediately.
val filter : predicate:('T -> bool) -> source:AsyncSeq<'T> -> AsyncSeq<'T>

/// Buffers elements up to a given chunk size and then yields the entire chunk.
val chunkBySize : chunkSize: int -> source: AsyncSeq<'T> -> AsyncSeq<'T array>

/// Creates an asynchronous sequence that lazily takes element from an
/// input synchronous sequence and returns them one-by-one.
val ofSeq : source:seq<'T> -> AsyncSeq<'T>
Expand Down Expand Up @@ -524,7 +527,7 @@ module AsyncSeq =
/// Builds a new asynchronous sequence whose elements are generated by
/// applying the specified function to all elements of the input sequence.
///
/// The function is applied to elements in parallel, and results are emitted
/// The function is applied to elements in parallel, and results are emitted
/// in the order they complete (unordered), without preserving the original order.
/// This can provide better performance than mapAsyncParallel when order doesn't matter.
/// Parallelism is bound by the ThreadPool.
Expand Down
17 changes: 17 additions & 0 deletions tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,23 @@ let ``AsyncSeq.filter``() =
let expected = ls |> Seq.filter p |> AsyncSeq.ofSeq
Assert.True(EQ expected actual)

[<Test>]
let ``AsyncSeq.chunkBySize``() =
let input = [ "a"; "b"; "c"; "d"; "e" ]
let actual =
input
|> AsyncSeq.ofSeq
|> AsyncSeq.chunkBySize 2
|> AsyncSeq.toListSynchronously
|> List.map List.ofSeq
let expected =
[
[ "a"; "b" ]
[ "c"; "d" ]
[ "e" ]
]
Assert.AreEqual(expected, actual)

[<Test>]
let ``AsyncSeq.merge``() =
let ls1 = [1;2;3;4;5]
Expand Down