WhenAll and WhenAny

This commit is contained in:
neuecc
2020-04-21 13:36:23 +09:00
parent 082f3e7335
commit 3654a9e2f9
16 changed files with 11143 additions and 3797 deletions

View File

@@ -3,8 +3,6 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using UniRx.Async.Internal;
@@ -12,285 +10,214 @@ namespace UniRx.Async
{
public partial struct UniTask
{
// UniTask
public static async UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
public static UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
{
return await new WhenAllPromise<T>(tasks, tasks.Length);
return new UniTask<T[]>(new WhenAllPromise<T>(tasks, tasks.Length), 0);
}
public static async UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
public static UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
{
WhenAllPromise<T> promise;
using (var span = ArrayPoolUtil.Materialize(tasks))
{
promise = new WhenAllPromise<T>(span.Array, span.Length);
var promise = new WhenAllPromise<T>(span.Array, span.Length); // consumed array in constructor.
return new UniTask<T[]>(promise, 0);
}
return await promise;
}
public static async UniTask WhenAll(params UniTask[] tasks)
public static UniTask WhenAll(params UniTask[] tasks)
{
await new WhenAllPromise(tasks, tasks.Length);
return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0);
}
public static async UniTask WhenAll(IEnumerable<UniTask> tasks)
public static UniTask WhenAll(IEnumerable<UniTask> tasks)
{
WhenAllPromise promise;
using (var span = ArrayPoolUtil.Materialize(tasks))
{
promise = new WhenAllPromise(span.Array, span.Length);
var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor.
return new UniTask(promise, 0);
}
await promise;
}
class WhenAllPromise<T>
sealed class WhenAllPromise<T> : IUniTaskSource<T[]>
{
readonly T[] result;
T[] result;
int completeCount;
Action whenComplete;
ExceptionDispatchInfo exception;
UniTaskCompletionSourceCore<T[]> core; // don't reset(called after GetResult, will invoke TrySetException.)
public WhenAllPromise(UniTask<T>[] tasks, int tasksLength)
{
TaskTracker2.TrackActiveTask(this, 3);
this.completeCount = 0;
this.whenComplete = null;
this.exception = null;
this.result = new T[tasksLength];
for (int i = 0; i < tasksLength; i++)
{
if (tasks[i].Status.IsCompleted())
UniTask<T>.Awaiter awaiter;
try
{
T value = default(T);
try
{
value = tasks[i].GetAwaiter().GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
continue;
}
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue;
}
result[i] = value;
var count = Interlocked.Increment(ref completeCount);
if (count == result.Length)
{
TryCallContinuation();
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
RunTask(tasks[i], i).Forget();
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise<T>, UniTask<T>.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
}
}
void TryCallContinuation()
static void TryInvokeContinuation(WhenAllPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask(UniTask<T> task, int index)
{
T value = default(T);
try
{
value = await task;
self.result[i] = awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
result[index] = value;
var count = Interlocked.Increment(ref completeCount);
if (count == result.Length)
if (Interlocked.Increment(ref self.completeCount) == self.result.Length)
{
TryCallContinuation();
self.core.TrySetResult(self.result);
}
}
public Awaiter GetAwaiter()
public T[] GetResult(short token)
{
return new Awaiter(this);
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public struct Awaiter : ICriticalNotifyCompletion
void IUniTaskSource.GetResult(short token)
{
WhenAllPromise<T> parent;
GetResult(token);
}
public Awaiter(WhenAllPromise<T> parent)
{
this.parent = parent;
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public bool IsCompleted
{
get
{
return parent.exception != null || parent.result.Length == parent.completeCount;
}
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public T[] GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
return parent.result;
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAllPromise()
{
core.Reset();
}
}
class WhenAllPromise
sealed class WhenAllPromise : IUniTaskSource
{
int completeCount;
int resultLength;
Action whenComplete;
ExceptionDispatchInfo exception;
int tasksLength;
UniTaskCompletionSourceCore<AsyncUnit> core; // don't reset(called after GetResult, will invoke TrySetException.)
public WhenAllPromise(UniTask[] tasks, int tasksLength)
{
TaskTracker2.TrackActiveTask(this, 3);
this.tasksLength = tasksLength;
this.completeCount = 0;
this.whenComplete = null;
this.exception = null;
this.resultLength = tasksLength;
for (int i = 0; i < tasksLength; i++)
{
if (tasks[i].Status.IsCompleted())
UniTask.Awaiter awaiter;
try
{
try
{
tasks[i].GetAwaiter().GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
continue;
}
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue;
}
var count = Interlocked.Increment(ref completeCount);
if (count == resultLength)
{
TryCallContinuation();
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter);
}
else
{
RunTask(tasks[i], i).Forget();
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAllPromise, UniTask.Awaiter>)state)
{
TryInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
}
void TryCallContinuation()
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask(UniTask task, int index)
static void TryInvokeContinuation(WhenAllPromise self, in UniTask.Awaiter awaiter)
{
try
{
await task;
awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == resultLength)
if (Interlocked.Increment(ref self.completeCount) == self.tasksLength)
{
TryCallContinuation();
self.core.TrySetResult(AsyncUnit.Default);
}
}
public Awaiter GetAwaiter()
public void GetResult(short token)
{
return new Awaiter(this);
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
core.GetResult(token);
}
public struct Awaiter : ICriticalNotifyCompletion
public UniTaskStatus GetStatus(short token)
{
WhenAllPromise parent;
return core.GetStatus(token);
}
public Awaiter(WhenAllPromise parent)
{
this.parent = parent;
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public bool IsCompleted
{
get
{
return parent.exception != null || parent.resultLength == parent.completeCount;
}
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public void GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAllPromise()
{
core.Reset();
}
}
}