using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public abstract class MoveNextSource : IUniTaskSource { protected UniTaskCompletionSourceCore completionSource; public bool GetResult(short token) { return completionSource.GetResult(token); } public UniTaskStatus GetStatus(short token) { return completionSource.GetStatus(token); } public void OnCompleted(Action continuation, object state, short token) { completionSource.OnCompleted(continuation, state, token); } public UniTaskStatus UnsafeGetStatus() { return completionSource.UnsafeGetStatus(); } void IUniTaskSource.GetResult(short token) { completionSource.GetResult(token); } protected bool TryGetResult(UniTask.Awaiter awaiter, out T result) { try { result = awaiter.GetResult(); return true; } catch (Exception ex) { completionSource.TrySetException(ex); result = default; return false; } } protected bool TryGetResult(UniTask.Awaiter awaiter) { try { awaiter.GetResult(); return true; } catch (Exception ex) { completionSource.TrySetException(ex); return false; } } } public abstract class AsyncEnumeratorBase : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action moveNextCallbackDelegate = MoveNextCallBack; readonly IUniTaskAsyncEnumerable source; protected CancellationToken cancellationToken; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter sourceMoveNext; public AsyncEnumeratorBase(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; } // abstract /// /// If return value is false, continue source.MoveNext. /// protected abstract bool TryMoveNextCore(bool sourceHasCurrent, out bool result); // Util protected TSource SourceCurrent => enumerator.Current; // IUniTaskAsyncEnumerator public TResult Current { get; protected set; } public UniTask MoveNextAsync() { if (enumerator == null) { enumerator = source.GetAsyncEnumerator(cancellationToken); } completionSource.Reset(); SourceMoveNext(); return new UniTask(this, completionSource.Version); } protected void SourceMoveNext() { CONTINUE: sourceMoveNext = enumerator.MoveNextAsync().GetAwaiter(); if (sourceMoveNext.IsCompleted) { bool result = false; try { if (!TryMoveNextCore(sourceMoveNext.GetResult(), out result)) { goto CONTINUE; } } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (cancellationToken.IsCancellationRequested) { completionSource.TrySetCanceled(cancellationToken); } else { completionSource.TrySetResult(result); } } else { sourceMoveNext.SourceOnCompleted(moveNextCallbackDelegate, this); } } static void MoveNextCallBack(object state) { var self = (AsyncEnumeratorBase)state; bool result; try { if (!self.TryMoveNextCore(self.sourceMoveNext.GetResult(), out result)) { self.SourceMoveNext(); return; } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { self.completionSource.TrySetResult(result); } } // if require additional resource to dispose, override and call base.DisposeAsync. public virtual UniTask DisposeAsync() { if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } public abstract class AsyncEnumeratorAwaitSelectorBase : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action moveNextCallbackDelegate = MoveNextCallBack; static readonly Action setCurrentCallbackDelegate = SetCurrentCallBack; readonly IUniTaskAsyncEnumerable source; protected CancellationToken cancellationToken; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter sourceMoveNext; UniTask.Awaiter resultAwaiter; public AsyncEnumeratorAwaitSelectorBase(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; } // abstract protected abstract UniTask TransformAsync(TSource sourceCurrent); protected abstract bool TrySetCurrentCore(TAwait awaitResult); // Util protected TSource SourceCurrent => enumerator.Current; protected (bool waitCallback, bool requireNextIteration) ActionCompleted(bool trySetCurrentResult, out bool moveNextResult) { if (trySetCurrentResult) { moveNextResult = true; return (false, false); } else { moveNextResult = default; return (false, true); } } protected (bool waitCallback, bool requireNextIteration) WaitAwaitCallback(out bool moveNextResult) { moveNextResult = default; return (true, false); } protected (bool waitCallback, bool requireNextIteration) IterateFinished(out bool moveNextResult) { moveNextResult = false; return (false, false); } // IUniTaskAsyncEnumerator public TResult Current { get; protected set; } public UniTask MoveNextAsync() { if (enumerator == null) { enumerator = source.GetAsyncEnumerator(cancellationToken); } completionSource.Reset(); SourceMoveNext(); return new UniTask(this, completionSource.Version); } protected void SourceMoveNext() { CONTINUE: sourceMoveNext = enumerator.MoveNextAsync().GetAwaiter(); if (sourceMoveNext.IsCompleted) { bool result = false; try { (bool waitCallback, bool requireNextIteration) = TryMoveNextCore(sourceMoveNext.GetResult(), out result); if (waitCallback) { return; } if (requireNextIteration) { goto CONTINUE; } else { completionSource.TrySetResult(result); } } catch (Exception ex) { completionSource.TrySetException(ex); return; } } else { sourceMoveNext.SourceOnCompleted(moveNextCallbackDelegate, this); } } (bool waitCallback, bool requireNextIteration) TryMoveNextCore(bool sourceHasCurrent, out bool result) { if (sourceHasCurrent) { var task = TransformAsync(enumerator.Current); if (UnwarapTask(task, out var taskResult)) { return ActionCompleted(TrySetCurrentCore(taskResult), out result); } else { return WaitAwaitCallback(out result); } } return IterateFinished(out result); } protected bool UnwarapTask(UniTask taskResult, out TAwait result) { resultAwaiter = taskResult.GetAwaiter(); if (resultAwaiter.IsCompleted) { result = resultAwaiter.GetResult(); return true; } else { resultAwaiter.SourceOnCompleted(setCurrentCallbackDelegate, this); result = default; return false; } } static void MoveNextCallBack(object state) { var self = (AsyncEnumeratorAwaitSelectorBase)state; bool result = false; try { (bool waitCallback, bool requireNextIteration) = self.TryMoveNextCore(self.sourceMoveNext.GetResult(), out result); if (waitCallback) { return; } if (requireNextIteration) { self.SourceMoveNext(); return; } else { self.completionSource.TrySetResult(result); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } } static void SetCurrentCallBack(object state) { var self = (AsyncEnumeratorAwaitSelectorBase)state; bool doneSetCurrent; try { var result = self.resultAwaiter.GetResult(); doneSetCurrent = self.TrySetCurrentCore(result); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { if (doneSetCurrent) { self.completionSource.TrySetResult(true); } else { self.SourceMoveNext(); } } } // if require additional resource to dispose, override and call base.DisposeAsync. public virtual UniTask DisposeAsync() { if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } }