Distinct, Except, Intersect, Union

This commit is contained in:
neuecc
2020-05-11 15:53:27 +09:00
parent 8ef7a66081
commit b20b37e7a5
11 changed files with 485 additions and 3912 deletions

View File

@@ -100,10 +100,18 @@ namespace Cysharp.Threading.Tasks.Linq
}
completionSource.Reset();
SourceMoveNext();
if (!OnFirstIteration())
{
SourceMoveNext();
}
return new UniTask<bool>(this, completionSource.Version);
}
protected virtual bool OnFirstIteration()
{
return false;
}
protected void SourceMoveNext()
{
CONTINUE:

View File

@@ -1,775 +1,76 @@
namespace Cysharp.Threading.Tasks.Linq
using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Cysharp.Threading.Tasks.Linq
{
internal sealed class Distinct
public static partial class UniTaskAsyncEnumerable
{
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source)
{
Error.ThrowArgumentNullException(source, nameof(source));
return Distinct(source, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(source, nameof(source));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return new Distinct<TSource>(source, comparer);
}
}
}
internal sealed class Distinct<TSource> : IUniTaskAsyncEnumerable<TSource>
{
readonly IUniTaskAsyncEnumerable<TSource> source;
readonly IEqualityComparer<TSource> comparer;
public Distinct(IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
{
this.source = source;
this.comparer = comparer;
}
public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(source, comparer, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TSource>
{
readonly HashSet<TSource> set;
public Enumerator(IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
: base(source, cancellationToken)
{
this.set = new HashSet<TSource>(comparer);
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
{
if (sourceHasCurrent)
{
var v = SourceCurrent;
if (set.Add(v))
{
Current = v;
result = true;
return true;
}
else
{
result = default;
return false;
}
}
result = false;
return true;
}
}
}
}

View File

@@ -1,775 +1,116 @@
namespace Cysharp.Threading.Tasks.Linq
using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Cysharp.Threading.Tasks.Linq
{
internal sealed class Except
public static partial class UniTaskAsyncEnumerable
{
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
return new Except<TSource>(first, second, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return new Except<TSource>(first, second, comparer);
}
}
}
internal sealed class Except<TSource> : IUniTaskAsyncEnumerable<TSource>
{
readonly IUniTaskAsyncEnumerable<TSource> first;
readonly IUniTaskAsyncEnumerable<TSource> second;
readonly IEqualityComparer<TSource> comparer;
public Except(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
this.first = first;
this.second = second;
this.comparer = comparer;
}
public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(first, second, comparer, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TSource>
{
static Action<object> HashSetAsyncCoreDelegate = HashSetAsyncCore;
readonly IEqualityComparer<TSource> comparer;
readonly IUniTaskAsyncEnumerable<TSource> second;
HashSet<TSource> set;
UniTask<HashSet<TSource>>.Awaiter awaiter;
public Enumerator(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
: base(first, cancellationToken)
{
this.second = second;
this.comparer = comparer;
}
protected override bool OnFirstIteration()
{
if (set != null) return false;
awaiter = second.ToHashSetAsync(cancellationToken).GetAwaiter();
if (awaiter.IsCompleted)
{
set = awaiter.GetResult();
SourceMoveNext();
}
else
{
awaiter.SourceOnCompleted(HashSetAsyncCoreDelegate, this);
}
return true;
}
static void HashSetAsyncCore(object state)
{
var self = (Enumerator)state;
if (self.TryGetResult(self.awaiter, out var result))
{
self.set = result;
self.SourceMoveNext();
}
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
{
if (sourceHasCurrent)
{
var v = SourceCurrent;
if (set.Add(v))
{
Current = v;
result = true;
return true;
}
else
{
result = default;
return false;
}
}
result = false;
return true;
}
}
}
}

View File

@@ -1,775 +1,117 @@
namespace Cysharp.Threading.Tasks.Linq
using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Cysharp.Threading.Tasks.Linq
{
internal sealed class Intersect
public static partial class UniTaskAsyncEnumerable
{
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
return new Intersect<TSource>(first, second, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return new Intersect<TSource>(first, second, comparer);
}
}
}
internal sealed class Intersect<TSource> : IUniTaskAsyncEnumerable<TSource>
{
readonly IUniTaskAsyncEnumerable<TSource> first;
readonly IUniTaskAsyncEnumerable<TSource> second;
readonly IEqualityComparer<TSource> comparer;
public Intersect(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
this.first = first;
this.second = second;
this.comparer = comparer;
}
public IUniTaskAsyncEnumerator<TSource> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(first, second, comparer, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TSource>
{
static Action<object> HashSetAsyncCoreDelegate = HashSetAsyncCore;
readonly IEqualityComparer<TSource> comparer;
readonly IUniTaskAsyncEnumerable<TSource> second;
HashSet<TSource> set;
UniTask<HashSet<TSource>>.Awaiter awaiter;
public Enumerator(IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
: base(first, cancellationToken)
{
this.second = second;
this.comparer = comparer;
}
protected override bool OnFirstIteration()
{
if (set != null) return false;
awaiter = second.ToHashSetAsync(cancellationToken).GetAwaiter();
if (awaiter.IsCompleted)
{
set = awaiter.GetResult();
SourceMoveNext();
}
else
{
awaiter.SourceOnCompleted(HashSetAsyncCoreDelegate, this);
}
return true;
}
static void HashSetAsyncCore(object state)
{
var self = (Enumerator)state;
if (self.TryGetResult(self.awaiter, out var result))
{
self.set = result;
self.SourceMoveNext();
}
}
protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result)
{
if (sourceHasCurrent)
{
var v = SourceCurrent;
if (set.Remove(v))
{
Current = v;
result = true;
return true;
}
else
{
result = default;
return false;
}
}
result = false;
return true;
}
}
}
}

View File

@@ -1,775 +1 @@
namespace Cysharp.Threading.Tasks.Linq
{
internal sealed class Join
{
}
}


View File

@@ -10,15 +10,23 @@ namespace Cysharp.Threading.Tasks.Linq
{
Error.ThrowArgumentNullException(source, nameof(source));
return Cysharp.Threading.Tasks.Linq.ToHashSet.InvokeAsync(source, cancellationToken);
return Cysharp.Threading.Tasks.Linq.ToHashSet.InvokeAsync(source, EqualityComparer<TSource>.Default, cancellationToken);
}
public static UniTask<HashSet<TSource>> ToHashSetAsync<TSource>(this IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken = default)
{
Error.ThrowArgumentNullException(source, nameof(source));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
return Cysharp.Threading.Tasks.Linq.ToHashSet.InvokeAsync(source, comparer, cancellationToken);
}
}
internal static class ToHashSet
{
internal static async UniTask<HashSet<TSource>> InvokeAsync<TSource>(IUniTaskAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
internal static async UniTask<HashSet<TSource>> InvokeAsync<TSource>(IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
{
var set = new HashSet<TSource>();
var set = new HashSet<TSource>(comparer);
var e = source.GetAsyncEnumerator(cancellationToken);
try

View File

@@ -1,775 +1,26 @@
namespace Cysharp.Threading.Tasks.Linq
using Cysharp.Threading.Tasks.Internal;
using System.Collections.Generic;
namespace Cysharp.Threading.Tasks.Linq
{
internal sealed class Union
public static partial class UniTaskAsyncEnumerable
{
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
return Union<TSource>(first, second, EqualityComparer<TSource>.Default);
}
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(comparer, nameof(comparer));
// improv without combinate?
return first.Concat(second).Distinct(comparer);
}
}
}
}

View File

@@ -26,26 +26,6 @@ namespace ___Dummy
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Distinct<TSource>(this IUniTaskAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Except<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
@@ -200,16 +180,6 @@ namespace ___Dummy
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Intersect<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IUniTaskAsyncEnumerable<TOuter> outer, IUniTaskAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector)
{
throw new NotImplementedException();
@@ -309,10 +279,6 @@ namespace ___Dummy
public static IUniTaskAsyncEnumerable<TSource> TakeLast<TSource>(this IUniTaskAsyncEnumerable<TSource> source, Int32 count)
{
throw new NotImplementedException();
}
public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
{
@@ -380,19 +346,6 @@ namespace ___Dummy
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second)
{
throw new NotImplementedException();
}
public static IUniTaskAsyncEnumerable<TSource> Union<TSource>(this IUniTaskAsyncEnumerable<TSource> first, IUniTaskAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
{
throw new NotImplementedException();
}
}