Skip to content

Commit 34c1aa9

Browse files
authored
CSHARP-5527: Add support for $sigmoid expression in LINQ (#1638)
1 parent 388c177 commit 34c1aa9

File tree

7 files changed

+157
-0
lines changed

7 files changed

+157
-0
lines changed

src/MongoDB.Driver/Core/Misc/Feature.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public class Feature
9292
private static readonly Feature __setWindowFields = new Feature("SetWindowFields", WireVersion.Server50);
9393
private static readonly Feature __setWindowFieldsLocf = new Feature("SetWindowFieldsLocf", WireVersion.Server52);
9494
private static readonly Feature __shardedTransactions = new Feature("ShardedTransactions", WireVersion.Server42);
95+
private static readonly Feature __sigmoidOperator = new Feature("SigmoidOperator", WireVersion.Server81);
9596
private static readonly Feature __snapshotReads = new Feature("SnapshotReads", WireVersion.Server50, notSupportedMessage: "Snapshot reads require MongoDB 5.0 or later");
9697
private static readonly Feature __sortArrayOperator = new Feature("SortArrayOperator", WireVersion.Server52);
9798
private static readonly Feature __speculativeAuthentication = new Feature("SpeculativeAuthentication", WireVersion.Server44);
@@ -437,6 +438,11 @@ public class Feature
437438
/// </summary>
438439
public static Feature ShardedTransactions => __shardedTransactions;
439440

441+
/// <summary>
442+
/// Gets the $sigmoid operator feature.
443+
/// </summary>
444+
public static Feature SigmoidOperator => __sigmoidOperator;
445+
440446
/// <summary>
441447
/// Gets the snapshot reads feature.
442448
/// </summary>

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstUnaryOperator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ internal enum AstUnaryOperator
5858
Round,
5959
SetIntersection,
6060
SetUnion,
61+
Sigmoid,
6162
Sin,
6263
Sinh,
6364
Size,
@@ -161,6 +162,7 @@ public static string Render(this AstUnaryOperator @operator)
161162
AstUnaryOperator.Round => "$round",
162163
AstUnaryOperator.SetIntersection => "$setIntersection",
163164
AstUnaryOperator.SetUnion => "$setUnion",
165+
AstUnaryOperator.Sigmoid => "$sigmoid",
164166
AstUnaryOperator.Sin => "$sin",
165167
AstUnaryOperator.Sinh => "$sinh",
166168
AstUnaryOperator.Size => "$size",

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ internal static class MqlMethod
3333
private static readonly MethodInfo __field;
3434
private static readonly MethodInfo __isMissing;
3535
private static readonly MethodInfo __isNullOrMissing;
36+
private static readonly MethodInfo __sigmoid;
3637

3738
// static constructor
3839
static MqlMethod()
@@ -47,6 +48,7 @@ static MqlMethod()
4748
__field = ReflectionInfo.Method((object container, string fieldName, IBsonSerializer<object> serializer) => Mql.Field<object, object>(container, fieldName, serializer));
4849
__isMissing = ReflectionInfo.Method((object field) => Mql.IsMissing(field));
4950
__isNullOrMissing = ReflectionInfo.Method((object field) => Mql.IsNullOrMissing(field));
51+
__sigmoid = ReflectionInfo.Method((double value) => Mql.Sigmoid(value));
5052
}
5153

5254
// public properties
@@ -60,5 +62,6 @@ static MqlMethod()
6062
public static MethodInfo Field => __field;
6163
public static MethodInfo IsMissing => __isMissing;
6264
public static MethodInfo IsNullOrMissing => __isNullOrMissing;
65+
public static MethodInfo Sigmoid => __sigmoid;
6366
}
6467
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
7777
case "SequenceEqual": return SequenceEqualMethodToAggregationExpressionTranslator.Translate(context, expression);
7878
case "SetEquals": return SetEqualsMethodToAggregationExpressionTranslator.Translate(context, expression);
7979
case "Shift": return ShiftMethodToAggregationExpressionTranslator.Translate(context, expression);
80+
case "Sigmoid": return SigmoidMethodToAggregationExpressionTranslator.Translate(context, expression);
8081
case "Split": return SplitMethodToAggregationExpressionTranslator.Translate(context, expression);
8182
case "Sqrt": return SqrtMethodToAggregationExpressionTranslator.Translate(context, expression);
8283
case "StrLenBytes": return StrLenBytesMethodToAggregationExpressionTranslator.Translate(context, expression);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using System.Linq.Expressions;
18+
using MongoDB.Bson.Serialization.Serializers;
19+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
22+
23+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
24+
{
25+
internal static class SigmoidMethodToAggregationExpressionTranslator
26+
{
27+
public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression)
28+
{
29+
var method = expression.Method;
30+
var arguments = expression.Arguments;
31+
32+
if (method.Is(MqlMethod.Sigmoid))
33+
{
34+
var valueExpression = arguments.Single();
35+
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
36+
SerializationHelper.EnsureRepresentationIsNumeric(expression, valueExpression, valueTranslation);
37+
38+
return new TranslatedExpression(
39+
expression,
40+
AstExpression.Unary(AstUnaryOperator.Sigmoid, valueTranslation.Ast),
41+
DoubleSerializer.Instance);
42+
}
43+
44+
throw new ExpressionNotSupportedException(expression);
45+
}
46+
}
47+
}

src/MongoDB.Driver/Mql.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,15 @@ public static bool IsNullOrMissing<TField>(TField field)
152152
{
153153
throw CustomLinqExtensionMethodHelper.CreateNotSupportedException();
154154
}
155+
156+
/// <summary>
157+
/// Transforms a real-valued input into a value between 0 and 1 using the $sigmoid operator.
158+
/// </summary>
159+
/// <param name="value">The input value.</param>
160+
/// <returns>The transformed value.</returns>
161+
public static double Sigmoid(double value)
162+
{
163+
throw CustomLinqExtensionMethodHelper.CreateNotSupportedException();
164+
}
155165
}
156166
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Linq;
19+
using FluentAssertions;
20+
using MongoDB.Bson;
21+
using MongoDB.Bson.Serialization.Attributes;
22+
using MongoDB.Driver.Core.Misc;
23+
using MongoDB.Driver.Linq;
24+
using MongoDB.Driver.TestHelpers;
25+
using Xunit;
26+
27+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
28+
{
29+
public class SigmoidMethodToAggregationExpressionTranslatorTests : LinqIntegrationTest<SigmoidMethodToAggregationExpressionTranslatorTests.ClassFixture>
30+
{
31+
public SigmoidMethodToAggregationExpressionTranslatorTests(ClassFixture fixture)
32+
: base(fixture, server => server.Supports(Feature.SigmoidOperator))
33+
{
34+
}
35+
36+
[Fact]
37+
public void Sigmoid_should_work()
38+
{
39+
var collection = Fixture.Collection;
40+
41+
var queryable = collection
42+
.AsQueryable()
43+
.Select(x => Mql.Sigmoid(x.X));
44+
45+
var stages = Translate(collection, queryable);
46+
AssertStages(stages, "{ $project : { _v : { $sigmoid : '$X' }, _id : 0 } }");
47+
48+
var result = queryable.ToList();
49+
result.Should().BeEquivalentTo(new[] { 0.7310585786300049, 0.9933071490757153, 0.999997739675702, 0.9999999992417439});
50+
}
51+
52+
[Fact]
53+
public void Sigmoid_with_non_numeric_representation_should_throw()
54+
{
55+
var exception = Record.Exception(() =>
56+
{
57+
var collection = Fixture.Collection;
58+
59+
var queryable = collection
60+
.AsQueryable()
61+
.Select(x => Mql.Sigmoid(x.Y));
62+
63+
Translate(collection, queryable);
64+
});
65+
66+
exception.Should().BeOfType<ExpressionNotSupportedException>();
67+
exception?.Message.Should().Contain("uses a non-numeric representation");
68+
}
69+
70+
public class C
71+
{
72+
[BsonRepresentation(BsonType.String)]
73+
public double Y { get; set; }
74+
public double X { get; set; }
75+
}
76+
77+
public sealed class ClassFixture : MongoCollectionFixture<C>
78+
{
79+
protected override IEnumerable<C> InitialData =>
80+
[
81+
new() { X = 1.0 },
82+
new() { X = 5.0 },
83+
new() { X = 13.0 },
84+
new() { X = 21.0 },
85+
];
86+
}
87+
}
88+
}

0 commit comments

Comments
 (0)