Concurrent HashSet Class without Blocking
Updated: July-23,2021
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
[DebuggerDisplay("Count = {" + nameof(Count) + "}")]
public class CcHashSet<T> : IEnumerable
{
private readonly HashSet<T>[] _array;
private readonly int _size;
private volatile int[] _activeThreads;
private volatile int _bP;
private readonly IEqualityComparer<T> _comparer;
public volatile int NumberOfActiveThreads;
public CcHashSet() : this(1024, null)
{
}
public CcHashSet(int size) : this(size, null)
{
}
public CcHashSet(int size, IEqualityComparer<T> comparer)
{
if (comparer == null)
_comparer = EqualityComparer<T>.Default;
else
_comparer = comparer;
ThreadPool.GetMaxThreads(out var nW, out var nI);
_array = new HashSet<T>[nW];
_size = size;
NumberOfActiveThreads = 0;
_bP = 0;
_activeThreads = new int[Environment.ProcessorCount];
_activeThreads.Fill(-1);
}
public int Count
{
get
{
var totalCount = 0;
for (var i = 0; i < _activeThreads.Length; ++i)
if (_activeThreads[i] != -1)
if (_array[_activeThreads[i]] != null)
totalCount += _array[_activeThreads[i]].Count;
return totalCount;
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnum();
}
public IEnumerator<T> GetEnumerator()
{
return GetEnum();
}
public void Add(T item)
{
var id = Thread.CurrentThread.ManagedThreadId;
if (_array[id] == null)
{
_array[id] = new HashSet<T>(_size, _comparer);
Interlocked.Increment(ref NumberOfActiveThreads);
if (_bP >= _activeThreads.Length)
{
var nAtA = new int[_activeThreads.Length << 1];
nAtA.Fill(-1);
for (var i = 0; i < _activeThreads.Length; ++i)
if (_activeThreads[i] != -1)
nAtA[i] = _activeThreads[i];
_activeThreads = nAtA;
}
_activeThreads[_bP] = id;
Interlocked.Increment(ref _bP);
}
if (!Contains(item))
_array[id].Add(item);
}
public bool Contains(T item)
{
for (var i = 0; i < _activeThreads.Length; ++i)
if (_activeThreads[i] != -1)
if (_array[_activeThreads[i]] != null)
if (_array[_activeThreads[i]].Contains(item))
return true;
return false;
}
public IEnumerator<T> GetEnum()
{
var arr = ToArray();
foreach (var i in arr)
yield return i;
}
public T[] ToArray()
{
var totalCount = 0;
for (var i = 0; i < _activeThreads.Length; ++i)
if (_activeThreads[i] != -1)
if (_array[_activeThreads[i]] != null)
totalCount += _array[_activeThreads[i]].Count;
var ta = new T[totalCount];
var ptr = 0;
for (var i = 0; i < _activeThreads.Length; ++i)
if (_activeThreads[i] != -1)
if (_array[_activeThreads[i]] != null)
{
var it = _array[_activeThreads[i]].ToArray();
for (var j = 0; j < it.Length; ++j)
ta[ptr++] = it[j];
}
return ta;
}
}