Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize PersistentGenericSet snapshot #2394

Merged
merged 7 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/NHibernate.Test/NHibernate.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
<Compile Include="..\NHibernate\Util\AsyncReaderWriterLock.cs">
<Link>UtilityTest\AsyncReaderWriterLock.cs</Link>
</Compile>
<Compile Include="..\NHibernate\Collection\Generic\SetHelpers\SetSnapShot.cs">
<Link>UtilityTest\SetSnapShot.cs</Link>
</Compile>
</ItemGroup>
<ItemGroup>
<PackageReference Include="log4net" Version="2.0.8" />
Expand Down
97 changes: 97 additions & 0 deletions src/NHibernate.Test/UtilityTest/SetSnapShotFixture.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using NHibernate.Collection.Generic.SetHelpers;
using NUnit.Framework;

namespace NHibernate.Test.UtilityTest
{
[TestFixture]
public class SetSnapShotFixture
{
[Test]
public void TestNullValue()
{
var sn = new SetSnapShot<object>(1);
Assert.That(sn, Has.Count.EqualTo(0));
Assert.That(sn, Is.EquivalentTo(new object[0]));
Assert.That(sn.Contains(null), Is.False);
Assert.That(sn.TryGetValue(null, out _), Is.False);

sn.Add(null);
Assert.That(sn, Has.Count.EqualTo(1));
Assert.That(sn, Is.EquivalentTo(new object[] {null}));

Assert.That(sn.TryGetValue(null, out var value), Is.True);
Assert.That(sn.Contains(null), Is.True);
Assert.That(value, Is.Null);

Assert.That(sn.Remove(null), Is.True);
Assert.That(sn, Has.Count.EqualTo(0));
Assert.That(sn, Is.EquivalentTo(new object[0]));

sn.Add(null);
Assert.That(sn, Has.Count.EqualTo(1));

sn.Clear();
Assert.That(sn, Has.Count.EqualTo(0));
Assert.That(sn, Is.EquivalentTo(new object[0]));
}

[Test]
public void TestInitialization()
{
var list = new List<string> {"test1", null, "test2"};
var sn = new SetSnapShot<string>(list);
Assert.That(sn, Has.Count.EqualTo(list.Count));
Assert.That(sn, Is.EquivalentTo(list));
Assert.That(sn.TryGetValue("test1", out _), Is.True);
Assert.That(sn.TryGetValue(null, out _), Is.True);
}

[Test]
public void TestCopyTo()
{
var list = new List<string> {"test1", null, "test2"};
var sn = new SetSnapShot<string>(list);

var array = new string[3];
sn.CopyTo(array, 0);
Assert.That(list, Is.EquivalentTo(array));
}

[Test]
public void TestSerialization()
{
var list = new List<string> {"test1", null, "test2"};
var sn = new SetSnapShot<string>(list);

sn = Deserialize<SetSnapShot<string>>(Serialize(sn));
Assert.That(sn, Has.Count.EqualTo(list.Count));
Assert.That(sn, Is.EquivalentTo(list));
Assert.That(sn.TryGetValue("test1", out var item1), Is.True);
Assert.That(item1, Is.EqualTo("test1"));
Assert.That(sn.TryGetValue(null, out var nullValue), Is.True);
Assert.That(nullValue, Is.Null);
}

private static byte[] Serialize<T>(T obj)
{
var serializer = new BinaryFormatter();
using (var stream = new MemoryStream())
{
serializer.Serialize(stream, obj);
return stream.ToArray();
}
}

private static T Deserialize<T>(byte[] value)
{
var serializer = new BinaryFormatter();
using (var stream = new MemoryStream(value))
{
return (T) serializer.Deserialize(stream);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public override async Task<bool> EqualsSnapshotAsync(ICollectionPersister persis
{
cancellationToken.ThrowIfCancellationRequested();
var elementType = persister.ElementType;
var snapshot = (ISetSnapshot<T>)GetSnapshot();
var snapshot = (SetSnapShot<T>)GetSnapshot();
if (((ICollection)snapshot).Count != WrappedSet.Count)
{
return false;
Expand Down Expand Up @@ -122,7 +122,7 @@ public override async Task<IEnumerable> GetDeletesAsync(ICollectionPersister per
{
cancellationToken.ThrowIfCancellationRequested();
IType elementType = persister.ElementType;
var sn = (ISetSnapshot<T>)GetSnapshot();
var sn = (SetSnapShot<T>)GetSnapshot();
var deletes = new List<T>(((ICollection<T>)sn).Count);

deletes.AddRange(sn.Where(obj => !WrappedSet.Contains(obj)));
Expand All @@ -140,7 +140,7 @@ public override async Task<IEnumerable> GetDeletesAsync(ICollectionPersister per
public override async Task<bool> NeedsInsertingAsync(object entry, int i, IType elemType, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var sn = (ISetSnapshot<T>)GetSnapshot();
var sn = (SetSnapShot<T>)GetSnapshot();
T oldKey;

// note that it might be better to iterate the snapshot but this is safe,
Expand Down
6 changes: 3 additions & 3 deletions src/NHibernate/Collection/Generic/PersistentGenericSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public override ICollection GetOrphans(object snapshot, string entityName)
public override bool EqualsSnapshot(ICollectionPersister persister)
{
var elementType = persister.ElementType;
var snapshot = (ISetSnapshot<T>)GetSnapshot();
var snapshot = (SetSnapShot<T>)GetSnapshot();
if (((ICollection)snapshot).Count != WrappedSet.Count)
{
return false;
Expand Down Expand Up @@ -217,7 +217,7 @@ public override object Disassemble(ICollectionPersister persister)
public override IEnumerable GetDeletes(ICollectionPersister persister, bool indexIsFormula)
{
IType elementType = persister.ElementType;
var sn = (ISetSnapshot<T>)GetSnapshot();
var sn = (SetSnapShot<T>)GetSnapshot();
var deletes = new List<T>(((ICollection<T>)sn).Count);

deletes.AddRange(sn.Where(obj => !WrappedSet.Contains(obj)));
Expand All @@ -234,7 +234,7 @@ public override IEnumerable GetDeletes(ICollectionPersister persister, bool inde

public override bool NeedsInserting(object entry, int i, IType elemType)
{
var sn = (ISetSnapshot<T>)GetSnapshot();
var sn = (SetSnapShot<T>)GetSnapshot();
T oldKey;

// note that it might be better to iterate the snapshot but this is safe,
Expand Down
10 changes: 0 additions & 10 deletions src/NHibernate/Collection/Generic/SetHelpers/ISetSnapshot.cs

This file was deleted.

153 changes: 112 additions & 41 deletions src/NHibernate/Collection/Generic/SetHelpers/SetSnapShot.cs
Original file line number Diff line number Diff line change
@@ -1,109 +1,180 @@
using System;
using System.Collections;
using System.Collections.Generic;
#if NETCOREAPP2_0
using System.Runtime.Serialization;
using System.Threading;
#endif

namespace NHibernate.Collection.Generic.SetHelpers
{
#if NETFX || NETSTANDARD2_0
// TODO 6.0: Consider removing this class in case we upgrade to .NET 4.7.2 and NET Standard 2.1,
// which have HashSet<T>.TryGetValue
[Serializable]
internal class SetSnapShot<T> : ISetSnapshot<T>
internal class SetSnapShot<T> : ICollection<T>, IReadOnlyCollection<T>, ICollection
{
private readonly List<T> _elements;
public SetSnapShot()
{
_elements = new List<T>();
}
private readonly Dictionary<T, T> _values;
private bool _hasNull;

public SetSnapShot(int capacity)
{
_elements = new List<T>(capacity);
_values = new Dictionary<T, T>(capacity);
}

public SetSnapShot(IEnumerable<T> collection)
{
_elements = new List<T>(collection);
_values = new Dictionary<T, T>();
foreach (var item in collection)
{
if (item == null)
{
_hasNull = true;
}
else
{
_values.Add(item, item);
}
}
}

public IEnumerator<T> GetEnumerator()
public bool TryGetValue(T equalValue, out T actualValue)
{
return _elements.GetEnumerator();
if (equalValue != null)
{
return _values.TryGetValue(equalValue, out actualValue);
}

actualValue = default(T);
return _hasNull;
}

IEnumerator IEnumerable.GetEnumerator()
public IEnumerator<T> GetEnumerator()
{
return GetEnumerator();
if (_hasNull)
{
yield return default(T);
}

foreach (var item in _values.Keys)
{
yield return item;
}
}

public void Add(T item)
{
_elements.Add(item);
if (item == null)
{
_hasNull = true;
return;
}

_values.Add(item, item);
}

public void Clear()
{
throw new InvalidOperationException();
_values.Clear();
_hasNull = false;
}

public bool Contains(T item)
{
return _elements.Contains(item);
return item == null ? _hasNull : _values.ContainsKey(item);
}

public void CopyTo(T[] array, int arrayIndex)
{
_elements.CopyTo(array, arrayIndex);
if (_hasNull)
array[arrayIndex] = default(T);
_values.Keys.CopyTo(array, arrayIndex + (_hasNull ? 1 : 0));
}

public bool Remove(T item)
{
throw new InvalidOperationException();
}
if (item != null)
{
return _values.Remove(item);
}

public void CopyTo(Array array, int index)
{
((ICollection)_elements).CopyTo(array, index);
if (!_hasNull)
{
return false;
}

_hasNull = false;
return true;
}

int ICollection.Count
IEnumerator IEnumerable.GetEnumerator()
{
get { return _elements.Count; }
return GetEnumerator();
}

public object SyncRoot
void ICollection.CopyTo(Array array, int index)
{
get { return ((ICollection)_elements).SyncRoot; }
if (!(array is T[] typedArray))
{
throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array));
}

CopyTo(typedArray, index);
}

public bool IsSynchronized
public int Count => _values.Count + (_hasNull ? 1 : 0);

public bool IsReadOnly => ((ICollection<KeyValuePair<T, T>>) _values).IsReadOnly;

public object SyncRoot => ((ICollection) _values).SyncRoot;

public bool IsSynchronized => ((ICollection) _values).IsSynchronized;
}
#endif

#if NETCOREAPP2_0
[Serializable]
internal class SetSnapShot<T> : HashSet<T>, ICollection
{
[NonSerialized]
private object _syncRoot;

public SetSnapShot(int capacity) : base(capacity)
{
get { return ((ICollection)_elements).IsSynchronized; }
}

int ICollection<T>.Count
public SetSnapShot(IEnumerable<T> collection) : base(collection)
{
get { return _elements.Count; }
}

int IReadOnlyCollection<T>.Count
protected SetSnapShot(SerializationInfo info, StreamingContext context) : base(info, context)
{
get { return _elements.Count; }
}

public bool IsReadOnly
void ICollection.CopyTo(Array array, int index)
{
get { return ((ICollection<T>)_elements).IsReadOnly; }
if (!(array is T[] typedArray))
{
throw new ArgumentException($"Array must be of type {typeof(T[])}.", nameof(array));
}

CopyTo(typedArray, index);
}

public bool TryGetValue(T element, out T value)
bool ICollection.IsSynchronized => false;

object ICollection.SyncRoot
{
var idx = _elements.IndexOf(element);
if (idx >= 0)
get
{
value = _elements[idx];
return true;
}
if (_syncRoot == null)
{
Interlocked.CompareExchange<object>(ref _syncRoot, new object(), null);
Copy link
Contributor Author

@maca88 maca88 May 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used the same logic as List<> has.

}

value = default(T);
return false;
return _syncRoot;
}
}
}
#endif
}