Skip to content

Commit 3cfb009

Browse files
maca88bahusoid
andauthored
Add an option to register a custom pre-transformer for a Linq query (#2411)
Co-authored-by: Roman Artiukhin <bahusdrive@gmail.com>
1 parent d58b668 commit 3cfb009

12 files changed

+290
-17
lines changed

doc/reference/modules/configuration.xml

+13
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,19 @@ var session = sessions.OpenSession(conn);
703703
</para>
704704
</entry>
705705
</row>
706+
<row>
707+
<entry>
708+
<literal>query.pre_transformer_registrar</literal>
709+
</entry>
710+
<entry>
711+
The class name of the LINQ query pre-transformer registrar, implementing
712+
<literal>IExpressionTransformerRegistrar</literal>. Defaults to <literal>null</literal> (no registrar).
713+
<para>
714+
<emphasis role="strong">eg.</emphasis>
715+
<literal>classname.of.ExpressionTransformerRegistrar, assembly</literal>
716+
</para>
717+
</entry>
718+
</row>
706719
<row>
707720
<entry>
708721
<literal>linqtohql.generatorsregistry</literal>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System;
12+
using System.Collections.Generic;
13+
using System.Linq;
14+
using System.Linq.Expressions;
15+
using System.Reflection;
16+
using NHibernate.Linq;
17+
using NHibernate.Linq.Visitors;
18+
using NHibernate.Util;
19+
using NUnit.Framework;
20+
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
21+
22+
namespace NHibernate.Test.Linq
23+
{
24+
using System.Threading.Tasks;
25+
[TestFixture]
26+
public class CustomPreTransformRegistrarTestsAsync : LinqTestCase
27+
{
28+
protected override void Configure(Cfg.Configuration configuration)
29+
{
30+
configuration.Properties[Cfg.Environment.PreTransformerRegistrar] = typeof(PreTransformerRegistrar).AssemblyQualifiedName;
31+
}
32+
33+
[Test]
34+
public async Task RewriteLikeAsync()
35+
{
36+
// This example shows how to use the pre-transformer registrar to rewrite the
37+
// query so that StartsWith, EndsWith and Contains methods will generate the same sql.
38+
var queryPlanCache = GetQueryPlanCache();
39+
queryPlanCache.Clear();
40+
await (db.Customers.Where(o => o.ContactName.StartsWith("A")).ToListAsync());
41+
await (db.Customers.Where(o => o.ContactName.EndsWith("A")).ToListAsync());
42+
await (db.Customers.Where(o => o.ContactName.Contains("A")).ToListAsync());
43+
44+
Assert.That(queryPlanCache.Count, Is.EqualTo(1));
45+
}
46+
47+
[Serializable]
48+
public class PreTransformerRegistrar : IExpressionTransformerRegistrar
49+
{
50+
public void Register(ExpressionTransformerRegistry expressionTransformerRegistry)
51+
{
52+
expressionTransformerRegistry.Register(new LikeTransformer());
53+
}
54+
}
55+
56+
private class LikeTransformer : IExpressionTransformer<MethodCallExpression>
57+
{
58+
private static readonly MethodInfo Like = ReflectHelper.GetMethodDefinition(() => SqlMethods.Like(null, null));
59+
private static readonly MethodInfo EndsWith = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null));
60+
private static readonly MethodInfo StartsWith = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null));
61+
private static readonly MethodInfo Contains = ReflectHelper.GetMethodDefinition<string>(x => x.Contains(null));
62+
private static readonly Dictionary<MethodInfo, Func<object, string>> ValueTransformers =
63+
new Dictionary<MethodInfo, Func<object, string>>
64+
{
65+
{StartsWith, s => $"{s}%"},
66+
{EndsWith, s => $"%{s}"},
67+
{Contains, s => $"%{s}%"},
68+
};
69+
70+
public Expression Transform(MethodCallExpression expression)
71+
{
72+
if (ValueTransformers.TryGetValue(expression.Method, out var valueTransformer) &&
73+
expression.Arguments[0] is ConstantExpression constantExpression)
74+
{
75+
return Expression.Call(
76+
Like,
77+
expression.Object,
78+
Expression.Constant(valueTransformer(constantExpression.Value))
79+
);
80+
}
81+
82+
return expression;
83+
}
84+
85+
public ExpressionType[] SupportedExpressionTypes { get; } = {ExpressionType.Call};
86+
}
87+
}
88+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Reflection;
6+
using NHibernate.Linq;
7+
using NHibernate.Linq.Visitors;
8+
using NHibernate.Util;
9+
using NUnit.Framework;
10+
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
11+
12+
namespace NHibernate.Test.Linq
13+
{
14+
[TestFixture]
15+
public class CustomPreTransformRegistrarTests : LinqTestCase
16+
{
17+
protected override void Configure(Cfg.Configuration configuration)
18+
{
19+
configuration.Properties[Cfg.Environment.PreTransformerRegistrar] = typeof(PreTransformerRegistrar).AssemblyQualifiedName;
20+
}
21+
22+
[Test]
23+
public void RewriteLike()
24+
{
25+
// This example shows how to use the pre-transformer registrar to rewrite the
26+
// query so that StartsWith, EndsWith and Contains methods will generate the same sql.
27+
var queryPlanCache = GetQueryPlanCache();
28+
queryPlanCache.Clear();
29+
db.Customers.Where(o => o.ContactName.StartsWith("A")).ToList();
30+
db.Customers.Where(o => o.ContactName.EndsWith("A")).ToList();
31+
db.Customers.Where(o => o.ContactName.Contains("A")).ToList();
32+
33+
Assert.That(queryPlanCache.Count, Is.EqualTo(1));
34+
}
35+
36+
[Serializable]
37+
public class PreTransformerRegistrar : IExpressionTransformerRegistrar
38+
{
39+
public void Register(ExpressionTransformerRegistry expressionTransformerRegistry)
40+
{
41+
expressionTransformerRegistry.Register(new LikeTransformer());
42+
}
43+
}
44+
45+
private class LikeTransformer : IExpressionTransformer<MethodCallExpression>
46+
{
47+
private static readonly MethodInfo Like = ReflectHelper.GetMethodDefinition(() => SqlMethods.Like(null, null));
48+
private static readonly MethodInfo EndsWith = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null));
49+
private static readonly MethodInfo StartsWith = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null));
50+
private static readonly MethodInfo Contains = ReflectHelper.GetMethodDefinition<string>(x => x.Contains(null));
51+
private static readonly Dictionary<MethodInfo, Func<object, string>> ValueTransformers =
52+
new Dictionary<MethodInfo, Func<object, string>>
53+
{
54+
{StartsWith, s => $"{s}%"},
55+
{EndsWith, s => $"%{s}"},
56+
{Contains, s => $"%{s}%"},
57+
};
58+
59+
public Expression Transform(MethodCallExpression expression)
60+
{
61+
if (ValueTransformers.TryGetValue(expression.Method, out var valueTransformer) &&
62+
expression.Arguments[0] is ConstantExpression constantExpression)
63+
{
64+
return Expression.Call(
65+
Like,
66+
expression.Object,
67+
Expression.Constant(valueTransformer(constantExpression.Value))
68+
);
69+
}
70+
71+
return expression;
72+
}
73+
74+
public ExpressionType[] SupportedExpressionTypes { get; } = {ExpressionType.Call};
75+
}
76+
}
77+
}

src/NHibernate.Test/TestCase.cs

+15-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ public abstract class TestCase
2929
private SchemaExport _schemaExport;
3030

3131
private static readonly ILog log = LogManager.GetLogger(typeof(TestCase));
32+
private static readonly FieldInfo PlanCacheField;
33+
34+
static TestCase()
35+
{
36+
PlanCacheField = typeof(QueryPlanCache)
37+
.GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance)
38+
?? throw new InvalidOperationException(
39+
"planCache field does not exist in QueryPlanCache.");
40+
}
3241

3342
protected Dialect.Dialect Dialect
3443
{
@@ -488,14 +497,14 @@ protected void AssumeFunctionSupported(string functionName)
488497
$"{dialect} doesn't support {functionName} standard function.");
489498
}
490499

491-
protected void ClearQueryPlanCache()
500+
protected SoftLimitMRUCache GetQueryPlanCache()
492501
{
493-
var planCacheField = typeof(QueryPlanCache)
494-
.GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance)
495-
?? throw new InvalidOperationException("planCache field does not exist in QueryPlanCache.");
502+
return (SoftLimitMRUCache) PlanCacheField.GetValue(Sfi.QueryPlanCache);
503+
}
496504

497-
var planCache = (SoftLimitMRUCache) planCacheField.GetValue(Sfi.QueryPlanCache);
498-
planCache.Clear();
505+
protected void ClearQueryPlanCache()
506+
{
507+
GetQueryPlanCache().Clear();
499508
}
500509

501510
protected Substitute<Dialect.Dialect> SubstituteDialect()

src/NHibernate/Cfg/Environment.cs

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using NHibernate.Cfg.ConfigurationSchema;
77
using NHibernate.Engine;
88
using NHibernate.Linq;
9+
using NHibernate.Linq.Visitors;
910
using NHibernate.Util;
1011

1112
namespace NHibernate.Cfg
@@ -282,6 +283,11 @@ public static string Version
282283

283284
public const string QueryModelRewriterFactory = "query.query_model_rewriter_factory";
284285

286+
/// <summary>
287+
/// The class name of the LINQ query pre-transformer registrar, implementing <see cref="IExpressionTransformerRegistrar"/>.
288+
/// </summary>
289+
public const string PreTransformerRegistrar = "query.pre_transformer_registrar";
290+
285291
/// <summary>
286292
/// Set the default length used in casting when the target type is length bound and
287293
/// does not specify it. <c>4000</c> by default, automatically trimmed down according to dialect type registration.

src/NHibernate/Cfg/Loquacious/DbIntegrationConfigurationProperties.cs

+9
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,15 @@ public void QueryModelRewriterFactory<TFactory>() where TFactory : IQueryModelRe
144144
configuration.SetProperty(Environment.QueryModelRewriterFactory, typeof(TFactory).AssemblyQualifiedName);
145145
}
146146

147+
/// <summary>
148+
/// Set the class of the LINQ query pre-transformer registrar.
149+
/// </summary>
150+
/// <typeparam name="TRegistrar">The class of the LINQ query pre-transformer registrar.</typeparam>
151+
public void PreTransformerRegistrar<TRegistrar>() where TRegistrar : IExpressionTransformerRegistrar
152+
{
153+
configuration.SetProperty(Environment.PreTransformerRegistrar, typeof(TRegistrar).AssemblyQualifiedName);
154+
}
155+
147156
#endregion
148157
}
149158
}

src/NHibernate/Cfg/Settings.cs

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Data;
4+
using System.Linq.Expressions;
45
using NHibernate.AdoNet;
56
using NHibernate.AdoNet.Util;
67
using NHibernate.Cache;
@@ -189,7 +190,14 @@ public Settings()
189190
public bool LinqToHqlFallbackOnPreEvaluation { get; internal set; }
190191

191192
public IQueryModelRewriterFactory QueryModelRewriterFactory { get; internal set; }
192-
193+
194+
/// <summary>
195+
/// The pre-transformer registrar used to register custom expression transformers.
196+
/// </summary>
197+
public IExpressionTransformerRegistrar PreTransformerRegistrar { get; internal set; }
198+
199+
internal Func<Expression, Expression> LinqPreTransformer { get; set; }
200+
193201
#endregion
194202

195203
internal string GetFullCacheRegionName(string name)

src/NHibernate/Cfg/SettingsFactory.cs

+28-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,14 @@ public Settings BuildSettings(IDictionary<string, string> properties)
310310
// Not ported - JdbcBatchVersionedData
311311

312312
settings.QueryModelRewriterFactory = CreateQueryModelRewriterFactory(properties);
313-
313+
settings.PreTransformerRegistrar = CreatePreTransformerRegistrar(properties);
314+
315+
// Avoid dependency on re-linq assembly when PreTransformerRegistrar is null
316+
if (settings.PreTransformerRegistrar != null)
317+
{
318+
settings.LinqPreTransformer = NhRelinqQueryParser.CreatePreTransformer(settings.PreTransformerRegistrar);
319+
}
320+
314321
// NHibernate-specific:
315322
settings.IsolationLevel = isolation;
316323

@@ -444,5 +451,25 @@ private static IQueryModelRewriterFactory CreateQueryModelRewriterFactory(IDicti
444451
throw new HibernateException("could not instantiate IQueryModelRewriterFactory: " + className, cnfe);
445452
}
446453
}
454+
455+
private static IExpressionTransformerRegistrar CreatePreTransformerRegistrar(IDictionary<string, string> properties)
456+
{
457+
var className = PropertiesHelper.GetString(Environment.PreTransformerRegistrar, properties, null);
458+
if (className == null)
459+
return null;
460+
461+
log.Info("Pre-transformer registrar: {0}", className);
462+
463+
try
464+
{
465+
return
466+
(IExpressionTransformerRegistrar)
467+
Environment.ObjectsFactory.CreateInstance(ReflectHelper.ClassForName(className));
468+
}
469+
catch (Exception e)
470+
{
471+
throw new HibernateException("could not instantiate IExpressionTransformerRegistrar: " + className, e);
472+
}
473+
}
447474
}
448475
}

src/NHibernate/Linq/NhRelinqQueryParser.cs

+12-8
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,9 @@ namespace NHibernate.Linq
2121
public static class NhRelinqQueryParser
2222
{
2323
private static readonly QueryParser QueryParser;
24-
private static readonly IExpressionTreeProcessor PreProcessor;
2524

2625
static NhRelinqQueryParser()
2726
{
28-
var preTransformerRegistry = new ExpressionTransformerRegistry();
29-
// NH-3247: must remove .Net compiler char to int conversion before
30-
// parameterization occurs.
31-
preTransformerRegistry.Register(new RemoveCharToIntConversion());
32-
PreProcessor = new TransformingExpressionTreeProcessor(preTransformerRegistry);
33-
3427
var transformerRegistry = ExpressionTransformerRegistry.CreateDefault();
3528
transformerRegistry.Register(new RemoveRedundantCast());
3629
transformerRegistry.Register(new SimplifyCompareTransformer());
@@ -78,7 +71,7 @@ public static PreTransformationResult PreTransform(Expression expression, PreTra
7871
.EvaluateIndependentSubtrees(expression, parameters);
7972

8073
return new PreTransformationResult(
81-
PreProcessor.Process(partiallyEvaluatedExpression),
74+
parameters.PreTransformer.Invoke(partiallyEvaluatedExpression),
8275
parameters.SessionFactory,
8376
parameters.QueryVariables);
8477
}
@@ -87,6 +80,17 @@ public static QueryModel Parse(Expression expression)
8780
{
8881
return QueryParser.GetParsedQuery(expression);
8982
}
83+
84+
internal static Func<Expression, Expression> CreatePreTransformer(IExpressionTransformerRegistrar expressionTransformerRegistrar)
85+
{
86+
var preTransformerRegistry = new ExpressionTransformerRegistry();
87+
// NH-3247: must remove .Net compiler char to int conversion before
88+
// parameterization occurs.
89+
preTransformerRegistry.Register(new RemoveCharToIntConversion());
90+
expressionTransformerRegistrar?.Register(preTransformerRegistry);
91+
92+
return new TransformingExpressionTreeProcessor(preTransformerRegistry).Process;
93+
}
9094
}
9195

9296
public class NHibernateNodeTypeProvider : INodeTypeProvider
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
2+
3+
namespace NHibernate.Linq.Visitors
4+
{
5+
/// <summary>
6+
/// Provides a way to register custom transformers for expressions.
7+
/// </summary>
8+
public interface IExpressionTransformerRegistrar
9+
{
10+
/// <summary>
11+
/// Registers additional transformers on the expression transformer registry.
12+
/// </summary>
13+
/// <param name="expressionTransformerRegistry">The expression transformer registry.</param>
14+
void Register(ExpressionTransformerRegistry expressionTransformerRegistry);
15+
}
16+
}

0 commit comments

Comments
 (0)