using Cysharp.Threading.Tasks.Internal; using System; using System.Collections.Generic; using System.Runtime.ExceptionServices; using System.Threading; namespace Cysharp.Threading.Tasks { public partial struct UniTask { public static IUniTaskAsyncEnumerable> WhenEach(IEnumerable> tasks) { return new WhenEachEnumerable(tasks); } public static IUniTaskAsyncEnumerable> WhenEach(params UniTask[] tasks) { return new WhenEachEnumerable(tasks); } } public readonly struct WhenEachResult { public T Result { get; } public Exception Exception { get; } //[MemberNotNullWhen(false, nameof(Exception))] public bool IsCompletedSuccessfully => Exception == null; //[MemberNotNullWhen(true, nameof(Exception))] public bool IsFaulted => Exception != null; public WhenEachResult(T result) { this.Result = result; this.Exception = null; } public WhenEachResult(Exception exception) { if (exception == null) throw new ArgumentNullException(nameof(exception)); this.Result = default; this.Exception = exception; } public void TryThrow() { if (IsFaulted) { ExceptionDispatchInfo.Capture(Exception).Throw(); } } public T GetResult() { if (IsFaulted) { ExceptionDispatchInfo.Capture(Exception).Throw(); } return Result; } public override string ToString() { if (IsCompletedSuccessfully) { return Result?.ToString() ?? ""; } else { return $"Exception{{{Exception.Message}}}"; } } } internal enum WhenEachState : byte { NotRunning, Running, Completed } internal sealed class WhenEachEnumerable : IUniTaskAsyncEnumerable> { IEnumerable> source; public WhenEachEnumerable(IEnumerable> source) { this.source = source; } public IUniTaskAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new Enumerator(source, cancellationToken); } sealed class Enumerator : IUniTaskAsyncEnumerator> { readonly IEnumerable> source; CancellationToken cancellationToken; Channel> channel; IUniTaskAsyncEnumerator> channelEnumerator; int completeCount; WhenEachState state; public Enumerator(IEnumerable> source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; } public WhenEachResult Current => channelEnumerator.Current; public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); if (state == WhenEachState.NotRunning) { state = WhenEachState.Running; channel = Channel.CreateSingleConsumerUnbounded>(); channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken); if (source is UniTask[] array) { ConsumeAll(this, array, array.Length); } else { using (var rentArray = ArrayPoolUtil.Materialize(source)) { ConsumeAll(this, rentArray.Array, rentArray.Length); } } } return channelEnumerator.MoveNextAsync(); } static void ConsumeAll(Enumerator self, UniTask[] array, int length) { for (int i = 0; i < length; i++) { RunWhenEachTask(self, array[i], length).Forget(); } } static async UniTaskVoid RunWhenEachTask(Enumerator self, UniTask task, int length) { try { var result = await task; self.channel.Writer.TryWrite(new WhenEachResult(result)); } catch (Exception ex) { self.channel.Writer.TryWrite(new WhenEachResult(ex)); } if (Interlocked.Increment(ref self.completeCount) == length) { self.state = WhenEachState.Completed; self.channel.Writer.TryComplete(); } } public async UniTask DisposeAsync() { if (channelEnumerator != null) { await channelEnumerator.DisposeAsync(); } if (state != WhenEachState.Completed) { state = WhenEachState.Completed; channel.Writer.TryComplete(new OperationCanceledException()); } } } } }