2025-04-21 21:14:23 +08:00

184 lines
5.5 KiB
C#

using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Runtime.ExceptionServices;
using System.Threading;
namespace Cysharp.Threading.Tasks
{
public partial struct UniTask
{
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(IEnumerable<UniTask<T>> tasks)
{
return new WhenEachEnumerable<T>(tasks);
}
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(params UniTask<T>[] tasks)
{
return new WhenEachEnumerable<T>(tasks);
}
}
public readonly struct WhenEachResult<T>
{
public T Result { get; }
public Exception Exception { get; }
//[MemberNotNullWhen(false, nameof(Exception))]
public bool IsCompletedSuccessfully => Exception == null;
//[MemberNotNullWhen(true, nameof(Exception))]
public bool IsFaulted => Exception != null;
public WhenEachResult(T result)
{
this.Result = result;
this.Exception = null;
}
public WhenEachResult(Exception exception)
{
if (exception == null) throw new ArgumentNullException(nameof(exception));
this.Result = default;
this.Exception = exception;
}
public void TryThrow()
{
if (IsFaulted)
{
ExceptionDispatchInfo.Capture(Exception).Throw();
}
}
public T GetResult()
{
if (IsFaulted)
{
ExceptionDispatchInfo.Capture(Exception).Throw();
}
return Result;
}
public override string ToString()
{
if (IsCompletedSuccessfully)
{
return Result?.ToString() ?? "";
}
else
{
return $"Exception{{{Exception.Message}}}";
}
}
}
internal enum WhenEachState : byte
{
NotRunning,
Running,
Completed
}
internal sealed class WhenEachEnumerable<T> : IUniTaskAsyncEnumerable<WhenEachResult<T>>
{
IEnumerable<UniTask<T>> source;
public WhenEachEnumerable(IEnumerable<UniTask<T>> source)
{
this.source = source;
}
public IUniTaskAsyncEnumerator<WhenEachResult<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(source, cancellationToken);
}
sealed class Enumerator : IUniTaskAsyncEnumerator<WhenEachResult<T>>
{
readonly IEnumerable<UniTask<T>> source;
CancellationToken cancellationToken;
Channel<WhenEachResult<T>> channel;
IUniTaskAsyncEnumerator<WhenEachResult<T>> channelEnumerator;
int completeCount;
WhenEachState state;
public Enumerator(IEnumerable<UniTask<T>> source, CancellationToken cancellationToken)
{
this.source = source;
this.cancellationToken = cancellationToken;
}
public WhenEachResult<T> Current => channelEnumerator.Current;
public UniTask<bool> MoveNextAsync()
{
cancellationToken.ThrowIfCancellationRequested();
if (state == WhenEachState.NotRunning)
{
state = WhenEachState.Running;
channel = Channel.CreateSingleConsumerUnbounded<WhenEachResult<T>>();
channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken);
if (source is UniTask<T>[] array)
{
ConsumeAll(this, array, array.Length);
}
else
{
using (var rentArray = ArrayPoolUtil.Materialize(source))
{
ConsumeAll(this, rentArray.Array, rentArray.Length);
}
}
}
return channelEnumerator.MoveNextAsync();
}
static void ConsumeAll(Enumerator self, UniTask<T>[] array, int length)
{
for (int i = 0; i < length; i++)
{
RunWhenEachTask(self, array[i], length).Forget();
}
}
static async UniTaskVoid RunWhenEachTask(Enumerator self, UniTask<T> task, int length)
{
try
{
var result = await task;
self.channel.Writer.TryWrite(new WhenEachResult<T>(result));
}
catch (Exception ex)
{
self.channel.Writer.TryWrite(new WhenEachResult<T>(ex));
}
if (Interlocked.Increment(ref self.completeCount) == length)
{
self.state = WhenEachState.Completed;
self.channel.Writer.TryComplete();
}
}
public async UniTask DisposeAsync()
{
if (channelEnumerator != null)
{
await channelEnumerator.DisposeAsync();
}
if (state != WhenEachState.Completed)
{
state = WhenEachState.Completed;
channel.Writer.TryComplete(new OperationCanceledException());
}
}
}
}
}