Skip to content

Commit c52c7f1

Browse files
authoredMay 17, 2023
Linq: add enum Equals and object Equals support (#3242)
1 parent 79ad5b8 commit c52c7f1

File tree

4 files changed

+64
-4
lines changed

4 files changed

+64
-4
lines changed
 

‎src/NHibernate.Test/Async/Linq/FunctionTests.cs

+27
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,33 @@ where item.Discount.Equals(-1)
459459
await (ObjectDumper.WriteAsync(query));
460460
}
461461

462+
[Test]
463+
public async Task WhereEnumEqualAsync()
464+
{
465+
var query = from item in db.PatientRecords
466+
where item.Gender.Equals(Gender.Female)
467+
select item;
468+
469+
await (ObjectDumper.WriteAsync(query));
470+
471+
query = from item in db.PatientRecords
472+
where item.Gender.Equals(item.Gender)
473+
select item;
474+
475+
await (ObjectDumper.WriteAsync(query));
476+
}
477+
478+
479+
[Test]
480+
public async Task WhereObjectEqualAsync()
481+
{
482+
var query = from item in db.PatientRecords
483+
where ((object) item.Gender).Equals(Gender.Female)
484+
select item;
485+
486+
await (ObjectDumper.WriteAsync(query));
487+
}
488+
462489
[Test]
463490
public async Task WhereEquatableEqualAsync()
464491
{

‎src/NHibernate.Test/Linq/FunctionTests.cs

+27
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,33 @@ where item.Discount.Equals(-1)
491491
ObjectDumper.Write(query);
492492
}
493493

494+
[Test]
495+
public void WhereEnumEqual()
496+
{
497+
var query = from item in db.PatientRecords
498+
where item.Gender.Equals(Gender.Female)
499+
select item;
500+
501+
ObjectDumper.Write(query);
502+
503+
query = from item in db.PatientRecords
504+
where item.Gender.Equals(item.Gender)
505+
select item;
506+
507+
ObjectDumper.Write(query);
508+
}
509+
510+
511+
[Test]
512+
public void WhereObjectEqual()
513+
{
514+
var query = from item in db.PatientRecords
515+
where ((object) item.Gender).Equals(Gender.Female)
516+
select item;
517+
518+
ObjectDumper.Write(query);
519+
}
520+
494521
[Test]
495522
public void WhereEquatableEqual()
496523
{

‎src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ public virtual bool TryGetGenerator(MemberInfo property, out IHqlGeneratorForPro
111111

112112
public virtual void RegisterGenerator(MethodInfo method, IHqlGeneratorForMethod generator)
113113
{
114-
registeredMethods.Add(method, generator);
114+
registeredMethods[method] = generator;
115115
}
116116

117117
public virtual void RegisterGenerator(MemberInfo property, IHqlGeneratorForProperty generator)
118118
{
119-
registeredProperties.Add(property, generator);
119+
registeredProperties[property] = generator;
120120
}
121121

122122
public void RegisterGenerator(IRuntimeMethodHqlGenerator generator)

‎src/NHibernate/Linq/Functions/EqualsGenerator.cs

+8-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ public class EqualsGenerator : BaseHqlGeneratorForMethod
5858
ReflectHelper.GetMethodDefinition<IEquatable<DateTime>>(x => x.Equals(default(DateTime))),
5959
ReflectHelper.GetMethodDefinition<IEquatable<DateTimeOffset>>(x => x.Equals(default(DateTimeOffset))),
6060
ReflectHelper.GetMethodDefinition<IEquatable<TimeSpan>>(x => x.Equals(default(TimeSpan))),
61-
ReflectHelper.GetMethodDefinition<IEquatable<bool>>(x => x.Equals(default(bool)))
61+
ReflectHelper.GetMethodDefinition<IEquatable<bool>>(x => x.Equals(default(bool))),
62+
ReflectHelper.GetMethodDefinition<object>(x => x.Equals(default(object))), // this covers also Enum.Equals
63+
ReflectHelper.GetMethodDefinition<IEquatable<object>>(x => x.Equals(default(object))),
64+
ReflectHelper.GetMethodDefinition<IEquatable<Enum>>(x => x.Equals(default(Enum)))
6265
};
6366

6467
public EqualsGenerator()
@@ -72,7 +75,10 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
7275
{
7376
Expression lhs = arguments.Count == 1 ? targetObject : arguments[0];
7477
Expression rhs = arguments.Count == 1 ? arguments[0] : arguments[1];
75-
78+
if (lhs.Type.IsEnum)
79+
{
80+
return visitor.Visit(Expression.Equal(lhs, Expression.Convert(rhs, lhs.Type)));
81+
}
7682
return visitor.Visit(Expression.Equal(lhs, rhs));
7783
}
7884
}

0 commit comments

Comments
 (0)