diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 8b57857cf85..061199e776f 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -49,6 +49,9 @@ UtilityTest\AsyncReaderWriterLock.cs + + UtilityTest\SetSnapShot.cs + diff --git a/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs b/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs new file mode 100644 index 00000000000..c13478424b0 --- /dev/null +++ b/src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs @@ -0,0 +1,97 @@ +using System.Collections.Generic; +using System.IO; +using System.Runtime.Serialization.Formatters.Binary; +using NHibernate.Collection.Generic.SetHelpers; +using NUnit.Framework; + +namespace NHibernate.Test.UtilityTest +{ + [TestFixture] + public class SetSnapShotFixture + { + [Test] + public void TestNullValue() + { + var sn = new SetSnapShot(1); + Assert.That(sn, Has.Count.EqualTo(0)); + Assert.That(sn, Is.EquivalentTo(new object[0])); + Assert.That(sn.Contains(null), Is.False); + Assert.That(sn.TryGetValue(null, out _), Is.False); + + sn.Add(null); + Assert.That(sn, Has.Count.EqualTo(1)); + Assert.That(sn, Is.EquivalentTo(new object[] {null})); + + Assert.That(sn.TryGetValue(null, out var value), Is.True); + Assert.That(sn.Contains(null), Is.True); + Assert.That(value, Is.Null); + + Assert.That(sn.Remove(null), Is.True); + Assert.That(sn, Has.Count.EqualTo(0)); + Assert.That(sn, Is.EquivalentTo(new object[0])); + + sn.Add(null); + Assert.That(sn, Has.Count.EqualTo(1)); + + sn.Clear(); + Assert.That(sn, Has.Count.EqualTo(0)); + Assert.That(sn, Is.EquivalentTo(new object[0])); + } + + [Test] + public void TestInitialization() + { + var list = new List {"test1", null, "test2"}; + var sn = new SetSnapShot(list); + Assert.That(sn, Has.Count.EqualTo(list.Count)); + Assert.That(sn, Is.EquivalentTo(list)); + Assert.That(sn.TryGetValue("test1", out _), Is.True); + Assert.That(sn.TryGetValue(null, out _), Is.True); + } + + [Test] + public void TestCopyTo() + { + var list = new List {"test1", null, "test2"}; + var sn = new SetSnapShot(list); + + var array = new string[3]; + sn.CopyTo(array, 0); + Assert.That(list, Is.EquivalentTo(array)); + } + + [Test] + public void TestSerialization() + { + var list = new List {"test1", null, "test2"}; + var sn = new SetSnapShot(list); + + sn = Deserialize>(Serialize(sn)); + Assert.That(sn, Has.Count.EqualTo(list.Count)); + Assert.That(sn, Is.EquivalentTo(list)); + Assert.That(sn.TryGetValue("test1", out var item1), Is.True); + Assert.That(item1, Is.EqualTo("test1")); + Assert.That(sn.TryGetValue(null, out var nullValue), Is.True); + Assert.That(nullValue, Is.Null); + } + + private static byte[] Serialize(T obj) + { + var serializer = new BinaryFormatter(); + using (var stream = new MemoryStream()) + { + serializer.Serialize(stream, obj); + return stream.ToArray(); + } + } + + private static T Deserialize(byte[] value) + { + var serializer = new BinaryFormatter(); + using (var stream = new MemoryStream(value)) + { + return (T) serializer.Deserialize(stream); + } + } + } +} diff --git a/src/NHibernate/Async/Collection/Generic/PersistentGenericSet.cs b/src/NHibernate/Async/Collection/Generic/PersistentGenericSet.cs index e608a2ffae2..65d992d6c28 100644 --- a/src/NHibernate/Async/Collection/Generic/PersistentGenericSet.cs +++ b/src/NHibernate/Async/Collection/Generic/PersistentGenericSet.cs @@ -54,7 +54,7 @@ public override async Task EqualsSnapshotAsync(ICollectionPersister persis { cancellationToken.ThrowIfCancellationRequested(); var elementType = persister.ElementType; - var snapshot = (ISetSnapshot)GetSnapshot(); + var snapshot = (SetSnapShot)GetSnapshot(); if (((ICollection)snapshot).Count != WrappedSet.Count) { return false; @@ -122,7 +122,7 @@ public override async Task GetDeletesAsync(ICollectionPersister per { cancellationToken.ThrowIfCancellationRequested(); IType elementType = persister.ElementType; - var sn = (ISetSnapshot)GetSnapshot(); + var sn = (SetSnapShot)GetSnapshot(); var deletes = new List(((ICollection)sn).Count); deletes.AddRange(sn.Where(obj => !WrappedSet.Contains(obj))); @@ -140,7 +140,7 @@ public override async Task GetDeletesAsync(ICollectionPersister per public override async Task NeedsInsertingAsync(object entry, int i, IType elemType, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - var sn = (ISetSnapshot)GetSnapshot(); + var sn = (SetSnapShot)GetSnapshot(); T oldKey; // note that it might be better to iterate the snapshot but this is safe, diff --git a/src/NHibernate/Collection/Generic/PersistentGenericSet.cs b/src/NHibernate/Collection/Generic/PersistentGenericSet.cs index 03a273cd32e..bec3cc7505f 100644 --- a/src/NHibernate/Collection/Generic/PersistentGenericSet.cs +++ b/src/NHibernate/Collection/Generic/PersistentGenericSet.cs @@ -103,7 +103,7 @@ public override ICollection GetOrphans(object snapshot, string entityName) public override bool EqualsSnapshot(ICollectionPersister persister) { var elementType = persister.ElementType; - var snapshot = (ISetSnapshot)GetSnapshot(); + var snapshot = (SetSnapShot)GetSnapshot(); if (((ICollection)snapshot).Count != WrappedSet.Count) { return false; @@ -217,7 +217,7 @@ public override object Disassemble(ICollectionPersister persister) public override IEnumerable GetDeletes(ICollectionPersister persister, bool indexIsFormula) { IType elementType = persister.ElementType; - var sn = (ISetSnapshot)GetSnapshot(); + var sn = (SetSnapShot)GetSnapshot(); var deletes = new List(((ICollection)sn).Count); deletes.AddRange(sn.Where(obj => !WrappedSet.Contains(obj))); @@ -234,7 +234,7 @@ public override IEnumerable GetDeletes(ICollectionPersister persister, bool inde public override bool NeedsInserting(object entry, int i, IType elemType) { - var sn = (ISetSnapshot)GetSnapshot(); + var sn = (SetSnapShot)GetSnapshot(); T oldKey; // note that it might be better to iterate the snapshot but this is safe, diff --git a/src/NHibernate/Collection/Generic/SetHelpers/ISetSnapshot.cs b/src/NHibernate/Collection/Generic/SetHelpers/ISetSnapshot.cs deleted file mode 100644 index 20f9983911f..00000000000 --- a/src/NHibernate/Collection/Generic/SetHelpers/ISetSnapshot.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System.Collections; -using System.Collections.Generic; - -namespace NHibernate.Collection.Generic.SetHelpers -{ - internal interface ISetSnapshot : ICollection, IReadOnlyCollection, ICollection - { - bool TryGetValue(T element, out T value); - } -} diff --git a/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs b/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs index 143d7f2564d..eb29c6b6fd5 100644 --- a/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs +++ b/src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs @@ -1,109 +1,180 @@ using System; using System.Collections; using System.Collections.Generic; +#if NETCOREAPP2_0 +using System.Runtime.Serialization; +using System.Threading; +#endif namespace NHibernate.Collection.Generic.SetHelpers { +#if NETFX || NETSTANDARD2_0 + // TODO 6.0: Consider removing this class in case we upgrade to .NET 4.7.2 and NET Standard 2.1, + // which have HashSet.TryGetValue [Serializable] - internal class SetSnapShot : ISetSnapshot + internal class SetSnapShot : ICollection, IReadOnlyCollection, ICollection { - private readonly List _elements; - public SetSnapShot() - { - _elements = new List(); - } + private readonly Dictionary _values; + private bool _hasNull; public SetSnapShot(int capacity) { - _elements = new List(capacity); + _values = new Dictionary(capacity); } public SetSnapShot(IEnumerable collection) { - _elements = new List(collection); + _values = new Dictionary(); + foreach (var item in collection) + { + if (item == null) + { + _hasNull = true; + } + else + { + _values.Add(item, item); + } + } } - public IEnumerator GetEnumerator() + public bool TryGetValue(T equalValue, out T actualValue) { - return _elements.GetEnumerator(); + if (equalValue != null) + { + return _values.TryGetValue(equalValue, out actualValue); + } + + actualValue = default(T); + return _hasNull; } - IEnumerator IEnumerable.GetEnumerator() + public IEnumerator GetEnumerator() { - return GetEnumerator(); + if (_hasNull) + { + yield return default(T); + } + + foreach (var item in _values.Keys) + { + yield return item; + } } public void Add(T item) { - _elements.Add(item); + if (item == null) + { + _hasNull = true; + return; + } + + _values.Add(item, item); } public void Clear() { - throw new InvalidOperationException(); + _values.Clear(); + _hasNull = false; } public bool Contains(T item) { - return _elements.Contains(item); + return item == null ? _hasNull : _values.ContainsKey(item); } public void CopyTo(T[] array, int arrayIndex) { - _elements.CopyTo(array, arrayIndex); + if (_hasNull) + array[arrayIndex] = default(T); + _values.Keys.CopyTo(array, arrayIndex + (_hasNull ? 1 : 0)); } public bool Remove(T item) { - throw new InvalidOperationException(); - } + if (item != null) + { + return _values.Remove(item); + } - public void CopyTo(Array array, int index) - { - ((ICollection)_elements).CopyTo(array, index); + if (!_hasNull) + { + return false; + } + + _hasNull = false; + return true; } - int ICollection.Count + IEnumerator IEnumerable.GetEnumerator() { - get { return _elements.Count; } + return GetEnumerator(); } - public object SyncRoot + void ICollection.CopyTo(Array array, int index) { - get { return ((ICollection)_elements).SyncRoot; } + if (!(array is T[] typedArray)) + { + throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array)); + } + + CopyTo(typedArray, index); } - public bool IsSynchronized + public int Count => _values.Count + (_hasNull ? 1 : 0); + + public bool IsReadOnly => ((ICollection>) _values).IsReadOnly; + + public object SyncRoot => ((ICollection) _values).SyncRoot; + + public bool IsSynchronized => ((ICollection) _values).IsSynchronized; + } +#endif + +#if NETCOREAPP2_0 + [Serializable] + internal class SetSnapShot : HashSet, ICollection + { + [NonSerialized] + private object _syncRoot; + + public SetSnapShot(int capacity) : base(capacity) { - get { return ((ICollection)_elements).IsSynchronized; } } - int ICollection.Count + public SetSnapShot(IEnumerable collection) : base(collection) { - get { return _elements.Count; } } - int IReadOnlyCollection.Count + protected SetSnapShot(SerializationInfo info, StreamingContext context) : base(info, context) { - get { return _elements.Count; } } - public bool IsReadOnly + void ICollection.CopyTo(Array array, int index) { - get { return ((ICollection)_elements).IsReadOnly; } + if (!(array is T[] typedArray)) + { + throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array)); + } + + CopyTo(typedArray, index); } - public bool TryGetValue(T element, out T value) + bool ICollection.IsSynchronized => false; + + object ICollection.SyncRoot { - var idx = _elements.IndexOf(element); - if (idx >= 0) + get { - value = _elements[idx]; - return true; - } + if (_syncRoot == null) + { + Interlocked.CompareExchange(ref _syncRoot, new object(), null); + } - value = default(T); - return false; + return _syncRoot; + } } } +#endif }