From 569293374117796e8f08a82a0899a9fc875d45b6 Mon Sep 17 00:00:00 2001 From: Raif Atef Date: Wed, 13 Oct 2021 06:20:03 +0200 Subject: [PATCH] Support raw SQL projections in Select and OrderBy statements. --- .../CoreFunctionality/query_by_sql.cs | 46 ++++++++ .../Linq/SqlProjection/SqlProjectionTests.cs | 17 +++ .../Linq/Parsing/SelectTransformBuilder.cs | 100 ++++++++++++++++-- src/Marten/Linq/SqlGeneration/Statement.cs | 19 +++- .../SqlProjection/SqlProjectionExtensions.cs | 18 ++++ .../SqlProjection/SqlProjectionSqlFragment.cs | 48 +++++++++ 6 files changed, 234 insertions(+), 14 deletions(-) create mode 100644 src/Marten.Testing/Linq/SqlProjection/SqlProjectionTests.cs create mode 100644 src/Marten/Linq/SqlProjection/SqlProjectionExtensions.cs create mode 100644 src/Marten/Linq/SqlProjection/SqlProjectionSqlFragment.cs diff --git a/src/Marten.Testing/CoreFunctionality/query_by_sql.cs b/src/Marten.Testing/CoreFunctionality/query_by_sql.cs index 78ec229382..7216a0d7f9 100644 --- a/src/Marten.Testing/CoreFunctionality/query_by_sql.cs +++ b/src/Marten.Testing/CoreFunctionality/query_by_sql.cs @@ -2,6 +2,7 @@ using System.Linq; using System.Threading.Tasks; using Marten.Linq.MatchesSql; +using Marten.Linq.SqlProjection; using Marten.Testing.Documents; using Marten.Testing.Harness; using Shouldly; @@ -324,6 +325,51 @@ public async Task query_with_select_in_query_async() } } + [Fact] + public async Task query_with_select_sql_projection_async() + { + using (var session = theStore.OpenSession()) + { + var u = new User {FirstName = "Jeremy", LastName = "Miller", Age = 1337}; + session.Store(u); + session.SaveChanges(); + + #region sample_using-sql-projection-queryasync + + var users = await session.Query() + .Select(x => new { Age = x.SqlProjection("(data->>'Age')::integer") }) + .OrderBy(x => x.Age) + .ToListAsync(); + var user = users.Single(); + + #endregion + + user.Age.ShouldBe(1337); + } + } + + [Fact] + public async Task query_with_order_by_sql_projection_async() + { + using (var session = theStore.OpenSession()) + { + var u = new User {FirstName = "Jeremy", LastName = "Miller"}; + session.Store(u); + session.SaveChanges(); + + #region sample_using-sql-projection-queryasync + + var users = await session.Query() + .OrderBy(x => (object)x.SqlProjection("data->>'FirstName'")) + .ToListAsync(); + var user = users.Single(); + + #endregion + + user.FirstName.ShouldBe("Jeremy"); + } + } + [Fact] public async Task get_sum_of_integers_asynchronously() { diff --git a/src/Marten.Testing/Linq/SqlProjection/SqlProjectionTests.cs b/src/Marten.Testing/Linq/SqlProjection/SqlProjectionTests.cs new file mode 100644 index 0000000000..ba30ecfa8f --- /dev/null +++ b/src/Marten.Testing/Linq/SqlProjection/SqlProjectionTests.cs @@ -0,0 +1,17 @@ +using System; +using Marten.Linq.SqlProjection; +using Shouldly; +using Xunit; + +namespace Marten.Testing.Linq.SqlProjection +{ + public class SqlProjectionTests + { + [Fact] + public void Throws_NotSupportedException_when_called_directly() + { + Should.Throw( + () => new object().SqlProjection("COALESCE(d.data ->> 'UserName', ?)", "baz")); + } + } +} diff --git a/src/Marten/Linq/Parsing/SelectTransformBuilder.cs b/src/Marten/Linq/Parsing/SelectTransformBuilder.cs index 21018d677d..e6c1ec71f3 100644 --- a/src/Marten/Linq/Parsing/SelectTransformBuilder.cs +++ b/src/Marten/Linq/Parsing/SelectTransformBuilder.cs @@ -6,14 +6,16 @@ using System.Reflection; using Baseline; using Marten.Linq.Fields; +using Marten.Linq.SqlProjection; using Remotion.Linq.Parsing; +using Weasel.Postgresql.SqlGeneration; namespace Marten.Linq.Parsing { internal class SelectTransformBuilder : RelinqExpressionVisitor { private TargetObject _target; - private SelectedField _currentField; + private BindingTarget _currentTarget; public SelectTransformBuilder(Expression clause, IFieldMapping fields, ISerializer serializer) { @@ -35,7 +37,7 @@ protected override Expression VisitNew(NewExpression expression) for (var i = 0; i < parameters.Length; i++) { - _currentField = _target.StartBinding(parameters[i].Name); + _currentTarget = _target.StartBinding(parameters[i].Name); Visit(expression.Arguments[i]); } @@ -44,21 +46,76 @@ protected override Expression VisitNew(NewExpression expression) protected override Expression VisitMember(MemberExpression node) { - _currentField.Add(node.Member); + _currentTarget.AddMember(node.Member); return base.VisitMember(node); } protected override MemberBinding VisitMemberBinding(MemberBinding node) { - _currentField = _target.StartBinding(node.Member.Name); + _currentTarget = _target.StartBinding(node.Member.Name); return base.VisitMemberBinding(node); } + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var fragment = SqlProjectionSqlFragment.TryParse(node); + if (fragment == null) + { + throw new NotSupportedException( + $"Method {node.Method.DeclaringType?.FullName}.{node.Method.Name} is not supported."); + } + + _currentTarget.AddSqlProjection(fragment); + + return base.VisitMethodCall(node); + } + + public class BindingTarget : TargetObject.ISetterBinding + { + private readonly string _name; + private TargetObject.SetterBinding _field; + private TargetObject.SqlProjectionBinding _sqlProjection; + + public BindingTarget(string name) + { + _name = name; + } + + public void AddMember(MemberInfo memberInfo) + { + if (_sqlProjection != null) + { + throw new InvalidOperationException( + "Cannot bind to a member after having bound to a sql projection"); + } + + _field ??= new TargetObject.SetterBinding(_name); + _field.Field.Add(memberInfo); + } + + public void AddSqlProjection(ISqlFragment sqlProjectionClause) + { + if (_field != null) + { + throw new InvalidOperationException( + "Cannot bind to a sql projection after having bound to a member."); + } + + _sqlProjection = new TargetObject.SqlProjectionBinding(_name, sqlProjectionClause); + } + + public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer) + { + return _field?.ToJsonBuildObjectPair(mapping, serializer) + ?? _sqlProjection?.ToJsonBuildObjectPair(mapping, serializer) + ?? string.Empty; + } + } public class TargetObject { - private readonly IList _setters = new List(); + private readonly IList _setters = new List(); public TargetObject(Type type) { @@ -67,12 +124,11 @@ public TargetObject(Type type) public Type Type { get; } - public SelectedField StartBinding(string bindingName) + public BindingTarget StartBinding(string bindingName) { - var setter = new SetterBinding(bindingName); - _setters.Add(setter); - - return setter.Field; + var bindingTarget = new BindingTarget(bindingName); + _setters.Add(bindingTarget); + return bindingTarget; } public string ToSelectField(IFieldMapping fields, ISerializer serializer) @@ -81,7 +137,12 @@ public string ToSelectField(IFieldMapping fields, ISerializer serializer) return $"jsonb_build_object({jsonBuildObjectArgs})"; } - private class SetterBinding + public interface ISetterBinding + { + string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer); + } + + public class SetterBinding: ISetterBinding { public SetterBinding(string name) { @@ -101,6 +162,23 @@ public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serialize return $"'{Name}', {locator}"; } } + + public class SqlProjectionBinding: ISetterBinding + { + public SqlProjectionBinding(string name, ISqlFragment projectionFragment) + { + Name = name; + ProjectionFragment = projectionFragment; + } + + private string Name { get; } + private ISqlFragment ProjectionFragment { get; } + + public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer) + { + return $"'{Name}', ({ProjectionFragment.ToSql()})"; + } + } } public class SelectedField: IEnumerable diff --git a/src/Marten/Linq/SqlGeneration/Statement.cs b/src/Marten/Linq/SqlGeneration/Statement.cs index 7c9ba28784..50076dad61 100644 --- a/src/Marten/Linq/SqlGeneration/Statement.cs +++ b/src/Marten/Linq/SqlGeneration/Statement.cs @@ -1,9 +1,12 @@ +using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using Baseline; using Marten.Internal; using Marten.Linq.Fields; using Marten.Linq.Parsing; +using Marten.Linq.SqlProjection; using Weasel.Postgresql; using Npgsql; using Remotion.Linq.Clauses; @@ -95,9 +98,19 @@ protected void writeWhereClause(CommandBuilder sql) protected void writeOrderByFragment(CommandBuilder sql, Ordering clause) { - var field = Fields.FieldFor(clause.Expression); - var locator = field.ToOrderExpression(clause.Expression); - sql.Append(locator); + var expression = clause.Expression; + + var sqlProjectionFragment = SqlProjectionSqlFragment.TryParse(expression); + if (sqlProjectionFragment != null) + { + sqlProjectionFragment.Apply(sql); + } + else + { + var field = Fields.FieldFor(expression); + var locator = field.ToOrderExpression(expression); + sql.Append(locator); + } if (clause.OrderingDirection == OrderingDirection.Desc) sql.Append(" desc"); } diff --git a/src/Marten/Linq/SqlProjection/SqlProjectionExtensions.cs b/src/Marten/Linq/SqlProjection/SqlProjectionExtensions.cs new file mode 100644 index 0000000000..c103d85c8d --- /dev/null +++ b/src/Marten/Linq/SqlProjection/SqlProjectionExtensions.cs @@ -0,0 +1,18 @@ +using System; +using System.Reflection; + +namespace Marten.Linq.SqlProjection +{ + public static class SqlProjectionExtensions + { + public static readonly MethodInfo MethodInfo = typeof(SqlProjectionExtensions) + .GetMethod(nameof(SqlProjection), + BindingFlags.Public | BindingFlags.Static); + + public static T SqlProjection(this object doc, string sql, params object[] parameters) + { + throw new NotSupportedException( + $"{nameof(SqlProjection)} extension method can only be used in Marten Linq queries."); + } + } +} diff --git a/src/Marten/Linq/SqlProjection/SqlProjectionSqlFragment.cs b/src/Marten/Linq/SqlProjection/SqlProjectionSqlFragment.cs new file mode 100644 index 0000000000..b34150587c --- /dev/null +++ b/src/Marten/Linq/SqlProjection/SqlProjectionSqlFragment.cs @@ -0,0 +1,48 @@ +using System; +using System.Linq.Expressions; +using Weasel.Postgresql.SqlGeneration; + +namespace Marten.Linq.SqlProjection +{ + public static class SqlProjectionSqlFragment + { + public static ISqlFragment TryParse(Expression expression, Func visit = null) + { + if (expression == null) + { + return null; + } + + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression) + { + expression = unaryExpression.Operand; + } + + visit ??= x => x; + + if (expression is not MethodCallExpression methodCall) + { + return null; + } + + if (!methodCall.Method.IsGenericMethod || + methodCall.Method.GetGenericMethodDefinition() != SqlProjectionExtensions.MethodInfo) + { + return null; + } + + if (visit(methodCall.Arguments[1]) is not ConstantExpression { Value: string sql }) + { + throw new NotSupportedException("SqlProjection first parameter needs to resolve to a string"); + } + + if (visit(methodCall.Arguments[2]) is not ConstantExpression { Value: object[] sqlArguments }) + { + throw new NotSupportedException("SqlProjection second parameter needs to resolve to an object[]"); + } + + var whereFragment = new WhereFragment(sql, sqlArguments); + return whereFragment; + } + } +}