Skip to content

Enable vector tests and validate ref assembly for vector APIs #3559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private SqlVector(int length)
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml' path='docs/members[@name="SqlVector"]/CreateNull/*' />
public static SqlVector<T> CreateNull(int length) => new(length);

/// <include file='../../../../doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml' path='docs/members[@name="SqlVector"]/ctor2/*' />
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml' path='docs/members[@name="SqlVector"]/ctor1/*' />
public SqlVector(ReadOnlyMemory<T> memory)
{
(_elementType, _elementSize) = GetTypeFieldsOrThrow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ public static class DataTestUtility
//SQL Server EngineEdition
private static string s_sqlServerEngineEdition;

// Currently, only Azure SQL supports vectors and JSON.
// Our CI images with specific SQL Server versions lag
// behind with vector and JSON support.
// JSON Column type
public static readonly bool IsJsonSupported = false;

public static readonly bool IsJsonSupported = !IsNotAzureServer();
// VECTOR column type
public static readonly bool IsVectorSupported = false;
public static readonly bool IsVectorSupported = !IsNotAzureServer();

// Azure Synapse EngineEditionId == 6
// More could be read at https://learn.microsoft.com/en-us/sql/t-sql/functions/serverproperty-transact-sql?view=sql-server-ver16#propertyname
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
<Compile Include="SQL\DataSourceParserTest\DataSourceParserTest.cs" />
<Compile Include="SQL\InstanceNameTest\InstanceNameTest.cs" />
<Compile Include="SQL\IntegratedAuthenticationTest\IntegratedAuthenticationTest.cs" />
<Compile Include="SQL\JsonTest\JsonBulkCopyTest.cs" />
<Compile Include="SQL\JsonTest\JsonStreamTest.cs" />
<Compile Include="SQL\JsonTest\JsonTest.cs" />
<Compile Include="SQL\KerberosTests\KerberosTest.cs" />
<Compile Include="SQL\KerberosTests\KerberosTicketManager\KerberosTicketManager.cs" />
<Compile Include="SQL\LocalDBTest\LocalDBTest.cs" />
Expand Down Expand Up @@ -213,6 +216,9 @@
<Compile Include="SQL\UdtTest\UdtTest2.cs" />
<Compile Include="SQL\UdtTest\UdtTestHelpers.cs" />
<Compile Include="SQL\Utf8SupportTest\Utf8SupportTest.cs" />
<Compile Include="SQL\VectorTest\VectorTypeBackwardCompatibilityTests.cs" />
<Compile Include="SQL\VectorTest\NativeVectorFloat32Tests.cs" />
<Compile Include="SQL\VectorTest\VectorAPIValidationTest.cs" />
<Compile Include="SQL\WeakRefTest\WeakRefTest.cs" />
<Compile Include="SQL\WeakRefTestYukonSpecific\WeakRefTestYukonSpecific.cs" />
<Compile Include="SQL\ParameterTest\DateTimeVariantTest.cs" />
Expand Down Expand Up @@ -294,11 +300,6 @@
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTest.cs" />
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTestWithTdsServer.cs" />
<Compile Include="SQL\JsonTest\JsonBulkCopyTest.cs" />
<Compile Include="SQL\JsonTest\JsonStreamTest.cs" />
<Compile Include="SQL\JsonTest\JsonTest.cs" />
<Compile Include="SQL\VectorTest\VectorTypeBackwardCompatibilityTests.cs" />
<Compile Include="SQL\VectorTest\NativeVectorFloat32Tests.cs" />
<Compile Include="SQL\SqlCommand\SqlCommandStoredProcTest.cs" />
<Compile Include="TracingTests\TestTdsServer.cs" />
<Compile Include="XUnitAssemblyAttributes.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static class VectorFloat32TestData
public static float[] testData = new float[] { 1.1f, 2.2f, 3.3f };
public static int vectorColumnLength = testData.Length;
// Incorrect size for SqlParameter.Size
public static int IncorrectParamSize = 3234;
public static int IncorrectParamSize = 3234;
public static IEnumerable<object[]> GetVectorFloat32TestData()
{
// Pattern 1-4 with SqlVector<float>(values: testData)
Expand All @@ -43,11 +43,11 @@ public static IEnumerable<object[]> GetVectorFloat32TestData()

// Pattern 1-4 with SqlVector<float>.Null
yield return new object[] { 1, SqlVector<float>.Null, Array.Empty<float>(), vectorColumnLength };

// Following scenario is not supported in SqlClient.
// This can only be fixed with a behavior change that SqlParameter.Value is internally set to DBNull.Value if it is set to null.
//yield return new object[] { 2, SqlVector<float>.Null, Array.Empty<float>(), vectorColumnLength };

yield return new object[] { 3, SqlVector<float>.Null, Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 4, SqlVector<float>.Null, Array.Empty<float>(), vectorColumnLength };
}
Expand Down Expand Up @@ -128,7 +128,7 @@ private void ValidateInsertedData(SqlConnection connection, float[] expectedData
ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector<float>)reader.GetSqlValue(0), expectedData, expectedLength);

if (!reader.IsDBNull(0))
{
{
ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector<float>)reader.GetValue(0), expectedData, expectedLength);
ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector<float>)reader[0], expectedData, expectedLength);
ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector<float>)reader["VectorData"], expectedData, expectedLength);
Expand All @@ -148,7 +148,7 @@ private void ValidateInsertedData(SqlConnection connection, float[] expectedData
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public void TestSqlVectorFloat32ParameterInsertionAndReads(
int pattern,
object value,
Expand Down Expand Up @@ -214,7 +214,7 @@ private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] e
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync(
int pattern,
object value,
Expand Down Expand Up @@ -248,7 +248,7 @@ public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync(
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public void TestStoredProcParamsForVectorFloat32(
int pattern,
object value,
Expand Down Expand Up @@ -305,7 +305,7 @@ public void TestStoredProcParamsForVectorFloat32(
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))]
[MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)]
public async Task TestStoredProcParamsForVectorFloat32Async(
int pattern,
object value,
Expand Down Expand Up @@ -374,8 +374,8 @@ public void TestBulkCopyFromSqlTable(int bulkCopySourceMode)
DataTable table = null;
switch (bulkCopySourceMode)
{
case 1:

case 1:
// Use SqlServer table as source
var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection);
var vectorParam = new SqlParameter(s_vectorParamName, new SqlVector<float>(VectorFloat32TestData.testData));
Expand All @@ -400,8 +400,8 @@ public void TestBulkCopyFromSqlTable(int bulkCopySourceMode)
throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}");
}



//Bulkcopy from sql server table to destination table
using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection);
using SqlDataReader reader = sourceDataCommand.ExecuteReader();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.Data.SqlTypes;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest
{
public sealed class VectorAPIValidationTest
{
// We need this testcase to validate ref assembly for vector APIs
// Unit tests are covered under SqlVectorTest.cs
[Fact]
public void VectorAPITest()
{
// Validate that SqlVector<float> is a valid type and has valid SqlDbType
Assert.True(typeof(SqlVector<float>).IsValueType, "SqlVector<float> should be a value type.");
Assert.Equal(36, (int)SqlDbTypeExtensions.Vector);

// Validate ctor1 with float[] : public SqlVector(System.ReadOnlyMemory<T> memory) { }
var vector = new SqlVector<float>(VectorFloat32TestData.testData);
Assert.Equal(VectorFloat32TestData.testData, vector.Memory.ToArray());
Assert.Equal(3, vector.Length);

// Validate ctor2 with ReadOnlyMemory<T> : public SqlVector(ReadOnlyMemory<T> memory) { }
vector = new SqlVector<float>(new ReadOnlyMemory<float>(VectorFloat32TestData.testData));
Assert.Equal(VectorFloat32TestData.testData, vector.Memory.ToArray());
Assert.Equal(3, vector.Length);

//Validate IsNull property
Assert.False(vector.IsNull, "IsNull should be false for non-null vector.");

// Validate Null property returns null
Assert.Null(SqlVector<float>.Null);

//Validate length property
Assert.Equal(3, vector.Length);

// Validate CreateNull method
vector = SqlVector<float>.CreateNull(5);
Assert.True(vector.IsNull);
Assert.Equal(5, vector.Length);
}
}
}
Loading