WhenAll and WhenAny

This commit is contained in:
neuecc
2020-04-21 13:36:23 +09:00
parent 082f3e7335
commit 3654a9e2f9
16 changed files with 11143 additions and 3797 deletions

View File

@@ -2,370 +2,365 @@
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
using System;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Collections.Generic;
using System.Threading;
using UniRx.Async.Internal;
namespace UniRx.Async
{
public partial struct UniTask
{
// UniTask
public static async UniTask<(bool hasResultLeft, T0 result)> WhenAny<T0>(UniTask<T0> task0, UniTask task1)
public static UniTask<(bool hasResultLeft, T result)> WhenAny<T>(UniTask<T> leftTask, UniTask rightTask)
{
return await new UnitWhenAnyPromise<T0>(task0, task1);
return new UniTask<(bool, T)>(new WhenAnyLRPromise<T>(leftTask, rightTask), 0);
}
public static async UniTask<(int winArgumentIndex, T result)> WhenAny<T>(params UniTask<T>[] tasks)
public static UniTask<(int winArgumentIndex, T result)> WhenAny<T>(params UniTask<T>[] tasks)
{
return await new WhenAnyPromise<T>(tasks);
return new UniTask<(int, T)>(new WhenAnyPromise<T>(tasks, tasks.Length), 0);
}
public static UniTask<(int winArgumentIndex, T result)> WhenAny<T>(IEnumerable<UniTask<T>> tasks)
{
using (var span = ArrayPoolUtil.Materialize(tasks))
{
return new UniTask<(int, T)>(new WhenAnyPromise<T>(span.Array, span.Length), 0);
}
}
/// <summary>Return value is winArgumentIndex</summary>
public static async UniTask<int> WhenAny(params UniTask[] tasks)
public static UniTask<int> WhenAny(params UniTask[] tasks)
{
return await new WhenAnyPromise(tasks);
return new UniTask<int>(new WhenAnyPromise(tasks, tasks.Length), 0);
}
class UnitWhenAnyPromise<T0>
/// <summary>Return value is winArgumentIndex</summary>
public static UniTask<int> WhenAny(IEnumerable<UniTask> tasks)
{
T0 result0;
ExceptionDispatchInfo exception;
Action whenComplete;
int completeCount;
using (var span = ArrayPoolUtil.Materialize(tasks))
{
return new UniTask<int>(new WhenAnyPromise(span.Array, span.Length), 0);
}
}
sealed class WhenAnyLRPromise<T> : IUniTaskSource<(bool, T)>
{
int completedCount;
int winArgumentIndex;
UniTaskCompletionSourceCore<(bool, T)> core;
bool IsCompleted => exception != null || Volatile.Read(ref winArgumentIndex) != -1;
public UnitWhenAnyPromise(UniTask<T0> task0, UniTask task1)
public WhenAnyLRPromise(UniTask<T> leftTask, UniTask rightTask)
{
this.whenComplete = null;
this.exception = null;
this.completeCount = 0;
this.winArgumentIndex = -1;
this.result0 = default(T0);
TaskTracker2.TrackActiveTask(this, 3);
RunTask0(task0).Forget();
RunTask1(task1).Forget();
}
void TryCallContinuation()
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
}
async UniTaskVoid RunTask0(UniTask<T0> task)
{
T0 value;
try
{
value = await task;
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
{
result0 = value;
Volatile.Write(ref winArgumentIndex, 0);
TryCallContinuation();
}
}
async UniTaskVoid RunTask1(UniTask task)
{
try
{
await task;
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
{
Volatile.Write(ref winArgumentIndex, 1);
TryCallContinuation();
}
}
public Awaiter GetAwaiter()
{
return new Awaiter(this);
}
public struct Awaiter : ICriticalNotifyCompletion
{
UnitWhenAnyPromise<T0> parent;
public Awaiter(UnitWhenAnyPromise<T0> parent)
{
this.parent = parent;
}
public bool IsCompleted
{
get
UniTask<T>.Awaiter awaiter;
try
{
return parent.IsCompleted;
awaiter = leftTask.GetAwaiter();
}
}
public (bool, T0) GetResult()
{
if (parent.exception != null)
catch (Exception ex)
{
parent.exception.Throw();
core.TrySetException(ex);
goto RIGHT;
}
return (parent.winArgumentIndex == 0, parent.result0);
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
if (awaiter.IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
TryLeftInvokeContinuation(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
action();
}
using (var t = (StateTuple<WhenAnyLRPromise<T>, UniTask<T>.Awaiter>)state)
{
TryLeftInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
RIGHT:
{
UniTask.Awaiter awaiter;
try
{
awaiter = rightTask.GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
return;
}
if (awaiter.IsCompleted)
{
TryRightInvokeContinuation(this, awaiter);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyLRPromise<T>, UniTask.Awaiter>)state)
{
TryRightInvokeContinuation(t.Item1, t.Item2);
}
}, StateTuple.Create(this, awaiter));
}
}
}
static void TryLeftInvokeContinuation(WhenAnyLRPromise<T> self, in UniTask<T>.Awaiter awaiter)
{
T result;
try
{
result = awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((true, result));
}
}
static void TryRightInvokeContinuation(WhenAnyLRPromise<T> self, in UniTask.Awaiter awaiter)
{
try
{
awaiter.GetResult();
}
catch (Exception ex)
{
self.core.TrySetException(ex);
return;
}
if (Interlocked.Increment(ref self.completedCount) == 1)
{
self.core.TrySetResult((false, default));
}
}
public (bool, T) GetResult(short token)
{
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
~WhenAnyLRPromise()
{
core.Reset();
}
}
class WhenAnyPromise<T>
sealed class WhenAnyPromise<T> : IUniTaskSource<(int, T)>
{
T result;
int completeCount;
int completedCount;
int winArgumentIndex;
Action whenComplete;
ExceptionDispatchInfo exception;
UniTaskCompletionSourceCore<(int, T)> core;
public bool IsComplete => exception != null || Volatile.Read(ref winArgumentIndex) != -1;
public WhenAnyPromise(UniTask<T>[] tasks)
public WhenAnyPromise(UniTask<T>[] tasks, int tasksLength)
{
this.completeCount = 0;
this.winArgumentIndex = -1;
this.whenComplete = null;
this.exception = null;
this.result = default(T);
TaskTracker2.TrackActiveTask(this, 3);
for (int i = 0; i < tasks.Length; i++)
for (int i = 0; i < tasksLength; i++)
{
RunTask(tasks[i], i).Forget();
UniTask<T>.Awaiter awaiter;
try
{
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue; // consume others.
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise<T>, UniTask<T>.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
}
}
async UniTaskVoid RunTask(UniTask<T> task, int index)
static void TryInvokeContinuation(WhenAnyPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
{
T value;
T result;
try
{
value = await task;
result = awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
if (Interlocked.Increment(ref self.completedCount) == 1)
{
result = value;
Volatile.Write(ref winArgumentIndex, index);
TryCallContinuation();
self.core.TrySetResult((i, result));
}
}
void TryCallContinuation()
public (int, T) GetResult(short token)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public Awaiter GetAwaiter()
public UniTaskStatus GetStatus(short token)
{
return new Awaiter(this);
return core.GetStatus(token);
}
public struct Awaiter : ICriticalNotifyCompletion
public void OnCompleted(Action<object> continuation, object state, short token)
{
WhenAnyPromise<T> parent;
core.OnCompleted(continuation, state, token);
}
public Awaiter(WhenAnyPromise<T> parent)
{
this.parent = parent;
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public bool IsCompleted
{
get
{
return parent.IsComplete;
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public (int, T) GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
return (parent.winArgumentIndex, parent.result);
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAnyPromise()
{
core.Reset();
}
}
class WhenAnyPromise
sealed class WhenAnyPromise : IUniTaskSource<int>
{
int completeCount;
int completedCount;
int winArgumentIndex;
Action whenComplete;
ExceptionDispatchInfo exception;
UniTaskCompletionSourceCore<int> core;
public bool IsComplete => exception != null || Volatile.Read(ref winArgumentIndex) != -1;
public WhenAnyPromise(UniTask[] tasks)
public WhenAnyPromise(UniTask[] tasks, int tasksLength)
{
this.completeCount = 0;
this.winArgumentIndex = -1;
this.whenComplete = null;
this.exception = null;
TaskTracker2.TrackActiveTask(this, 3);
for (int i = 0; i < tasks.Length; i++)
for (int i = 0; i < tasksLength; i++)
{
RunTask(tasks[i], i).Forget();
UniTask.Awaiter awaiter;
try
{
awaiter = tasks[i].GetAwaiter();
}
catch (Exception ex)
{
core.TrySetException(ex);
continue; // consume others.
}
if (awaiter.IsCompleted)
{
TryInvokeContinuation(this, awaiter, i);
}
else
{
awaiter.SourceOnCompleted(state =>
{
using (var t = (StateTuple<WhenAnyPromise, UniTask.Awaiter, int>)state)
{
TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
}
}, StateTuple.Create(this, awaiter, i));
}
}
}
async UniTaskVoid RunTask(UniTask task, int index)
static void TryInvokeContinuation(WhenAnyPromise self, in UniTask.Awaiter awaiter, int i)
{
try
{
await task;
awaiter.GetResult();
}
catch (Exception ex)
{
exception = ExceptionDispatchInfo.Capture(ex);
TryCallContinuation();
self.core.TrySetException(ex);
return;
}
var count = Interlocked.Increment(ref completeCount);
if (count == 1)
if (Interlocked.Increment(ref self.completedCount) == 1)
{
Volatile.Write(ref winArgumentIndex, index);
TryCallContinuation();
self.core.TrySetResult(i);
}
}
void TryCallContinuation()
public int GetResult(short token)
{
var action = Interlocked.Exchange(ref whenComplete, null);
if (action != null)
{
action.Invoke();
}
TaskTracker2.RemoveTracking(this);
GC.SuppressFinalize(this);
return core.GetResult(token);
}
public Awaiter GetAwaiter()
public UniTaskStatus GetStatus(short token)
{
return new Awaiter(this);
return core.GetStatus(token);
}
public struct Awaiter : ICriticalNotifyCompletion
public void OnCompleted(Action<object> continuation, object state, short token)
{
WhenAnyPromise parent;
core.OnCompleted(continuation, state, token);
}
public Awaiter(WhenAnyPromise parent)
{
this.parent = parent;
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public bool IsCompleted
{
get
{
return parent.IsComplete;
}
}
void IUniTaskSource.GetResult(short token)
{
GetResult(token);
}
public int GetResult()
{
if (parent.exception != null)
{
parent.exception.Throw();
}
return parent.winArgumentIndex;
}
public void OnCompleted(Action continuation)
{
UnsafeOnCompleted(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
parent.whenComplete = continuation;
if (IsCompleted)
{
var action = Interlocked.Exchange(ref parent.whenComplete, null);
if (action != null)
{
action();
}
}
}
~WhenAnyPromise()
{
core.Reset();
}
}
}