diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index ae4cd58f96..84913d6fdc 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -3,7 +3,7 @@ version: '3' services: npgsql-dev: # Source for tags: https://mcr.microsoft.com/v2/dotnet/sdk/tags/list - image: mcr.microsoft.com/dotnet/sdk:8.0.100-preview.6 + image: mcr.microsoft.com/dotnet/sdk:8.0.100-preview.7 volumes: - ..:/workspace:cached tty: true diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 065f6a2924..66e0ff5bd6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ concurrency: cancel-in-progress: true env: - dotnet_sdk_version: '8.0.100-preview.6.23330.14' + dotnet_sdk_version: '8.0.100-rc.1.23463.5' postgis_version: 3 DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true # Windows comes with PG pre-installed, and defines the PGPASSWORD environment variable. Remove it as it interferes @@ -315,7 +315,6 @@ jobs: run: | if [ -z "${{ matrix.pg_prerelease }}" ]; then dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.PluginTests --logger "GitHubActions;report-warnings=false" - dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.NodaTime.Tests --logger "GitHubActions;report-warnings=false" fi shell: bash diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index a4b0f44bbb..93c46a180c 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -32,7 +32,7 @@ concurrency: cancel-in-progress: true env: - dotnet_sdk_version: '8.0.100-preview.6.23330.14' + dotnet_sdk_version: '8.0.100-rc.1.23463.5' DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true jobs: diff --git a/.github/workflows/native-aot.yml b/.github/workflows/native-aot.yml index 6b58f5f7cd..1c6ecce4ba 100644 --- a/.github/workflows/native-aot.yml +++ b/.github/workflows/native-aot.yml @@ -15,31 +15,32 @@ concurrency: cancel-in-progress: true env: - dotnet_sdk_version: '8.0.100-preview.6.23330.14' + dotnet_sdk_version: '8.0.100-rc.1.23463.5' DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true - nuget_config: | - - - - - - - - - - - - - - - - - - - - - - + # Uncomment and edit the following to use nightly/preview builds +# nuget_config: | +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# jobs: build: runs-on: ${{ matrix.os }} @@ -68,8 +69,8 @@ jobs: dotnet-version: | ${{ env.dotnet_sdk_version }} - - name: Setup nuget config - run: echo "$nuget_config" > NuGet.config +# - name: Setup nuget config +# run: echo "$nuget_config" > NuGet.config - name: Setup Native AOT prerequisites run: sudo apt-get install clang zlib1g-dev @@ -108,6 +109,20 @@ jobs: path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.mstat" retention-days: 3 + - name: Upload codedgen dgml + uses: actions/upload-artifact@v3.1.2 + with: + name: npgsql.codegen.dgml.xml + path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.codegen.dgml.xml" + retention-days: 3 + + - name: Upload scan dgml + uses: actions/upload-artifact@v3.1.2 + with: + name: npgsql.scan.dgml.xml + path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.scan.dgml.xml" + retention-days: 3 + - name: Assert binary size run: | size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" diff --git a/.github/workflows/rich-code-nav.yml b/.github/workflows/rich-code-nav.yml index bc2db9b271..e47a8a2adb 100644 --- a/.github/workflows/rich-code-nav.yml +++ b/.github/workflows/rich-code-nav.yml @@ -9,7 +9,7 @@ on: - '*' env: - dotnet_sdk_version: '8.0.100-preview.6.23330.14' + dotnet_sdk_version: '8.0.100-rc.1.23463.5' DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true jobs: diff --git a/Directory.Packages.props b/Directory.Packages.props index 41220a98ac..2648b64175 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -1,27 +1,40 @@ + + 8.0.0-rc.1.23419.4 + $(SystemVersion) + + - - - - + + - - - - + + + + + + + - + + + + + + + + - - + + @@ -29,7 +42,7 @@ - + @@ -39,20 +52,4 @@ - - - - - - - - - - - - - - - - diff --git a/Npgsql.sln b/Npgsql.sln index 007681d5bb..80ef02c3a8 100644 --- a/Npgsql.sln +++ b/Npgsql.sln @@ -37,8 +37,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.SourceGenerators", "src\Npgsql.SourceGenerators\Npgsql.SourceGenerators.csproj", "{63026A19-60B8-4906-81CB-216F30E8094B}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.NodaTime.Tests", "test\Npgsql.NodaTime.Tests\Npgsql.NodaTime.Tests.csproj", "{C00D2EB1-5719-4372-9E1C-5ED05DC23A00}" -EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.OpenTelemetry", "src\Npgsql.OpenTelemetry\Npgsql.OpenTelemetry.csproj", "{DA29F063-1828-47D8-B051-800AF7C9A0BE}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Github", "Github", "{BA7B6F53-D24D-45AC-927A-266857EA8D1E}" @@ -144,14 +142,6 @@ Global {63026A19-60B8-4906-81CB-216F30E8094B}.Release|Any CPU.Build.0 = Release|Any CPU {63026A19-60B8-4906-81CB-216F30E8094B}.Release|x86.ActiveCfg = Release|Any CPU {63026A19-60B8-4906-81CB-216F30E8094B}.Release|x86.Build.0 = Release|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Debug|Any CPU.Build.0 = Debug|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Debug|x86.ActiveCfg = Debug|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Debug|x86.Build.0 = Debug|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Release|Any CPU.ActiveCfg = Release|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Release|Any CPU.Build.0 = Release|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Release|x86.ActiveCfg = Release|Any CPU - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00}.Release|x86.Build.0 = Release|Any CPU {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|Any CPU.Build.0 = Debug|Any CPU {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|x86.ActiveCfg = Debug|Any CPU @@ -199,7 +189,6 @@ Global {F7C53EBD-0075-474F-A083-419257D04080} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {A77E5FAF-D775-4AB4-8846-8965C2104E60} = {ED612DB1-AB32-4603-95E7-891BACA71C39} {63026A19-60B8-4906-81CB-216F30E8094B} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {C00D2EB1-5719-4372-9E1C-5ED05DC23A00} = {ED612DB1-AB32-4603-95E7-891BACA71C39} {DA29F063-1828-47D8-B051-800AF7C9A0BE} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {BA7B6F53-D24D-45AC-927A-266857EA8D1E} = {004A2E0F-D34A-44D4-8DF0-D2BC63B57073} {B58E12EB-E43D-4D77-894E-5157D2269836} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} diff --git a/src/Npgsql.GeoJSON/Internal/CrsMap.WellKnown.cs b/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs similarity index 99% rename from src/Npgsql.GeoJSON/Internal/CrsMap.WellKnown.cs rename to src/Npgsql.GeoJSON/CrsMap.WellKnown.cs index 9d733830ea..14da2f893e 100644 --- a/src/Npgsql.GeoJSON/Internal/CrsMap.WellKnown.cs +++ b/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs @@ -1,6 +1,6 @@ -namespace Npgsql.GeoJSON.Internal; +namespace Npgsql.GeoJSON; -readonly partial struct CrsMap +public partial class CrsMap { /// /// These entries came from spatial_res_sys. They are used to elide memory allocations @@ -586,4 +586,4 @@ readonly partial struct CrsMap new(32766, 32766, "EPSG"), new(900913, 900913, "spatialreferencing.org"), }; -} \ No newline at end of file +} diff --git a/src/Npgsql.GeoJSON/CrsMap.cs b/src/Npgsql.GeoJSON/CrsMap.cs new file mode 100644 index 0000000000..dd556d9b33 --- /dev/null +++ b/src/Npgsql.GeoJSON/CrsMap.cs @@ -0,0 +1,59 @@ + +namespace Npgsql.GeoJSON; + +/// +/// A map of entries that map the authority to the inclusive range of SRID. +/// +public partial class CrsMap +{ + readonly CrsMapEntry[]? _overriden; + + internal CrsMap(CrsMapEntry[]? overriden) + => _overriden = overriden; + + internal string? GetAuthority(int srid) + => GetAuthority(_overriden, srid) ?? GetAuthority(WellKnown, srid); + + static string? GetAuthority(CrsMapEntry[]? entries, int srid) + { + if (entries == null) + return null; + + var left = 0; + var right = entries.Length; + while (left <= right) + { + var middle = left + (right - left) / 2; + var entry = entries[middle]; + + if (srid < entry.MinSrid) + right = middle - 1; + else + if (srid > entry.MaxSrid) + left = middle + 1; + else + return entry.Authority; + } + + return null; + } +} + +/// +/// An entry which maps the authority to the inclusive range of SRID. +/// +readonly struct CrsMapEntry +{ + internal readonly int MinSrid; + internal readonly int MaxSrid; + internal readonly string? Authority; + + internal CrsMapEntry(int minSrid, int maxSrid, string? authority) + { + MinSrid = minSrid; + MaxSrid = maxSrid; + Authority = authority != null + ? string.IsInterned(authority) ?? authority + : null; + } +} diff --git a/src/Npgsql.GeoJSON/CrsMapExtensions.cs b/src/Npgsql.GeoJSON/CrsMapExtensions.cs new file mode 100644 index 0000000000..329b7d9265 --- /dev/null +++ b/src/Npgsql.GeoJSON/CrsMapExtensions.cs @@ -0,0 +1,50 @@ +using System.Threading.Tasks; +using Npgsql.GeoJSON.Internal; + +namespace Npgsql.GeoJSON; + +/// +/// Extensions for getting a CrsMap from a database. +/// +public static class CrsMapExtensions +{ + /// + /// Gets the full crs details from the database. + /// + /// + public static async Task GetCrsMapAsync(this NpgsqlDataSource dataSource) + { + var builder = new CrsMapBuilder(); + using var cmd = GetCsrCommand(dataSource); + await using var reader = await cmd.ExecuteReaderAsync(); + + while (await reader.ReadAsync()) + builder.Add(new CrsMapEntry(reader.GetInt32(0), reader.GetInt32(1), reader.GetString(2))); + + return builder.Build(); + } + + /// + /// Gets the full crs details from the database. + /// + /// + public static CrsMap GetCrsMap(this NpgsqlDataSource dataSource) + { + var builder = new CrsMapBuilder(); + using var cmd = GetCsrCommand(dataSource); + using var reader = cmd.ExecuteReader(); + + while (reader.Read()) + builder.Add(new CrsMapEntry(reader.GetInt32(0), reader.GetInt32(1), reader.GetString(2))); + + return builder.Build(); + } + + static NpgsqlCommand GetCsrCommand(NpgsqlDataSource dataSource) + => dataSource.CreateCommand(""" + SELECT min(srid), max(srid), auth_name + FROM(SELECT srid, auth_name, srid - rank() OVER(PARTITION BY auth_name ORDER BY srid) AS range FROM spatial_ref_sys) AS s + GROUP BY range, auth_name + ORDER BY 1; + """); +} diff --git a/src/Npgsql.GeoJSON/Internal/CrsMap.cs b/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs similarity index 52% rename from src/Npgsql.GeoJSON/Internal/CrsMap.cs rename to src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs index aa7dc58e2d..44829761c9 100644 --- a/src/Npgsql.GeoJSON/Internal/CrsMap.cs +++ b/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs @@ -2,25 +2,6 @@ namespace Npgsql.GeoJSON.Internal; -/// -/// An entry which maps the authority to the inclusive range of SRID. -/// -readonly struct CrsMapEntry -{ - internal readonly int MinSrid; - internal readonly int MaxSrid; - internal readonly string? Authority; - - internal CrsMapEntry(int minSrid, int maxSrid, string? authority) - { - MinSrid = minSrid; - MaxSrid = maxSrid; - Authority = authority != null - ? string.IsInterned(authority) ?? authority - : null; - } -} - struct CrsMapBuilder { CrsMapEntry[] _overrides; @@ -71,38 +52,3 @@ internal CrsMap Build() return new CrsMap(_overrides); } } - -readonly partial struct CrsMap -{ - readonly CrsMapEntry[]? _overriden; - - internal CrsMap(CrsMapEntry[]? overriden) - => _overriden = overriden; - - internal string? GetAuthority(int srid) - => GetAuthority(_overriden, srid) ?? GetAuthority(WellKnown, srid); - - static string? GetAuthority(CrsMapEntry[]? entries, int srid) - { - if (entries == null) - return null; - - var left = 0; - var right = entries.Length; - while (left <= right) - { - var middle = left + (right - left) / 2; - var entry = entries[middle]; - - if (srid < entry.MinSrid) - right = middle - 1; - else - if (srid > entry.MaxSrid) - left = middle + 1; - else - return entry.Authority; - } - - return null; - } -} \ No newline at end of file diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs new file mode 100644 index 0000000000..2f6ece1fd8 --- /dev/null +++ b/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs @@ -0,0 +1,748 @@ +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.ObjectModel; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using GeoJSON.Net; +using GeoJSON.Net.CoordinateReferenceSystem; +using GeoJSON.Net.Geometry; +using Npgsql.Internal; + +namespace Npgsql.GeoJSON.Internal; + +sealed class GeoJSONConverter : PgStreamingConverter where T : IGeoJSONObject +{ + readonly ConcurrentDictionary _cachedCrs = new(); + readonly GeoJSONOptions _options; + readonly Func _getCrs; + + public GeoJSONConverter(GeoJSONOptions options, CrsMap crsMap) + { + _options = options; + _getCrs = GetCrs( + crsMap, + _cachedCrs, + crsType: _options & (GeoJSONOptions.ShortCRS | GeoJSONOptions.LongCRS) + ); + } + + bool BoundingBox => (_options & GeoJSONOptions.BoundingBox) != 0; + + public override T Read(PgReader reader) + => (T)GeoJSONConverter.Read(async: false, reader, BoundingBox ? new BoundingBoxBuilder() : null, _getCrs, CancellationToken.None).GetAwaiter().GetResult(); + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => (T)await GeoJSONConverter.Read(async: true, reader, BoundingBox ? new BoundingBoxBuilder() : null, _getCrs, cancellationToken).ConfigureAwait(false); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => GeoJSONConverter.GetSize(context, value, ref writeState); + + public override void Write(PgWriter writer, T value) + => GeoJSONConverter.Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => GeoJSONConverter.Write(async: true, writer, value, CancellationToken.None); + + static Func GetCrs(CrsMap crsMap, ConcurrentDictionary cachedCrs, GeoJSONOptions crsType) + => srid => + { + if (crsType == GeoJSONOptions.None) + return null; + +#if NETSTANDARD2_0 + return cachedCrs.GetOrAdd(srid, srid => + { + var authority = crsMap.GetAuthority(srid); + + return authority is null + ? throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table") + : new NamedCRS(crsType == GeoJSONOptions.LongCRS + ? "urn:ogc:def:crs:" + authority + "::" + srid + : authority + ":" + srid); + }); +#else + return cachedCrs.GetOrAdd(srid, static (srid, state) => + { + var (crsMap, crsType) = state; + var authority = crsMap.GetAuthority(srid); + + return authority is null + ? throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table") + : new NamedCRS(crsType == GeoJSONOptions.LongCRS + ? "urn:ogc:def:crs:" + authority + "::" + srid + : authority + ":" + srid); + }, (crsMap, crsType)); +#endif + }; +} + +static class GeoJSONConverter +{ + public static async ValueTask Read(bool async, PgReader reader, BoundingBoxBuilder? boundingBox, Func getCrs, CancellationToken cancellationToken) + { + var geometry = await Core(async, reader, boundingBox, getCrs, cancellationToken).ConfigureAwait(false); + geometry.BoundingBoxes = boundingBox?.Build(); + return geometry; + + static async ValueTask Core(bool async, PgReader reader, BoundingBoxBuilder? boundingbox, Func getCrs, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(SizeOfHeader)) + await reader.BufferData(async, SizeOfHeader, cancellationToken).ConfigureAwait(false); + + var littleEndian = reader.ReadByte() > 0; + var type = (EwkbGeometryType)ReadUInt32(littleEndian); + + GeoJSONObject geometry; + NamedCRS? crs = null; + + if (HasSrid(type)) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.BufferData(async, sizeof(int), cancellationToken).ConfigureAwait(false); + crs = getCrs(ReadInt32(littleEndian)); + } + + switch (type & EwkbGeometryType.BaseType) + { + case EwkbGeometryType.Point: + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + geometry = new Point(position); + break; + } + + case EwkbGeometryType.LineString: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var i = 0; i < coordinates.Length; ++i) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[i] = position; + } + geometry = new LineString(coordinates); + break; + } + + case EwkbGeometryType.Polygon: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var lines = new LineString[ReadInt32(littleEndian)]; + for (var i = 0; i < lines.Length; ++i) + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var j = 0; j < coordinates.Length; ++j) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[j] = position; + } + lines[i] = new LineString(coordinates); + } + geometry = new Polygon(lines); + break; + } + + case EwkbGeometryType.MultiPoint: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var points = new Point[ReadInt32(littleEndian)]; + for (var i = 0; i < points.Length; ++i) + { + if (SizeOfHeader + SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + if (async) + await reader.ConsumeAsync(SizeOfHeader, cancellationToken).ConfigureAwait(false); + else + reader.Consume(SizeOfHeader); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + points[i] = new Point(position); + } + geometry = new MultiPoint(points); + break; + } + + case EwkbGeometryType.MultiLineString: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var lines = new LineString[ReadInt32(littleEndian)]; + for (var i = 0; i < lines.Length; ++i) + { + if (reader.ShouldBuffer(SizeOfHeaderWithLength)) + await reader.BufferData(async, SizeOfHeaderWithLength, cancellationToken).ConfigureAwait(false); + + if (async) + await reader.ConsumeAsync(SizeOfHeader, cancellationToken).ConfigureAwait(false); + else + reader.Consume(SizeOfHeader); + + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var j = 0; j < coordinates.Length; ++j) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[j] = position; + } + lines[i] = new LineString(coordinates); + } + geometry = new MultiLineString(lines); + break; + } + + case EwkbGeometryType.MultiPolygon: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var polygons = new Polygon[ReadInt32(littleEndian)]; + for (var i = 0; i < polygons.Length; ++i) + { + if (reader.ShouldBuffer(SizeOfHeaderWithLength)) + await reader.BufferData(async, SizeOfHeaderWithLength, cancellationToken).ConfigureAwait(false); + + if (async) + await reader.ConsumeAsync(SizeOfHeader, cancellationToken).ConfigureAwait(false); + else + reader.Consume(SizeOfHeader); + + var lines = new LineString[ReadInt32(littleEndian)]; + for (var j = 0; j < lines.Length; ++j) + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var k = 0; k < coordinates.Length; ++k) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[k] = position; + } + lines[j] = new LineString(coordinates); + } + polygons[i] = new Polygon(lines); + } + geometry = new MultiPolygon(polygons); + break; + } + + case EwkbGeometryType.GeometryCollection: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var elements = new IGeometryObject[ReadInt32(littleEndian)]; + for (var i = 0; i < elements.Length; ++i) + elements[i] = (IGeometryObject)await Core(async, reader, boundingbox, getCrs, cancellationToken).ConfigureAwait(false); + geometry = new GeometryCollection(elements); + break; + } + + default: + throw UnknownPostGisType(); + } + + geometry.CRS = crs; + return geometry; + + int ReadInt32(bool littleEndian) + => littleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadInt32()) : reader.ReadInt32(); + uint ReadUInt32(bool littleEndian) + => littleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadUInt32()) : reader.ReadUInt32(); + } + + static Position ReadPosition(PgReader reader, EwkbGeometryType type, bool littleEndian) + { + var position = new Position( + longitude: ReadDouble(littleEndian), + latitude: ReadDouble(littleEndian), + altitude: HasZ(type) ? reader.ReadDouble() : null); + if (HasM(type)) ReadDouble(littleEndian); + return position; + + double ReadDouble(bool littleEndian) + => littleEndian + // Netstandard is missing ReverseEndianness apis for double. + ? Unsafe.As(ref Unsafe.AsRef( + BinaryPrimitives.ReverseEndianness(Unsafe.As(ref Unsafe.AsRef(reader.ReadDouble()))))) + : reader.ReadDouble(); + } + } + + public static Size GetSize(SizeContext context, IGeoJSONObject value, ref object? writeState) + => value.Type switch + { + GeoJSONObjectType.Point => GetSize((Point)value), + GeoJSONObjectType.LineString => GetSize((LineString)value), + GeoJSONObjectType.Polygon => GetSize((Polygon)value), + GeoJSONObjectType.MultiPoint => GetSize((MultiPoint)value), + GeoJSONObjectType.MultiLineString => GetSize((MultiLineString)value), + GeoJSONObjectType.MultiPolygon => GetSize((MultiPolygon)value), + GeoJSONObjectType.GeometryCollection => GetSize(context, (GeometryCollection)value, ref writeState), + _ => throw UnknownPostGisType() + }; + + static bool NotValid(ReadOnlyCollection coordinates, out bool hasZ) + { + if (coordinates.Count == 0) + hasZ = false; + else + { + hasZ = HasZ(coordinates[0]); + for (var i = 1; i < coordinates.Count; ++i) + if (HasZ(coordinates[i]) != hasZ) return true; + } + return false; + } + + static Size GetSize(Point value) + { + var length = Size.Create(SizeOfHeader + SizeOfPoint(HasZ(value.Coordinates))); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + return length; + } + + static Size GetSize(LineString value) + { + var coordinates = value.Coordinates; + if (NotValid(coordinates, out var hasZ)) + throw AllOrNoneCoordiantesMustHaveZ(nameof(LineString)); + + var length = Size.Create(SizeOfHeaderWithLength + coordinates.Count * SizeOfPoint(hasZ)); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + return length; + } + + static Size GetSize(Polygon value) + { + var lines = value.Coordinates; + var length = Size.Create(SizeOfHeaderWithLength + SizeOfLength * lines.Count); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var hasZ = false; + for (var i = 0; i < lines.Count; ++i) + { + var coordinates = lines[i].Coordinates; + if (NotValid(coordinates, out var lineHasZ)) + throw AllOrNoneCoordiantesMustHaveZ(nameof(Polygon)); + + if (hasZ != lineHasZ) + { + if (i == 0) hasZ = lineHasZ; + else throw AllOrNoneCoordiantesMustHaveZ(nameof(LineString)); + } + + length = length.Combine(coordinates.Count * SizeOfPoint(hasZ)); + } + + return length; + } + + static Size GetSize(MultiPoint value) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var coordinates = value.Coordinates; + foreach (var t in coordinates) + length = length.Combine(GetSize(t)); + + return length; + } + + static Size GetSize(MultiLineString value) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var coordinates = value.Coordinates; + foreach (var t in coordinates) + length = length.Combine(GetSize(t)); + + return length; + } + + static Size GetSize(MultiPolygon value) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var coordinates = value.Coordinates; + foreach (var t in coordinates) + length = length.Combine(GetSize(t)); + + return length; + } + + static Size GetSize(SizeContext context, GeometryCollection value, ref object? writeState) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var geometries = value.Geometries; + foreach (var t in geometries) + length = length.Combine(GetSize(context, (IGeoJSONObject)t, ref writeState)); + + return length; + } + + public static ValueTask Write(bool async, PgWriter writer, IGeoJSONObject value, CancellationToken cancellationToken = default) + => value.Type switch + { + GeoJSONObjectType.Point => Write(async, writer, (Point)value, cancellationToken), + GeoJSONObjectType.LineString => Write(async, writer, (LineString)value, cancellationToken), + GeoJSONObjectType.Polygon => Write(async, writer, (Polygon)value, cancellationToken), + GeoJSONObjectType.MultiPoint => Write(async, writer, (MultiPoint)value, cancellationToken), + GeoJSONObjectType.MultiLineString => Write(async, writer, (MultiLineString)value, cancellationToken), + GeoJSONObjectType.MultiPolygon => Write(async, writer, (MultiPolygon)value, cancellationToken), + GeoJSONObjectType.GeometryCollection => Write(async, writer, (GeometryCollection)value, cancellationToken), + _ => throw UnknownPostGisType() + }; + + static async ValueTask Write(bool async, PgWriter writer, Point value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.Point; + var size = SizeOfHeader; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + + if (srid != 0) + writer.WriteInt32(srid); + + await WritePosition(async, writer, value.Coordinates, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, LineString value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.LineString; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in coordinates) + await WritePosition(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, Polygon value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.Polygon; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var lines = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(lines.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in lines) + { + if (writer.ShouldFlush(SizeOfLength)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + var coordinates = t.Coordinates; + writer.WriteInt32(coordinates.Count); + foreach (var t1 in coordinates) + await WritePosition(async, writer, t1, cancellationToken).ConfigureAwait(false); + } + } + + static async ValueTask Write(bool async, PgWriter writer, MultiPoint value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.MultiPoint; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in coordinates) + await Write(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, MultiLineString value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.MultiLineString; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in coordinates) + await Write(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, MultiPolygon value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.MultiPolygon; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + foreach (var t in coordinates) + await Write(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, GeometryCollection value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.GeometryCollection; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var geometries = value.Geometries; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(geometries.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in geometries) + await Write(async, writer, (IGeoJSONObject)t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask WritePosition(bool async, PgWriter writer, IPosition coordinate, CancellationToken cancellationToken) + { + var altitude = coordinate.Altitude; + if (SizeOfPoint(altitude.HasValue) is var size && writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteDouble(coordinate.Longitude); + writer.WriteDouble(coordinate.Latitude); + if (altitude.HasValue) + writer.WriteDouble(altitude.Value); + } + + static ValueTask BufferData(this PgReader reader, bool async, int byteCount, CancellationToken cancellationToken) + { + if (async) + return reader.BufferAsync(byteCount, cancellationToken); + + reader.Buffer(byteCount); + return new(); + } + + static ValueTask Flush(this PgWriter writer, bool async, CancellationToken cancellationToken) + { + if (async) + return writer.FlushAsync(cancellationToken); + + writer.Flush(); + return new(); + } + + static bool HasSrid(EwkbGeometryType type) + => (type & EwkbGeometryType.HasSrid) != 0; + + static bool HasZ(EwkbGeometryType type) + => (type & EwkbGeometryType.HasZ) != 0; + + static bool HasM(EwkbGeometryType type) + => (type & EwkbGeometryType.HasM) != 0; + + static bool HasZ(IPosition coordinates) + => coordinates.Altitude.HasValue; + + const int SizeOfLength = sizeof(int); + const int SizeOfHeader = sizeof(byte) + sizeof(EwkbGeometryType); + const int SizeOfHeaderWithLength = SizeOfHeader + SizeOfLength; + const int SizeOfPoint2D = 2 * sizeof(double); + const int SizeOfPoint3D = 3 * sizeof(double); + + static int SizeOfPoint(bool hasZ) + => hasZ ? SizeOfPoint3D : SizeOfPoint2D; + + static int SizeOfPoint(EwkbGeometryType type) + { + var size = SizeOfPoint2D; + if (HasZ(type)) + size += sizeof(double); + if (HasM(type)) + size += sizeof(double); + return size; + } + + static Exception UnknownPostGisType() + => throw new InvalidOperationException("Invalid PostGIS type"); + + static Exception AllOrNoneCoordiantesMustHaveZ(string typeName) + => new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName}"); + + static int GetSrid(ICRSObject crs) + { + if (crs is null or UnspecifiedCRS) + return 0; + + var namedCrs = crs as NamedCRS; + if (namedCrs == null) + throw new NotSupportedException("The LinkedCRS class isn't supported"); + + if (namedCrs.Properties.TryGetValue("name", out var value) && value != null) + { + var name = value.ToString()!; + if (string.Equals(name, "urn:ogc:def:crs:OGC::CRS84", StringComparison.Ordinal)) + return 4326; + + var index = name.LastIndexOf(':'); + if (index != -1 && int.TryParse(name.Substring(index + 1), out var srid)) + return srid; + + throw new FormatException("The specified CRS isn't properly named"); + } + + return 0; + } +} + +/// +/// Represents the identifier of the Well Known Binary representation of a geographical feature specified by the OGC. +/// http://portal.opengeospatial.org/files/?artifact_id=13227 Chapter 6.3.2.7 +/// +[Flags] +enum EwkbGeometryType : uint +{ + // Types + Point = 1, + LineString = 2, + Polygon = 3, + MultiPoint = 4, + MultiLineString = 5, + MultiPolygon = 6, + GeometryCollection = 7, + + // Masks + BaseType = Point | LineString | Polygon | MultiPoint | MultiLineString | MultiPolygon | GeometryCollection, + + // Flags + HasSrid = 0x20000000, + HasM = 0x40000000, + HasZ = 0x80000000 +} diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONHandler.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONHandler.cs deleted file mode 100644 index ba040ed79d..0000000000 --- a/src/Npgsql.GeoJSON/Internal/GeoJSONHandler.cs +++ /dev/null @@ -1,722 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.ObjectModel; -using System.Threading; -using System.Threading.Tasks; -using GeoJSON.Net; -using GeoJSON.Net.CoordinateReferenceSystem; -using GeoJSON.Net.Geometry; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.GeoJSON.Internal; - -sealed partial class GeoJsonHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler -{ - readonly GeoJSONOptions _options; - readonly CrsMap _crsMap; - readonly ConcurrentDictionary _cachedCrs = new(); - - internal GeoJsonHandler(PostgresType postgresType, GeoJSONOptions options, CrsMap crsMap) - : base(postgresType) - { - _options = options; - _crsMap = crsMap; - } - - GeoJSONOptions CrsType => _options & (GeoJSONOptions.ShortCRS | GeoJSONOptions.LongCRS); - - bool BoundingBox => (_options & GeoJSONOptions.BoundingBox) != 0; - - static bool HasSrid(EwkbGeometryType type) - => (type & EwkbGeometryType.HasSrid) != 0; - - static bool HasZ(EwkbGeometryType type) - => (type & EwkbGeometryType.HasZ) != 0; - - static bool HasM(EwkbGeometryType type) - => (type & EwkbGeometryType.HasM) != 0; - - static bool HasZ(IPosition coordinates) - => coordinates.Altitude.HasValue; - - const int SizeOfLength = sizeof(int); - const int SizeOfHeader = sizeof(byte) + sizeof(EwkbGeometryType); - const int SizeOfHeaderWithLength = SizeOfHeader + SizeOfLength; - const int SizeOfPoint2D = 2 * sizeof(double); - const int SizeOfPoint3D = 3 * sizeof(double); - - static int SizeOfPoint(bool hasZ) - => hasZ ? SizeOfPoint3D : SizeOfPoint2D; - - static int SizeOfPoint(EwkbGeometryType type) - { - var size = SizeOfPoint2D; - if (HasZ(type)) - size += sizeof(double); - if (HasM(type)) - size += sizeof(double); - return size; - } - - #region Throw - - static Exception UnknownPostGisType() - => throw new InvalidOperationException("Invalid PostGIS type"); - - static Exception AllOrNoneCoordiantesMustHaveZ(NpgsqlParameter? parameter, string typeName) - => parameter is null - ? new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName}") - : new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName} in the {parameter.ParameterName} parameter", parameter.ParameterName); - - #endregion - - #region Read - - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (Point)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (LineString)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (Polygon)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (MultiPoint)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (MultiLineString)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (MultiPolygon)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (GeometryCollection)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (IGeometryObject)await ReadGeometry(buf, async); - - async ValueTask ReadGeometry(NpgsqlReadBuffer buf, bool async) - { - var boundingBox = BoundingBox ? new BoundingBoxBuilder() : null; - var geometry = await ReadGeometryCore(buf, async, boundingBox); - - geometry.BoundingBoxes = boundingBox?.Build(); - return geometry; - } - - async ValueTask ReadGeometryCore(NpgsqlReadBuffer buf, bool async, BoundingBoxBuilder? boundingBox) - { - await buf.Ensure(SizeOfHeader, async); - var littleEndian = buf.ReadByte() > 0; - var type = (EwkbGeometryType)buf.ReadUInt32(littleEndian); - - GeoJSONObject geometry; - NamedCRS? crs = null; - - if (HasSrid(type)) - { - await buf.Ensure(4, async); - crs = GetCrs(buf.ReadInt32(littleEndian)); - } - - switch (type & EwkbGeometryType.BaseType) - { - case EwkbGeometryType.Point: - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - geometry = new Point(position); - break; - } - - case EwkbGeometryType.LineString: - { - await buf.Ensure(SizeOfLength, async); - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < coordinates.Length; ++i) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[i] = position; - } - geometry = new LineString(coordinates); - break; - } - - case EwkbGeometryType.Polygon: - { - await buf.Ensure(SizeOfLength, async); - var lines = new LineString[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < lines.Length; ++i) - { - await buf.Ensure(SizeOfLength, async); - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var j = 0; j < coordinates.Length; ++j) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[j] = position; - } - lines[i] = new LineString(coordinates); - } - geometry = new Polygon(lines); - break; - } - - case EwkbGeometryType.MultiPoint: - { - await buf.Ensure(SizeOfLength, async); - var points = new Point[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < points.Length; ++i) - { - await buf.Ensure(SizeOfHeader + SizeOfPoint(type), async); - await buf.Skip(SizeOfHeader, async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - points[i] = new Point(position); - } - geometry = new MultiPoint(points); - break; - } - - case EwkbGeometryType.MultiLineString: - { - await buf.Ensure(SizeOfLength, async); - var lines = new LineString[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < lines.Length; ++i) - { - await buf.Ensure(SizeOfHeaderWithLength, async); - await buf.Skip(SizeOfHeader, async); - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var j = 0; j < coordinates.Length; ++j) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[j] = position; - } - lines[i] = new LineString(coordinates); - } - geometry = new MultiLineString(lines); - break; - } - - case EwkbGeometryType.MultiPolygon: - { - await buf.Ensure(SizeOfLength, async); - var polygons = new Polygon[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < polygons.Length; ++i) - { - await buf.Ensure(SizeOfHeaderWithLength, async); - await buf.Skip(SizeOfHeader, async); - var lines = new LineString[buf.ReadInt32(littleEndian)]; - for (var j = 0; j < lines.Length; ++j) - { - await buf.Ensure(SizeOfLength, async); - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var k = 0; k < coordinates.Length; ++k) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[k] = position; - } - lines[j] = new LineString(coordinates); - } - polygons[i] = new Polygon(lines); - } - geometry = new MultiPolygon(polygons); - break; - } - - case EwkbGeometryType.GeometryCollection: - { - await buf.Ensure(SizeOfLength, async); - var elements = new IGeometryObject[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < elements.Length; ++i) - elements[i] = (IGeometryObject)await ReadGeometryCore(buf, async, boundingBox); - geometry = new GeometryCollection(elements); - break; - } - - default: - throw UnknownPostGisType(); - } - - geometry.CRS = crs; - return geometry; - } - - static Position ReadPosition(NpgsqlReadBuffer buf, EwkbGeometryType type, bool littleEndian) - { - var position = new Position( - longitude: buf.ReadDouble(littleEndian), - latitude: buf.ReadDouble(littleEndian), - altitude: HasZ(type) ? buf.ReadDouble() : (double?)null); - if (HasM(type)) buf.ReadDouble(littleEndian); - return position; - } - - #endregion - - #region Write - - public override int ValidateAndGetLength(GeoJSONObject value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Type switch - { - GeoJSONObjectType.Point => ValidateAndGetLength((Point)value, ref lengthCache, parameter), - GeoJSONObjectType.LineString => ValidateAndGetLength((LineString)value, ref lengthCache, parameter), - GeoJSONObjectType.Polygon => ValidateAndGetLength((Polygon)value, ref lengthCache, parameter), - GeoJSONObjectType.MultiPoint => ValidateAndGetLength((MultiPoint)value, ref lengthCache, parameter), - GeoJSONObjectType.MultiLineString => ValidateAndGetLength((MultiLineString)value, ref lengthCache, parameter), - GeoJSONObjectType.MultiPolygon => ValidateAndGetLength((MultiPolygon)value, ref lengthCache, parameter), - GeoJSONObjectType.GeometryCollection => ValidateAndGetLength((GeometryCollection)value, ref lengthCache, parameter), - _ => throw UnknownPostGisType() - }; - - public int ValidateAndGetLength(Point value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeader + SizeOfPoint(HasZ(value.Coordinates)); - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - return length; - } - - public int ValidateAndGetLength(LineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var coordinates = value.Coordinates; - if (NotValid(coordinates, out var hasZ)) - throw AllOrNoneCoordiantesMustHaveZ(parameter, nameof(LineString)); - - var length = SizeOfHeaderWithLength + coordinates.Count * SizeOfPoint(hasZ); - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - return length; - } - - public int ValidateAndGetLength(Polygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var lines = value.Coordinates; - var length = SizeOfHeaderWithLength + SizeOfLength * lines.Count; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var hasZ = false; - for (var i = 0; i < lines.Count; ++i) - { - var coordinates = lines[i].Coordinates; - if (NotValid(coordinates, out var lineHasZ)) - throw AllOrNoneCoordiantesMustHaveZ(parameter, nameof(Polygon)); - - if (hasZ != lineHasZ) - { - if (i == 0) hasZ = lineHasZ; - else throw AllOrNoneCoordiantesMustHaveZ(parameter, nameof(LineString)); - } - - length += coordinates.Count * SizeOfPoint(hasZ); - } - - return length; - } - - static bool NotValid(ReadOnlyCollection coordinates, out bool hasZ) - { - if (coordinates.Count == 0) - hasZ = false; - else - { - hasZ = HasZ(coordinates[0]); - for (var i = 1; i < coordinates.Count; ++i) - if (HasZ(coordinates[i]) != hasZ) return true; - } - return false; - } - - public int ValidateAndGetLength(MultiPoint value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var coordinates = value.Coordinates; - for (var i = 0; i < coordinates.Count; ++i) - length += ValidateAndGetLength(coordinates[i], ref lengthCache, parameter); - - return length; - } - - public int ValidateAndGetLength(MultiLineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var coordinates = value.Coordinates; - for (var i = 0; i < coordinates.Count; ++i) - length += ValidateAndGetLength(coordinates[i], ref lengthCache, parameter); - - return length; - } - - public int ValidateAndGetLength(MultiPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var coordinates = value.Coordinates; - for (var i = 0; i < coordinates.Count; ++i) - length += ValidateAndGetLength(coordinates[i], ref lengthCache, parameter); - - return length; - } - - public int ValidateAndGetLength(GeometryCollection value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var geometries = value.Geometries; - for (var i = 0; i < geometries.Count; ++i) - length += ValidateAndGetLength((GeoJSONObject)geometries[i], ref lengthCache, parameter); - - return length; - } - - int INpgsqlTypeHandler.ValidateAndGetLength(IGeoJSONObject value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((GeoJSONObject)value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(IGeometryObject value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((GeoJSONObject)value, ref lengthCache, parameter); - - public override Task Write(GeoJSONObject value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value.Type switch - { - GeoJSONObjectType.Point => Write((Point)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.LineString => Write((LineString)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.Polygon => Write((Polygon)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.MultiPoint => Write((MultiPoint)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.MultiLineString => Write((MultiLineString)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.MultiPolygon => Write((MultiPolygon)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.GeometryCollection => Write((GeometryCollection)value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw UnknownPostGisType() - }; - - public async Task Write(Point value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.Point; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - - if (srid != 0) - buf.WriteInt32(srid); - - await WritePosition(value.Coordinates, buf, async, cancellationToken); - } - - public async Task Write(LineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.LineString; - var size = SizeOfHeaderWithLength; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < coordinates.Count; ++i) - await WritePosition(coordinates[i], buf, async, cancellationToken); - } - - public async Task Write(Polygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.Polygon; - var size = SizeOfHeaderWithLength; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var lines = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(lines.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < lines.Count; ++i) - { - if (buf.WriteSpaceLeft < SizeOfLength) - await buf.Flush(async, cancellationToken); - var coordinates = lines[i].Coordinates; - buf.WriteInt32(coordinates.Count); - for (var j = 0; j < coordinates.Count; ++j) - await WritePosition(coordinates[j], buf, async, cancellationToken); - } - } - - public async Task Write(MultiPoint value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.MultiPoint; - var size = SizeOfHeaderWithLength; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < coordinates.Count; ++i) - await Write(coordinates[i], buf, lengthCache, parameter, async, cancellationToken); - } - - public async Task Write(MultiLineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.MultiLineString; - var size = SizeOfHeaderWithLength; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < coordinates.Count; ++i) - await Write(coordinates[i], buf, lengthCache, parameter, async, cancellationToken); - } - - public async Task Write(MultiPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.MultiPolygon; - var size = SizeOfHeaderWithLength; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - for (var i = 0; i < coordinates.Count; ++i) - await Write(coordinates[i], buf, lengthCache, parameter, async, cancellationToken); - } - - public async Task Write(GeometryCollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.GeometryCollection; - var size = SizeOfHeaderWithLength; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var geometries = value.Geometries; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(geometries.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < geometries.Count; ++i) - await Write((GeoJSONObject) geometries[i], buf, lengthCache, parameter, async, cancellationToken); - } - - Task INpgsqlTypeHandler.Write(IGeoJSONObject value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => Write((GeoJSONObject)value, buf, lengthCache, parameter, async, cancellationToken); - - Task INpgsqlTypeHandler.Write(IGeometryObject value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => Write((GeoJSONObject)value, buf, lengthCache, parameter, async, cancellationToken); - - static async Task WritePosition(IPosition coordinate, NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - { - var altitude = coordinate.Altitude; - if (buf.WriteSpaceLeft < SizeOfPoint(altitude.HasValue)) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(coordinate.Longitude); - buf.WriteDouble(coordinate.Latitude); - if (altitude.HasValue) - buf.WriteDouble(altitude.Value); - } - - #endregion - - #region Crs - - NamedCRS? GetCrs(int srid) - { - var crsType = CrsType; - if (crsType == GeoJSONOptions.None) - return null; - -#if NETSTANDARD2_0 - return _cachedCrs.GetOrAdd(srid, srid => - { - var authority = _crsMap.GetAuthority(srid); - - return authority is null - ? throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table") - : new NamedCRS(crsType == GeoJSONOptions.LongCRS - ? "urn:ogc:def:crs:" + authority + "::" + srid - : authority + ":" + srid); - }); -#else - return _cachedCrs.GetOrAdd(srid, static (srid, me) => - { - var authority = me._crsMap.GetAuthority(srid); - - return authority is null - ? throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table") - : new NamedCRS(me.CrsType == GeoJSONOptions.LongCRS - ? "urn:ogc:def:crs:" + authority + "::" + srid - : authority + ":" + srid); - }, this); -#endif - } - - static int GetSrid(ICRSObject crs) - { - if (crs == null || crs is UnspecifiedCRS) - return 0; - - var namedCrs = crs as NamedCRS; - if (namedCrs == null) - throw new NotSupportedException("The LinkedCRS class isn't supported"); - - if (namedCrs.Properties.TryGetValue("name", out var value) && value != null) - { - var name = value.ToString()!; - if (string.Equals(name, "urn:ogc:def:crs:OGC::CRS84", StringComparison.Ordinal)) - return 4326; - - var index = name.LastIndexOf(':'); - if (index != -1 && int.TryParse(name.Substring(index + 1), out var srid)) - return srid; - - throw new FormatException("The specified CRS isn't properly named"); - } - - return 0; - } - - #endregion -} - -/// -/// Represents the identifier of the Well Known Binary representation of a geographical feature specified by the OGC. -/// http://portal.opengeospatial.org/files/?artifact_id=13227 Chapter 6.3.2.7 -/// -[Flags] -enum EwkbGeometryType : uint -{ - // Types - Point = 1, - LineString = 2, - Polygon = 3, - MultiPoint = 4, - MultiLineString = 5, - MultiPolygon = 6, - GeometryCollection = 7, - - // Masks - BaseType = Point | LineString | Polygon | MultiPoint | MultiLineString | MultiPolygon | GeometryCollection, - - // Flags - HasSrid = 0x20000000, - HasM = 0x40000000, - HasZ = 0x80000000 -} diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeHandlerResolver.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeHandlerResolver.cs deleted file mode 100644 index a937c1d62b..0000000000 --- a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeHandlerResolver.cs +++ /dev/null @@ -1,80 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Data; -using GeoJSON.Net; -using GeoJSON.Net.Geometry; -using Newtonsoft.Json; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.GeoJSON.Internal; - -public class GeoJSONTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlDatabaseInfo _databaseInfo; - readonly GeoJsonHandler? _geometryHandler, _geographyHandler; - readonly bool _geographyAsDefault; - - static readonly ConcurrentDictionary CRSMaps = new(); - - internal GeoJSONTypeHandlerResolver(NpgsqlConnector connector, GeoJSONOptions options, bool geographyAsDefault) - { - _databaseInfo = connector.DatabaseInfo; - _geographyAsDefault = geographyAsDefault; - - var crsMap = (options & (GeoJSONOptions.ShortCRS | GeoJSONOptions.LongCRS)) == GeoJSONOptions.None - ? default : CRSMaps.GetOrAdd(connector.Settings.ConnectionString, _ => - { - var builder = new CrsMapBuilder(); - using var cmd = connector.CreateCommand( - "SELECT min(srid), max(srid), auth_name " + - "FROM(SELECT srid, auth_name, srid - rank() OVER(ORDER BY srid) AS range " + - "FROM spatial_ref_sys) AS s GROUP BY range, auth_name ORDER BY 1;"); - cmd.AllResultTypesAreUnknown = true; - using var reader = cmd.ExecuteReader(); - - while (reader.Read()) - { - builder.Add(new CrsMapEntry( - int.Parse(reader.GetString(0)), - int.Parse(reader.GetString(1)), - reader.GetString(2))); - } - - return builder.Build(); - }); - - var (pgGeometryType, pgGeographyType) = (PgType("geometry"), PgType("geography")); - - if (pgGeometryType is not null) - _geometryHandler = new GeoJsonHandler(pgGeometryType, options, crsMap); - if (pgGeographyType is not null) - _geographyHandler = new GeoJsonHandler(pgGeographyType, options, crsMap); - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName switch - { - "geometry" => _geometryHandler, - "geography" => _geographyHandler, - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => ClrTypeToDataTypeName(type, _geographyAsDefault) is { } dataTypeName && ResolveByDataTypeName(dataTypeName) is { } handler - ? handler - : null; - - internal static string? ClrTypeToDataTypeName(Type type, bool geographyAsDefault) - => type.BaseType != typeof(GeoJSONObject) - ? null - : geographyAsDefault - ? "geography" - : "geometry"; - - PostgresType? PgType(string pgTypeName) => _databaseInfo.TryGetPostgresTypeByName(pgTypeName, out var pgType) ? pgType : null; -} \ No newline at end of file diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeHandlerResolverFactory.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeHandlerResolverFactory.cs deleted file mode 100644 index aae2c9102a..0000000000 --- a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,21 +0,0 @@ -using System; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.TypeMapping; - -namespace Npgsql.GeoJSON.Internal; - -public class GeoJSONTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - readonly GeoJSONOptions _options; - readonly bool _geographyAsDefault; - - public GeoJSONTypeHandlerResolverFactory(GeoJSONOptions options, bool geographyAsDefault) - => (_options, _geographyAsDefault) = (options, geographyAsDefault); - - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new GeoJSONTypeHandlerResolver(connector, _options, _geographyAsDefault); - - public override TypeMappingResolver CreateMappingResolver() => new GeoJsonTypeMappingResolver(_geographyAsDefault); -} diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolver.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolver.cs new file mode 100644 index 0000000000..5ea3b8c9e3 --- /dev/null +++ b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolver.cs @@ -0,0 +1,76 @@ +using System; +using GeoJSON.Net; +using GeoJSON.Net.Geometry; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.GeoJSON.Internal; + +sealed class GeoJSONTypeInfoResolver : IPgTypeInfoResolver +{ + TypeInfoMappingCollection Mappings { get; } + + internal GeoJSONTypeInfoResolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings, options, geographyAsDefault, crsMap); + // TODO opt-in arrays + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings, GeoJSONOptions geoJsonOptions, bool geographyAsDefault, CrsMap? crsMap) + { + crsMap ??= new CrsMap(CrsMap.WellKnown); + + var geometryMatchRequirement = !geographyAsDefault ? MatchRequirement.Single : MatchRequirement.DataTypeName; + var geographyMatchRequirement = geographyAsDefault ? MatchRequirement.Single : MatchRequirement.DataTypeName; + + foreach (var dataTypeName in new[] { "geometry", "geography" }) + { + var matchRequirement = dataTypeName == "geometry" ? geometryMatchRequirement : geographyMatchRequirement; + + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + } + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + foreach (var dataTypeName in new[] { "geometry", "geography" }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + } +} diff --git a/src/Npgsql.GeoJSON/Internal/GeoJsonTypeMappingResolver.cs b/src/Npgsql.GeoJSON/Internal/GeoJsonTypeMappingResolver.cs deleted file mode 100644 index 137606538b..0000000000 --- a/src/Npgsql.GeoJSON/Internal/GeoJsonTypeMappingResolver.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.GeoJSON.Internal; - -public class GeoJsonTypeMappingResolver : TypeMappingResolver -{ - readonly bool _geographyAsDefault; - - public GeoJsonTypeMappingResolver(bool geographyAsDefault) => _geographyAsDefault = geographyAsDefault; - - public override string? GetDataTypeNameByClrType(Type type) - => GeoJSONTypeHandlerResolver.ClrTypeToDataTypeName(type, _geographyAsDefault); - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => DoGetMappingByDataTypeName(dataTypeName); - - static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName) - => dataTypeName switch - { - "geometry" => new(NpgsqlDbType.Geometry, "geometry"), - "geography" => new(NpgsqlDbType.Geography, "geography"), - _ => null - }; -} diff --git a/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs b/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs index c59b0b21c7..6817094caa 100644 --- a/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs +++ b/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs @@ -1,4 +1,5 @@ -using Npgsql.GeoJSON.Internal; +using Npgsql.GeoJSON; +using Npgsql.GeoJSON.Internal; using Npgsql.TypeMapping; // ReSharper disable once CheckNamespace @@ -17,7 +18,20 @@ public static class NpgsqlGeoJSONExtensions /// Specifies that the geography type is used for mapping by default. public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) { - mapper.AddTypeResolverFactory(new GeoJSONTypeHandlerResolverFactory(options, geographyAsDefault)); + mapper.AddTypeInfoResolver(new GeoJSONTypeInfoResolver(options, geographyAsDefault, crsMap: null)); return mapper; } -} \ No newline at end of file + + /// + /// Sets up GeoJSON mappings for the PostGIS types. + /// + /// The type mapper to set up (global or connection-specific) + /// A custom crs map that might contain more or less entries than the default well-known crs map. + /// Options to use when constructing objects. + /// Specifies that the geography type is used for mapping by default. + public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, CrsMap crsMap, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) + { + mapper.AddTypeInfoResolver(new GeoJSONTypeInfoResolver(options, geographyAsDefault, crsMap)); + return mapper; + } +} diff --git a/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt b/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt index ab058de62d..be72efeb37 100644 --- a/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt +++ b/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt @@ -1 +1,5 @@ -#nullable enable +Npgsql.GeoJSON.CrsMap +Npgsql.GeoJSON.CrsMapExtensions +static Npgsql.GeoJSON.CrsMapExtensions.GetCrsMap(this Npgsql.NpgsqlDataSource! dataSource) -> Npgsql.GeoJSON.CrsMap! +static Npgsql.GeoJSON.CrsMapExtensions.GetCrsMapAsync(this Npgsql.NpgsqlDataSource! dataSource) -> System.Threading.Tasks.Task! +static Npgsql.NpgsqlGeoJSONExtensions.UseGeoJson(this Npgsql.TypeMapping.INpgsqlTypeMapper! mapper, Npgsql.GeoJSON.CrsMap! crsMap, Npgsql.GeoJSONOptions options = Npgsql.GeoJSONOptions.None, bool geographyAsDefault = false) -> Npgsql.TypeMapping.INpgsqlTypeMapper! \ No newline at end of file diff --git a/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs b/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs new file mode 100644 index 0000000000..42b7c88e0d --- /dev/null +++ b/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs @@ -0,0 +1,121 @@ +using System; +using System.Globalization; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Npgsql.Internal; +using JsonSerializer = Newtonsoft.Json.JsonSerializer; + +namespace Npgsql.Json.NET.Internal; + +sealed class JsonNetJsonConverter : PgStreamingConverter +{ + readonly bool _jsonb; + readonly Encoding _textEncoding; + readonly JsonSerializerSettings _settings; + + public JsonNetJsonConverter(bool jsonb, Encoding textEncoding, JsonSerializerSettings settings) + { + _jsonb = jsonb; + _textEncoding = textEncoding; + _settings = settings; + } + + public override T? Read(PgReader reader) + => (T?)JsonNetJsonConverter.Read(async: false, _jsonb, reader, typeof(T), _settings, _textEncoding, CancellationToken.None).GetAwaiter().GetResult(); + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => (T?)await JsonNetJsonConverter.Read(async: true, _jsonb, reader, typeof(T), _settings, _textEncoding, cancellationToken).ConfigureAwait(false); + + public override Size GetSize(SizeContext context, T? value, ref object? writeState) + => JsonNetJsonConverter.GetSize(_jsonb, context, typeof(T), _settings, _textEncoding, value, ref writeState); + + public override void Write(PgWriter writer, T? value) + => JsonNetJsonConverter.Write(_jsonb, async: false, writer, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) + => JsonNetJsonConverter.Write(_jsonb, async: true, writer, cancellationToken); +} + +// Split out to avoid unneccesary code duplication. +static class JsonNetJsonConverter +{ + public const byte JsonbProtocolVersion = 1; + + public static async ValueTask Read(bool async, bool jsonb, PgReader reader, Type type, JsonSerializerSettings settings, Encoding encoding, CancellationToken cancellationToken) + { + if (jsonb) + { + if (reader.ShouldBuffer(sizeof(byte))) + { + if (async) + await reader.BufferAsync(sizeof(byte), cancellationToken).ConfigureAwait(false); + else + reader.Buffer(sizeof(byte)); + } + var version = reader.ReadByte(); + if (version != JsonbProtocolVersion) + throw new InvalidCastException($"Unknown jsonb wire format version {version}"); + } + + using var stream = reader.GetStream(); + var mem = new MemoryStream(); + if (async) + await stream.CopyToAsync(mem, Math.Min((int)mem.Length, 81920), cancellationToken).ConfigureAwait(false); + else + stream.CopyTo(mem); + mem.Position = 0; + var jsonSerializer = JsonSerializer.CreateDefault(settings); + using var textReader = new JsonTextReader(new StreamReader(mem, encoding)); + return jsonSerializer.Deserialize(textReader, type); + } + + public static Size GetSize(bool jsonb, SizeContext context, Type type, JsonSerializerSettings settings, Encoding encoding, object? value, ref object? writeState) + { + var jsonSerializer = JsonSerializer.CreateDefault(settings); + var sb = new StringBuilder(256); + var sw = new StringWriter(sb, CultureInfo.InvariantCulture); + using (var jsonWriter = new JsonTextWriter(sw)) + { + jsonWriter.Formatting = jsonSerializer.Formatting; + + jsonSerializer.Serialize(jsonWriter, value, type); + } + + var str = sw.ToString(); + var bytes = encoding.GetBytes(str); + writeState = bytes; + return bytes.Length + (jsonb ? sizeof(byte) : 0); + } + + public static async ValueTask Write(bool jsonb, bool async, PgWriter writer, CancellationToken cancellationToken) + { + if (jsonb) + { + if (writer.ShouldFlush(sizeof(byte))) + { + if (async) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + else + writer.Flush(); + } + writer.WriteByte(JsonbProtocolVersion); + } + + ArraySegment buffer; + switch (writer.Current.WriteState) + { + case byte[] bytes: + buffer = new ArraySegment(bytes); + break; + default: + throw new InvalidCastException($"Invalid state {writer.Current.WriteState?.GetType().FullName}."); + } + + if (async) + await writer.WriteBytesAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(buffer.AsSpan()); + } +} diff --git a/src/Npgsql.Json.NET/Internal/JsonNetJsonHandler.cs b/src/Npgsql.Json.NET/Internal/JsonNetJsonHandler.cs deleted file mode 100644 index cbf8ca3ae2..0000000000 --- a/src/Npgsql.Json.NET/Internal/JsonNetJsonHandler.cs +++ /dev/null @@ -1,64 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using System.Threading.Tasks; -using Newtonsoft.Json; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Json.NET.Internal; - -class JsonNetJsonHandler : JsonTextHandler -{ - readonly JsonSerializerSettings _settings; - - public JsonNetJsonHandler(PostgresType postgresType, NpgsqlConnector connector, bool isJsonb, JsonSerializerSettings settings) - : base(postgresType, connector.TextEncoding, isJsonb) => _settings = settings; - - protected override async ValueTask ReadCustom(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - if (IsSupportedAsText()) - return await base.ReadCustom(buf, len, async, fieldDescription); - - // JSON.NET returns null if no JSON content was found. This means null may get returned even if T is a non-nullable reference - // type (for value types, an exception will be thrown). - return JsonConvert.DeserializeObject(await base.Read(buf, len, async, fieldDescription), _settings)!; - } - - protected override int ValidateAndGetLengthCustom([DisallowNull] TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (IsSupportedAsText()) - return base.ValidateAndGetLengthCustom(value, ref lengthCache, parameter); - - var serialized = JsonConvert.SerializeObject(value, _settings); - if (parameter != null) - parameter.ConvertedValue = serialized; - return base.ValidateAndGetLengthCustom(serialized, ref lengthCache, parameter); - } - - protected override Task WriteWithLengthCustom([DisallowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (IsSupportedAsText()) - return base.WriteWithLengthCustom(value, buf, lengthCache, parameter, async, cancellationToken); - - // User POCO, read serialized representation from the validation phase - var serialized = parameter?.ConvertedValue != null - ? (string)parameter.ConvertedValue - : JsonConvert.SerializeObject(value, _settings); - return base.WriteWithLengthCustom(serialized, buf, lengthCache, parameter, async, cancellationToken); - } - - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => IsSupported(value.GetType()) - ? base.ValidateObjectAndGetLength(value, ref lengthCache, parameter) - : ValidateAndGetLengthCustom(value, ref lengthCache, parameter); - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value is null or DBNull || IsSupported(value.GetType()) - ? base.WriteObjectWithLength(value, buf, lengthCache, parameter, async, cancellationToken) - : WriteWithLengthCustom(value, buf, lengthCache, parameter, async, cancellationToken); -} \ No newline at end of file diff --git a/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolver.cs b/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolver.cs new file mode 100644 index 0000000000..a9d54d863f --- /dev/null +++ b/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolver.cs @@ -0,0 +1,105 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using Newtonsoft.Json; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Json.NET.Internal; + +[RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] +[RequiresDynamicCode("Serializing arbitary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] +class JsonNetPocoTypeInfoResolver : DynamicTypeInfoResolver, IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } = new(); + protected JsonSerializerSettings _serializerSettings; + + const string JsonDataTypeName = "pg_catalog.json"; + const string JsonbDataTypeName = "pg_catalog.jsonb"; + + public JsonNetPocoTypeInfoResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + { + // Capture default settings during construction. + _serializerSettings = serializerSettings ??= JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); + + AddMappings(Mappings, jsonbClrTypes ?? Array.Empty(), jsonClrTypes ?? Array.Empty(), serializerSettings); + } + + void AddMappings(TypeInfoMappingCollection mappings, Type[] jsonbClrTypes, Type[] jsonClrTypes, JsonSerializerSettings serializerSettings) + { + AddUserMappings(jsonb: true, jsonbClrTypes); + AddUserMappings(jsonb: false, jsonClrTypes); + + void AddUserMappings(bool jsonb, Type[] clrTypes) + { + var dynamicMappings = CreateCollection(); + var dataTypeName = jsonb ? JsonbDataTypeName : JsonDataTypeName; + foreach (var jsonType in clrTypes) + { + dynamicMappings.AddMapping(jsonType, dataTypeName, + factory: (options, mapping, _) => mapping.CreateInfo(options, + CreateConverter(mapping.Type, jsonb, options.TextEncoding, serializerSettings))); + } + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + } + } + + protected void AddArrayInfos(TypeInfoMappingCollection mappings, TypeInfoMappingCollection baseMappings) + { + if (baseMappings.Items.Count == 0) + return; + + var dynamicMappings = CreateCollection(baseMappings); + foreach (var mapping in baseMappings.Items) + dynamicMappings.AddArrayMapping(mapping.Type, mapping.DataTypeName); + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + // Match all types except null, object and text types as long as DataTypeName (json/jsonb) is present. + if (type is null || type == typeof(object) || Array.IndexOf(PgSerializerOptions.WellKnownTextTypes, type) != -1 + || dataTypeName != JsonbDataTypeName && dataTypeName != JsonDataTypeName) + return null; + + return CreateCollection().AddMapping(type, dataTypeName, (options, mapping, _) => + { + var jsonb = dataTypeName == JsonbDataTypeName; + return mapping.CreateInfo(options, + CreateConverter(mapping.Type, jsonb, options.TextEncoding, _serializerSettings)); + }); + } + + static PgConverter CreateConverter(Type valueType, bool jsonb, Encoding textEncoding, JsonSerializerSettings settings) + => (PgConverter)Activator.CreateInstance( + typeof(JsonNetJsonConverter<>).MakeGenericType(valueType), + jsonb, + textEncoding, + settings + )!; +} + +[RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] +[RequiresDynamicCode("Serializing arbitary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] +sealed class JsonNetPocoArrayTypeInfoResolver : JsonNetPocoTypeInfoResolver, IPgTypeInfoResolver +{ + new TypeInfoMappingCollection Mappings { get; } + + public JsonNetPocoArrayTypeInfoResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + : base(jsonbClrTypes, jsonClrTypes, serializerSettings) + { + Mappings = new TypeInfoMappingCollection(base.Mappings); + AddArrayInfos(Mappings, base.Mappings); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; +} diff --git a/src/Npgsql.Json.NET/Internal/JsonNetTypeHandlerResolver.cs b/src/Npgsql.Json.NET/Internal/JsonNetTypeHandlerResolver.cs deleted file mode 100644 index 04bb63bdf1..0000000000 --- a/src/Npgsql.Json.NET/Internal/JsonNetTypeHandlerResolver.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; -using System.Collections.Generic; -using Newtonsoft.Json; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Json.NET.Internal; - -public class JsonNetTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlDatabaseInfo _databaseInfo; - readonly JsonNetJsonHandler _jsonNetJsonbHandler; - readonly JsonNetJsonHandler _jsonNetJsonHandler; - readonly Dictionary _dataTypeNamesByClrType; - - internal JsonNetTypeHandlerResolver( - NpgsqlConnector connector, - Dictionary dataTypeNamesByClrType, - JsonSerializerSettings settings) - { - _databaseInfo = connector.DatabaseInfo; - - _jsonNetJsonbHandler = new JsonNetJsonHandler(PgType("jsonb"), connector, isJsonb: true, settings); - _jsonNetJsonHandler = new JsonNetJsonHandler(PgType("json"), connector, isJsonb: false, settings); - - _dataTypeNamesByClrType = dataTypeNamesByClrType; - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName switch - { - "jsonb" => _jsonNetJsonbHandler, - "json" => _jsonNetJsonHandler, - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => ClrTypeToDataTypeName(type, _dataTypeNamesByClrType) is { } dataTypeName && ResolveByDataTypeName(dataTypeName) is { } handler - ? handler - : null; - - internal static string? ClrTypeToDataTypeName(Type type, Dictionary clrTypes) - => clrTypes.TryGetValue(type, out var dataTypeName) ? dataTypeName : null; - - PostgresType PgType(string pgTypeName) => _databaseInfo.GetPostgresTypeByName(pgTypeName); -} \ No newline at end of file diff --git a/src/Npgsql.Json.NET/Internal/JsonNetTypeHandlerResolverFactory.cs b/src/Npgsql.Json.NET/Internal/JsonNetTypeHandlerResolverFactory.cs deleted file mode 100644 index 739efc6d2c..0000000000 --- a/src/Npgsql.Json.NET/Internal/JsonNetTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,43 +0,0 @@ -using System; -using System.Collections.Generic; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.TypeMapping; - -namespace Npgsql.Json.NET.Internal; - -public class JsonNetTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - readonly JsonSerializerSettings _settings; - readonly Dictionary _byType; - - public JsonNetTypeHandlerResolverFactory( - Type[]? jsonbClrTypes, - Type[]? jsonClrTypes, - JsonSerializerSettings? settings) - { - _settings = settings ?? new JsonSerializerSettings(); - - _byType = new() - { - { typeof(JObject), "jsonb" }, - { typeof(JArray), "jsonb" } - }; - - if (jsonbClrTypes is not null) - foreach (var type in jsonbClrTypes) - _byType[type] = "jsonb"; - - if (jsonClrTypes is not null) - foreach (var type in jsonClrTypes) - _byType[type] = "json"; - } - - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new JsonNetTypeHandlerResolver(connector, _byType, _settings); - - public override TypeMappingResolver CreateMappingResolver() => new JsonNetTypeMappingResolver(_byType); -} diff --git a/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolver.cs b/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolver.cs new file mode 100644 index 0000000000..7954c4bb2f --- /dev/null +++ b/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolver.cs @@ -0,0 +1,67 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Json.NET.Internal; + +class JsonNetTypeInfoResolver : IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } = new(); + + public JsonNetTypeInfoResolver(JsonSerializerSettings? settings = null) + => AddTypeInfos(Mappings, settings); + + static void AddTypeInfos(TypeInfoMappingCollection mappings, JsonSerializerSettings? settings = null) + { + // Capture default settings during construction. + settings ??= JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); + + // Jsonb is the first default for JToken etc. + foreach (var dataTypeName in new[] { "jsonb", "json" }) + { + var jsonb = dataTypeName == "jsonb"; + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings)), + isDefault: true); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); + } + } + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + foreach (var dataTypeName in new[] { "jsonb", "json" }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} + +sealed class JsonNetArrayTypeInfoResolver : JsonNetTypeInfoResolver, IPgTypeInfoResolver +{ + new TypeInfoMappingCollection Mappings { get; } + + public JsonNetArrayTypeInfoResolver(JsonSerializerSettings? settings = null) : base(settings) + { + Mappings = new TypeInfoMappingCollection(base.Mappings); + AddArrayInfos(Mappings); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} diff --git a/src/Npgsql.Json.NET/Internal/JsonNetTypeMappingResolver.cs b/src/Npgsql.Json.NET/Internal/JsonNetTypeMappingResolver.cs deleted file mode 100644 index 119882f37e..0000000000 --- a/src/Npgsql.Json.NET/Internal/JsonNetTypeMappingResolver.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; -using System.Collections.Generic; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Json.NET.Internal; - -public class JsonNetTypeMappingResolver : TypeMappingResolver -{ - readonly Dictionary _byType; - - public JsonNetTypeMappingResolver(Dictionary byType) => _byType = byType; - - public override string? GetDataTypeNameByClrType(Type type) - => JsonNetTypeHandlerResolver.ClrTypeToDataTypeName(type, _byType); - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => DoGetMappingByDataTypeName(dataTypeName); - - static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName) - => dataTypeName switch - { - "jsonb" => new(NpgsqlDbType.Jsonb, "jsonb"), - "json" => new(NpgsqlDbType.Json, "json"), - _ => null - }; -} diff --git a/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj b/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj index abd1a4ea6d..baff7e6af6 100644 --- a/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj +++ b/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj @@ -3,8 +3,9 @@ Shay Rojansky Json.NET plugin for Npgsql, allowing transparent serialization/deserialization of JSON objects directly to and from the database. npgsql;postgresql;json;postgres;ado;ado.net;database;sql - netstandard2.0 + netstandard2.0;net6.0 net8.0 + enable diff --git a/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs b/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs index bd3b7b41f8..9cb70d86f1 100644 --- a/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs +++ b/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs @@ -29,7 +29,12 @@ public static INpgsqlTypeMapper UseJsonNet( Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null) { - mapper.AddTypeResolverFactory(new JsonNetTypeHandlerResolverFactory(jsonbClrTypes, jsonClrTypes, settings)); + // TODO opt-in of arrays. + // Reverse order + mapper.AddTypeInfoResolver(new JsonNetPocoArrayTypeInfoResolver(jsonbClrTypes, jsonClrTypes, settings)); + mapper.AddTypeInfoResolver(new JsonNetArrayTypeInfoResolver(settings)); + mapper.AddTypeInfoResolver(new JsonNetPocoTypeInfoResolver(jsonbClrTypes, jsonClrTypes, settings)); + mapper.AddTypeInfoResolver(new JsonNetTypeInfoResolver(settings)); return mapper; } -} \ No newline at end of file +} diff --git a/src/Npgsql.LegacyPostgis/Properties/AssemblyInfo.cs b/src/Npgsql.LegacyPostgis/Properties/AssemblyInfo.cs deleted file mode 100644 index 1a340b1a15..0000000000 --- a/src/Npgsql.LegacyPostgis/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,5 +0,0 @@ -using System.Runtime.CompilerServices; - -#if NET5_0_OR_GREATER -[module: SkipLocalsInit] -#endif diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteConverter.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteConverter.cs new file mode 100644 index 0000000000..467356164e --- /dev/null +++ b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteConverter.cs @@ -0,0 +1,81 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using NetTopologySuite.Geometries; +using NetTopologySuite.IO; +using Npgsql.Internal; + +namespace Npgsql.NetTopologySuite.Internal; + +sealed class NetTopologySuiteConverter : PgStreamingConverter + where T : Geometry +{ + readonly PostGisReader _reader; + readonly PostGisWriter _writer; + + internal NetTopologySuiteConverter(PostGisReader reader, PostGisWriter writer) + => (_reader, _writer) = (reader, writer); + + public override T Read(PgReader reader) + => (T)_reader.Read(reader.GetStream()); + + // PostGisReader/PostGisWriter doesn't support async + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => new(Read(reader)); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + var lengthStream = new LengthStream(); + lengthStream.SetLength(0); + _writer.Write(value, lengthStream); + return (int)lengthStream.Length; + } + +#pragma warning disable CA2252 // GetStream() is a "preview" feature + public override void Write(PgWriter writer, T value) + => _writer.Write(value, writer.GetStream()); +#pragma warning restore CA2252 + + // PostGisReader/PostGisWriter doesn't support async + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + { + Write(writer, value); + return default; + } + + sealed class LengthStream : Stream + { + long _length; + + public override bool CanRead => false; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => _length; + + public override long Position + { + get => _length; + set => throw new NotSupportedException(); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + + public override void SetLength(long value) + => _length = value; + + public override void Write(byte[] buffer, int offset, int count) + => _length += count; + } +} diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteHandler.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteHandler.cs deleted file mode 100644 index f75be9f4a7..0000000000 --- a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteHandler.cs +++ /dev/null @@ -1,168 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using NetTopologySuite.Geometries; -using NetTopologySuite.IO; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.NetTopologySuite.Internal; - -partial class NetTopologySuiteHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler -{ - readonly PostGisReader _reader; - readonly PostGisWriter _writer; - - internal NetTopologySuiteHandler(PostgresType postgresType, PostGisReader reader, PostGisWriter writer) - : base(postgresType) - { - _reader = reader; - _writer = writer; - } - - #region Read - - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask ReadCore(NpgsqlReadBuffer buf, int len) - where T : Geometry - => new((T)_reader.Read(buf.GetStream(len, false))); - - #endregion - - #region ValidateAndGetLength - - public override int ValidateAndGetLength(Geometry value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCore(value); - - int INpgsqlTypeHandler.ValidateAndGetLength(Point value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(LineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(Polygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(MultiPoint value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(MultiLineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(MultiPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(GeometryCollection value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int ValidateAndGetLengthCore(Geometry value) - { - var lengthStream = new LengthStream(); - lengthStream.SetLength(0); - _writer.Write(value, lengthStream); - return (int)lengthStream.Length; - } - - sealed class LengthStream : Stream - { - long _length; - - public override bool CanRead => false; - - public override bool CanSeek => false; - - public override bool CanWrite => true; - - public override long Length => _length; - - public override long Position - { - get => _length; - set => throw new NotSupportedException(); - } - - public override void Flush() - { } - - public override int Read(byte[] buffer, int offset, int count) - => throw new NotSupportedException(); - - public override long Seek(long offset, SeekOrigin origin) - => throw new NotSupportedException(); - - public override void SetLength(long value) - => _length = value; - - public override void Write(byte[] buffer, int offset, int count) - => _length += count; - } - - #endregion - - #region Write - - public override Task Write(Geometry value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(Point value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(LineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(Polygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(MultiPoint value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToke) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(MultiLineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(MultiPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(GeometryCollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task WriteCore(Geometry value, NpgsqlWriteBuffer buf) - { - _writer.Write(value, buf.GetStream()); - return Task.CompletedTask; - } - - #endregion -} \ No newline at end of file diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeHandlerResolver.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeHandlerResolver.cs deleted file mode 100644 index 8f270ac90f..0000000000 --- a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeHandlerResolver.cs +++ /dev/null @@ -1,55 +0,0 @@ -using System; -using System.Data; -using NetTopologySuite.Geometries; -using NetTopologySuite.IO; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.NetTopologySuite.Internal; - -public class NetTopologySuiteTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlDatabaseInfo _databaseInfo; - readonly bool _geographyAsDefault; - - readonly NetTopologySuiteHandler? _geometryHandler, _geographyHandler; - - internal NetTopologySuiteTypeHandlerResolver( - NpgsqlConnector connector, - CoordinateSequenceFactory coordinateSequenceFactory, - PrecisionModel precisionModel, - Ordinates handleOrdinates, - bool geographyAsDefault) - { - _databaseInfo = connector.DatabaseInfo; - _geographyAsDefault = geographyAsDefault; - - var (pgGeometryType, pgGeographyType) = (PgType("geometry"), PgType("geography")); - - var reader = new PostGisReader(coordinateSequenceFactory, precisionModel, handleOrdinates); - var writer = new PostGisWriter(); - - if (pgGeometryType is not null) - _geometryHandler = new NetTopologySuiteHandler(pgGeometryType, reader, writer); - if (pgGeographyType is not null) - _geographyHandler = new NetTopologySuiteHandler(pgGeographyType, reader, writer); - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName switch - { - "geometry" => _geometryHandler, - "geography" => _geographyHandler, - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => NetTopologySuiteTypeMappingResolver.ClrTypeToDataTypeName(type, _geographyAsDefault) is { } dataTypeName && ResolveByDataTypeName(dataTypeName) is { } handler - ? handler - : null; - - PostgresType? PgType(string pgTypeName) => _databaseInfo.TryGetPostgresTypeByName(pgTypeName, out var pgType) ? pgType : null; -} \ No newline at end of file diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeHandlerResolverFactory.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeHandlerResolverFactory.cs deleted file mode 100644 index 1aed03a058..0000000000 --- a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,33 +0,0 @@ -using NetTopologySuite; -using NetTopologySuite.Geometries; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.NetTopologySuite.Internal; - -public class NetTopologySuiteTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - readonly CoordinateSequenceFactory _coordinateSequenceFactory; - readonly PrecisionModel _precisionModel; - readonly Ordinates _handleOrdinates; - readonly bool _geographyAsDefault; - - public NetTopologySuiteTypeHandlerResolverFactory( - CoordinateSequenceFactory? coordinateSequenceFactory, - PrecisionModel? precisionModel, - Ordinates handleOrdinates, - bool geographyAsDefault) - { - _coordinateSequenceFactory = coordinateSequenceFactory ?? NtsGeometryServices.Instance.DefaultCoordinateSequenceFactory;; - _precisionModel = precisionModel ?? NtsGeometryServices.Instance.DefaultPrecisionModel; - _handleOrdinates = handleOrdinates == Ordinates.None ? _coordinateSequenceFactory.Ordinates : handleOrdinates; - _geographyAsDefault = geographyAsDefault; - } - - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new NetTopologySuiteTypeHandlerResolver(connector, _coordinateSequenceFactory, _precisionModel, _handleOrdinates, - _geographyAsDefault); - - public override TypeMappingResolver CreateMappingResolver() => new NetTopologySuiteTypeMappingResolver(_geographyAsDefault); -} diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolver.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolver.cs new file mode 100644 index 0000000000..a934bd2f86 --- /dev/null +++ b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolver.cs @@ -0,0 +1,115 @@ +using System; +using NetTopologySuite; +using NetTopologySuite.Geometries; +using NetTopologySuite.IO; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.NetTopologySuite.Internal; + +sealed class NetTopologySuiteTypeInfoResolver : IPgTypeInfoResolver +{ + TypeInfoMappingCollection Mappings { get; } + + public NetTopologySuiteTypeInfoResolver( + CoordinateSequenceFactory? coordinateSequenceFactory, + PrecisionModel? precisionModel, + Ordinates handleOrdinates, + bool geographyAsDefault) + { + coordinateSequenceFactory ??= NtsGeometryServices.Instance.DefaultCoordinateSequenceFactory; + precisionModel ??= NtsGeometryServices.Instance.DefaultPrecisionModel; + handleOrdinates = handleOrdinates == Ordinates.None ? coordinateSequenceFactory.Ordinates : handleOrdinates; + + var reader = new PostGisReader(coordinateSequenceFactory, precisionModel, handleOrdinates); + var writer = new PostGisWriter(); + + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings, reader, writer, geographyAsDefault); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings, PostGisReader reader, PostGisWriter writer, bool geographyAsDefault) + { + // geometry + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.Single : MatchRequirement.DataTypeName); + + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geometry", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: !geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + + // geography + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.Single : MatchRequirement.DataTypeName); + + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + mappings.AddType("geography", + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + matchRequirement: geographyAsDefault ? MatchRequirement.All : MatchRequirement.DataTypeName); + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + // geometry + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + mappings.AddArrayType("geometry"); + + // geography + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + mappings.AddArrayType("geography"); + } +} diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeMappingResolver.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeMappingResolver.cs deleted file mode 100644 index f087d6c55e..0000000000 --- a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeMappingResolver.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System; -using NetTopologySuite.Geometries; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.NetTopologySuite.Internal; - -public class NetTopologySuiteTypeMappingResolver : TypeMappingResolver -{ - readonly bool _geographyAsDefault; - - public NetTopologySuiteTypeMappingResolver(bool geographyAsDefault) => _geographyAsDefault = geographyAsDefault; - - public override string? GetDataTypeNameByClrType(Type type) - => ClrTypeToDataTypeName(type, _geographyAsDefault); - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => DoGetMappingByDataTypeName(dataTypeName); - - internal static string? ClrTypeToDataTypeName(Type type, bool geographyAsDefault) - => type != typeof(Geometry) && type.BaseType != typeof(Geometry) && type.BaseType != typeof(GeometryCollection) - ? null - : geographyAsDefault - ? "geography" - : "geometry"; - - static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName) - => dataTypeName switch - { - "geometry" => new(NpgsqlDbType.Geometry, "geometry"), - "geography" => new(NpgsqlDbType.Geography, "geography"), - _ => null - }; -} diff --git a/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj b/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj index e09653ac97..c36aec8652 100644 --- a/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj +++ b/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj @@ -24,6 +24,6 @@ - + diff --git a/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs b/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs index a867fea349..1408709236 100644 --- a/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs +++ b/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs @@ -27,9 +27,7 @@ public static INpgsqlTypeMapper UseNetTopologySuite( Ordinates handleOrdinates = Ordinates.None, bool geographyAsDefault = false) { - mapper.AddTypeResolverFactory( - new NetTopologySuiteTypeHandlerResolverFactory( - coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault)); + mapper.AddTypeInfoResolver(new NetTopologySuiteTypeInfoResolver(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault)); return mapper; } -} \ No newline at end of file +} diff --git a/src/Npgsql.NodaTime/Internal/DateHandler.cs b/src/Npgsql.NodaTime/Internal/DateHandler.cs deleted file mode 100644 index 9ae07b040a..0000000000 --- a/src/Npgsql.NodaTime/Internal/DateHandler.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.NodaTime.Properties; -using Npgsql.PostgresTypes; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; -using BclDateHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.DateHandler; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class DateHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -#if NET6_0_OR_GREATER - , INpgsqlSimpleTypeHandler -#endif -{ - readonly BclDateHandler _bclHandler; - - internal DateHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclDateHandler(postgresType); - - public override LocalDate Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt32() switch - { - int.MaxValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue) - : LocalDate.MaxIsoValue, - int.MinValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue) - : LocalDate.MinIsoValue, - var value => new LocalDate().PlusDays(value + 730119) - }; - - public override int ValidateAndGetLength(LocalDate value, NpgsqlParameter? parameter) - => 4; - - public override void Write(LocalDate value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (!DisableDateTimeInfinityConversions) - { - if (value == LocalDate.MaxIsoValue) - { - buf.WriteInt32(int.MaxValue); - return; - } - if (value == LocalDate.MinIsoValue) - { - buf.WriteInt32(int.MinValue); - return; - } - } - - var totalDaysSinceEra = Period.Between(default, value, PeriodUnits.Days).Days; - buf.WriteInt32(totalDaysSinceEra - 730119); - } - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(int value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - -#if NET6_0_OR_GREATER - DateOnly INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - public int ValidateAndGetLength(DateOnly value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - public void Write(DateOnly value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); -#endif - - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => new DateRangeHandler(pgRangeType, this); -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs b/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs new file mode 100644 index 0000000000..5e25d8bfcc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs @@ -0,0 +1,49 @@ +using System.Threading; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Internal; +using NpgsqlTypes; + +namespace Npgsql.NodaTime.Internal; + +public class DateIntervalConverter : PgStreamingConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly PgConverter> _rangeConverter; + + public DateIntervalConverter(PgConverter> rangeConverter, bool dateTimeInfinityConversions) + { + _rangeConverter = rangeConverter; + _dateTimeInfinityConversions = dateTimeInfinityConversions; + } + + public override DateInterval Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + var range = async + ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : _rangeConverter.Read(reader); + + var upperBound = range.UpperBound; + + if (upperBound != LocalDate.MaxIsoValue || !_dateTimeInfinityConversions) + upperBound -= Period.FromDays(1); + + return new(range.LowerBound, upperBound); + } + + public override Size GetSize(SizeContext context, DateInterval value, ref object? writeState) + => _rangeConverter.GetSize(context, new NpgsqlRange(value.Start, value.End), ref writeState); + + public override void Write(PgWriter writer, DateInterval value) + => _rangeConverter.Write(writer, new NpgsqlRange(value.Start, value.End)); + + public override ValueTask WriteAsync(PgWriter writer, DateInterval value, CancellationToken cancellationToken = default) + => _rangeConverter.WriteAsync(writer, new NpgsqlRange(value.Start, value.End), cancellationToken); +} diff --git a/src/Npgsql.NodaTime/Internal/DateMultirangeHandler.cs b/src/Npgsql.NodaTime/Internal/DateMultirangeHandler.cs deleted file mode 100644 index 167b8eb310..0000000000 --- a/src/Npgsql.NodaTime/Internal/DateMultirangeHandler.cs +++ /dev/null @@ -1,120 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.NodaTime.Internal; - -public partial class DateMultirangeHandler : MultirangeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler> -{ - readonly INpgsqlTypeHandler _dateIntervalHandler; - - public DateMultirangeHandler(PostgresMultirangeType multirangePostgresType, DateRangeHandler rangeHandler) - : base(multirangePostgresType, rangeHandler) - => _dateIntervalHandler = rangeHandler; - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(DateInterval[]); - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - async ValueTask INpgsqlTypeHandler.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(4, async); - var numRanges = buf.ReadInt32(); - var multirange = new DateInterval[numRanges]; - - for (var i = 0; i < multirange.Length; i++) - { - await buf.Ensure(4, async); - var rangeLen = buf.ReadInt32(); - multirange[i] = await _dateIntervalHandler.Read(buf, rangeLen, async, fieldDescription); - } - - return multirange; - } - - async ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(4, async); - var numRanges = buf.ReadInt32(); - var multirange = new List(numRanges); - - for (var i = 0; i < numRanges; i++) - { - await buf.Ensure(4, async); - var rangeLen = buf.ReadInt32(); - multirange.Add(await _dateIntervalHandler.Read(buf, rangeLen, async, fieldDescription)); - } - - return multirange; - } - - public int ValidateAndGetLength(DateInterval[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCore(value, ref lengthCache); - - public int ValidateAndGetLength(List value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCore(value, ref lengthCache); - - int ValidateAndGetLengthCore(IList value, ref NpgsqlLengthCache? lengthCache) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var sum = 4 + 4 * value.Count; - for (var i = 0; i < value.Count; i++) - sum += _dateIntervalHandler.ValidateAndGetLength(value[i], ref lengthCache, parameter: null); - - return lengthCache!.Set(sum); - } - - public async Task Write( - DateInterval[] value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(value.Length); - - for (var i = 0; i < value.Length; i++) - await RangeHandler.WriteWithLength(value[i], buf, lengthCache, parameter: null, async, cancellationToken); - } - - public async Task Write( - List value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(value.Count); - - for (var i = 0; i < value.Count; i++) - { - var interval = value[i]; - await RangeHandler.WriteWithLength( - new NpgsqlRange(interval.Start, interval.End), buf, lengthCache, parameter: null, async, cancellationToken); - } - } -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/DateRangeHandler.cs b/src/Npgsql.NodaTime/Internal/DateRangeHandler.cs deleted file mode 100644 index 601a0cfb45..0000000000 --- a/src/Npgsql.NodaTime/Internal/DateRangeHandler.cs +++ /dev/null @@ -1,69 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.NodaTime.Properties; -using Npgsql.PostgresTypes; -using NpgsqlTypes; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; - -namespace Npgsql.NodaTime.Internal; - -public partial class DateRangeHandler : RangeHandler, INpgsqlTypeHandler -#if NET6_0_OR_GREATER - , INpgsqlTypeHandler> -#endif -{ - public DateRangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler subtypeHandler) - : base(rangePostgresType, subtypeHandler) - { - } - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(DateInterval); - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - var range = await Read(buf, len, async, fieldDescription); - - var upperBound = range.UpperBound; - - if (DisableDateTimeInfinityConversions || upperBound != LocalDate.MaxIsoValue) - upperBound -= Period.FromDays(1); - - return new(range.LowerBound, upperBound); - } - - public int ValidateAndGetLength(DateInterval value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(new NpgsqlRange(value.Start, value.End), ref lengthCache, parameter); - - public Task Write( - DateInterval value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken = default) - => WriteRange(new NpgsqlRange(value.Start, value.End), buf, lengthCache, parameter, async, cancellationToken); - -#if NET6_0_OR_GREATER - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadRange(buf, len, async, fieldDescription); - - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - public Task Write( - NpgsqlRange value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); -#endif -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/DurationConverter.cs b/src/Npgsql.NodaTime/Internal/DurationConverter.cs new file mode 100644 index 0000000000..940ef29464 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/DurationConverter.cs @@ -0,0 +1,42 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using Npgsql.NodaTime.Properties; + +namespace Npgsql.NodaTime.Internal; + +sealed class DurationConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override Duration ReadCore(PgReader reader) + { + var microsecondsInDay = reader.ReadInt64(); + var days = reader.ReadInt32(); + var totalMonths = reader.ReadInt32(); + + if (totalMonths != 0) + throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadIntervalWithMonthsAsDuration); + + return Duration.FromDays(days) + Duration.FromNanoseconds(microsecondsInDay * 1000); + } + + protected override void WriteCore(PgWriter writer, Duration value) + { + const long microsecondsPerSecond = 1_000_000; + + // Note that the end result must be long + // see #3438 + var microsecondsInDay = + (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * + microsecondsPerSecond + value.SubsecondNanoseconds / 1000); // Take the microseconds, discard the nanosecond remainder + + writer.WriteInt64(microsecondsInDay); + writer.WriteInt32(value.Days); // days + writer.WriteInt32(0); // months + } +} diff --git a/src/Npgsql.NodaTime/Internal/IntervalConverter.cs b/src/Npgsql.NodaTime/Internal/IntervalConverter.cs new file mode 100644 index 0000000000..3ca9ca9ab0 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/IntervalConverter.cs @@ -0,0 +1,57 @@ +using System.Threading; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Internal; +using NpgsqlTypes; + +namespace Npgsql.NodaTime.Internal; + +public class IntervalConverter : PgStreamingConverter +{ + readonly PgConverter> _rangeConverter; + + public IntervalConverter(PgConverter> rangeConverter) + => _rangeConverter = rangeConverter; + + public override Interval Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + var range = async + ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : _rangeConverter.Read(reader); + + // NodaTime Interval includes the start instant and excludes the end instant. + Instant? start = range.LowerBoundInfinite + ? null + : range.LowerBoundIsInclusive + ? range.LowerBound + : range.LowerBound + Duration.Epsilon; + Instant? end = range.UpperBoundInfinite + ? null + : range.UpperBoundIsInclusive + ? range.UpperBound + Duration.Epsilon + : range.UpperBound; + + return new(start, end); + } + + public override Size GetSize(SizeContext context, Interval value, ref object? writeState) + => _rangeConverter.GetSize(context, IntervalToNpgsqlRange(value), ref writeState); + + public override void Write(PgWriter writer, Interval value) + => _rangeConverter.Write(writer, IntervalToNpgsqlRange(value)); + + public override ValueTask WriteAsync(PgWriter writer, Interval value, CancellationToken cancellationToken = default) + => _rangeConverter.WriteAsync(writer, IntervalToNpgsqlRange(value), cancellationToken); + + static NpgsqlRange IntervalToNpgsqlRange(Interval interval) + => new( + interval.HasStart ? interval.Start : default, true, !interval.HasStart, + interval.HasEnd ? interval.End : default, false, !interval.HasEnd); +} diff --git a/src/Npgsql.NodaTime/Internal/IntervalHandler.cs b/src/Npgsql.NodaTime/Internal/IntervalHandler.cs deleted file mode 100644 index 4e9305a20b..0000000000 --- a/src/Npgsql.NodaTime/Internal/IntervalHandler.cs +++ /dev/null @@ -1,106 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; -using BclIntervalHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.IntervalHandler; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class IntervalHandler : - NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler -{ - readonly BclIntervalHandler _bclHandler; - - internal IntervalHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclIntervalHandler(postgresType); - - public override Period Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var microsecondsInDay = buf.ReadInt64(); - var days = buf.ReadInt32(); - var totalMonths = buf.ReadInt32(); - - // NodaTime will normalize most things (i.e. nanoseconds to milliseconds, seconds...) - // but it will not normalize months to years. - var months = totalMonths % 12; - var years = totalMonths / 12; - - return new PeriodBuilder - { - Nanoseconds = microsecondsInDay * 1000, - Days = days, - Months = months, - Years = years - }.Build().Normalize(); - } - - public override int ValidateAndGetLength(Period value, NpgsqlParameter? parameter) - => 16; - - public override void Write(Period value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - // Note that the end result must be long - // see #3438 - var microsecondsInDay = - (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * NodaConstants.MillisecondsPerSecond + value.Milliseconds) * 1000 + - value.Nanoseconds / 1000; // Take the microseconds, discard the nanosecond remainder - - buf.WriteInt64(microsecondsInDay); - buf.WriteInt32(value.Weeks * 7 + value.Days); // days - buf.WriteInt32(value.Years * 12 + value.Months); // months - } - - Duration INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var microsecondsInDay = buf.ReadInt64(); - var days = buf.ReadInt32(); - var totalMonths = buf.ReadInt32(); - - if (totalMonths != 0) - throw new NpgsqlException("Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead."); - - return Duration.FromDays(days) + Duration.FromNanoseconds(microsecondsInDay * 1000); - } - - public int ValidateAndGetLength(Duration value, NpgsqlParameter? parameter) => 16; - - public void Write(Duration value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - const long microsecondsPerSecond = 1_000_000; - - // Note that the end result must be long - // see #3438 - var microsecondsInDay = - (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * - microsecondsPerSecond + value.SubsecondNanoseconds / 1000); // Take the microseconds, discard the nanosecond remainder - - buf.WriteInt64(microsecondsInDay); - buf.WriteInt32(value.Days); // days - buf.WriteInt32(0); // months - } - - TimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - - NpgsqlInterval INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(NpgsqlInterval value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(NpgsqlInterval value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/LegacyConverters.cs b/src/Npgsql.NodaTime/Internal/LegacyConverters.cs new file mode 100644 index 0000000000..54393a4821 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/LegacyConverters.cs @@ -0,0 +1,78 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; + +namespace Npgsql.NodaTime.Internal; + +sealed class LegacyTimestampTzZonedDateTimeConverter : PgBufferedConverter +{ + readonly DateTimeZone _dateTimeZone; + readonly bool _dateTimeInfinityConversions; + + public LegacyTimestampTzZonedDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) + { + _dateTimeZone = dateTimeZone; + _dateTimeInfinityConversions = dateTimeInfinityConversions; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override ZonedDateTime ReadCore(PgReader reader) + { + var instant = DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); + + return instant.InZone(_dateTimeZone); + } + + protected override void WriteCore(PgWriter writer, ZonedDateTime value) + { + var instant = value.ToInstant(); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new ArgumentException("Infinity values not supported for timestamp with time zone"); + + writer.WriteInt64(EncodeInstant(instant, _dateTimeInfinityConversions)); + } +} + +sealed class LegacyTimestampTzOffsetDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly DateTimeZone _dateTimeZone; + + public LegacyTimestampTzOffsetDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) + { + _dateTimeInfinityConversions = dateTimeInfinityConversions; + _dateTimeZone = dateTimeZone; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override OffsetDateTime ReadCore(PgReader reader) + { + var instant = DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); + + return instant.InZone(_dateTimeZone).ToOffsetDateTime(); + } + + protected override void WriteCore(PgWriter writer, OffsetDateTime value) + { + var instant = value.ToInstant(); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new ArgumentException("Infinity values not supported for timestamp with time zone"); + + writer.WriteInt64(EncodeInstant(instant, true)); + } +} diff --git a/src/Npgsql.NodaTime/Internal/LegacyTimestampHandler.cs b/src/Npgsql.NodaTime/Internal/LegacyTimestampHandler.cs deleted file mode 100644 index ee2ba1a130..0000000000 --- a/src/Npgsql.NodaTime/Internal/LegacyTimestampHandler.cs +++ /dev/null @@ -1,64 +0,0 @@ -using System; -using System.Diagnostics; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using BclTimestampHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.TimestampHandler; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class LegacyTimestampHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - readonly BclTimestampHandler _bclHandler; - - internal LegacyTimestampHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclTimestampHandler(postgresType); - - #region Read - - public override Instant Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => TimestampTzHandler.ReadInstant(buf); - - LocalDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => TimestampHandler.ReadLocalDateTime(buf); - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(Instant value, NpgsqlParameter? parameter) - => 8; - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(LocalDateTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(Instant value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => TimestampTzHandler.WriteInstant(value, buf); - - void INpgsqlSimpleTypeHandler.Write(LocalDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => TimestampHandler.WriteLocalDateTime(value, buf); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - - void INpgsqlSimpleTypeHandler.Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/LegacyTimestampTzHandler.cs b/src/Npgsql.NodaTime/Internal/LegacyTimestampTzHandler.cs deleted file mode 100644 index c299193343..0000000000 --- a/src/Npgsql.NodaTime/Internal/LegacyTimestampTzHandler.cs +++ /dev/null @@ -1,121 +0,0 @@ -using System; -using NodaTime; -using NodaTime.TimeZones; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using BclTimestampTzHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.TimestampTzHandler; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class LegacyTimestampTzHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - readonly IDateTimeZoneProvider _dateTimeZoneProvider; - readonly TimestampTzHandler _wrappedHandler; - - public LegacyTimestampTzHandler(PostgresType postgresType) - : base(postgresType) - { - _dateTimeZoneProvider = DateTimeZoneProviders.Tzdb; - _wrappedHandler = new TimestampTzHandler(postgresType); - } - - #region Read - - public override Instant Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => _wrappedHandler.Read(buf, len, fieldDescription); - - ZonedDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - try - { - var instant = Read(buf, len, fieldDescription); - - if (!DisableDateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) - throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); - - return instant.InZone(_dateTimeZoneProvider[buf.Connection.Timezone]); - } - catch (Exception e) when ( - string.Equals(buf.Connection.Timezone, "localtime", StringComparison.OrdinalIgnoreCase) && - (e is TimeZoneNotFoundException || e is DateTimeZoneNotFoundException)) - { - throw new TimeZoneNotFoundException( - "The special PostgreSQL timezone 'localtime' is not supported when reading values of type 'timestamp with time zone'. " + - "Please specify a real timezone in 'postgresql.conf' on the server, or set the 'PGTZ' environment variable on the client.", - e); - } - } - - OffsetDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ((INpgsqlSimpleTypeHandler)this).Read(buf, len, fieldDescription).ToOffsetDateTime(); - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _wrappedHandler.Read(buf, len, fieldDescription); - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _wrappedHandler.Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _wrappedHandler.Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(Instant value, NpgsqlParameter? parameter) - => 8; - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(ZonedDateTime value, NpgsqlParameter? parameter) - => 8; - - public int ValidateAndGetLength(OffsetDateTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(Instant value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _wrappedHandler.Write(value, buf, parameter); - - void INpgsqlSimpleTypeHandler.Write(ZonedDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var instant = value.ToInstant(); - - if (!DisableDateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) - throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); - - _wrappedHandler.Write(instant, buf, parameter); - } - - public void Write(OffsetDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var instant = value.ToInstant(); - - if (!DisableDateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) - throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); - - _wrappedHandler.Write(instant, buf, parameter); - } - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_wrappedHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_wrappedHandler).Write(value, buf, parameter); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_wrappedHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_wrappedHandler).Write(value, buf, parameter); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(long value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_wrappedHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_wrappedHandler).Write(value, buf, parameter); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs b/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs new file mode 100644 index 0000000000..e6be7fe69b --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs @@ -0,0 +1,52 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using Npgsql.NodaTime.Properties; + +namespace Npgsql.NodaTime.Internal; + +sealed class LocalDateConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public LocalDateConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + + protected override LocalDate ReadCore(PgReader reader) + => reader.ReadInt32() switch + { + int.MaxValue => _dateTimeInfinityConversions + ? LocalDate.MaxIsoValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + int.MinValue => _dateTimeInfinityConversions + ? LocalDate.MinIsoValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + var value => new LocalDate().PlusDays(value + 730119) + }; + + protected override void WriteCore(PgWriter writer, LocalDate value) + { + if (_dateTimeInfinityConversions) + { + if (value == LocalDate.MaxIsoValue) + { + writer.WriteInt32(int.MaxValue); + return; + } + if (value == LocalDate.MinIsoValue) + { + writer.WriteInt32(int.MinValue); + return; + } + } + + var totalDaysSinceEra = Period.Between(default, value, PeriodUnits.Days).Days; + writer.WriteInt32(totalDaysSinceEra - 730119); + } +} diff --git a/src/Npgsql.NodaTime/Internal/LocalTimeConverter.cs b/src/Npgsql.NodaTime/Internal/LocalTimeConverter.cs new file mode 100644 index 0000000000..5849f45dfc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/LocalTimeConverter.cs @@ -0,0 +1,20 @@ +using NodaTime; +using Npgsql.Internal; + +namespace Npgsql.NodaTime.Internal; + +sealed class LocalTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + // PostgreSQL time resolution == 1 microsecond == 10 ticks + protected override LocalTime ReadCore(PgReader reader) + => LocalTime.FromTicksSinceMidnight(reader.ReadInt64() * 10); + + protected override void WriteCore(PgWriter writer, LocalTime value) + => writer.WriteInt64(value.TickOfDay / 10); +} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeHandlerResolver.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeHandlerResolver.cs deleted file mode 100644 index c0b1cc60c6..0000000000 --- a/src/Npgsql.NodaTime/Internal/NodaTimeTypeHandlerResolver.cs +++ /dev/null @@ -1,155 +0,0 @@ -using System; -using NodaTime; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; - -namespace Npgsql.NodaTime.Internal; - -public class NodaTimeTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlDatabaseInfo _databaseInfo; - - readonly NpgsqlTypeHandler _timestampHandler; - readonly NpgsqlTypeHandler _timestampTzHandler; - readonly DateHandler _dateHandler; - readonly TimeHandler _timeHandler; - readonly TimeTzHandler _timeTzHandler; - readonly IntervalHandler _intervalHandler; - - TimestampTzRangeHandler? _timestampTzRangeHandler; - DateRangeHandler? _dateRangeHandler; - DateMultirangeHandler? _dateMultirangeHandler; - TimestampTzMultirangeHandler? _timestampTzMultirangeHandler; - - NpgsqlTypeHandler? _timestampTzRangeArray; - NpgsqlTypeHandler? _dateRangeArray; - - readonly ArrayNullabilityMode _arrayNullabilityMode; - - internal NodaTimeTypeHandlerResolver(NpgsqlConnector connector) - { - _databaseInfo = connector.DatabaseInfo; - - _timestampHandler = LegacyTimestampBehavior - ? new LegacyTimestampHandler(PgType("timestamp without time zone")) - : new TimestampHandler(PgType("timestamp without time zone")); - _timestampTzHandler = LegacyTimestampBehavior - ? new LegacyTimestampTzHandler(PgType("timestamp with time zone")) - : new TimestampTzHandler(PgType("timestamp with time zone")); - _dateHandler = new DateHandler(PgType("date")); - _timeHandler = new TimeHandler(PgType("time without time zone")); - _timeTzHandler = new TimeTzHandler(PgType("time with time zone")); - _intervalHandler = new IntervalHandler(PgType("interval")); - - // Note that the range handlers are absent on some pseudo-PostgreSQL databases (e.g. CockroachDB), and multirange types - // were only introduced in PG14. So we resolve these lazily. - - _arrayNullabilityMode = connector.Settings.ArrayNullabilityMode; - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName switch - { - "timestamp" or "timestamp without time zone" => _timestampHandler, - "timestamptz" or "timestamp with time zone" => _timestampTzHandler, - "date" => _dateHandler, - "time without time zone" => _timeHandler, - "time with time zone" => _timeTzHandler, - "interval" => _intervalHandler, - - "tstzrange" => TsTzRange(), - "daterange" => DateRange(), - "tstzmultirange" => TsTzMultirange(), - "datemultirange" => DateMultirange(), - - "tstzrange[]" => TsTzRangeArray(), - "daterange[]" => DateRangeArray(), - - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => NodaTimeTypeMappingResolver.ClrTypeToDataTypeName(type) is { } dataTypeName && ResolveByDataTypeName(dataTypeName) is { } handler - ? handler - : null; - - public override NpgsqlTypeHandler? ResolveByNpgsqlDbType(NpgsqlDbType npgsqlDbType) - => npgsqlDbType switch - { - NpgsqlDbType.TimestampTzRange => TsTzRange(), - NpgsqlDbType.DateRange => DateRange(), - NpgsqlDbType.TimestampTzMultirange => TsTzMultirange(), - NpgsqlDbType.DateMultirange => DateMultirange(), - NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array => TsTzRangeArray(), - NpgsqlDbType.DateRange | NpgsqlDbType.Array => TsTzRangeArray(), - _ => null - }; - - public override NpgsqlTypeHandler? ResolveValueTypeGenerically(T value) - { - // This method only ever gets called for value types, and relies on the JIT specializing the method for T by eliding all the - // type checks below. - - if (typeof(T) == typeof(Instant)) - return LegacyTimestampBehavior ? _timestampHandler : _timestampTzHandler; - - if (typeof(T) == typeof(LocalDateTime)) - return _timestampHandler; - if (typeof(T) == typeof(ZonedDateTime)) - return _timestampTzHandler; - if (typeof(T) == typeof(OffsetDateTime)) - return _timestampTzHandler; - if (typeof(T) == typeof(LocalDate)) - return _dateHandler; - if (typeof(T) == typeof(LocalTime)) - return _timeHandler; - if (typeof(T) == typeof(OffsetTime)) - return _timeTzHandler; - if (typeof(T) == typeof(Period)) - return _intervalHandler; - if (typeof(T) == typeof(Duration)) - return _intervalHandler; - - if (typeof(T) == typeof(Interval)) - return _timestampTzRangeHandler; - if (typeof(T) == typeof(NpgsqlRange)) - return _timestampTzRangeHandler; - if (typeof(T) == typeof(NpgsqlRange)) - return _timestampTzRangeHandler; - if (typeof(T) == typeof(NpgsqlRange)) - return _timestampTzRangeHandler; - - // Note that DateInterval is a reference type, so not included in this method - if (typeof(T) == typeof(NpgsqlRange)) - return _dateRangeHandler; - - return null; - } - - PostgresType PgType(string pgTypeName) => _databaseInfo.GetPostgresTypeByName(pgTypeName); - - TimestampTzRangeHandler TsTzRange() - => _timestampTzRangeHandler ??= new TimestampTzRangeHandler(PgType("tstzrange"), _timestampTzHandler); - - DateRangeHandler DateRange() - => _dateRangeHandler ??= new DateRangeHandler(PgType("daterange"), _dateHandler); - - NpgsqlTypeHandler TsTzMultirange() - => _timestampTzMultirangeHandler ??= - new TimestampTzMultirangeHandler((PostgresMultirangeType)PgType("tstzmultirange"), TsTzRange()); - - NpgsqlTypeHandler DateMultirange() - => _dateMultirangeHandler ??= new DateMultirangeHandler((PostgresMultirangeType)PgType("datemultirange"), DateRange()); - - NpgsqlTypeHandler TsTzRangeArray() - => _timestampTzRangeArray ??= - new ArrayHandler((PostgresArrayType)PgType("tstzrange[]"), TsTzRange(), _arrayNullabilityMode); - - NpgsqlTypeHandler DateRangeArray() - => _dateRangeArray ??= - new ArrayHandler((PostgresArrayType)PgType("daterange[]"), DateRange(), _arrayNullabilityMode); -} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeHandlerResolverFactory.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeHandlerResolverFactory.cs deleted file mode 100644 index d1034e7f5e..0000000000 --- a/src/Npgsql.NodaTime/Internal/NodaTimeTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.NodaTime.Internal; - -public class NodaTimeTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new NodaTimeTypeHandlerResolver(connector); - - public override TypeMappingResolver CreateMappingResolver() => new NodaTimeTypeMappingResolver(); - - public override TypeMappingResolver CreateGlobalMappingResolver() => new NodaTimeTypeMappingResolver(); -} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolver.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolver.cs new file mode 100644 index 0000000000..66dcfc35dc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolver.cs @@ -0,0 +1,265 @@ +using System; +using System.Collections.Generic; +using NodaTime; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; +using static Npgsql.Internal.PgConverterFactory; + +namespace Npgsql.NodaTime.Internal; + +sealed class NodaTimeTypeInfoResolver : IPgTypeInfoResolver +{ + static DataTypeName TimestampTzDataTypeName => new("pg_catalog.timestamptz"); + static DataTypeName TimestampDataTypeName => new("pg_catalog.timestamp"); + static DataTypeName DateDataTypeName => new("pg_catalog.date"); + static DataTypeName TimeDataTypeName => new("pg_catalog.time"); + static DataTypeName TimeTzDataTypeName => new("pg_catalog.timetz"); + static DataTypeName IntervalDataTypeName => new("pg_catalog.interval"); + + static DataTypeName DateRangeDataTypeName => new("pg_catalog.daterange"); + static DataTypeName DateMultirangeDataTypeName => new("pg_catalog.datemultirange"); + static DataTypeName TimestampTzRangeDataTypeName => new("pg_catalog.tstzrange"); + static DataTypeName TimestampTzMultirangeDataTypeName => new("pg_catalog.tstzmultirange"); + static DataTypeName TimestampRangeDataTypeName => new("pg_catalog.tsrange"); + static DataTypeName TimestampMultirangeDataTypeName => new("pg_catalog.tsmultirange"); + + TypeInfoMappingCollection Mappings { get; } + + public NodaTimeTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + { + // timestamp and timestamptz, legacy and non-legacy modes + if (LegacyTimestampBehavior) + { + // timestamptz + mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: false); + mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + static (options, mapping, _) => + mapping.CreateInfo(options, new LegacyTimestampTzZonedDateTimeConverter( + DateTimeZoneProviders.Tzdb[options.TimeZone], options.EnableDateTimeInfinityConversions))); + mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + static (options, mapping, _) => + mapping.CreateInfo(options, new LegacyTimestampTzOffsetDateTimeConverter( + DateTimeZoneProviders.Tzdb[options.TimeZone], options.EnableDateTimeInfinityConversions))); + + // timestamp + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), + isDefault: true); + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), + isDefault: false); + } + else + { + // timestamptz + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + static (options, mapping, _) => + mapping.CreateInfo(options, new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions))); + mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + static (options, mapping, _) => + mapping.CreateInfo(options, new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions))); + + // timestamp + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + } + + // date + mappings.AddStructType(DateDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + + // time + mappings.AddStructType(TimeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new LocalTimeConverter()), isDefault: true); + + // timetz + mappings.AddStructType(TimeTzDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new OffsetTimeConverter()), isDefault: true); + + // interval + mappings.AddType(IntervalDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new PeriodConverter()), isDefault: true); + mappings.AddStructType(IntervalDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new DurationConverter())); + + // tstzrange + mappings.AddStructType(TimestampTzRangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new IntervalConverter(CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options))), isDefault: true); + mappings.AddStructType>(TimestampTzRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options))); + mappings.AddStructType>(TimestampTzRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions), options))); + mappings.AddStructType>(TimestampTzRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions), options))); + + // tstzmultirange + mappings.AddType(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(new IntervalConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options)), options)), + isDefault: true); + mappings.AddType>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(new IntervalConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options)), options))); + mappings.AddType[]>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType[]>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType[]>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options))); + + // tsrange + mappings.AddStructType>(TimestampRangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions), options)), + isDefault: true); + + // tsmultirange + mappings.AddType[]>(TimestampMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(TimestampMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options))); + + // daterange + mappings.AddType(DateRangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new DateIntervalConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), + options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType>(DateRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options))); + + // datemultirange + mappings.AddType(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(new DateIntervalConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), + options.EnableDateTimeInfinityConversions), options)), + isDefault: true); + mappings.AddType>(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(new DateIntervalConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), + options.EnableDateTimeInfinityConversions), options))); + mappings.AddType[]>(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + // timestamptz + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + + // timestamp + if (LegacyTimestampBehavior) + { + mappings.AddStructArrayType(TimestampDataTypeName); + + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), + isDefault: true); + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), + isDefault: false); + } + else + { + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), + isDefault: true); + } + mappings.AddStructArrayType(TimestampDataTypeName); + + // other + mappings.AddStructArrayType(DateDataTypeName); + mappings.AddStructArrayType(TimeDataTypeName); + mappings.AddStructArrayType(TimeTzDataTypeName); + mappings.AddArrayType(IntervalDataTypeName); + mappings.AddStructArrayType(IntervalDataTypeName); + + // tstzrange + mappings.AddStructArrayType(TimestampTzRangeDataTypeName); + mappings.AddStructArrayType>(TimestampTzRangeDataTypeName); + mappings.AddStructArrayType>(TimestampTzRangeDataTypeName); + mappings.AddStructArrayType>(TimestampTzRangeDataTypeName); + + // tstzmultirange + mappings.AddArrayType(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType[]>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType[]>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType[]>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampTzMultirangeDataTypeName); + + // tsrange + mappings.AddStructArrayType>(TimestampRangeDataTypeName); + + // tsmultirange + mappings.AddArrayType[]>(TimestampMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampMultirangeDataTypeName); + + // daterange + mappings.AddArrayType(DateRangeDataTypeName); + mappings.AddStructArrayType>(DateRangeDataTypeName); + + // datemultirange + mappings.AddArrayType(DateMultirangeDataTypeName); + mappings.AddArrayType>(DateMultirangeDataTypeName); + mappings.AddArrayType[]>(DateMultirangeDataTypeName); + mappings.AddArrayType>>(DateMultirangeDataTypeName); + } +} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeMappingResolver.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeMappingResolver.cs deleted file mode 100644 index dd5f271050..0000000000 --- a/src/Npgsql.NodaTime/Internal/NodaTimeTypeMappingResolver.cs +++ /dev/null @@ -1,99 +0,0 @@ -using System; -using System.Collections.Generic; -using NodaTime; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using NpgsqlTypes; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; - -namespace Npgsql.NodaTime.Internal; - -public class NodaTimeTypeMappingResolver : TypeMappingResolver -{ - public override string? GetDataTypeNameByClrType(Type type) - => ClrTypeToDataTypeName(type); - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => DoGetMappingByDataTypeName(dataTypeName); - - static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName) - => dataTypeName switch - { - "timestamp" or "timestamp without time zone" => new(NpgsqlDbType.Timestamp, "timestamp without time zone"), - "timestamptz" or "timestamp with time zone" => new(NpgsqlDbType.TimestampTz, "timestamp with time zone"), - "date" => new(NpgsqlDbType.Date, "date"), - "time without time zone" => new(NpgsqlDbType.Time, "time without time zone"), - "time with time zone" => new(NpgsqlDbType.TimeTz, "time with time zone"), - "interval" => new(NpgsqlDbType.Interval, "interval"), - - "tsrange" => new(NpgsqlDbType.TimestampRange, "tsrange"), - "tstzrange" => new(NpgsqlDbType.TimestampTzRange, "tstzrange"), - "daterange" => new(NpgsqlDbType.DateRange, "daterange"), - - "tsmultirange" => new(NpgsqlDbType.TimestampMultirange, "tsmultirange"), - "tstzmultirange" => new(NpgsqlDbType.TimestampTzMultirange, "tstzmultirange"), - "datemultirange" => new(NpgsqlDbType.DateMultirange, "datemultirange"), - - _ => null - }; - - internal static string? ClrTypeToDataTypeName(Type type) - { - if (type == typeof(Instant)) - return LegacyTimestampBehavior ? "timestamp without time zone" : "timestamp with time zone"; - - if (type == typeof(LocalDateTime)) - return "timestamp without time zone"; - if (type == typeof(ZonedDateTime) || type == typeof(OffsetDateTime)) - return "timestamp with time zone"; - if (type == typeof(LocalDate)) - return "date"; - if (type == typeof(LocalTime)) - return "time without time zone"; - if (type == typeof(OffsetTime)) - return "time with time zone"; - if (type == typeof(Period) || type == typeof(Duration)) - return "interval"; - - // Ranges - if (type == typeof(NpgsqlRange)) - return "tsrange"; - - if (type == typeof(Interval) || - type == typeof(NpgsqlRange) || - type == typeof(NpgsqlRange) || - type == typeof(NpgsqlRange)) - { - return "tstzrange"; - } - - if (type == typeof(DateInterval) || type == typeof(NpgsqlRange)) - return "daterange"; - - // Multiranges - if (type == typeof(NpgsqlRange[]) || type == typeof(List>)) - return "tsmultirange"; - - if (type == typeof(Interval[]) || - type == typeof(List) || - type == typeof(NpgsqlRange[]) || - type == typeof(List>) || - type == typeof(NpgsqlRange[]) || - type == typeof(List>) || - type == typeof(NpgsqlRange[]) || - type == typeof(List>)) - { - return "tstzmultirange"; - } - if (type == typeof(DateInterval[]) || - type == typeof(List) || - type == typeof(NpgsqlRange[]) || - type == typeof(List>)) - { - return "datemultirange"; - } - - return null; - } -} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs b/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs index ff37bdd196..1cf433759a 100644 --- a/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs +++ b/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs @@ -1,5 +1,6 @@ using System; using NodaTime; +using Npgsql.NodaTime.Properties; namespace Npgsql.NodaTime.Internal; @@ -7,17 +8,11 @@ static class NodaTimeUtils { #if DEBUG internal static bool LegacyTimestampBehavior; - internal static bool DisableDateTimeInfinityConversions; #else internal static readonly bool LegacyTimestampBehavior; - internal static readonly bool DisableDateTimeInfinityConversions; #endif - static NodaTimeUtils() - { - LegacyTimestampBehavior = AppContext.TryGetSwitch("Npgsql.EnableLegacyTimestampBehavior", out var enabled) && enabled; - DisableDateTimeInfinityConversions = AppContext.TryGetSwitch("Npgsql.DisableDateTimeInfinityConversions", out enabled) && enabled; - } + static NodaTimeUtils() => LegacyTimestampBehavior = AppContext.TryGetSwitch("Npgsql.EnableLegacyTimestampBehavior", out var enabled) && enabled; static readonly Instant Instant2000 = Instant.FromUtc(2000, 1, 1, 0, 0, 0); static readonly Duration Plus292Years = Duration.FromDays(292 * 365); @@ -27,17 +22,36 @@ static NodaTimeUtils() /// Decodes a PostgreSQL timestamp/timestamptz into a NodaTime Instant. /// /// The number of microseconds from 2000-01-01T00:00:00. + /// Whether infinity date/time conversions are enabled. /// /// Unfortunately NodaTime doesn't have Duration.FromMicroseconds(), so we decompose into milliseconds and nanoseconds. /// - internal static Instant DecodeInstant(long value) - => Instant2000 + Duration.FromMilliseconds(value / 1000) + Duration.FromNanoseconds(value % 1000 * 1000); + internal static Instant DecodeInstant(long value, bool dateTimeInfinityConversions) + => value switch + { + long.MaxValue => dateTimeInfinityConversions + ? Instant.MaxValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + long.MinValue => dateTimeInfinityConversions + ? Instant.MinValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + _ => Instant2000 + Duration.FromMilliseconds(value / 1000) + Duration.FromNanoseconds(value % 1000 * 1000) + }; /// /// Encodes a NodaTime Instant to a PostgreSQL timestamp/timestamptz. /// - internal static long EncodeInstant(Instant instant) + internal static long EncodeInstant(Instant instant, bool dateTimeInfinityConversions) { + if (dateTimeInfinityConversions) + { + if (instant == Instant.MaxValue) + return long.MaxValue; + + if (instant == Instant.MinValue) + return long.MinValue; + } + // We need to write the number of microseconds from 2000-01-01T00:00:00. var since2000 = instant - Instant2000; @@ -46,4 +60,4 @@ internal static long EncodeInstant(Instant instant) ? since2000.ToInt64Nanoseconds() / 1000 : (long)(since2000.ToBigIntegerNanoseconds() / 1000); } -} \ No newline at end of file +} diff --git a/src/Npgsql.NodaTime/Internal/OffsetTimeConverter.cs b/src/Npgsql.NodaTime/Internal/OffsetTimeConverter.cs new file mode 100644 index 0000000000..7c5499c2f8 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/OffsetTimeConverter.cs @@ -0,0 +1,23 @@ +using NodaTime; +using Npgsql.Internal; + +namespace Npgsql.NodaTime.Internal; + +sealed class OffsetTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int)); + return format is DataFormat.Binary; + } + + // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. + protected override OffsetTime ReadCore(PgReader reader) + => new(LocalTime.FromTicksSinceMidnight(reader.ReadInt64() * 10), Offset.FromSeconds(-reader.ReadInt32())); + + protected override void WriteCore(PgWriter writer, OffsetTime value) + { + writer.WriteInt64(value.TickOfDay / 10); + writer.WriteInt32(-(int)(value.Offset.Ticks / NodaConstants.TicksPerSecond)); + } +} diff --git a/src/Npgsql.NodaTime/Internal/PeriodConverter.cs b/src/Npgsql.NodaTime/Internal/PeriodConverter.cs new file mode 100644 index 0000000000..4dbde48dbc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/PeriodConverter.cs @@ -0,0 +1,46 @@ +using NodaTime; +using Npgsql.Internal; + +namespace Npgsql.NodaTime.Internal; + +sealed class PeriodConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override Period ReadCore(PgReader reader) + { + var microsecondsInDay = reader.ReadInt64(); + var days = reader.ReadInt32(); + var totalMonths = reader.ReadInt32(); + + // NodaTime will normalize most things (i.e. nanoseconds to milliseconds, seconds...) + // but it will not normalize months to years. + var months = totalMonths % 12; + var years = totalMonths / 12; + + return new PeriodBuilder + { + Nanoseconds = microsecondsInDay * 1000, + Days = days, + Months = months, + Years = years + }.Build().Normalize(); + } + + protected override void WriteCore(PgWriter writer, Period value) + { + // Note that the end result must be long + // see #3438 + var microsecondsInDay = + (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * NodaConstants.MillisecondsPerSecond + value.Milliseconds) * 1000 + + value.Nanoseconds / 1000; // Take the microseconds, discard the nanosecond remainder + + writer.WriteInt64(microsecondsInDay); + writer.WriteInt32(value.Weeks * 7 + value.Days); // days + writer.WriteInt32(value.Years * 12 + value.Months); // months + } +} diff --git a/src/Npgsql.NodaTime/Internal/TimeHandler.cs b/src/Npgsql.NodaTime/Internal/TimeHandler.cs deleted file mode 100644 index 5171745764..0000000000 --- a/src/Npgsql.NodaTime/Internal/TimeHandler.cs +++ /dev/null @@ -1,53 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using BclTimeHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.TimeHandler; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class TimeHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -#if NET6_0_OR_GREATER - , INpgsqlSimpleTypeHandler -#endif -{ - readonly BclTimeHandler _bclHandler; - - internal TimeHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclTimeHandler(postgresType); - - // PostgreSQL time resolution == 1 microsecond == 10 ticks - public override LocalTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => LocalTime.FromTicksSinceMidnight(buf.ReadInt64() * 10); - - public override int ValidateAndGetLength(LocalTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(LocalTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value.TickOfDay / 10); - - TimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - -#if NET6_0_OR_GREATER - TimeOnly INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - public int ValidateAndGetLength(TimeOnly value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - public void Write(TimeOnly value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); -#endif -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/TimeTzHandler.cs b/src/Npgsql.NodaTime/Internal/TimeTzHandler.cs deleted file mode 100644 index d8ace650dc..0000000000 --- a/src/Npgsql.NodaTime/Internal/TimeTzHandler.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using BclTimeTzHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.TimeTzHandler; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class TimeTzHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - readonly BclTimeTzHandler _bclHandler; - - internal TimeTzHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclTimeTzHandler(postgresType); - - // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. - public override OffsetTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new( - LocalTime.FromTicksSinceMidnight(buf.ReadInt64() * 10), - Offset.FromSeconds(-buf.ReadInt32())); - - public override int ValidateAndGetLength(OffsetTime value, NpgsqlParameter? parameter) => 12; - - public override void Write(OffsetTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.TickOfDay / 10); - buf.WriteInt32(-(int)(value.Offset.Ticks / NodaConstants.TicksPerSecond)); - } - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/TimestampConverters.cs b/src/Npgsql.NodaTime/Internal/TimestampConverters.cs new file mode 100644 index 0000000000..6808503638 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/TimestampConverters.cs @@ -0,0 +1,106 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; + +namespace Npgsql.NodaTime.Internal; + +sealed class InstantConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public InstantConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override Instant ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + + protected override void WriteCore(PgWriter writer, Instant value) + => writer.WriteInt64(EncodeInstant(value, _dateTimeInfinityConversions)); +} + +sealed class ZonedDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public ZonedDateTimeConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override ZonedDateTime ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).InUtc(); + + protected override void WriteCore(PgWriter writer, ZonedDateTime value) + { + if (value.Zone != DateTimeZone.Utc && !LegacyTimestampBehavior) + { + throw new ArgumentException( + $"Cannot write ZonedDateTime with Zone={value.Zone} to PostgreSQL type 'timestamp with time zone', " + + "only UTC is supported. " + + "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); + } + + writer.WriteInt64(EncodeInstant(value.ToInstant(), _dateTimeInfinityConversions)); + } +} + +sealed class OffsetDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public OffsetDateTimeConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override OffsetDateTime ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).WithOffset(Offset.Zero); + + protected override void WriteCore(PgWriter writer, OffsetDateTime value) + { + if (value.Offset != Offset.Zero && !LegacyTimestampBehavior) + { + throw new ArgumentException( + $"Cannot write OffsetDateTime with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', " + + "only offset 0 (UTC) is supported. " + + "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); + } + + writer.WriteInt64(EncodeInstant(value.ToInstant(), _dateTimeInfinityConversions)); + } +} + +sealed class LocalDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public LocalDateTimeConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override LocalDateTime ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).InUtc().LocalDateTime; + + protected override void WriteCore(PgWriter writer, LocalDateTime value) + => writer.WriteInt64(EncodeInstant(value.InUtc().ToInstant(), _dateTimeInfinityConversions)); +} diff --git a/src/Npgsql.NodaTime/Internal/TimestampHandler.cs b/src/Npgsql.NodaTime/Internal/TimestampHandler.cs deleted file mode 100644 index 15c254e3d0..0000000000 --- a/src/Npgsql.NodaTime/Internal/TimestampHandler.cs +++ /dev/null @@ -1,88 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.NodaTime.Properties; -using Npgsql.PostgresTypes; -using BclTimestampHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.TimestampHandler; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class TimestampHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - readonly BclTimestampHandler _bclHandler; - - internal TimestampHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclTimestampHandler(postgresType); - - #region Read - - public override LocalDateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ReadLocalDateTime(buf); - - internal static LocalDateTime ReadLocalDateTime(NpgsqlReadBuffer buf) - => buf.ReadInt64() switch - { - long.MaxValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue) - : LocalDateTime.MaxIsoValue, - long.MinValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue) - : LocalDateTime.MinIsoValue, - var value => DecodeInstant(value).InUtc().LocalDateTime - }; - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(LocalDateTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(LocalDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => WriteLocalDateTime(value, buf); - - internal static void WriteLocalDateTime(LocalDateTime value, NpgsqlWriteBuffer buf) - { - if (!DisableDateTimeInfinityConversions) - { - if (value == LocalDateTime.MaxIsoValue) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value == LocalDateTime.MinIsoValue) - { - buf.WriteInt64(long.MinValue); - return; - } - } - - buf.WriteInt64(EncodeInstant(value.InUtc().ToInstant())); - } - - public int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - - void INpgsqlSimpleTypeHandler.Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/TimestampTzHandler.cs b/src/Npgsql.NodaTime/Internal/TimestampTzHandler.cs deleted file mode 100644 index fa1924656a..0000000000 --- a/src/Npgsql.NodaTime/Internal/TimestampTzHandler.cs +++ /dev/null @@ -1,126 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.NodaTime.Properties; -using Npgsql.PostgresTypes; -using BclTimestampTzHandler = Npgsql.Internal.TypeHandlers.DateTimeHandlers.TimestampTzHandler; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; - -namespace Npgsql.NodaTime.Internal; - -sealed partial class TimestampTzHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - readonly BclTimestampTzHandler _bclHandler; - - public TimestampTzHandler(PostgresType postgresType) - : base(postgresType) - => _bclHandler = new BclTimestampTzHandler(postgresType); - - #region Read - - public override Instant Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadInstant(buf); - - internal static Instant ReadInstant(NpgsqlReadBuffer buf) - => buf.ReadInt64() switch - { - long.MaxValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue) - : Instant.MaxValue, - long.MinValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue) - : Instant.MinValue, - var value => DecodeInstant(value) - }; - - ZonedDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription).InUtc(); - - OffsetDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription).WithOffset(Offset.Zero); - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(Instant value, NpgsqlParameter? parameter) - => 8; - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(ZonedDateTime value, NpgsqlParameter? parameter) - => value.Zone == DateTimeZone.Utc || LegacyTimestampBehavior - ? 8 - : throw new InvalidCastException( - $"Cannot write ZonedDateTime with Zone={value.Zone} to PostgreSQL type 'timestamp with time zone', " + - "only UTC is supported. " + - "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); - - public int ValidateAndGetLength(OffsetDateTime value, NpgsqlParameter? parameter) - => value.Offset == Offset.Zero || LegacyTimestampBehavior - ? 8 - : throw new InvalidCastException( - $"Cannot write OffsetDateTime with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', " + - "only offset 0 (UTC) is supported. " + - "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); - - public override void Write(Instant value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => WriteInstant(value, buf); - - internal static void WriteInstant(Instant value, NpgsqlWriteBuffer buf) - { - if (!DisableDateTimeInfinityConversions) - { - if (value == Instant.MaxValue) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value == Instant.MinValue) - { - buf.WriteInt64(long.MinValue); - return; - } - } - - buf.WriteInt64(EncodeInstant(value)); - } - - void INpgsqlSimpleTypeHandler.Write(ZonedDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => Write(value.ToInstant(), buf, parameter); - - public void Write(OffsetDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => Write(value.ToInstant(), buf, parameter); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - void INpgsqlSimpleTypeHandler.Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/TimestampTzMultirangeHandler.cs b/src/Npgsql.NodaTime/Internal/TimestampTzMultirangeHandler.cs deleted file mode 100644 index a13bb091b2..0000000000 --- a/src/Npgsql.NodaTime/Internal/TimestampTzMultirangeHandler.cs +++ /dev/null @@ -1,202 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.NodaTime.Internal; - -public partial class TimestampTzMultirangeHandler : MultirangeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler>, - INpgsqlTypeHandler[]>, INpgsqlTypeHandler>>, - INpgsqlTypeHandler[]>, INpgsqlTypeHandler>>, - INpgsqlTypeHandler[]>, INpgsqlTypeHandler>>, - INpgsqlTypeHandler[]>, INpgsqlTypeHandler>> -{ - readonly INpgsqlTypeHandler _intervalHandler; - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(Interval[]); - - public TimestampTzMultirangeHandler(PostgresMultirangeType pgMultirangeType, TimestampTzRangeHandler rangeHandler) - : base(pgMultirangeType, rangeHandler) - => _intervalHandler = rangeHandler; - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription) - { - await buf.Ensure(4, async); - var numRanges = buf.ReadInt32(); - var multirange = new Interval[numRanges]; - - for (var i = 0; i < multirange.Length; i++) - { - await buf.Ensure(4, async); - var rangeLen = buf.ReadInt32(); - multirange[i] = await _intervalHandler.Read(buf, rangeLen, async, fieldDescription); - } - - return multirange; - } - - async ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(4, async); - var numRanges = buf.ReadInt32(); - var multirange = new List(numRanges); - - for (var i = 0; i < numRanges; i++) - { - await buf.Ensure(4, async); - var rangeLen = buf.ReadInt32(); - multirange.Add(await _intervalHandler.Read(buf, rangeLen, async, fieldDescription)); - } - - return multirange; - } - - public int ValidateAndGetLength(List value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCore(value, ref lengthCache); - - public int ValidateAndGetLength(Interval[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCore(value, ref lengthCache); - - int ValidateAndGetLengthCore(IList value, ref NpgsqlLengthCache? lengthCache) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var sum = 4 + 4 * value.Count; - for (var i = 0; i < value.Count; i++) - sum += _intervalHandler.ValidateAndGetLength(value[i], ref lengthCache, parameter: null); - - return lengthCache!.Set(sum); - } - - public async Task Write(Interval[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(value.Length); - - for (var i = 0; i < value.Length; i++) - await RangeHandler.WriteWithLength(value[i], buf, lengthCache, parameter: null, async, cancellationToken); - } - - public async Task Write(List value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(value.Count); - - for (var i = 0; i < value.Count; i++) - await RangeHandler.WriteWithLength(value[i], buf, lengthCache, parameter: null, async, cancellationToken); - } - - #region Boilerplate - - ValueTask[]> INpgsqlTypeHandler[]>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeArray(buf, len, async, fieldDescription); - - ValueTask>> INpgsqlTypeHandler>>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeList(buf, len, async, fieldDescription); - - ValueTask[]> INpgsqlTypeHandler[]>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeArray(buf, len, async, fieldDescription); - - ValueTask>> INpgsqlTypeHandler>>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeList(buf, len, async, fieldDescription); - - ValueTask[]> INpgsqlTypeHandler[]>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeArray(buf, len, async, fieldDescription); - - ValueTask>> INpgsqlTypeHandler>>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeList(buf, len, async, fieldDescription); - - ValueTask[]> INpgsqlTypeHandler[]>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeArray(buf, len, async, fieldDescription); - - ValueTask>> INpgsqlTypeHandler>>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeList(buf, len, async, fieldDescription); - - public int ValidateAndGetLength(NpgsqlRange[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(List> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(List> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(List> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(List> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public Task Write(NpgsqlRange[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(List> value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(NpgsqlRange[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(List> value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(NpgsqlRange[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(List> value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(NpgsqlRange[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(List> value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion Boilerplate -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/Internal/TimestampTzRangeHandler.cs b/src/Npgsql.NodaTime/Internal/TimestampTzRangeHandler.cs deleted file mode 100644 index 8205cc17ef..0000000000 --- a/src/Npgsql.NodaTime/Internal/TimestampTzRangeHandler.cs +++ /dev/null @@ -1,105 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.NodaTime.Internal; - -public partial class TimestampTzRangeHandler : RangeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler>, INpgsqlTypeHandler>, - INpgsqlTypeHandler>, INpgsqlTypeHandler> -{ - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(Interval); - - public TimestampTzRangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler subtypeHandler) - : base(rangePostgresType, subtypeHandler) - { - } - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - // internal Interval ConvertRangetoInterval(NpgsqlRange range) - async ValueTask INpgsqlTypeHandler.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - var range = await Read(buf, len, async, fieldDescription); - - // NodaTime Interval includes the start instant and excludes the end instant. - Instant? start = range.LowerBoundInfinite - ? null - : range.LowerBoundIsInclusive - ? range.LowerBound - : range.LowerBound + Duration.Epsilon; - Instant? end = range.UpperBoundInfinite - ? null - : range.UpperBoundIsInclusive - ? range.UpperBound + Duration.Epsilon - : range.UpperBound; - return new(start, end); - } - - public int ValidateAndGetLength(Interval value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange( - new NpgsqlRange(value.HasStart ? value.Start : default, true, !value.HasStart, value.HasEnd ? value.End : default, false, !value.HasEnd), ref lengthCache, parameter); - - public Task Write(Interval value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(new NpgsqlRange(value.HasStart ? value.Start : default, true, !value.HasStart, value.HasEnd ? value.End : default, false, !value.HasEnd), - buf, lengthCache, parameter, async, cancellationToken); - - #region Boilerplate - - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadRange(buf, len, async, fieldDescription); - - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadRange(buf, len, async, fieldDescription); - - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadRange(buf, len, async, fieldDescription); - - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadRange(buf, len, async, fieldDescription); - - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion Boilerplate -} \ No newline at end of file diff --git a/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs b/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs index 9fe67ec485..030f1ec1be 100644 --- a/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs +++ b/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs @@ -15,7 +15,7 @@ public static class NpgsqlNodaTimeExtensions /// The type mapper to set up (global or connection-specific) public static INpgsqlTypeMapper UseNodaTime(this INpgsqlTypeMapper mapper) { - mapper.AddTypeResolverFactory(new NodaTimeTypeHandlerResolverFactory()); + mapper.AddTypeInfoResolver(new NodaTimeTypeInfoResolver()); return mapper; } -} \ No newline at end of file +} diff --git a/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs b/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs index cf71b6d0b6..a03d5a93d6 100644 --- a/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs +++ b/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs @@ -4,7 +4,7 @@ [module: SkipLocalsInit] #endif -[assembly: InternalsVisibleTo("Npgsql.NodaTime.Tests, PublicKey=" + +[assembly: InternalsVisibleTo("Npgsql.PluginTests, PublicKey=" + "0024000004800000940000000602000000240000525341310004000001000100" + "2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + "8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + diff --git a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs index e47b9140b5..bc6511ea9a 100644 --- a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs +++ b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs @@ -50,5 +50,11 @@ internal static string CannotReadInfinityValue { return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); } } + + internal static string CannotReadIntervalWithMonthsAsDuration { + get { + return ResourceManager.GetString("CannotReadIntervalWithMonthsAsDuration", resourceCulture); + } + } } } diff --git a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx index d05d0c3a62..d3329f2a80 100644 --- a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx +++ b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx @@ -21,4 +21,7 @@ Cannot read infinity value since Npgsql.DisableDateTimeInfinityConversions is enabled. - \ No newline at end of file + + Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead. + + diff --git a/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj b/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj index 434936cefe..bc0f37e9bb 100644 --- a/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj +++ b/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj @@ -27,7 +27,6 @@ - diff --git a/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs index f7008610b0..a25495a40a 100644 --- a/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs +++ b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs @@ -1,4 +1,3 @@ -using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/Npgsql.SourceGenerators/TypeHandler.snbtxt b/src/Npgsql.SourceGenerators/TypeHandler.snbtxt deleted file mode 100644 index 041c948881..0000000000 --- a/src/Npgsql.SourceGenerators/TypeHandler.snbtxt +++ /dev/null @@ -1,36 +0,0 @@ -{{ for using in usings }} -using {{ using }}; -{{ end }} - -#nullable enable -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member -#pragma warning disable RS0016 // Add public types and members to the declared API -#pragma warning disable 618 // Member is obsolete - -namespace {{ namespace }} -{ - partial class {{ type_name }} - { - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - {{ for interface in interfaces }} - {{ interface.handled_type }} converted => (({{ interface.name }})this).ValidateAndGetLength(converted, {{ is_simple ? "" : "ref lengthCache, " }}parameter), - {{ end }} - - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type {{ type_name }}") - }; - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch - { - {{ for interface in interfaces }} - {{ interface.handled_type }} converted => WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - {{ end }} - - DBNull => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type {{ type_name }}") - }; - } -} diff --git a/src/Npgsql.SourceGenerators/TypeHandlerSourceGenerator.cs b/src/Npgsql.SourceGenerators/TypeHandlerSourceGenerator.cs deleted file mode 100644 index d36cc41988..0000000000 --- a/src/Npgsql.SourceGenerators/TypeHandlerSourceGenerator.cs +++ /dev/null @@ -1,129 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; -using Scriban; - -namespace Npgsql.SourceGenerators; - -[Generator] -sealed class TypeHandlerSourceGenerator : ISourceGenerator -{ - public void Initialize(GeneratorInitializationContext context) - => context.RegisterForSyntaxNotifications(() => new MySyntaxReceiver()); - - public void Execute(GeneratorExecutionContext context) - { - var compilation = context.Compilation; - - var (simpleTypeHandlerInterfaceSymbol, typeHandlerInterfaceSymbol) = ( - compilation.GetTypeByMetadataName("Npgsql.Internal.TypeHandling.INpgsqlSimpleTypeHandler`1"), - compilation.GetTypeByMetadataName("Npgsql.Internal.TypeHandling.INpgsqlTypeHandler`1")); - - if (simpleTypeHandlerInterfaceSymbol is null || typeHandlerInterfaceSymbol is null) - throw new Exception("Could not find INpgsqlSimpleTypeHandler or INpgsqlTypeHandler"); - - var template = Template.Parse(EmbeddedResource.GetContent("TypeHandler.snbtxt"), "TypeHandler.snbtxt"); - - foreach (var cds in ((MySyntaxReceiver)context.SyntaxReceiver!).TypeHandlerCandidates) - { - var semanticModel = compilation.GetSemanticModel(cds.SyntaxTree); - if (semanticModel.GetDeclaredSymbol(cds) is not INamedTypeSymbol typeSymbol) - continue; - - if (typeSymbol.AllInterfaces.Any(i => - i.OriginalDefinition.Equals(simpleTypeHandlerInterfaceSymbol, SymbolEqualityComparer.Default))) - { - AugmentTypeHandler(template, typeSymbol, cds, isSimple: true); - continue; - } - - if (typeSymbol.AllInterfaces.Any(i => - i.OriginalDefinition.Equals(typeHandlerInterfaceSymbol, SymbolEqualityComparer.Default))) - { - AugmentTypeHandler(template, typeSymbol, cds, isSimple: false); - } - } - - void AugmentTypeHandler( - Template template, - INamedTypeSymbol typeSymbol, - ClassDeclarationSyntax classDeclarationSyntax, - bool isSimple) - { - var usings = new HashSet( - new[] - { - "System", - "System.Threading", - "System.Threading.Tasks", - "Npgsql.Internal" - }.Concat(classDeclarationSyntax.SyntaxTree.GetCompilationUnitRoot().Usings - .Where(u => u.Name is not null && u.Alias is null && u.StaticKeyword.IsKind(SyntaxKind.None)) - .Select(u => u.Name!.ToString()))); - - var interfaces = typeSymbol.AllInterfaces - .Where(i => i.OriginalDefinition.Equals(isSimple ? simpleTypeHandlerInterfaceSymbol : typeHandlerInterfaceSymbol, - SymbolEqualityComparer.Default)) - // Hacky: we want to emit switch arms for abstract types after concrete ones, since otherwise the compiled complains about - // unreachable arms - .OrderBy(i => i.TypeArguments[0].IsAbstract); - - var output = template.Render(new - { - Usings = usings, - TypeName = FormatTypeName(typeSymbol), - Namespace = typeSymbol.ContainingNamespace.ToDisplayString(), - IsSimple = isSimple, - Interfaces = interfaces.Select(i => new - { - Name = FormatTypeName(i), - HandledType = FormatTypeName(i.TypeArguments[0]), - }) - }); - - context.AddSource(typeSymbol.Name + ".Generated.cs", SourceText.From(output, Encoding.UTF8)); - } - - static string FormatTypeName(ITypeSymbol typeSymbol) - { - if (typeSymbol is INamedTypeSymbol namedTypeSymbol) - { - return namedTypeSymbol.IsGenericType - ? new StringBuilder(namedTypeSymbol.Name) - .Append('<') - .Append(string.Join(",", namedTypeSymbol.TypeArguments.Select(FormatTypeName))) - .Append('>') - .ToString() - : namedTypeSymbol.Name; - } - - if (typeSymbol.TypeKind == TypeKind.Array) - { - return $"{FormatTypeName(((IArrayTypeSymbol)typeSymbol).ElementType)}[]"; - // return "int"; - } - - return typeSymbol.ToString(); - } - } - - sealed class MySyntaxReceiver : ISyntaxReceiver - { - public List TypeHandlerCandidates { get; } = new(); - - public void OnVisitSyntaxNode(SyntaxNode syntaxNode) - { - if (syntaxNode is ClassDeclarationSyntax cds && - cds.BaseList is not null && - cds.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - TypeHandlerCandidates.Add(cds); - } - } - } -} diff --git a/src/Npgsql/BackendMessages/AuthenticationMessages.cs b/src/Npgsql/BackendMessages/AuthenticationMessages.cs index 31a6c06e24..b6320e87b8 100644 --- a/src/Npgsql/BackendMessages/AuthenticationMessages.cs +++ b/src/Npgsql/BackendMessages/AuthenticationMessages.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using Microsoft.Extensions.Logging; using Npgsql.Internal; -using Npgsql.Util; namespace Npgsql.BackendMessages; @@ -136,7 +135,7 @@ sealed class AuthenticationSCRAMServerFirstMessage internal static AuthenticationSCRAMServerFirstMessage Load(byte[] bytes, ILogger connectionLogger) { - var data = PGUtil.UTF8Encoding.GetString(bytes); + var data = NpgsqlWriteBuffer.UTF8Encoding.GetString(bytes); string? nonce = null, salt = null; var iteration = -1; @@ -188,7 +187,7 @@ sealed class AuthenticationSCRAMServerFinalMessage internal static AuthenticationSCRAMServerFinalMessage Load(byte[] bytes, ILogger connectionLogger) { - var data = PGUtil.UTF8Encoding.GetString(bytes); + var data = NpgsqlWriteBuffer.UTF8Encoding.GetString(bytes); string? serverSignature = null; foreach (var part in data.Split(',')) diff --git a/src/Npgsql/BackendMessages/CopyMessages.cs b/src/Npgsql/BackendMessages/CopyMessages.cs index 67ee5da526..1aa8aec0c2 100644 --- a/src/Npgsql/BackendMessages/CopyMessages.cs +++ b/src/Npgsql/BackendMessages/CopyMessages.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using Npgsql.Internal; -using Npgsql.Util; namespace Npgsql.BackendMessages; @@ -11,11 +10,11 @@ abstract class CopyResponseMessageBase : IBackendMessage internal bool IsBinary { get; private set; } internal short NumColumns { get; private set; } - internal List ColumnFormatCodes { get; } + internal List ColumnFormatCodes { get; } internal CopyResponseMessageBase() { - ColumnFormatCodes = new List(); + ColumnFormatCodes = new List(); } internal void Load(NpgsqlReadBuffer buf) @@ -32,7 +31,7 @@ internal void Load(NpgsqlReadBuffer buf) NumColumns = buf.ReadInt16(); for (var i = 0; i < NumColumns; i++) - ColumnFormatCodes.Add((FormatCode)buf.ReadInt16()); + ColumnFormatCodes.Add(DataFormatUtils.Create(buf.ReadInt16())); } } @@ -91,4 +90,4 @@ sealed class CopyDoneMessage : IBackendMessage public BackendMessageCode Code => BackendMessageCode.CopyDone; internal static readonly CopyDoneMessage Instance = new(); CopyDoneMessage() { } -} \ No newline at end of file +} diff --git a/src/Npgsql/BackendMessages/RowDescriptionMessage.cs b/src/Npgsql/BackendMessages/RowDescriptionMessage.cs index b2a9a6f111..a963d165c9 100644 --- a/src/Npgsql/BackendMessages/RowDescriptionMessage.cs +++ b/src/Npgsql/BackendMessages/RowDescriptionMessage.cs @@ -3,14 +3,12 @@ using System.Collections.Generic; using System.Diagnostics; using System.Globalization; +using System.Runtime.CompilerServices; +using System.Threading; using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; using Npgsql.Replication.PgOutput.Messages; -using Npgsql.TypeMapping; -using Npgsql.Util; namespace Npgsql.BackendMessages; @@ -22,12 +20,15 @@ namespace Npgsql.BackendMessages; /// sealed class RowDescriptionMessage : IBackendMessage, IReadOnlyList { + readonly bool _connectorOwned; FieldDescription?[] _fields; readonly Dictionary _nameIndex; Dictionary? _insensitiveIndex; + PgConverterInfo[]? _lastConverterInfoCache; - internal RowDescriptionMessage(int numFields = 10) + internal RowDescriptionMessage(bool connectorOwned, int numFields = 10) { + _connectorOwned = connectorOwned; _fields = new FieldDescription[numFields]; _nameIndex = new Dictionary(); } @@ -43,7 +44,7 @@ internal RowDescriptionMessage(int numFields = 10) _insensitiveIndex = new Dictionary(source._insensitiveIndex); } - internal RowDescriptionMessage Load(NpgsqlReadBuffer buf, TypeMapper typeMapper) + internal RowDescriptionMessage Load(NpgsqlReadBuffer buf, PgSerializerOptions options) { _nameIndex.Clear(); _insensitiveIndex?.Clear(); @@ -61,14 +62,14 @@ internal RowDescriptionMessage Load(NpgsqlReadBuffer buf, TypeMapper typeMapper) var field = _fields[i] ??= new(); field.Populate( - typeMapper, + options, name: buf.ReadNullTerminatedString(), tableOID: buf.ReadUInt32(), columnAttributeNumber: buf.ReadInt16(), oid: buf.ReadUInt32(), typeSize: buf.ReadInt16(), typeModifier: buf.ReadInt32(), - formatCode: (FormatCode)buf.ReadInt16() + dataFormat: DataFormatUtils.Create(buf.ReadInt16()) ); _nameIndex.TryAdd(field.Name, i); @@ -78,9 +79,9 @@ internal RowDescriptionMessage Load(NpgsqlReadBuffer buf, TypeMapper typeMapper) } internal static RowDescriptionMessage CreateForReplication( - TypeMapper typeMapper, uint tableOID, FormatCode formatCode, IReadOnlyList columns) + PgSerializerOptions options, uint tableOID, DataFormat dataFormat, IReadOnlyList columns) { - var msg = new RowDescriptionMessage(columns.Count); + var msg = new RowDescriptionMessage(false, columns.Count); var numFields = msg.Count = columns.Count; for (var i = 0; i < numFields; ++i) @@ -89,14 +90,14 @@ internal static RowDescriptionMessage CreateForReplication( var column = columns[i]; field.Populate( - typeMapper, - name: column.ColumnName, - tableOID: tableOID, + options, + name: column.ColumnName, + tableOID: tableOID, columnAttributeNumber: checked((short)i), - oid: column.DataTypeId, - typeSize: 0, // TODO: Confirm we don't have this in replication - typeModifier: column.TypeModifier, - formatCode: formatCode + oid: column.DataTypeId, + typeSize: 0, // TODO: Confirm we don't have this in replication + typeModifier: column.TypeModifier, + dataFormat: dataFormat ); if (!msg._nameIndex.ContainsKey(field.Name)) @@ -108,6 +109,7 @@ internal static RowDescriptionMessage CreateForReplication( public FieldDescription this[int index] { + [MethodImpl(MethodImplOptions.AggressiveInlining)] get { Debug.Assert(index < Count); @@ -117,6 +119,20 @@ public FieldDescription this[int index] } } + internal void SetConverterInfoCache(ReadOnlySpan values) + { + if (_connectorOwned || _lastConverterInfoCache is not null) + return; + Interlocked.CompareExchange(ref _lastConverterInfoCache, values.ToArray(), null); + } + + internal void LoadConverterInfoCache(PgConverterInfo[] values) + { + if (_lastConverterInfoCache is not { } cache) + return; + cache.CopyTo(values.AsSpan()); + } + public int Count { get; private set; } public IEnumerator GetEnumerator() => new Enumerator(this); @@ -164,7 +180,7 @@ sealed class InsensitiveComparer : IEqualityComparer public static readonly InsensitiveComparer Instance = new(); static readonly CompareInfo CompareInfo = CultureInfo.InvariantCulture.CompareInfo; - InsensitiveComparer() {} + InsensitiveComparer() { } // We should really have CompareOptions.IgnoreKanaType here, but see // https://github.com/dotnet/corefx/issues/12518#issuecomment-389658716 @@ -204,7 +220,7 @@ public bool MoveNext() } public void Reset() => _pos = -1; - public void Dispose() {} + public void Dispose() { } } } @@ -215,14 +231,14 @@ public void Dispose() {} public sealed class FieldDescription { #pragma warning disable CS8618 // Lazy-initialized type - internal FieldDescription() {} + internal FieldDescription() { } internal FieldDescription(uint oid) - : this("?", 0, 0, oid, 0, 0, FormatCode.Binary) {} + : this("?", 0, 0, oid, 0, 0, DataFormat.Binary) { } internal FieldDescription( string name, uint tableOID, short columnAttributeNumber, - uint oid, short typeSize, int typeModifier, FormatCode formatCode) + uint oid, short typeSize, int typeModifier, DataFormat dataFormat) { Name = name; TableOID = tableOID; @@ -230,38 +246,41 @@ internal FieldDescription( TypeOID = oid; TypeSize = typeSize; TypeModifier = typeModifier; - FormatCode = formatCode; + DataFormat = dataFormat; } #pragma warning restore CS8618 internal FieldDescription(FieldDescription source) { - _typeMapper = source._typeMapper; + _serializerOptions = source._serializerOptions; Name = source.Name; TableOID = source.TableOID; ColumnAttributeNumber = source.ColumnAttributeNumber; TypeOID = source.TypeOID; TypeSize = source.TypeSize; TypeModifier = source.TypeModifier; - FormatCode = source.FormatCode; - Handler = source.Handler; + DataFormat = source.DataFormat; + PostgresType = source.PostgresType; + Field = source.Field; + _objectOrDefaultInfo = source._objectOrDefaultInfo; } internal void Populate( - TypeMapper typeMapper, string name, uint tableOID, short columnAttributeNumber, - uint oid, short typeSize, int typeModifier, FormatCode formatCode + PgSerializerOptions serializerOptions, string name, uint tableOID, short columnAttributeNumber, + uint oid, short typeSize, int typeModifier, DataFormat dataFormat ) { - _typeMapper = typeMapper; + _serializerOptions = serializerOptions; Name = name; TableOID = tableOID; ColumnAttributeNumber = columnAttributeNumber; TypeOID = oid; TypeSize = typeSize; TypeModifier = typeModifier; - FormatCode = formatCode; - - ResolveHandler(); + DataFormat = dataFormat; + PostgresType = _serializerOptions.DatabaseInfo.FindPostgresType((Oid)TypeOID)?.GetRepresentationalType() ?? UnknownBackendType.Instance; + Field = new(Name, _serializerOptions.ToCanonicalTypeId(PostgresType), TypeModifier); + _objectOrDefaultInfo = default; } /// @@ -296,43 +315,94 @@ internal void Populate( /// /// The format code being used for the field. - /// Currently will be zero (text) or one (binary). + /// Currently will be text or binary. /// In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. /// - internal FormatCode FormatCode { get; set; } + internal DataFormat DataFormat { get; set; } - internal string TypeDisplayName => PostgresType.GetDisplayNameWithFacets(TypeModifier); + internal Field Field { get; private set; } - /// - /// The Npgsql type handler assigned to handle this field. - /// Returns for fields with format text. - /// - internal NpgsqlTypeHandler Handler { get; private set; } + internal string TypeDisplayName => PostgresType.GetDisplayNameWithFacets(TypeModifier); - internal PostgresType PostgresType - => _typeMapper.DatabaseInfo.ByOID.TryGetValue(TypeOID, out var postgresType) - ? postgresType - : UnknownBackendType.Instance; + internal PostgresType PostgresType { get; private set; } - internal Type FieldType => Handler.GetFieldType(this); + internal Type FieldType => ObjectOrDefaultInfo.TypeToConvert; - internal void ResolveHandler() - => Handler = IsBinaryFormat ? _typeMapper.ResolveByOID(TypeOID) : _typeMapper.UnrecognizedTypeHandler; + PgConverterInfo _objectOrDefaultInfo; + internal PgConverterInfo ObjectOrDefaultInfo + { + get + { + if (!_objectOrDefaultInfo.IsDefault) + return _objectOrDefaultInfo; - TypeMapper _typeMapper; + ref var info = ref _objectOrDefaultInfo; + GetInfo(null, ref _objectOrDefaultInfo); + return info; + } + } - internal bool IsBinaryFormat => FormatCode == FormatCode.Binary; - internal bool IsTextFormat => FormatCode == FormatCode.Text; + PgSerializerOptions _serializerOptions; internal FieldDescription Clone() { - var field = new FieldDescription(this); - field.ResolveHandler(); + var field = new FieldDescription(this); return field; } + internal void GetInfo(Type? type, ref PgConverterInfo lastConverterInfo) + { + Debug.Assert(lastConverterInfo.IsDefault || ( + ReferenceEquals(_serializerOptions, lastConverterInfo.TypeInfo.Options) && + lastConverterInfo.TypeInfo.PgTypeId == _serializerOptions.ToCanonicalTypeId(PostgresType)), "Cache is bleeding over"); + + if (!lastConverterInfo.IsDefault && lastConverterInfo.TypeToConvert == type) + return; + + // Have to check for null as it's a sentinel value used by ObjectOrDefaultTypeInfo init itself. + if (type is not null && ObjectOrDefaultInfo is var odfInfo) + { + if (typeof(object) == type) + { + lastConverterInfo = odfInfo with { AsObject = true }; + return; + } + if (odfInfo.TypeToConvert == type) + { + lastConverterInfo = odfInfo; + return; + } + } + + GetInfoSlow(out lastConverterInfo); + + [MethodImpl(MethodImplOptions.NoInlining)] + void GetInfoSlow(out PgConverterInfo lastConverterInfo) + { + PgConverterInfo converterInfo; + var typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(object), PostgresType, _serializerOptions); + switch (DataFormat) + { + case DataFormat.Binary: + // If we don't support binary we'll just throw. + converterInfo = typeInfo.Bind(Field, DataFormat); + break; + default: + // For text we'll fall back to any available text converter for the expected clr type or throw. + if (!typeInfo.TryBind(Field, DataFormat, out converterInfo)) + { + typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(string), _serializerOptions.UnknownPgType, _serializerOptions); + converterInfo = typeInfo.Bind(Field, DataFormat); + } + break; + } + + lastConverterInfo = converterInfo; + } + } + /// /// Returns a string that represents the current object. /// - public override string ToString() => Name + (Handler == null ? "" : $"({Handler.PgDisplayName})"); + public override string ToString() => Name + $"({PostgresType.DisplayName})"; } diff --git a/src/Npgsql/Internal/AdoSerializerHelpers.cs b/src/Npgsql/Internal/AdoSerializerHelpers.cs new file mode 100644 index 0000000000..f9e63e7a40 --- /dev/null +++ b/src/Npgsql/Internal/AdoSerializerHelpers.cs @@ -0,0 +1,58 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal; + +static class AdoSerializerHelpers +{ + public static PgTypeInfo GetTypeInfoForReading(Type type, PostgresType postgresType, PgSerializerOptions options) + { + PgTypeInfo? typeInfo = null; + Exception? inner = null; + try + { + typeInfo = type == typeof(object) ? options.GetObjectOrDefaultTypeInfo(postgresType) : options.GetTypeInfo(type, postgresType); + } + catch (Exception ex) + { + inner = ex; + } + return typeInfo ?? ThrowReadingNotSupported(type, postgresType.DisplayName, inner); + + // InvalidCastException thrown to align with ADO.NET convention. + [DoesNotReturn] + static PgTypeInfo ThrowReadingNotSupported(Type? type, string displayName, Exception? inner = null) + => throw new InvalidCastException($"Reading{(type is null ? "" : $" as '{type.FullName}'")} is not supported for fields having DataTypeName '{displayName}'", inner); + } + + public static PgTypeInfo GetTypeInfoForWriting(Type? type, PgTypeId? pgTypeId, PgSerializerOptions options, NpgsqlDbType? npgsqlDbType = null) + { + Debug.Assert(type != typeof(object), "Parameters of type object are not supported."); + + PgTypeInfo? typeInfo = null; + Exception? inner = null; + try + { + typeInfo = type is null ? options.GetDefaultTypeInfo(pgTypeId!.Value) : options.GetTypeInfo(type, pgTypeId); + } + catch (Exception ex) + { + inner = ex; + } + return typeInfo ?? ThrowWritingNotSupported(type, + pgTypeString: + pgTypeId is null ? "no NpgsqlDbType or DataTypeName. Try setting one of these values to the expected database type." : + npgsqlDbType is null + ? $"DataTypeName '{options.DatabaseInfo.FindPostgresType(pgTypeId.GetValueOrDefault())?.DisplayName ?? "unknown"}'" + : $"NpgsqlDbType '{npgsqlDbType}'", inner); + + // InvalidCastException thrown to align with ADO.NET convention. + [DoesNotReturn] + static PgTypeInfo ThrowWritingNotSupported(Type? type, string pgTypeString, Exception? inner = null) + => throw new InvalidCastException($"Writing{(type is null ? "" : $" values of '{type.FullName}'")} is not supported for parameters having {pgTypeString}.", inner); + } +} diff --git a/src/Npgsql/Internal/BufferRequirements.cs b/src/Npgsql/Internal/BufferRequirements.cs new file mode 100644 index 0000000000..cd32c0cbd1 --- /dev/null +++ b/src/Npgsql/Internal/BufferRequirements.cs @@ -0,0 +1,43 @@ +using System; + +namespace Npgsql.Internal; + +public readonly struct BufferRequirements : IEquatable +{ + readonly Size _read; + readonly Size _write; + + BufferRequirements(Size read, Size write) + { + _read = read; + _write = write; + } + + public Size Read => _read; + public Size Write => _write; + + /// Streaming + public static BufferRequirements None => new(Size.Unknown, Size.Unknown); + /// Entire value should be buffered + public static BufferRequirements Value => new(Size.CreateUpperBound(int.MaxValue), Size.CreateUpperBound(int.MaxValue)); + /// Fixed size value should be buffered + public static BufferRequirements CreateFixedSize(int byteCount) => new(byteCount, byteCount); + /// Custom requirements + public static BufferRequirements Create(Size value) => new(value, value); + public static BufferRequirements Create(Size read, Size write) => new(read, write); + + public BufferRequirements Combine(Size read, Size write) + => new(_read.Combine(read), _write.Combine(write)); + + public BufferRequirements Combine(BufferRequirements other) + => Combine(other._read, other._write); + + public BufferRequirements Combine(int byteCount) + => Combine(CreateFixedSize(byteCount)); + + public bool Equals(BufferRequirements other) => _read.Equals(other._read) && _write.Equals(other._write); + public override bool Equals(object? obj) => obj is BufferRequirements other && Equals(other); + public override int GetHashCode() => HashCode.Combine(_read, _write); + public static bool operator ==(BufferRequirements left, BufferRequirements right) => left.Equals(right); + public static bool operator !=(BufferRequirements left, BufferRequirements right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs new file mode 100644 index 0000000000..c51c0dafa0 --- /dev/null +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs @@ -0,0 +1,109 @@ +using System; +using System.Buffers; +using Npgsql.Util; + +namespace Npgsql.Internal.Composites; + +abstract class CompositeBuilder +{ + protected StrongBox[] _tempBoxes; + protected int _currentField; + + protected CompositeBuilder(StrongBox[] tempBoxes) => _tempBoxes = tempBoxes; + + protected abstract void Construct(); + protected abstract void SetField(TValue value); + + public void AddValue(TValue value) + { + var tempBoxes = _tempBoxes; + var currentField = _currentField; + if (currentField >= tempBoxes.Length) + { + if (currentField == tempBoxes.Length) + Construct(); + SetField(value); + } + else + { + ((StrongBox)tempBoxes[currentField]).TypedValue = value; + if (currentField + 1 == tempBoxes.Length) + Construct(); + } + + _currentField++; + } +} + +sealed class CompositeBuilder : CompositeBuilder, IDisposable +{ + readonly CompositeInfo _compositeInfo; + T _instance = default!; + object? _boxedInstance; + + public CompositeBuilder(CompositeInfo compositeInfo) + : base(compositeInfo.CreateTempBoxes()) + => _compositeInfo = compositeInfo; + + public T Complete() + { + if (_currentField < _compositeInfo.Fields.Count) + throw new InvalidOperationException($"Missing values, expected: {_compositeInfo.Fields.Count} got: {_currentField}"); + + return (T)(_boxedInstance ?? _instance!); + } + + public void Reset() + { + _instance = default!; + _boxedInstance = null; + _currentField = 0; + foreach (var box in _tempBoxes) + box.Clear(); + } + + public void Dispose() => Reset(); + + protected override void Construct() + { + var tempBoxes = _tempBoxes; + if (_currentField < tempBoxes.Length - 1) + throw new InvalidOperationException($"Missing values, expected: {tempBoxes.Length} got: {_currentField + 1}"); + + var fields = _compositeInfo.Fields; + var args = ArrayPool.Shared.Rent(_compositeInfo.ConstructorParameters); + for (var i = 0; i < tempBoxes.Length; i++) + { + var field = fields[i]; + if (field.ConstructorParameterIndex is { } argIndex) + args[argIndex] = tempBoxes[i]; + } + _instance = _compositeInfo.Constructor(args)!; + ArrayPool.Shared.Return(args); + + if (tempBoxes.Length == _compositeInfo.Fields.Count) + return; + + // We're expecting or already have stored more fields, so box the instance once here. + _boxedInstance = _instance; + for (var i = 0; i < tempBoxes.Length; i++) + { + var field = _compositeInfo.Fields[i]; + if (field.ConstructorParameterIndex is null) + field.Set(_boxedInstance, tempBoxes[i]); + } + } + + protected override void SetField(TValue value) + { + if (_boxedInstance is null) + ThrowHelper.ThrowInvalidOperationException("Not constructed yet, or no more fields were expected."); + + var currentField = _currentField; + var fields = _compositeInfo.Fields; + if (currentField > fields.Count - 1) + ThrowHelper.ThrowIndexOutOfRangeException($"Cannot set field {value} at position {currentField} - all fields have already been set"); + + ((CompositeFieldInfo)fields[currentField]).Set(_boxedInstance, value); + } +} diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs new file mode 100644 index 0000000000..765399bf76 --- /dev/null +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs @@ -0,0 +1,192 @@ +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; +using Npgsql.Util; + +namespace Npgsql.Internal.Composites; + +abstract class CompositeFieldInfo +{ + protected PgConverter Converter { get; } + protected BufferRequirements _binaryBufferRequirements; + + private protected CompositeFieldInfo(string name, PgConverterResolution resolution) + { + Name = name; + Converter = resolution.Converter; + PgTypeId = resolution.PgTypeId; + + if (!Converter.CanConvert(DataFormat.Binary, out _binaryBufferRequirements)) + throw new InvalidOperationException("Converter must support binary format to participate in composite types."); + } + + protected PgConverter GetConverter() => (PgConverter)Converter; + + protected ValueTask ReadAsObject(bool async, CompositeBuilder builder, PgReader reader, CancellationToken cancellationToken) + { + if (async) + { + var task = Converter.ReadAsObjectAsync(reader, cancellationToken); + if (!task.IsCompletedSuccessfully) + return Core(builder, task); + + AddValue(builder, task.Result); + } + else + AddValue(builder, Converter.ReadAsObject(reader)); + return new(); +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + async ValueTask Core(CompositeBuilder builder, ValueTask task) + { + builder.AddValue(await task.ConfigureAwait(false)); + } + } + + protected ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + if (async) + return Converter.WriteAsObjectAsync(writer, value, cancellationToken); + + Converter.WriteAsObject(writer, value); + return new(); + } + + public string Name { get; } + public PgTypeId PgTypeId { get; } + public Size BinaryReadRequirement => _binaryBufferRequirements.Read; + public Size BinaryWriteRequirement => _binaryBufferRequirements.Write; + + public abstract Type Type { get; } + + protected abstract void AddValue(CompositeBuilder builder, object value); + + public abstract StrongBox CreateBox(); + public abstract void Set(object instance, StrongBox value); + public abstract int? ConstructorParameterIndex { get; } + public abstract bool IsDbNullable { get; } + + public abstract void ReadDbNull(CompositeBuilder builder); + public abstract ValueTask Read(bool async, CompositeBuilder builder, PgReader reader, CancellationToken cancellationToken = default); + public abstract bool IsDbNull(object instance); + public abstract Size? GetSizeOrDbNull(DataFormat format, object instance, ref object? writeState); + public abstract ValueTask Write(bool async, PgWriter writer, object instance, CancellationToken cancellationToken); +} + +sealed class CompositeFieldInfo : CompositeFieldInfo +{ + readonly Action? _setter; + readonly int _parameterIndex; + readonly Func _getter; + readonly bool _asObject; + + CompositeFieldInfo(string name, PgConverterResolution resolution, Func getter) + : base(name, resolution) + { + var typeToConvert = resolution.Converter.TypeToConvert; + if (!typeToConvert.IsAssignableFrom(typeof(T))) + throw new InvalidOperationException($"Converter type '{typeToConvert}' must be assignable from field type '{typeof(T)}'."); + + _getter = getter; + _asObject = typeToConvert != typeof(T); + } + + public CompositeFieldInfo(string name, PgConverterResolution resolution, Func getter, int parameterIndex) + : this(name, resolution, getter) + => _parameterIndex = parameterIndex; + + public CompositeFieldInfo(string name, PgConverterResolution resolution, Func getter, Action setter) + : this(name, resolution, getter) + => _setter = setter; + + public override Type Type => typeof(T); + + public override int? ConstructorParameterIndex => _setter is not null ? null : _parameterIndex; + + public T Get(object instance) => _getter(instance); + + public override StrongBox CreateBox() => new Util.StrongBox(); + + public void Set(object instance, T value) + { + if (_setter is null) + throw new InvalidOperationException("Not a composite field for a clr field."); + + _setter(instance, value); + } + + public override void Set(object instance, StrongBox value) + { + if (_setter is null) + throw new InvalidOperationException("Not a composite field for a clr field."); + + _setter(instance, ((Util.StrongBox)value).TypedValue!); + } + + public override void ReadDbNull(CompositeBuilder builder) + { + if (default(T) != null) + throw new InvalidCastException($"Type {typeof(T).FullName} does not have null as a possible value."); + + builder.AddValue((T?)default); + } + + protected override void AddValue(CompositeBuilder builder, object value) => builder.AddValue((T)value); + + public override ValueTask Read(bool async, CompositeBuilder builder, PgReader reader, CancellationToken cancellationToken = default) + { + if (_asObject) + return ReadAsObject(async, builder, reader, cancellationToken); + + if (async) + { + var task = GetConverter().ReadAsync(reader, cancellationToken); + if (!task.IsCompletedSuccessfully) + return Core(builder, task); + + builder.AddValue(task.Result); + } + else + builder.AddValue(GetConverter().Read(reader)); + return new(); +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + async ValueTask Core(CompositeBuilder builder, ValueTask task) + { + builder.AddValue(await task.ConfigureAwait(false)); + } + } + + public override bool IsDbNullable => Converter.IsDbNullable; + + public override bool IsDbNull(object instance) + { + var value = _getter(instance); + return _asObject ? Converter.IsDbNullAsObject(value) : GetConverter().IsDbNull(value); + } + + public override Size? GetSizeOrDbNull(DataFormat format, object instance, ref object? writeState) + { + var value = _getter(instance); + return _asObject + ? Converter.GetSizeOrDbNullAsObject(format, _binaryBufferRequirements.Write, value, ref writeState) + : GetConverter().GetSizeOrDbNull(format, _binaryBufferRequirements.Write, value, ref writeState); + } + + public override ValueTask Write(bool async, PgWriter writer, object instance, CancellationToken cancellationToken) + { + var value = _getter(instance); + if (_asObject) + return WriteAsObject(async, writer, value!, cancellationToken); + + if (async) + return GetConverter().WriteAsync(writer, value!, cancellationToken); + + GetConverter().Write(writer, value!); + return new(); + } +} diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs new file mode 100644 index 0000000000..95a2c316a1 --- /dev/null +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Npgsql.Util; + +namespace Npgsql.Internal.Composites; + +sealed class CompositeInfo +{ + readonly int _lastConstructorFieldIndex; + readonly CompositeFieldInfo[] _fields; + + public CompositeInfo(CompositeFieldInfo[] fields, int? constructorParameters, Func? constructor) + { + _lastConstructorFieldIndex = -1; + for (var i = fields.Length - 1; i >= 0; i--) + if (fields[i].ConstructorParameterIndex is not null) + { + _lastConstructorFieldIndex = i; + break; + } + + var parameterSum = 0; + for(var i = constructorParameters - 1 ?? 0; i > 0; i--) + parameterSum += i; + + var argumentsSum = 0; + if (parameterSum > 0) + { + foreach (var field in fields) + if (field.ConstructorParameterIndex is { } index) + argumentsSum += index; + } + + if (parameterSum != argumentsSum) + throw new InvalidOperationException($"Missing composite fields to map to the required {constructorParameters} constructor parameters."); + + _fields = fields; + if (constructor is null) + Constructor = _ => Activator.CreateInstance(); + else + { + var arguments = new CompositeFieldInfo[constructorParameters.GetValueOrDefault()]; + foreach (var field in fields) + { + if (field.ConstructorParameterIndex is { } index) + arguments[index] = field; + } + Constructor = constructor; + } + + ConstructorParameters = constructorParameters ?? 0; + } + + public IReadOnlyList Fields => _fields; + + public int ConstructorParameters { get; } + public Func Constructor { get; } + + /// + /// Create temporary storage for all values that come before the constructor parameters can be saturated. + /// + /// + public StrongBox[] CreateTempBoxes() + { + var valueCache = _lastConstructorFieldIndex + 1 is 0 ? Array.Empty() : new StrongBox[_lastConstructorFieldIndex + 1]; + var fields = _fields; + + for (var i = 0; i < valueCache.Length; i++) + valueCache[i] = fields[i].CreateBox(); + + return valueCache; + } +} diff --git a/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs b/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs new file mode 100644 index 0000000000..1fe217f5dc --- /dev/null +++ b/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs @@ -0,0 +1,296 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using Npgsql.PostgresTypes; +using Npgsql.Util; +using NpgsqlTypes; + +namespace Npgsql.Internal.Composites; + +static class ReflectionCompositeInfoFactory +{ + public static CompositeInfo CreateCompositeInfo(PostgresCompositeType pgType, INpgsqlNameTranslator nameTranslator, PgSerializerOptions options) + { + var pgFields = pgType.Fields; + var propertyMap = MapProperties(pgFields, nameTranslator); + var fieldMap = MapFields(pgFields, nameTranslator); + + var duplicates = propertyMap.Keys.Intersect(fieldMap.Keys).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Property {propertyMap[duplicates[0]].Name} and field {fieldMap[duplicates[0]].Name} map to the same '{pgFields[duplicates[0]].Name}' composite field name."); + + var (constructorInfo, parameterFieldMap) = MapBestMatchingConstructor(pgFields, nameTranslator); + var constructorParameters = constructorInfo?.GetParameters() ?? Array.Empty(); + var compositeFields = new CompositeFieldInfo?[pgFields.Count]; + for (var i = 0; i < parameterFieldMap.Length; i++) + { + var fieldIndex = parameterFieldMap[i]; + var pgField = pgFields[fieldIndex]; + var parameter = constructorParameters[i]; + PgTypeInfo pgTypeInfo; + Delegate getter; + if (propertyMap.TryGetValue(fieldIndex, out var property) && property.GetMethod is not null) + { + if (property.PropertyType != parameter.ParameterType) + throw new InvalidOperationException($"Could not find a matching getter for constructor parameter {parameter.Name} and type {parameter.ParameterType} mapped to composite field {pgFields[fieldIndex].Name}."); + + pgTypeInfo = options.GetTypeInfo(property.PropertyType, pgField.Type.GetRepresentationalType()) ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); + getter = CreateGetter(property); + } + else if (fieldMap.TryGetValue(fieldIndex, out var field)) + { + if (field.FieldType != parameter.ParameterType) + throw new InvalidOperationException($"Could not find a matching getter for constructor parameter {parameter.Name} and type {parameter.ParameterType} mapped to composite field {pgFields[fieldIndex].Name}."); + + pgTypeInfo = options.GetTypeInfo(field.FieldType, pgField.Type.GetRepresentationalType()) ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); + getter = CreateGetter(field); + } + else + throw new InvalidOperationException($"Cannot find property or field for composite field {pgFields[fieldIndex].Name}."); + + compositeFields[fieldIndex] = CreateCompositeFieldInfo(pgField.Name, pgTypeInfo.Type, MapResolution(pgField, pgTypeInfo.GetConcreteResolution()), getter, i); + } + + for (var fieldIndex = 0; fieldIndex < pgFields.Count; fieldIndex++) + { + // Handled by constructor. + if (compositeFields[fieldIndex] is not null) + continue; + + var pgField = pgFields[fieldIndex]; + PgTypeInfo pgTypeInfo; + Delegate getter; + Delegate setter; + if (propertyMap.TryGetValue(fieldIndex, out var property)) + { + pgTypeInfo = options.GetTypeInfo(property.PropertyType, pgField.Type.GetRepresentationalType()) + ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); + getter = CreateGetter(property); + setter = CreateSetter(property); + } + else if (fieldMap.TryGetValue(fieldIndex, out var field)) + { + pgTypeInfo = options.GetTypeInfo(field.FieldType, pgField.Type.GetRepresentationalType()) + ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); + getter = CreateGetter(field); + setter = CreateSetter(field); + } + else + throw new InvalidOperationException($"Cannot find property or field for composite field '{pgFields[fieldIndex].Name}'."); + + compositeFields[fieldIndex] = CreateCompositeFieldInfo(pgField.Name, pgTypeInfo.Type, MapResolution(pgField, pgTypeInfo.GetConcreteResolution()), getter, setter); + } + + Debug.Assert(compositeFields.All(x => x is not null)); + + var constructor = constructorInfo is null ? null : CreateStrongBoxConstructor(constructorInfo); + return new CompositeInfo(compositeFields!, constructorInfo is null ? null : constructorParameters.Length, constructor); + + // We have to map the pg type back to the composite field type, as we've resolved based on the representational pg type. + PgConverterResolution MapResolution(PostgresCompositeType.Field field, PgConverterResolution resolution) + => new(resolution.Converter, options.ToCanonicalTypeId(field.Type)); + + static NotSupportedException NotSupportedField(PostgresCompositeType composite, PostgresCompositeType.Field field, bool isField, string name, Type type) + => new($"No resolution could be found for ('{type.FullName}', '{field.Type.FullName}'). Mapping: CLR {(isField ? "field" : "property")} '{type.Name}.{name}' <-> Composite field '{composite.Name}.{field.Name}'"); + } + + static Delegate CreateGetter(FieldInfo info) + { + var instance = Expression.Parameter(typeof(object), "instance"); + return Expression + .Lambda(typeof(Func<,>).MakeGenericType(typeof(object), info.FieldType), + Expression.Field(UnboxAny(instance, typeof(T)), info), + instance) + .Compile(); + } + + static Delegate CreateSetter(FieldInfo info) + { + var instance = Expression.Parameter(typeof(object), "instance"); + var value = Expression.Parameter(info.FieldType, "value"); + + return Expression + .Lambda(typeof(Action<,>).MakeGenericType(typeof(object), info.FieldType), + Expression.Assign(Expression.Field(UnboxAny(instance, typeof(T)), info), value), instance, value) + .Compile(); + } + + static Delegate CreateGetter(PropertyInfo info) + { + var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor(new []{ typeof(string) })!; + var instance = Expression.Parameter(typeof(object), "instance"); + var body = info.GetMethod is null || !info.GetMethod.IsPublic + ? (Expression)Expression.Throw(Expression.New(invalidOpExceptionMessageConstructor, + Expression.Constant($"No (public) getter for '{info}' on type {typeof(T)}")), info.PropertyType) + : Expression.Property(UnboxAny(instance, typeof(T)), info); + + return Expression + .Lambda(typeof(Func<,>).MakeGenericType(typeof(object), info.PropertyType), body, instance) + .Compile(); + } + + static Delegate CreateSetter(PropertyInfo info) + { + var instance = Expression.Parameter(typeof(object), "instance"); + var value = Expression.Parameter(info.PropertyType, "value"); + + var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor(new []{ typeof(string) })!; + var body = info.SetMethod is null || !info.SetMethod.IsPublic + ? (Expression)Expression.Throw(Expression.New(invalidOpExceptionMessageConstructor, + Expression.Constant($"No (public) getter for '{info}' on type {typeof(T)}")), info.PropertyType) + : Expression.Call(UnboxAny(instance, typeof(T)), info.SetMethod, value); + + return Expression + .Lambda(typeof(Action<,>).MakeGenericType(typeof(object), info.PropertyType), body, instance, value) + .Compile(); + } + + static Expression UnboxAny(Expression expression, Type type) + => type.IsValueType ? Expression.Unbox(expression, type) : Expression.Convert(expression, type, null); + + static Func CreateStrongBoxConstructor(ConstructorInfo constructorInfo) + { + var values = Expression.Parameter(typeof(StrongBox[]), "values"); + + var parameters = constructorInfo.GetParameters(); + var parameterCount = Expression.Constant(parameters.Length); + var argumentExceptionNameMessageConstructor = typeof(ArgumentException).GetConstructor(new []{ typeof(string), typeof(string) })!; + return Expression + .Lambda>( + Expression.Block( + Expression.IfThen( + Expression.LessThan(Expression.Property(values, "Length"), parameterCount), + + Expression.Throw(Expression.New(argumentExceptionNameMessageConstructor, + Expression.Constant("Passed fewer arguments than there are constructor parameters."), Expression.Constant(values.Name))) + ), + Expression.New(constructorInfo, parameters.Select((parameter, i) => + Expression.Property( + UnboxAny( + Expression.ArrayIndex(values, Expression.Constant(i)), + typeof(StrongBox<>).MakeGenericType(parameter.ParameterType) + ), + "TypedValue" + ) + )) + ), values) + .Compile(); + } + static CompositeFieldInfo CreateCompositeFieldInfo(string name, Type type, PgConverterResolution converterResolution, Delegate getter, int constructorParameterIndex) + => (CompositeFieldInfo)Activator.CreateInstance( + typeof(CompositeFieldInfo<>).MakeGenericType(type), name, converterResolution, getter, constructorParameterIndex)!; + + static CompositeFieldInfo CreateCompositeFieldInfo(string name, Type type, PgConverterResolution converterResolution, Delegate getter, Delegate setter) + => (CompositeFieldInfo)Activator.CreateInstance( + typeof(CompositeFieldInfo<>).MakeGenericType(type), name, converterResolution, getter, setter)!; + + static Dictionary MapProperties(IReadOnlyList fields, INpgsqlNameTranslator nameTranslator) + { + var properties = typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance); + var propertiesAndNames = properties.Select(x => + { + var attr = x.GetCustomAttribute(); + var name = attr?.PgName ?? nameTranslator.TranslateMemberName(x.Name); + return new KeyValuePair(name, x); + }).ToArray(); + + var duplicates = propertiesAndNames.Except(propertiesAndNames.Distinct()).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Multiple properties are mapped to the '{duplicates[0].Key}' field."); + + var propertiesMap = propertiesAndNames.ToDictionary(x => x.Key, x => x.Value); + var result = new Dictionary(); + for (var i = 0; i < fields.Count; i++) + { + var field = fields[i]; + if (!propertiesMap.TryGetValue(field.Name, out var value)) + continue; + + result[i] = value; + } + + return result; + } + + static Dictionary MapFields(IReadOnlyList fields, INpgsqlNameTranslator nameTranslator) + { + var clrFields = typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance); + var clrFieldsAndNames = clrFields.Select(x => + { + var attr = x.GetCustomAttribute(); + var name = attr?.PgName ?? nameTranslator.TranslateMemberName(x.Name); + return new KeyValuePair(name, x); + }).ToArray(); + + var duplicates = clrFieldsAndNames.Except(clrFieldsAndNames.Distinct()).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Multiple properties are mapped to the '{duplicates[0].Key}' field."); + + var clrFieldsMap = clrFieldsAndNames.ToDictionary(x => x.Key, x => x.Value); + var result = new Dictionary(); + for (var i = 0; i < fields.Count; i++) + { + var field = fields[i]; + if (!clrFieldsMap.TryGetValue(field.Name, out var value)) + continue; + + result[i] = value; + } + + return result; + } + + static (ConstructorInfo? ConstructorInfo, int[] ParameterFieldMap) MapBestMatchingConstructor(IReadOnlyList fields, INpgsqlNameTranslator nameTranslator) + { + ConstructorInfo? clrDefaultConstructor = null; + foreach (var constructor in typeof(T).GetConstructors().OrderByDescending(x => x.GetParameters().Length)) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != fields.Count) + { + if (parameters.Length == 0) + clrDefaultConstructor = constructor; + + continue; + } + + var parametersMapped = 0; + var parametersMap = new int[parameters.Length]; + + for (var i = 0; i < parameters.Length; i++) + { + var clrParameter = parameters[i]; + var attr = clrParameter.GetCustomAttribute(); + var name = attr?.PgName ?? (clrParameter.Name is { } clrName ? nameTranslator.TranslateMemberName(clrName) : null); + if (name is null) + break; + + for (var pgFieldIndex = 0; pgFieldIndex < fields.Count; pgFieldIndex++) + { + var pgField = fields[pgFieldIndex]; + if (pgField.Name != name) + continue; + + parametersMapped++; + parametersMap[i] = pgFieldIndex; + break; + } + } + + var duplicates = parametersMap.Except(parametersMap.Distinct()).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Multiple constructor parameters are mapped to the '{fields[duplicates[0]].Name}' field."); + + if (parametersMapped == parameters.Length) + return (constructor, parametersMap); + } + + if (clrDefaultConstructor is null && !typeof(T).IsValueType) + throw new InvalidOperationException($"No parameterless constructor defined for type '{typeof(T)}'."); + + return (clrDefaultConstructor, Array.Empty()); + } +} diff --git a/src/Npgsql/Internal/Converters/ArrayConverter.cs b/src/Npgsql/Internal/Converters/ArrayConverter.cs new file mode 100644 index 0000000000..2801714cc5 --- /dev/null +++ b/src/Npgsql/Internal/Converters/ArrayConverter.cs @@ -0,0 +1,675 @@ +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +interface IElementOperations +{ + object CreateCollection(int[] lengths); + int GetCollectionCount(object collection, out int[]? lengths); + Size? GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState); + ValueTask Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken = default); + ValueTask Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken = default); +} + +readonly struct PgArrayConverter +{ + internal const string ReadNonNullableCollectionWithNullsExceptionMessage = "Cannot read a non-nullable collection of elements because the returned array contains nulls. Call GetFieldValue with a nullable collection type instead."; + + readonly IElementOperations _elemOps; + readonly int? _expectedDimensions; + readonly BufferRequirements _bufferRequirements; + public bool ElemTypeDbNullable { get; } + readonly int _pgLowerBound; + readonly PgTypeId _elemTypeId; + + public PgArrayConverter(IElementOperations elemOps, bool elemTypeDbNullable, int? expectedDimensions, BufferRequirements bufferRequirements, PgTypeId elemTypeId, int pgLowerBound = 1) + { + _elemTypeId = elemTypeId; + ElemTypeDbNullable = elemTypeDbNullable; + _pgLowerBound = pgLowerBound; + _elemOps = elemOps; + _expectedDimensions = expectedDimensions; + _bufferRequirements = bufferRequirements; + } + + bool IsDbNull(object values, int[] indices) + { + object? state = null; + return _elemOps.GetSizeOrDbNull(new(DataFormat.Binary, _bufferRequirements.Write), values, indices, ref state) is null; + } + + Size GetElemsSize(object values, (Size, object?)[] elemStates, out bool anyElementState, DataFormat format, int count, int[] indices, int[]? lengths = null) + { + Debug.Assert(elemStates.Length >= count); + var totalSize = Size.Zero; + var context = new SizeContext(format, _bufferRequirements.Write); + anyElementState = false; + var lastLength = lengths?[lengths.Length - 1] ?? count; + ref var lastIndex = ref indices[indices.Length - 1]; + var i = 0; + do + { + ref var elemItem = ref elemStates[i++]; + var elemState = (object?)null; + var size = _elemOps.GetSizeOrDbNull(context, values, indices, ref elemState); + anyElementState = anyElementState || elemState is not null; + elemItem = (size ?? -1, elemState); + totalSize = totalSize.Combine(size ?? 0); + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++lastIndex < lastLength || (indices.Length > 1 && CarryIndices(lengths!, indices))); + + return totalSize; + } + + Size GetFixedElemsSize(Size elemSize, object values, int count, int[] indices, int[]? lengths = null) + { + var nulls = 0; + var lastLength = lengths?[lengths.Length - 1] ?? count; + ref var lastIndex = ref indices[indices.Length - 1]; + if (ElemTypeDbNullable) + do + { + if (IsDbNull(values, indices)) + nulls++; + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++lastIndex < lastLength || (indices.Length > 1 && CarryIndices(lengths!, indices))); + + return (count - nulls) * elemSize.Value; + } + + int GetFormatSize(int count, int dimensions) + => sizeof(int) + // Dimensions + sizeof(int) + // Flags + sizeof(int) + // Element OID + dimensions * (sizeof(int) + sizeof(int)) + // Dimensions * (array length and lower bound) + sizeof(int) * count; // Element length integers + + public Size GetSize(SizeContext context, object values, ref object? writeState) + { + var count = _elemOps.GetCollectionCount(values, out var lengths); + var dimensions = lengths?.Length ?? 1; + if (dimensions > 8) + throw new ArgumentException(nameof(values), "Postgres arrays can have at most 8 dimensions."); + + var formatSize = Size.Create(GetFormatSize(count, dimensions)); + if (count is 0) + return formatSize; + + Size elemsSize; + var indices = new int[dimensions]; + if (_bufferRequirements.Write is { Kind: SizeKind.Exact } req) + { + elemsSize = GetFixedElemsSize(req, values, count, indices, lengths); + writeState = new WriteState { Count = count, Indices = indices, Lengths = lengths, ArrayPool = null, Data = default, AnyWriteState = false }; + } + else + { + var arrayPool = ArrayPool<(Size, object?)>.Shared; + var data = ArrayPool<(Size, object?)>.Shared.Rent(count); + elemsSize = GetElemsSize(values, data, out var elemStateDisposable, context.Format, count, indices, lengths); + writeState = new WriteState + { Count = count, Indices = indices, Lengths = lengths, + ArrayPool = arrayPool, Data = new(data, 0, count), AnyWriteState = elemStateDisposable }; + } + + return formatSize.Combine(elemsSize); + } + + sealed class WriteState : MultiWriteState + { + public required int Count { get; init; } + public required int[] Indices { get; init; } + public required int[]? Lengths { get; init; } + } + + public async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken = default) + { + var dimensions = reader.ReadInt32(); + var containsNulls = reader.ReadInt32() is 1; + _ = reader.ReadUInt32(); // Element OID. + + if (dimensions is not 0 && _expectedDimensions is not null && dimensions != _expectedDimensions) + ThrowHelper.ThrowInvalidCastException( + $"Cannot read an array value with {dimensions} dimension{(dimensions == 1 ? "" : "s")} into a " + + $"collection type with {_expectedDimensions} dimension{(_expectedDimensions == 1 ? "" : "s")}. " + + $"Call GetValue or a version of GetFieldValue with the commas being the expected amount of dimensions."); + + if (containsNulls && !ElemTypeDbNullable) + ThrowHelper.ThrowInvalidCastException(ReadNonNullableCollectionWithNullsExceptionMessage); + + // Make sure we can read length + lower bound N dimension times. + if (reader.ShouldBuffer((sizeof(int) + sizeof(int)) * dimensions)) + await reader.Buffer(async, (sizeof(int) + sizeof(int)) * dimensions, cancellationToken).ConfigureAwait(false); + + var dimLengths = new int[_expectedDimensions ?? dimensions]; + var lastDimLength = 0; + for (var i = 0; i < dimensions; i++) + { + lastDimLength = reader.ReadInt32(); + reader.ReadInt32(); // Lower bound + if (dimLengths.Length is 0) + break; + dimLengths[i] = lastDimLength; + } + + var collection = _elemOps.CreateCollection(dimLengths); + Debug.Assert(dimensions <= 1 || collection is Array a && a.Rank == dimensions); + + if (dimensions is 0 || lastDimLength is 0) + return collection; + + int[] indices; + // Reuse array for dim <= 1 + if (dimensions == 1) + { + dimLengths[0] = 0; + indices = dimLengths; + } + else + indices = new int[dimensions]; + do + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var length = reader.ReadInt32(); + var isDbNull = length == -1; + if (!isDbNull) + { + var scope = await reader.BeginNestedRead(async, length, _bufferRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + await _elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + else + await _elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++indices[indices.Length - 1] < lastDimLength || (dimensions > 1 && CarryIndices(dimLengths, indices))); + + return collection; + } + + static bool CarryIndices(int[] lengths, int[] indices) + { + Debug.Assert(lengths.Length > 1); + + // Find the first dimension from the end that isn't at or past its length, increment it and bring all previous dimensions to zero. + for (var dim = indices.Length - 1; dim >= 0; dim--) + { + if (indices[dim] >= lengths[dim] - 1) + continue; + + indices.AsSpan().Slice(dim + 1).Clear(); + indices[dim]++; + return true; + } + + // We're done if we can't find any dimension that isn't at its length. + return false; + } + + public async ValueTask Write(bool async, PgWriter writer, object values, CancellationToken cancellationToken) + { + var (count, dims, state) = writer.Current.WriteState switch + { + WriteState writeState => (writeState.Count, writeState.Lengths?.Length ?? 1 , writeState), + null => (0, values is Array a ? a.Rank : 1, null), + _ => throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}.") + }; + + if (writer.ShouldFlush(GetFormatSize(count, dims))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(dims); // Dimensions + writer.WriteInt32(0); // Flags (not really used) + writer.WriteAsOid(_elemTypeId); + for (var dim = 0; dim < dims; dim++) + { + writer.WriteInt32(state?.Lengths?[dim] ?? count); + writer.WriteInt32(_pgLowerBound); // Lower bound + } + + // We can stop here for empty collections. + if (state is null) + return; + + var elemTypeDbNullable = ElemTypeDbNullable; + var elemData = state.Data.Array; + + var indices = state.Indices; + Array.Clear(indices, 0 , indices.Length); + var lastLength = state.Lengths?[state.Lengths.Length - 1] ?? state.Count; + var i = state.Data.Offset; + do + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var elem = elemData?[i++]; + var size = elem?.Size ?? (elemTypeDbNullable && IsDbNull(values, indices) ? -1 : _bufferRequirements.Write); + if (size.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var length = size.Value; + writer.WriteInt32(length); + if (length != -1) + { + using var _ = await writer.BeginNestedWrite(async, _bufferRequirements.Write, length, elem?.WriteState, cancellationToken).ConfigureAwait(false); + await _elemOps.Write(async, writer, values, indices, cancellationToken).ConfigureAwait(false); + } + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++indices[indices.Length - 1] < lastLength || (indices.Length > 1 && CarryIndices(state.Lengths!, indices))); + } +} + +// Class constraint exists to make Unsafe.As, ValueTask> safe, don't remove unless that unsafe cast is also removed. +abstract class ArrayConverter : PgStreamingConverter where T : class +{ + protected PgConverterResolution ElemResolution { get; } + protected Type ElemTypeToConvert { get; } + + readonly PgArrayConverter _pgArrayConverter; + + private protected ArrayConverter(int? expectedDimensions, PgConverterResolution elemResolution, int pgLowerBound = 1) + { + if (!elemResolution.Converter.CanConvert(DataFormat.Binary, out var bufferRequirements)) + throw new NotSupportedException("Element converter has to support the binary format to be compatible."); + + ElemResolution = elemResolution; + ElemTypeToConvert = elemResolution.Converter.TypeToConvert; + _pgArrayConverter = new((IElementOperations)this, elemResolution.Converter.IsDbNullable, expectedDimensions, + bufferRequirements, elemResolution.PgTypeId, pgLowerBound); + } + + public override T Read(PgReader reader) => (T)_pgArrayConverter.Read(async: false, reader).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) +#pragma warning disable CS9193 + => Unsafe.As, ValueTask>(ref Unsafe.AsRef(_pgArrayConverter.Read(async: true, reader, cancellationToken))); +#pragma warning restore + + public override Size GetSize(SizeContext context, T values, ref object? writeState) + => _pgArrayConverter.GetSize(context, values, ref writeState); + + public override void Write(PgWriter writer, T values) + => _pgArrayConverter.Write(async: false, writer, values, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T values, CancellationToken cancellationToken = default) + => _pgArrayConverter.Write(async: true, writer, values, cancellationToken); + + // Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. + // As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. + // The alternatives are: + // 1. Add a virtual method and make AwaitTask call into it (bloating the vtable of all derived types). + // 2. Using a delegate, meaning we add a static field + an alloc per T + metadata, slightly slower dispatch perf so overall strictly worse as well. +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + private protected static async ValueTask AwaitTask(Task task, Continuation continuation, object collection, int[] indices) + { + await task.ConfigureAwait(false); + continuation.Invoke(task, collection, indices); + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. + protected readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public void Invoke(Task task, object collection, int[] indices) => _continuation(task, collection, indices); + } + + protected static int[]? GetLengths(Array array) + { + if (array.Rank == 1) + return null; + + var lengths = new int[array.Rank]; + for (var i = 0; i < lengths.Length; i++) + lengths[i] = array.GetLength(i); + + return lengths; + } +} + +sealed class ArrayBasedArrayConverter : ArrayConverter, IElementOperations where T : class, IList +{ + readonly PgConverter _elemConverter; + + public ArrayBasedArrayConverter(PgConverterResolution elemResolution, Type? effectiveType = null, int pgLowerBound = 1) + : base( + expectedDimensions: effectiveType is null ? 1 : effectiveType.IsArray ? effectiveType.GetArrayRank() : null, + elemResolution, pgLowerBound) + => _elemConverter = elemResolution.GetConverter(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static TElement? GetValue(object collection, int[] indices) + { + switch (indices.Length) + { + case 1: + Debug.Assert(collection is TElement?[]); + return Unsafe.As(collection)[indices[0]]; + default: + Debug.Assert(collection is Array); + return (TElement?)Unsafe.As(collection).GetValue(indices); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SetValue(object collection, int[] indices, TElement? value) + { + switch (indices.Length) + { + case 1: + Debug.Assert(collection is TElement?[]); + Unsafe.As(collection)[indices[0]] = value; + break; + default: + Debug.Assert(collection is Array); + Unsafe.As(collection).SetValue(value, indices); + break; + } + } + + object IElementOperations.CreateCollection(int[] lengths) + => lengths.Length switch + { + 0 => Array.Empty(), + 1 when lengths[0] == 0 => Array.Empty(), + 1 => new TElement?[lengths[0]], + 2 => new TElement?[lengths[0],lengths[1]], + 3 => new TElement?[lengths[0],lengths[1], lengths[2]], + 4 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3]], + 5 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4]], + 6 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4], lengths[5]], + 7 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6]], + 8 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6], lengths[7]], + _ => throw new InvalidOperationException("Postgres arrays can have at most 8 dimensions.") + }; + + int IElementOperations.GetCollectionCount(object collection, out int[]? lengths) + { + Debug.Assert(collection is Array); + var array = Unsafe.As(collection); + lengths = GetLengths(array); + return array.Length; + } + + Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState) + => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices), ref writeState); + + unsafe ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken) + { + TElement? result; + if (isDbNull) + result = default; + else if (!async) + result = _elemConverter.Read(reader); + else + { + var task = _elemConverter.ReadAsync(reader, cancellationToken); + if (!task.IsCompletedSuccessfully) + return AwaitTask(task.AsTask(), new(this, &SetResult), collection, indices); + + result = task.Result; + } + + SetValue(collection, indices, result); + return new(); + + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + static void SetResult(Task task, object collection, int[] indices) + { + Debug.Assert(task is Task); + SetValue(collection, indices, new ValueTask(Unsafe.As>(task)).Result); + } + } + + ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken) + { + if (async) + return _elemConverter.WriteAsync(writer, GetValue(collection, indices)!, cancellationToken); + + _elemConverter.Write(writer, GetValue(collection, indices)!); + return new(); + } +} + +sealed class ListBasedArrayConverter : ArrayConverter, IElementOperations where T : class, IList +{ + readonly PgConverter _elemConverter; + + public ListBasedArrayConverter(PgConverterResolution elemResolution, int pgLowerBound = 1) + : base(expectedDimensions: 1, elemResolution, pgLowerBound) + => _elemConverter = elemResolution.GetConverter(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static TElement? GetValue(object collection, int index) + { + Debug.Assert(collection is List); + return Unsafe.As>(collection)[index]; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SetValue(object collection, int index, TElement? value) + { + Debug.Assert(collection is List); + var list = Unsafe.As>(collection); + list.Insert(index, value); + } + + object IElementOperations.CreateCollection(int[] lengths) + => new List(lengths.Length is 0 ? 0 : lengths[0]); + + int IElementOperations.GetCollectionCount(object collection, out int[]? lengths) + { + Debug.Assert(collection is List); + lengths = null; + return Unsafe.As>(collection).Count; + } + + Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState) + => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices[0]), ref writeState); + + unsafe ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken) + { + Debug.Assert(indices.Length is 1); + TElement? result; + if (isDbNull) + result = default; + else if (!async) + result = _elemConverter.Read(reader); + else + { + var task = _elemConverter.ReadAsync(reader, cancellationToken); + if (!task.IsCompletedSuccessfully) + return AwaitTask(task.AsTask(), new(this, &SetResult), collection, indices); + + result = task.Result; + } + + SetValue(collection, indices[0], result); + return new(); + + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + static void SetResult(Task task, object collection, int[] indices) + { + Debug.Assert(task is Task); + SetValue(collection, indices[0], new ValueTask(Unsafe.As>(task)).Result); + } + } + + ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken) + { + Debug.Assert(indices.Length is 1); + if (async) + return _elemConverter.WriteAsync(writer, GetValue(collection, indices[0])!, cancellationToken); + + _elemConverter.Write(writer, GetValue(collection, indices[0])!); + return new(); + } +} + +sealed class ArrayConverterResolver : PgComposingConverterResolver where T : class, IList +{ + readonly Type _effectiveType; + + public ArrayConverterResolver(PgResolverTypeInfo elementTypeInfo, Type effectiveType) + : base(elementTypeInfo.PgTypeId is { } id ? elementTypeInfo.Options.GetArrayTypeId(id) : null, elementTypeInfo) + => _effectiveType = effectiveType; + + PgSerializerOptions Options => EffectiveTypeInfo.Options; + + protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => Options.GetArrayElementTypeId(pgTypeId); + protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => Options.GetArrayTypeId(effectivePgTypeId); + + protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) + => typeof(T).IsConstructedGenericType && typeof(T).GetGenericTypeDefinition() == typeof(List<>) + ? new ListBasedArrayConverter(effectiveResolution) + : new ArrayBasedArrayConverter(effectiveResolution, _effectiveType); + + protected override PgConverterResolution? GetEffectiveResolution(T? values, PgTypeId? expectedEffectivePgTypeId) + { + PgConverterResolution? resolution = null; + if (values is null) + { + resolution = EffectiveTypeInfo.GetDefaultResolution(expectedEffectivePgTypeId); + } + else + { + switch (values) + { + case TElement[] array: + foreach (var value in array) + { + var result = EffectiveTypeInfo.GetResolution(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + case IList list: + foreach (var value in list) + { + var result = EffectiveTypeInfo.GetResolution(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + default: + foreach (var value in values) + { + var result = EffectiveTypeInfo.GetResolutionAsObject(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + } + } + + return resolution; + } +} + +// T is Array as we only know what type it will be after reading 'contains nulls'. +sealed class PolymorphicArrayConverter : PgStreamingConverter +{ + readonly PgConverter _structElementCollectionConverter; + readonly PgConverter _nullableElementCollectionConverter; + + public PolymorphicArrayConverter(PgConverter structElementCollectionConverter, PgConverter nullableElementCollectionConverter) + { + _structElementCollectionConverter = structElementCollectionConverter; + _nullableElementCollectionConverter = nullableElementCollectionConverter; + } + + public override TBase Read(PgReader reader) + { + _ = reader.ReadInt32(); + var containsNulls = reader.ReadInt32() is 1; + reader.Rewind(sizeof(int) + sizeof(int)); + return containsNulls + ? _nullableElementCollectionConverter.Read(reader) + : _structElementCollectionConverter.Read(reader); + } + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + _ = reader.ReadInt32(); + var containsNulls = reader.ReadInt32() is 1; + reader.Rewind(sizeof(int) + sizeof(int)); + return containsNulls + ? _nullableElementCollectionConverter.ReadAsync(reader, cancellationToken) + : _structElementCollectionConverter.ReadAsync(reader, cancellationToken); + } + + public override Size GetSize(SizeContext context, TBase value, ref object? writeState) + => throw new NotSupportedException("Polymorphic writing is not supported"); + + public override void Write(PgWriter writer, TBase value) + => throw new NotSupportedException("Polymorphic writing is not supported"); + + public override ValueTask WriteAsync(PgWriter writer, TBase value, CancellationToken cancellationToken = default) + => throw new NotSupportedException("Polymorphic writing is not supported"); +} + +sealed class PolymorphicArrayConverterResolver : PolymorphicConverterResolver +{ + readonly PgResolverTypeInfo _effectiveInfo; + readonly PgResolverTypeInfo _effectiveNullableInfo; + readonly ConcurrentDictionary _converterCache = new(ReferenceEqualityComparer.Instance); + + public PolymorphicArrayConverterResolver(PgResolverTypeInfo effectiveInfo, PgResolverTypeInfo effectiveNullableInfo) + : base(effectiveInfo.PgTypeId!.Value) + { + if (effectiveInfo.PgTypeId is null || effectiveNullableInfo.PgTypeId is null) + throw new InvalidOperationException("Cannot accept undecided infos"); + + _effectiveInfo = effectiveInfo; + _effectiveNullableInfo = effectiveNullableInfo; + } + + protected override PgConverter Get(Field? maybeField) + { + var structResolution = maybeField is { } field + ? _effectiveInfo.GetResolution(field) + : _effectiveInfo.GetDefaultResolution(PgTypeId); + var nullableResolution = maybeField is { } field2 + ? _effectiveNullableInfo.GetResolution(field2) + : _effectiveNullableInfo.GetDefaultResolution(PgTypeId); + + (PgConverter StructConverter, PgConverter NullableConverter) state = (structResolution.Converter, nullableResolution.Converter); + return _converterCache.GetOrAdd(structResolution.Converter, + static (_, state) => new PolymorphicArrayConverter((PgConverter)state.StructConverter, (PgConverter)state.NullableConverter), + state); + } +} diff --git a/src/Npgsql/Internal/Converters/AsyncHelpers.cs b/src/Npgsql/Internal/Converters/AsyncHelpers.cs new file mode 100644 index 0000000000..339378fdd7 --- /dev/null +++ b/src/Npgsql/Internal/Converters/AsyncHelpers.cs @@ -0,0 +1,114 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +// Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. +// As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. +static class AsyncHelpers +{ + static async void AwaitTask(Task task, CompletionSource tcs, Continuation continuation) + { + try + { + await task.ConfigureAwait(false); + continuation.Invoke(task, tcs); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + } + + abstract class CompletionSource + { + public abstract void SetException(Exception exception); + } + + sealed class CompletionSource : CompletionSource + { +#if NETSTANDARD + AsyncValueTaskMethodBuilder _amb = AsyncValueTaskMethodBuilder.Create(); +#else + PoolingAsyncValueTaskMethodBuilder _amb = PoolingAsyncValueTaskMethodBuilder.Create(); +#endif + public ValueTask Task => _amb.Task; + + public void SetResult(T value) + => _amb.SetResult(value); + + public override void SetException(Exception exception) + => _amb.SetException(exception); + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. + readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public void Invoke(Task task, CompletionSource tcs) => _continuation(task, tcs); + } + + public static unsafe ValueTask ComposingReadAsync(this PgConverter instance, PgConverter effectiveConverter, PgReader reader, CancellationToken cancellationToken) + { + if (!typeof(T).IsValueType && !typeof(TEffective).IsValueType) +#pragma warning disable CS9193 + return Unsafe.As, ValueTask>(ref Unsafe.AsRef(effectiveConverter.ReadAsync(reader, cancellationToken))); +#pragma warning restore + // Easy if we have all the data. + var task = effectiveConverter.ReadAsync(reader, cancellationToken); + if (task.IsCompletedSuccessfully) + return new((T)(object)task.Result!); + + // Otherwise we do one additional allocation, this allow us to share state machine codegen for all Ts. + var source = new CompletionSource(); + AwaitTask(task.AsTask(), source, new(instance, &UnboxAndComplete)); + return source.Task; + + static void UnboxAndComplete(Task task, CompletionSource completionSource) + { + Debug.Assert(task is Task); + Debug.Assert(completionSource is CompletionSource); + Unsafe.As>(completionSource).SetResult(new ValueTask(Unsafe.As>(task)).Result); + } + } + + public static unsafe ValueTask ComposingReadAsObjectAsync(this PgConverter instance, PgConverter effectiveConverter, PgReader reader, CancellationToken cancellationToken) + { + if (!typeof(T).IsValueType) +#pragma warning disable CS9193 + return Unsafe.As, ValueTask>(ref Unsafe.AsRef(effectiveConverter.ReadAsObjectAsync(reader, cancellationToken))); +#pragma warning restore + + // Easy if we have all the data. + var task = effectiveConverter.ReadAsObjectAsync(reader, cancellationToken); + if (task.IsCompletedSuccessfully) + return new((T)task.Result); + + // Otherwise we do one additional allocation, this allow us to share state machine codegen for all Ts. + var source = new CompletionSource(); + AwaitTask(task.AsTask(), source, new(instance, &UnboxAndComplete)); + return source.Task; + + static void UnboxAndComplete(Task task, CompletionSource completionSource) + { + Debug.Assert(task is Task); + Debug.Assert(completionSource is CompletionSource); + Unsafe.As>(completionSource).SetResult((T)new ValueTask(Unsafe.As>(task)).Result); + } + } +} diff --git a/src/Npgsql/Internal/Converters/BitStringConverters.cs b/src/Npgsql/Internal/Converters/BitStringConverters.cs new file mode 100644 index 0000000000..b7597f96d9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/BitStringConverters.cs @@ -0,0 +1,249 @@ +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Specialized; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; +using static Npgsql.Internal.Converters.BitStringHelpers; + +namespace Npgsql.Internal.Converters; + +static class BitStringHelpers +{ + public static int GetByteLengthFromBits(int n) + { + const int BitShiftPerByte = 3; + Debug.Assert(n >= 0); + // Due to sign extension, we don't need to special case for n == 0, since ((n - 1) >> 3) + 1 = 0 + // This doesn't hold true for ((n - 1) / 8) + 1, which equals 1. + return (int)((uint)(n - 1 + (1 << BitShiftPerByte)) >> BitShiftPerByte); + } + + // http://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits + public static byte ReverseBits(byte b) => (byte)(((b * 0x80200802UL) & 0x0884422110UL) * 0x0101010101UL >> 32); +} + +sealed class BitArrayBitStringConverter : PgStreamingConverter +{ + public override BitArray Read(PgReader reader) + { + if (reader.ShouldBuffer(sizeof(int))) + reader.Buffer(sizeof(int)); + + var bits = reader.ReadInt32(); + var bytes = new byte[GetByteLengthFromBits(bits)]; + reader.ReadBytes(bytes); + return ReadValue(bytes, bits); + } + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.BufferAsync(sizeof(int), cancellationToken).ConfigureAwait(false); + + var bits = reader.ReadInt32(); + var bytes = new byte[GetByteLengthFromBits(bits)]; + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + return ReadValue(bytes, bits); + } + + internal static BitArray ReadValue(byte[] bytes, int bits) + { + for (var i = 0; i < bytes.Length; i++) + { + ref var b = ref bytes[i]; + b = ReverseBits(b); + } + + return new(bytes) { Length = bits }; + } + + public override Size GetSize(SizeContext context, BitArray value, ref object? writeState) + => sizeof(int) + GetByteLengthFromBits(value.Length); + + public override void Write(PgWriter writer, BitArray value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask WriteAsync(PgWriter writer, BitArray value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, BitArray value, CancellationToken cancellationToken = default) + { + var byteCount = writer.Current.Size.Value - sizeof(int); + var array = ArrayPool.Shared.Rent(byteCount); + for (var pos = 0; pos < byteCount; pos++) + { + var bitPos = pos*8; + var bits = Math.Min(8, value.Length - bitPos); + var b = 0; + for (var i = 0; i < bits; i++) + b += (value[bitPos + i] ? 1 : 0) << (8 - i - 1); + array[pos] = (byte)b; + } + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(value.Length); + if (async) + await writer.WriteBytesAsync(new ReadOnlyMemory(array, 0, byteCount), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(new ReadOnlySpan(array, 0, byteCount)); + + ArrayPool.Shared.Return(array); + } +} + +sealed class BitVector32BitStringConverter : PgBufferedConverter +{ + static int MaxSize => sizeof(int) + sizeof(int); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(Size.CreateUpperBound(MaxSize)); + return format is DataFormat.Binary; + } + + protected override BitVector32 ReadCore(PgReader reader) + { + if (reader.CurrentRemaining > sizeof(int) + sizeof(int)) + throw new InvalidCastException("Can't read a BIT(N) with more than 32 bits to BitVector32, only up to BIT(32)."); + + var bits = reader.ReadInt32(); + return GetByteLengthFromBits(bits) switch + { + 4 => new(reader.ReadInt32()), + 3 => new((reader.ReadInt16() << 8) + reader.ReadByte()), + 2 => new(reader.ReadInt16() << 16), + 1 => new(reader.ReadByte() << 24), + _ => new(0) + }; + } + + public override Size GetSize(SizeContext context, BitVector32 value, ref object? writeState) + => value.Data is 0 ? 4 : MaxSize; + + protected override void WriteCore(PgWriter writer, BitVector32 value) + { + if (value.Data == 0) + writer.WriteInt32(0); + else + { + writer.WriteInt32(32); + writer.WriteInt32(value.Data); + } + } +} + +sealed class BoolBitStringConverter : PgBufferedConverter +{ + static int MaxSize => sizeof(int) + sizeof(byte); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(read: Size.CreateUpperBound(MaxSize), write: MaxSize); + return format is DataFormat.Binary; + } + + protected override bool ReadCore(PgReader reader) + { + var bits = reader.ReadInt32(); + return bits switch + { + > 1 => throw new InvalidCastException("Can't read a BIT(N) type to bool, only BIT(1)."), + // We make an accommodation for varbit with no data. + 0 => false, + _ => (reader.ReadByte() & 128) is not 0 + }; + } + + public override Size GetSize(SizeContext context, bool value, ref object? writeState) => MaxSize; + protected override void WriteCore(PgWriter writer, bool value) + { + writer.WriteInt32(1); + writer.WriteByte(value ? (byte)128 : (byte)0); + } +} + +sealed class StringBitStringConverter : PgStreamingConverter +{ + public override string Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var bits = reader.ReadInt32(); + var bytes = new byte[GetByteLengthFromBits(bits)]; + if (async) + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + else + reader.ReadBytes(bytes); + + var bitArray = BitArrayBitStringConverter.ReadValue(bytes, bits); + var sb = new StringBuilder(bits); + for (var i = 0; i < bitArray.Count; i++) + sb.Append(bitArray[i] ? '1' : '0'); + + return sb.ToString(); + } + + public override Size GetSize(SizeContext context, string value, ref object? writeState) + { + if (value.AsSpan().IndexOfAnyExcept('0', '1') is not -1 and var index) + throw new ArgumentException($"Invalid bitstring character '{value[index]}' at index: {index}", nameof(value)); + + return sizeof(int) + GetByteLengthFromBits(value.Length); + } + + public override void Write(PgWriter writer, string value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask WriteAsync(PgWriter writer, string value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, string value, CancellationToken cancellationToken) + { + var byteCount = writer.Current.Size.Value - sizeof(int); + var array = ArrayPool.Shared.Rent(byteCount); + for (var pos = 0; pos < byteCount; pos++) + { + var bitPos = pos*8; + var bits = Math.Min(8, value.Length - bitPos); + var b = 0; + for (var i = 0; i < bits; i++) + b += (value[bitPos + i] == '1' ? 1 : 0) << (8 - i - 1); + array[pos] = (byte)b; + } + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(value.Length); + if (async) + await writer.WriteBytesAsync(new ReadOnlyMemory(array, 0, byteCount), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(new ReadOnlySpan(array, 0, byteCount)); + + ArrayPool.Shared.Return(array); + } +} + +/// Note that for BIT(1), this resolver will return a bool by default, to align with SqlClient +/// (see discussion https://github.com/npgsql/npgsql/pull/362#issuecomment-59622101). +sealed class PolymorphicBitStringConverterResolver : PolymorphicConverterResolver +{ + BoolBitStringConverter? _boolConverter; + BitArrayBitStringConverter? _bitArrayConverter; + + public PolymorphicBitStringConverterResolver(PgTypeId bitString) : base(bitString) { } + + protected override PgConverter Get(Field? field) + => field?.TypeModifier is 1 + ? _boolConverter ??= new BoolBitStringConverter() + : _bitArrayConverter ??= new BitArrayBitStringConverter(); +} diff --git a/src/Npgsql/Internal/Converters/CastingConverter.cs b/src/Npgsql/Internal/Converters/CastingConverter.cs new file mode 100644 index 0000000000..f721d8d08e --- /dev/null +++ b/src/Npgsql/Internal/Converters/CastingConverter.cs @@ -0,0 +1,83 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +/// A converter to map strongly typed apis onto boxed converter results to produce a strongly typed converter over T. +sealed class CastingConverter : PgConverter +{ + readonly PgConverter _effectiveConverter; + public CastingConverter(PgConverter effectiveConverter) + : base(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) + => _effectiveConverter = effectiveConverter; + + protected override bool IsDbNullValue(T? value) => _effectiveConverter.IsDbNullAsObject(value); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => _effectiveConverter.CanConvert(format, out bufferRequirements); + + public override T Read(PgReader reader) => (T)_effectiveConverter.ReadAsObject(reader); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => this.ComposingReadAsObjectAsync(_effectiveConverter, reader, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => _effectiveConverter.GetSizeAsObject(context, value!, ref writeState); + + public override void Write(PgWriter writer, T value) + => _effectiveConverter.WriteAsObject(writer, value!); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => _effectiveConverter.WriteAsObjectAsync(writer, value!, cancellationToken); + + internal override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) + => async + ? _effectiveConverter.ReadAsObjectAsync(reader, cancellationToken) + : new(_effectiveConverter.ReadAsObject(reader)); + + internal override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + if (async) + return _effectiveConverter.WriteAsObjectAsync(writer, value, cancellationToken); + + _effectiveConverter.WriteAsObject(writer, value); + return new(); + } +} + +// Given there aren't many instantiations of converter resolvers (and it's fairly involved to write a fast one) we use the composing base class. +sealed class CastingConverterResolver : PgComposingConverterResolver +{ + public CastingConverterResolver(PgResolverTypeInfo effectiveResolverTypeInfo) + : base(effectiveResolverTypeInfo.PgTypeId, effectiveResolverTypeInfo) { } + + protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => pgTypeId; + protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => effectivePgTypeId; + + protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) + => new CastingConverter(effectiveResolution.Converter); + + protected override PgConverterResolution? GetEffectiveResolution(T? value, PgTypeId? expectedEffectiveTypeId) + => EffectiveTypeInfo.GetResolutionAsObject(value, expectedEffectiveTypeId); +} + +static class CastingTypeInfoExtensions +{ + internal static PgTypeInfo ToNonBoxing(this PgTypeInfo typeInfo) + { + if (!typeInfo.IsBoxing) + return typeInfo; + + var type = typeInfo.Type; + if (typeInfo is PgResolverTypeInfo resolverTypeInfo) + return new PgResolverTypeInfo(typeInfo.Options, + (PgConverterResolver)Activator.CreateInstance(typeof(CastingConverterResolver<>).MakeGenericType(type), + resolverTypeInfo)!, typeInfo.PgTypeId); + + var resolution = typeInfo.GetConcreteResolution(); + return new PgTypeInfo(typeInfo.Options, + (PgConverter)Activator.CreateInstance(typeof(CastingConverter<>).MakeGenericType(type), resolution.Converter)!, resolution.PgTypeId); + } +} diff --git a/src/Npgsql/Internal/Converters/CompositeConverter.cs b/src/Npgsql/Internal/Converters/CompositeConverter.cs new file mode 100644 index 0000000000..62befeb900 --- /dev/null +++ b/src/Npgsql/Internal/Converters/CompositeConverter.cs @@ -0,0 +1,185 @@ +using System; +using System.Buffers; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Composites; + +namespace Npgsql.Internal.Converters; + +sealed class CompositeConverter : PgStreamingConverter where T : notnull +{ + readonly CompositeInfo _composite; + readonly BufferRequirements _bufferRequirements; + + public CompositeConverter(CompositeInfo composite) + { + _composite = composite; + + var req = BufferRequirements.CreateFixedSize(sizeof(int) + _composite.Fields.Count * (sizeof(uint) + sizeof(int))); + foreach (var field in _composite.Fields) + { + var readReq = field.BinaryReadRequirement; + var writeReq = field.BinaryWriteRequirement; + + // If so we cannot depend on its buffer size being fixed. + if (field.IsDbNullable) + { + readReq = readReq.Combine(Size.CreateUpperBound(0)); + writeReq = readReq.Combine(Size.CreateUpperBound(0)); + } + + req = req.Combine( + // If a read is Unknown (streaming) we can map it to zero as we just want a minimum buffered size. + readReq is { Kind: SizeKind.Unknown } ? Size.Zero : readReq, + // For writes Unknown means our size is dependent on the value so we can't ignore it. + writeReq); + } + + // We have to put a limit on the requirements we report otherwise smaller buffer sizes won't work. + req = BufferRequirements.Create(Limit(req.Read), Limit(req.Write)); + + _bufferRequirements = req; + + Size Limit(Size requirement) + { + const int maxByteCount = 1024; + return requirement switch + { + { Kind: SizeKind.UpperBound } => Size.CreateUpperBound(Math.Min(maxByteCount, requirement.Value)), + { Kind: SizeKind.Exact } => Size.Create(Math.Min(maxByteCount, requirement.Value)), + _ => Size.Unknown + }; + } + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = _bufferRequirements; + return format is DataFormat.Binary; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + // TODO we can make a nice thread-static cache for this. + using var builder = new CompositeBuilder(_composite); + var count = reader.ReadInt32(); + if (count != _composite.Fields.Count) + throw new InvalidOperationException("Cannot read composite type with mismatched number of fields"); + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + foreach (var field in _composite.Fields) + { + if (reader.ShouldBuffer(sizeof(uint) + sizeof(int))) + await reader.Buffer(async, sizeof(uint) + sizeof(int), cancellationToken).ConfigureAwait(false); + + var oid = reader.ReadUInt32(); + var length = reader.ReadInt32(); + + // We're only requiring the PgTypeIds to be oids if this converter is actually used during execution. + // As a result we can still introspect in the global mapper and create all the info with portable ids. + if(oid != field.PgTypeId.Oid) + // We could remove this requirement by storing a dictionary of CompositeInfos keyed by backend. + throw new InvalidCastException( + $"Cannot read oid {oid} into composite field {field.Name} with oid {field.PgTypeId}. " + + $"This could be caused by a DDL change after this DataSource loaded its types, or a difference between column order of table composites between backends make sure these line up identically."); + + if (length is -1) + field.ReadDbNull(builder); + else + { + var scope = await reader.BeginNestedRead(async, length, field.BinaryReadRequirement, cancellationToken).ConfigureAwait(false); + try + { + await field.Read(async, builder, reader, cancellationToken).ConfigureAwait(false); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + } + + return builder.Complete(); + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + var arrayPool = ArrayPool<(Size Size, object? WriteState)>.Shared; + var data = arrayPool.Rent(_composite.Fields.Count); + + var totalSize = Size.Create(sizeof(int) + _composite.Fields.Count * (sizeof(uint) + sizeof(int))); + var boxedValue = (object)value; + var anyWriteState = false; + for (var i = 0; i < _composite.Fields.Count; i++) + { + var field = _composite.Fields[i]; + object? fieldState = null; + var fieldSize = field.GetSizeOrDbNull(context.Format, boxedValue, ref fieldState); + anyWriteState = anyWriteState || fieldState is not null; + data[i] = (fieldSize ?? -1, fieldState); + totalSize = totalSize.Combine(fieldSize ?? 0); + } + + writeState = new WriteState + { + ArrayPool = arrayPool, + BoxedInstance = boxedValue, + Data = new(data, 0, _composite.Fields.Count), + AnyWriteState = anyWriteState + }; + return totalSize; + } + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) + { + if (writer.Current.WriteState is not null and not WriteState) + throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}."); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(_composite.Fields.Count); + + var writeState = writer.Current.WriteState as WriteState; + var boxedInstance = writeState?.BoxedInstance ?? value!; + var data = writeState?.Data.Array; + for (var i = 0; i < _composite.Fields.Count; i++) + { + if (writer.ShouldFlush(sizeof(uint) + sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var field = _composite.Fields[i]; + writer.WriteAsOid(field.PgTypeId); + + var (size, fieldState) = data?[i] ?? (field.IsDbNull(boxedInstance) ? -1 : field.BinaryReadRequirement, null); + + var length = size.Value; + writer.WriteInt32(length); + if (length != -1) + { + using var _ = await writer.BeginNestedWrite(async, _bufferRequirements.Write, length, fieldState, cancellationToken).ConfigureAwait(false); + await field.Write(async, writer, boxedInstance, cancellationToken).ConfigureAwait(false); + } + } + } + + sealed class WriteState : MultiWriteState + { + public required object BoxedInstance { get; init; } + } +} diff --git a/src/Npgsql/Internal/Converters/EnumConverter.cs b/src/Npgsql/Internal/Converters/EnumConverter.cs new file mode 100644 index 0000000000..12f85992f0 --- /dev/null +++ b/src/Npgsql/Internal/Converters/EnumConverter.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; + +namespace Npgsql.Internal.Converters; + +[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] +sealed class EnumConverter : PgBufferedConverter where TEnum : struct, Enum +{ + readonly Dictionary _enumToLabel; + readonly Dictionary _labelToEnum; + readonly Encoding _encoding; + + // Unmapped enums + public EnumConverter(Dictionary enumToLabel, Dictionary labelToEnum, Encoding encoding) + { + _enumToLabel = new(enumToLabel.Count); + foreach (var kv in enumToLabel) + _enumToLabel.Add((TEnum)kv.Key, kv.Value); + + _labelToEnum = new(labelToEnum.Count); + foreach (var kv in labelToEnum) + _labelToEnum.Add(kv.Key, (TEnum)kv.Value); + + _encoding = encoding; + } + + public EnumConverter(Dictionary enumToLabel, Dictionary labelToEnum, Encoding encoding) + { + _enumToLabel = enumToLabel; + _labelToEnum = labelToEnum; + _encoding = encoding; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Value; + return format is DataFormat.Binary or DataFormat.Text; + } + + public override Size GetSize(SizeContext context, TEnum value, ref object? writeState) + { + if (!_enumToLabel.TryGetValue(value, out var str)) + throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); + + return _encoding.GetByteCount(str); + } + + protected override TEnum ReadCore(PgReader reader) + { + var str = _encoding.GetString(reader.ReadBytes(reader.CurrentRemaining)); + var success = _labelToEnum.TryGetValue(str, out var value); + + if (!success) + throw new InvalidCastException($"Received enum value '{str}' from database which wasn't found on enum {typeof(TEnum)}"); + + return value; + } + + protected override void WriteCore(PgWriter writer, TEnum value) + { + if (!_enumToLabel.TryGetValue(value, out var str)) + throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); + + writer.WriteBytes(new ReadOnlySpan(_encoding.GetBytes(str))); + } +} diff --git a/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs b/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs new file mode 100644 index 0000000000..220cc88894 --- /dev/null +++ b/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs @@ -0,0 +1,227 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; +using static NpgsqlTypes.NpgsqlTsQuery.NodeKind; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TsQueryConverter : PgStreamingConverter + where T : NpgsqlTsQuery +{ + readonly Encoding _encoding; + + public TsQueryConverter(Encoding encoding) + => _encoding = encoding; + + public override T Read(PgReader reader) + => (T)Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => (T)await Read(async: true, reader, cancellationToken).ConfigureAwait(false); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var numTokens = reader.ReadInt32(); + if (numTokens == 0) + return new NpgsqlTsQueryEmpty(); + + NpgsqlTsQuery? value = null; + var nodes = new Stack<(NpgsqlTsQuery Node, int Location)>(); + + for (var i = 0; i < numTokens; i++) + { + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + switch (reader.ReadByte()) + { + case 1: // lexeme + if (reader.ShouldBuffer(sizeof(byte) + sizeof(byte))) + await reader.Buffer(async, sizeof(byte) + sizeof(byte), cancellationToken).ConfigureAwait(false); + var weight = (NpgsqlTsQueryLexeme.Weight)reader.ReadByte(); + var prefix = reader.ReadByte() != 0; + + var str = async + ? await reader.ReadNullTerminatedStringAsync(_encoding, cancellationToken).ConfigureAwait(false) + : reader.ReadNullTerminatedString(_encoding); + InsertInTree(new NpgsqlTsQueryLexeme(str, weight, prefix), nodes, ref value); + continue; + + case 2: // operation + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + var kind = (NpgsqlTsQuery.NodeKind)reader.ReadByte(); + + NpgsqlTsQuery node; + switch (kind) + { + case Not: + node = new NpgsqlTsQueryNot(null!); + InsertInTree(node, nodes, ref value); + nodes.Push((node, 0)); + continue; + + case And: + node = new NpgsqlTsQueryAnd(null!, null!); + break; + case Or: + node = new NpgsqlTsQueryOr(null!, null!); + break; + case Phrase: + if (reader.ShouldBuffer(sizeof(short))) + await reader.Buffer(async, sizeof(short), cancellationToken).ConfigureAwait(false); + node = new NpgsqlTsQueryFollowedBy(null!, reader.ReadInt16(), null!); + break; + default: + throw new UnreachableException( + $"Internal Npgsql bug: unexpected value {kind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug."); + } + + InsertInTree(node, nodes, ref value); + + nodes.Push((node, 1)); + nodes.Push((node, 2)); + continue; + + case var tokenType: + throw new UnreachableException( + $"Internal Npgsql bug: unexpected token type {tokenType} when reading tsquery. Please file a bug."); + } + } + + if (nodes.Count != 0) + throw new UnreachableException("Internal Npgsql bug, please report."); + + return value!; + + static void InsertInTree(NpgsqlTsQuery node, Stack<(NpgsqlTsQuery Node, int Location)> nodes, ref NpgsqlTsQuery? value) + { + if (nodes.Count == 0) + value = node; + else + { + var parent = nodes.Pop(); + switch (parent.Location) + { + case 0: + ((NpgsqlTsQueryNot)parent.Node).Child = node; + break; + case 1: + ((NpgsqlTsQueryBinOp)parent.Node).Left = node; + break; + case 2: + ((NpgsqlTsQueryBinOp)parent.Node).Right = node; + break; + default: + throw new UnreachableException("Internal Npgsql bug, please report."); + } + } + } + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => value.Kind is Empty + ? 4 + : 4 + GetNodeLength(value); + + int GetNodeLength(NpgsqlTsQuery node) + => node.Kind switch + { + Lexeme when _encoding.GetByteCount(((NpgsqlTsQueryLexeme)node).Text) is var strLen + => strLen > 2046 + ? throw new InvalidCastException("Lexeme text too long. Must be at most 2046 encoded bytes.") + : 4 + strLen, + And or Or => 2 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right), + Not => 2 + GetNodeLength(((NpgsqlTsQueryNot)node).Child), + Empty => throw new InvalidOperationException("Empty tsquery nodes must be top-level"), + + // 2 additional bytes for uint16 phrase operator "distance" field. + Phrase => 4 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right), + + _ => throw new UnreachableException( + $"Internal Npgsql bug: unexpected value {node.Kind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug.") + }; + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlTsQuery value, CancellationToken cancellationToken) + { + var numTokens = GetTokenCount(value); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(numTokens); + + if (numTokens is 0) + return; + + await WriteCore(value).ConfigureAwait(false); + + async Task WriteCore(NpgsqlTsQuery node) + { + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte(node.Kind is Lexeme ? (byte)1 : (byte)2); + + if (node.Kind is Lexeme) + { + var lexemeNode = (NpgsqlTsQueryLexeme)node; + + if (writer.ShouldFlush(sizeof(byte) + sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte((byte)lexemeNode.Weights); + writer.WriteByte(lexemeNode.IsPrefixSearch ? (byte)1 : (byte)0); + + if (async) + await writer.WriteCharsAsync(lexemeNode.Text.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(lexemeNode.Text.AsMemory().Span, _encoding); + + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte(0); + return; + } + + writer.WriteByte((byte)node.Kind); + + switch (node.Kind) + { + case Not: + await WriteCore(((NpgsqlTsQueryNot)node).Child).ConfigureAwait(false); + return; + case Phrase: + writer.WriteInt16(((NpgsqlTsQueryFollowedBy)node).Distance); + break; + } + + await WriteCore(((NpgsqlTsQueryBinOp)node).Right).ConfigureAwait(false); + await WriteCore(((NpgsqlTsQueryBinOp)node).Left).ConfigureAwait(false); + } + } + + int GetTokenCount(NpgsqlTsQuery node) + => node.Kind switch + { + Lexeme => 1, + And or Or or Phrase => 1 + GetTokenCount(((NpgsqlTsQueryBinOp)node).Left) + GetTokenCount(((NpgsqlTsQueryBinOp)node).Right), + Not => 1 + GetTokenCount(((NpgsqlTsQueryNot)node).Child), + Empty => 0, + + _ => throw new UnreachableException( + $"Internal Npgsql bug: unexpected value {node.Kind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug.") + }; +} diff --git a/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs b/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs new file mode 100644 index 0000000000..a61aa2244c --- /dev/null +++ b/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TsVectorConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + + public TsVectorConverter(Encoding encoding) + => _encoding = encoding; + + public override NpgsqlTsVector Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var numLexemes = reader.ReadInt32(); + var lexemes = new List(numLexemes); + + for (var i = 0; i < numLexemes; i++) + { + var lexemeString = async + ? await reader.ReadNullTerminatedStringAsync(_encoding, cancellationToken).ConfigureAwait(false) + : reader.ReadNullTerminatedString(_encoding); + + if (reader.ShouldBuffer(sizeof(short))) + await reader.Buffer(async, sizeof(short), cancellationToken).ConfigureAwait(false); + var numPositions = reader.ReadInt16(); + + if (numPositions == 0) + { + lexemes.Add(new NpgsqlTsVector.Lexeme(lexemeString, wordEntryPositions: null, noCopy: true)); + continue; + } + + // There can only be a maximum of 256 positions, so we just before them all (256 * sizeof(short) = 512) + if (numPositions > 256) + throw new NpgsqlException($"Got {numPositions} lexeme positions when reading tsvector"); + + if (reader.ShouldBuffer(numPositions * sizeof(short))) + await reader.Buffer(async, numPositions * sizeof(short), cancellationToken).ConfigureAwait(false); + + var positions = new List(numPositions); + + for (var j = 0; j < numPositions; j++) + { + var wordEntryPos = reader.ReadInt16(); + positions.Add(new NpgsqlTsVector.Lexeme.WordEntryPos(wordEntryPos)); + } + + lexemes.Add(new NpgsqlTsVector.Lexeme(lexemeString, positions, noCopy: true)); + } + + return new NpgsqlTsVector(lexemes, noCheck: true); + } + + public override Size GetSize(SizeContext context, NpgsqlTsVector value, ref object? writeState) + => 4 + value.Sum(l => _encoding.GetByteCount(l.Text) + 1 + 2 + l.Count * 2); + + public override void Write(PgWriter writer, NpgsqlTsVector value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlTsVector value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlTsVector value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + foreach (var lexeme in value) + { + if (async) + await writer.WriteCharsAsync(lexeme.Text.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(lexeme.Text.AsMemory().Span, _encoding); + + if (writer.ShouldFlush(sizeof(byte) + sizeof(short))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte(0); + writer.WriteInt16((short)lexeme.Count); + + for (var i = 0; i < lexeme.Count; i++) + { + if (writer.ShouldFlush(sizeof(short))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt16(lexeme[i].Value); + } + } + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/BoxConverter.cs b/src/Npgsql/Internal/Converters/Geometric/BoxConverter.cs new file mode 100644 index 0000000000..4a7578afba --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/BoxConverter.cs @@ -0,0 +1,26 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class BoxConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 4); + return format is DataFormat.Binary; + } + + protected override NpgsqlBox ReadCore(PgReader reader) + => new( + new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble()), + new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); + + protected override void WriteCore(PgWriter writer, NpgsqlBox value) + { + writer.WriteDouble(value.Right); + writer.WriteDouble(value.Top); + writer.WriteDouble(value.Left); + writer.WriteDouble(value.Bottom); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/CircleConverter.cs b/src/Npgsql/Internal/Converters/Geometric/CircleConverter.cs new file mode 100644 index 0000000000..51eea75814 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/CircleConverter.cs @@ -0,0 +1,23 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class CircleConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 3); + return format is DataFormat.Binary; + } + + protected override NpgsqlCircle ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlCircle value) + { + writer.WriteDouble(value.X); + writer.WriteDouble(value.Y); + writer.WriteDouble(value.Radius); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/LineConverter.cs b/src/Npgsql/Internal/Converters/Geometric/LineConverter.cs new file mode 100644 index 0000000000..17d89909b9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/LineConverter.cs @@ -0,0 +1,23 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class LineConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 3); + return format is DataFormat.Binary; + } + + protected override NpgsqlLine ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlLine value) + { + writer.WriteDouble(value.A); + writer.WriteDouble(value.B); + writer.WriteDouble(value.C); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/LineSegmentConverter.cs b/src/Npgsql/Internal/Converters/Geometric/LineSegmentConverter.cs new file mode 100644 index 0000000000..117a108379 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/LineSegmentConverter.cs @@ -0,0 +1,24 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class LineSegmentConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 4); + return format is DataFormat.Binary; + } + + protected override NpgsqlLSeg ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlLSeg value) + { + writer.WriteDouble(value.Start.X); + writer.WriteDouble(value.Start.Y); + writer.WriteDouble(value.End.X); + writer.WriteDouble(value.End.Y); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs new file mode 100644 index 0000000000..c78ba84013 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs @@ -0,0 +1,68 @@ +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PathConverter : PgStreamingConverter +{ + public override NpgsqlPath Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(byte) + sizeof(int))) + await reader.Buffer(async, sizeof(byte) + sizeof(int), cancellationToken).ConfigureAwait(false); + + var open = reader.ReadByte() switch + { + 1 => false, + 0 => true, + _ => throw new UnreachableException("Error decoding binary geometric path: bad open byte") + }; + + var numPoints = reader.ReadInt32(); + var result = new NpgsqlPath(numPoints, open); + + for (var i = 0; i < numPoints; i++) + { + if (reader.ShouldBuffer(sizeof(double) * 2)) + await reader.Buffer(async, sizeof(byte) + sizeof(int), cancellationToken).ConfigureAwait(false); + + result.Add(new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); + } + + return result; + } + + public override Size GetSize(SizeContext context, NpgsqlPath value, ref object? writeState) + => 5 + value.Count * sizeof(double) * 2; + + public override void Write(PgWriter writer, NpgsqlPath value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlPath value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlPath value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(byte) + sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte((byte)(value.Open ? 0 : 1)); + writer.WriteInt32(value.Count); + + foreach (var p in value) + { + if (writer.ShouldFlush(sizeof(double) * 2)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteDouble(p.X); + writer.WriteDouble(p.Y); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PointConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PointConverter.cs new file mode 100644 index 0000000000..03e84c05bd --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/PointConverter.cs @@ -0,0 +1,22 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PointConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 2); + return format is DataFormat.Binary; + } + + protected override NpgsqlPoint ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlPoint value) + { + writer.WriteDouble(value.X); + writer.WriteDouble(value.Y); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PolygonConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PolygonConverter.cs new file mode 100644 index 0000000000..9a889b4323 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/PolygonConverter.cs @@ -0,0 +1,55 @@ +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PolygonConverter : PgStreamingConverter +{ + public override NpgsqlPolygon Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var numPoints = reader.ReadInt32(); + var result = new NpgsqlPolygon(numPoints); + for (var i = 0; i < numPoints; i++) + { + if (reader.ShouldBuffer(sizeof(double) * 2)) + await reader.Buffer(async, sizeof(double) * 2, cancellationToken).ConfigureAwait(false); + result.Add(new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); + } + + return result; + } + + public override Size GetSize(SizeContext context, NpgsqlPolygon value, ref object? writeState) + => 4 + value.Count * sizeof(double) * 2; + + public override void Write(PgWriter writer, NpgsqlPolygon value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlPolygon value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlPolygon value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + foreach (var p in value) + { + if (writer.ShouldFlush(sizeof(double) * 2)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteDouble(p.X); + writer.WriteDouble(p.Y); + } + } +} diff --git a/src/Npgsql/Internal/Converters/HstoreConverter.cs b/src/Npgsql/Internal/Converters/HstoreConverter.cs new file mode 100644 index 0000000000..5f99fd128c --- /dev/null +++ b/src/Npgsql/Internal/Converters/HstoreConverter.cs @@ -0,0 +1,159 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class HstoreConverter : PgStreamingConverter where T : ICollection> +{ + readonly Encoding _encoding; + readonly Func>, T>? _convert; + + public HstoreConverter(Encoding encoding, Func>, T>? convert = null) + { + _encoding = encoding; + _convert = convert; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + // Number of lengths (count, key length, value length). + var totalSize = sizeof(int) + value.Count * (sizeof(int) + sizeof(int)); + if (value.Count is 0) + return totalSize; + + var arrayPool = ArrayPool<(Size Size, object? WriteState)>.Shared; + var data = arrayPool.Rent(value.Count * 2); + + var i = 0; + foreach (var kv in value) + { + if (kv.Key is null) + throw new ArgumentException("Hstore doesn't support null keys", nameof(value)); + + var keySize = _encoding.GetByteCount(kv.Key); + var valueSize = kv.Value is null ? -1 : _encoding.GetByteCount(kv.Value); + totalSize += keySize + (valueSize is -1 ? 0 : valueSize); + data[i] = (keySize, null); + data[i + 1] = (valueSize, null); + i += 2; + } + writeState = new WriteState + { + ArrayPool = arrayPool, + Data = new(data, 0, value.Count * 2), + AnyWriteState = false + }; + return totalSize; + } + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var count = reader.ReadInt32(); + + var result = typeof(T) == typeof(Dictionary) || typeof(T) == typeof(IDictionary) + ? (ICollection>)new Dictionary(count) + : new List>(count); + + for (var i = 0; i < count; i++) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var keySize = reader.ReadInt32(); + var key = _encoding.GetString(async + ? await reader.ReadBytesAsync(keySize, cancellationToken).ConfigureAwait(false) + : reader.ReadBytes(keySize) + ); + + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var valueSize = reader.ReadInt32(); + string? value = null; + if (valueSize is not -1) + value = _encoding.GetString(async + ? await reader.ReadBytesAsync(valueSize, cancellationToken).ConfigureAwait(false) + : reader.ReadBytes(valueSize) + ); + + result.Add(new(key, value)); + } + + if (typeof(T) == typeof(Dictionary) || typeof(T) == typeof(IDictionary)) + return (T)result; + + return _convert is null ? throw new NotSupportedException() : _convert(result); + } + + async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) + { + if (writer.Current.WriteState is not WriteState && value.Count is not 0) + throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}."); + + // Number of lengths (count, key length, value length). + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + if (value.Count is 0 || writer.Current.WriteState is not WriteState writeState) + return; + + var data = writeState.Data; + var i = data.Offset; + foreach (var kv in value) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var (size, _) = data.Array![i]; + if (size.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var length = size.Value; + writer.WriteInt32(length); + if (async) + await writer.WriteCharsAsync(kv.Key.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(kv.Key.AsSpan(), _encoding); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var (valueSize, _) = data.Array![i + 1]; + if (valueSize.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var valueLength = valueSize.Value; + writer.WriteInt32(valueLength); + if (valueLength is not -1) + { + if (async) + await writer.WriteCharsAsync(kv.Value.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(kv.Key.AsSpan(), _encoding); + } + i += 2; + } + } + + sealed class WriteState : MultiWriteState + { + } +} diff --git a/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs b/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs new file mode 100644 index 0000000000..5d00a26dcb --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class InternalCharConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(byte)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadByte()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteByte(byte.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadByte(); + if (typeof(byte) == typeof(T)) + return (T)(object)value; + if (typeof(char) == typeof(T)) + return (T)(object)(char)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(byte) == typeof(T)) + writer.WriteByte((byte)(object)value!); + else if (typeof(char) == typeof(T)) + writer.WriteByte(checked((byte)(char)(object)value!)); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Internal/PgLsnConverter.cs b/src/Npgsql/Internal/Converters/Internal/PgLsnConverter.cs new file mode 100644 index 0000000000..96730c857a --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/PgLsnConverter.cs @@ -0,0 +1,15 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PgLsnConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(ulong)); + return format is DataFormat.Binary; + } + protected override NpgsqlLogSequenceNumber ReadCore(PgReader reader) => new(reader.ReadUInt64()); + protected override void WriteCore(PgWriter writer, NpgsqlLogSequenceNumber value) => writer.WriteUInt64((ulong)value); +} diff --git a/src/Npgsql/Internal/Converters/Internal/TidConverter.cs b/src/Npgsql/Internal/Converters/Internal/TidConverter.cs new file mode 100644 index 0000000000..747d98fe17 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/TidConverter.cs @@ -0,0 +1,19 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TidConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(uint) + sizeof(ushort)); + return format is DataFormat.Binary; + } + protected override NpgsqlTid ReadCore(PgReader reader) => new(reader.ReadUInt32(), reader.ReadUInt16()); + protected override void WriteCore(PgWriter writer, NpgsqlTid value) + { + writer.WriteUInt32(value.BlockNumber); + writer.WriteUInt16(value.OffsetNumber); + } +} diff --git a/src/Npgsql/Internal/Converters/Internal/UInt32Converter.cs b/src/Npgsql/Internal/Converters/Internal/UInt32Converter.cs new file mode 100644 index 0000000000..92061b1fd2 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/UInt32Converter.cs @@ -0,0 +1,13 @@ +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class UInt32Converter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(uint)); + return format is DataFormat.Binary; + } + protected override uint ReadCore(PgReader reader) => reader.ReadUInt32(); + protected override void WriteCore(PgWriter writer, uint value) => writer.WriteUInt32(value); +} diff --git a/src/Npgsql/Internal/Converters/Internal/UInt64Converter.cs b/src/Npgsql/Internal/Converters/Internal/UInt64Converter.cs new file mode 100644 index 0000000000..fcf5e3695a --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/UInt64Converter.cs @@ -0,0 +1,13 @@ +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class UInt64Converter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(ulong)); + return format is DataFormat.Binary; + } + protected override ulong ReadCore(PgReader reader) => reader.ReadUInt64(); + protected override void WriteCore(PgWriter writer, ulong value) => writer.WriteUInt64(value); +} diff --git a/src/Npgsql/Internal/Converters/Internal/VoidConverter.cs b/src/Npgsql/Internal/Converters/Internal/VoidConverter.cs new file mode 100644 index 0000000000..45b48df5b5 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/VoidConverter.cs @@ -0,0 +1,13 @@ +using System; + +namespace Npgsql.Internal.Converters.Internal; + +// Void is not a value so we read it as a null reference, not a DBNull. +sealed class VoidConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(DataFormat.Binary, out bufferRequirements); // Text is identical + + protected override object? ReadCore(PgReader reader) => null; + protected override void WriteCore(PgWriter writer, object? value) => throw new NotSupportedException(); +} diff --git a/src/Npgsql/Internal/Converters/MoneyConverter.cs b/src/Npgsql/Internal/Converters/MoneyConverter.cs new file mode 100644 index 0000000000..8443acedc3 --- /dev/null +++ b/src/Npgsql/Internal/Converters/MoneyConverter.cs @@ -0,0 +1,74 @@ +using System; +using System.Numerics; + +namespace Npgsql.Internal.Converters; + +sealed class MoneyConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + protected override T ReadCore(PgReader reader) => ConvertTo(new PgMoney(reader.ReadInt64())); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt64(ConvertFrom(value).GetValue()); + + static PgMoney ConvertFrom(T value) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return new PgMoney((decimal)(short)(object)value!); + if (typeof(int) == typeof(T)) + return new PgMoney((decimal)(int)(object)value!); + if (typeof(long) == typeof(T)) + return new PgMoney((decimal)(long)(object)value!); + + if (typeof(byte) == typeof(T)) + return new PgMoney((decimal)(byte)(object)value!); + if (typeof(sbyte) == typeof(T)) + return new PgMoney((decimal)(sbyte)(object)value!); + + if (typeof(float) == typeof(T)) + return new PgMoney((decimal)(float)(object)value!); + if (typeof(double) == typeof(T)) + return new PgMoney((decimal)(double)(object)value!); + if (typeof(decimal) == typeof(T)) + return new PgMoney((decimal)(object)value!); + + throw new NotSupportedException(); +#else + return new PgMoney(decimal.CreateChecked(value)); +#endif + } + + static T ConvertTo(PgMoney money) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return (T)(object)(short)money.ToDecimal(); + if (typeof(int) == typeof(T)) + return (T)(object)(int)money.ToDecimal(); + if (typeof(long) == typeof(T)) + return (T)(object)(long)money.ToDecimal(); + + if (typeof(byte) == typeof(T)) + return (T)(object)(byte)money.ToDecimal(); + if (typeof(sbyte) == typeof(T)) + return (T)(object)(sbyte)money.ToDecimal(); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)money.ToDecimal(); + if (typeof(double) == typeof(T)) + return (T)(object)(double)money.ToDecimal(); + if (typeof(decimal) == typeof(T)) + return (T)(object)money.ToDecimal(); + + throw new NotSupportedException(); +#else + return T.CreateChecked(money.ToDecimal()); +#endif + } +} diff --git a/src/Npgsql/Internal/Converters/MultirangeConverter.cs b/src/Npgsql/Internal/Converters/MultirangeConverter.cs new file mode 100644 index 0000000000..524901977b --- /dev/null +++ b/src/Npgsql/Internal/Converters/MultirangeConverter.cs @@ -0,0 +1,142 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class MultirangeConverter : PgStreamingConverter + where T : IList + where TRange : notnull +{ + readonly PgConverter _rangeConverter; + readonly BufferRequirements _rangeRequirements; + + static MultirangeConverter() + => Debug.Assert(typeof(T).IsArray || typeof(T).IsGenericType && typeof(T).GetGenericTypeDefinition() == typeof(List<>)); + + public MultirangeConverter(PgConverter rangeConverter) + { + if (!rangeConverter.CanConvert(DataFormat.Binary, out var bufferRequirements)) + throw new NotSupportedException("Range subtype converter has to support the binary format to be compatible."); + _rangeRequirements = bufferRequirements; + _rangeConverter = rangeConverter; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var numRanges = reader.ReadInt32(); + var multirange = (T)(object)(typeof(T).IsArray ? new TRange[numRanges] : new List()); + + for (var i = 0; i < numRanges; i++) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var length = reader.ReadInt32(); + Debug.Assert(length != -1); + + var scope = await reader.BeginNestedRead(async, length, _rangeRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + var range = async + ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : _rangeConverter.Read(reader); + + if (typeof(T).IsArray) + multirange[i] = range; + else + multirange.Add(range); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + + return multirange; + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + var arrayPool = ArrayPool<(Size Size, object? WriteState)>.Shared; + var data = arrayPool.Rent(value.Count); + + var totalSize = Size.Create(sizeof(int) + sizeof(int) * value.Count); + var anyWriteState = false; + for (var i = 0; i < value.Count; i++) + { + object? innerState = null; + var rangeSize = _rangeConverter.GetSizeOrDbNull(context.Format, _rangeRequirements.Write, value[i], ref innerState); + anyWriteState = anyWriteState || innerState is not null; + // Ranges should never be NULL. + Debug.Assert(rangeSize.HasValue); + data[i] = new(rangeSize.Value, innerState); + totalSize = totalSize.Combine(rangeSize.Value); + } + + writeState = new WriteState + { + ArrayPool = arrayPool, + Data = new(data, 0, value.Count), + AnyWriteState = anyWriteState + }; + return totalSize; + } + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) + { + if (writer.Current.WriteState is not WriteState writeState) + throw new InvalidCastException($"Invalid state {writer.Current.WriteState?.GetType().FullName}."); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + var data = writeState.Data.Array!; + for (var i = 0; i < value.Count; i++) + { + if (writer.ShouldFlush(sizeof(int))) // Length + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var (size, state) = data[i]; + if (size.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var length = size.Value; + writer.WriteInt32(length); + if (length != -1) + { + using var _ = await writer.BeginNestedWrite(async, _rangeRequirements.Write, length, state, cancellationToken).ConfigureAwait(false); + if (async) + await _rangeConverter.WriteAsync(writer, value[i], cancellationToken).ConfigureAwait(false); + else + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + _rangeConverter.Write(writer, value[i]); + } + } + } + + sealed class WriteState : MultiWriteState + { + } +} diff --git a/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs b/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs new file mode 100644 index 0000000000..9050f36f16 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs @@ -0,0 +1,23 @@ +using System.Net; +using System.Net.Sockets; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class IPAddressConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, IPAddress value, ref object? writeState) + => NpgsqlInetConverter.GetSizeImpl(context, value, ref writeState); + + protected override IPAddress ReadCore(PgReader reader) + => NpgsqlInetConverter.ReadImpl(reader, shouldBeCidr: false).Address; + + protected override void WriteCore(PgWriter writer, IPAddress value) + => NpgsqlInetConverter.WriteImpl( + writer, + (value, (byte)(value.AddressFamily == AddressFamily.InterNetwork ? 32 : 128)), + isCidr: false); +} diff --git a/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs b/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs new file mode 100644 index 0000000000..dd8aac78bc --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs @@ -0,0 +1,40 @@ +using System; +using System.Diagnostics; +using System.Net.NetworkInformation; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class MacaddrConverter : PgBufferedConverter +{ + readonly bool _macaddr8; + + public MacaddrConverter(bool macaddr8) => _macaddr8 = macaddr8; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = _macaddr8 ? BufferRequirements.Create(Size.CreateUpperBound(8)) : BufferRequirements.CreateFixedSize(6); + return format is DataFormat.Binary; + } + + public override Size GetSize(SizeContext context, PhysicalAddress value, ref object? writeState) + => value.GetAddressBytes().Length; + + protected override PhysicalAddress ReadCore(PgReader reader) + { + var len = reader.CurrentRemaining; + Debug.Assert(len is 6 or 8); + + var bytes = new byte[len]; + reader.Read(bytes); + return new PhysicalAddress(bytes); + } + + protected override void WriteCore(PgWriter writer, PhysicalAddress value) + { + var bytes = value.GetAddressBytes(); + if (!_macaddr8 && bytes.Length is not 6) + throw new ArgumentException("A macaddr value must be 6 bytes long."); + writer.WriteBytes(bytes); + } +} diff --git a/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs b/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs new file mode 100644 index 0000000000..249ec9a68f --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs @@ -0,0 +1,22 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class NpgsqlCidrConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, NpgsqlCidr value, ref object? writeState) + => NpgsqlInetConverter.GetSizeImpl(context, value.Address, ref writeState); + + protected override NpgsqlCidr ReadCore(PgReader reader) + { + var (ip, netmask) = NpgsqlInetConverter.ReadImpl(reader, shouldBeCidr: true); + return new(ip, netmask); + } + + protected override void WriteCore(PgWriter writer, NpgsqlCidr value) + => NpgsqlInetConverter.WriteImpl(writer, (value.Address, value.Netmask), isCidr: false); +} diff --git a/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs b/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs new file mode 100644 index 0000000000..f3af04e80a --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs @@ -0,0 +1,73 @@ +using System; +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class NpgsqlInetConverter : PgBufferedConverter +{ + const byte IPv4 = 2; + const byte IPv6 = 3; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, NpgsqlInet value, ref object? writeState) + => GetSizeImpl(context, value.Address, ref writeState); + + internal static Size GetSizeImpl(SizeContext context, IPAddress ipAddress, ref object? writeState) + => ipAddress.AddressFamily switch + { + AddressFamily.InterNetwork => 8, + AddressFamily.InterNetworkV6 => 20, + _ => throw new InvalidCastException( + $"Can't handle IPAddress with AddressFamily {ipAddress.AddressFamily}, only InterNetwork or InterNetworkV6!") + }; + + protected override NpgsqlInet ReadCore(PgReader reader) + { + var (ip, netmask) = ReadImpl(reader, shouldBeCidr: false); + return new(ip, netmask); + } + + internal static (IPAddress Address, byte Netmask) ReadImpl(PgReader reader, bool shouldBeCidr) + { + _ = reader.ReadByte(); // addressFamily + var mask = reader.ReadByte(); // mask + + var isCidr = reader.ReadByte() == 1; + Debug.Assert(isCidr == shouldBeCidr); + + var numBytes = reader.ReadByte(); + Span bytes = stackalloc byte[numBytes]; + reader.Read(bytes); +#if NETSTANDARD2_0 + return (new IPAddress(bytes.ToArray()), mask); +#else + return (new IPAddress(bytes), mask); +#endif + } + + protected override void WriteCore(PgWriter writer, NpgsqlInet value) + => WriteImpl(writer, (value.Address, value.Netmask), isCidr: false); + + internal static void WriteImpl(PgWriter writer, (IPAddress Address, byte Netmask) value, bool isCidr) + { + writer.WriteByte(value.Address.AddressFamily switch + { + AddressFamily.InterNetwork => IPv4, + AddressFamily.InterNetworkV6 => IPv6, + _ => throw new InvalidCastException( + $"Can't handle IPAddress with AddressFamily {value.Address.AddressFamily}, only InterNetwork or InterNetworkV6!") + }); + + writer.WriteByte(value.Netmask); + writer.WriteByte((byte)(isCidr ? 1 : 0)); // Ignored on server side + var bytes = value.Address.GetAddressBytes(); + writer.WriteByte((byte)bytes.Length); + writer.WriteBytes(bytes); + } +} diff --git a/src/Npgsql/Internal/Converters/NullableConverter.cs b/src/Npgsql/Internal/Converters/NullableConverter.cs new file mode 100644 index 0000000000..b3f8a8a0b2 --- /dev/null +++ b/src/Npgsql/Internal/Converters/NullableConverter.cs @@ -0,0 +1,60 @@ +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +// NULL writing is always responsibility of the caller writing the length, so there is not much we do here. +/// Special value converter to be able to use struct converters as System.Nullable converters, it delegates all behavior to the effective converter. +sealed class NullableConverter : PgConverter where T : struct +{ + readonly PgConverter _effectiveConverter; + public NullableConverter(PgConverter effectiveConverter) + : base(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) + => _effectiveConverter = effectiveConverter; + + protected override bool IsDbNullValue(T? value) + => value is null || _effectiveConverter.IsDbNull(value.GetValueOrDefault()); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => _effectiveConverter.CanConvert(format, out bufferRequirements); + + public override T? Read(PgReader reader) + => _effectiveConverter.Read(reader); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => this.ComposingReadAsync(_effectiveConverter, reader, cancellationToken); + + public override Size GetSize(SizeContext context, [DisallowNull]T? value, ref object? writeState) + => _effectiveConverter.GetSize(context, value.GetValueOrDefault(), ref writeState); + + public override void Write(PgWriter writer, T? value) + => _effectiveConverter.Write(writer, value.GetValueOrDefault()); + + public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) + => _effectiveConverter.WriteAsync(writer, value.GetValueOrDefault(), cancellationToken); + + internal override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) + => _effectiveConverter.ReadAsObject(async, reader, cancellationToken); + + internal override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + => _effectiveConverter.WriteAsObject(async, writer, value, cancellationToken); +} + +sealed class NullableConverterResolver : PgComposingConverterResolver where T : struct +{ + public NullableConverterResolver(PgResolverTypeInfo effectiveTypeInfo) + : base(effectiveTypeInfo.PgTypeId, effectiveTypeInfo) { } + + protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => pgTypeId; + protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => effectivePgTypeId; + + protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) + => new NullableConverter(effectiveResolution.GetConverter()); + + protected override PgConverterResolution? GetEffectiveResolution(T? value, PgTypeId? expectedEffectivePgTypeId) + => value is null + ? EffectiveTypeInfo.GetDefaultResolution(expectedEffectivePgTypeId) + : EffectiveTypeInfo.GetResolution(value.GetValueOrDefault(), expectedEffectivePgTypeId); +} diff --git a/src/Npgsql/Internal/Converters/ObjectArrayRecordConverter.cs b/src/Npgsql/Internal/Converters/ObjectArrayRecordConverter.cs new file mode 100644 index 0000000000..9b028c02cc --- /dev/null +++ b/src/Npgsql/Internal/Converters/ObjectArrayRecordConverter.cs @@ -0,0 +1,79 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +sealed class ObjectArrayRecordConverter : PgStreamingConverter +{ + readonly PgSerializerOptions _serializerOptions; + readonly Func? _factory; + + public ObjectArrayRecordConverter(PgSerializerOptions serializerOptions, Func? factory = null) + { + _serializerOptions = serializerOptions; + _factory = factory; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var fieldCount = reader.ReadInt32(); + var result = new object[fieldCount]; + for (var i = 0; i < fieldCount; i++) + { + if (reader.ShouldBuffer(sizeof(uint) + sizeof(int))) + await reader.Buffer(async, sizeof(uint) + sizeof(int), cancellationToken).ConfigureAwait(false); + + var typeOid = reader.ReadUInt32(); + var length = reader.ReadInt32(); + + // Note that we leave .NET nulls in the object array rather than DBNull. + if (length == -1) + continue; + + var postgresType = + _serializerOptions.DatabaseInfo.GetPostgresType(typeOid).GetRepresentationalType() + ?? throw new NotSupportedException($"Reading isn't supported for record field {i} (unknown type OID {typeOid}"); + + var typeInfo = _serializerOptions.GetObjectOrDefaultTypeInfo(postgresType) + ?? throw new NotSupportedException( + $"Reading isn't supported for record field {i} (PG type '{postgresType.DisplayName}'"); + var resolution = typeInfo.GetConcreteResolution(); + if (typeInfo.GetBufferRequirements(resolution.Converter, DataFormat.Binary) is not { } bufferRequirements) + throw new NotSupportedException($"Resolved record field converter '{resolution.Converter.GetType()}' has to support the binary format to be compatible."); + + var scope = await reader.BeginNestedRead(async, length, bufferRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + result[i] = await resolution.Converter.ReadAsObject(async, reader, cancellationToken).ConfigureAwait(false); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + + return _factory is null ? (T)(object)result : _factory(result); + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => throw new NotSupportedException(); + + public override void Write(PgWriter writer, T value) + => throw new NotSupportedException(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => throw new NotSupportedException(); +} diff --git a/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs b/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs new file mode 100644 index 0000000000..7c78e34a24 --- /dev/null +++ b/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +abstract class PolymorphicConverterResolver : PgConverterResolver +{ + protected PolymorphicConverterResolver(PgTypeId pgTypeId) => PgTypeId = pgTypeId; + + protected PgTypeId PgTypeId { get; } + + protected abstract PgConverter Get(Field? field); + + public sealed override PgConverterResolution GetDefault(PgTypeId? pgTypeId) + { + if (pgTypeId is not null && pgTypeId != PgTypeId) + throw CreateUnsupportedPgTypeIdException(pgTypeId.Value); + + return new(Get(null), PgTypeId); + } + + public sealed override PgConverterResolution? Get(TBase? value, PgTypeId? expectedPgTypeId) + => new(Get(null), PgTypeId); + + public sealed override PgConverterResolution Get(Field field) + { + if (field.PgTypeId != PgTypeId) + throw CreateUnsupportedPgTypeIdException(field.PgTypeId); + + var converter = Get(field); + return new(converter, PgTypeId); + } +} + +// Many ways to achieve strongly typed composition on top of a polymorphic element type. +// Including pushing construction through a GVM visitor pattern on the element handler, +// manual reimplementation of the element logic in the array resolver, and other ways. +// This one however is by far the most lightweight on both the implementation duplication and code bloat axes. +sealed class ArrayPolymorphicConverterResolver : PolymorphicConverterResolver +{ + readonly PgResolverTypeInfo _elemTypeInfo; + readonly Func _elemToArrayConverterFactory; + readonly PgTypeId _elemPgTypeId; + readonly ConcurrentDictionary _converterCache = new(ReferenceEqualityComparer.Instance); + + public ArrayPolymorphicConverterResolver(PgTypeId pgTypeId, PgResolverTypeInfo elemTypeInfo, Func elemToArrayConverterFactory) + : base(pgTypeId) + { + if (elemTypeInfo.PgTypeId is null) + throw new ArgumentException("elemTypeInfo.PgTypeId must be non-null.", nameof(elemTypeInfo)); + + _elemTypeInfo = elemTypeInfo; + _elemToArrayConverterFactory = elemToArrayConverterFactory; + _elemPgTypeId = elemTypeInfo.PgTypeId!.Value; + } + + protected override PgConverter Get(Field? maybeField) + { + var elemResolution = maybeField is { } field + ? _elemTypeInfo.GetResolution(field with { PgTypeId = _elemPgTypeId }) + : _elemTypeInfo.GetDefaultResolution(_elemPgTypeId); + + (Func Factory, PgConverterResolution Resolution) state = (_elemToArrayConverterFactory, elemResolution); + return _converterCache.GetOrAdd(elemResolution.Converter, static (_, state) => state.Factory(state.Resolution), state); + } +} diff --git a/src/Npgsql/Internal/Converters/Primitive/BoolConverter.cs b/src/Npgsql/Internal/Converters/Primitive/BoolConverter.cs new file mode 100644 index 0000000000..196877ad0e --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/BoolConverter.cs @@ -0,0 +1,13 @@ +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class BoolConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(byte)); + return format is DataFormat.Binary; + } + protected override bool ReadCore(PgReader reader) => reader.ReadByte() is not 0; + protected override void WriteCore(PgWriter writer, bool value) => writer.WriteByte((byte)(value ? 1 : 0)); +} diff --git a/src/Npgsql/Internal/Converters/Primitive/ByteaConverters.cs b/src/Npgsql/Internal/Converters/Primitive/ByteaConverters.cs new file mode 100644 index 0000000000..1d2b1ce531 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/ByteaConverters.cs @@ -0,0 +1,124 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +abstract class ByteaConverters : PgStreamingConverter +{ + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => ConvertTo(value).Length; + + public override void Write(PgWriter writer, T value) + => writer.WriteBytes(ConvertTo(value).Span); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => writer.WriteBytesAsync(ConvertTo(value), cancellationToken); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + var bytes = new byte[reader.CurrentRemaining]; + if (async) + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + else + reader.ReadBytes(bytes); + + return ConvertFrom(new(bytes)); + } + + protected abstract Memory ConvertTo(T value); + protected abstract T ConvertFrom(Memory value); +} + +sealed class ArraySegmentByteaConverter : ByteaConverters> +{ + protected override Memory ConvertTo(ArraySegment value) => value; + protected override ArraySegment ConvertFrom(Memory value) + => MemoryMarshal.TryGetArray(value, out var segment) + ? segment + : throw new UnreachableException("Expected array-backed memory"); +} + +sealed class ArrayByteaConverter : PgStreamingConverter +{ + public override byte[] Read(PgReader reader) + { + var bytes = new byte[reader.CurrentRemaining]; + reader.ReadBytes(bytes); + return bytes; + } + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + var bytes = new byte[reader.CurrentRemaining]; + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + return bytes; + } + + public override Size GetSize(SizeContext context, byte[] value, ref object? writeState) + => value.Length; + + public override void Write(PgWriter writer, byte[] value) + => writer.WriteBytes(value); + + public override ValueTask WriteAsync(PgWriter writer, byte[] value, CancellationToken cancellationToken = default) + => writer.WriteBytesAsync(value, cancellationToken); +} + +sealed class ReadOnlyMemoryByteaConverter : ByteaConverters> +{ + protected override Memory ConvertTo(ReadOnlyMemory value) => MemoryMarshal.AsMemory(value); + protected override ReadOnlyMemory ConvertFrom(Memory value) => value; +} + +sealed class MemoryByteaConverter : ByteaConverters> +{ + protected override Memory ConvertTo(Memory value) => value; + protected override Memory ConvertFrom(Memory value) => value; +} + +sealed class StreamByteaConverter : PgStreamingConverter +{ + public override Stream Read(PgReader reader) + => throw new NotSupportedException("Handled by generic stream support in NpgsqlDataReader"); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => throw new NotSupportedException("Handled by generic stream support in NpgsqlDataReader"); + + public override Size GetSize(SizeContext context, Stream value, ref object? writeState) + { + var memoryStream = new MemoryStream(value.CanSeek ? (int)(value.Length - value.Position) : 0); + value.CopyTo(memoryStream); + writeState = memoryStream; + return checked((int)memoryStream.Length); + } + + public override void Write(PgWriter writer, Stream value) + { + if (!((MemoryStream)writer.Current.WriteState!).TryGetBuffer(out var segment)) + throw new InvalidOperationException(); + writer.WriteBytes(segment.AsSpan()); + } + + public override ValueTask WriteAsync(PgWriter writer, Stream value, CancellationToken cancellationToken = default) + { + if (!((MemoryStream)writer.Current.WriteState!).TryGetBuffer(out var segment)) + throw new InvalidOperationException(); + + return writer.WriteBytesAsync(segment.AsMemory(), cancellationToken); + } +} diff --git a/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs b/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs new file mode 100644 index 0000000000..74a56d06ae --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DoubleConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadDouble()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteDouble(double.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadDouble(); + if (typeof(float) == typeof(T)) + return (T)(object)value; + if (typeof(double) == typeof(T)) + return (T)(object)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(float) == typeof(T)) + writer.WriteDouble((float)(object)value!); + else if (typeof(double) == typeof(T)) + writer.WriteDouble((double)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs b/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs new file mode 100644 index 0000000000..596deedfce --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs @@ -0,0 +1,70 @@ +using System; +using System.Buffers.Binary; +using System.Runtime.InteropServices; + +namespace Npgsql.Internal.Converters; + +sealed class GuidUuidConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(16 * sizeof(byte)); + return format is DataFormat.Binary; + } + protected override Guid ReadCore(PgReader reader) + { +#if NET8_0_OR_GREATER + return new Guid(reader.ReadBytes(16).FirstSpan, bigEndian: true); +#else + return new GuidRaw + { + Data1 = reader.ReadInt32(), + Data2 = reader.ReadInt16(), + Data3 = reader.ReadInt16(), + Data4 = BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadInt64()) : reader.ReadInt64() + }.Value; +#endif + } + + protected override void WriteCore(PgWriter writer, Guid value) + { +#if NET8_0_OR_GREATER + Span bytes = stackalloc byte[16]; + value.TryWriteBytes(bytes, bigEndian: true, out _); + writer.WriteBytes(bytes); +#else + var raw = new GuidRaw(value); + + writer.WriteInt32(raw.Data1); + writer.WriteInt16(raw.Data2); + writer.WriteInt16(raw.Data3); + writer.WriteInt64(BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(raw.Data4) : raw.Data4); +#endif + } + +#if !NET8_0_OR_GREATER + // The following table shows .NET GUID vs Postgres UUID (RFC 4122) layouts. + // + // Note that the first fields are converted from/to native endianness (handled by the Read* + // and Write* methods), while the last field is always read/written in big-endian format. + // + // We're reverting endianness on little endian systems to get it into big endian format. + // + // | Bits | Bytes | Name | Endianness (GUID) | Endianness (RFC 4122) | + // | ---- | ----- | ----- | ----------------- | --------------------- | + // | 32 | 4 | Data1 | Native | Big | + // | 16 | 2 | Data2 | Native | Big | + // | 16 | 2 | Data3 | Native | Big | + // | 64 | 8 | Data4 | Big | Big | + [StructLayout(LayoutKind.Explicit)] + struct GuidRaw + { + [FieldOffset(0)] public Guid Value; + [FieldOffset(0)] public int Data1; + [FieldOffset(4)] public short Data2; + [FieldOffset(6)] public short Data3; + [FieldOffset(8)] public long Data4; + public GuidRaw(Guid value) : this() => Value = value; + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs new file mode 100644 index 0000000000..e54658d925 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs @@ -0,0 +1,70 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class Int2Converter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(short)); + return format is DataFormat.Binary; + } +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt16()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt16(short.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadInt16(); + if (typeof(short) == typeof(T)) + return (T)(object)value; + if (typeof(int) == typeof(T)) + return (T)(object)(int)value; + if (typeof(long) == typeof(T)) + return (T)(object)(long)value; + + if (typeof(byte) == typeof(T)) + return (T)(object)checked((byte)value); + if (typeof(sbyte) == typeof(T)) + return (T)(object)checked((sbyte)value); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + if (typeof(decimal) == typeof(T)) + return (T)(object)(decimal)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(short) == typeof(T)) + writer.WriteInt16((short)(object)value!); + else if (typeof(int) == typeof(T)) + writer.WriteInt16(checked((short)(int)(object)value!)); + else if (typeof(long) == typeof(T)) + writer.WriteInt16(checked((short)(long)(object)value!)); + + else if (typeof(byte) == typeof(T)) + writer.WriteInt16((byte)(object)value!); + else if (typeof(sbyte) == typeof(T)) + writer.WriteInt16((sbyte)(object)value!); + + else if (typeof(float) == typeof(T)) + writer.WriteInt16(checked((short)(float)(object)value!)); + else if (typeof(double) == typeof(T)) + writer.WriteInt16(checked((short)(double)(object)value!)); + else if (typeof(decimal) == typeof(T)) + writer.WriteInt16((short)(decimal)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs new file mode 100644 index 0000000000..1831ca9b1e --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs @@ -0,0 +1,71 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class Int4Converter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt32()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt32(int.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadInt32(); + if (typeof(short) == typeof(T)) + return (T)(object)checked((short)value); + if (typeof(int) == typeof(T)) + return (T)(object)value; + if (typeof(long) == typeof(T)) + return (T)(object)(long)value; + + if (typeof(byte) == typeof(T)) + return (T)(object)checked((byte)value); + if (typeof(sbyte) == typeof(T)) + return (T)(object)checked((sbyte)value); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + if (typeof(decimal) == typeof(T)) + return (T)(object)(decimal)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(short) == typeof(T)) + writer.WriteInt32((short)(object)value!); + else if (typeof(int) == typeof(T)) + writer.WriteInt32((int)(object)value!); + else if (typeof(long) == typeof(T)) + writer.WriteInt32(checked((int)(long)(object)value!)); + + else if (typeof(byte) == typeof(T)) + writer.WriteInt32((byte)(object)value!); + else if (typeof(sbyte) == typeof(T)) + writer.WriteInt32((sbyte)(object)value!); + + else if (typeof(float) == typeof(T)) + writer.WriteInt32(checked((int)(float)(object)value!)); + else if (typeof(double) == typeof(T)) + writer.WriteInt32(checked((int)(double)(object)value!)); + else if (typeof(decimal) == typeof(T)) + writer.WriteInt32((int)(decimal)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs new file mode 100644 index 0000000000..b422816244 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs @@ -0,0 +1,72 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class Int8Converter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt64()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt64(long.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadInt64(); + if (typeof(long) == typeof(T)) + return (T)(object)value; + + if (typeof(short) == typeof(T)) + return (T)(object)checked((short)value); + if (typeof(int) == typeof(T)) + return (T)(object)checked((int)value); + + if (typeof(byte) == typeof(T)) + return (T)(object)checked((byte)value); + if (typeof(sbyte) == typeof(T)) + return (T)(object)checked((sbyte)value); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + if (typeof(decimal) == typeof(T)) + return (T)(object)(decimal)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(short) == typeof(T)) + writer.WriteInt64((short)(object)value!); + else if (typeof(int) == typeof(T)) + writer.WriteInt64((int)(object)value!); + else if (typeof(long) == typeof(T)) + writer.WriteInt64((long)(object)value!); + + else if (typeof(byte) == typeof(T)) + writer.WriteInt64((byte)(object)value!); + else if (typeof(sbyte) == typeof(T)) + writer.WriteInt64((sbyte)(object)value!); + + else if (typeof(float) == typeof(T)) + writer.WriteInt64(checked((long)(float)(object)value!)); + else if (typeof(double) == typeof(T)) + writer.WriteInt64(checked((long)(double)(object)value!)); + else if (typeof(decimal) == typeof(T)) + writer.WriteInt64((long)(decimal)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs b/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs new file mode 100644 index 0000000000..c43e90a1f7 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs @@ -0,0 +1,262 @@ +using System; +using System.Buffers; +using System.Numerics; +using System.Threading; +using System.Threading.Tasks; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class BigIntegerNumericConverter : PgStreamingConverter +{ + const int StackAllocByteThreshold = 64 * sizeof(uint); + + public override BigInteger Read(PgReader reader) + { + var digitCount = reader.ReadInt16(); + short[]? digitsFromPool = null; + var digits = (digitCount <= StackAllocByteThreshold / sizeof(short) + ? stackalloc short[StackAllocByteThreshold / sizeof(short)] + : (digitsFromPool = ArrayPool.Shared.Rent(digitCount)).AsSpan()).Slice(0, digitCount); + + var value = ConvertTo(NumericConverter.Read(reader, digits)); + + if (digitsFromPool is not null) + ArrayPool.Shared.Return(digitsFromPool); + + return value; + } + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + // If we don't need a read and can read buffered we delegate to our sync read method which won't do IO in such a case. + if (!reader.ShouldBuffer(reader.CurrentRemaining)) + Read(reader); + + return AsyncCore(reader, cancellationToken); + + static async ValueTask AsyncCore(PgReader reader, CancellationToken cancellationToken) + { + await reader.BufferAsync(PgNumeric.GetByteCount(0), cancellationToken).ConfigureAwait(false); + var digitCount = reader.ReadInt16(); + var digits = new ArraySegment(ArrayPool.Shared.Rent(digitCount), 0, digitCount); + var value = ConvertTo(await NumericConverter.ReadAsync(reader, digits, cancellationToken).ConfigureAwait(false)); + + ArrayPool.Shared.Return(digits.Array!); + + return value; + } + } + + public override Size GetSize(SizeContext context, BigInteger value, ref object? writeState) => + PgNumeric.GetByteCount(PgNumeric.GetDigitCount(value)); + + public override void Write(PgWriter writer, BigInteger value) + { + // We don't know how many digits we need so we allocate a decent chunk of stack for the builder to use. + // If it's not enough for the builder will do a heap allocation (for decimal it's always enough). + Span destination = stackalloc short[StackAllocByteThreshold / sizeof(short)]; + var numeric = ConvertFrom(value, destination); + NumericConverter.Write(writer, numeric); + } + + public override ValueTask WriteAsync(PgWriter writer, BigInteger value, CancellationToken cancellationToken = default) + { + if (writer.ShouldFlush(writer.Current.Size)) + return AsyncCore(writer, value, cancellationToken); + + // If we don't need a flush and can write buffered we delegate to our sync write method which won't flush in such a case. + Write(writer, value); + return new(); + + static async ValueTask AsyncCore(PgWriter writer, BigInteger value, CancellationToken cancellationToken) + { + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + var numeric = ConvertFrom(value, Array.Empty()).Build(); + await NumericConverter.WriteAsync(writer, numeric, cancellationToken).ConfigureAwait(false); + } + } + + static PgNumeric.Builder ConvertFrom(BigInteger value, Span destination) => new(value, destination); + static BigInteger ConvertTo(in PgNumeric.Builder numeric) => numeric.ToBigInteger(); + static BigInteger ConvertTo(in PgNumeric numeric) => numeric.ToBigInteger(); +} + +sealed class DecimalNumericConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#else + where T : notnull +#endif +{ + const int StackAllocByteThreshold = 64 * sizeof(uint); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + // This upper bound would already cause an overflow exception in the builder, no need to do + 1. + bufferRequirements = BufferRequirements.Create(Size.CreateUpperBound(NumericConverter.DecimalBasedMaxByteCount)); + return format is DataFormat.Binary; + } + + protected override T ReadCore(PgReader reader) + { + var digitCount = reader.ReadInt16(); + var digits = stackalloc short[StackAllocByteThreshold / sizeof(short)].Slice(0, digitCount);; + var value = ConvertTo(NumericConverter.Read(reader, digits)); + return value; + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) => + PgNumeric.GetByteCount(default(T) switch + { + _ when typeof(decimal) == typeof(T) => PgNumeric.GetDigitCount((decimal)(object)value), + _ when typeof(short) == typeof(T) => PgNumeric.GetDigitCount((decimal)(short)(object)value), + _ when typeof(int) == typeof(T) => PgNumeric.GetDigitCount((decimal)(int)(object)value), + _ when typeof(long) == typeof(T) => PgNumeric.GetDigitCount((decimal)(long)(object)value), + _ when typeof(byte) == typeof(T) => PgNumeric.GetDigitCount((decimal)(byte)(object)value), + _ when typeof(sbyte) == typeof(T) => PgNumeric.GetDigitCount((decimal)(sbyte)(object)value), + _ when typeof(float) == typeof(T) => PgNumeric.GetDigitCount((decimal)(float)(object)value), + _ when typeof(double) == typeof(T) => PgNumeric.GetDigitCount((decimal)(double)(object)value), + _ => throw new NotSupportedException() + }); + + protected override void WriteCore(PgWriter writer, T value) + { + // We don't know how many digits we need so we allocate enough for the builder to use. + Span destination = stackalloc short[PgNumeric.Builder.MaxDecimalNumericDigits]; + var numeric = ConvertFrom(value, destination); + NumericConverter.Write(writer, numeric); + } + + static PgNumeric.Builder ConvertFrom(T value, Span destination) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return new PgNumeric.Builder((decimal)(short)(object)value!, destination); + if (typeof(int) == typeof(T)) + return new PgNumeric.Builder((decimal)(int)(object)value!, destination); + if (typeof(long) == typeof(T)) + return new PgNumeric.Builder((decimal)(long)(object)value!, destination); + + if (typeof(byte) == typeof(T)) + return new PgNumeric.Builder((decimal)(byte)(object)value!, destination); + if (typeof(sbyte) == typeof(T)) + return new PgNumeric.Builder((decimal)(sbyte)(object)value!, destination); + + if (typeof(float) == typeof(T)) + return new PgNumeric.Builder((decimal)(float)(object)value!, destination); + if (typeof(double) == typeof(T)) + return new PgNumeric.Builder((decimal)(double)(object)value!, destination); + if (typeof(decimal) == typeof(T)) + return new PgNumeric.Builder((decimal)(object)value!, destination); + + throw new NotSupportedException(); +#else + return new PgNumeric.Builder(decimal.CreateChecked(value), destination); +#endif + } + + static T ConvertTo(in PgNumeric.Builder numeric) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return (T)(object)(short)numeric.ToDecimal(); + if (typeof(int) == typeof(T)) + return (T)(object)(int)numeric.ToDecimal(); + if (typeof(long) == typeof(T)) + return (T)(object)(long)numeric.ToDecimal(); + + if (typeof(byte) == typeof(T)) + return (T)(object)(byte)numeric.ToDecimal(); + if (typeof(sbyte) == typeof(T)) + return (T)(object)(sbyte)numeric.ToDecimal(); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)numeric.ToDecimal(); + if (typeof(double) == typeof(T)) + return (T)(object)(double)numeric.ToDecimal(); + if (typeof(decimal) == typeof(T)) + return (T)(object)numeric.ToDecimal(); + + throw new NotSupportedException(); +#else + return T.CreateChecked(numeric.ToDecimal()); +#endif + } +} + +static class NumericConverter +{ + public static int DecimalBasedMaxByteCount = PgNumeric.GetByteCount(PgNumeric.Builder.MaxDecimalNumericDigits); + + public static PgNumeric.Builder Read(PgReader reader, Span digits) + { + var remainingStructureSize = PgNumeric.GetByteCount(0) - sizeof(short); + if (reader.ShouldBuffer(remainingStructureSize)) + reader.Buffer(remainingStructureSize); + var weight = reader.ReadInt16(); + var sign = reader.ReadInt16(); + var scale = reader.ReadInt16(); + foreach (ref var digit in digits) + { + if (reader.ShouldBuffer(sizeof(short))) + reader.Buffer(sizeof(short)); + digit = reader.ReadInt16(); + } + + return new PgNumeric.Builder(digits, weight, sign, scale); + } + + public static async ValueTask ReadAsync(PgReader reader, ArraySegment digits, CancellationToken cancellationToken) + { + var remainingStructureSize = PgNumeric.GetByteCount(0) - sizeof(short); + if (reader.ShouldBuffer(remainingStructureSize)) + await reader.BufferAsync(remainingStructureSize, cancellationToken).ConfigureAwait(false); + var weight = reader.ReadInt16(); + var sign = reader.ReadInt16(); + var scale = reader.ReadInt16(); + var array = digits.Array!; + for (var i = digits.Offset; i < array.Length; i++) + { + if (reader.ShouldBuffer(sizeof(short))) + await reader.BufferAsync(sizeof(short), cancellationToken).ConfigureAwait(false); + array[i] = reader.ReadInt16(); + } + + return new PgNumeric.Builder(digits, weight, sign, scale).Build(); + } + + public static void Write(PgWriter writer, PgNumeric.Builder numeric) + { + if (writer.ShouldFlush(PgNumeric.GetByteCount(0))) + writer.Flush(); + writer.WriteInt16((short)numeric.Digits.Length); + writer.WriteInt16(numeric.Weight); + writer.WriteInt16(numeric.Sign); + writer.WriteInt16(numeric.Scale); + + foreach (var digit in numeric.Digits) + { + if (writer.ShouldFlush(sizeof(short))) + writer.Flush(); + writer.WriteInt16(digit); + } + } + + public static async ValueTask WriteAsync(PgWriter writer, PgNumeric numeric, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(PgNumeric.GetByteCount(0))) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + writer.WriteInt16((short)numeric.Digits.Count); + writer.WriteInt16(numeric.Weight); + writer.WriteInt16(numeric.Sign); + writer.WriteInt16(numeric.Scale); + + foreach (var digit in numeric.Digits) + { + if (writer.ShouldFlush(sizeof(short))) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + writer.WriteInt16(digit); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs b/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs new file mode 100644 index 0000000000..495e2a8aba --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs @@ -0,0 +1,104 @@ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Npgsql.Internal.Converters; + +readonly struct PgMoney +{ + const int DecimalBits = 4; + const int MoneyScale = 2; + readonly long _value; + + public PgMoney(long value) => _value = value; + + public PgMoney(decimal value) + { + if (value is < -92233720368547758.08M or > 92233720368547758.07M) + throw new OverflowException($"The supplied value '{value}' is outside the range for a PostgreSQL money value."); + + // No-op if scale was already 2 or less. + value = decimal.Round(value, MoneyScale, MidpointRounding.AwayFromZero); + + Span bits = stackalloc uint[DecimalBits]; + GetDecimalBits(value, bits, out var scale); + + var money = (long)bits[1] << 32 | bits[0]; + if (value < 0) + money = -money; + + // If we were less than scale 2, multiply. + _value = (MoneyScale - scale) switch + { + 1 => money * 10, + 2 => money * 100, + _ => money + }; + } + + public long GetValue() => _value; + + public decimal ToDecimal() + { + var result = new decimal(_value); + var scaleFactor = new decimal(1, 0, 0, false, MoneyScale); + result *= scaleFactor; + return result; + } + + static void GetDecimalBits(decimal value, Span destination, out short scale) + { + Debug.Assert(destination.Length >= DecimalBits); + +#if NETSTANDARD + var raw = new DecimalRaw(value); + destination[0] = raw.Low; + destination[1] = raw.Mid; + destination[2] = raw.High; + destination[3] = (uint)raw.Flags; + scale = raw.Scale; +#else + decimal.GetBits(value, MemoryMarshal.Cast(destination)); +#endif +#if NET7_0_OR_GREATER + scale = value.Scale; +#else + scale = (byte)(destination[3] >> 16); +#endif + } + +#if NETSTANDARD + // Zero-alloc access to the decimal bits on netstandard. + [StructLayout(LayoutKind.Explicit)] + readonly struct DecimalRaw + { + const int ScaleMask = 0x00FF0000; + const int ScaleShift = 16; + + // Do not change the order in which these fields are declared. It + // should be same as in the System.Decimal.DecCalc struct. + [FieldOffset(0)] + readonly decimal _value; + [FieldOffset(0)] + readonly int _flags; + [FieldOffset(4)] + readonly uint _high; + [FieldOffset(8)] + readonly ulong _low64; + + // Convenience aliased fields but their usage needs to take endianness into account. + [FieldOffset(8)] + readonly uint _low; + [FieldOffset(12)] + readonly uint _mid; + + public DecimalRaw(decimal value) : this() => _value = value; + + public uint High => _high; + public uint Mid => BitConverter.IsLittleEndian ? _mid : _low; + public uint Low => BitConverter.IsLittleEndian ? _low : _mid; + public int Flags => _flags; + public short Scale => (short)((_flags & ScaleMask) >> ScaleShift); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs b/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs new file mode 100644 index 0000000000..fad0fd50a9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs @@ -0,0 +1,462 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.InteropServices; +using static Npgsql.Internal.Converters.PgNumeric.Builder; + +namespace Npgsql.Internal.Converters; + +readonly struct PgNumeric +{ + // numeric digit count + weight + sign + scale + const int StructureByteCount = 4 * sizeof(short); + const int DecimalBits = 4; + const int StackAllocByteThreshold = 64 * sizeof(uint); + + readonly ushort _sign; + + public PgNumeric(ArraySegment digits, short weight, short sign, short scale) + { + Digits = digits; + Weight = weight; + _sign = (ushort)sign; + Scale = scale; + } + + /// Big endian array of numeric digits + public ArraySegment Digits { get; } + public short Weight { get; } + public short Sign => (short)_sign; + public short Scale { get; } + + public int GetByteCount() => GetByteCount(Digits.Count); + public static int GetByteCount(int digitCount) => StructureByteCount + digitCount * sizeof(short); + + static void GetDecimalBits(decimal value, Span destination, out short scale) + { + Debug.Assert(destination.Length >= DecimalBits); + +#if NETSTANDARD + var raw = new DecimalRaw(value); + destination[0] = raw.Low; + destination[1] = raw.Mid; + destination[2] = raw.High; + destination[3] = (uint)raw.Flags; + scale = raw.Scale; +#else + decimal.GetBits(value, MemoryMarshal.Cast(destination)); +#endif +#if NET7_0_OR_GREATER + scale = value.Scale; +#else + scale = (byte)(destination[3] >> 16); +#endif + } + + public static int GetDigitCount(decimal value) + { + Span bits = stackalloc uint[DecimalBits]; + GetDecimalBits(value, bits, out var scale); + bits = bits.Slice(0, DecimalBits - 1); + return GetDigitCountCore(bits, scale); + } + + public static int GetDigitCount(BigInteger value) + { +# if NETSTANDARD2_0 + var bits = value.ToByteArray().AsSpan(); + // Detect the presence of a padding byte and slice it away (as we don't have isUnsigned: true overloads on ns2.0). + if (value.Sign == 1 && bits.Length > 2 && (bits[bits.Length - 2] & 0x80) != 0 && bits[bits.Length - 1] == 0) + bits = bits.Slice(0, bits.Length - 1); + var uintRoundedByteCount = (bits.Length + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +# else + var absValue = BigInteger.Abs(value); // isUnsigned: true fails for negative values. + var uintRoundedByteCount = (absValue.GetByteCount(isUnsigned: true) + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +#endif + byte[]? uintRoundedBitsFromPool = null; + var uintRoundedBits = (uintRoundedByteCount <= StackAllocByteThreshold + ? stackalloc byte[StackAllocByteThreshold] + : uintRoundedBitsFromPool = ArrayPool.Shared.Rent(uintRoundedByteCount) + ).Slice(0, uintRoundedByteCount); + // Fill the last uint worth of bytes as it may only be partially written to. + uintRoundedBits.Slice(uintRoundedBits.Length - sizeof(uint)).Fill(0); + +#if NETSTANDARD2_0 + bits.CopyTo(uintRoundedBits); +#else + var success = absValue.TryWriteBytes(uintRoundedBits, out _, isUnsigned: true); + Debug.Assert(success); +#endif + var uintBits = MemoryMarshal.Cast(uintRoundedBits); + if (!BitConverter.IsLittleEndian) + for (var i = 0; i < uintBits.Length; i++) + uintBits[i] = BinaryPrimitives.ReverseEndianness(uintBits[i]); + + var size = GetDigitCountCore(uintBits, scale: 0); + + if (uintRoundedBitsFromPool is not null) + ArrayPool.Shared.Return(uintRoundedBitsFromPool); + + return size; + } + + public decimal ToDecimal() => Builder.ToDecimal(Scale, Weight, _sign, Digits); + public BigInteger ToBigInteger() => Builder.ToBigInteger(Weight, _sign, Digits); + + public readonly ref struct Builder + { + const ushort SignPositive = 0x0000; + const ushort SignNegative = 0x4000; + const ushort SignNan = 0xC000; + const ushort SignPinf = 0xD000; + const ushort SignNinf = 0xF000; + + const uint NumericBase = 10000; + const int NumericBaseLog10 = 4; // log10(10000) + + internal const int MaxDecimalNumericDigits = 8; + + // Fast access for 10^n where n is 0-9 + static ReadOnlySpan UIntPowers10 => new uint[] { + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000 + }; + + const int MaxUInt32Scale = 9; + const int MaxUInt16Scale = 4; + + public short Weight { get; } + + readonly ushort _sign; + public short Sign => (short)_sign; + + public short Scale { get; } + public Span Digits { get; } + readonly short[]? _digitsArray; + + public Builder(Span digits, short weight, short sign, short scale) + { + Digits = digits; + Weight = weight; + _sign = (ushort)sign; + Scale = scale; + } + + public Builder(short[] digits, short weight, short sign, short scale) + { + Digits = _digitsArray = digits; + Weight = weight; + _sign = (ushort)sign; + Scale = scale; + } + + [Conditional("DEBUG")] + static void AssertInvariants() + { + Debug.Assert(UIntPowers10.Length >= NumericBaseLog10); + Debug.Assert(NumericBase < short.MaxValue); + } + + static void Create(ref short[]? digitsArray, ref Span destination, scoped Span bits, short scale, out short weight, out int digitCount) + { + AssertInvariants(); + digitCount = 0; + var digitWeight = -scale / NumericBaseLog10 - 1; + + var bitsUpperBound = (bits.Length * (MaxUInt32Scale + 1) + MaxUInt16Scale - 1) / MaxUInt16Scale + 1; + if (bitsUpperBound > destination.Length) + destination = digitsArray = new short[bitsUpperBound]; + + // When the given scale does not sit on a numeric digit boundary we divide once by the remainder power of 10 instead of the base. + // As a result the quotient is aligned to a digit boundary, we must then scale up the remainder by the missed power of 10 to compensate. + var scaleRemainder = scale % NumericBaseLog10; + if (scaleRemainder > 0 && DivideInPlace(bits, UIntPowers10[scaleRemainder], out var remainder) && remainder != 0) + { + remainder *= UIntPowers10[NumericBaseLog10 - scaleRemainder]; + digitWeight--; + destination[destination.Length - 1 - digitCount++] = (short)remainder; + } + while (DivideInPlace(bits, NumericBase, out remainder)) + { + // Initial zero remainders are skipped as these present trailing zero digits, which should not be stored. + if (digitCount == 0 && remainder == 0) + digitWeight++; + else + // We store the results starting from the end so the final digits end up in big endian. + destination[destination.Length - 1 - digitCount++] = (short)remainder; + } + + weight = (short)(digitWeight + digitCount); + + } + + public Builder(decimal value, Span destination) + { + Span bits = stackalloc uint[DecimalBits]; + GetDecimalBits(value, bits, out var scale); + bits = bits.Slice(0, DecimalBits - 1); + + Create(ref _digitsArray, ref destination, bits, scale, out var weight, out var digitCount); + Digits = destination.Slice(destination.Length - digitCount); + Weight = weight; + _sign = value < 0 ? SignNegative : SignPositive; + Scale = scale; + } + + /// + /// + /// + /// + /// If the destination ends up being too small the builder allocates instead + public Builder(BigInteger value, Span destination) + { +# if NETSTANDARD2_0 + var bits = value.ToByteArray().AsSpan(); + // Detect the presence of a padding byte and slice it away (as we don't have isUnsigned: true overloads on ns2.0). + if (value.Sign == 1 && bits.Length > 2 && (bits[bits.Length - 2] & 0x80) != 0 && bits[bits.Length - 1] == 0) + bits = bits.Slice(0, bits.Length - 1); + var uintRoundedByteCount = (bits.Length + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +# else + var absValue = BigInteger.Abs(value); // isUnsigned: true fails for negative values. + var uintRoundedByteCount = (absValue.GetByteCount(isUnsigned: true) + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +#endif + byte[]? uintRoundedBitsFromPool = null; + var uintRoundedBits = (uintRoundedByteCount <= StackAllocByteThreshold + ? stackalloc byte[StackAllocByteThreshold] + : uintRoundedBitsFromPool = ArrayPool.Shared.Rent(uintRoundedByteCount) + ).Slice(0, uintRoundedByteCount); + // Fill the last uint worth of bytes as it may only be partially written to. + uintRoundedBits.Slice(uintRoundedBits.Length - sizeof(uint)).Fill(0); + +#if NETSTANDARD2_0 + bits.CopyTo(uintRoundedBits); +#else + var success = absValue.TryWriteBytes(uintRoundedBits, out _, isUnsigned: true); + Debug.Assert(success); +#endif + var uintBits = MemoryMarshal.Cast(uintRoundedBits); + + // Our calculations are all done in little endian, meaning the least significant *uint* is first, just like in BigInteger. + // The bytes comprising every individual uint should still be converted to big endian though. + // As a result an array of bytes like [ 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8 ] should become [ 0x4, 0x3, 0x2, 0x1, 0x8, 0x7, 0x6, 0x5 ]. + if (!BitConverter.IsLittleEndian) + for (var i = 0; i < uintBits.Length; i++) + uintBits[i] = BinaryPrimitives.ReverseEndianness(uintBits[i]); + + Create(ref _digitsArray, ref destination, uintBits, scale: 0, out var weight, out var digitCount); + Digits = destination.Slice(destination.Length - digitCount); + Weight = weight; + _sign = value < 0 ? SignNegative : SignPositive; + Scale = 0; + + if (uintRoundedBitsFromPool is not null) + ArrayPool.Shared.Return(uintRoundedBitsFromPool); + } + + public PgNumeric Build() + { + var digitsArray = _digitsArray is not null + ? new ArraySegment(_digitsArray, _digitsArray.Length - Digits.Length, Digits.Length) + : new ArraySegment(Digits.ToArray()); + + return new(digitsArray, Weight, Sign, Scale); + } + + public decimal ToDecimal() => ToDecimal(Scale, Weight, _sign, Digits); + public BigInteger ToBigInteger() => ToBigInteger(Weight, _sign, Digits); + + int DigitCount => Digits.Length; + + /// + /// + /// + /// + /// + /// + /// Whether the input consists of any non zero bits + static bool DivideInPlace(Span left, uint right, out uint remainder) + => Divide(left, right, left, out remainder); + + /// Adapted from BigInteger, to allow us to operate directly on stack allocated bits + static bool Divide(ReadOnlySpan left, uint right, Span quotient, out uint remainder) + { + Debug.Assert(quotient.Length == left.Length); + + // Executes the division for one big and one 32-bit integer. + // Thus, we've similar code than below, but there is no loop for + // processing the 32-bit integer, since it's a single element. + + var carry = 0UL; + + var nonZeroInput = false; + for (var i = left.Length - 1; i >= 0; i--) + { + var value = (carry << 32) | left[i]; + nonZeroInput = nonZeroInput || value != 0; + var digit = value / right; + quotient[i] = (uint)digit; + carry = value - digit * right; + } + remainder = (uint)carry; + + return nonZeroInput; + } + + internal static int GetDigitCountCore(Span bits, int scale) + { + AssertInvariants(); + // When a fractional result is expected we must send two numeric digits. + // When the given scale does not sit on a numeric digit boundary- + // we divide once by the remaining power of 10 instead of the full base to align things. + var baseLogRemainder = scale % NumericBaseLog10; + var den = baseLogRemainder > 0 ? UIntPowers10[baseLogRemainder] : NumericBase; + var digits = 0; + while (DivideInPlace(bits, den, out var remainder)) + { + den = NumericBase; + // Initial zero remainders are skipped as these present trailing zero digits, which should not be transmitted. + if (digits != 0 || remainder != 0) + digits++; + } + + return digits; + } + + internal static decimal ToDecimal(short scale, short weight, ushort sign, Span digits) + { + const int MaxUIntScale = 9; + const int MaxDecimalScale = 28; + + var digitCount = digits.Length; + if (digitCount > MaxDecimalNumericDigits) + throw new OverflowException("Numeric value does not fit in a System.Decimal"); + + if (Math.Abs(scale) > MaxDecimalScale) + throw new OverflowException("Numeric value does not fit in a System.Decimal"); + + if (digitCount == 0) + return sign switch + { + SignPositive or SignNegative => decimal.Zero, + SignNan => throw new InvalidCastException("Numeric NaN not supported by System.Decimal"), + SignPinf => throw new InvalidCastException("Numeric Infinity not supported by System.Decimal"), + SignNinf => throw new InvalidCastException("Numeric -Infinity not supported by System.Decimal"), + _ => throw new ArgumentOutOfRangeException() + }; + + var numericBase = new decimal(NumericBase); + var result = decimal.Zero; + for (var i = 0; i < digitCount - 1; i++) + { + result *= numericBase; + result += digits[i]; + } + + var digitScale = (weight + 1 - digitCount) * NumericBaseLog10; + var scaleDifference = scale < 0 ? digitScale : digitScale + scale; + + var digit = digits[digitCount - 1]; + if (digitCount == MaxDecimalNumericDigits) + { + // On the max group we adjust the base based on the scale difference, to prevent overflow for valid values. + var pow = UIntPowers10[-scaleDifference]; + result *= numericBase / pow; + result += new decimal(digit / pow); + } + else + { + result *= numericBase; + result += digit; + + if (scaleDifference < 0) + result /= UIntPowers10[-scaleDifference]; + else + while (scaleDifference > 0) + { + var scaleChunk = Math.Min(MaxUIntScale, scaleDifference); + result *= UIntPowers10[scaleChunk]; + scaleDifference -= scaleChunk; + } + } + + var scaleFactor = new decimal(1, 0, 0, false, (byte)(scale > 0 ? scale : 0)); + result *= scaleFactor; + return sign == SignNegative ? -result : result; + } + + internal static BigInteger ToBigInteger(short weight, ushort sign, Span digits) + { + var digitCount = digits.Length; + if (digitCount == 0) + return sign switch + { + SignPositive or SignNegative => BigInteger.Zero, + SignNan => throw new InvalidCastException("Numeric NaN not supported by BigInteger"), + SignPinf => throw new InvalidCastException("Numeric Infinity not supported by BigInteger"), + SignNinf => throw new InvalidCastException("Numeric -Infinity not supported by BigInteger"), + _ => throw new ArgumentOutOfRangeException() + }; + + var digitWeight = weight + 1 - digitCount; + if (digitWeight < 0) + throw new InvalidCastException("Numeric value with non-zero fractional digits not supported by BigInteger"); + + var numericBase = new BigInteger(NumericBase); + var result = BigInteger.Zero; + foreach (var digit in digits) + { + result *= numericBase; + result += new BigInteger(digit); + } + + var exponentCorrection = BigInteger.Pow(numericBase, digitWeight); + result *= exponentCorrection; + return sign == SignNegative ? -result : result; + } + } + +#if NETSTANDARD + // Zero-alloc access to the decimal bits on netstandard. + [StructLayout(LayoutKind.Explicit)] + readonly struct DecimalRaw + { + const int ScaleMask = 0x00FF0000; + const int ScaleShift = 16; + + // Do not change the order in which these fields are declared. It + // should be same as in the System.Decimal.DecCalc struct. + [FieldOffset(0)] + readonly decimal _value; + [FieldOffset(0)] + readonly int _flags; + [FieldOffset(4)] + readonly uint _high; + [FieldOffset(8)] + readonly ulong _low64; + + // Convenience aliased fields but their usage needs to take endianness into account. + [FieldOffset(8)] + readonly uint _low; + [FieldOffset(12)] + readonly uint _mid; + + public DecimalRaw(decimal value) : this() => _value = value; + + public uint High => _high; + public uint Mid => BitConverter.IsLittleEndian ? _mid : _low; + public uint Low => BitConverter.IsLittleEndian ? _low : _mid; + public int Flags => _flags; + public short Scale => (short)((_flags & ScaleMask) >> ScaleShift); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs b/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs new file mode 100644 index 0000000000..b47e641aa5 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class RealConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(float)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadFloat()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteFloat(float.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadFloat(); + if (typeof(float) == typeof(T)) + return (T)(object)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(float) == typeof(T)) + writer.WriteFloat((float)(object)value!); + else if (typeof(double) == typeof(T)) + writer.WriteFloat((float)(double)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs b/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs new file mode 100644 index 0000000000..8fc04f1360 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs @@ -0,0 +1,355 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +abstract class StringBasedTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + protected StringBasedTextConverter(Encoding encoding) => _encoding = encoding; + + public override T Read(PgReader reader) + => Read(async: false, reader, _encoding).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, _encoding, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => TextConverter.GetSize(ref context, ConvertTo(value), _encoding); + + public override void Write(PgWriter writer, T value) + => writer.WriteChars(ConvertTo(value).Span, _encoding); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => writer.WriteCharsAsync(ConvertTo(value), _encoding, cancellationToken); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary or DataFormat.Text; + } + + protected abstract ReadOnlyMemory ConvertTo(T value); + protected abstract T ConvertFrom(string value); + + ValueTask Read(bool async, PgReader reader, Encoding encoding, CancellationToken cancellationToken = default) + { + return async + ? ReadAsync(reader, encoding, cancellationToken) + : new(ConvertFrom(encoding.GetString(reader.ReadBytes(reader.CurrentRemaining)))); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask ReadAsync(PgReader reader, Encoding encoding, CancellationToken cancellationToken) + => ConvertFrom(encoding.GetString(await reader.ReadBytesAsync(reader.CurrentRemaining, cancellationToken).ConfigureAwait(false))); + } +} + +sealed class ReadOnlyMemoryTextConverter : StringBasedTextConverter> +{ + public ReadOnlyMemoryTextConverter(Encoding encoding) : base(encoding) { } + protected override ReadOnlyMemory ConvertTo(ReadOnlyMemory value) => value; + protected override ReadOnlyMemory ConvertFrom(string value) => value.AsMemory(); +} + +sealed class StringTextConverter : StringBasedTextConverter +{ + public StringTextConverter(Encoding encoding) : base(encoding) { } + protected override ReadOnlyMemory ConvertTo(string value) => value.AsMemory(); + protected override string ConvertFrom(string value) => value; +} + +abstract class ArrayBasedTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + protected ArrayBasedTextConverter(Encoding encoding) => _encoding = encoding; + + public override T Read(PgReader reader) + => Read(async: false, reader, _encoding).GetAwaiter().GetResult(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, _encoding); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => TextConverter.GetSize(ref context, ConvertTo(value), _encoding); + + public override void Write(PgWriter writer, T value) + => writer.WriteChars(ConvertTo(value).AsSpan(), _encoding); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => writer.WriteCharsAsync(ConvertTo(value), _encoding, cancellationToken); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary or DataFormat.Text; + } + + protected abstract ArraySegment ConvertTo(T value); + protected abstract T ConvertFrom(ArraySegment value); + + ValueTask Read(bool async, PgReader reader, Encoding encoding) + { + return async ? ReadAsync(reader, encoding) : new(ConvertFrom(GetSegment(reader.ReadBytes(reader.CurrentRemaining), encoding))); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask ReadAsync(PgReader reader, Encoding encoding) + => ConvertFrom(GetSegment(await reader.ReadBytesAsync(reader.CurrentRemaining).ConfigureAwait(false), encoding)); + + static ArraySegment GetSegment(ReadOnlySequence bytes, Encoding encoding) + { + var array = TextConverter.GetChars(encoding, bytes); + return new(array, 0, array.Length); + } + } +} + +sealed class CharArraySegmentTextConverter : ArrayBasedTextConverter> +{ + public CharArraySegmentTextConverter(Encoding encoding) : base(encoding) { } + protected override ArraySegment ConvertTo(ArraySegment value) => value; + protected override ArraySegment ConvertFrom(ArraySegment value) => value; +} + +sealed class CharArrayTextConverter : ArrayBasedTextConverter +{ + public CharArrayTextConverter(Encoding encoding) : base(encoding) { } + protected override ArraySegment ConvertTo(char[] value) => new(value, 0, value.Length); + protected override char[] ConvertFrom(ArraySegment value) + { + if (value.Array?.Length == value.Count) + return value.Array!; + + var array = new char[value.Count]; + Array.Copy(value.Array!, value.Offset, array, 0, value.Count); + return array; + } +} + +sealed class CharTextConverter : PgBufferedConverter +{ + readonly Encoding _encoding; + readonly Size _oneCharMaxByteCount; + + public CharTextConverter(Encoding encoding) + { + _encoding = encoding; + _oneCharMaxByteCount = Size.CreateUpperBound(encoding.GetMaxByteCount(1)); + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(_oneCharMaxByteCount); + return format is DataFormat.Binary or DataFormat.Text; + } + + protected override char ReadCore(PgReader reader) + { + var byteSeq = reader.ReadBytes(Math.Min(_oneCharMaxByteCount.Value, reader.CurrentRemaining)); + Debug.Assert(byteSeq.IsSingleSegment); + var bytes = byteSeq.GetFirstSpan(); + + var chars = _encoding.GetCharCount(bytes); + if (chars < 1) + throw new NpgsqlException("Could not read char - string was empty"); + + Span destination = stackalloc char[chars]; + _encoding.GetChars(bytes, destination); + return destination[0]; + } + + public override Size GetSize(SizeContext context, char value, ref object? writeState) + { + Span spanValue = stackalloc char[] { value }; + return _encoding.GetByteCount(spanValue); + } + + protected override void WriteCore(PgWriter writer, char value) + { + Span spanValue = stackalloc char[] { value }; + writer.WriteChars(spanValue, _encoding); + } +} + +sealed class TextReaderTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + public TextReaderTextConverter(Encoding encoding) => _encoding = encoding; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary or DataFormat.Text; + } + + public override TextReader Read(PgReader reader) + => reader.GetTextReader(_encoding); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => reader.GetTextReaderAsync(_encoding, cancellationToken); + + public override Size GetSize(SizeContext context, TextReader value, ref object? writeState) => throw new NotImplementedException(); + public override void Write(PgWriter writer, TextReader value) => throw new NotImplementedException(); + public override ValueTask WriteAsync(PgWriter writer, TextReader value, CancellationToken cancellationToken = default) => throw new NotImplementedException(); +} + + +readonly struct GetChars +{ + public int Read { get; } + public GetChars(int read) => Read = read; +} + +sealed class GetCharsTextConverter : PgStreamingConverter, IResumableRead +{ + readonly Encoding _encoding; + public GetCharsTextConverter(Encoding encoding) => _encoding = encoding; + + public override GetChars Read(PgReader reader) + => reader.IsCharsRead + ? ResumableRead(reader) + : throw new NotSupportedException(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => throw new NotSupportedException(); + + public override Size GetSize(SizeContext context, GetChars value, ref object? writeState) => throw new NotSupportedException(); + public override void Write(PgWriter writer, GetChars value) => throw new NotSupportedException(); + public override ValueTask WriteAsync(PgWriter writer, GetChars value, CancellationToken cancellationToken = default) => throw new NotSupportedException(); + + GetChars ResumableRead(PgReader reader) + { + reader.GetCharsReadInfo(_encoding, out var charsRead, out var textReader, out var charsOffset, out var buffer); + if (charsOffset < charsRead || (buffer is null && charsRead > 0)) + { + // With variable length encodings, moving backwards based on bytes means we have to start over. + reader.ResetCharsRead(out charsRead); + } + + // First seek towards the charsOffset. + // If buffer is null read the entire thing and report the length, see sql client remarks. + // https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqldatareader.getchars + int read; + if (buffer is null) + { + read = ConsumeChars(textReader, null); + } + else + { + var consumed = ConsumeChars(textReader, charsOffset - charsRead); + Debug.Assert(consumed == charsOffset - charsRead); + read = textReader.ReadBlock(buffer.GetValueOrDefault().Array!, buffer.GetValueOrDefault().Offset, buffer.GetValueOrDefault().Count); + } + + return new(read); + + static int ConsumeChars(TextReader reader, int? count) + { + if (count is 0) + return 0; + + const int maxStackAlloc = 512; +#if NETSTANDARD + var tempCharBuf = new char[maxStackAlloc]; +#else + Span tempCharBuf = stackalloc char[maxStackAlloc]; +#endif + var totalRead = 0; + var fin = false; + while (!fin) + { + var toRead = count is null ? maxStackAlloc : Math.Min(maxStackAlloc, count.Value - totalRead); +#if NETSTANDARD + var read = reader.ReadBlock(tempCharBuf, 0, toRead); +#else + var read = reader.ReadBlock(tempCharBuf.Slice(0, toRead)); +#endif + totalRead += read; + if (count is not null && read is 0) + throw new EndOfStreamException(); + + fin = count is null ? read is 0 : totalRead >= count; + } + return totalRead; + } + } + + bool IResumableRead.Supported => true; +} + +// Moved out for code size/sharing. +static class TextConverter +{ + public static Size GetSize(ref SizeContext context, ReadOnlyMemory value, Encoding encoding) + => encoding.GetByteCount(value.Span); + + // Adapted version of GetString(ROSeq) removing the intermediate string allocation to make a contiguous char array. + public static char[] GetChars(Encoding encoding, ReadOnlySequence bytes) + { + if (bytes.IsSingleSegment) + { + // If the incoming sequence is single-segment, one-shot this. + var firstSpan = bytes.First.Span; + var chars = new char[encoding.GetCharCount(firstSpan)]; + encoding.GetChars(bytes.First.Span, chars); + return chars; + } + else + { + // If the incoming sequence is multi-segment, create a stateful Decoder + // and use it as the workhorse. On the final iteration we'll pass flush=true. + + var decoder = encoding.GetDecoder(); + + // Maintain a list of all the segments we'll need to concat together. + // These will be released back to the pool at the end of the method. + + var listOfSegments = new List<(char[], int)>(); + var totalCharCount = 0; + + var remainingBytes = bytes; + bool isFinalSegment; + + do + { + var firstSpan = remainingBytes.First.Span; + var next = remainingBytes.GetPosition(firstSpan.Length); + isFinalSegment = remainingBytes.IsSingleSegment; + + var charCountThisIteration = decoder.GetCharCount(firstSpan, flush: isFinalSegment); // could throw ArgumentException if overflow would occur + var rentedArray = ArrayPool.Shared.Rent(charCountThisIteration); + var actualCharsWrittenThisIteration = decoder.GetChars(firstSpan, rentedArray, flush: isFinalSegment); + listOfSegments.Add((rentedArray, actualCharsWrittenThisIteration)); + + totalCharCount += actualCharsWrittenThisIteration; + if (totalCharCount < 0) + throw new OutOfMemoryException(); + + remainingBytes = remainingBytes.Slice(next); + } while (!isFinalSegment); + + // Now build up the string to return, then release all of our scratch buffers + // back to the shared pool. + var chars = new char[totalCharCount]; + var span = chars.AsSpan(); + foreach (var (array, length) in listOfSegments) + { + array.AsSpan(0, length).CopyTo(span); + ArrayPool.Shared.Return(array); + span = span.Slice(length); + } + + return chars; + } + } +} diff --git a/src/Npgsql/Internal/Converters/RangeConverter.cs b/src/Npgsql/Internal/Converters/RangeConverter.cs new file mode 100644 index 0000000000..c378d830f7 --- /dev/null +++ b/src/Npgsql/Internal/Converters/RangeConverter.cs @@ -0,0 +1,216 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +namespace Npgsql.Internal.Converters; + +sealed class RangeConverter : PgStreamingConverter> +{ + readonly PgConverter _subtypeConverter; + readonly BufferRequirements _subtypeRequirements; + + public RangeConverter(PgConverter subtypeConverter) + { + if (!subtypeConverter.CanConvert(DataFormat.Binary, out var bufferRequirements)) + throw new NotSupportedException("Range subtype converter has to support the binary format to be compatible."); + _subtypeRequirements = bufferRequirements; + _subtypeConverter = subtypeConverter; + } + + public override NpgsqlRange Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask> ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask> Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + var flags = (RangeFlags)reader.ReadByte(); + if ((flags & RangeFlags.Empty) != 0) + return NpgsqlRange.Empty; + + var lowerBound = default(TSubtype); + var upperBound = default(TSubtype); + + var converter = _subtypeConverter; + if ((flags & RangeFlags.LowerBoundInfinite) == 0) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var length = reader.ReadInt32(); + + // Note that we leave the CLR default for nulls + if (length != -1) + { + var scope = await reader.BeginNestedRead(async, length, _subtypeRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + lowerBound = async + ? await converter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : converter.Read(reader); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + } + + if ((flags & RangeFlags.UpperBoundInfinite) == 0) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var length = reader.ReadInt32(); + + // Note that we leave the CLR default for nulls + if (length != -1) + { + var scope = await reader.BeginNestedRead(async, length, _subtypeRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + upperBound = async + ? await converter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : converter.Read(reader); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + } + + return new NpgsqlRange(lowerBound, upperBound, flags); + } + + public override Size GetSize(SizeContext context, NpgsqlRange value, ref object? writeState) + { + var totalSize = Size.Create(1); + if (value.IsEmpty) + return totalSize; // Just flags. + + WriteState? state = null; + if (!value.LowerBoundInfinite) + { + totalSize = totalSize.Combine(sizeof(int)); + var subTypeState = (object?)null; + if (_subtypeConverter.GetSizeOrDbNull(context.Format, _subtypeRequirements.Write, value.LowerBound, ref subTypeState) is { } size) + { + totalSize = totalSize.Combine(size); + (state ??= new WriteState()).LowerBoundSize = size; + state.LowerBoundWriteState = subTypeState; + } + else if (state is not null) + state.LowerBoundSize = -1; + } + + if (!value.UpperBoundInfinite) + { + totalSize = totalSize.Combine(sizeof(int)); + var subTypeState = (object?)null; + if (_subtypeConverter.GetSizeOrDbNull(context.Format, _subtypeRequirements.Write, value.UpperBound, ref subTypeState) is { } size) + { + totalSize = totalSize.Combine(size); + (state ??= new WriteState()).UpperBoundSize = size; + state.UpperBoundWriteState = subTypeState; + } + else if (state is not null) + state.UpperBoundSize = -1; + } + + writeState = state; + return totalSize; + } + + public override void Write(PgWriter writer, NpgsqlRange value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlRange value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlRange value, CancellationToken cancellationToken) + { + var writeState = writer.Current.WriteState as WriteState; + var lowerBoundSize = writeState?.LowerBoundSize ?? -1; + var upperBoundSize = writeState?.UpperBoundSize ?? -1; + + var flags = value.Flags; + if (!value.IsEmpty) + { + // Normalize nulls to infinite, as pg does. + if (lowerBoundSize == -1 && !value.LowerBoundInfinite) + flags = (flags & ~RangeFlags.LowerBoundInclusive) | RangeFlags.LowerBoundInfinite; + + if (upperBoundSize == -1 && !value.UpperBoundInfinite) + flags = (flags & ~RangeFlags.UpperBoundInclusive) | RangeFlags.UpperBoundInfinite; + } + + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte((byte)flags); + var lowerBoundInfinite = flags.HasFlag(RangeFlags.LowerBoundInfinite); + var upperBoundInfinite = flags.HasFlag(RangeFlags.UpperBoundInfinite); + if (value.IsEmpty || (lowerBoundInfinite && upperBoundInfinite)) + return; + + // Always need write state from this point. + if (writeState is null) + throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}."); + + if (!lowerBoundInfinite) + { + Debug.Assert(lowerBoundSize.Value != -1); + if (lowerBoundSize.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var byteCount = lowerBoundSize.Value; // Never -1 so it's a byteCount. + if (writer.ShouldFlush(sizeof(int))) // Length + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(byteCount); + using var _ = await writer.BeginNestedWrite(async, _subtypeRequirements.Write, byteCount, + writeState.LowerBoundWriteState, cancellationToken).ConfigureAwait(false); + if (async) + await _subtypeConverter.WriteAsync(writer, value.LowerBound!, cancellationToken).ConfigureAwait(false); + else + _subtypeConverter.Write(writer, value.LowerBound!); + } + + if (!upperBoundInfinite) + { + Debug.Assert(upperBoundSize.Value != -1); + if (upperBoundSize.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var byteCount = upperBoundSize.Value; // Never -1 so it's a byteCount. + if (writer.ShouldFlush(sizeof(int))) // Length + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(byteCount); + using var _ = await writer.BeginNestedWrite(async, _subtypeRequirements.Write, byteCount, + writeState.UpperBoundWriteState, cancellationToken).ConfigureAwait(false); + if (async) + await _subtypeConverter.WriteAsync(writer, value.UpperBound!, cancellationToken).ConfigureAwait(false); + else + _subtypeConverter.Write(writer, value.UpperBound!); + } + } + + sealed class WriteState + { + internal Size LowerBoundSize { get; set; } + internal object? LowerBoundWriteState { get; set; } + internal Size UpperBoundSize { get; set; } + internal object? UpperBoundWriteState { get; set; } + } +} diff --git a/src/Npgsql/Internal/Converters/SystemTextJsonConverter.cs b/src/Npgsql/Internal/Converters/SystemTextJsonConverter.cs new file mode 100644 index 0000000000..cedf1664f2 --- /dev/null +++ b/src/Npgsql/Internal/Converters/SystemTextJsonConverter.cs @@ -0,0 +1,205 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class SystemTextJsonConverter : PgStreamingConverter where T: TBase? +{ + readonly bool _jsonb; + readonly Encoding _textEncoding; + readonly JsonTypeInfo _jsonTypeInfo; + readonly JsonTypeInfo? _objectTypeInfo; + + public SystemTextJsonConverter(bool jsonb, Encoding textEncoding, JsonSerializerOptions serializerOptions) + { + // We do GetTypeInfo calls directly so we need a resolver. + if (serializerOptions.TypeInfoResolver is null) + serializerOptions.TypeInfoResolver = new DefaultJsonTypeInfoResolver(); + + _jsonb = jsonb; + _textEncoding = textEncoding; + _jsonTypeInfo = typeof(TBase) != typeof(object) && typeof(T) != typeof(TBase) + ? (JsonTypeInfo)serializerOptions.GetTypeInfo(typeof(TBase)) + : (JsonTypeInfo)serializerOptions.GetTypeInfo(typeof(T)); + // Unspecified polymorphism, let STJ handle it. + _objectTypeInfo = typeof(TBase) == typeof(object) + ? (JsonTypeInfo)serializerOptions.GetTypeInfo(typeof(object)) + : null; + } + + public override T? Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (_jsonb && reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + // We always fall back to buffers on older targets due to the absence of transcoding stream. + if (SystemTextJsonConverter.TryReadStream(_jsonb, _textEncoding, reader, out var byteCount, out var stream)) + { + using var _ = stream; + if (_jsonTypeInfo is JsonTypeInfo typeInfoOfT) + return async + ? await JsonSerializer.DeserializeAsync(stream, typeInfoOfT, cancellationToken).ConfigureAwait(false) + : JsonSerializer.Deserialize(stream, typeInfoOfT); + + return (T?)(async + ? await JsonSerializer.DeserializeAsync(stream, (JsonTypeInfo)_jsonTypeInfo, cancellationToken).ConfigureAwait(false) + : JsonSerializer.Deserialize(stream, (JsonTypeInfo)_jsonTypeInfo)); + } + else + { + var (rentedChars, rentedBytes) = await SystemTextJsonConverter.ReadRentedBuffer(async, _textEncoding, byteCount, reader, cancellationToken).ConfigureAwait(false); + var result = _jsonTypeInfo is JsonTypeInfo typeInfoOfT + ? JsonSerializer.Deserialize(rentedChars.AsSpan(), typeInfoOfT) + : (T?)JsonSerializer.Deserialize(rentedChars.AsSpan(), (JsonTypeInfo)_jsonTypeInfo); + + ArrayPool.Shared.Return(rentedChars.Array!); + if (rentedBytes is not null) + ArrayPool.Shared.Return(rentedBytes); + + return result; + } + } + + public override Size GetSize(SizeContext context, T? value, ref object? writeState) + { + var capacity = 0; + if (typeof(T) == typeof(JsonDocument)) + capacity = ((JsonDocument?)(object?)value)?.RootElement.GetRawText().Length ?? 0; + var stream = new MemoryStream(capacity); + + // Mirroring ASP.NET Core serialization strategy https://github.com/dotnet/aspnetcore/issues/47548 + if (_objectTypeInfo is null) + JsonSerializer.Serialize(stream, value, (JsonTypeInfo)_jsonTypeInfo); + else + JsonSerializer.Serialize(stream, value, _objectTypeInfo); + + return SystemTextJsonConverter.GetSizeCore(_jsonb, stream, _textEncoding, ref writeState); + } + + public override void Write(PgWriter writer, T? value) + => SystemTextJsonConverter.Write(_jsonb, async: false, writer, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) + => SystemTextJsonConverter.Write(_jsonb, async: true, writer, cancellationToken); +} + +// Split out to avoid unneccesary code duplication. +static class SystemTextJsonConverter +{ + public const byte JsonbProtocolVersion = 1; + // We pick a value that is the largest multiple of 4096 that is still smaller than the large object heap threshold (85K). + const int StreamingThreshold = 81920; + + public static bool TryReadStream(bool jsonb, Encoding encoding, PgReader reader, out int byteCount, [NotNullWhen(true)]out Stream? stream) + { + if (jsonb) + { + var version = reader.ReadByte(); + if (version != JsonbProtocolVersion) + throw new InvalidCastException($"Unknown jsonb wire format version {version}"); + } + + var isUtf8 = encoding.CodePage == Encoding.UTF8.CodePage; + byteCount = reader.CurrentRemaining; + // We always fall back to buffers on older targets + if (isUtf8 +#if !NETSTANDARD + || byteCount >= StreamingThreshold +#endif + ) + { + stream = +#if !NETSTANDARD + !isUtf8 + ? Encoding.CreateTranscodingStream(reader.GetStream(), encoding, Encoding.UTF8) + : reader.GetStream(); +#else + reader.GetStream(); + Debug.Assert(isUtf8); +#endif + } + else + stream = null; + + return stream is not null; + } + + public static async ValueTask<(ArraySegment RentedChars, byte[]? RentedBytes)> ReadRentedBuffer(bool async, Encoding encoding, int byteCount, PgReader reader, CancellationToken cancellationToken) + { + // Never utf8, but we may still be able to save a copy. + byte[]? rentedBuffer = null; + if (!reader.TryReadBytes(byteCount, out ReadOnlyMemory buffer)) + { + rentedBuffer = ArrayPool.Shared.Rent(byteCount); + if (async) + await reader.ReadBytesAsync(rentedBuffer.AsMemory(0, byteCount), cancellationToken).ConfigureAwait(false); + else + reader.ReadBytes(rentedBuffer.AsSpan(0, byteCount)); + buffer = rentedBuffer.AsMemory(0, byteCount); + } + + var charCount = encoding.GetCharCount(buffer.Span); + var chars = ArrayPool.Shared.Rent(charCount); + encoding.GetChars(buffer.Span, chars); + + return (new(chars, 0, charCount), rentedBuffer); + } + + public static Size GetSizeCore(bool jsonb, MemoryStream stream, Encoding encoding, ref object? writeState) + { + if (encoding.CodePage == Encoding.UTF8.CodePage) + { + writeState = stream; + return (int)stream.Length + (jsonb ? sizeof(byte) : 0); + } + + if (!stream.TryGetBuffer(out var buffer)) + throw new InvalidOperationException(); + + var bytes = encoding.GetBytes(Encoding.UTF8.GetChars(buffer.Array!, buffer.Offset, buffer.Count)); + writeState = bytes; + return bytes.Length + (jsonb ? sizeof(byte) : 0); + } + + public static async ValueTask Write(bool jsonb, bool async, PgWriter writer, CancellationToken cancellationToken) + { + if (jsonb) + { + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte(JsonbProtocolVersion); + } + + ArraySegment buffer; + switch (writer.Current.WriteState) + { + case MemoryStream stream: + if (!stream.TryGetBuffer(out buffer)) + throw new InvalidOperationException(); + break; + case byte[] bytes: + buffer = new ArraySegment(bytes); + break; + default: + throw new InvalidCastException($"Invalid state {writer.Current.WriteState?.GetType().FullName}."); + } + + if (async) + await writer.WriteBytesAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(buffer.AsSpan()); + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs b/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs new file mode 100644 index 0000000000..261d305439 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs @@ -0,0 +1,103 @@ +using System; +using Npgsql.Properties; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DateTimeDateConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + static readonly DateTime BaseValue = new(2000, 1, 1, 0, 0, 0); + + public DateTimeDateConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + + protected override DateTime ReadCore(PgReader reader) + => reader.ReadInt32() switch + { + int.MaxValue => _dateTimeInfinityConversions + ? DateTime.MaxValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + int.MinValue => _dateTimeInfinityConversions + ? DateTime.MinValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + var value => BaseValue + TimeSpan.FromDays(value) + }; + + protected override void WriteCore(PgWriter writer, DateTime value) + { + if (_dateTimeInfinityConversions) + { + if (value == DateTime.MaxValue) + { + writer.WriteInt32(int.MaxValue); + return; + } + + if (value == DateTime.MinValue) + { + writer.WriteInt32(int.MinValue); + return; + } + } + + writer.WriteInt32((value.Date - BaseValue).Days); + } +} + +#if NET6_0_OR_GREATER +sealed class DateOnlyDateConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + static readonly DateOnly BaseValue = new(2000, 1, 1); + + public DateOnlyDateConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + + protected override DateOnly ReadCore(PgReader reader) + => reader.ReadInt32() switch + { + int.MaxValue => _dateTimeInfinityConversions + ? DateOnly.MaxValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + int.MinValue => _dateTimeInfinityConversions + ? DateOnly.MinValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + var value => BaseValue.AddDays(value) + }; + + protected override void WriteCore(PgWriter writer, DateOnly value) + { + if (_dateTimeInfinityConversions) + { + if (value == DateOnly.MaxValue) + { + writer.WriteInt32(int.MaxValue); + return; + } + + if (value == DateOnly.MinValue) + { + writer.WriteInt32(int.MinValue); + return; + } + } + + writer.WriteInt32(value.DayNumber - BaseValue.DayNumber); + } +} +#endif diff --git a/src/Npgsql/Internal/Converters/Temporal/DateTimeConverterResolver.cs b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverterResolver.cs new file mode 100644 index 0000000000..6ae5a783a1 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverterResolver.cs @@ -0,0 +1,143 @@ +using System; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DateTimeConverterResolver : PgConverterResolver +{ + readonly PgSerializerOptions _options; + readonly Func, T?, PgTypeId?, PgConverterResolution?> _resolver; + readonly Func _factory; + readonly PgTypeId _timestampTz; + PgConverter? _timestampTzConverter; + readonly PgTypeId _timestamp; + PgConverter? _timestampConverter; + readonly bool _dateTimeInfinityConversions; + + internal DateTimeConverterResolver(PgSerializerOptions options, Func, T?, PgTypeId?, PgConverterResolution?> resolver, Func factory, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + { + _options = options; + _resolver = resolver; + _factory = factory; + _timestampTz = timestampTz; + _timestamp = timestamp; + _dateTimeInfinityConversions = dateTimeInfinityConversions; + } + + public override PgConverterResolution GetDefault(PgTypeId? pgTypeId) + { + if (pgTypeId == _timestampTz) + return new(_timestampTzConverter ??= _factory(_timestampTz), _timestampTz); + if (pgTypeId is null || pgTypeId == _timestamp) + return new(_timestampConverter ??= _factory(_timestamp), _timestamp); + + throw CreateUnsupportedPgTypeIdException(pgTypeId.Value); + } + + public PgConverterResolution? Get(DateTime value, PgTypeId? expectedPgTypeId, bool validateOnly = false) + { + if (value.Kind is DateTimeKind.Utc) + { + // We coalesce with expectedPgTypeId to throw on unknown type ids. + return expectedPgTypeId == _timestamp + ? throw new ArgumentException( + string.Format(NpgsqlStrings.TimestampNoDateTimeUtc, _options.GetDataTypeName(_timestamp).DisplayName, _options.GetDataTypeName(_timestampTz).DisplayName), nameof(value)) + : validateOnly ? null : GetDefault(expectedPgTypeId ?? _timestampTz); + } + + // For timestamptz types we'll accept unspecified MinValue/MaxValue as well. + if (expectedPgTypeId == _timestampTz + && !(_dateTimeInfinityConversions && (value == DateTime.MinValue || value == DateTime.MaxValue))) + { + throw new ArgumentException( + string.Format(NpgsqlStrings.TimestampTzNoDateTimeUnspecified, value.Kind, _options.GetDataTypeName(_timestampTz).DisplayName), nameof(value)); + } + + // We coalesce with expectedPgTypeId to throw on unknown type ids. + return GetDefault(expectedPgTypeId ?? _timestamp); + } + + public override PgConverterResolution? Get(T? value, PgTypeId? expectedPgTypeId) + => _resolver(this, value, expectedPgTypeId); +} + +sealed class DateTimeConverterResolver +{ + public static DateTimeConverterResolver CreateResolver(PgSerializerOptions options, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + => new(options, static (resolver, value, expectedPgTypeId) => resolver.Get(value, expectedPgTypeId), pgTypeId => + { + if (pgTypeId == timestampTz) + return new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Utc); + if (pgTypeId == timestamp) + return new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Unspecified); + + throw new NotSupportedException(); + }, timestampTz, timestamp, dateTimeInfinityConversions); + + public static DateTimeConverterResolver> CreateRangeResolver(PgSerializerOptions options, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + => new(options, static (resolver, value, expectedPgTypeId) => + { + // Resolve both sides to make sure we end up with consistent PgTypeIds. + PgConverterResolution? resolution = null; + if (!value.LowerBoundInfinite) + resolution = resolver.Get(value.LowerBound, expectedPgTypeId); + + if (!value.UpperBoundInfinite) + { + var result = resolver.Get(value.UpperBound, resolution?.PgTypeId ?? expectedPgTypeId, validateOnly: resolution is not null); + resolution ??= result; + } + + return resolution; + }, pgTypeId => + { + if (pgTypeId == timestampTz) + return new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Utc)); + if (pgTypeId == timestamp) + return new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Unspecified)); + + throw new NotSupportedException(); + }, timestampTz, timestamp, dateTimeInfinityConversions); + + public static DateTimeConverterResolver CreateMultirangeResolver(PgSerializerOptions options, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + where T : IList where TElement : notnull + { + if (typeof(TElement) != typeof(NpgsqlRange)) + ThrowHelper.ThrowNotSupportedException("Unsupported element type"); + + return new DateTimeConverterResolver(options, static (resolver, value, expectedPgTypeId) => + { + PgConverterResolution? resolution = null; + if (value is null) + return null; + + foreach (var element in (IList>)value) + { + PgConverterResolution? result; + if (!element.LowerBoundInfinite) + { + result = resolver.Get(element.LowerBound, resolution?.PgTypeId ?? expectedPgTypeId, validateOnly: resolution is not null); + resolution ??= result; + } + if (!element.UpperBoundInfinite) + { + result = resolver.Get(element.UpperBound, resolution?.PgTypeId ?? expectedPgTypeId, validateOnly: resolution is not null); + resolution ??= result; + } + } + return resolution; + }, pgTypeId => + { + if (pgTypeId == timestampTz) + return new MultirangeConverter((PgConverter)(object)new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Utc))); + if (pgTypeId == timestamp) + return new MultirangeConverter((PgConverter)(object)new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Unspecified))); + + throw new NotSupportedException(); + }, timestampTz, timestamp, dateTimeInfinityConversions); + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs new file mode 100644 index 0000000000..0047e3c572 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs @@ -0,0 +1,53 @@ +using System; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly DateTimeKind _kind; + + public DateTimeConverter(bool dateTimeInfinityConversions, DateTimeKind kind) + { + _dateTimeInfinityConversions = dateTimeInfinityConversions; + _kind = kind; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTime ReadCore(PgReader reader) + => PgTimestamp.Decode(reader.ReadInt64(), _kind, _dateTimeInfinityConversions); + + protected override void WriteCore(PgWriter writer, DateTime value) + => writer.WriteInt64(PgTimestamp.Encode(value, _dateTimeInfinityConversions)); +} + +sealed class DateTimeOffsetConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + public DateTimeOffsetConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTimeOffset ReadCore(PgReader reader) + => PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); + + protected override void WriteCore(PgWriter writer, DateTimeOffset value) + { + if (value.Offset != TimeSpan.Zero) + throw new ArgumentException($"Cannot write DateTimeOffset with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', only offset 0 (UTC) is supported. ", nameof(value)); + + writer.WriteInt64(PgTimestamp.Encode(value.DateTime, _dateTimeInfinityConversions)); + + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/IntervalConverters.cs b/src/Npgsql/Internal/Converters/Temporal/IntervalConverters.cs new file mode 100644 index 0000000000..1e1cbe9df2 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/IntervalConverters.cs @@ -0,0 +1,58 @@ +using System; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TimeSpanIntervalConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override TimeSpan ReadCore(PgReader reader) + { + var microseconds = reader.ReadInt64(); + var days = reader.ReadInt32(); + var months = reader.ReadInt32(); + + return months > 0 + ? throw new InvalidCastException( + "Cannot read interval values with non-zero months as TimeSpan, since that type doesn't support months. Consider using NodaTime Period which better corresponds to PostgreSQL interval, or read the value as NpgsqlInterval, or transform the interval to not contain months or years in PostgreSQL before reading it.") + : new(microseconds * 10 + days * TimeSpan.TicksPerDay); + } + + protected override void WriteCore(PgWriter writer, TimeSpan value) + { + var ticksInDay = value.Ticks - TimeSpan.TicksPerDay * value.Days; + writer.WriteInt64(ticksInDay / 10); + writer.WriteInt32(value.Days); + writer.WriteInt32(0); + } +} + +sealed class NpgsqlIntervalConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override NpgsqlInterval ReadCore(PgReader reader) + { + var ticks = reader.ReadInt64(); + var day = reader.ReadInt32(); + var month = reader.ReadInt32(); + return new NpgsqlInterval(month, day, ticks); + } + + protected override void WriteCore(PgWriter writer, NpgsqlInterval value) + { + writer.WriteInt64(value.Time); + writer.WriteInt32(value.Days); + writer.WriteInt32(value.Months); + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs b/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs new file mode 100644 index 0000000000..5e6306da56 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs @@ -0,0 +1,62 @@ +using System; + +namespace Npgsql.Internal.Converters; + +sealed class LegacyDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly bool _timestamp; + + public LegacyDateTimeConverter(bool dateTimeInfinityConversions, bool timestamp) + { + _dateTimeInfinityConversions = dateTimeInfinityConversions; + _timestamp = timestamp; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTime ReadCore(PgReader reader) + { + var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); + return !_timestamp && (!_dateTimeInfinityConversions || dateTime != DateTime.MaxValue && dateTime != DateTime.MinValue) + ? dateTime.ToLocalTime() + : dateTime; + } + + protected override void WriteCore(PgWriter writer, DateTime value) + { + if (!_timestamp && value.Kind is DateTimeKind.Local) + value = value.ToUniversalTime(); + + writer.WriteInt64(PgTimestamp.Encode(value, _dateTimeInfinityConversions)); + } +} + +sealed class LegacyDateTimeOffsetConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public LegacyDateTimeOffsetConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTimeOffset ReadCore(PgReader reader) + { + var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); + return !_dateTimeInfinityConversions || dateTime != DateTime.MaxValue && dateTime != DateTime.MinValue + ? dateTime.ToLocalTime() + : dateTime; + } + + protected override void WriteCore(PgWriter writer, DateTimeOffset value) + => writer.WriteInt64(PgTimestamp.Encode(value.UtcDateTime, _dateTimeInfinityConversions)); +} diff --git a/src/Npgsql/Internal/Converters/Temporal/PgTimestamp.cs b/src/Npgsql/Internal/Converters/Temporal/PgTimestamp.cs new file mode 100644 index 0000000000..6a44ccbdc9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/PgTimestamp.cs @@ -0,0 +1,43 @@ +using System; + +namespace Npgsql.Internal.Converters; + +static class PgTimestamp +{ + const long PostgresTimestampOffsetTicks = 630822816000000000L; + + internal static long Encode(DateTime value, bool dateTimeInfinityConversions) + { + if (dateTimeInfinityConversions) + { + if (value.Ticks == DateTime.MaxValue.Ticks) + return long.MaxValue; + if (value.Ticks == DateTime.MinValue.Ticks) + return long.MinValue; + } + // Rounding here would cause problems because we would round up DateTime.MaxValue + // which would make it impossible to retrieve it back from the database, so we just drop the additional precision + return (value.Ticks - PostgresTimestampOffsetTicks) / 10; + } + + internal static DateTime Decode(long value, DateTimeKind kind, bool dateTimeInfinityConversions) + { + try + { + return value switch + { + long.MaxValue => dateTimeInfinityConversions + ? DateTime.MaxValue + : throw new InvalidCastException("Cannot read infinity value since DisableDateTimeInfinityConversions is true."), + long.MinValue => dateTimeInfinityConversions + ? DateTime.MinValue + : throw new InvalidCastException("Cannot read infinity value since DisableDateTimeInfinityConversions is true."), + _ => new(value * 10 + PostgresTimestampOffsetTicks, kind) + }; + } + catch (ArgumentOutOfRangeException e) + { + throw new InvalidCastException("Out of range of DateTime (year must be between 1 and 9999).", e); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs b/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs new file mode 100644 index 0000000000..d2fbf60fda --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs @@ -0,0 +1,52 @@ +using System; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TimeSpanTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + protected override TimeSpan ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); + protected override void WriteCore(PgWriter writer, TimeSpan value) => writer.WriteInt64(value.Ticks / 10); +} + +#if NET6_0_OR_GREATER +sealed class TimeOnlyTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + protected override TimeOnly ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); + protected override void WriteCore(PgWriter writer, TimeOnly value) => writer.WriteInt64(value.Ticks / 10); +} +#endif + +sealed class DateTimeOffsetTimeTzConverter : PgBufferedConverter +{ + // Binary Format: int64 expressing microseconds, int32 expressing timezone in seconds, negative + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override DateTimeOffset ReadCore(PgReader reader) + { + // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. + var ticks = reader.ReadInt64() * 10; + var offset = new TimeSpan(0, 0, -reader.ReadInt32()); + return new DateTimeOffset(ticks + TimeSpan.TicksPerDay, offset); + } + + protected override void WriteCore(PgWriter writer, DateTimeOffset value) + { + writer.WriteInt64(value.Ticks / 10); + writer.WriteInt32(-(int)(value.Offset.Ticks / TimeSpan.TicksPerSecond)); + } +} diff --git a/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs b/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs new file mode 100644 index 0000000000..d4776550fd --- /dev/null +++ b/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs @@ -0,0 +1,107 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class VersionPrefixedTextConverter : PgStreamingConverter, IResumableRead +{ + readonly byte _versionPrefix; + readonly PgConverter _textConverter; + BufferRequirements _innerRequirements; + + public VersionPrefixedTextConverter(byte versionPrefix, PgConverter textConverter) + : base(textConverter.DbNullPredicateKind is DbNullPredicate.Custom) + { + _versionPrefix = versionPrefix; + _textConverter = textConverter; + } + + protected override bool IsDbNullValue(T? value) => _textConverter.IsDbNull(value); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => VersionPrefixedTextConverter.CanConvert(_textConverter, format, out _innerRequirements, out bufferRequirements); + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public override Size GetSize(SizeContext context, [DisallowNull]T value, ref object? writeState) + => _textConverter.GetSize(context, value, ref writeState).Combine(context.Format is DataFormat.Binary ? sizeof(byte) : 0); + + public override void Write(PgWriter writer, [DisallowNull]T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, [DisallowNull]T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + await VersionPrefixedTextConverter.ReadVersion(async, _versionPrefix, reader, _innerRequirements.Read, cancellationToken).ConfigureAwait(false); + return async ? await _textConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) : _textConverter.Read(reader); + } + + async ValueTask Write(bool async, PgWriter writer, [DisallowNull]T value, CancellationToken cancellationToken) + { + await VersionPrefixedTextConverter.WriteVersion(async, _versionPrefix, writer, cancellationToken).ConfigureAwait(false); + if (async) + await _textConverter.WriteAsync(writer, value, cancellationToken).ConfigureAwait(false); + else + _textConverter.Write(writer, value); + } + + bool IResumableRead.Supported => _textConverter is IResumableRead { Supported: true }; +} + +static class VersionPrefixedTextConverter +{ + public static async ValueTask WriteVersion(bool async, byte version, PgWriter writer, CancellationToken cancellationToken) + { + if (writer.Current.Format is not DataFormat.Binary) + return; + + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte(version); + } + + public static async ValueTask ReadVersion(bool async, byte expectedVersion, PgReader reader, Size textConverterReadRequirement, CancellationToken cancellationToken) + { + if (reader.Current.Format is not DataFormat.Binary) + return; + + if (!reader.IsResumed) + { + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + var actualVersion = reader.ReadByte(); + if (actualVersion != expectedVersion) + throw new InvalidCastException($"Unknown wire format version: {actualVersion}"); + } + + // No need for a nested read, all text converters will read CurrentRemaining bytes. + // We only need to buffer data if we're binary, otherwise the caller would have had to do so + // as we directly expose the underlying text converter requirements for the text data format. + await reader.Buffer(async, textConverterReadRequirement, cancellationToken).ConfigureAwait(false); + } + + public static bool CanConvert(PgConverter textConverter, DataFormat format, out BufferRequirements textConverterRequirements, out BufferRequirements bufferRequirements) + { + var success = textConverter.CanConvert(format, out textConverterRequirements); + if (!success) + { + bufferRequirements = default; + return false; + } + if (textConverter.CanConvert(format is DataFormat.Binary ? DataFormat.Text : DataFormat.Binary, out var otherRequirements) && otherRequirements != textConverterRequirements) + throw new InvalidOperationException("Text converter should have identical requirements for text and binary formats."); + + bufferRequirements = format is DataFormat.Binary ? textConverterRequirements.Combine(sizeof(byte)) : textConverterRequirements; + + return success; + } +} diff --git a/src/Npgsql/Internal/DataFormat.cs b/src/Npgsql/Internal/DataFormat.cs new file mode 100644 index 0000000000..c9950ea417 --- /dev/null +++ b/src/Npgsql/Internal/DataFormat.cs @@ -0,0 +1,29 @@ +using System; +using System.Diagnostics; + +namespace Npgsql.Internal; + +public enum DataFormat : byte +{ + Binary, + Text +} + +static class DataFormatUtils +{ + public static DataFormat Create(short formatCode) + => formatCode switch + { + 0 => DataFormat.Text, + 1 => DataFormat.Binary, + _ => throw new ArgumentOutOfRangeException(nameof(formatCode), formatCode, "Unknown postgres format code, please file a bug,") + }; + + public static short ToFormatCode(this DataFormat dataFormat) + => dataFormat switch + { + DataFormat.Text => 0, + DataFormat.Binary => 1, + _ => throw new UnreachableException() + }; +} diff --git a/src/Npgsql/Internal/DynamicTypeInfoResolver.cs b/src/Npgsql/Internal/DynamicTypeInfoResolver.cs new file mode 100644 index 0000000000..22ffcd2248 --- /dev/null +++ b/src/Npgsql/Internal/DynamicTypeInfoResolver.cs @@ -0,0 +1,132 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; + +namespace Npgsql.Internal; + +[RequiresUnreferencedCode("A dynamic type info resolver may perform reflection on types that were trimmed if not referenced directly.")] +[RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] +public abstract class DynamicTypeInfoResolver : IPgTypeInfoResolver +{ + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (dataTypeName is null) + return null; + + var context = GetMappings(type, dataTypeName.GetValueOrDefault(), options); + return context?.Find(type, dataTypeName.GetValueOrDefault(), options); + } + + protected DynamicMappingCollection CreateCollection(TypeInfoMappingCollection? baseCollection = null) => new(baseCollection); + + protected static bool IsTypeOrNullableOfType(Type type, Func predicate, out Type matchedType) + { + matchedType = Nullable.GetUnderlyingType(type) ?? type; + return predicate(matchedType); + } + + protected static bool IsArrayLikeType(Type type, [NotNullWhen(true)]out Type? elementType) => TypeInfoMappingCollection.IsArrayLikeType(type, out elementType); + + protected static bool IsArrayDataTypeName(DataTypeName dataTypeName, PgSerializerOptions options, out DataTypeName elementDataTypeName) + { + if (options.DatabaseInfo.GetPostgresType(dataTypeName) is PostgresArrayType arrayType) + { + elementDataTypeName = arrayType.Element.DataTypeName; + return true; + } + + elementDataTypeName = default; + return false; + } + + protected abstract DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options); + + protected class DynamicMappingCollection + { + TypeInfoMappingCollection? _mappings; + + static readonly MethodInfo AddTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod(nameof(TypeInfoMappingCollection.AddType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddStructTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod(nameof(TypeInfoMappingCollection.AddStructType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddStructArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddStructArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod( + nameof(TypeInfoMappingCollection.AddResolverType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverStructTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod( + nameof(TypeInfoMappingCollection.AddResolverStructType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverStructArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverStructArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + internal DynamicMappingCollection(TypeInfoMappingCollection? baseCollection = null) + { + if (baseCollection is not null) + _mappings = new(baseCollection); + } + + public DynamicMappingCollection AddMapping(Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) + { + if (type.IsValueType && Nullable.GetUnderlyingType(type) is not null) + throw new NotSupportedException("Mapping nullable types is not supported, map its underlying type instead to get both."); + + (type.IsValueType ? AddStructTypeMethodInfo : AddTypeMethodInfo) + .MakeGenericMethod(type).Invoke(_mappings ??= new(), new object?[] + { + dataTypeName, + factory, + configureMapping + }); + return this; + } + + public DynamicMappingCollection AddArrayMapping(Type elementType, string dataTypeName) + { + (elementType.IsValueType ? AddStructArrayTypeMethodInfo : AddArrayTypeMethodInfo) + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), new object?[] { dataTypeName }); + return this; + } + + public DynamicMappingCollection AddResolverMapping(Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) + { + if (type.IsValueType && Nullable.GetUnderlyingType(type) is not null) + throw new NotSupportedException("Mapping nullable types is not supported"); + + (type.IsValueType ? AddResolverStructTypeMethodInfo : AddResolverTypeMethodInfo) + .MakeGenericMethod(type).Invoke(_mappings ??= new(), new object?[] + { + dataTypeName, + factory, + configureMapping + }); + return this; + } + + public DynamicMappingCollection AddResolverArrayMapping(Type elementType, string dataTypeName) + { + (elementType.IsValueType ? AddResolverStructArrayTypeMethodInfo : AddResolverArrayTypeMethodInfo) + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), new object?[] { dataTypeName }); + return this; + } + + internal PgTypeInfo? Find(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => _mappings?.Find(type, dataTypeName, options); + + public TypeInfoMappingCollection ToTypeInfoMappingCollection() + => new(_mappings?.Items ?? Array.Empty()); + } +} diff --git a/src/Npgsql/Internal/IPgTypeInfoResolver.cs b/src/Npgsql/Internal/IPgTypeInfoResolver.cs new file mode 100644 index 0000000000..62955446eb --- /dev/null +++ b/src/Npgsql/Internal/IPgTypeInfoResolver.cs @@ -0,0 +1,19 @@ +using System; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +/// +/// An Npgsql resolver for type info. Used by Npgsql to read and write values to PostgreSQL. +/// +public interface IPgTypeInfoResolver +{ + /// + /// Resolve a type info for a given type and data type name, at least one value will be non-null. + /// + /// The clr type being requested. + /// The postgres type being requested. + /// Used for configuration state and Npgsql type info or PostgreSQL type catalog lookups. + /// A result, or null if there was no match. + PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options); +} diff --git a/src/Npgsql/Internal/NpgsqlConnector.Auth.cs b/src/Npgsql/Internal/NpgsqlConnector.Auth.cs index d5ea9af5e1..25847da65e 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.Auth.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.Auth.cs @@ -121,7 +121,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async // Assumption: the write buffer is big enough to contain all our outgoing messages var clientNonce = GetNonce(); - await WriteSASLInitialResponse(mechanism, PGUtil.UTF8Encoding.GetBytes($"{cbindFlag},,n=*,r={clientNonce}"), async, cancellationToken); + await WriteSASLInitialResponse(mechanism, NpgsqlWriteBuffer.UTF8Encoding.GetBytes($"{cbindFlag},,n=*,r={clientNonce}"), async, cancellationToken); await Flush(async, cancellationToken); var saslContinueMsg = Expect(await ReadMessage(async), this); @@ -280,8 +280,8 @@ async Task AuthenticateMD5(string username, byte[] salt, bool async, Cancellatio using (var md5 = MD5.Create()) { // First phase - var passwordBytes = PGUtil.UTF8Encoding.GetBytes(passwd); - var usernameBytes = PGUtil.UTF8Encoding.GetBytes(username); + var passwordBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(passwd); + var usernameBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(username); var cryptBuf = new byte[passwordBytes.Length + usernameBytes.Length]; passwordBytes.CopyTo(cryptBuf, 0); usernameBytes.CopyTo(cryptBuf, passwordBytes.Length); @@ -293,7 +293,7 @@ async Task AuthenticateMD5(string username, byte[] salt, bool async, Cancellatio var prehash = sb.ToString(); - var prehashbytes = PGUtil.UTF8Encoding.GetBytes(prehash); + var prehashbytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(prehash); cryptBuf = new byte[prehashbytes.Length + 4]; Array.Copy(salt, 0, cryptBuf, prehashbytes.Length, 4); diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index c2c6c23976..91a492ae5b 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -4,8 +4,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Npgsql.Util; -// ReSharper disable VariableHidesOuterVariable namespace Npgsql.Internal; @@ -141,7 +139,7 @@ internal async Task WriteParse(string sql, string statementName, List 0 ? size.Value : 0; + formatCodesSum += format.ToFormatCode(); } var formatCodeListLength = formatCodesSum == 0 ? 0 : formatCodesSum == parameters.Count ? 1 : parameters.Count; @@ -201,30 +199,38 @@ internal async Task WriteBind( // 0 length implicitly means all-text, 1 means all-binary, >1 means mix-and-match if (formatCodeListLength == 1) { - if (WriteBuffer.WriteSpaceLeft < 2) + if (WriteBuffer.WriteSpaceLeft < sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); - WriteBuffer.WriteInt16((short)FormatCode.Binary); + WriteBuffer.WriteInt16(DataFormat.Binary.ToFormatCode()); } else if (formatCodeListLength > 1) { for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) { - if (WriteBuffer.WriteSpaceLeft < 2) + if (WriteBuffer.WriteSpaceLeft < sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); - WriteBuffer.WriteInt16((short)parameters[paramIndex].FormatCode); + WriteBuffer.WriteInt16(parameters[paramIndex].Format.ToFormatCode()); } } - if (WriteBuffer.WriteSpaceLeft < 2) + if (WriteBuffer.WriteSpaceLeft < sizeof(ushort)) await Flush(async, cancellationToken).ConfigureAwait(false); WriteBuffer.WriteUInt16((ushort)parameters.Count); - for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + var writer = WriteBuffer.GetWriter(DatabaseInfo, async ? FlushMode.NonBlocking : FlushMode.Blocking); + try { - var param = parameters[paramIndex]; - param.LengthCache?.Rewind(); - await param.WriteWithLength(WriteBuffer, async, cancellationToken).ConfigureAwait(false); + for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + { + var param = parameters[paramIndex]; + await param.Write(async, writer, cancellationToken).ConfigureAwait(false); + } + } + catch(Exception ex) + { + Break(ex); + throw; } if (unknownResultTypeList != null) @@ -375,8 +381,8 @@ internal void WriteStartup(Dictionary parameters) sizeof(byte); // Trailing zero byte foreach (var kvp in parameters) - len += PGUtil.UTF8Encoding.GetByteCount(kvp.Key) + 1 + - PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1; + len += NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Key) + 1 + + NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Value) + 1; // Should really never happen, just in case if (len > WriteBuffer.Size) @@ -422,7 +428,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes { var len = sizeof(byte) + // Message code sizeof(int) + // Length - PGUtil.UTF8Encoding.GetByteCount(mechanism) + sizeof(byte) + // Mechanism plus null terminator + NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(mechanism) + sizeof(byte) + // Mechanism plus null terminator sizeof(int) + // Initial response length (initialResponse?.Length ?? 0); // Initial response payload diff --git a/src/Npgsql/Internal/NpgsqlConnector.cs b/src/Npgsql/Internal/NpgsqlConnector.cs index 4fb25fa761..d7e359b6af 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.cs @@ -18,12 +18,10 @@ using System.Threading.Channels; using System.Threading.Tasks; using Npgsql.BackendMessages; -using Npgsql.TypeMapping; using Npgsql.Util; using static Npgsql.Util.Statics; using System.Transactions; using Microsoft.Extensions.Logging; -using Npgsql.Internal.TypeMapping; using Npgsql.Properties; namespace Npgsql.Internal; @@ -115,13 +113,13 @@ internal string InferredUserName /// internal int Id => BackendProcessId; + internal PgSerializerOptions SerializerOptions { get; set; } = default!; + /// /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). /// public NpgsqlDatabaseInfo DatabaseInfo { get; internal set; } = default!; - internal TypeMapper TypeMapper { get; set; } = default!; - /// /// The current transaction status for this connector. /// @@ -182,6 +180,9 @@ internal string InferredUserName /// volatile Exception? _breakReason; + // Used by replication to change our cancellation behaviour on ColumnStreams. + internal bool LongRunningConnection { get; set; } + /// /// /// Used by the pool to indicate that I/O is currently in progress on this connector, so that another write @@ -319,7 +320,7 @@ internal bool PostgresCancellationPerformed readonly ReadyForQueryMessage _readyForQueryMessage = new(); readonly ParameterDescriptionMessage _parameterDescriptionMessage = new(); readonly DataRowMessage _dataRowMessage = new(); - readonly RowDescriptionMessage _rowDescriptionMessage = new(); + readonly RowDescriptionMessage _rowDescriptionMessage = new(connectorOwned: true); // Since COPY is rarely used, allocate these lazily CopyInResponseMessage? _copyInResponseMessage; @@ -500,9 +501,9 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca await DataSource.Bootstrap(this, timeout, forceReload: false, async, cancellationToken); - Debug.Assert(DataSource.TypeMapper is not null); + Debug.Assert(DataSource.SerializerOptions is not null); Debug.Assert(DataSource.DatabaseInfo is not null); - TypeMapper = DataSource.TypeMapper; + SerializerOptions = DataSource.SerializerOptions; DatabaseInfo = DataSource.DatabaseInfo; if (Settings.Pooling && !Settings.Multiplexing && !Settings.NoResetOnClose && DatabaseInfo.SupportsDiscard) @@ -770,8 +771,8 @@ async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, Cancellat if (Settings.Encoding == "UTF8") { - TextEncoding = PGUtil.UTF8Encoding; - RelaxedTextEncoding = PGUtil.RelaxedUTF8Encoding; + TextEncoding = NpgsqlWriteBuffer.UTF8Encoding; + RelaxedTextEncoding = NpgsqlWriteBuffer.RelaxedUTF8Encoding; } else { @@ -1242,7 +1243,7 @@ async Task MultiplexingReadLoop() // TODO: the exception we have here is sometimes just the result of the write loop breaking // the connector, so it doesn't represent the actual root cause. - pendingCommand.ExecutionCompletion.SetException(_breakReason!); + pendingCommand.ExecutionCompletion.SetException(new NpgsqlException("A previous command on this connection caused an error requiring all pending commands on this connection to be aborted", _breakReason!)); } } catch (ChannelClosedException) @@ -1303,7 +1304,7 @@ internal ValueTask ReadMessage( return ReadMessageLong(async, dataRowLoadingMode, readingNotifications: false)!; } - PGUtil.ValidateBackendMessageCode(messageCode); + ValidateBackendMessageCode(messageCode); var len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself if (len > ReadBuffer.ReadBytesLeft) { @@ -1337,7 +1338,8 @@ internal ValueTask ReadMessage( { // Prepended queries should never fail. // If they do, we're not even going to attempt to salvage the connector. - throw Break(e); + Break(e); + throw; } } @@ -1351,7 +1353,7 @@ internal ValueTask ReadMessage( { await ReadBuffer.Ensure(5, async, readingNotifications); var messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); - PGUtil.ValidateBackendMessageCode(messageCode); + ValidateBackendMessageCode(messageCode); var len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself if ((messageCode == BackendMessageCode.DataRow && @@ -1432,6 +1434,12 @@ internal ValueTask ReadMessage( } Debug.Assert(msg != null, "Message is null for code: " + messageCode); + + // Reset flushed bytes after any RFQ or in between potentially long running operations. + // Just in case we'll hit that 15 exbibyte limit of a signed long... + if (messageCode is BackendMessageCode.ReadyForQuery or BackendMessageCode.CopyData or BackendMessageCode.NotificationResponse) + ReadBuffer.ResetFlushedBytes(); + return msg; } } @@ -1464,7 +1472,7 @@ internal ValueTask ReadMessage( switch (code) { case BackendMessageCode.RowDescription: - return _rowDescriptionMessage.Load(buf, TypeMapper); + return _rowDescriptionMessage.Load(buf, SerializerOptions); case BackendMessageCode.DataRow: return _dataRowMessage.Load(len); case BackendMessageCode.CommandComplete: @@ -1891,15 +1899,44 @@ internal CancellationTokenRegistration StartCancellableOperation( /// PostgreSQL cancellation will be skipped and client-socket cancellation will occur immediately. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal CancellationTokenRegistration StartNestedCancellableOperation( + internal NestedCancellableScope StartNestedCancellableOperation( CancellationToken cancellationToken = default, bool attemptPgCancellation = true) { + var currentUserCancellationToken = UserCancellationToken; UserCancellationToken = cancellationToken; + var currentAttemptPostgresCancellation = AttemptPostgresCancellation; AttemptPostgresCancellation = attemptPgCancellation; - return _cancellationTokenRegistration = - cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); + var registration = cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); + + return new(this, registration, currentUserCancellationToken, currentAttemptPostgresCancellation); + } + + internal readonly struct NestedCancellableScope : IDisposable + { + readonly NpgsqlConnector _connector; + readonly CancellationTokenRegistration _registration; + readonly CancellationToken _previousCancellationToken; + readonly bool _previousAttemptPostgresCancellation; + + public NestedCancellableScope(NpgsqlConnector connector, CancellationTokenRegistration registration, CancellationToken previousCancellationToken, bool previousAttemptPostgresCancellation) + { + _connector = connector; + _registration = registration; + _previousCancellationToken = previousCancellationToken; + _previousAttemptPostgresCancellation = previousAttemptPostgresCancellation; + } + + public void Dispose() + { + if (_connector is null) + return; + + _connector.UserCancellationToken = _previousCancellationToken; + _connector.AttemptPostgresCancellation = _previousAttemptPostgresCancellation; + _registration.Dispose(); + } } #endregion Cancel @@ -2318,6 +2355,7 @@ internal async Task Reset(bool async) [MethodImpl(MethodImplOptions.AggressiveInlining)] void ResetReadBuffer() { + LongRunningConnection = false; if (_origReadBuffer != null) { Debug.Assert(_origReadBuffer.ReadBytesLeft == 0); @@ -2615,7 +2653,8 @@ internal async Task Wait(bool async, int timeout, CancellationToken cancel { // We're somewhere in the middle of a reading keepalive messages // Breaking the connection, as we've lost protocol sync - throw Break(e); + Break(e); + throw; } if (msg == null) diff --git a/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs b/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs index 09417eef21..f3c8ea52a3 100644 --- a/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs +++ b/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs @@ -1,9 +1,9 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading.Tasks; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; using Npgsql.Util; @@ -17,8 +17,7 @@ public abstract class NpgsqlDatabaseInfo { #region Fields - static volatile INpgsqlDatabaseInfoFactory[] Factories = new INpgsqlDatabaseInfoFactory[] - { + static volatile INpgsqlDatabaseInfoFactory[] Factories = { new PostgresMinimalDatabaseInfoFactory(), new PostgresDatabaseInfoFactory() }; @@ -138,7 +137,7 @@ public abstract class NpgsqlDatabaseInfo internal Dictionary ByOID { get; } = new(); /// - /// Indexes backend types by their PostgreSQL name, including namespace (e.g. pg_catalog.int4). + /// Indexes backend types by their PostgreSQL internal name, including namespace (e.g. pg_catalog.int4). /// Only used for enums and composites. /// internal Dictionary ByFullName { get; } = new(); @@ -179,10 +178,22 @@ private protected NpgsqlDatabaseInfo(string host, int port, string databaseName, Version = ParseServerVersion(serverVersion); } - public PostgresType GetPostgresTypeByName(string pgName) + internal PostgresType GetPostgresType(Oid oid) => GetPostgresType(oid.Value); + + public PostgresType GetPostgresType(uint oid) + => ByOID.TryGetValue(oid, out var pgType) + ? pgType + : throw new ArgumentException($"A PostgreSQL type with the oid '{oid}' was not found in the current database info"); + + internal PostgresType GetPostgresType(DataTypeName dataTypeName) + => ByFullName.TryGetValue(dataTypeName.Value, out var value) + ? value + : throw new ArgumentException($"A PostgreSQL type with the name '{dataTypeName}' was not found in the current database info"); + + public PostgresType GetPostgresType(string pgName) => TryGetPostgresTypeByName(pgName, out var pgType) ? pgType - : throw new ArgumentException($"A PostgreSQL type with the name '{pgName}' was not found in the database"); + : throw new ArgumentException($"A PostgreSQL type with the name '{pgName}' was not found in the current database info"); public bool TryGetPostgresTypeByName(string pgName, [NotNullWhen(true)] out PostgresType? pgType) { @@ -217,10 +228,10 @@ internal void ProcessTypes() foreach (var type in GetTypes()) { ByOID[type.OID] = type; - ByFullName[type.FullName] = type; + ByFullName[type.DataTypeName.Value] = type; // If more than one type exists with the same partial name, we place a null value. // This allows us to detect this case later and force the user to use full names only. - ByName[type.Name] = ByName.ContainsKey(type.Name) + ByName[type.InternalName] = ByName.ContainsKey(type.InternalName) ? null : type; @@ -326,4 +337,24 @@ internal static void ResetFactories() }; #endregion Factory management -} \ No newline at end of file + + internal Oid GetOid(PgTypeId pgTypeId, bool validate = false) + => pgTypeId.IsOid + ? validate ? GetPostgresType(pgTypeId.Oid).OID : pgTypeId.Oid + : GetPostgresType(pgTypeId.DataTypeName).OID; + + internal DataTypeName GetDataTypeName(PgTypeId pgTypeId, bool validate = false) + => pgTypeId.IsDataTypeName + ? validate ? GetPostgresType(pgTypeId.DataTypeName).DataTypeName : pgTypeId.DataTypeName + : GetPostgresType(pgTypeId.Oid).DataTypeName; + + internal PostgresType GetPostgresType(PgTypeId pgTypeId) + => pgTypeId.IsOid + ? GetPostgresType(pgTypeId.Oid.Value) + : GetPostgresType(pgTypeId.DataTypeName.Value); + + internal PostgresType? FindPostgresType(PgTypeId pgTypeId) + => pgTypeId.IsOid + ? ByOID.TryGetValue(pgTypeId.Oid.Value, out var pgType) ? pgType : null + : TryGetPostgresTypeByName(pgTypeId.DataTypeName.Value, out pgType) ? pgType : null; +} diff --git a/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs b/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs index cd38bcad0f..e99b77fa1b 100644 --- a/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs +++ b/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs @@ -6,33 +6,44 @@ namespace Npgsql.Internal; -public sealed partial class NpgsqlReadBuffer +sealed partial class NpgsqlReadBuffer { internal sealed class ColumnStream : Stream +#if NETSTANDARD2_0 + , IAsyncDisposable +#endif { readonly NpgsqlConnector _connector; readonly NpgsqlReadBuffer _buf; - int _start, _len, _read; + long _startPos; + int _start; + int _read; bool _canSeek; - readonly bool _startCancellableOperations; + bool _commandScoped; + /// Does not throw ODE. + internal int CurrentLength { get; private set; } internal bool IsDisposed { get; private set; } - internal ColumnStream(NpgsqlConnector connector, bool startCancellableOperations = true) + internal ColumnStream(NpgsqlConnector connector) { _connector = connector; _buf = connector.ReadBuffer; - _startCancellableOperations = startCancellableOperations; IsDisposed = true; } - internal void Init(int len, bool canSeek) + internal void Init(int len, bool canSeek, bool commandScoped) { Debug.Assert(!canSeek || _buf.ReadBytesLeft >= len, "Seekable stream constructed but not all data is in buffer (sequential)"); - _start = _buf.ReadPosition; - _len = len; - _read = 0; + _startPos = _buf.CumulativeReadPosition; + _canSeek = canSeek; + _start = canSeek ? _buf.ReadPosition : 0; + + CurrentLength = len; + _read = 0; + + _commandScoped = commandScoped; IsDisposed = false; } @@ -47,7 +58,7 @@ public override long Length get { CheckDisposed(); - return _len; + return CurrentLength; } } @@ -102,11 +113,11 @@ public override long Seek(long offset, SeekOrigin origin) } case SeekOrigin.End: { - var tempPosition = unchecked(_start + _len + (int)offset); - if (unchecked(_start + _len + offset) < _start || tempPosition < _start) + var tempPosition = unchecked(_start + CurrentLength + (int)offset); + if (unchecked(_start + CurrentLength + offset) < _start || tempPosition < _start) throw new IOException(seekBeforeBegin); _buf.ReadPosition = tempPosition; - _read = _len + (int)offset; + _read = CurrentLength + (int)offset; return _read; } default: @@ -140,9 +151,7 @@ public override int Read(byte[] buffer, int offset, int count) public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { ValidateArguments(buffer, offset, count); - - using (NoSynchronizationContextScope.Enter()) - return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); } #if NETSTANDARD2_0 @@ -153,12 +162,12 @@ public override int Read(Span span) { CheckDisposed(); - var count = Math.Min(span.Length, _len - _read); + var count = Math.Min(span.Length, CurrentLength - _read); if (count == 0) return 0; - var read = _buf.Read(span.Slice(0, count)); + var read = _buf.Read(_commandScoped, span.Slice(0, count)); _read += read; return read; @@ -172,20 +181,16 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken { CheckDisposed(); - var count = Math.Min(buffer.Length, _len - _read); - - if (count == 0) - return new ValueTask(0); - - using (NoSynchronizationContextScope.Enter()) - return ReadLong(this, buffer.Slice(0, count), cancellationToken); + var count = Math.Min(buffer.Length, CurrentLength - _read); + return count == 0 ? new ValueTask(0) : ReadLong(this, buffer.Slice(0, count), cancellationToken); static async ValueTask ReadLong(ColumnStream stream, Memory buffer, CancellationToken cancellationToken = default) { - using var registration = stream._startCancellableOperations + using var registration = cancellationToken.CanBeCanceled ? stream._connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false) : default; - var read = await stream._buf.ReadAsync(buffer, cancellationToken); + + var read = await stream._buf.ReadAsync(stream._commandScoped, buffer, cancellationToken).ConfigureAwait(false); stream._read += read; return read; } @@ -208,24 +213,21 @@ public ValueTask DisposeAsync() #else public override ValueTask DisposeAsync() #endif - { - using (NoSynchronizationContextScope.Enter()) - return DisposeAsync(disposing: true, async: true); - } + => DisposeAsync(disposing: true, async: true); async ValueTask DisposeAsync(bool disposing, bool async) { if (IsDisposed || !disposing) return; - var leftToSkip = _len - _read; - if (leftToSkip > 0) + if (!_connector.IsBroken) { - if (async) - await _buf.Skip(leftToSkip, async); - else - _buf.Skip(leftToSkip, async).GetAwaiter().GetResult(); + var pos = _buf.CumulativeReadPosition - _startPos; + var remaining = checked((int)(CurrentLength - pos)); + if (remaining > 0) + await _buf.Skip(remaining, async).ConfigureAwait(false); } + IsDisposed = true; } } diff --git a/src/Npgsql/Internal/NpgsqlReadBuffer.cs b/src/Npgsql/Internal/NpgsqlReadBuffer.cs index f854f27476..cb28028815 100644 --- a/src/Npgsql/Internal/NpgsqlReadBuffer.cs +++ b/src/Npgsql/Internal/NpgsqlReadBuffer.cs @@ -2,7 +2,6 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Sockets; using System.Runtime.CompilerServices; @@ -12,15 +11,13 @@ using Npgsql.Util; using static System.Threading.Timeout; -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - namespace Npgsql.Internal; /// /// A buffer used by Npgsql to read data from the socket efficiently. /// Provides methods which decode different values types and tracks the current position. /// -public sealed partial class NpgsqlReadBuffer : IDisposable +sealed partial class NpgsqlReadBuffer : IDisposable { #region Fields and Properties @@ -74,13 +71,15 @@ internal TimeSpan Timeout internal int ReadPosition { get; set; } internal int ReadBytesLeft => FilledBytes - ReadPosition; + internal PgReader PgReader { get; } + + long _flushedBytes; // this will always fit at least one message. + internal long CumulativeReadPosition => unchecked(_flushedBytes + ReadPosition); internal readonly byte[] Buffer; internal int FilledBytes; - ColumnStream? _columnStream; - - PreparedTextReader? _preparedTextReader; + internal ReadOnlySpan Span => Buffer.AsSpan(ReadPosition, ReadBytesLeft); readonly bool _usePool; bool _disposed; @@ -120,20 +119,163 @@ internal NpgsqlReadBuffer( TextEncoding = textEncoding; RelaxedTextEncoding = relaxedTextEncoding; + PgReader = new PgReader(this); } #endregion #region I/O - internal void Ensure(int count) => Ensure(count, false).GetAwaiter().GetResult(); - public Task Ensure(int count, bool async) => Ensure(count, async, readingNotifications: false); public Task EnsureAsync(int count) => Ensure(count, async: true, readingNotifications: false); + // Can't share due to Span vs Memory difference (can't make a memory out of a span). + int ReadWithTimeout(Span buffer) + { + while (true) + { + try + { + var read = Underlying.Read(buffer); + _flushedBytes = unchecked(_flushedBytes + read); + NpgsqlEventSource.Log.BytesRead(read); + return read; + } + catch (Exception ex) + { + var connector = Connector; + switch (ex) + { + // Note that mono throws SocketException with the wrong error (see #1330) + case IOException e when (e.InnerException as SocketException)?.SocketErrorCode == + (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + { + var isStreamBroken = false; +#if NETSTANDARD2_0 + // SslStream on .NET Framework treats any IOException (including timeouts) as fatal and may + // return garbage if reused. To prevent this, we flow down and break the connection immediately. + // See #4305. + isStreamBroken = connector.IsSecure && ex is IOException; +#endif + + if (!isStreamBroken) + { + // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. + // TODO: As an optimization, we can still attempt to send a cancellation request, but after + // that immediately break the connection + if (connector.AttemptPostgresCancellation && + !connector.PostgresCancellationPerformed && + connector.PerformPostgresCancellation()) + { + // Note that if the cancellation timeout is negative, we flow down and break the + // connection immediately. + var cancellationTimeout = connector.Settings.CancellationTimeout; + if (cancellationTimeout >= 0) + { + if (cancellationTimeout > 0) + Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + + continue; + } + } + } + + // If we're here, the PostgreSQL cancellation either failed or skipped entirely. + // Break the connection, bubbling up the correct exception type (cancellation or timeout) + throw connector.Break(CreateCancelException(connector)); + } + default: + throw connector.Break(new NpgsqlException("Exception while reading from stream", ex)); + } + } + } + } + + async ValueTask ReadWithTimeoutAsync(Memory buffer, CancellationToken cancellationToken) + { + var finalCt = Timeout != TimeSpan.Zero + ? Cts.Start(cancellationToken) + : Cts.Reset(); + + while (true) + { + try + { + var read = await Underlying.ReadAsync(buffer, finalCt).ConfigureAwait(false); + _flushedBytes = unchecked(_flushedBytes + read); + Cts.Stop(); + NpgsqlEventSource.Log.BytesRead(read); + return read; + } + catch (Exception ex) + { + var connector = Connector; + Cts.Stop(); + switch (ex) + { + // Read timeout + case OperationCanceledException: + // Note that mono throws SocketException with the wrong error (see #1330) + case IOException e when (e.InnerException as SocketException)?.SocketErrorCode == + (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + { + Debug.Assert(ex is OperationCanceledException); + var isStreamBroken = false; +#if NETSTANDARD2_0 + // SslStream on .NET Framework treats any IOException (including timeouts) as fatal and may + // return garbage if reused. To prevent this, we flow down and break the connection immediately. + // See #4305. + isStreamBroken = connector.IsSecure && ex is IOException; +#endif + + if (!isStreamBroken) + { + // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. + // TODO: As an optimization, we can still attempt to send a cancellation request, but after + // that immediately break the connection + if (connector.AttemptPostgresCancellation && + !connector.PostgresCancellationPerformed && + connector.PerformPostgresCancellation()) + { + // Note that if the cancellation timeout is negative, we flow down and break the + // connection immediately. + var cancellationTimeout = connector.Settings.CancellationTimeout; + if (cancellationTimeout >= 0) + { + if (cancellationTimeout > 0) + Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + + finalCt = Cts.Start(cancellationToken); + continue; + } + } + } + + // If we're here, the PostgreSQL cancellation either failed or skipped entirely. + // Break the connection, bubbling up the correct exception type (cancellation or timeout) + throw connector.Break(CreateCancelException(connector)); + } + default: + throw connector.Break(new NpgsqlException("Exception while reading from stream", ex)); + } + } + } + } + + static Exception CreateCancelException(NpgsqlConnector connector) + => !connector.UserCancellationRequested + ? NpgsqlTimeoutException() + : connector.PostgresCancellationPerformed + ? new OperationCanceledException("Query was cancelled", TimeoutException(), connector.UserCancellationToken) + : new OperationCanceledException("Query was cancelled", connector.UserCancellationToken); + + static Exception NpgsqlTimeoutException() => new NpgsqlException("Exception while reading from stream", TimeoutException()); + + static Exception TimeoutException() => new TimeoutException("Timeout during reading attempt"); + /// /// Ensures that bytes are available in the buffer, and if /// not, reads from the socket until enough is available. @@ -154,12 +296,13 @@ static async Task EnsureLong( if (buffer.ReadPosition == buffer.FilledBytes) { - buffer.Clear(); + buffer.ResetPosition(); } else if (count > buffer.Size - buffer.FilledBytes) { Array.Copy(buffer.Buffer, buffer.ReadPosition, buffer.Buffer, 0, buffer.ReadBytesLeft); buffer.FilledBytes = buffer.ReadBytesLeft; + buffer._flushedBytes = unchecked(buffer._flushedBytes + buffer.ReadPosition); buffer.ReadPosition = 0; } @@ -174,7 +317,7 @@ static async Task EnsureLong( { var toRead = buffer.Size - buffer.FilledBytes; var read = async - ? await buffer.Underlying.ReadAsync(buffer.Buffer.AsMemory(buffer.FilledBytes, toRead), finalCt) + ? await buffer.Underlying.ReadAsync(buffer.Buffer.AsMemory(buffer.FilledBytes, toRead), finalCt).ConfigureAwait(false) : buffer.Underlying.Read(buffer.Buffer, buffer.FilledBytes, toRead); if (read == 0) @@ -287,23 +430,23 @@ internal NpgsqlReadBuffer AllocateOversize(int count) if (_underlyingSocket != null) tempBuf.Timeout = Timeout; CopyTo(tempBuf); - Clear(); + ResetPosition(); return tempBuf; } /// /// Does not perform any I/O - assuming that the bytes to be skipped are in the memory buffer. /// - internal void Skip(long len) + internal void Skip(int len) { Debug.Assert(ReadBytesLeft >= len); - ReadPosition += (int)len; + ReadPosition += len; } /// /// Skip a given number of bytes. /// - public async Task Skip(long len, bool async) + public async Task Skip(int len, bool async) { Debug.Assert(len >= 0); @@ -312,15 +455,15 @@ public async Task Skip(long len, bool async) len -= ReadBytesLeft; while (len > Size) { - Clear(); - await Ensure(Size, async); + ResetPosition(); + await Ensure(Size, async).ConfigureAwait(false); len -= Size; } - Clear(); - await Ensure((int)len, async); + ResetPosition(); + await Ensure(len, async).ConfigureAwait(false); } - ReadPosition += (int)len; + ReadPosition += len; } #endregion @@ -428,19 +571,14 @@ public double ReadDouble(bool littleEndian) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - T Read() + unsafe T Read() where T : unmanaged { - if (Unsafe.SizeOf() > ReadBytesLeft) - ThrowNotSpaceLeft(); - + Debug.Assert(sizeof(T) <= ReadBytesLeft, "There is not enough space left in the buffer."); var result = Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); - ReadPosition += Unsafe.SizeOf(); + ReadPosition += sizeof(T); return result; } - static void ThrowNotSpaceLeft() - => ThrowHelper.ThrowInvalidOperationException("There is not enough space left in the buffer."); - public string ReadString(int byteLen) { Debug.Assert(byteLen <= ReadBytesLeft); @@ -449,14 +587,6 @@ public string ReadString(int byteLen) return result; } - public char[] ReadChars(int byteLen) - { - Debug.Assert(byteLen <= ReadBytesLeft); - var result = TextEncoding.GetChars(Buffer, ReadPosition, byteLen); - ReadPosition += byteLen; - return result; - } - public void ReadBytes(Span output) { Debug.Assert(output.Length <= ReadBytesLeft); @@ -467,14 +597,6 @@ public void ReadBytes(Span output) public void ReadBytes(byte[] output, int outputOffset, int len) => ReadBytes(new Span(output, outputOffset, len)); - public ReadOnlySpan ReadSpan(int len) - { - Debug.Assert(len <= ReadBytesLeft); - var span = new ReadOnlySpan(Buffer, ReadPosition, len); - ReadPosition += len; - return span; - } - public ReadOnlyMemory ReadMemory(int len) { Debug.Assert(len <= ReadBytesLeft); @@ -487,26 +609,31 @@ public ReadOnlyMemory ReadMemory(int len) #region Read Complex - public int Read(Span output) + public int Read(bool commandScoped, Span output) { var readFromBuffer = Math.Min(ReadBytesLeft, output.Length); if (readFromBuffer > 0) { - new Span(Buffer, ReadPosition, readFromBuffer).CopyTo(output); + Buffer.AsSpan(ReadPosition, readFromBuffer).CopyTo(output); ReadPosition += readFromBuffer; return readFromBuffer; } - if (output.Length == 0) - return 0; + // Only reset if we'll be able to read data, this is to support zero-byte reads. + if (output.Length > 0) + { + Debug.Assert(ReadBytesLeft == 0); + ResetPosition(); + } + + if (commandScoped) + return ReadWithTimeout(output); - Debug.Assert(ReadBytesLeft == 0); - Clear(); try { var read = Underlying.Read(output); - if (read == 0) - throw new EndOfStreamException(); + _flushedBytes = unchecked(_flushedBytes + read); + NpgsqlEventSource.Log.BytesRead(read); return read; } catch (Exception e) @@ -515,30 +642,35 @@ public int Read(Span output) } } - public ValueTask ReadAsync(Memory output, CancellationToken cancellationToken = default) + public ValueTask ReadAsync(bool commandScoped, Memory output, CancellationToken cancellationToken = default) { var readFromBuffer = Math.Min(ReadBytesLeft, output.Length); if (readFromBuffer > 0) { - new Span(Buffer, ReadPosition, readFromBuffer).CopyTo(output.Span); + Buffer.AsSpan(ReadPosition, readFromBuffer).CopyTo(output.Span); ReadPosition += readFromBuffer; return new ValueTask(readFromBuffer); } - if (output.Length == 0) - return new ValueTask(0); + return ReadAsyncLong(this, commandScoped, output, cancellationToken); - return ReadAsyncLong(this, output, cancellationToken); - - static async ValueTask ReadAsyncLong(NpgsqlReadBuffer buffer, Memory output, CancellationToken cancellationToken) + static async ValueTask ReadAsyncLong(NpgsqlReadBuffer buffer, bool commandScoped, Memory output, CancellationToken cancellationToken) { - Debug.Assert(buffer.ReadBytesLeft == 0); - buffer.Clear(); + // Only reset if we'll be able to read data, this is to support zero-byte reads. + if (output.Length > 0) + { + Debug.Assert(buffer.ReadBytesLeft == 0); + buffer.ResetPosition(); + } + + if (commandScoped) + return await buffer.ReadWithTimeoutAsync(output, cancellationToken).ConfigureAwait(false); + try { - var read = await buffer.Underlying.ReadAsync(output, cancellationToken); - if (read == 0) - throw new EndOfStreamException(); + var read = await buffer.Underlying.ReadAsync(output, cancellationToken).ConfigureAwait(false); + buffer._flushedBytes = unchecked(buffer._flushedBytes + read); + NpgsqlEventSource.Log.BytesRead(read); return read; } catch (Exception e) @@ -548,22 +680,13 @@ static async ValueTask ReadAsyncLong(NpgsqlReadBuffer buffer, Memory } } - public Stream GetStream(int len, bool canSeek) + ColumnStream? _lastStream; + public ColumnStream CreateStream(int len, bool canSeek) { - if (_columnStream == null) - _columnStream = new ColumnStream(Connector); - - _columnStream.Init(len, canSeek); - return _columnStream; - } - - public TextReader GetPreparedTextReader(string str, Stream stream) - { - if (_preparedTextReader is not { IsDisposed: true }) - _preparedTextReader = new PreparedTextReader(); - - _preparedTextReader.Init(str, (ColumnStream)stream); - return _preparedTextReader; + if (_lastStream is not { IsDisposed: true }) + _lastStream = new ColumnStream(Connector); + _lastStream.Init(len, canSeek, !Connector.LongRunningConnection); + return _lastStream; } /// @@ -588,9 +711,9 @@ public ValueTask ReadNullTerminatedString(bool async, CancellationToken /// Seeks the first null terminator (\0) and returns the string up to it. Reads additional data from the network if a null /// terminator isn't found in the buffered data. /// - ValueTask ReadNullTerminatedString(Encoding encoding, bool async, CancellationToken cancellationToken = default) + public ValueTask ReadNullTerminatedString(Encoding encoding, bool async, CancellationToken cancellationToken = default) { - var index = Buffer.AsSpan(ReadPosition, FilledBytes - ReadPosition).IndexOf((byte)0); + var index = Span.IndexOf((byte)0); if (index >= 0) { var result = new ValueTask(encoding.GetString(Buffer, ReadPosition, index)); @@ -614,7 +737,7 @@ async ValueTask ReadLong(Encoding encoding, bool async) do { - await ReadMore(async); + await ReadMore(async).ConfigureAwait(false); Debug.Assert(ReadPosition == 0); foundTerminator = false; @@ -655,7 +778,7 @@ async ValueTask ReadLong(Encoding encoding, bool async) public ReadOnlySpan GetNullTerminatedBytes() { - var i = Buffer.AsSpan(ReadPosition).IndexOf((byte)0); + var i = Span.IndexOf((byte)0); Debug.Assert(i >= 0); var result = new ReadOnlySpan(Buffer, ReadPosition, i); ReadPosition += i + 1; @@ -682,12 +805,15 @@ public void Dispose() #region Misc - internal void Clear() + void ResetPosition() { + _flushedBytes = unchecked(_flushedBytes + FilledBytes); ReadPosition = 0; FilledBytes = 0; } + internal void ResetFlushedBytes() => _flushedBytes = 0; + internal void CopyTo(NpgsqlReadBuffer other) { Debug.Assert(other.Size - other.FilledBytes >= ReadBytesLeft); diff --git a/src/Npgsql/Internal/NpgsqlWriteBuffer.Stream.cs b/src/Npgsql/Internal/NpgsqlWriteBuffer.Stream.cs deleted file mode 100644 index 428fb0ec30..0000000000 --- a/src/Npgsql/Internal/NpgsqlWriteBuffer.Stream.cs +++ /dev/null @@ -1,122 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql.Internal; - -public sealed partial class NpgsqlWriteBuffer -{ - sealed class ParameterStream : Stream - { - readonly NpgsqlWriteBuffer _buf; - bool _disposed; - - internal ParameterStream(NpgsqlWriteBuffer buf) - => _buf = buf; - - internal void Init() - => _disposed = false; - - public override bool CanRead => false; - - public override bool CanWrite => true; - - public override bool CanSeek => false; - - public override long Length => throw new NotSupportedException(); - - public override void SetLength(long value) - => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - public override long Seek(long offset, SeekOrigin origin) - => throw new NotSupportedException(); - - public override void Flush() - => CheckDisposed(); - - public override Task FlushAsync(CancellationToken cancellationToken = default) - { - CheckDisposed(); - return cancellationToken.IsCancellationRequested - ? Task.FromCanceled(cancellationToken) : Task.CompletedTask; - } - - public override int Read(byte[] buffer, int offset, int count) - => throw new NotSupportedException(); - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - => throw new NotSupportedException(); - - public override void Write(byte[] buffer, int offset, int count) - => Write(buffer, offset, count, false); - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return Write(buffer, offset, count, true, cancellationToken); - } - - Task Write(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - CheckDisposed(); - - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentNullException(nameof(offset)); - if (count < 0) - throw new ArgumentNullException(nameof(count)); - if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - - while (count > 0) - { - var left = _buf.WriteSpaceLeft; - if (left == 0) - return WriteLong(buffer, offset, count, async, cancellationToken); - - var slice = Math.Min(count, left); - _buf.WriteBytes(buffer, offset, slice); - offset += slice; - count -= slice; - } - - return Task.CompletedTask; - } - - async Task WriteLong(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - while (count > 0) - { - var left = _buf.WriteSpaceLeft; - if (left == 0) - { - await _buf.Flush(async, cancellationToken); - continue; - } - var slice = Math.Min(count, left); - _buf.WriteBytes(buffer, offset, slice); - offset += slice; - count -= slice; - } - } - - void CheckDisposed() - { - if (_disposed) - ThrowHelper.ThrowObjectDisposedException(nameof(ParameterStream)); - } - - protected override void Dispose(bool disposing) - => _disposed = true; - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs index 451e7d5263..94775ec3ad 100644 --- a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs +++ b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs @@ -1,5 +1,4 @@ using System; -using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; using System.IO; @@ -19,10 +18,13 @@ namespace Npgsql.Internal; /// A buffer used by Npgsql to write data to the socket efficiently. /// Provides methods which encode different values types and tracks the current position. /// -public sealed partial class NpgsqlWriteBuffer : IDisposable +sealed class NpgsqlWriteBuffer : IDisposable { #region Fields and Properties + internal static readonly UTF8Encoding UTF8Encoding = new(false, true); + internal static readonly UTF8Encoding RelaxedUTF8Encoding = new(false, false); + internal readonly NpgsqlConnector Connector; internal Stream Underlying { private get; set; } @@ -67,14 +69,23 @@ internal TimeSpan Timeout public int WriteSpaceLeft => Size - WritePosition; + internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode? flushMode = null) + { + // Make sure we'll refetch from the write buffer. + _pgWriter.Reset(); + var writer = _pgWriter.Init(typeCatalog); + if (flushMode is not null) + writer.WithFlushMode(flushMode.GetValueOrDefault()); + return writer; + } + internal readonly byte[] Buffer; readonly Encoder _textEncoder; internal int WritePosition; - ParameterStream? _parameterStream; - bool _disposed; + readonly PgWriter _pgWriter; /// /// The minimum buffer size possible. @@ -106,6 +117,7 @@ internal NpgsqlWriteBuffer( TextEncoding = textEncoding; _textEncoder = TextEncoding.GetEncoder(); + _pgWriter = new PgWriter(new NpgsqlBufferWriter(this)); } #endregion @@ -378,60 +390,12 @@ static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s } } - internal Task WriteChars(char[] chars, int offset, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default) - { - if (byteLen <= WriteSpaceLeft) - { - WriteChars(chars, offset, charLen); - return Task.CompletedTask; - } - return WriteCharsLong(this, async, chars, offset, charLen, byteLen, cancellationToken); - - static async Task WriteCharsLong(NpgsqlWriteBuffer buffer, bool async, char[] chars, int offset, int charLen, int byteLen, CancellationToken cancellationToken) - { - Debug.Assert(byteLen > buffer.WriteSpaceLeft); - if (byteLen <= buffer.Size) - { - // String can fit entirely in an empty buffer. Flush and retry rather than - // going into the partial writing flow below (which requires ToCharArray()) - await buffer.Flush(async, cancellationToken); - buffer.WriteChars(chars, offset, charLen); - } - else - { - var charPos = 0; - - while (true) - { - buffer.WriteStringChunked(chars, charPos + offset, charLen - charPos, true, out var charsUsed, out var completed); - if (completed) - break; - await buffer.Flush(async, cancellationToken); - charPos += charsUsed; - } - } - } - } - public void WriteString(string s, int len = 0) { Debug.Assert(TextEncoding.GetByteCount(s) <= WriteSpaceLeft); WritePosition += TextEncoding.GetBytes(s, 0, len == 0 ? s.Length : len, Buffer, WritePosition); } - internal void WriteChars(char[] chars, int offset, int len) - { - var charCount = len == 0 ? chars.Length : len; - Debug.Assert(TextEncoding.GetByteCount(chars, 0, charCount) <= WriteSpaceLeft); - WritePosition += TextEncoding.GetBytes(chars, offset, charCount, Buffer, WritePosition); - } - - internal void WriteChars(ReadOnlySpan chars) - { - Debug.Assert(TextEncoding.GetByteCount(chars) <= WriteSpaceLeft); - WritePosition += TextEncoding.GetBytes(chars, Buffer.AsSpan(WritePosition)); - } - public void WriteBytes(ReadOnlySpan buf) { Debug.Assert(buf.Length <= WriteSpaceLeft); @@ -518,15 +482,6 @@ public void WriteNullTerminatedString(string s) #region Write Complex - public Stream GetStream() - { - if (_parameterStream == null) - _parameterStream = new ParameterStream(this); - - _parameterStream.Init(); - return _parameterStream; - } - internal void WriteStringChunked(char[] chars, int charIndex, int charCount, bool flush, out int charsUsed, out bool completed) { diff --git a/src/Npgsql/Internal/PgBufferedConverter.cs b/src/Npgsql/Internal/PgBufferedConverter.cs new file mode 100644 index 0000000000..7faf7bb0c4 --- /dev/null +++ b/src/Npgsql/Internal/PgBufferedConverter.cs @@ -0,0 +1,52 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +public abstract class PgBufferedConverter : PgConverter +{ + protected PgBufferedConverter(bool customDbNullPredicate = false) : base(customDbNullPredicate) { } + + protected abstract T ReadCore(PgReader reader); + protected abstract void WriteCore(PgWriter writer, T value); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => throw new NotSupportedException(); + + public sealed override T Read(PgReader reader) + { + // We check IsAtStart first to speed up primitive reads. + if (!reader.IsAtStart && reader.ShouldBufferCurrent()) + ThrowIORequired(); + + return ReadCore(reader); + } + + public sealed override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => new(Read(reader)); + + internal sealed override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) + => new(Read(reader)!); + + public sealed override void Write(PgWriter writer, T value) + { + if (!writer.BufferingWrite && writer.ShouldFlush(writer.CurrentBufferRequirement)) + ThrowIORequired(); + + WriteCore(writer, value); + } + + public sealed override ValueTask WriteAsync(PgWriter writer, [DisallowNull] T value, CancellationToken cancellationToken = default) + { + Write(writer, value); + return new(); + } + + internal sealed override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + Write(writer, (T)value); + return new(); + } +} diff --git a/src/Npgsql/Internal/PgComposingConverterResolver.cs b/src/Npgsql/Internal/PgComposingConverterResolver.cs new file mode 100644 index 0000000000..543ef8bdbd --- /dev/null +++ b/src/Npgsql/Internal/PgComposingConverterResolver.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +abstract class PgComposingConverterResolver : PgConverterResolver +{ + readonly PgTypeId? _pgTypeId; + public PgResolverTypeInfo EffectiveTypeInfo { get; } + readonly ConcurrentDictionary _converters = new(ReferenceEqualityComparer.Instance); + + protected PgComposingConverterResolver(PgTypeId? pgTypeId, PgResolverTypeInfo effectiveTypeInfo) + { + if (pgTypeId is null && effectiveTypeInfo.PgTypeId is not null) + throw new ArgumentNullException(nameof(pgTypeId), $"Cannot be null if {nameof(effectiveTypeInfo)}.{nameof(PgTypeInfo.PgTypeId)} is not null."); + + _pgTypeId = pgTypeId; + EffectiveTypeInfo = effectiveTypeInfo; + } + + protected abstract PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId); + protected abstract PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId); + protected abstract PgConverter CreateConverter(PgConverterResolution effectiveResolution); + protected abstract PgConverterResolution? GetEffectiveResolution(T? value, PgTypeId? expectedEffectivePgTypeId); + + public override PgConverterResolution GetDefault(PgTypeId? pgTypeId) + { + PgTypeId? effectivePgTypeId = pgTypeId is not null ? GetEffectiveTypeId(pgTypeId.GetValueOrDefault()) : null; + var effectiveResolution = EffectiveTypeInfo.GetDefaultResolution(effectivePgTypeId); + return new(GetOrAdd(effectiveResolution), pgTypeId ?? _pgTypeId ?? GetPgTypeId(effectiveResolution.PgTypeId)); + } + + public override PgConverterResolution? Get(T? value, PgTypeId? expectedPgTypeId) + { + PgTypeId? expectedEffectiveId = expectedPgTypeId is not null ? GetEffectiveTypeId(expectedPgTypeId.GetValueOrDefault()) : null; + if (GetEffectiveResolution(value, expectedEffectiveId) is { } resolution) + return new PgConverterResolution(GetOrAdd(resolution), expectedPgTypeId ?? _pgTypeId ?? GetPgTypeId(resolution.PgTypeId)); + + return null; + } + + public override PgConverterResolution Get(Field field) + { + var effectiveResolution = EffectiveTypeInfo.GetResolution(field with { PgTypeId = GetEffectiveTypeId(field.PgTypeId) }); + return new PgConverterResolution(GetOrAdd(effectiveResolution), field.PgTypeId); + } + + PgTypeId GetEffectiveTypeId(PgTypeId pgTypeId) + { + if (_pgTypeId == pgTypeId) + return EffectiveTypeInfo.PgTypeId.GetValueOrDefault(); + + // We have an undecided type info which is asked to resolve for a specific type id + // we'll unfortunately have to look up the effective id, this is rare though. + return GetEffectivePgTypeId(pgTypeId); + } + + PgConverter GetOrAdd(PgConverterResolution effectiveResolution) + { + (PgComposingConverterResolver Instance, PgConverterResolution EffectiveResolution) state = (this, effectiveResolution); + return (PgConverter)_converters.GetOrAdd( + effectiveResolution.Converter, + static (_, state) => state.Instance.CreateConverter(state.EffectiveResolution), + state); + } +} diff --git a/src/Npgsql/Internal/PgConverter.cs b/src/Npgsql/Internal/PgConverter.cs new file mode 100644 index 0000000000..e136e9a904 --- /dev/null +++ b/src/Npgsql/Internal/PgConverter.cs @@ -0,0 +1,205 @@ +using System; +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +public abstract class PgConverter +{ + internal DbNullPredicate DbNullPredicateKind { get; } + public bool IsDbNullable => DbNullPredicateKind is not DbNullPredicate.None; + + private protected PgConverter(Type type, bool isNullDefaultValue, bool customDbNullPredicate = false) + => DbNullPredicateKind = customDbNullPredicate ? DbNullPredicate.Custom : InferDbNullPredicate(type, isNullDefaultValue); + + /// + /// Whether this converter can handle the given format and with which buffer requirements. + /// + /// The data format. + /// Returns the buffer requirements. + /// Returns true if the given data format is supported. + /// The buffer requirements should not cover database NULL reads or writes, these are handled by the caller. + public abstract bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements); + + internal abstract Type TypeToConvert { get; } + + internal bool IsDbNullAsObject([NotNullWhen(false)] object? value) + => DbNullPredicateKind switch + { + DbNullPredicate.Null => value is null, + DbNullPredicate.None => false, + DbNullPredicate.PolymorphicNull => value is null or DBNull, + // We do the null check to keep the NotNullWhen(false) invariant. + _ => IsDbNullValueAsObject(value) || (value is null && ThrowInvalidNullValue()) + }; + + private protected abstract bool IsDbNullValueAsObject(object? value); + + internal abstract Size GetSizeAsObject(SizeContext context, object value, ref object? writeState); + + internal object ReadAsObject(PgReader reader) + => ReadAsObject(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + internal ValueTask ReadAsObjectAsync(PgReader reader, CancellationToken cancellationToken = default) + => ReadAsObject(async: true, reader, cancellationToken); + + // Shared sync/async abstract to reduce virtual method table size overhead and code size for each NpgsqlConverter instantiation. + internal abstract ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken); + + internal void WriteAsObject(PgWriter writer, object value) + => WriteAsObject(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + internal ValueTask WriteAsObjectAsync(PgWriter writer, object value, CancellationToken cancellationToken = default) + => WriteAsObject(async: true, writer, value, cancellationToken); + + // Shared sync/async abstract to reduce virtual method table size overhead and code size for each NpgsqlConverter instantiation. + internal abstract ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken); + + static DbNullPredicate InferDbNullPredicate(Type type, bool isNullDefaultValue) + => type == typeof(object) || type == typeof(DBNull) + ? DbNullPredicate.PolymorphicNull + : isNullDefaultValue + ? DbNullPredicate.Null + : DbNullPredicate.None; + + internal enum DbNullPredicate : byte + { + /// Never DbNull (struct types) + None, + /// DbNull when *user code* + Custom, + /// DbNull when value is null + Null, + /// DbNull when value is null or DBNull + PolymorphicNull + } + + [DoesNotReturn] + private protected static void ThrowIORequired() + => throw new InvalidOperationException("Buffer requirements for format not respected, expected no IO to be required."); + + private protected static bool ThrowInvalidNullValue() + => throw new ArgumentNullException("value", "Null value given for non-nullable type converter"); + + protected bool CanConvertBufferedDefault(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Value; + return format is DataFormat.Binary; + } +} + +public abstract class PgConverter : PgConverter +{ + private protected PgConverter(bool customDbNullPredicate) + : base(typeof(T), default(T) is null, customDbNullPredicate) { } + + protected virtual bool IsDbNullValue(T? value) => throw new NotSupportedException(); + + // Object null semantics as follows, if T is a struct (so excluding nullable) report false for null values, don't throw on the cast. + // As a result this creates symmetry with IsDbNull when we're dealing with a struct T, as it cannot be passed null at all. + private protected override bool IsDbNullValueAsObject(object? value) + => (default(T) is null || value is not null) && IsDbNullValue(Downcast(value)); + + public bool IsDbNull([NotNullWhen(false)] T? value) + { + return DbNullPredicateKind switch + { + DbNullPredicate.Null => value is null, + DbNullPredicate.None => false, + DbNullPredicate.PolymorphicNull => value is null or DBNull, + // We do the null check to keep the NotNullWhen(false) invariant. + DbNullPredicate.Custom => IsDbNullValue(value) || (value is null && ThrowInvalidNullValue()), + _ => ThrowOutOfRange() + }; + + bool ThrowOutOfRange() => throw new ArgumentOutOfRangeException(nameof(DbNullPredicateKind), "Unknown case", DbNullPredicateKind.ToString()); + } + + public abstract T Read(PgReader reader); + public abstract ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default); + + public abstract Size GetSize(SizeContext context, [DisallowNull]T value, ref object? writeState); + public abstract void Write(PgWriter writer, [DisallowNull] T value); + public abstract ValueTask WriteAsync(PgWriter writer, [DisallowNull] T value, CancellationToken cancellationToken = default); + + internal sealed override Type TypeToConvert => typeof(T); + + internal sealed override Size GetSizeAsObject(SizeContext context, object value, ref object? writeState) + => GetSize(context, Downcast(value), ref writeState); + + [MethodImpl(MethodImplOptions.NoInlining)] + [return: NotNullIfNotNull(nameof(value))] + static T? Downcast(object? value) => (T?)value; +} + +static class PgConverterExtensions +{ + public static Size? GetSizeOrDbNull(this PgConverter converter, DataFormat format, Size writeRequirement, T? value, ref object? writeState) + { + if (converter.IsDbNull(value)) + return null; + + if (writeRequirement is { Kind: SizeKind.Exact, Value: var byteCount }) + return byteCount; + var size = converter.GetSize(new(format, writeRequirement), value, ref writeState); + if (size.Kind is SizeKind.UpperBound) + throw new InvalidOperationException("SizeKind.UpperBound is not a valid return value for GetSize."); + return size; + } + + public static Size? GetSizeOrDbNullAsObject(this PgConverter converter, DataFormat format, Size writeRequirement, object? value, ref object? writeState) + { + if (converter.IsDbNullAsObject(value)) + return null; + + if (writeRequirement is { Kind: SizeKind.Exact, Value: var byteCount }) + return byteCount; + var size = converter.GetSizeAsObject(new(format, writeRequirement), value, ref writeState); + if (size.Kind is SizeKind.UpperBound) + throw new InvalidOperationException("SizeKind.UpperBound is not a valid return value for GetSize."); + return size; + } +} + +interface IResumableRead +{ + bool Supported { get; } +} + +public readonly struct SizeContext +{ + [SetsRequiredMembers] + public SizeContext(DataFormat format, Size bufferRequirement) + { + Format = format; + BufferRequirement = bufferRequirement; + } + + public DataFormat Format { get; } + public required Size BufferRequirement { get; init; } +} + +class MultiWriteState : IDisposable +{ + public required ArrayPool<(Size Size, object? WriteState)>? ArrayPool { get; init; } + public required ArraySegment<(Size Size, object? WriteState)> Data { get; init; } + public required bool AnyWriteState { get; init; } + + public void Dispose() + { + if (Data.Array is not { } array) + return; + + if (AnyWriteState) + { + for (var i = Data.Offset; i < array.Length; i++) + if (array[i].WriteState is IDisposable disposable) + disposable.Dispose(); + + Array.Clear(Data.Array, Data.Offset, Data.Count); + } + + ArrayPool?.Return(Data.Array); + } +} diff --git a/src/Npgsql/Internal/PgConverterResolver.cs b/src/Npgsql/Internal/PgConverterResolver.cs new file mode 100644 index 0000000000..baee09d58e --- /dev/null +++ b/src/Npgsql/Internal/PgConverterResolver.cs @@ -0,0 +1,109 @@ +using System; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +public abstract class PgConverterResolver +{ + private protected PgConverterResolver() { } + + /// + /// Gets the appropriate converter solely based on PgTypeId. + /// + /// + /// The converter resolution. + /// + /// Implementations should not return new instances of the possible converters that can be returned, instead its expected these are cached once used. + /// Array or other collection converters depend on this to cache their own converter - which wraps the element converter - with the cache key being the element converter reference. + /// + public abstract PgConverterResolution GetDefault(PgTypeId? pgTypeId); + + /// + /// Gets the appropriate converter to read with based on the given field info. + /// + /// + /// The converter resolution. + /// + /// Implementations should not return new instances of the possible converters that can be returned, instead its expected these are cached once used. + /// Array or other collection converters depend on this to cache their own converter - which wraps the element converter - with the cache key being the element converter reference. + /// + public virtual PgConverterResolution Get(Field field) => GetDefault(field.PgTypeId); + + internal abstract Type TypeToConvert { get; } + + internal abstract PgConverterResolution? GetAsObjectInternal(PgTypeInfo typeInfo, object? value, PgTypeId? expectedPgTypeId); + + internal PgConverterResolution GetDefaultInternal(bool validate, bool expectPortableTypeIds, PgTypeId? pgTypeId) + { + var resolution = GetDefault(pgTypeId); + if (validate) + Validate(nameof(GetDefault), resolution, TypeToConvert, pgTypeId, expectPortableTypeIds); + return resolution; + } + + internal PgConverterResolution GetInternal(PgTypeInfo typeInfo, Field field) + { + var resolution = Get(field); + if (typeInfo.ValidateResolution) + Validate(nameof(Get), resolution, TypeToConvert, field.PgTypeId, typeInfo.Options.PortableTypeIds); + return resolution; + } + + private protected static void Validate(string methodName, PgConverterResolution resolution, Type expectedTypeToConvert, PgTypeId? expectedPgTypeId, bool expectPortableTypeIds) + { + if (resolution.Converter is null) + throw new InvalidOperationException($"'{methodName}' returned a null {nameof(PgConverterResolution.Converter)} unexpectedly."); + + // We allow object resolvers to return any converter, this is to help: + // - Composing resolvers being able to use converter type identity (instead of everything being CastingConverter). + // - Reduce indirection by allowing disparate type converters to be returned directly. + // As a consequence any object typed resolver info is always a boxing one, to reduce the chances invalid casts to PgConverter are attempted. + if (expectedTypeToConvert != typeof(object) && resolution.Converter.TypeToConvert != expectedTypeToConvert) + throw new InvalidOperationException($"'{methodName}' returned a {nameof(PgConverterResolution.Converter)} of type {resolution.Converter.TypeToConvert} instead of {expectedTypeToConvert} unexpectedly."); + + if (expectPortableTypeIds && resolution.PgTypeId.IsOid || !expectPortableTypeIds && resolution.PgTypeId.IsDataTypeName) + throw new InvalidOperationException($"{methodName}' returned a resolution with a {nameof(PgConverterResolution.PgTypeId)} that was not in canonical form."); + + if (expectedPgTypeId is not null && resolution.PgTypeId != expectedPgTypeId) + throw new InvalidOperationException( + $"'{methodName}' returned a different {nameof(PgConverterResolution.PgTypeId)} than was passed in as expected." + + $" If such a mismatch occurs an exception should be thrown instead."); + } + + protected ArgumentOutOfRangeException CreateUnsupportedPgTypeIdException(PgTypeId pgTypeId) + => new(nameof(pgTypeId), pgTypeId, "Unsupported PgTypeId."); +} + +public abstract class PgConverterResolver : PgConverterResolver +{ + /// + /// Gets the appropriate converter to write with based on the given value. + /// + /// + /// + /// The converter resolution. + /// + /// Implementations should not return new instances of the possible converters that can be returned, instead its expected these are + /// cached once used. Array or other collection converters depend on this to cache their own converter - which wraps the element + /// converter - with the cache key being the element converter reference. + /// + public abstract PgConverterResolution? Get(T? value, PgTypeId? expectedPgTypeId); + + internal sealed override Type TypeToConvert => typeof(T); + + internal PgConverterResolution? GetInternal(PgTypeInfo typeInfo, T? value, PgTypeId? expectedPgTypeId) + { + var resolution = Get(value, expectedPgTypeId); + if (typeInfo.ValidateResolution && resolution is not null) + Validate(nameof(Get), resolution.GetValueOrDefault(), TypeToConvert, expectedPgTypeId, typeInfo.Options.PortableTypeIds); + return resolution; + } + + internal sealed override PgConverterResolution? GetAsObjectInternal(PgTypeInfo typeInfo, object? value, PgTypeId? expectedPgTypeId) + { + var resolution = Get(value is null ? default : (T)value, expectedPgTypeId); + if (typeInfo.ValidateResolution && resolution is not null) + Validate(nameof(Get), resolution.GetValueOrDefault(), TypeToConvert, expectedPgTypeId, typeInfo.Options.PortableTypeIds); + return resolution; + } +} diff --git a/src/Npgsql/Internal/PgReader.cs b/src/Npgsql/Internal/PgReader.cs new file mode 100644 index 0000000000..f1f448bc65 --- /dev/null +++ b/src/Npgsql/Internal/PgReader.cs @@ -0,0 +1,723 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +public class PgReader +{ + readonly NpgsqlReadBuffer _buffer; + + bool _resumable; + + byte[]? _pooledArray; + NpgsqlReadBuffer.ColumnStream? _userActiveStream; + PreparedTextReader? _preparedTextReader; + + long _fieldStartPos; + Size _fieldBufferRequirement; + DataFormat _fieldFormat; + int _fieldSize; + + // This position is relative to _fieldStartPos, which is why it can be an int. + int _currentStartPos; + Size _currentBufferRequirement; + int _currentSize; + + // GetChars Internal state + TextReader? _charsReadReader; + int _charsRead; + + // GetChars User state + int? _charsReadOffset; + ArraySegment? _charsReadBuffer; + + bool _requiresCleanup; + + internal PgReader(NpgsqlReadBuffer buffer) + { + _buffer = buffer; + _fieldStartPos = -1; + _currentSize = -1; + } + + internal long FieldStartPos => _fieldStartPos; + internal int FieldSize => _fieldSize; + internal bool Initialized => _fieldStartPos is not -1; + internal int FieldOffset => (int)(_buffer.CumulativeReadPosition - _fieldStartPos); + internal int FieldRemaining => FieldSize - FieldOffset; + + bool HasCurrent => _currentSize >= 0; + int CurrentSize => HasCurrent ? _currentSize : _fieldSize; + + public ValueMetadata Current => new() { Size = CurrentSize, Format = _fieldFormat, BufferRequirement = CurrentBufferRequirement }; + public int CurrentRemaining => HasCurrent ? _currentSize - CurrentOffset : FieldRemaining; + + Size CurrentBufferRequirement => HasCurrent ? _currentBufferRequirement : _fieldBufferRequirement; + int CurrentOffset => FieldOffset - _currentStartPos; + + int BufferSize => _buffer.Size; + int BufferBytesRemaining => _buffer.ReadBytesLeft; + + internal bool IsAtStart => FieldOffset is 0; + internal bool Resumable => _resumable; + public bool IsResumed => Resumable && CurrentSize != CurrentRemaining; + + ArrayPool ArrayPool => ArrayPool.Shared; + + [MemberNotNullWhen(true, nameof(_charsReadReader))] + internal bool IsCharsRead => _charsReadOffset is not null; + + // Here for testing purposes + internal void BreakConnection() => throw _buffer.Connector.Break(new Exception("Broken")); + + internal void Revert(int size, int startPos, Size bufferRequirement) + { + if (startPos > FieldOffset) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(startPos), "Can't revert forwardly"); + + _currentStartPos = startPos; + _currentBufferRequirement = bufferRequirement; + _currentSize = size; + } + + [Conditional("DEBUG")] + void CheckBounds(int count) + { + if (count > FieldRemaining) + ThrowHelper.ThrowInvalidOperationException("Attempt to read past the end of the field."); + } + + public byte ReadByte() + { + CheckBounds(sizeof(byte)); + var result = _buffer.ReadByte(); + return result; + } + + public short ReadInt16() + { + CheckBounds(sizeof(short)); + var result = _buffer.ReadInt16(); + return result; + } + + public int ReadInt32() + { + CheckBounds(sizeof(int)); + var result = _buffer.ReadInt32(); + return result; + } + + public long ReadInt64() + { + CheckBounds(sizeof(long)); + var result = _buffer.ReadInt64(); + return result; + } + + public ushort ReadUInt16() + { + CheckBounds(sizeof(ushort)); + var result = _buffer.ReadUInt16(); + return result; + } + + public uint ReadUInt32() + { + CheckBounds(sizeof(uint)); + var result = _buffer.ReadUInt32(); + return result; + } + + public ulong ReadUInt64() + { + CheckBounds(sizeof(ulong)); + var result = _buffer.ReadUInt64(); + return result; + } + + public float ReadFloat() + { + CheckBounds(sizeof(float)); + var result = _buffer.ReadSingle(); + return result; + } + + public double ReadDouble() + { + CheckBounds(sizeof(double)); + var result = _buffer.ReadDouble(); + return result; + } + + public void Read(Span destination) + { + CheckBounds(destination.Length); + _buffer.ReadBytes(destination); + } + + public async ValueTask ReadNullTerminatedStringAsync(Encoding encoding, CancellationToken cancellationToken = default) + { + var result = await _buffer.ReadNullTerminatedString(encoding, async: true, cancellationToken).ConfigureAwait(false); + // Can only check after the fact. + CheckBounds(0); + return result; + } + + public string ReadNullTerminatedString(Encoding encoding) + { + var result = _buffer.ReadNullTerminatedString(encoding, async: false, CancellationToken.None).GetAwaiter().GetResult(); + CheckBounds(0); + return result; + } + public Stream GetStream(int? length = null) => GetColumnStream(false, length); + + internal Stream GetStream(bool canSeek, int? length = null) => GetColumnStream(canSeek, length); + + NpgsqlReadBuffer.ColumnStream GetColumnStream(bool canSeek = false, int? length = null) + { + if (length > CurrentRemaining) + throw new ArgumentOutOfRangeException(nameof(length), "Length is larger than the current remaining value size"); + + _requiresCleanup = true; + // This will cause any previously handed out StreamReaders etc to throw, as intended. + if (_userActiveStream is not null) + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); + + length ??= CurrentRemaining; + CheckBounds(length.GetValueOrDefault()); + return _userActiveStream = _buffer.CreateStream(length.GetValueOrDefault(), canSeek && length <= BufferBytesRemaining); + } + + public TextReader GetTextReader(Encoding encoding) + => GetTextReader(async: false, encoding, CancellationToken.None).GetAwaiter().GetResult(); + + public ValueTask GetTextReaderAsync(Encoding encoding, CancellationToken cancellationToken) + => GetTextReader(async: true, encoding, cancellationToken); + + async ValueTask GetTextReader(bool async, Encoding encoding, CancellationToken cancellationToken) + { + // We don't want to add a ton of memory pressure for large strings. + const int maxPreparedSize = 1024 * 64; + + _requiresCleanup = true; + if (CurrentRemaining > BufferBytesRemaining || CurrentRemaining > maxPreparedSize) + return new StreamReader(GetColumnStream(), encoding, detectEncodingFromByteOrderMarks: false); + + if (_preparedTextReader is { IsDisposed: false }) + { + _preparedTextReader.Dispose(); + _preparedTextReader = null; + } + + _preparedTextReader ??= new PreparedTextReader(); + _preparedTextReader.Init( + encoding.GetString(async + ? await ReadBytesAsync(CurrentRemaining, cancellationToken).ConfigureAwait(false) + : ReadBytes(CurrentRemaining)), GetColumnStream(canSeek: false, 0)); + return _preparedTextReader; + } + + public ValueTask ReadBytesAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var count = buffer.Length; + CheckBounds(count); + if (BufferBytesRemaining >= count) + { + _buffer.Buffer.AsSpan(_buffer.ReadPosition, count).CopyTo(buffer.Span); + _buffer.ReadPosition += count; + return new(); + } + + return Slow(); + + async ValueTask Slow() + { + var stream = _buffer.CreateStream(count, canSeek: false); + await using var _ = stream.ConfigureAwait(false); + await stream.ReadExactlyAsync(buffer, cancellationToken).ConfigureAwait(false); + } + } + + public void ReadBytes(Span buffer) + { + var count = buffer.Length; + CheckBounds(count); + if (BufferBytesRemaining >= count) + { + _buffer.Buffer.AsSpan(_buffer.ReadPosition, count).CopyTo(buffer); + _buffer.ReadPosition += count; + return; + } + + Slow(buffer); + + void Slow(Span buffer) + { + using var stream = _buffer.CreateStream(count, canSeek: false); + stream.ReadExactly(buffer); + } + } + + public bool TryReadBytes(int count, out ReadOnlySpan bytes) + { + CheckBounds(count); + if (BufferBytesRemaining >= count) + { + bytes = new ReadOnlySpan(_buffer.Buffer, _buffer.ReadPosition, count); + _buffer.ReadPosition += count; + return true; + } + bytes = default; + return false; + } + + public bool TryReadBytes(int count, out ReadOnlyMemory bytes) + { + CheckBounds(count); + if (BufferBytesRemaining >= count) + { + bytes = new ReadOnlyMemory(_buffer.Buffer, _buffer.ReadPosition, count); + _buffer.ReadPosition += count; + return true; + } + bytes = default; + return false; + } + + /// ReadBytes without memory management, the next read invalidates the underlying buffer(s), only use this for intermediate transformations. + public ReadOnlySequence ReadBytes(int count) + { + CheckBounds(count); + if (BufferBytesRemaining >= count) + { + var result = new ReadOnlySequence(_buffer.Buffer, _buffer.ReadPosition, count); + _buffer.ReadPosition += count; + return result; + } + + var array = RentArray(count); + ReadBytes(array.AsSpan(0, count)); + return new(array, 0, count); + } + + /// ReadBytesAsync without memory management, the next read invalidates the underlying buffer(s), only use this for intermediate transformations. + public async ValueTask> ReadBytesAsync(int count, CancellationToken cancellationToken = default) + { + CheckBounds(count); + if (BufferBytesRemaining >= count) + { + var result = new ReadOnlySequence(_buffer.Buffer, _buffer.ReadPosition, count); + _buffer.ReadPosition += count; + return result; + } + + var array = RentArray(count); + await ReadBytesAsync(array.AsMemory(0, count), cancellationToken).ConfigureAwait(false); + return new(array, 0, count); + } + + public void Rewind(int count) + { + // Shut down any streaming going on on the column + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); + + if (_buffer.ReadPosition < count) + throw new ArgumentOutOfRangeException("Cannot rewind further than the buffer start"); + + if (CurrentOffset < count) + throw new ArgumentOutOfRangeException("Cannot rewind further than the current field offset"); + + _buffer.ReadPosition -= count; + } + + /// + /// + /// + /// + /// The stream length, if any + async ValueTask DisposeUserActiveStream(bool async) + { + if (_userActiveStream is { IsDisposed: false }) + { + if (async) + await _userActiveStream.DisposeAsync().ConfigureAwait(false); + else + _userActiveStream.Dispose(); + } + + _userActiveStream = null; + } + + internal bool GetCharsReadInfo(Encoding encoding, out int charsRead, out TextReader reader, out int charsOffset, out ArraySegment? buffer) + { + if (!IsCharsRead) + throw new InvalidOperationException("No active chars read"); + + if (_charsReadReader is null) + { + charsRead = 0; + reader = _charsReadReader = GetTextReader(encoding); + charsOffset = _charsReadOffset ??= 0; + buffer = _charsReadBuffer; + return true; + } + + charsRead = _charsRead; + reader = _charsReadReader; + charsOffset = _charsReadOffset!.Value; + buffer = _charsReadBuffer; + + return false; + } + + internal void ResetCharsRead(out int charsRead) + { + if (!IsCharsRead) + throw new InvalidOperationException("No active chars read"); + + switch (_charsReadReader) + { + case PreparedTextReader reader: + reader.Restart(); + break; + case StreamReader reader: + reader.BaseStream.Seek(0, SeekOrigin.Begin); + reader.DiscardBufferedData(); + break; + } + _charsRead = charsRead = 0; + } + + internal void AdvanceCharsRead(int charsRead) + { + _charsRead += charsRead; + _charsReadOffset = null; + _charsReadBuffer = null; + } + + internal void InitCharsRead(int dataOffset, ArraySegment? buffer, out int? charsRead) + { + if (!Resumable) + throw new InvalidOperationException("Wasn't initialized as resumed"); + + charsRead = _charsReadReader is null ? null : _charsRead; + _charsReadOffset = dataOffset; + _charsReadBuffer = buffer; + } + + internal PgReader Init(int fieldLength, DataFormat format, bool resumable = false) + { + if (resumable) + { + if (Resumable) + { + Debug.Assert(Initialized); + return this; + } + _resumable = true; + } + else if (Initialized) + { + if (!IsAtStart) + ThrowHelper.ThrowInvalidOperationException("Cannot be initialized to be non-resumable until a commit is issued."); + _resumable = false; + } + + // Debug.Assert(!Initialized || Resumable, "Reader wasn't properly committed before next init"); + Debug.Assert(!_requiresCleanup, "Reader wasn't properly committed before next init"); + + _fieldStartPos = _buffer.CumulativeReadPosition; + _fieldFormat = format; + _fieldSize = fieldLength; + return this; + } + + internal void StartRead(Size bufferRequirement) + { + Debug.Assert(FieldSize >= 0); + _fieldBufferRequirement = bufferRequirement; + if (ShouldBuffer(bufferRequirement)) + Buffer(bufferRequirement); + } + + internal ValueTask StartReadAsync(Size bufferRequirement, CancellationToken cancellationToken) + { + Debug.Assert(FieldSize >= 0); + _fieldBufferRequirement = bufferRequirement; + return ShouldBuffer(bufferRequirement) ? BufferAsync(bufferRequirement, cancellationToken) : new(); + } + + internal void EndRead() + { + if (_resumable) + return; + + // If it was upper bound we should consume. + if (_fieldBufferRequirement is { Kind: SizeKind.UpperBound }) + { + Consume(FieldRemaining); + return; + } + + if (FieldOffset != FieldSize) + ThrowNotConsumedExactly(); + } + + internal ValueTask EndReadAsync() + { + if (_resumable) + return new(); + + // If it was upper bound we should consume. + if (_fieldBufferRequirement is { Kind: SizeKind.UpperBound }) + return ConsumeAsync(FieldRemaining); + + if (FieldOffset != FieldSize) + ThrowNotConsumedExactly(); + return new(); + } + + internal async ValueTask BeginNestedRead(bool async, int size, Size bufferRequirement, CancellationToken cancellationToken = default) + { + if (size > CurrentRemaining) + throw new ArgumentOutOfRangeException(nameof(size), "Cannot begin a read for a larger size than the current remaining size."); + + if (size < 0) + throw new ArgumentOutOfRangeException(nameof(size), "Cannot be negative"); + + var previousSize = CurrentSize; + var previousStartPos = _currentStartPos; + var previousBufferRequirement = CurrentBufferRequirement; + _currentSize = size; + _currentBufferRequirement = bufferRequirement; + _currentStartPos = FieldOffset; + + await Buffer(async, bufferRequirement, cancellationToken).ConfigureAwait(false); + return new NestedReadScope(async, this, previousSize, previousStartPos, previousBufferRequirement); + } + + public NestedReadScope BeginNestedRead(int size, Size bufferRequirement) + => BeginNestedRead(async: false, size, bufferRequirement, CancellationToken.None).GetAwaiter().GetResult(); + + public ValueTask BeginNestedReadAsync(int size, Size bufferRequirement, CancellationToken cancellationToken = default) + => BeginNestedRead(async: true, size, bufferRequirement, cancellationToken); + + internal void Seek(int offset) + { + if (CurrentOffset > offset) + Rewind(CurrentOffset - offset); + else if (CurrentOffset < offset) + Consume(offset - CurrentOffset); + } + + internal async ValueTask Consume(bool async, int? count = null, CancellationToken cancellationToken = default) + { + if (count <= 0 || FieldSize < 0 || FieldRemaining == 0) + return; + + var remaining = count ?? CurrentRemaining; + CheckBounds(remaining); + + var origOffset = FieldOffset; + // A breaking exception unwind from a nested scope should not try to consume its remaining data. + if (!_buffer.Connector.IsBroken) + await _buffer.Skip(remaining, async).ConfigureAwait(false); + + Debug.Assert(FieldRemaining == FieldSize - origOffset - remaining); + } + + public void Consume(int? count = null) => Consume(async: false, count).GetAwaiter().GetResult(); + public ValueTask ConsumeAsync(int? count = null, CancellationToken cancellationToken = default) => Consume(async: true, count, cancellationToken); + + internal void ThrowIfStreamActive() + { + if (_userActiveStream is { IsDisposed: false}) + ThrowHelper.ThrowInvalidOperationException("A stream is already open for this reader"); + } + + internal bool CommitHasIO(bool resuming) => Initialized && !resuming && FieldRemaining > 0; + internal ValueTask Commit(bool async, bool resuming) + { + if (!Initialized) + return new(); + + if (resuming) + { + if (!Resumable) + ThrowHelper.ThrowInvalidOperationException("Cannot resume a non-resumable read."); + return new(); + } + + // We don't rely on CurrentRemaining, just to make sure we consume fully in the event of a nested scope not being disposed. + // Also shut down any streaming, pooled arrays etc. + if (_requiresCleanup || FieldRemaining > 0) + return Slow(async); + + _fieldSize = default; + _fieldStartPos = -1; + _resumable = false; + _fieldFormat = default; + if (_currentSize is not -1) + { + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = -1; + } + Debug.Assert(!Initialized); + return new(); + + async ValueTask Slow(bool async) + { + // Shut down any streaming and pooling going on on the column. + if (_requiresCleanup) + { + if (_userActiveStream is { IsDisposed: false }) + await DisposeUserActiveStream(async).ConfigureAwait(false); + + if (_pooledArray is not null) + { + ArrayPool.Return(_pooledArray); + _pooledArray = null; + } + + if (_charsReadReader is not null) + { + _charsReadReader.Dispose(); + _charsReadReader = null; + _charsRead = default; + } + _requiresCleanup = false; + } + + await Consume(async, count: FieldRemaining).ConfigureAwait(false); + _fieldSize = default; + _fieldStartPos = -1; + _resumable = false; + _fieldFormat = default; + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = -1; + Debug.Assert(!Initialized); + } + } + + byte[] RentArray(int count) + { + _requiresCleanup = true; + var pooledArray = _pooledArray; + var array = _pooledArray = ArrayPool.Rent(count); + if (pooledArray is not null) + ArrayPool.Return(pooledArray); + return array; + } + + int GetBufferRequirementByteCount(Size bufferRequirement) + => bufferRequirement is { Kind: SizeKind.UpperBound } + ? Math.Min(CurrentRemaining, bufferRequirement.Value) + : bufferRequirement.GetValueOrDefault(); + + internal bool ShouldBufferCurrent() => ShouldBuffer(CurrentBufferRequirement); + + public bool ShouldBuffer(Size bufferRequirement) + => ShouldBuffer(GetBufferRequirementByteCount(bufferRequirement)); + public bool ShouldBuffer(int byteCount) + { + return BufferBytesRemaining < byteCount && ShouldBufferSlow(); + + [MethodImpl(MethodImplOptions.NoInlining)] + bool ShouldBufferSlow() + { + if (byteCount > BufferSize) + ThrowArgumentOutOfRange(); + if (byteCount > CurrentRemaining) + ThrowArgumentOutOfRangeOfValue(); + + return true; + } + + static void ThrowArgumentOutOfRange() + => throw new ArgumentOutOfRangeException(nameof(byteCount), + "Buffer requirement is larger than the buffer size, this can never succeed by buffering data but requires a larger buffer size instead."); + static void ThrowArgumentOutOfRangeOfValue() + => throw new ArgumentOutOfRangeException(nameof(byteCount), + "Buffer requirement is larger than the remaining length of the value, make sure the value is always at least this size or use an upper bound requirement instead."); + } + + public void Buffer(Size bufferRequirement) + => Buffer(GetBufferRequirementByteCount(bufferRequirement)); + public void Buffer(int byteCount) => _buffer.Ensure(byteCount, async: false).GetAwaiter().GetResult(); + + public ValueTask BufferAsync(Size bufferRequirement, CancellationToken cancellationToken) + => BufferAsync(GetBufferRequirementByteCount(bufferRequirement), cancellationToken); + public ValueTask BufferAsync(int byteCount, CancellationToken cancellationToken) => new(_buffer.EnsureAsync(byteCount)); + + internal ValueTask Buffer(bool async, Size bufferRequirement, CancellationToken cancellationToken) + => Buffer(async, GetBufferRequirementByteCount(bufferRequirement), cancellationToken); + internal ValueTask Buffer(bool async, int byteCount, CancellationToken cancellationToken) + { + if (async) + return BufferAsync(byteCount, cancellationToken); + + Buffer(byteCount); + return new(); + } + + void ThrowNotConsumedExactly() => + throw _buffer.Connector.Break( + new InvalidOperationException( + FieldOffset < FieldSize + ? $"The read on this field has not consumed all of its bytes (pos: {FieldOffset}, len: {FieldSize})" + : $"The read on this field has consumed all of its bytes and read into the subsequent bytes (pos: {FieldOffset}, len: {FieldSize})")); +} + +public readonly struct NestedReadScope : IDisposable, IAsyncDisposable +{ + readonly PgReader _reader; + readonly int _previousSize; + readonly int _previousStartPos; + readonly Size _previousBufferRequirement; + readonly bool _async; + + internal NestedReadScope(bool async, PgReader reader, int previousSize, int previousStartPos, Size previousBufferRequirement) + { + _async = async; + _reader = reader; + _previousSize = previousSize; + _previousStartPos = previousStartPos; + _previousBufferRequirement = previousBufferRequirement; + } + + public void Dispose() + { + if (_async) + throw new InvalidOperationException("Cannot synchronously dispose async scopes, call DisposeAsync instead."); + DisposeAsync().GetAwaiter().GetResult(); + } + + public ValueTask DisposeAsync() + { + if (_reader.CurrentRemaining > 0) + { + if (_async) + return AsyncCore(_reader, _previousSize, _previousStartPos, _previousBufferRequirement); + + _reader.Consume(); + } + _reader.Revert(_previousSize, _previousStartPos, _previousBufferRequirement); + return new(); + + static async ValueTask AsyncCore(PgReader reader, int previousSize, int previousStartPos, Size previousBufferRequirement) + { + await reader.ConsumeAsync().ConfigureAwait(false); + reader.Revert(previousSize, previousStartPos, previousBufferRequirement); + } + } +} diff --git a/src/Npgsql/Internal/PgSerializerOptions.cs b/src/Npgsql/Internal/PgSerializerOptions.cs new file mode 100644 index 0000000000..5ee9077458 --- /dev/null +++ b/src/Npgsql/Internal/PgSerializerOptions.cs @@ -0,0 +1,146 @@ +using System; +using System.Runtime.CompilerServices; +using System.Text; +using Npgsql.Internal.Postgres; +using Npgsql.NameTranslation; +using Npgsql.PostgresTypes; + +namespace Npgsql.Internal; + +public sealed class PgSerializerOptions +{ + /// + /// Used by GetSchema to be able to attempt to resolve all type catalog types without exceptions. + /// + [field: ThreadStatic] + internal static bool IntrospectionCaller { get; set; } + + readonly Func? _timeZoneProvider; + readonly object _typeInfoCache; + + internal PgSerializerOptions(NpgsqlDatabaseInfo databaseInfo, Func? timeZoneProvider = null) + { + _timeZoneProvider = timeZoneProvider; + DatabaseInfo = databaseInfo; + UnknownPgType = databaseInfo.GetPostgresType("unknown"); + _typeInfoCache = PortableTypeIds ? new TypeInfoCache(this) : new TypeInfoCache(this); + } + + // Represents the 'unknown' type, which can be used for reading and writing arbitrary text values. + public PostgresType UnknownPgType { get; } + + // Used purely for type mapping, where we don't have a full set of types but resolvers might know enough. + readonly bool _introspectionInstance; + internal bool IntrospectionMode + { + get => _introspectionInstance || IntrospectionCaller; + init => _introspectionInstance = value; + } + + /// Whether options should return a portable identifier (data type name) to prevent any generated id (oid) confusion across backends, this comes with a perf penalty. + internal bool PortableTypeIds { get; init; } + internal NpgsqlDatabaseInfo DatabaseInfo { get; } + + public string TimeZone => _timeZoneProvider?.Invoke() ?? throw new NotSupportedException("TimeZone was not configured."); + public Encoding TextEncoding { get; init; } = Encoding.UTF8; + public required IPgTypeInfoResolver TypeInfoResolver { get; init; } + public bool EnableDateTimeInfinityConversions { get; init; } = true; + + public ArrayNullabilityMode ArrayNullabilityMode { get; init; } = ArrayNullabilityMode.Never; + public INpgsqlNameTranslator DefaultNameTranslator { get; init; } = NpgsqlSnakeCaseNameTranslator.Instance; + + public static Type[] WellKnownTextTypes { get; } = { + typeof(string), typeof(char[]), typeof(byte[]), + typeof(ArraySegment), typeof(ArraySegment?), + typeof(char), typeof(char?) + }; + + // We don't verify the kind of pgTypeId we get, it'll throw if it's incorrect. + // It's up to the caller to call GetCanonicalTypeId if they want to use an oid instead of a DataTypeName. + // This also makes it easier to realize it should be a cached value if infos for different CLR types are requested for the same + // pgTypeId. Effectively it should be 'impossible' to get the wrong kind via any PgConverterOptions api which is what this is mainly + // for. + PgTypeInfo? GetTypeInfoCore(Type? type, PgTypeId? pgTypeId, bool defaultTypeFallback) + => PortableTypeIds + ? Unsafe.As>(_typeInfoCache).GetOrAddInfo(type, pgTypeId?.DataTypeName, defaultTypeFallback) + : Unsafe.As>(_typeInfoCache).GetOrAddInfo(type, pgTypeId?.Oid, defaultTypeFallback); + + public PgTypeInfo? GetDefaultTypeInfo(PostgresType pgType) + => GetTypeInfoCore(null, ToCanonicalTypeId(pgType), false); + + public PgTypeInfo? GetDefaultTypeInfo(PgTypeId pgTypeId) + => GetTypeInfoCore(null, pgTypeId, false); + + public PgTypeInfo? GetTypeInfo(Type type, PostgresType pgType) + => GetTypeInfoCore(type, ToCanonicalTypeId(pgType), false); + + public PgTypeInfo? GetTypeInfo(Type type, PgTypeId? pgTypeId = null) + => GetTypeInfoCore(type, pgTypeId, false); + + public PgTypeInfo? GetObjectOrDefaultTypeInfo(PostgresType pgType) + => GetTypeInfoCore(typeof(object), ToCanonicalTypeId(pgType), true); + + public PgTypeInfo? GetObjectOrDefaultTypeInfo(PgTypeId pgTypeId) + => GetTypeInfoCore(typeof(object), pgTypeId, true); + + // If a given type id is in the opposite form than what was expected it will be mapped according to the requirement. + internal PgTypeId GetCanonicalTypeId(PgTypeId pgTypeId) + => PortableTypeIds ? DatabaseInfo.GetDataTypeName(pgTypeId) : DatabaseInfo.GetOid(pgTypeId); + + // If a given type id is in the opposite form than what was expected it will be mapped according to the requirement. + internal PgTypeId ToCanonicalTypeId(PostgresType pgType) + => PortableTypeIds ? pgType.DataTypeName : (Oid)pgType.OID; + + public PgTypeId GetArrayTypeId(PgTypeId elementTypeId) + { + // Static affordance to help the global type mapper. + if (PortableTypeIds && elementTypeId.IsDataTypeName) + return elementTypeId.DataTypeName.ToArrayName(); + + return ToCanonicalTypeId(DatabaseInfo.GetPostgresType(elementTypeId).Array + ?? throw new NotSupportedException("Cannot resolve array type id")); + } + + public PgTypeId GetArrayElementTypeId(PgTypeId arrayTypeId) + { + // Static affordance to help the global type mapper. + if (PortableTypeIds && arrayTypeId.IsDataTypeName && arrayTypeId.DataTypeName.UnqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal)) + return new DataTypeName(arrayTypeId.DataTypeName.Schema + arrayTypeId.DataTypeName.UnqualifiedNameSpan.Slice(1).ToString()); + + return ToCanonicalTypeId((DatabaseInfo.GetPostgresType(arrayTypeId) as PostgresArrayType)?.Element + ?? throw new NotSupportedException("Cannot resolve array element type id")); + } + + public PgTypeId GetRangeTypeId(PgTypeId subtypeTypeId) => + ToCanonicalTypeId(DatabaseInfo.GetPostgresType(subtypeTypeId).Range + ?? throw new NotSupportedException("Cannot resolve range type id")); + + public PgTypeId GetRangeSubtypeTypeId(PgTypeId rangeTypeId) => + ToCanonicalTypeId((DatabaseInfo.GetPostgresType(rangeTypeId) as PostgresRangeType)?.Subtype + ?? throw new NotSupportedException("Cannot resolve range subtype type id")); + + public PgTypeId GetMultirangeTypeId(PgTypeId rangeTypeId) => + ToCanonicalTypeId((DatabaseInfo.GetPostgresType(rangeTypeId) as PostgresRangeType)?.Multirange + ?? throw new NotSupportedException("Cannot resolve multirange type id")); + + public PgTypeId GetMultirangeElementTypeId(PgTypeId multirangeTypeId) => + ToCanonicalTypeId((DatabaseInfo.GetPostgresType(multirangeTypeId) as PostgresMultirangeType)?.Subrange + ?? throw new NotSupportedException("Cannot resolve multirange element type id")); + + public bool TryGetDataTypeName(PgTypeId pgTypeId, out DataTypeName dataTypeName) + { + if (DatabaseInfo.FindPostgresType(pgTypeId) is { } pgType) + { + dataTypeName = pgType.DataTypeName; + return true; + } + + dataTypeName = default; + return false; + } + + public DataTypeName GetDataTypeName(PgTypeId pgTypeId) + => !TryGetDataTypeName(pgTypeId, out var name) + ? throw new ArgumentException("Unknown type id", nameof(pgTypeId)) + : name; +} diff --git a/src/Npgsql/Internal/PgStreamingConverter.cs b/src/Npgsql/Internal/PgStreamingConverter.cs new file mode 100644 index 0000000000..09176f82d9 --- /dev/null +++ b/src/Npgsql/Internal/PgStreamingConverter.cs @@ -0,0 +1,87 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +public abstract class PgStreamingConverter : PgConverter +{ + protected PgStreamingConverter(bool customDbNullPredicate = false) : base(customDbNullPredicate) { } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary; + } + + internal sealed override unsafe ValueTask ReadAsObject( + bool async, PgReader reader, CancellationToken cancellationToken) + { + if (!async) + return new(Read(reader)!); + + var task = ReadAsync(reader, cancellationToken); + return task.IsCompletedSuccessfully + ? new(task.Result!) + : PgStreamingConverterHelpers.AwaitTask(task.AsTask(), new(this, &BoxResult)); + + static object BoxResult(Task task) + { + Debug.Assert(task is Task); + return new ValueTask(Unsafe.As>(task)).Result; + } + } + + internal sealed override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + if (async) + return WriteAsync(writer, (T)value, cancellationToken); + + Write(writer, (T)value); + return new(); + } +} + +// Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is +// passed along. As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're +// done. +// The alternatives are: +// 1. Add a virtual method and make AwaitTask call into it (bloating the vtable of all derived types). +// 2. Using a delegate, meaning we add a static field + an alloc per T + metadata, slightly slower dispatch perf so overall strictly worse +// as well. +static class PgStreamingConverterHelpers +{ + // Split out from the generic class to amortize the huge size penalty per async state machine, which would otherwise be per + // instantiation. +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + public static async ValueTask AwaitTask(Task task, Continuation continuation) + { + await task.ConfigureAwait(false); + var result = continuation.Invoke(task); + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + return result; + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent + // mistakes. + public readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public object Invoke(Task task) => _continuation(task); + } +} diff --git a/src/Npgsql/Internal/PgTypeInfo.cs b/src/Npgsql/Internal/PgTypeInfo.cs new file mode 100644 index 0000000000..8b0dc22c2d --- /dev/null +++ b/src/Npgsql/Internal/PgTypeInfo.cs @@ -0,0 +1,362 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +public class PgTypeInfo +{ + readonly bool _canBinaryConvert; + readonly BufferRequirements _binaryBufferRequirements; + + readonly bool _canTextConvert; + readonly BufferRequirements _textBufferRequirements; + + PgTypeInfo(PgSerializerOptions options, Type type, Type? unboxedType) + { + if (unboxedType is not null && !type.IsAssignableFrom(unboxedType)) + throw new ArgumentException("A value of unboxed type is not assignable to converter type", nameof(unboxedType)); + + Options = options; + IsBoxing = unboxedType is not null; + Type = unboxedType ?? type; + SupportsWriting = true; + } + + public PgTypeInfo(PgSerializerOptions options, PgConverter converter, PgTypeId pgTypeId, Type? unboxedType = null) + : this(options, converter.TypeToConvert, unboxedType) + { + Converter = converter; + PgTypeId = options.GetCanonicalTypeId(pgTypeId); + _canBinaryConvert = converter.CanConvert(DataFormat.Binary, out _binaryBufferRequirements); + _canTextConvert = converter.CanConvert(DataFormat.Text, out _textBufferRequirements); + } + + private protected PgTypeInfo(PgSerializerOptions options, Type type, PgConverterResolution? resolution, Type? unboxedType = null) + : this(options, type, unboxedType) + { + if (resolution is { } res) + { + // Resolutions should always be in canonical form already. + if (options.PortableTypeIds && res.PgTypeId.IsOid || !options.PortableTypeIds && res.PgTypeId.IsDataTypeName) + throw new ArgumentException("Given type id is not in canonical form. Make sure ConverterResolver implementations close over canonical ids, e.g. by calling options.GetCanonicalTypeId(pgTypeId) on the constructor arguments.", nameof(PgTypeId)); + + PgTypeId = res.PgTypeId; + Converter = res.Converter; + _canBinaryConvert = res.Converter.CanConvert(DataFormat.Binary, out _binaryBufferRequirements); + _canTextConvert = res.Converter.CanConvert(DataFormat.Text, out _textBufferRequirements); + } + } + + bool HasCachedInfo(PgConverter converter) => ReferenceEquals(Converter, converter); + + public Type Type { get; } + public PgSerializerOptions Options { get; } + + public bool SupportsWriting { get; init; } + public DataFormat? PreferredFormat { get; init; } + + // Doubles as the storage for the converter coming from a default resolution (used to confirm whether we can use cached info). + PgConverter? Converter { get; } + [MemberNotNullWhen(false, nameof(Converter))] + [MemberNotNullWhen(false, nameof(PgTypeId))] + internal bool IsResolverInfo => GetType() == typeof(PgResolverTypeInfo); + + // TODO pull validate from options + internal exempt for perf? + internal bool ValidateResolution => true; + + // Used for internal converters to save on binary bloat. + internal bool IsBoxing { get; } + + public PgTypeId? PgTypeId { get; } + + // Having it here so we can easily extend any behavior. + internal void DisposeWriteState(object writeState) + { + if (writeState is IDisposable disposable) + disposable.Dispose(); + } + + internal bool TryBind(Field field, DataFormat format, out PgConverterInfo info) + { + switch (this) + { + case { IsResolverInfo: false }: + // Type lies when IsBoxing is true. + var typeToConvert = IsBoxing ? typeof(object) : Type; + if (!CachedCanConvert(format, out var bufferRequirements)) + { + info = default; + return false; + } + info = CreateConverterInfo(bufferRequirements, isRead: true, Converter, typeToConvert); + return true; + case PgResolverTypeInfo resolverInfo: + var resolution = resolverInfo.GetResolution(field); + if (!HasCachedInfo(resolution.Converter) + ? !CachedCanConvert(format, out bufferRequirements) + : !resolution.Converter.CanConvert(format, out bufferRequirements)) + { + info = default; + return false; + } + info = CreateConverterInfo(bufferRequirements, isRead: true, resolution.Converter, resolution.Converter.TypeToConvert); + return true; + default: + throw new NotSupportedException("Should not happen, please file a bug."); + } + } + + // Bind for reading. + internal PgConverterInfo Bind(Field field, DataFormat format) + { + if (!TryBind(field, format, out var info)) + ThrowHelper.ThrowInvalidOperationException($"Resolved converter does not support {format} format."); + + return info; + } + + public PgConverterResolution GetResolution(T? value) + { + // Other cases, to keep binary bloat minimal. + if (this is not PgResolverTypeInfo resolverInfo) + return GetObjectResolution(null); + var resolution = resolverInfo.GetResolution(value, null); + return resolution ?? resolverInfo.GetDefaultResolution(null); + } + + // Note: this api is not called GetResolutionAsObject as the semantics are extended, DBNull is a NULL value for all object values. + public PgConverterResolution GetObjectResolution(object? value) + { + switch (this) + { + case { IsResolverInfo: false }: + return new(Converter, PgTypeId.GetValueOrDefault()); + case PgResolverTypeInfo resolverInfo: + PgConverterResolution? resolution = null; + if (value is not DBNull) + resolution = resolverInfo.GetResolutionAsObject(value, null); + return resolution ?? resolverInfo.GetDefaultResolution(null); + default: + return ThrowNotSupported(); + } + + static PgConverterResolution ThrowNotSupported() + => throw new NotSupportedException("Should not happen, please file a bug."); + } + + /// Throws if the type info is undecided in its PgTypeId. + internal PgConverterResolution GetConcreteResolution() + { + var pgTypeId = PgTypeId; + if (pgTypeId is null) + ThrowHelper.ThrowInvalidOperationException("PgTypeId is null."); + + return this switch + { + { IsResolverInfo: false } => new(Converter, pgTypeId.GetValueOrDefault()), + PgResolverTypeInfo resolverInfo => resolverInfo.GetDefaultResolution(null), + _ => ThrowNotSupported() + }; + + static PgConverterResolution ThrowNotSupported() + => throw new NotSupportedException("Should not happen, please file a bug."); + } + + PgConverterInfo CreateConverterInfo(BufferRequirements bufferRequirements, bool isRead, PgConverter converter, Type typeToConvert) + => new() + { + TypeInfo = this, + Converter = converter, + AsObject = Type != typeToConvert, + BufferRequirement = isRead ? bufferRequirements.Read : bufferRequirements.Write + }; + + bool CachedCanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + if (format is DataFormat.Binary) + { + bufferRequirements = _binaryBufferRequirements; + return _canBinaryConvert; + } + + bufferRequirements = _textBufferRequirements; + return _canTextConvert; + } + + public BufferRequirements? GetBufferRequirements(PgConverter converter, DataFormat format) + { + var success = HasCachedInfo(converter) + ? CachedCanConvert(format, out var bufferRequirements) + : converter.CanConvert(format, out bufferRequirements); + + return success ? bufferRequirements : null; + } + + // Bind for writing. + /// When result is null, the value was interpreted to be a SQL NULL. + internal PgConverterInfo? Bind(PgConverter converter, T? value, out Size size, out object? writeState, out DataFormat format, DataFormat? formatPreference = null) + { + // Basically exists to catch cases like object[] resolving a polymorphic read converter, better to fail during binding than writing. + if (!SupportsWriting) + ThrowHelper.ThrowNotSupportedException($"Writing {Type} is not supported for this type info."); + + format = ResolveFormat(converter, out var bufferRequirements, formatPreference ?? PreferredFormat); + if (converter.IsDbNull(value)) + { + writeState = null; + size = default; + return null; + } + writeState = null; + var context = new SizeContext(format, bufferRequirements.Write); + size = bufferRequirements.Write is { Kind: SizeKind.Exact } req ? req : converter.GetSize(context, value, ref writeState); + + if (size is { Kind: SizeKind.Unknown}) + ThrowHelper.ThrowNotSupportedException($"Returning {nameof(Size.Unknown)} from {nameof(PgConverter.GetSize)} is not supported yet."); + + return new() + { + TypeInfo = this, + Converter = converter, + AsObject = IsBoxing, + BufferRequirement = bufferRequirements.Write, + }; + } + + // Bind for writing. + // Note: this api is not called BindAsObject as the semantics are extended, DBNull is a NULL value for all object values. + /// When result is null or DBNull, the value was interpreted to be a SQL NULL. + internal PgConverterInfo? BindObject(PgConverter converter, object? value, out Size size, out object? writeState, out DataFormat format, DataFormat? formatPreference = null) + { + // Basically exists to catch cases like object[] resolving a polymorphic read converter, better to fail during binding than writing. + if (!SupportsWriting) + throw new NotSupportedException($"Writing {Type} is not supported for this type info."); + + format = ResolveFormat(converter, out var bufferRequirements, formatPreference ?? PreferredFormat); + + // Given SQL values are effectively a union of T | NULL we support DBNull.Value to signify a NULL value for all types except DBNull in this api. + if (value is DBNull && Type != typeof(DBNull) || converter.IsDbNullAsObject(value)) + { + writeState = null; + size = default; + return null; + } + writeState = null; + var context = new SizeContext(format, bufferRequirements.Write); + size = bufferRequirements.Write is { Kind: SizeKind.Exact } req ? req : converter.GetSizeAsObject(context, value, ref writeState); + + if (size is { Kind: SizeKind.Unknown}) + ThrowHelper.ThrowNotSupportedException($"Returning {nameof(Size.Unknown)} from {nameof(PgConverter.GetSizeAsObject)} is not supported yet."); + + return new() + { + TypeInfo = this, + Converter = converter, + AsObject = Type != typeof(object), + BufferRequirement = bufferRequirements.Write, + }; + } + + // If we don't have a converter stored we must ask the retrieved one. + DataFormat ResolveFormat(PgConverter converter, out BufferRequirements bufferRequirements, DataFormat? formatPreference = null) + { + switch (formatPreference) + { + // The common case, no preference means we default to binary if supported. + case null or DataFormat.Binary when HasCachedInfo(converter) ? CachedCanConvert(DataFormat.Binary, out bufferRequirements) : converter.CanConvert(DataFormat.Binary, out bufferRequirements): + return DataFormat.Binary; + // In this case we either prefer text or we have no preference and our converter doesn't support binary. + case null or DataFormat.Text: + var canTextConvert = HasCachedInfo(converter) ? CachedCanConvert(DataFormat.Text, out bufferRequirements) : converter.CanConvert(DataFormat.Text, out bufferRequirements); + if (!canTextConvert) + { + if (formatPreference is null) + throw new InvalidOperationException("Converter doesn't support any data format."); + // Rerun without preference. + return ResolveFormat(converter, out bufferRequirements); + } + return DataFormat.Text; + default: + throw new ArgumentOutOfRangeException(); + } + } +} + +public sealed class PgResolverTypeInfo : PgTypeInfo +{ + internal readonly PgConverterResolver _converterResolver; + + public PgResolverTypeInfo(PgSerializerOptions options, PgConverterResolver converterResolver, PgTypeId? pgTypeId, Type? unboxedType = null) + : base(options, + converterResolver.TypeToConvert, + pgTypeId is { } typeId ? ResolveDefaultId(options, converterResolver, typeId) : null, + // We always mark resolvers with type object as boxing, as they may freely return converters for any type (see PgConverterResolver.Validate). + unboxedType ?? (converterResolver.TypeToConvert == typeof(object) ? typeof(object) : null)) + => _converterResolver = converterResolver; + + // We'll always validate the default resolution, the info will be re-used so there is no real downside. + static PgConverterResolution ResolveDefaultId(PgSerializerOptions options, PgConverterResolver converterResolver, PgTypeId typeId) + => converterResolver.GetDefaultInternal(validate: true, options.PortableTypeIds, options.GetCanonicalTypeId(typeId)); + + public PgConverterResolution? GetResolution(T? value, PgTypeId? expectedPgTypeId) + { + return _converterResolver is PgConverterResolver resolverT + ? resolverT.GetInternal(this, value, expectedPgTypeId ?? PgTypeId) + : ThrowNotSupportedType(typeof(T)); + + PgConverterResolution ThrowNotSupportedType(Type? type) + => throw new NotSupportedException(IsBoxing + ? "TypeInfo only supports boxing conversions, call GetResolutionAsObject instead." + : $"TypeInfo is not of type {type}"); + } + + public PgConverterResolution? GetResolutionAsObject(object? value, PgTypeId? expectedPgTypeId) + => _converterResolver.GetAsObjectInternal(this, value, expectedPgTypeId ?? PgTypeId); + + public PgConverterResolution GetResolution(Field field) + => _converterResolver.GetInternal(this, field); + + public PgConverterResolution GetDefaultResolution(PgTypeId? pgTypeId) + => _converterResolver.GetDefaultInternal(ValidateResolution, Options.PortableTypeIds, pgTypeId ?? PgTypeId); +} + +public readonly struct PgConverterResolution +{ + public PgConverterResolution(PgConverter converter, PgTypeId pgTypeId) + { + Converter = converter; + PgTypeId = pgTypeId; + } + + public PgConverter Converter { get; } + public PgTypeId PgTypeId { get; } + + public PgConverter GetConverter() => (PgConverter)Converter; +} + +readonly struct PgConverterInfo +{ + public bool IsDefault => TypeInfo is null; + + public Type TypeToConvert + { + get + { + // Object typed resolvers can return any type of converter, so we check the type of the converter instead. + // We cannot do this in general as we should respect the 'unboxed type' of infos, which can differ from the converter type. + if (TypeInfo.IsResolverInfo && TypeInfo.Type == typeof(object)) + return Converter.TypeToConvert; + + return TypeInfo.Type; + } + } + + public required PgTypeInfo TypeInfo { get; init; } + public required PgConverter Converter { get; init; } + public required Size BufferRequirement { get; init; } + // Whether Converter.TypeToConvert matches the PgTypeInfo.Type, if it doesn't object apis and a downcast should be used. + public required bool AsObject { get; init; } + + public PgConverter GetConverter() => (PgConverter)Converter; +} diff --git a/src/Npgsql/Internal/PgWriter.cs b/src/Npgsql/Internal/PgWriter.cs new file mode 100644 index 0000000000..c1e2983e17 --- /dev/null +++ b/src/Npgsql/Internal/PgWriter.cs @@ -0,0 +1,571 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +enum FlushMode +{ + None, + Blocking, + NonBlocking +} + +// A streaming alternative to a System.IO.Stream, instead based on the preferable IBufferWriter. +interface IStreamingWriter: IBufferWriter +{ + void Flush(TimeSpan timeout = default); + ValueTask FlushAsync(CancellationToken cancellationToken = default); +} + +sealed class NpgsqlBufferWriter : IStreamingWriter +{ + readonly NpgsqlWriteBuffer _buffer; + int? _lastBufferSize; + public NpgsqlBufferWriter(NpgsqlWriteBuffer buffer) => _buffer = buffer; + + public void Advance(int count) + { + if (_lastBufferSize < count || _buffer.WriteSpaceLeft < count) + throw new InvalidOperationException("Cannot advance past the end of the current buffer."); + _lastBufferSize = null; + _buffer.WritePosition += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + if (sizeHint > _buffer.WriteSpaceLeft) + throw new OutOfMemoryException("Not enough space left in buffer."); + + var bufferSize = _buffer.WriteSpaceLeft; + _lastBufferSize = bufferSize; + return _buffer.Buffer.AsMemory(_buffer.WritePosition, bufferSize); + } + + public Span GetSpan(int sizeHint = 0) + { + if (sizeHint > _buffer.WriteSpaceLeft) + throw new OutOfMemoryException("Not enough space left in buffer."); + + var bufferSize = _buffer.WriteSpaceLeft; + _lastBufferSize = bufferSize; + return _buffer.Buffer.AsSpan(_buffer.WritePosition, bufferSize); + } + + public void Flush(TimeSpan timeout = default) + { + if (timeout == TimeSpan.Zero) + _buffer.Flush(); + else + { + TimeSpan? originalTimeout = null; + try + { + if (timeout != TimeSpan.Zero) + { + originalTimeout = _buffer.Timeout; + _buffer.Timeout = timeout; + } + _buffer.Flush(); + } + finally + { + if (originalTimeout is { } value) + _buffer.Timeout = value; + } + } + } + + public ValueTask FlushAsync(CancellationToken cancellationToken = default) + => new(_buffer.Flush(async: true, cancellationToken)); +} + +public sealed class PgWriter +{ + readonly IBufferWriter _writer; + + byte[]? _buffer; + int _offset; + int _pos; + int _length; + + int _totalBytesWritten; + + ValueMetadata _current; + NpgsqlDatabaseInfo? _typeCatalog; + + internal PgWriter(IBufferWriter writer) => _writer = writer; + + internal PgWriter Init(NpgsqlDatabaseInfo typeCatalog) + { + if (_typeCatalog is not null) + throw new InvalidOperationException("Invalid concurrent use or PgWriter was not reset properly."); + + _typeCatalog = typeCatalog; + return this; + } + + internal void Reset() + { + if (_pos != _offset) + throw new InvalidOperationException("PgWriter still has uncommitted bytes."); + + _typeCatalog = null; + FlushMode = FlushMode.None; + _totalBytesWritten = 0; + ResetBuffer(); + } + + void ResetBuffer() + { + _buffer = null; + _pos = 0; + _offset = 0; + _length = 0; + } + + internal FlushMode FlushMode { get; private set; } + + internal PgWriter Refresh() + { + if (_buffer is not null) + ResetBuffer(); + return this; + } + + internal PgWriter WithFlushMode(FlushMode mode) + { + FlushMode = mode; + return this; + } + + // TODO if we're working on a normal buffer writer we should use normal Ensure (so commit and get another buffer) semantics. + void Ensure(int count = 1) + { + if (_buffer is null) + SetBuffer(); + + if (count > _length - _pos) + ThrowOutOfRange(); + + void ThrowOutOfRange() => throw new ArgumentOutOfRangeException(nameof(count), "Coud not ensure enough space in buffer."); + [MethodImpl(MethodImplOptions.NoInlining)] + void SetBuffer() + { + // GetMemory will check whether count is larger than the max buffer size. + var mem = _writer.GetMemory(count); + if (!MemoryMarshal.TryGetArray(mem, out var segment)) + throw new NotSupportedException("Only array backed writers are supported."); + + _buffer = segment.Array!; + _offset = segment.Offset; + _pos = segment.Offset; + _length = segment.Offset + segment.Count; + } + } + + Span Span => _buffer.AsSpan(_pos, _length - _pos); + + int Remaining + { + get + { + if (_buffer is null) + Ensure(count: 0); + return _length - _pos; + } + } + + void Advance(int count) => _pos += count; + + internal void Commit(int? expectedByteCount = null) + { + _totalBytesWritten += _pos - _offset; + _writer.Advance(_pos - _offset); + _offset = _pos; + + if (expectedByteCount is not null) + { + var totalBytesWritten = _totalBytesWritten; + _totalBytesWritten = 0; + if (totalBytesWritten != expectedByteCount) + throw new InvalidOperationException($"Bytes written ({totalBytesWritten}) and expected byte count ({expectedByteCount}) don't match."); + } + } + + internal ValueTask BeginWrite(bool async, ValueMetadata current, CancellationToken cancellationToken) + { + _current = current; + if (ShouldFlush(current.BufferRequirement)) + return Flush(async, cancellationToken); + + return new(); + } + + public ValueMetadata Current => _current; + internal Size CurrentBufferRequirement => _current.BufferRequirement; + + // When we don't know the size during writing we're using the writer buffer as a sizing mechanism. + internal bool BufferingWrite => Current.Size.Kind is SizeKind.Unknown; + + // This method lives here to remove the chances oids will be cached on converters inadvertently when data type names should be used. + // Such a mapping (for instance for array element oids) should be done per operation to ensure it is done in the context of a specific backend. + public void WriteAsOid(PgTypeId pgTypeId) + { + var oid = _typeCatalog!.GetOid(pgTypeId); + WriteUInt32((uint)oid); + } + + public void WriteByte(byte value) + { + Ensure(sizeof(byte)); + Span[0] = value; + Advance(sizeof(byte)); + } + + public void WriteInt16(short value) + { + Ensure(sizeof(short)); + BinaryPrimitives.WriteInt16BigEndian(Span, value); + Advance(sizeof(short)); + } + + public void WriteInt32(int value) + { + Ensure(sizeof(int)); + BinaryPrimitives.WriteInt32BigEndian(Span, value); + Advance(sizeof(int)); + } + + public void WriteInt64(long value) + { + Ensure(sizeof(long)); + BinaryPrimitives.WriteInt64BigEndian(Span, value); + Advance(sizeof(long)); + } + + public void WriteUInt16(ushort value) + { + Ensure(sizeof(ushort)); + BinaryPrimitives.WriteUInt16BigEndian(Span, value); + Advance(sizeof(ushort)); + } + + public void WriteUInt32(uint value) + { + Ensure(sizeof(uint)); + BinaryPrimitives.WriteUInt32BigEndian(Span, value); + Advance(sizeof(uint)); + } + + public void WriteUInt64(ulong value) + { + Ensure(sizeof(ulong)); + BinaryPrimitives.WriteUInt64BigEndian(Span, value); + Advance(sizeof(ulong)); + } + + public void WriteFloat(float value) + { +#if NET5_0_OR_GREATER + Ensure(sizeof(float)); + BinaryPrimitives.WriteSingleBigEndian(Span, value); + Advance(sizeof(float)); +#else + WriteUInt32(Unsafe.As(ref value)); +#endif + } + + public void WriteDouble(double value) + { +#if NET5_0_OR_GREATER + Ensure(sizeof(double)); + BinaryPrimitives.WriteDoubleBigEndian(Span, value); + Advance(sizeof(double)); +#else + WriteUInt64(Unsafe.As(ref value)); +#endif + } + + public void WriteChars(ReadOnlySpan data, Encoding encoding) + { + // If we have more chars than bytes remaining we can immediately go to the slow path. + if (data.Length <= Remaining) + { + // If not, it's worth a shot to see if we can convert in one go. + var encodedLength = encoding.GetByteCount(data); + if (!ShouldFlush(encodedLength)) + { + var count = encoding.GetBytes(data, Span); + Advance(count); + return; + } + } + Core(data, encoding); + + void Core(ReadOnlySpan data, Encoding encoding) + { + var encoder = encoding.GetEncoder(); + var minBufferSize = encoding.GetMaxByteCount(1); + + bool completed; + do + { + if (ShouldFlush(minBufferSize)) + Flush(); + Ensure(minBufferSize); + encoder.Convert(data, Span, flush: data.Length <= Span.Length, out var charsUsed, out var bytesUsed, out completed); + data = data.Slice(charsUsed); + Advance(bytesUsed); + } while (!completed); + } + } + + public ValueTask WriteCharsAsync(ReadOnlyMemory data, Encoding encoding, CancellationToken cancellationToken = default) + { + var dataSpan = data.Span; + // If we have more chars than bytes remaining we can immediately go to the slow path. + if (data.Length <= Remaining) + { + // If not, it's worth a shot to see if we can convert in one go. + var encodedLength = encoding.GetByteCount(dataSpan); + if (!ShouldFlush(encodedLength)) + { + var count = encoding.GetBytes(dataSpan, Span); + Advance(count); + return new(); + } + } + + return Core(data, encoding, cancellationToken); + + async ValueTask Core(ReadOnlyMemory data, Encoding encoding, CancellationToken cancellationToken) + { + var encoder = encoding.GetEncoder(); + var minBufferSize = encoding.GetMaxByteCount(1); + + bool completed; + do + { + if (ShouldFlush(minBufferSize)) + await FlushAsync(cancellationToken).ConfigureAwait(false); + Ensure(minBufferSize); + encoder.Convert(data.Span, Span, flush: data.Length <= Span.Length, out var charsUsed, out var bytesUsed, out completed); + data = data.Slice(charsUsed); + Advance(bytesUsed); + } while (!completed); + } + } + + public void WriteBytes(ReadOnlySpan buffer) + { + while (!buffer.IsEmpty) + { + var write = Math.Min(buffer.Length, Remaining); + buffer.Slice(0, write).CopyTo(Span); + Advance(write); + buffer = buffer.Slice(write); + if (Remaining is 0) + Flush(); + } + } + + public ValueTask WriteBytesAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length <= Remaining) + { + buffer.Span.CopyTo(Span); + Advance(buffer.Length); + return new(); + } + + return Core(buffer, cancellationToken); + + async ValueTask Core(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + while (!buffer.IsEmpty) + { + var write = Math.Min(buffer.Length, Remaining); + buffer.Span.Slice(0, write).CopyTo(Span); + Advance(write); + buffer = buffer.Slice(write); + if (Remaining is 0) + await FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + } + + public Stream GetStream() + => new PgWriterStream(this); + + public bool ShouldFlush(Size bufferRequirement) + => ShouldFlush(bufferRequirement is { Kind: SizeKind.UpperBound } + ? Math.Min(Current.Size.Value, bufferRequirement.Value) + : bufferRequirement.GetValueOrDefault()); + + public bool ShouldFlush(int byteCount) => Remaining < byteCount && FlushMode is not FlushMode.None; + + public void Flush(TimeSpan timeout = default) + { + switch (FlushMode) + { + case FlushMode.None: + return; + case FlushMode.NonBlocking: + throw new NotSupportedException($"Cannot call {nameof(Flush)} on a non-blocking {nameof(PgWriter)}, you might need to override {nameof(PgConverter.WriteAsync)} on {nameof(PgConverter)} if you want to call flush."); + } + + if (_writer is not IStreamingWriter writer) + throw new NotSupportedException($"Cannot call {nameof(Flush)} on a buffered {nameof(PgWriter)}, {nameof(FlushMode)}.{nameof(FlushMode.None)} should be used to prevent this."); + + Commit(); + ResetBuffer(); + writer.Flush(timeout); + } + + public ValueTask FlushAsync(CancellationToken cancellationToken = default) + { + switch (FlushMode) + { + case FlushMode.None: + return new(); + case FlushMode.Blocking: + throw new NotSupportedException($"Cannot call {nameof(FlushAsync)} on a blocking {nameof(PgWriter)}, call Flush instead."); + } + + if (_writer is not IStreamingWriter writer) + throw new NotSupportedException($"Cannot call {nameof(FlushAsync)} on a buffered {nameof(PgWriter)}, {nameof(FlushMode)}.{nameof(FlushMode.None)} should be used to prevent this."); + + Commit(); + ResetBuffer(); + return writer.FlushAsync(cancellationToken); + } + + internal ValueTask Flush(bool async, CancellationToken cancellationToken = default) + { + if (async) + return FlushAsync(cancellationToken); + + Flush(); + return new(); + } + + internal ValueTask BeginNestedWrite(bool async, Size bufferRequirement, int byteCount, object? state, CancellationToken cancellationToken) + { + Debug.Assert(bufferRequirement != -1); + if (ShouldFlush(bufferRequirement)) + return Core(async, bufferRequirement, byteCount, state, cancellationToken); + + _current = new() { Format = _current.Format, Size = byteCount, BufferRequirement = bufferRequirement, WriteState = state }; + + return new(new NestedWriteScope()); +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask Core(bool async, Size bufferRequirement, int byteCount, object? state, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + _current = new() { Format = _current.Format, Size = byteCount, BufferRequirement = bufferRequirement, WriteState = state }; + return new(); + } + } + + public NestedWriteScope BeginNestedWrite(Size bufferRequirement, int byteCount, object? state) + => BeginNestedWrite(async: false, bufferRequirement, byteCount, state, CancellationToken.None).GetAwaiter().GetResult(); + + public ValueTask BeginNestedWriteAsync(Size bufferRequirement, int byteCount, object? state, CancellationToken cancellationToken = default) + => BeginNestedWrite(async: true, bufferRequirement, byteCount, state, cancellationToken); + + sealed class PgWriterStream : Stream + { + readonly PgWriter _writer; + + internal PgWriterStream(PgWriter writer) + => _writer = writer; + + public override void Write(byte[] buffer, int offset, int count) + => Write(async: false, buffer: buffer, offset: offset, count: count, CancellationToken.None).GetAwaiter().GetResult(); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => Write(async: true, buffer: buffer, offset: offset, count: count, cancellationToken: cancellationToken); + + Task Write(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (buffer is null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0) + throw new ArgumentNullException(nameof(offset)); + if (count < 0) + throw new ArgumentNullException(nameof(count)); + if (buffer.Length - offset < count) + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + + if (async) + { + if (cancellationToken.IsCancellationRequested) + return Task.FromCanceled(cancellationToken); + + return _writer.WriteBytesAsync(buffer, cancellationToken).AsTask(); + } + + _writer.WriteBytes(new Span(buffer, offset, count)); + return Task.CompletedTask; + } + +#if !NETSTANDARD2_0 + public override void Write(ReadOnlySpan buffer) => _writer.WriteBytes(buffer); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + return new(Task.FromCanceled(cancellationToken)); + + return _writer.WriteBytesAsync(buffer, cancellationToken); + } +#endif + + public override void Flush() + => _writer.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) + => _writer.FlushAsync(cancellationToken).AsTask(); + + public override bool CanRead => false; + public override bool CanWrite => true; + public override bool CanSeek => false; + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotSupportedException(); + + public override long Length => throw new NotSupportedException(); + public override void SetLength(long value) + => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + } +} + +// No-op for now. +public struct NestedWriteScope : IDisposable +{ + public void Dispose() + { + } +} diff --git a/src/Npgsql/Internal/Postgres/DataTypeName.cs b/src/Npgsql/Internal/Postgres/DataTypeName.cs new file mode 100644 index 0000000000..2384ec723d --- /dev/null +++ b/src/Npgsql/Internal/Postgres/DataTypeName.cs @@ -0,0 +1,234 @@ +using System; +using System.Diagnostics; + +namespace Npgsql.Internal.Postgres; + +/// +/// Represents the fully-qualified name of a PostgreSQL type. +/// +[DebuggerDisplay("{DisplayName,nq}")] +public readonly struct DataTypeName : IEquatable +{ + /// + /// The maximum length of names in an unmodified PostgreSQL installation. + /// + /// + /// We need to respect this to get to valid names when deriving them (for multirange/arrays etc). + /// This does not include the namespace. + /// + const int NAMEDATALEN = 64 - 1; // Minus null terminator. + + readonly string _value; + + DataTypeName(string fullyQualifiedDataTypeName, bool validated) + { + if (!validated) + { + var schemaEndIndex = fullyQualifiedDataTypeName.IndexOf('.'); + if (schemaEndIndex == -1) + throw new ArgumentException("Given value does not contain a schema.", nameof(fullyQualifiedDataTypeName)); + + // Friendly array syntax is the only fully qualified name quirk that's allowed by postgres (see FromDisplayName). + if (fullyQualifiedDataTypeName.AsSpan(schemaEndIndex).EndsWith("[]".AsSpan())) + fullyQualifiedDataTypeName = NormalizeName(fullyQualifiedDataTypeName); + + var typeNameLength = fullyQualifiedDataTypeName.Length - schemaEndIndex + 1; + if (typeNameLength > NAMEDATALEN) + throw new ArgumentException( + $"Name is too long and would be truncated to: {fullyQualifiedDataTypeName.Substring(0, fullyQualifiedDataTypeName.Length - typeNameLength + NAMEDATALEN)}"); + } + + _value = fullyQualifiedDataTypeName; + } + + public DataTypeName(string fullyQualifiedDataTypeName) + : this(fullyQualifiedDataTypeName, validated: false) { } + + internal static DataTypeName ValidatedName(string fullyQualifiedDataTypeName) + => new(fullyQualifiedDataTypeName, validated: true); + + // Includes schema unless it's pg_catalog. + public string DisplayName => + Value.StartsWith("pg_catalog", StringComparison.Ordinal) + ? UnqualifiedDisplayName + : Schema + "." + UnqualifiedDisplayName; + + public string UnqualifiedDisplayName => ToDisplayName(UnqualifiedNameSpan); + + public string Schema => Value.Substring(0, _value.IndexOf('.')); + internal ReadOnlySpan UnqualifiedNameSpan => Value.AsSpan().Slice(_value.IndexOf('.') + 1); + public string UnqualifiedName => Value.Substring(_value.IndexOf('.') + 1); + public string Value => _value is null ? ThrowDefaultException() : _value; + + static string ThrowDefaultException() => + throw new InvalidOperationException($"This operation cannot be performed on a default instance of {nameof(DataTypeName)}."); + + public static implicit operator string(DataTypeName value) => value.Value; + + public bool IsDefault => _value is null; + + public bool IsArray => UnqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal); + + internal static DataTypeName CreateFullyQualifiedName(string dataTypeName) + => dataTypeName.IndexOf('.') != -1 ? new(dataTypeName) : new("pg_catalog." + dataTypeName); + + // Static transform as defined by https://www.postgresql.org/docs/current/sql-createtype.html#SQL-CREATETYPE-ARRAY + // We don't have to deal with [] as we're always starting from a normalized fully qualified name. + public DataTypeName ToArrayName() + { + var unqualifiedNameSpan = UnqualifiedNameSpan; + if (unqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal)) + return this; + + var unqualifiedName = unqualifiedNameSpan.ToString(); + if (unqualifiedName.Length + "_".Length > NAMEDATALEN) + unqualifiedName = unqualifiedName.Substring(0, NAMEDATALEN - "_".Length); + + return new(Schema + "._" + unqualifiedName); + } + + // Static transform as defined by https://www.postgresql.org/docs/current/sql-createtype.html#SQL-CREATETYPE-RANGE + // Manual testing on PG confirmed it's only the first occurence of 'range' that gets replaced. + public DataTypeName ToDefaultMultirangeName() + { + var unqualifiedNameSpan = UnqualifiedNameSpan; + if (UnqualifiedNameSpan.IndexOf("multirange".AsSpan(), StringComparison.Ordinal) != -1) + return this; + + var unqualifiedName = unqualifiedNameSpan.ToString(); + var rangeIndex = unqualifiedName.IndexOf("range", StringComparison.Ordinal); + if (rangeIndex != -1) + { + var str = unqualifiedName.Substring(0, rangeIndex) + "multirange" + unqualifiedName.Substring(rangeIndex + "range".Length); + + return new($"{Schema}." + (unqualifiedName.Length + "multi".Length > NAMEDATALEN + ? str.Substring(0, NAMEDATALEN - "multi".Length) + : str)); + } + + return new($"{Schema}." + (unqualifiedName.Length + "multi".Length > NAMEDATALEN + ? unqualifiedName.Substring(0, NAMEDATALEN - "_multirange".Length) + "_multirange" + : unqualifiedName + "_multirange")); + } + + // Create a DataTypeName from a broader range of valid names. + // including SQL aliases like 'timestamp without time zone', trailing facet info etc. + public static DataTypeName FromDisplayName(string displayName, string? schema = null) + { + var displayNameSpan = displayName.AsSpan().Trim(); + + // If we have a schema we're done, Postgres doesn't do display name conversions on fully qualified names. + // There is one exception and that's array syntax, which is always resolvable in both ways, while we want the canonical name. + var schemaEndIndex = displayNameSpan.IndexOf('.'); + if (schemaEndIndex is not -1 && + !displayNameSpan.Slice(schemaEndIndex).StartsWith("_".AsSpan(), StringComparison.Ordinal) && + !displayNameSpan.EndsWith("[]".AsSpan(), StringComparison.Ordinal)) + return new(displayName); + + // First we strip the schema to get the type name. + if (schemaEndIndex is not -1) + { + schema = displayNameSpan.Slice(0, schemaEndIndex).ToString(); + displayNameSpan = displayNameSpan.Slice(schemaEndIndex + 1); + } + + // Then we strip either of the two valid array representations to get the base type name (with or without facets). + var isArray = false; + if (displayNameSpan.StartsWith("_".AsSpan())) + { + isArray = true; + displayNameSpan = displayNameSpan.Slice(1); + } + else if (displayNameSpan.EndsWith("[]".AsSpan())) + { + isArray = true; + displayNameSpan = displayNameSpan.Slice(0, displayNameSpan.Length - 2); + } + + string mapped; + if (schemaEndIndex is -1) + { + // Finally we strip the facet info. + var parenIndex = displayNameSpan.IndexOf('('); + if (parenIndex > -1) + displayNameSpan = displayNameSpan.Slice(0, parenIndex); + + // Map any aliases to the internal type name. + mapped = displayNameSpan.ToString() switch + { + "boolean" => "bool", + "character" => "bpchar", + "decimal" => "numeric", + "real" => "float4", + "double precision" => "float8", + "smallint" => "int2", + "integer" => "int4", + "bigint" => "int8", + "time without time zone" => "time", + "timestamp without time zone" => "timestamp", + "time with time zone" => "timetz", + "timestamp with time zone" => "timestamptz", + "bit varying" => "varbit", + "character varying" => "varchar", + var value => value + }; + } + else + { + // If we had a schema originally we stop here, see comment at schemaEndIndex. + mapped = displayNameSpan.ToString(); + } + + return new((schema ?? "pg_catalog") + "." + (isArray ? "_" : "") + mapped); + } + + // The type names stored in a DataTypeName are usually the actual typname from the pg_type column. + // There are some canonical aliases defined in the SQL standard which we take into account. + // Additionally array types have a '_' prefix while for readability their element type should be postfixed with '[]'. + // See the table for all the aliases https://www.postgresql.org/docs/current/static/datatype.html#DATATYPE-TABLE + // Alternatively some of the source lives at https://github.com/postgres/postgres/blob/c8e1ba736b2b9e8c98d37a5b77c4ed31baf94147/src/backend/utils/adt/format_type.c#L186 + static string ToDisplayName(ReadOnlySpan unqualifiedName) + { + var isArray = unqualifiedName.IndexOf('_') == 0; + var baseTypeName = isArray ? unqualifiedName.Slice(1).ToString() : unqualifiedName.ToString(); + + var mappedBaseType = baseTypeName switch + { + "bool" => "boolean", + "bpchar" => "character", + "decimal" => "numeric", + "float4" => "real", + "float8" => "double precision", + "int2" => "smallint", + "int4" => "integer", + "int8" => "bigint", + "time" => "time without time zone", + "timestamp" => "timestamp without time zone", + "timetz" => "time with time zone", + "timestamptz" => "timestamp with time zone", + "varbit" => "bit varying", + "varchar" => "character varying", + _ => baseTypeName + }; + + if (isArray) + return mappedBaseType + "[]"; + + return mappedBaseType; + } + + internal static bool IsFullyQualified(ReadOnlySpan dataTypeName) => dataTypeName.Contains(".".AsSpan(), StringComparison.Ordinal); + + internal static string NormalizeName(string dataTypeName) + { + var fqName = FromDisplayName(dataTypeName); + return IsFullyQualified(dataTypeName.AsSpan()) ? fqName.Value : fqName.UnqualifiedName; + } + + public override string ToString() => Value; + public bool Equals(DataTypeName other) => !IsDefault && !other.IsDefault && _value == other._value; + public override bool Equals(object? obj) => obj is DataTypeName other && Equals(other); + public override int GetHashCode() => _value.GetHashCode(); + public static bool operator ==(DataTypeName left, DataTypeName right) => left.Equals(right); + public static bool operator !=(DataTypeName left, DataTypeName right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/Postgres/DataTypeNames.cs b/src/Npgsql/Internal/Postgres/DataTypeNames.cs new file mode 100644 index 0000000000..275bcb9937 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/DataTypeNames.cs @@ -0,0 +1,79 @@ +using static Npgsql.Internal.Postgres.DataTypeName; + +namespace Npgsql.Internal.Postgres; + +/// +/// Well-known PostgreSQL data type names. +/// +static class DataTypeNames +{ + // Note: The names are fully qualified in source so the strings are constants and instances will be interned after the first call. + // Uses an internal constructor bypassing the public DataTypeName constructor validation, as we don't want to store all these names on + // fields either. + public static DataTypeName Int2 => ValidatedName("pg_catalog.int2"); + public static DataTypeName Int4 => ValidatedName("pg_catalog.int4"); + public static DataTypeName Int4Range => ValidatedName("pg_catalog.int4range"); + public static DataTypeName Int4Multirange => ValidatedName("pg_catalog.int4multirange"); + public static DataTypeName Int8 => ValidatedName("pg_catalog.int8"); + public static DataTypeName Int8Range => ValidatedName("pg_catalog.int8range"); + public static DataTypeName Int8Multirange => ValidatedName("pg_catalog.int8multirange"); + public static DataTypeName Float4 => ValidatedName("pg_catalog.float4"); + public static DataTypeName Float8 => ValidatedName("pg_catalog.float8"); + public static DataTypeName Numeric => ValidatedName("pg_catalog.numeric"); + public static DataTypeName NumRange => ValidatedName("pg_catalog.numrange"); + public static DataTypeName NumMultirange => ValidatedName("pg_catalog.nummultirange"); + public static DataTypeName Money => ValidatedName("pg_catalog.money"); + public static DataTypeName Bool => ValidatedName("pg_catalog.bool"); + public static DataTypeName Box => ValidatedName("pg_catalog.box"); + public static DataTypeName Circle => ValidatedName("pg_catalog.circle"); + public static DataTypeName Line => ValidatedName("pg_catalog.line"); + public static DataTypeName LSeg => ValidatedName("pg_catalog.lseg"); + public static DataTypeName Path => ValidatedName("pg_catalog.path"); + public static DataTypeName Point => ValidatedName("pg_catalog.point"); + public static DataTypeName Polygon => ValidatedName("pg_catalog.polygon"); + public static DataTypeName Bpchar => ValidatedName("pg_catalog.bpchar"); + public static DataTypeName Text => ValidatedName("pg_catalog.text"); + public static DataTypeName Varchar => ValidatedName("pg_catalog.varchar"); + public static DataTypeName Char => ValidatedName("pg_catalog.char"); + public static DataTypeName Name => ValidatedName("pg_catalog.name"); + public static DataTypeName Bytea => ValidatedName("pg_catalog.bytea"); + public static DataTypeName Date => ValidatedName("pg_catalog.date"); + public static DataTypeName DateRange => ValidatedName("pg_catalog.daterange"); + public static DataTypeName DateMultirange => ValidatedName("pg_catalog.datemultirange"); + public static DataTypeName Time => ValidatedName("pg_catalog.time"); + public static DataTypeName Timestamp => ValidatedName("pg_catalog.timestamp"); + public static DataTypeName TsRange => ValidatedName("pg_catalog.tsrange"); + public static DataTypeName TsMultirange => ValidatedName("pg_catalog.tsmultirange"); + public static DataTypeName TimestampTz => ValidatedName("pg_catalog.timestamptz"); + public static DataTypeName TsTzRange => ValidatedName("pg_catalog.tstzrange"); + public static DataTypeName TsTzMultirange => ValidatedName("pg_catalog.tstzmultirange"); + public static DataTypeName Interval => ValidatedName("pg_catalog.interval"); + public static DataTypeName TimeTz => ValidatedName("pg_catalog.timetz"); + public static DataTypeName Inet => ValidatedName("pg_catalog.inet"); + public static DataTypeName Cidr => ValidatedName("pg_catalog.cidr"); + public static DataTypeName MacAddr => ValidatedName("pg_catalog.macaddr"); + public static DataTypeName MacAddr8 => ValidatedName("pg_catalog.macaddr8"); + public static DataTypeName Bit => ValidatedName("pg_catalog.bit"); + public static DataTypeName Varbit => ValidatedName("pg_catalog.varbit"); + public static DataTypeName TsVector => ValidatedName("pg_catalog.tsvector"); + public static DataTypeName TsQuery => ValidatedName("pg_catalog.tsquery"); + public static DataTypeName RegConfig => ValidatedName("pg_catalog.regconfig"); + public static DataTypeName Uuid => ValidatedName("pg_catalog.uuid"); + public static DataTypeName Xml => ValidatedName("pg_catalog.xml"); + public static DataTypeName Json => ValidatedName("pg_catalog.json"); + public static DataTypeName Jsonb => ValidatedName("pg_catalog.jsonb"); + public static DataTypeName Jsonpath => ValidatedName("pg_catalog.jsonpath"); + public static DataTypeName Record => ValidatedName("pg_catalog.record"); + public static DataTypeName RefCursor => ValidatedName("pg_catalog.refcursor"); + public static DataTypeName OidVector => ValidatedName("pg_catalog.oidvector"); + public static DataTypeName Int2Vector => ValidatedName("pg_catalog.int2vector"); + public static DataTypeName Oid => ValidatedName("pg_catalog.oid"); + public static DataTypeName Xid => ValidatedName("pg_catalog.xid"); + public static DataTypeName Xid8 => ValidatedName("pg_catalog.xid8"); + public static DataTypeName Cid => ValidatedName("pg_catalog.cid"); + public static DataTypeName RegType => ValidatedName("pg_catalog.regtype"); + public static DataTypeName Tid => ValidatedName("pg_catalog.tid"); + public static DataTypeName PgLsn => ValidatedName("pg_catalog.pg_lsn"); + public static DataTypeName Unknown => ValidatedName("pg_catalog.unknown"); + public static DataTypeName Void => ValidatedName("pg_catalog.void"); +} diff --git a/src/Npgsql/Internal/Postgres/Field.cs b/src/Npgsql/Internal/Postgres/Field.cs new file mode 100644 index 0000000000..f6a261c103 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/Field.cs @@ -0,0 +1,16 @@ +namespace Npgsql.Internal.Postgres; + +/// Base field type shared between tables and composites. +public readonly struct Field +{ + public Field(string name, PgTypeId pgTypeId, int typeModifier) + { + Name = name; + PgTypeId = pgTypeId; + TypeModifier = typeModifier; + } + + public string Name { get; init; } + public PgTypeId PgTypeId { get; init; } + public int TypeModifier { get; init; } +} diff --git a/src/Npgsql/Internal/Postgres/Oid.cs b/src/Npgsql/Internal/Postgres/Oid.cs new file mode 100644 index 0000000000..ac9577609d --- /dev/null +++ b/src/Npgsql/Internal/Postgres/Oid.cs @@ -0,0 +1,19 @@ +using System; + +namespace Npgsql.Internal.Postgres; + +public readonly struct Oid: IEquatable +{ + public Oid(uint value) => Value = value; + + public static explicit operator uint(Oid oid) => oid.Value; + public static implicit operator Oid(uint oid) => new(oid); + public uint Value { get; init; } + + public override string ToString() => Value.ToString(); + public bool Equals(Oid other) => Value == other.Value; + public override bool Equals(object? obj) => obj is Oid other && Equals(other); + public override int GetHashCode() => (int)Value; + public static bool operator ==(Oid left, Oid right) => left.Equals(right); + public static bool operator !=(Oid left, Oid right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/Postgres/PgTypeId.cs b/src/Npgsql/Internal/Postgres/PgTypeId.cs new file mode 100644 index 0000000000..e363969a47 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/PgTypeId.cs @@ -0,0 +1,44 @@ +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal.Postgres; + +/// +/// A discriminated union of and . +/// +public readonly struct PgTypeId: IEquatable +{ + readonly DataTypeName _dataTypeName; + readonly Oid _oid; + + public PgTypeId(DataTypeName name) => _dataTypeName = name; + public PgTypeId(Oid oid) => _oid = oid; + + [MemberNotNullWhen(true, nameof(_dataTypeName))] + public bool IsDataTypeName => !_dataTypeName.IsDefault; + public bool IsOid => _dataTypeName.IsDefault; + + public DataTypeName DataTypeName + => IsDataTypeName ? _dataTypeName : throw new InvalidOperationException("This value does not describe a DataTypeName."); + + public Oid Oid + => IsOid ? _oid : throw new InvalidOperationException("This value does not describe an Oid."); + + public static implicit operator PgTypeId(DataTypeName name) => new(name); + public static implicit operator PgTypeId(Oid id) => new(id); + + public override string ToString() => IsOid ? _oid.ToString() : _dataTypeName.Value; + + public bool Equals(PgTypeId other) + => (this, other) switch + { + ({ IsOid: true }, { IsOid: true }) => _oid == other._oid, + ({ IsDataTypeName: true }, { IsDataTypeName: true }) => _dataTypeName.Equals(other._dataTypeName), + _ => false + }; + + public override bool Equals(object? obj) => obj is PgTypeId other && Equals(other); + public override int GetHashCode() => IsOid ? _oid.GetHashCode() : _dataTypeName.GetHashCode(); + public static bool operator ==(PgTypeId left, PgTypeId right) => left.Equals(right); + public static bool operator !=(PgTypeId left, PgTypeId right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/Resolvers/AdoTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/AdoTypeInfoResolver.cs new file mode 100644 index 0000000000..0f2c077aad --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/AdoTypeInfoResolver.cs @@ -0,0 +1,491 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Diagnostics; +using System.IO; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Converters.Internal; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using Npgsql.Util; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +// Baseline types that are always supported. +class AdoTypeInfoResolver : IPgTypeInfoResolver +{ + public AdoTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + } + + public static AdoTypeInfoResolver Instance { get; } = new(); + + protected TypeInfoMappingCollection Mappings { get; } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + var info = Mappings.Find(type, dataTypeName, options); + if (info is null && dataTypeName is not null) + info = GetEnumTypeInfo(type, dataTypeName.GetValueOrDefault(), options); + return info; + } + + protected static PgTypeInfo? GetEnumTypeInfo(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if (type is not null && type != typeof(string)) + return null; + + if (options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresEnumType) + return null; + + return new PgTypeInfo(options, new StringTextConverter(options.TextEncoding), dataTypeName); + } + + static void AddInfos(TypeInfoMappingCollection mappings) + { + // Bool + mappings.AddStructType(DataTypeNames.Bool, + static (options, mapping, _) => mapping.CreateInfo(options, new BoolConverter()), isDefault: true); + + // Numeric + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Float4, + static (options, mapping, _) => mapping.CreateInfo(options, new RealConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Float8, + static (options, mapping, _) => mapping.CreateInfo(options, new DoubleConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Money, + static (options, mapping, _) => mapping.CreateInfo(options, new MoneyConverter()), MatchRequirement.DataTypeName); + + // Text + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), isDefault: true); + mappings.AddStructType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new CharTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + // Uses the bytea converters, as neither type has a header. + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new ArrayByteaConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType>(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryByteaConverter()), + MatchRequirement.DataTypeName); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CharTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + // Uses the bytea converters, as neither type has a header. + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new ArrayByteaConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType>(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryByteaConverter()), + MatchRequirement.DataTypeName); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + } + + // Jsonb + const byte jsonbVersion = 1; + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new StringTextConverter(options.TextEncoding))), isDefault: true); + mappings.AddStructType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new CharTextConverter(options.TextEncoding)))); + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new ArrayByteaConverter())), + MatchRequirement.DataTypeName); + mappings.AddStructType>(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter>(jsonbVersion, new ReadOnlyMemoryByteaConverter())), + MatchRequirement.DataTypeName); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new TextReaderTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new GetCharsTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Jsonpath + const byte jsonpathVersion = 1; + mappings.AddType(DataTypeNames.Jsonpath, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new StringTextConverter(options.TextEncoding))), isDefault: true); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(DataTypeNames.Jsonpath, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new TextReaderTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Jsonpath, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new GetCharsTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Bytea + mappings.AddType(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new ArrayByteaConverter()), isDefault: true); + mappings.AddStructType>(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryByteaConverter())); + mappings.AddType(DataTypeNames.Bytea, + static (options, mapping, _) => new PgTypeInfo(options, new StreamByteaConverter(), new DataTypeName(mapping.DataTypeName), unboxedType: mapping.Type), + mapping => mapping with { TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); + + // Varbit + mappings.AddType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, + new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Varbit)), supportsWriting: false)); + mappings.AddType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitArrayBitStringConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new BoolBitStringConverter())); + mappings.AddStructType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitVector32BitStringConverter())); + + // Bit + mappings.AddType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, + new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Bit)), supportsWriting: false)); + mappings.AddType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitArrayBitStringConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new BoolBitStringConverter())); + mappings.AddStructType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitVector32BitStringConverter())); + + // Timestamp + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType(DataTypeNames.Timestamp, + static (options, mapping, _) => mapping.CreateInfo(options, + new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true)), isDefault: true); + } + else + { + mappings.AddResolverStructType(DataTypeNames.Timestamp, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateResolver(options, options.GetCanonicalTypeId(DataTypeNames.TimestampTz), options.GetCanonicalTypeId(DataTypeNames.Timestamp), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), isDefault: true); + } + mappings.AddStructType(DataTypeNames.Timestamp, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + + // TimestampTz + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, + new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false)), matchRequirement: MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions))); + } + else + { + mappings.AddResolverStructType(DataTypeNames.TimestampTz, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateResolver(options, options.GetCanonicalTypeId(DataTypeNames.TimestampTz), options.GetCanonicalTypeId(DataTypeNames.Timestamp), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), isDefault: true); + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions))); + } + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + + // Date + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => + mapping.CreateInfo(options, new DateTimeDateConverter(options.EnableDateTimeInfinityConversions)), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); +#if NET6_0_OR_GREATER + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => mapping.CreateInfo(options, new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions))); +#endif + + // Time + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeSpanTimeConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); +#if NET6_0_OR_GREATER + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeOnlyTimeConverter())); +#endif + + // TimeTz + mappings.AddStructType(DataTypeNames.TimeTz, + static (options, mapping, _) => mapping.CreateInfo(options, new DateTimeOffsetTimeTzConverter()), + MatchRequirement.DataTypeName); + + // Interval + mappings.AddStructType(DataTypeNames.Interval, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeSpanIntervalConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Interval, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlIntervalConverter())); + + // Uuid + mappings.AddStructType(DataTypeNames.Uuid, + static (options, mapping, _) => mapping.CreateInfo(options, new GuidUuidConverter()), isDefault: true); + + // Hstore + mappings.AddType>("hstore", + static (options, mapping, _) => mapping.CreateInfo(options, new HstoreConverter>(options.TextEncoding)), isDefault: true); + mappings.AddType>("hstore", + static (options, mapping, _) => mapping.CreateInfo(options, new HstoreConverter>(options.TextEncoding))); + + // Unknown + mappings.AddType(DataTypeNames.Unknown, + static (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Void + mappings.AddType(DataTypeNames.Void, + static (options, mapping, _) => mapping.CreateInfo(options, new VoidConverter(), supportsWriting: false), + MatchRequirement.DataTypeName); + + // UInt internal types + foreach (var dataTypeName in new[] { DataTypeNames.Oid, DataTypeNames.Xid, DataTypeNames.Cid, DataTypeNames.RegType, DataTypeNames.RegConfig }) + { + mappings.AddStructType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new UInt32Converter()), + MatchRequirement.DataTypeName); + } + + // Char + mappings.AddStructType(DataTypeNames.Char, + static (options, mapping, _) => mapping.CreateInfo(options, new InternalCharConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Char, + static (options, mapping, _) => mapping.CreateInfo(options, new InternalCharConverter())); + + // Xid8 + mappings.AddStructType(DataTypeNames.Xid8, + static (options, mapping, _) => mapping.CreateInfo(options, new UInt64Converter()), + MatchRequirement.DataTypeName); + + // Oidvector + mappings.AddType( + DataTypeNames.OidVector, + static (options, mapping, _) => mapping.CreateInfo(options, + new ArrayBasedArrayConverter(new(new UInt32Converter(), new PgTypeId(DataTypeNames.Oid)), pgLowerBound: 0)), + MatchRequirement.DataTypeName); + + // Int2vector + mappings.AddType( + DataTypeNames.Int2Vector, + static (options, mapping, _) => mapping.CreateInfo(options, + new ArrayBasedArrayConverter(new(new Int2Converter(), new PgTypeId(DataTypeNames.Int2)), pgLowerBound: 0)), + MatchRequirement.DataTypeName); + + // Tid + mappings.AddStructType(DataTypeNames.Tid, + static (options, mapping, _) => mapping.CreateInfo(options, new TidConverter()), + MatchRequirement.DataTypeName); + + // PgLsn + mappings.AddStructType(DataTypeNames.PgLsn, + static (options, mapping, _) => mapping.CreateInfo(options, new PgLsnConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.PgLsn, + static (options, mapping, _) => mapping.CreateInfo(options, new UInt64Converter())); + } + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + // Bool + mappings.AddStructArrayType(DataTypeNames.Bool); + + // Numeric + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Float4); + mappings.AddStructArrayType(DataTypeNames.Float8); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Money); + + // Text + mappings.AddArrayType(DataTypeNames.Text); + mappings.AddStructArrayType(DataTypeNames.Text); + mappings.AddArrayType(DataTypeNames.Text); + mappings.AddStructArrayType>(DataTypeNames.Text); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType>(dataTypeName); + } + + // Jsonb + mappings.AddArrayType(DataTypeNames.Jsonb); + mappings.AddStructArrayType(DataTypeNames.Jsonb); + mappings.AddArrayType(DataTypeNames.Jsonb); + mappings.AddStructArrayType>(DataTypeNames.Jsonb); + + // Jsonpath + mappings.AddArrayType(DataTypeNames.Jsonpath); + + // Bytea + mappings.AddArrayType(DataTypeNames.Bytea); + mappings.AddStructArrayType>(DataTypeNames.Bytea); + + // Varbit + // Object mapping first. + mappings.AddPolymorphicResolverArrayType(DataTypeNames.Varbit, static options => resolution => resolution.Converter switch + { + BoolBitStringConverter => PgConverterFactory.CreatePolymorphicArrayConverter( + () => new ArrayBasedArrayConverter(resolution, typeof(Array)), + () => new ArrayBasedArrayConverter(new(new NullableConverter(resolution.GetConverter()), resolution.PgTypeId), typeof(Array)), + options), + BitArrayBitStringConverter => new ArrayBasedArrayConverter(resolution, typeof(Array)), + _ => throw new NotSupportedException() + }); + mappings.AddArrayType(DataTypeNames.Varbit); + mappings.AddStructArrayType(DataTypeNames.Varbit); + mappings.AddStructArrayType(DataTypeNames.Varbit); + + // Bit + // Object mapping first. + mappings.AddPolymorphicResolverArrayType(DataTypeNames.Bit, static options => resolution => resolution.Converter switch + { + BoolBitStringConverter => PgConverterFactory.CreatePolymorphicArrayConverter( + () => new ArrayBasedArrayConverter(resolution, typeof(Array)), + () => new ArrayBasedArrayConverter(new(new NullableConverter(resolution.GetConverter()), resolution.PgTypeId), typeof(Array)), + options), + BitArrayBitStringConverter => new ArrayBasedArrayConverter(resolution, typeof(Array)), + _ => throw new NotSupportedException() + }); + mappings.AddArrayType(DataTypeNames.Bit); + mappings.AddStructArrayType(DataTypeNames.Bit); + mappings.AddStructArrayType(DataTypeNames.Bit); + + // Timestamp + if (Statics.LegacyTimestampBehavior) + mappings.AddStructArrayType(DataTypeNames.Timestamp); + else + mappings.AddResolverStructArrayType(DataTypeNames.Timestamp); + mappings.AddStructArrayType(DataTypeNames.Timestamp); + + // TimestampTz + if (Statics.LegacyTimestampBehavior) + mappings.AddStructArrayType(DataTypeNames.TimestampTz); + else + mappings.AddResolverStructArrayType(DataTypeNames.TimestampTz); + mappings.AddStructArrayType(DataTypeNames.TimestampTz); + mappings.AddStructArrayType(DataTypeNames.TimestampTz); + + // Date + mappings.AddStructArrayType(DataTypeNames.Date); + mappings.AddStructArrayType(DataTypeNames.Date); +#if NET6_0_OR_GREATER + mappings.AddStructArrayType(DataTypeNames.Date); +#endif + + // Time + mappings.AddStructArrayType(DataTypeNames.Time); + mappings.AddStructArrayType(DataTypeNames.Time); +#if NET6_0_OR_GREATER + mappings.AddStructArrayType(DataTypeNames.Time); +#endif + + // TimeTz + mappings.AddStructArrayType(DataTypeNames.TimeTz); + + // Interval + mappings.AddStructArrayType(DataTypeNames.Interval); + mappings.AddStructArrayType(DataTypeNames.Interval); + + // Uuid + mappings.AddStructArrayType(DataTypeNames.Uuid); + + // Hstore + mappings.AddArrayType>("hstore"); + mappings.AddArrayType>("hstore"); + + // UInt internal types + foreach (var dataTypeName in new[] { DataTypeNames.Oid, DataTypeNames.Xid, DataTypeNames.Cid, DataTypeNames.RegType, (string)DataTypeNames.RegConfig }) + { + mappings.AddStructArrayType(dataTypeName); + } + + // Char + mappings.AddStructArrayType(DataTypeNames.Char); + mappings.AddStructArrayType(DataTypeNames.Char); + + // Xid8 + mappings.AddStructArrayType(DataTypeNames.Xid8); + + // Oidvector + mappings.AddArrayType(DataTypeNames.OidVector); + + // Int2vector + mappings.AddArrayType(DataTypeNames.Int2Vector); + } +} + +sealed class AdoArrayTypeInfoResolver : AdoTypeInfoResolver, IPgTypeInfoResolver +{ + new TypeInfoMappingCollection Mappings { get; } + + public AdoArrayTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(base.Mappings); + var elementTypeCount = Mappings.Items.Count; + AddArrayInfos(Mappings); + // Make sure we have at least one mapping for each element type. + Debug.Assert(Mappings.Items.Count >= elementTypeCount * 2); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + var info = Mappings.Find(type, dataTypeName, options); + if (info is null && dataTypeName is not null) + info = GetEnumArrayTypeInfo(type, dataTypeName.GetValueOrDefault(), options); + return info; + } + + static PgTypeInfo? GetEnumArrayTypeInfo(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if (type is not null && type != typeof(object) && (!TypeInfoMappingCollection.IsArrayLikeType(type, out var elementType) || elementType != typeof(string))) + return null; + + if (options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresArrayType { Element: PostgresEnumType enumType }) + return null; + + var mappings = new TypeInfoMappingCollection(); + mappings.AddType(enumType.DataTypeName, (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding)), MatchRequirement.DataTypeName); + mappings.AddArrayType(enumType.DataTypeName); + return mappings.Find(type, dataTypeName, options); + } +} diff --git a/src/Npgsql/Internal/Resolvers/ExtraConversionsResolver.cs b/src/Npgsql/Internal/Resolvers/ExtraConversionsResolver.cs new file mode 100644 index 0000000000..5f642daf80 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/ExtraConversionsResolver.cs @@ -0,0 +1,235 @@ +using System; +using System.Collections.Immutable; +using System.Numerics; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Resolvers; + +class ExtraConversionsResolver : IPgTypeInfoResolver +{ + public ExtraConversionsResolver() => AddInfos(Mappings); + + protected TypeInfoMappingCollection Mappings { get; } = new(); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + { + // Int2 + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + + // Int4 + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + + // Int8 + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + + // Float4 + mappings.AddStructType(DataTypeNames.Float4, + static (options, mapping, _) => mapping.CreateInfo(options, new RealConverter())); + + // Float8 + mappings.AddStructType(DataTypeNames.Float8, + static (options, mapping, _) => mapping.CreateInfo(options, new DoubleConverter())); + + // Numeric + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new BigIntegerNumericConverter())); + + // Bytea + mappings.AddStructType>(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new ArraySegmentByteaConverter())); + mappings.AddStructType>(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new MemoryByteaConverter())); + + // Varbit + mappings.AddType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new StringBitStringConverter())); + + // Bit + mappings.AddType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new StringBitStringConverter())); + + // Text + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArrayTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + mappings.AddStructType>(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + mappings.AddStructType>(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArraySegmentTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArrayTextConverter(options.TextEncoding), + preferredFormat: DataFormat.Text)); + mappings.AddStructType>(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryTextConverter(options.TextEncoding), + preferredFormat: DataFormat.Text)); + mappings.AddStructType>(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArraySegmentTextConverter(options.TextEncoding), + preferredFormat: DataFormat.Text)); + } + + // Jsonb + const byte jsonbVersion = 1; + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new CharArrayTextConverter(options.TextEncoding)))); + mappings.AddStructType>(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter>(jsonbVersion, new ReadOnlyMemoryTextConverter(options.TextEncoding)))); + mappings.AddStructType>(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter>(jsonbVersion, new CharArraySegmentTextConverter(options.TextEncoding)))); + + // Hstore + mappings.AddType>("hstore", + static (options, mapping, _) => mapping.CreateInfo(options, new HstoreConverter>(options.TextEncoding, result => result.ToImmutableDictionary()))); + } + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + // Int2 + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + + // Int4 + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + + // Int8 + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + + // Float4 + mappings.AddStructArrayType(DataTypeNames.Float4); + + // Float8 + mappings.AddStructArrayType(DataTypeNames.Float8); + + // Numeric + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + + // Bytea + mappings.AddStructArrayType>(DataTypeNames.Bytea); + mappings.AddStructArrayType>(DataTypeNames.Bytea); + + // Varbit + mappings.AddArrayType(DataTypeNames.Varbit); + + // Bit + mappings.AddArrayType(DataTypeNames.Bit); + + // Text + mappings.AddArrayType(DataTypeNames.Text); + mappings.AddStructArrayType>(DataTypeNames.Text); + mappings.AddStructArrayType>(DataTypeNames.Text); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType>(dataTypeName); + mappings.AddStructArrayType>(dataTypeName); + } + + // Jsonb + mappings.AddArrayType(DataTypeNames.Jsonb); + mappings.AddStructArrayType>(DataTypeNames.Jsonb); + mappings.AddStructArrayType>(DataTypeNames.Jsonb); + + // Hstore + mappings.AddArrayType>("hstore"); + } +} + +sealed class ExtraConversionsArrayTypeInfoResolver : ExtraConversionsResolver, IPgTypeInfoResolver +{ + public ExtraConversionsArrayTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(base.Mappings.Items); + AddArrayInfos(Mappings); + } + + new TypeInfoMappingCollection Mappings { get; } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} diff --git a/src/Npgsql/Internal/Resolvers/FullTextSearchTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/FullTextSearchTypeInfoResolver.cs new file mode 100644 index 0000000000..f3b3a90d79 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/FullTextSearchTypeInfoResolver.cs @@ -0,0 +1,81 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +sealed class FullTextSearchTypeInfoResolver : IPgTypeInfoResolver +{ + TypeInfoMappingCollection Mappings { get; } + + public FullTextSearchTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + { + // tsvector + mappings.AddType(DataTypeNames.TsVector, + static (options, mapping, _) => mapping.CreateInfo(options, new TsVectorConverter(options.TextEncoding)), isDefault: true); + + // tsquery + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding)), isDefault: true); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + // tsvector + mappings.AddArrayType(DataTypeNames.TsVector); + + // tsquery + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + } + + public static void CheckUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (type != typeof(object) && (dataTypeName == DataTypeNames.TsQuery || dataTypeName == DataTypeNames.TsVector)) + throw new NotSupportedException( + string.Format(NpgsqlStrings.FullTextSearchNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), typeof(TBuilder).Name)); + + if (type is null) + return; + + if (TypeInfoMappingCollection.IsArrayLikeType(type, out var elementType)) + type = elementType; + + if (type is { IsConstructedGenericType: true } && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + type = type.GetGenericArguments()[0]; + + if (type == typeof(NpgsqlTsVector) || typeof(NpgsqlTsQuery).IsAssignableFrom(type)) + throw new NotSupportedException( + string.Format(NpgsqlStrings.FullTextSearchNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), typeof(TBuilder).Name)); + } +} diff --git a/src/Npgsql/Internal/Resolvers/GeometricTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/GeometricTypeInfoResolver.cs new file mode 100644 index 0000000000..6c24e1dcf9 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/GeometricTypeInfoResolver.cs @@ -0,0 +1,51 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +sealed class GeometricTypeInfoResolver : IPgTypeInfoResolver +{ + TypeInfoMappingCollection Mappings { get; } + + public GeometricTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + { + mappings.AddStructType(DataTypeNames.Point, + static (options, mapping, _) => mapping.CreateInfo(options, new PointConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Box, + static (options, mapping, _) => mapping.CreateInfo(options, new BoxConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Polygon, + static (options, mapping, _) => mapping.CreateInfo(options, new PolygonConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Line, + static (options, mapping, _) => mapping.CreateInfo(options, new LineConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.LSeg, + static (options, mapping, _) => mapping.CreateInfo(options, new LineSegmentConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Path, + static (options, mapping, _) => mapping.CreateInfo(options, new PathConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Circle, + static (options, mapping, _) => mapping.CreateInfo(options, new CircleConverter()), isDefault: true); + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + mappings.AddStructArrayType(DataTypeNames.Point); + mappings.AddStructArrayType(DataTypeNames.Box); + mappings.AddStructArrayType(DataTypeNames.Polygon); + mappings.AddStructArrayType(DataTypeNames.Line); + mappings.AddStructArrayType(DataTypeNames.LSeg); + mappings.AddStructArrayType(DataTypeNames.Path); + mappings.AddStructArrayType(DataTypeNames.Circle); + } +} diff --git a/src/Npgsql/Internal/Resolvers/LTreeTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/LTreeTypeInfoResolver.cs new file mode 100644 index 0000000000..129f73eecd --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/LTreeTypeInfoResolver.cs @@ -0,0 +1,51 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; + +namespace Npgsql.Internal.Resolvers; + +sealed class LTreeTypeInfoResolver : IPgTypeInfoResolver +{ + const byte LTreeVersion = 1; + TypeInfoMappingCollection Mappings { get; } + + public LTreeTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + { + mappings.AddType("ltree", + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(LTreeVersion, new StringTextConverter(options.TextEncoding))), + MatchRequirement.DataTypeName); + mappings.AddType("lquery", + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(LTreeVersion, new StringTextConverter(options.TextEncoding))), + MatchRequirement.DataTypeName); + mappings.AddType("ltxtquery", + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(LTreeVersion, new StringTextConverter(options.TextEncoding))), + MatchRequirement.DataTypeName); + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + mappings.AddArrayType("ltree"); + mappings.AddArrayType("lquery"); + mappings.AddArrayType("ltxtquery"); + } + + public static void CheckUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (type != typeof(object) && dataTypeName is { UnqualifiedName: "ltree" or "lquery" or "ltxtquery" }) + throw new NotSupportedException( + string.Format(NpgsqlStrings.LTreeNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableLTree), + typeof(TBuilder).Name)); + } +} diff --git a/src/Npgsql/Internal/Resolvers/NetworkTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/NetworkTypeInfoResolver.cs new file mode 100644 index 0000000000..49ef7e8a5f --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/NetworkTypeInfoResolver.cs @@ -0,0 +1,74 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Net; +using System.Net.NetworkInformation; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +sealed class NetworkTypeInfoResolver : IPgTypeInfoResolver +{ + TypeInfoMappingCollection Mappings { get; } + + public NetworkTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + { + // macaddr + mappings.AddType(DataTypeNames.MacAddr, + static (options, mapping, _) => mapping.CreateInfo(options, new MacaddrConverter(macaddr8: false)), isDefault: true); + mappings.AddType(DataTypeNames.MacAddr8, + static (options, mapping, _) => mapping.CreateInfo(options, new MacaddrConverter(macaddr8: true)), + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }); + + // inet + // This is one of the rare mappings that force us to use reflection for a lack of any alternative. + // There are certain IPAddress values like Loopback or Any that return a *private* derived type (see https://github.com/dotnet/runtime/issues/27870). + // However we still need to be able to resolve an exactly typed converter for those values. + // We do so by wrapping our converter in a casting converter constructed over the derived type. + // Finally we add a custom predicate to be able to match any type which values are assignable to IPAddress. + mappings.AddType(DataTypeNames.Inet, + [UnconditionalSuppressMessage("AOT", "IL3050", Justification = "MakeGenericType is safe because the target will only ever be a reference type.")] + static (options, resolvedMapping, _) => + { + var derivedType = resolvedMapping.Type != typeof(IPAddress); + PgConverter converter = new IPAddressConverter(); + if (derivedType) + // There is not much more we can do, the deriving type IPAddress+ReadOnlyIPAddress isn't public. + converter = (PgConverter)Activator.CreateInstance(typeof(CastingConverter<>).MakeGenericType(resolvedMapping.Type), converter)!; + + return resolvedMapping.CreateInfo(options, converter); + }, mapping => mapping with { MatchRequirement = MatchRequirement.Single, TypeMatchPredicate = type => type is null || typeof(IPAddress).IsAssignableFrom(type) }); + mappings.AddStructType(DataTypeNames.Inet, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlInetConverter())); + + // cidr + mappings.AddStructType(DataTypeNames.Cidr, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlCidrConverter()), isDefault: true); + } + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + // macaddr + mappings.AddArrayType(DataTypeNames.MacAddr); + mappings.AddArrayType(DataTypeNames.MacAddr8); + + // inet + mappings.AddArrayType(DataTypeNames.Inet); + mappings.AddStructArrayType(DataTypeNames.Inet); + + // cidr + mappings.AddStructArrayType(DataTypeNames.Cidr); + } +} diff --git a/src/Npgsql/Internal/Resolvers/RangeTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/RangeTypeInfoResolver.cs new file mode 100644 index 0000000000..57fc75e978 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/RangeTypeInfoResolver.cs @@ -0,0 +1,437 @@ +using System; +using System.Collections.Generic; +using System.Numerics; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; +using Npgsql.TypeMapping; +using Npgsql.Util; +using NpgsqlTypes; +using static Npgsql.Internal.PgConverterFactory; + +namespace Npgsql.Internal.Resolvers; + +// TODO improve the ability to switch on server capability. +class RangeTypeInfoResolver : IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } + protected TypeInfoMappingCollection MappingsWithMultiRanges { get; } + + public RangeTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings, supportsMultiRange: false); + MappingsWithMultiRanges = new TypeInfoMappingCollection(); + AddInfos(MappingsWithMultiRanges, supportsMultiRange: true); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => (options.DatabaseInfo.SupportsMultirangeTypes ? MappingsWithMultiRanges : Mappings).Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings, bool supportsMultiRange) + { + // numeric ranges + mappings.AddStructType>(DataTypeNames.Int4Range, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int4Converter(), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.Int8Range, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.NumRange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new DecimalNumericConverter(), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.NumRange, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new BigIntegerNumericConverter(), options))); + + // tsrange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType>(DataTypeNames.TsRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true), options)), + isDefault: true); + } + else + { + mappings.AddResolverStructType>(DataTypeNames.TsRange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateRangeResolver(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzRange), + options.GetCanonicalTypeId(DataTypeNames.TsRange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + } + mappings.AddStructType>(DataTypeNames.TsRange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options))); + + // tstzrange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options))); + } + else + { + mappings.AddResolverStructType>(DataTypeNames.TsTzRange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateRangeResolver(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzRange), + options.GetCanonicalTypeId(DataTypeNames.TsRange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options))); + } + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options))); + + // daterange + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int4Converter(), options))); +#if NET6_0_OR_GREATER + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options))); +#endif + + if (supportsMultiRange) + { + // int4multirange + mappings.AddType[]>(DataTypeNames.Int4Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new Int4Converter(), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.Int4Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new Int4Converter(), options), options))); + + // int8multirange + mappings.AddType[]>(DataTypeNames.Int8Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.Int8Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + + // nummultirange + mappings.AddType[]>(DataTypeNames.NumMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new DecimalNumericConverter(), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.NumMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new DecimalNumericConverter(), options), options))); + + // tsmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddType[]>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true), options), options))); + } + else + { + mappings.AddType[]>(DataTypeNames.TsMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver[], NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver>, NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch)); + } + mappings.AddType[]>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + mappings.AddType>>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + + // tstzmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false), options), options))); + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), options))); + } + else + { + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver[], NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver>, NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch)); + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), options))); + } + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + + // datemultirange + mappings.AddType[]>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options))); +#if NET6_0_OR_GREATER + mappings.AddType[]>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options), options))); +#endif + } + } + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings, bool supportsMultiRange) + { + // numeric ranges + mappings.AddStructArrayType>(DataTypeNames.Int4Range); + mappings.AddStructArrayType>(DataTypeNames.Int8Range); + mappings.AddStructArrayType>(DataTypeNames.NumRange); + mappings.AddStructArrayType>(DataTypeNames.NumRange); + + // tsrange + if (Statics.LegacyTimestampBehavior) + mappings.AddStructArrayType>(DataTypeNames.TsRange); + else + mappings.AddResolverStructArrayType>(DataTypeNames.TsRange); + mappings.AddStructArrayType>(DataTypeNames.TsRange); + + // tstzrange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + } + else + { + mappings.AddResolverStructArrayType>(DataTypeNames.TsTzRange); + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + } + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + + // daterange + mappings.AddStructArrayType>(DataTypeNames.DateRange); + mappings.AddStructArrayType>(DataTypeNames.DateRange); +#if NET6_0_OR_GREATER + mappings.AddStructArrayType>(DataTypeNames.DateRange); +#endif + + if (supportsMultiRange) + { + // int4multirange + mappings.AddArrayType[]>(DataTypeNames.Int4Multirange); + mappings.AddArrayType>>(DataTypeNames.Int4Multirange); + + // int8multirange + mappings.AddArrayType[]>(DataTypeNames.Int8Multirange); + mappings.AddArrayType>>(DataTypeNames.Int8Multirange); + + // nummultirange + mappings.AddArrayType[]>(DataTypeNames.NumMultirange); + mappings.AddArrayType>>(DataTypeNames.NumMultirange); + + // tsmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddArrayType[]>(DataTypeNames.TsMultirange); + mappings.AddArrayType>>(DataTypeNames.TsMultirange); + } + else + { + mappings.AddResolverArrayType[]>(DataTypeNames.TsMultirange); + mappings.AddResolverArrayType>>(DataTypeNames.TsMultirange); + } + mappings.AddArrayType[]>(DataTypeNames.TsMultirange); + mappings.AddArrayType>>(DataTypeNames.TsMultirange); + + // tstzmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + } + else + { + mappings.AddResolverArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddResolverArrayType>>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + } + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + + // datemultirange + mappings.AddArrayType[]>(DataTypeNames.DateMultirange); + mappings.AddArrayType>>(DataTypeNames.DateMultirange); +#if NET6_0_OR_GREATER + mappings.AddArrayType[]>(DataTypeNames.DateMultirange); + mappings.AddArrayType>>(DataTypeNames.DateMultirange); +#endif + } + } + + public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + var kind = CheckUnsupported(type, dataTypeName, options); + switch (kind) + { + case PgTypeKind.Range when kind.Value.HasFlag(PgTypeKind.Array): + throw new NotSupportedException( + string.Format(NpgsqlStrings.RangeArraysNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableArrays), typeof(TBuilder).Name)); + case PgTypeKind.Range: + throw new NotSupportedException( + string.Format(NpgsqlStrings.RangesNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableRanges), typeof(TBuilder).Name)); + case PgTypeKind.Multirange when kind.Value.HasFlag(PgTypeKind.Array): + throw new NotSupportedException( + string.Format(NpgsqlStrings.MultirangeArraysNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableArrays), typeof(TBuilder).Name)); + case PgTypeKind.Multirange: + throw new NotSupportedException( + string.Format(NpgsqlStrings.MultirangesNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableMultiranges), typeof(TBuilder).Name)); + default: + return; + } + } + + public static PgTypeKind? CheckUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + // Only trigger on well known data type names. + var npgsqlDbType = dataTypeName?.ToNpgsqlDbType(); + if (type != typeof(object)) + { + if (npgsqlDbType?.HasFlag(NpgsqlDbType.Range) != true && npgsqlDbType?.HasFlag(NpgsqlDbType.Multirange) != true) + return null; + + if (npgsqlDbType.Value.HasFlag(NpgsqlDbType.Range)) + return dataTypeName?.IsArray == true + ? PgTypeKind.Array | PgTypeKind.Range + : PgTypeKind.Range; + + return dataTypeName?.IsArray == true + ? PgTypeKind.Array | PgTypeKind.Multirange + : PgTypeKind.Multirange; + } + + if (type == typeof(object)) + return null; + + var isArray = false; + if (TypeInfoMappingCollection.IsArrayLikeType(type, out var elementType)) + { + type = elementType; + isArray = true; + } + + if (type is { IsConstructedGenericType: true } && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + type = type.GetGenericArguments()[0]; + + if (type is { IsConstructedGenericType: true } && type.GetGenericTypeDefinition() == typeof(NpgsqlRange<>)) + { + type = type.GetGenericArguments()[0]; + var matchingArguments = + new[] + { + typeof(int), typeof(long), typeof(decimal), typeof(DateTime), +# if NET6_0_OR_GREATER + typeof(DateOnly) +#endif + }; + + // If we don't know more than the clr type, default to a Multirange kind over Array as they share the same types. + foreach (var argument in matchingArguments) + if (argument == type) + return isArray ? PgTypeKind.Multirange : PgTypeKind.Range; + + if (type.AssemblyQualifiedName == "System.Numerics.BigInteger,System.Runtime.Numerics") + return isArray ? PgTypeKind.Multirange : PgTypeKind.Range; + } + + return null; + } +} + +sealed class RangeArrayTypeInfoResolver : RangeTypeInfoResolver, IPgTypeInfoResolver +{ + new TypeInfoMappingCollection Mappings { get; } + new TypeInfoMappingCollection MappingsWithMultiRanges { get; } + + public RangeArrayTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(base.Mappings); + AddArrayInfos(Mappings, supportsMultiRange: false); + MappingsWithMultiRanges = new TypeInfoMappingCollection(base.MappingsWithMultiRanges); + AddArrayInfos(MappingsWithMultiRanges, supportsMultiRange: true); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => (options.DatabaseInfo.SupportsMultirangeTypes ? MappingsWithMultiRanges : Mappings).Find(type, dataTypeName, options); +} diff --git a/src/Npgsql/Internal/Resolvers/RecordTypeInfoResolvers.cs b/src/Npgsql/Internal/Resolvers/RecordTypeInfoResolvers.cs new file mode 100644 index 0000000000..f51aeac322 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/RecordTypeInfoResolvers.cs @@ -0,0 +1,137 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; + +namespace Npgsql.Internal.Resolvers; + +class RecordTypeInfoResolver : IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } = new(); + public RecordTypeInfoResolver() => AddInfos(Mappings); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + => mappings.AddType(DataTypeNames.Record, static (options, mapping, _) => + mapping.CreateInfo(options, new ObjectArrayRecordConverter(options), supportsWriting: false), + MatchRequirement.DataTypeName); + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings) + => mappings.AddArrayType(DataTypeNames.Record); + + public static void CheckUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (type != typeof(object) && dataTypeName == DataTypeNames.Record) + { + throw new NotSupportedException( + string.Format(NpgsqlStrings.RecordsNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableRecords), typeof(TBuilder).Name)); + } + } +} + +sealed class RecordArrayTypeInfoResolver : RecordTypeInfoResolver, IPgTypeInfoResolver +{ + public RecordArrayTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(base.Mappings.Items); + AddArrayInfos(Mappings); + } + + new TypeInfoMappingCollection Mappings { get; } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} + +[RequiresUnreferencedCode("Tupled record resolver may perform reflection on trimmed tuple types.")] +[RequiresDynamicCode("Tupled records need to construct a generic converter for a statically unknown (value)tuple type.")] +class TupledRecordTypeInfoResolver : IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } = new(); + public TupledRecordTypeInfoResolver() => AddInfos(Mappings); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + // Stand-in type, type match predicate does the actual work. + static void AddInfos(TypeInfoMappingCollection mappings) + { + mappings.AddType>(DataTypeNames.Record, Factory, + mapping => mapping with + { + MatchRequirement = MatchRequirement.DataTypeName, + TypeMatchPredicate = type => type is null || (type is { IsConstructedGenericType: true, FullName: not null } + && type.FullName.StartsWith("System.Tuple", StringComparison.Ordinal)) + }); + + mappings.AddStructType>(DataTypeNames.Record, Factory, + mapping => mapping with + { + MatchRequirement = MatchRequirement.DataTypeName, + TypeMatchPredicate = type => type is null || (type is { IsConstructedGenericType: true, FullName: not null } + && type.FullName.StartsWith("System.ValueTuple", StringComparison.Ordinal)) + }); + } + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + mappings.AddArrayType>(DataTypeNames.Record); + mappings.AddStructArrayType>(DataTypeNames.Record); + } + + static readonly TypeInfoFactory Factory = static (options, mapping, _) => + { + var constructors = mapping.Type.GetConstructors(); + ConstructorInfo? constructor = null; + if (constructors.Length is 1) + constructor = constructors[0]; + else + { + var args = mapping.Type.GenericTypeArguments.Length; + foreach (var ctor in constructors) + if (ctor.GetParameters().Length == args) + { + constructor = ctor; + break; + } + } + + if (constructor is null) + throw new InvalidOperationException($"Couldn't find a suitable constructor for record type: {mapping.Type.FullName}"); + + var factory = typeof(TupledRecordTypeInfoResolver).GetMethod(nameof(CreateFactory), BindingFlags.Static | BindingFlags.NonPublic)! + .MakeGenericMethod(mapping.Type) + .Invoke(null, new object[] { constructor, constructor.GetParameters().Length }); + + var converterType = typeof(ObjectArrayRecordConverter<>).MakeGenericType(mapping.Type); + var converter = (PgConverter)Activator.CreateInstance(converterType, options, factory)!; + return mapping.CreateInfo(options, converter, supportsWriting: false); + }; + + static Func CreateFactory(ConstructorInfo constructor, int constructorParameters) => array => + { + if (array.Length != constructorParameters) + throw new InvalidCastException($"Cannot read record type with {array.Length} fields as {typeof(T)}"); + return (T)constructor.Invoke(array); + }; +} + +[RequiresUnreferencedCode("Tupled record resolver may perform reflection on trimmed tuple types.")] +[RequiresDynamicCode("Tupled records need to construct a generic converter for a statically unknown (value)tuple type.")] +sealed class TupledRecordArrayTypeInfoResolver : TupledRecordTypeInfoResolver, IPgTypeInfoResolver +{ + public TupledRecordArrayTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(base.Mappings.Items); + AddArrayInfos(Mappings); + } + + new TypeInfoMappingCollection Mappings { get; } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} diff --git a/src/Npgsql/Internal/Resolvers/SystemTextJsonPocoTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/SystemTextJsonPocoTypeInfoResolver.cs new file mode 100644 index 0000000000..e513b29a86 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/SystemTextJsonPocoTypeInfoResolver.cs @@ -0,0 +1,123 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Resolvers; + +[RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] +[RequiresDynamicCode("Serializing arbitary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] +class SystemTextJsonPocoTypeInfoResolver : DynamicTypeInfoResolver, IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } = new(); + protected JsonSerializerOptions _serializerOptions; + + public SystemTextJsonPocoTypeInfoResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + { +#if NET7_0_OR_GREATER + _serializerOptions = serializerOptions ??= JsonSerializerOptions.Default; +#else + _serializerOptions = serializerOptions ??= new JsonSerializerOptions(); +#endif + + AddMappings(Mappings, jsonbClrTypes ?? Array.Empty(), jsonClrTypes ?? Array.Empty(), serializerOptions); + } + + void AddMappings(TypeInfoMappingCollection mappings, Type[] jsonbClrTypes, Type[] jsonClrTypes, JsonSerializerOptions serializerOptions) + { + // We do GetTypeInfo calls directly so we need a resolver. + serializerOptions.TypeInfoResolver ??= new DefaultJsonTypeInfoResolver(); + + AddUserMappings(jsonb: true, jsonbClrTypes); + AddUserMappings(jsonb: false, jsonClrTypes); + + void AddUserMappings(bool jsonb, Type[] clrTypes) + { + var dynamicMappings = CreateCollection(); + var dataTypeName = (string)(jsonb ? DataTypeNames.Jsonb : DataTypeNames.Json); + foreach (var jsonType in clrTypes) + { + var jsonTypeInfo = serializerOptions.GetTypeInfo(jsonType); + dynamicMappings.AddMapping(jsonTypeInfo.Type, dataTypeName, + factory: (options, mapping, _) => mapping.CreateInfo(options, + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, serializerOptions, jsonType))); + + if (!jsonType.IsValueType && jsonTypeInfo.PolymorphismOptions is not null) + { + foreach (var derived in jsonTypeInfo.PolymorphismOptions.DerivedTypes) + dynamicMappings.AddMapping(derived.DerivedType, dataTypeName, + factory: (options, mapping, _) => mapping.CreateInfo(options, + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, serializerOptions, jsonType))); + } + } + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + } + } + + protected void AddArrayInfos(TypeInfoMappingCollection mappings, TypeInfoMappingCollection baseMappings) + { + if (baseMappings.Items.Count == 0) + return; + + var dynamicMappings = CreateCollection(baseMappings); + foreach (var mapping in baseMappings.Items) + dynamicMappings.AddArrayMapping(mapping.Type, mapping.DataTypeName); + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + // Match all types except null, object and text types as long as DataTypeName (json/jsonb) is present. + if (type is null || type == typeof(object) || Array.IndexOf(PgSerializerOptions.WellKnownTextTypes, type) != -1 + || dataTypeName != DataTypeNames.Jsonb && dataTypeName != DataTypeNames.Json) + return null; + + return CreateCollection().AddMapping(type, dataTypeName, (options, mapping, _) => + { + var jsonb = dataTypeName == DataTypeNames.Jsonb; + + // For jsonb we can't properly support polymorphic serialization unless we do quite some additional work + // so we default to mapping.Type instead (exact types will never serialize their "$type" fields, essentially disabling the feature). + var baseType = jsonb ? mapping.Type : typeof(object); + + return mapping.CreateInfo(options, + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, _serializerOptions, baseType)); + }); + } + + static PgConverter CreateSystemTextJsonConverter(Type valueType, bool jsonb, Encoding textEncoding, JsonSerializerOptions serializerOptions, Type baseType) + => (PgConverter)Activator.CreateInstance( + typeof(SystemTextJsonConverter<,>).MakeGenericType(valueType, baseType), + jsonb, + textEncoding, + serializerOptions + )!; +} + +[RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] +[RequiresDynamicCode("Serializing arbitary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] +sealed class SystemTextJsonPocoArrayTypeInfoResolver : SystemTextJsonPocoTypeInfoResolver, IPgTypeInfoResolver +{ + new TypeInfoMappingCollection Mappings { get; } + + public SystemTextJsonPocoArrayTypeInfoResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + : base(jsonbClrTypes, jsonClrTypes, serializerOptions) + { + Mappings = new TypeInfoMappingCollection(base.Mappings); + AddArrayInfos(Mappings, base.Mappings); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; +} diff --git a/src/Npgsql/Internal/Resolvers/SystemTextJsonTypeInfoResolvers.cs b/src/Npgsql/Internal/Resolvers/SystemTextJsonTypeInfoResolvers.cs new file mode 100644 index 0000000000..5650906ecb --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/SystemTextJsonTypeInfoResolvers.cs @@ -0,0 +1,70 @@ +using System; +using System.Text.Json; +using System.Text.Json.Nodes; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Resolvers; + +class SystemTextJsonTypeInfoResolver : IPgTypeInfoResolver +{ + protected TypeInfoMappingCollection Mappings { get; } = new(); + + public SystemTextJsonTypeInfoResolver(JsonSerializerOptions? serializerOptions = null) + => AddTypeInfos(Mappings, serializerOptions); + + static void AddTypeInfos(TypeInfoMappingCollection mappings, JsonSerializerOptions? serializerOptions = null) + { +#if NET7_0_OR_GREATER + serializerOptions ??= JsonSerializerOptions.Default; +#else + serializerOptions ??= new JsonSerializerOptions(); +#endif + + // Jsonb is the first default for JsonDocument + foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) + { + var jsonb = dataTypeName == DataTypeNames.Jsonb; + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new SystemTextJsonConverter(jsonb, options.TextEncoding, serializerOptions)), + isDefault: true); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new SystemTextJsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new SystemTextJsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new SystemTextJsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new SystemTextJsonConverter(jsonb, options.TextEncoding, serializerOptions))); + } + } + + protected static void AddArrayInfos(TypeInfoMappingCollection mappings) + { + foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} + +sealed class SystemTextJsonArrayTypeInfoResolver : SystemTextJsonTypeInfoResolver, IPgTypeInfoResolver +{ + new TypeInfoMappingCollection Mappings { get; } + + public SystemTextJsonArrayTypeInfoResolver(JsonSerializerOptions? serializerOptions = null) : base(serializerOptions) + { + Mappings = new TypeInfoMappingCollection(base.Mappings); + AddArrayInfos(Mappings); + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); +} diff --git a/src/Npgsql/Internal/Resolvers/UnmappedEnumTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/UnmappedEnumTypeInfoResolver.cs new file mode 100644 index 0000000000..b6ab437255 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/UnmappedEnumTypeInfoResolver.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +[RequiresUnreferencedCode("Unmapped enum resolver may perform reflection on types with fields that were trimmed if not referenced directly.")] +[RequiresDynamicCode("Unmapped enums need to construct a generic converter for a statically unknown enum type.")] +class UnmappedEnumTypeInfoResolver : DynamicTypeInfoResolver +{ + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if (type is null || !IsTypeOrNullableOfType(type, static type => type.IsEnum, out var matchedType) || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresEnumType) + return null; + + return CreateCollection().AddMapping(matchedType, dataTypeName, static (options, mapping, _) => + { + var enumToLabel = new Dictionary(); + var labelToEnum = new Dictionary(); + foreach (var field in mapping.Type.GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var attribute = (PgNameAttribute?)field.GetCustomAttributes(typeof(PgNameAttribute), false).FirstOrDefault(); + var enumName = attribute?.PgName ?? options.DefaultNameTranslator.TranslateMemberName(field.Name); + var enumValue = (Enum)field.GetValue(null)!; + + enumToLabel[enumValue] = enumName; + labelToEnum[enumName] = enumValue; + } + + return mapping.CreateInfo(options, (PgConverter)Activator.CreateInstance(typeof(EnumConverter<>).MakeGenericType(mapping.Type), + enumToLabel, labelToEnum, + options.TextEncoding)!); + }); + } +} + +[RequiresUnreferencedCode("Unmapped enum resolver may perform reflection on types with fields that were trimmed if not referenced directly.")] +[RequiresDynamicCode("Unmapped enums need to construct a generic converter for a statically unknown enum type")] +sealed class UnmappedEnumArrayTypeInfoResolver : UnmappedEnumTypeInfoResolver +{ + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; +} diff --git a/src/Npgsql/Internal/Resolvers/UnmappedMultirangeTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/UnmappedMultirangeTypeInfoResolver.cs new file mode 100644 index 0000000000..d18b1421db --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/UnmappedMultirangeTypeInfoResolver.cs @@ -0,0 +1,59 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +[RequiresUnreferencedCode("A dynamic type info resolver may perform reflection on types that were trimmed if not referenced directly.")] +[RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] +class UnmappedMultirangeTypeInfoResolver : DynamicTypeInfoResolver +{ + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + Type? elementType = null; + if (type is not null && !IsArrayLikeType(type, out elementType) + || elementType is not null && !IsTypeOrNullableOfType(elementType, + static type => type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(NpgsqlRange<>), out _) + || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresMultirangeType multirangeType) + return null; + + var subInfo = + elementType is null + ? options.GetDefaultTypeInfo(multirangeType.Subrange) + : options.GetTypeInfo(elementType, multirangeType.Subrange); + + // We have no generic MultirangeConverterResolver so we would not know how to compose a range mapping for such infos. + // See https://github.com/npgsql/npgsql/issues/5268 + if (subInfo is not { IsResolverInfo: false }) + return null; + + subInfo = subInfo.ToNonBoxing(); + + type ??= subInfo.Type.MakeArrayType(); + + return CreateCollection().AddMapping(type, dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, + (PgConverter)Activator.CreateInstance(typeof(MultirangeConverter<,>).MakeGenericType(type, subInfo.Type), subInfo.GetConcreteResolution().Converter)!, + preferredFormat: subInfo.PreferredFormat, supportsWriting: subInfo.SupportsWriting), + mapping => mapping with { MatchRequirement = MatchRequirement.Single }); + } +} + +[RequiresUnreferencedCode("A dynamic type info resolver may perform reflection on types that were trimmed if not referenced directly.")] +[RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] +sealed class UnmappedMultirangeArrayTypeInfoResolver : UnmappedMultirangeTypeInfoResolver +{ + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + Type? elementType = null; + if (!((type is null || IsArrayLikeType(type, out elementType)) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName))) + return null; + + var mappings = base.GetMappings(elementType, elementDataTypeName, options); + elementType ??= mappings?.Find(null, elementDataTypeName, options)?.Type; // Try to get the default mapping. + return elementType is null ? null : mappings?.AddArrayMapping(elementType, elementDataTypeName); + } +} diff --git a/src/Npgsql/Internal/Resolvers/UnmappedRangeTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/UnmappedRangeTypeInfoResolver.cs new file mode 100644 index 0000000000..9e9ba0fb7d --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/UnmappedRangeTypeInfoResolver.cs @@ -0,0 +1,59 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal.Resolvers; + +[RequiresUnreferencedCode("A dynamic type info resolver may perform reflection on types that were trimmed if not referenced directly.")] +[RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] +class UnmappedRangeTypeInfoResolver : DynamicTypeInfoResolver +{ + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + var matchedType = type; + if (type is not null && !IsTypeOrNullableOfType(type, + static type => type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(NpgsqlRange<>), out matchedType) + || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresRangeType rangeType) + return null; + + var subInfo = + matchedType is null + ? options.GetDefaultTypeInfo(rangeType.Subtype) + // Input matchedType here as we don't want an NpgsqlRange over Nullable (it has its own nullability tracking, for better or worse) + : options.GetTypeInfo(matchedType.GetGenericArguments()[0], rangeType.Subtype); + + // We have no generic RangeConverterResolver so we would not know how to compose a range mapping for such infos. + // See https://github.com/npgsql/npgsql/issues/5268 + if (subInfo is not { IsResolverInfo: false }) + return null; + + subInfo = subInfo.ToNonBoxing(); + + matchedType ??= typeof(NpgsqlRange<>).MakeGenericType(subInfo.Type); + + return CreateCollection().AddMapping(matchedType, dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, + (PgConverter)Activator.CreateInstance(typeof(RangeConverter<>).MakeGenericType(subInfo.Type), subInfo.GetConcreteResolution().Converter)!, + preferredFormat: subInfo.PreferredFormat, supportsWriting: subInfo.SupportsWriting), + mapping => mapping with { MatchRequirement = MatchRequirement.Single }); + } +} + +[RequiresUnreferencedCode("A dynamic type info resolver may perform reflection on types that were trimmed if not referenced directly.")] +[RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] +sealed class UnmappedRangeArrayTypeInfoResolver : UnmappedRangeTypeInfoResolver +{ + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + Type? elementType = null; + if (!((type is null || IsArrayLikeType(type, out elementType)) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName))) + return null; + + var mappings = base.GetMappings(elementType, elementDataTypeName, options); + elementType ??= mappings?.Find(null, elementDataTypeName, options)?.Type; // Try to get the default mapping. + return elementType is null ? null : mappings?.AddArrayMapping(elementType, elementDataTypeName); + } +} diff --git a/src/Npgsql/Internal/Resolvers/UnsupportedTypeInfoResolver.cs b/src/Npgsql/Internal/Resolvers/UnsupportedTypeInfoResolver.cs new file mode 100644 index 0000000000..845816f7a7 --- /dev/null +++ b/src/Npgsql/Internal/Resolvers/UnsupportedTypeInfoResolver.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Resolvers; + +sealed class UnsupportedTypeInfoResolver : IPgTypeInfoResolver +{ + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (options.IntrospectionMode) + return null; + + RecordTypeInfoResolver.CheckUnsupported(type, dataTypeName, options); + RangeTypeInfoResolver.ThrowIfUnsupported(type, dataTypeName, options); + FullTextSearchTypeInfoResolver.CheckUnsupported(type, dataTypeName, options); + LTreeTypeInfoResolver.CheckUnsupported(type, dataTypeName, options); + + if (type is null) + return null; + + if (TypeInfoMappingCollection.IsArrayLikeType(type, out var elementType) && TypeInfoMappingCollection.IsArrayLikeType(elementType, out _)) + throw new NotSupportedException("Writing is not supported for jagged collections, use a multidimensional array instead."); + + if (typeof(IEnumerable).IsAssignableFrom(type) && !typeof(IList).IsAssignableFrom(type) && type != typeof(string) && (dataTypeName is null || dataTypeName.Value.IsArray)) + throw new NotSupportedException("Writing is not supported for IEnumerable parameters, use an array or List instead."); + + // TODO bring back json help message. + // $"Can't write CLR type {value.GetType()}. " + + // "You may need to use the System.Text.Json or Json.NET plugins, see the docs for more information." + + return null; + } +} diff --git a/src/Npgsql/Internal/Size.cs b/src/Npgsql/Internal/Size.cs new file mode 100644 index 0000000000..f239453015 --- /dev/null +++ b/src/Npgsql/Internal/Size.cs @@ -0,0 +1,70 @@ +using System; +using System.Diagnostics; + +namespace Npgsql.Internal; + +public enum SizeKind : byte +{ + Unknown = 0, + Exact, + UpperBound +} + +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public readonly struct Size : IEquatable +{ + readonly int _value; + readonly SizeKind _kind; + + Size(SizeKind kind, int value) + { + _value = value; + _kind = kind; + } + + public int Value + { + get + { + if (_kind is SizeKind.Unknown) + ThrowHelper.ThrowInvalidOperationException("Cannot get value from default or Unknown kind"); + return _value; + } + } + + internal int GetValueOrDefault() => _value; + + public SizeKind Kind => _kind; + + public static Size Create(int byteCount) => new(SizeKind.Exact, byteCount); + public static Size CreateUpperBound(int byteCount) => new(SizeKind.UpperBound, byteCount); + public static Size Unknown { get; } = new(SizeKind.Unknown, 0); + public static Size Zero { get; } = new(SizeKind.Exact, 0); + + public Size Combine(Size result) + { + if (_kind is SizeKind.Unknown || result._kind is SizeKind.Unknown) + return Unknown; + + if (_kind is SizeKind.UpperBound || result._kind is SizeKind.UpperBound) + return CreateUpperBound((int)Math.Min((long)(_value + result._value), int.MaxValue)); + + return Create((int)Math.Min((long)(_value + result._value), int.MaxValue)); + } + + public static implicit operator Size(int value) => Create(value); + + string DebuggerDisplay + => _kind switch + { + SizeKind.Exact or SizeKind.UpperBound => $"{_value} ({_kind})", + SizeKind.Unknown => "Unknown", + _ => throw new ArgumentOutOfRangeException() + }; + + public bool Equals(Size other) => _value == other._value && _kind == other.Kind; + public override bool Equals(object? obj) => obj is Size other && Equals(other); + public override int GetHashCode() => HashCode.Combine(_value, (int)_kind); + public static bool operator ==(Size left, Size right) => left.Equals(right); + public static bool operator !=(Size left, Size right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/TypeHandlers/ArrayHandler.cs b/src/Npgsql/Internal/TypeHandlers/ArrayHandler.cs deleted file mode 100644 index bc1f100322..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/ArrayHandler.cs +++ /dev/null @@ -1,610 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// Non-generic base class for all type handlers which handle PostgreSQL arrays. -/// -/// -/// https://www.postgresql.org/docs/current/static/arrays.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public class ArrayHandler : NpgsqlTypeHandler -{ - readonly Type _defaultArrayType; - readonly ConcurrentDictionary _concreteHandlers = new(); - protected int LowerBound { get; } - protected NpgsqlTypeHandler ElementHandler { get; } - protected ArrayNullabilityMode ArrayNullabilityMode { get; } - - public ArrayHandler(PostgresType arrayPostgresType, NpgsqlTypeHandler elementHandler, ArrayNullabilityMode arrayNullabilityMode, int lowerBound = 1) : base(arrayPostgresType) - { - LowerBound = lowerBound; - ElementHandler = elementHandler; - ArrayNullabilityMode = arrayNullabilityMode; - _defaultArrayType = elementHandler.GetFieldType().MakeArrayType(); - } - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(Array); - - /// - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => throw new NotSupportedException(); - - /// - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => throw new NotSupportedException(); - - /// - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType) - => throw new NotSupportedException(); - - ArrayHandlerCore CreateHandler(Type elementType) - => (ArrayHandlerCore)Activator.CreateInstance(typeof(ArrayHandlerCore<>).MakeGenericType(elementType), ElementHandler, ArrayNullabilityMode, LowerBound)!; - - /// - protected internal override async ValueTask ReadCustom(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription) - { - return (TArray)await ReadGenericAsObject(buf, async, fieldDescription); - - // Sync helper to keep the code size cost of ReadCustom low. - ValueTask ReadGenericAsObject(NpgsqlReadBuffer buf, bool async, FieldDescription? fieldDescription) - { - if (ArrayTypeInfo.IsArray) - return GetOrAddHandler().ReadArray(buf, async, ArrayTypeInfo.ArrayRank); - - if (ListTypeInfo.IsList) - return GetOrAddHandler().ReadList(buf, async); - - throw new InvalidCastException(fieldDescription == null - ? $"Can't cast database type to {typeof(TArray).Name}" - : $"Can't cast database type {fieldDescription.Handler.PgDisplayName} to {typeof(TArray).Name}" - ); - } - } - - /// - public override ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - => ReadAsObject(ElementHandler.GetFieldType(), buf, len, async, fieldDescription); - - protected async ValueTask ReadAsObject(Type elementType, NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - { - if (!elementType.IsValueType || ArrayNullabilityMode is ArrayNullabilityMode.Never) - return await GetOrAddObjectHandler(elementType).ReadArrayAsObject(buf, async); - - if (ArrayNullabilityMode is ArrayNullabilityMode.Always) - return await GetOrAddObjectHandler(typeof(Nullable<>).MakeGenericType(elementType)).ReadArrayAsObject(buf, async); - - // We need to peek at the data to call into the right handler. - await buf.Ensure(sizeof(int) * 2, async); - var origPos = buf.ReadPosition; - var _ = buf.ReadInt32(); - var containsNulls = buf.ReadInt32() == 1; - buf.ReadPosition = origPos; - - return containsNulls - ? await GetOrAddObjectHandler(typeof(Nullable<>).MakeGenericType(elementType)).ReadArrayAsObject(buf, async) - : await GetOrAddObjectHandler(elementType).ReadArrayAsObject(buf, async); - } - - ArrayHandlerCore GetOrAddObjectHandler(Type elementType) - { - var arrayType = - elementType == ElementHandler.GetFieldType() - ? _defaultArrayType - : elementType.MakeArrayType(); - - return _concreteHandlers.GetOrAdd(arrayType, - static (t, instance) => instance.CreateHandler(t.GetElementType()!), this); - } - - /// - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => GetOrAddObjectHandler(ElementHandler.GetFieldType()).ValidateAndGetElementLength(value, ref lengthCache); - - /// - protected internal override int ValidateAndGetLengthCustom([DisallowNull] TArray value, ref NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter) - => GetOrAddHandler().ValidateAndGetElementLength(value, ref lengthCache); - - /// - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken = default) - { - if (value is null or DBNull) - { - buf.WriteInt32(-1); - return Task.CompletedTask; - } - return GetOrAddObjectHandler(ElementHandler.GetFieldType()).WriteElementWithLength(value, buf, lengthCache, async, cancellationToken); - } - - protected override Task WriteWithLengthCustom([DisallowNull]TArray value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken) - => GetOrAddHandler().WriteElementWithLength(value, buf, lengthCache, async, cancellationToken); - - private protected ArrayHandlerCore GetOrAddHandler() - => _concreteHandlers.GetOrAdd(typeof(TArray), static (_, instance) => - { - if (ArrayTypeInfo.IsArray) - return instance.CreateHandler(ArrayTypeInfo.ElementType); - - if (ListTypeInfo.IsList) - return instance.CreateHandler(ListTypeInfo.ElementType); - - return null!; - }, this); - - static class ArrayTypeInfo - { - // ReSharper disable StaticMemberInGenericType - public static readonly Type? ElementType = typeof(TArray).IsArray ? typeof(TArray).GetElementType() : null; - public static readonly int ArrayRank = ElementType is not null ? typeof(TArray).GetArrayRank() : 0; - // ReSharper restore StaticMemberInGenericType - - [MemberNotNullWhen(true, nameof(ElementType))] - public static bool IsArray => ElementType is not null; - } - - static class ListTypeInfo - { - // ReSharper disable StaticMemberInGenericType - public static readonly Type? ElementType = typeof(TList).IsGenericType && typeof(TList).GetGenericTypeDefinition() == typeof(List<>) ? typeof(TList).GetGenericArguments()[0] : null; - // ReSharper restore StaticMemberInGenericType - - [MemberNotNullWhen(true, nameof(ElementType))] - public static bool IsList => ElementType is not null; - } -} - -abstract class ArrayHandlerCore -{ - internal const string ReadNonNullableCollectionWithNullsExceptionMessage = - "Cannot read a non-nullable collection of elements because the returned array contains nulls. " + - "Call GetFieldValue with a nullable array instead."; - - readonly int _lowerBound; - public ArrayNullabilityMode ArrayNullabilityMode { get; } - - protected ArrayHandlerCore(ArrayNullabilityMode arrayNullabilityMode, int lowerBound = 1) - { - ArrayNullabilityMode = arrayNullabilityMode; - _lowerBound = lowerBound; - } - - public ValueTask ReadArray(NpgsqlReadBuffer buf, bool async, int expectedDimensions = 0) - => ReadArray(buf, async, expectedDimensions, readAsObject: false); - - public ValueTask ReadArrayAsObject(NpgsqlReadBuffer buf, bool async, int expectedDimensions = 0) - => ReadArray(buf, async, expectedDimensions, readAsObject: true); - - protected abstract Type ElementType { get; } - protected abstract bool IsNonNullable { get; } - protected abstract bool IsGenericCollection(object value, out int count); - protected abstract NpgsqlTypeHandler ElementHandler { get; } - protected abstract object CreateCollection(bool isArray, int capacity); - protected abstract ValueTask ReadElement(bool isArray, object values, int index, NpgsqlReadBuffer buf, int length, bool async, - FieldDescription? fieldDescription = null); - protected abstract ValueTask ReadElement(Array array, int[] indices, NpgsqlReadBuffer buf, int length, bool async, - FieldDescription? fieldDescription = null); - protected abstract int ValidateAndGetElementLength(bool isArray, object values, int index, ref NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter); - protected abstract ValueTask WriteElementWithLength(bool isArray, object values, int index, NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken); - - /// - /// Reads an array of element type from the given buffer . - /// - async ValueTask ReadArray(NpgsqlReadBuffer buf, bool async, int expectedDimensions, bool readAsObject) - { - await buf.Ensure(12, async); - var dimensions = buf.ReadInt32(); - var containsNulls = buf.ReadInt32() == 1; - buf.ReadUInt32(); // Element OID. Ignored. - - var nullableElementType = IsNonNullable - ? typeof(Nullable<>).MakeGenericType(ElementType) - : ElementType; - - var returnType = readAsObject - ? ArrayNullabilityMode switch - { - ArrayNullabilityMode.Never => IsNonNullable && containsNulls - ? throw new InvalidOperationException(ReadNonNullableCollectionWithNullsExceptionMessage) - : ElementType, - ArrayNullabilityMode.Always => nullableElementType, - ArrayNullabilityMode.PerInstance => containsNulls - ? nullableElementType - : ElementType, - _ => throw new ArgumentOutOfRangeException() - } - : IsNonNullable && containsNulls - ? throw new InvalidOperationException(ReadNonNullableCollectionWithNullsExceptionMessage) - : ElementType; - - if (dimensions == 0) - return expectedDimensions > 1 - ? Array.CreateInstance(returnType, new int[expectedDimensions]) - : CreateCollection(isArray: true, 0); - - if (expectedDimensions > 0 && dimensions != expectedDimensions) - throw new InvalidOperationException($"Cannot read an array with {expectedDimensions} dimension(s) from an array with {dimensions} dimension(s)"); - - if (dimensions == 1 && returnType == ElementType) - { - await buf.Ensure(8, async); - var arrayLength = buf.ReadInt32(); - - buf.ReadInt32(); // Lower bound - - var oneDimensional = CreateCollection(isArray: true, arrayLength); - for (var i = 0; i < arrayLength; i++) - { - await buf.Ensure(4, async); - var len = buf.ReadInt32(); - await ReadElement(isArray: true, oneDimensional, i, buf, len, async); - } - return oneDimensional; - } - - var dimLengths = new int[dimensions]; - await buf.Ensure(dimensions * 8, async); - - for (var i = 0; i < dimLengths.Length; i++) - { - dimLengths[i] = buf.ReadInt32(); - buf.ReadInt32(); // Lower bound - } - - var result = Array.CreateInstance(returnType, dimLengths); - - // Either multidimensional arrays or arrays of nullable value types requested as object - // We can't avoid boxing here - var indices = new int[dimensions]; - while (true) - { - await buf.Ensure(4, async); - var len = buf.ReadInt32(); - if (len == -1) - result.SetValue(null, indices); - else - await ReadElement(result, indices, buf, len, async); - - // TODO: Overly complicated/inefficient... - indices[dimensions - 1]++; - for (var dim = dimensions - 1; dim >= 0; dim--) - { - if (indices[dim] <= result.GetUpperBound(dim)) - continue; - - if (dim == 0) - return result; - - for (var j = dim; j < dimensions; j++) - indices[j] = result.GetLowerBound(j); - indices[dim - 1]++; - } - } - } - - /// - /// Reads a generic list containing elements from the given buffer . - /// - public async ValueTask ReadList(NpgsqlReadBuffer buf, bool async) - { - await buf.Ensure(12, async); - var dimensions = buf.ReadInt32(); - var containsNulls = buf.ReadInt32() == 1; - buf.ReadUInt32(); // Element OID. Ignored. - - if (dimensions == 0) - return CreateCollection(isArray: false, 0); - if (dimensions > 1) - throw new NotSupportedException($"Can't read multidimensional array as List<{ElementType.Name}>"); - - if (containsNulls && IsNonNullable) - throw new InvalidOperationException(ReadNonNullableCollectionWithNullsExceptionMessage); - - await buf.Ensure(8, async); - var length = buf.ReadInt32(); - buf.ReadInt32(); // We don't care about the lower bounds - - var list = CreateCollection(isArray: false, length); - for (var i = 0; i < length; i++) - { - var len = buf.ReadInt32(); - await ReadElement(isArray: false, list, i, buf, len, async); - } - return list; - } - - // Handle single-dimensional arrays and generic IList - public int ValidateAndGetElementLength(object value, int count, ref NpgsqlLengthCache lengthCache) - { - // Leave empty slot for the entire array length, and go ahead an populate the element slots - var pos = lengthCache.Position; - var len = - 4 + // dimensions - 4 + // has_nulls (unused) - 4 + // type OID - 1 * 8 + // number of dimensions (1) * (length + lower bound) - 4 * count; // sum of element lengths - - lengthCache.Set(0); - var elemLengthCache = lengthCache; - - var isArray = value is Array; - for (var i = 0; i < count; i++) - { - try - { - len += ValidateAndGetElementLength(isArray, value, i, ref elemLengthCache, null); - } - catch (Exception e) - { - throw MixedTypesOrJaggedArrayException(e); - } - } - - lengthCache.Lengths[pos] = len; - return len; - } - - // Take care of multi-dimensional arrays and non-generic IList, we have no choice but to box/unbox - public int ValidateAndGetLengthAsObject(ICollection value, ref NpgsqlLengthCache lengthCache) - { - var dimensions = (value as Array)?.Rank ?? 1; - - // Leave empty slot for the entire array length, and go ahead an populate the element slots - var pos = lengthCache.Position; - var len = - 4 + // dimensions - 4 + // has_nulls (unused) - 4 + // type OID - dimensions * 8 + // number of dimensions * (length + lower bound) - 4 * value.Count; // sum of element lengths - - lengthCache.Set(0); - var elemLengthCache = lengthCache; - - var elementHandler = ElementHandler; - foreach (var element in value) - { - if (element is null) - continue; - - try - { - len += elementHandler.ValidateObjectAndGetLength(element, ref elemLengthCache, null); - } - catch (Exception e) - { - throw MixedTypesOrJaggedArrayException(e); - } - } - - lengthCache.Lengths[pos] = len; - return len; - } - - public async Task WriteAsObject(ICollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - var asArray = value as Array; - var dimensions = asArray?.Rank ?? 1; - - var len = - 4 + // ndim - 4 + // has_nulls - 4 + // element_oid - dimensions * 8; // dim (4) + lBound (4) - - if (buf.WriteSpaceLeft < len) - { - await buf.Flush(async, cancellationToken); - Debug.Assert(buf.WriteSpaceLeft >= len, "Buffer too small for header"); - } - - var elementHandler = ElementHandler; - buf.WriteInt32(dimensions); - buf.WriteInt32(1); // HasNulls=1. Not actually used by the backend. - buf.WriteUInt32(elementHandler.PostgresType.OID); - if (asArray != null) - { - for (var i = 0; i < dimensions; i++) - { - buf.WriteInt32(asArray.GetLength(i)); - buf.WriteInt32(_lowerBound); // We don't map .NET lower bounds to PG - } - } - else - { - buf.WriteInt32(value.Count); - buf.WriteInt32(_lowerBound); // We don't map .NET lower bounds to PG - } - - foreach (var element in value) - await elementHandler.WriteObjectWithLength(element, buf, lengthCache, null, async, cancellationToken); - } - - public async Task Write(object value, int count, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - var len = - 4 + // dimensions - 4 + // has_nulls (unused) - 4 + // type OID - 1 * 8; // number of dimensions (1) * (length + lower bound) - if (buf.WriteSpaceLeft < len) - { - await buf.Flush(async, cancellationToken); - Debug.Assert(buf.WriteSpaceLeft >= len, "Buffer too small for header"); - } - - var elementHandler = ElementHandler; - buf.WriteInt32(1); - buf.WriteInt32(1); // has_nulls = 1. Not actually used by the backend. - buf.WriteUInt32(elementHandler.PostgresType.OID); - buf.WriteInt32(count); - buf.WriteInt32(_lowerBound); // We don't map .NET lower bounds to PG - - var isArray = value is Array; - for (var i = 0; i < count; i++) - await WriteElementWithLength(isArray, value, i, buf, lengthCache, null, async, cancellationToken); - } - - static Exception MixedTypesOrJaggedArrayException(Exception innerException) - => new("While trying to write an array, one of its elements failed validation. " + - "You may be trying to mix types in a non-generic IList, or to write a jagged array.", innerException); - - public int ValidateAndGetElementLength(object value, ref NpgsqlLengthCache? lengthCache) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - return value switch - { - _ when IsGenericCollection(value, out var count) => ValidateAndGetElementLength(value, count, ref lengthCache), - ICollection nonGeneric => ValidateAndGetLengthAsObject(nonGeneric, ref lengthCache), - _ => throw CantWriteTypeException(value.GetType()) - }; - } - - public Task WriteElementWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken) - { - buf.WriteInt32(ValidateAndGetElementLength(value, ref lengthCache)); - return value switch - { - _ when IsGenericCollection(value, out var count) => Write(value, count, buf, lengthCache, async, cancellationToken), - ICollection nonGeneric => WriteAsObject(nonGeneric, buf, lengthCache, async, cancellationToken), - _ => throw CantWriteTypeException(value.GetType()) - }; - } - - InvalidCastException CantWriteTypeException(Type type) - => new($"Can't write type '{type}' as an array of {ElementType}"); -} - -sealed class ArrayHandlerCore : ArrayHandlerCore -{ - readonly NpgsqlTypeHandler _elementHandler; - - public ArrayHandlerCore(NpgsqlTypeHandler nonNullableElementHandler, ArrayNullabilityMode arrayNullabilityMode, int lowerBound = 1) - : base(arrayNullabilityMode, lowerBound) - => _elementHandler = nonNullableElementHandler; - - protected override Type ElementType => typeof(TElement); - protected override bool IsNonNullable => typeof(TElement).IsValueType && default(TElement) is not null; - - protected override bool IsGenericCollection(object value, out int count) - { - if (value is ICollection collection) - { - count = collection.Count; - return true; - } - - count = 0; - return false; - } - - protected override NpgsqlTypeHandler ElementHandler => _elementHandler; - - protected override object CreateCollection(bool isArray, int capacity) => isArray switch - { - true => capacity is 0 ? Array.Empty() : new TElement[capacity], - false => new List() - }; - - protected override ValueTask ReadElement(bool isArray, object values, int index, NpgsqlReadBuffer buf, int length, bool async, FieldDescription? fieldDescription = null) - { - // We want a generic mutation so we unfortunately need the null check on this side. - if (length == -1) - { - SetResult(isArray, values, index, (TElement?)(object?)null); - return new ValueTask(); - } - - var task = - NullableHandler.Exists - ? NullableHandler.ReadAsync(_elementHandler, buf, length, async, fieldDescription) - : _elementHandler.Read(buf, length, async, fieldDescription); - - if (!task.IsCompletedSuccessfully) - return Core(isArray, values, index, task); - - SetResult(isArray, values, index, task.GetAwaiter().GetResult()); - return new ValueTask(); - - static async ValueTask Core(bool isArray, object values, int index, ValueTask task) - => SetResult(isArray, values, index, await task); - - static void SetResult(bool isArray, object values, int index, TElement? result) - { - Debug.Assert(isArray ? values is TElement?[] : values is List); - if (isArray) - Unsafe.As(ref values)[index] = result; - else - Unsafe.As>(ref values).Add(result); - } - } - - protected override async ValueTask ReadElement(Array array, int[] indices, NpgsqlReadBuffer buf, int length, bool async, FieldDescription? fieldDescription = null) - { - // Null check is handled in ArrayHandlerOps to reduce code size. - var result = - NullableHandler.Exists - ? await NullableHandler.ReadAsync(_elementHandler, buf, length, async, fieldDescription) - : await _elementHandler.Read(buf, length, async, fieldDescription); - - array.SetValue(result, indices); - } - - protected override int ValidateAndGetElementLength(bool isArray, object values, int index, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - Debug.Assert(isArray ? values is TElement?[] : values is List); - var element = - isArray - ? Unsafe.As(ref values)[index] - : Unsafe.As>(ref values)[index]; - - return element is null - ? 0 - : NullableHandler.Exists - ? NullableHandler.ValidateAndGetLength(_elementHandler, element, ref lengthCache, parameter) - : _elementHandler.ValidateAndGetLength(element, ref lengthCache, parameter); - } - - protected override async ValueTask WriteElementWithLength(bool isArray, object values, int index, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - { - Debug.Assert(isArray ? values is TElement?[] : values is List); - var element = - isArray - ? Unsafe.As(ref values)[index] - : Unsafe.As>(ref values)[index]; - - if (NullableHandler.Exists) - await NullableHandler.WriteAsync(_elementHandler, element!, buf, lengthCache, parameter, async, cancellationToken); - else - await _elementHandler.WriteWithLength(element!, buf, lengthCache, parameter, async, cancellationToken); - } -} diff --git a/src/Npgsql/Internal/TypeHandlers/BitStringHandler.cs b/src/Npgsql/Internal/TypeHandlers/BitStringHandler.cs deleted file mode 100644 index b448463343..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/BitStringHandler.cs +++ /dev/null @@ -1,271 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Specialized; -using System.Diagnostics; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL bit string data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-bit.html. -/// -/// Note that for BIT(1), this handler will return a bool by default, to align with SQLClient -/// (see discussion https://github.com/npgsql/npgsql/pull/362#issuecomment-59622101). -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class BitStringHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler -{ - public BitStringHandler(PostgresType pgType) : base(pgType) {} - - public override Type GetFieldType(FieldDescription? fieldDescription = null) - => fieldDescription != null && fieldDescription.TypeModifier == 1 ? typeof(bool) : typeof(BitArray); - - // BitString requires a special array handler which returns bool or BitArray - /// - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => new BitStringArrayHandler(pgArrayType, this, arrayNullabilityMode); - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numBits = buf.ReadInt32(); - var result = new BitArray(numBits); - var bytesLeft = len - 4; // Remove leading number of bits - if (bytesLeft == 0) - return result; - - var bitNo = 0; - while (true) - { - var iterationEndPos = bytesLeft > buf.ReadBytesLeft - ? bytesLeft - buf.ReadBytesLeft - : 1; - - for (; bytesLeft > iterationEndPos; bytesLeft--) - { - // ReSharper disable ShiftExpressionRealShiftCountIsZero - var chunk = buf.ReadByte(); - result[bitNo++] = (chunk & (1 << 7)) != 0; - result[bitNo++] = (chunk & (1 << 6)) != 0; - result[bitNo++] = (chunk & (1 << 5)) != 0; - result[bitNo++] = (chunk & (1 << 4)) != 0; - result[bitNo++] = (chunk & (1 << 3)) != 0; - result[bitNo++] = (chunk & (1 << 2)) != 0; - result[bitNo++] = (chunk & (1 << 1)) != 0; - result[bitNo++] = (chunk & (1 << 0)) != 0; - } - - if (bytesLeft == 1) - break; - - Debug.Assert(buf.ReadBytesLeft == 0); - await buf.Ensure(Math.Min(bytesLeft, buf.Size), async); - } - - if (bitNo < result.Length) - { - var remainder = result.Length - bitNo; - await buf.Ensure(1, async); - var lastChunk = buf.ReadByte(); - for (var i = 7; i >= 8 - remainder; i--) - result[bitNo++] = (lastChunk & (1 << i)) != 0; - } - - return result; - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - if (len > 4 + 4) - throw new InvalidCastException("Can't read PostgreSQL bitstring with more than 32 bits into BitVector32"); - - await buf.Ensure(4 + 4, async); - - var numBits = buf.ReadInt32(); - return numBits == 0 - ? new BitVector32(0) - : new BitVector32(buf.ReadInt32()); - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(5, async); - var bitLen = buf.ReadInt32(); - if (bitLen != 1) - throw new InvalidCastException("Can't convert a BIT(N) type to bool, only BIT(1)"); - var b = buf.ReadByte(); - return (b & 128) != 0; - } - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing string to PostgreSQL bitstring is supported, no reading."); - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => fieldDescription?.TypeModifier == 1 - ? await Read(buf, len, async, fieldDescription) - : await Read(buf, len, async, fieldDescription); - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(BitArray value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 4 + (value.Length + 7) / 8; - - /// - public int ValidateAndGetLength(BitVector32 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Data == 0 ? 4 : 8; - - /// - public int ValidateAndGetLength(bool value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 5; - - /// - public int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (value.Any(c => c != '0' && c != '1')) - throw new FormatException("Cannot interpret as ASCII BitString: " + value); - return 4 + (value.Length + 7) / 8; - } - - /// - public override async Task Write(BitArray value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // Initial bitlength byte - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Length); - - var byteLen = (value.Length + 7) / 8; - var pos = 0; - while (true) - { - var endPos = pos + Math.Min(byteLen - pos, buf.WriteSpaceLeft); - for (; pos < endPos; pos++) - { - var bitPos = pos*8; - var b = 0; - for (var i = 0; i < Math.Min(8, value.Length - bitPos); i++) - b += (value[bitPos + i] ? 1 : 0) << (8 - i - 1); - buf.WriteByte((byte)b); - } - - if (pos == byteLen) - return; - await buf.Flush(async, cancellationToken); - } - } - - /// - public async Task Write(BitVector32 value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 8) - await buf.Flush(async, cancellationToken); - - if (value.Data == 0) - buf.WriteInt32(0); - else - { - buf.WriteInt32(32); - buf.WriteInt32(value.Data); - } - } - - /// - public async Task Write(bool value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 5) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(1); - buf.WriteByte(value ? (byte)0x80 : (byte)0); - } - - /// - public async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // Initial bitlength byte - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Length); - - var pos = 0; - var byteLen = (value.Length + 7) / 8; - var bytePos = 0; - - while (true) - { - var endBytePos = bytePos + Math.Min(byteLen - bytePos - 1, buf.WriteSpaceLeft); - - for (; bytePos < endBytePos; bytePos++) - { - var b = 0; - b += (value[pos++] - '0') << 7; - b += (value[pos++] - '0') << 6; - b += (value[pos++] - '0') << 5; - b += (value[pos++] - '0') << 4; - b += (value[pos++] - '0') << 3; - b += (value[pos++] - '0') << 2; - b += (value[pos++] - '0') << 1; - b += (value[pos++] - '0'); - buf.WriteByte((byte)b); - } - - if (bytePos >= byteLen - 1) - break; - await buf.Flush(async, cancellationToken); - } - - if (pos < value.Length) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - var remainder = value.Length - pos; - var lastChunk = 0; - for (var i = 7; i >= 8 - remainder; i--) - lastChunk += (value[pos++] - '0') << i; - buf.WriteByte((byte)lastChunk); - } - } - - #endregion -} - -/// -/// A special handler for arrays of bit strings. -/// Differs from the standard array handlers in that it returns arrays of bool for BIT(1) and arrays -/// of BitArray otherwise (just like the scalar BitStringHandler does). -/// -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public class BitStringArrayHandler : ArrayHandler -{ - /// - public BitStringArrayHandler(PostgresType postgresType, BitStringHandler elementHandler, ArrayNullabilityMode arrayNullabilityMode) - : base(postgresType, elementHandler, arrayNullabilityMode) - { } - - public override ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => fieldDescription?.TypeModifier == 1 - ? base.ReadAsObject(typeof(bool), buf, len, async, fieldDescription) - : base.ReadAsObject(buf, len, async, fieldDescription); -} diff --git a/src/Npgsql/Internal/TypeHandlers/BoolHandler.cs b/src/Npgsql/Internal/TypeHandlers/BoolHandler.cs deleted file mode 100644 index c33004c701..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/BoolHandler.cs +++ /dev/null @@ -1,32 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL bool data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-boolean.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class BoolHandler : NpgsqlSimpleTypeHandler -{ - public BoolHandler(PostgresType pgType) : base(pgType) {} - - /// - public override bool Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadByte() != 0; - - /// - public override int ValidateAndGetLength(bool value, NpgsqlParameter? parameter) - => 1; - - /// - public override void Write(bool value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteByte(value ? (byte)1 : (byte)0); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/ByteaHandler.cs b/src/Npgsql/Internal/TypeHandlers/ByteaHandler.cs deleted file mode 100644 index 785250989e..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/ByteaHandler.cs +++ /dev/null @@ -1,148 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL bytea data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-binary.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class ByteaHandler : NpgsqlTypeHandler, INpgsqlTypeHandler>, INpgsqlTypeHandler, INpgsqlTypeHandler>, INpgsqlTypeHandler> -{ - public ByteaHandler(PostgresType pgType) : base(pgType) {} - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - var bytes = new byte[len]; - var pos = 0; - while (true) - { - var toRead = Math.Min(len - pos, buf.ReadBytesLeft); - buf.ReadBytes(bytes, pos, toRead); - pos += toRead; - if (pos == len) - break; - await buf.ReadMore(async); - } - return bytes; - } - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing ArraySegment to PostgreSQL bytea is supported, no reading."); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Reading a PostgreSQL bytea as a Stream is unsupported, use NpgsqlDataReader.GetStream() instead.."); - - int ValidateAndGetLength(int bufferLen, NpgsqlParameter? parameter) - => parameter == null || parameter.Size <= 0 || parameter.Size >= bufferLen - ? bufferLen - : parameter.Size; - - int ValidateAndGetLength(Stream stream, NpgsqlParameter? parameter) - { - if (parameter != null && parameter.Size > 0) - return parameter.Size; - - if (!stream.CanSeek) - throw new NpgsqlException("Cannot write a stream of bytes. Either provide a positive size, or a seekable stream."); - - try - { - return (int)(stream.Length - stream.Position); - } - catch (Exception ex) - { - throw new NpgsqlException("The remaining bytes in the provided Stream exceed the maximum length. The vaule may be truncated by setting NpgsqlParameter.Size.", ex); - } - } - - /// - public override int ValidateAndGetLength(byte[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Length, parameter); - - /// - public int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Count, parameter); - - /// - public int ValidateAndGetLength(Stream value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, parameter); - - /// - public override Task Write(byte[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, 0, ValidateAndGetLength(value.Length, parameter), async, cancellationToken); - - /// - public Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value.Array is null ? Task.CompletedTask : Write(value.Array, buf, value.Offset, ValidateAndGetLength(value.Count, parameter), async, cancellationToken); - - /// - public Task Write(Stream value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, ValidateAndGetLength(value, parameter), async, cancellationToken); - - async Task Write(byte[] value, NpgsqlWriteBuffer buf, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - // The entire segment fits in our buffer, copy it as usual. - if (count <= buf.WriteSpaceLeft) - { - buf.WriteBytes(value, offset, count); - return; - } - - // The segment is larger than our buffer. Flush whatever is currently in the buffer and - // write the array directly to the socket. - await buf.Flush(async, cancellationToken); - await buf.DirectWrite(new ReadOnlyMemory(value, offset, count), async, cancellationToken); - } - - Task Write(Stream value, NpgsqlWriteBuffer buf, int count, bool async, CancellationToken cancellationToken = default) - => buf.WriteStreamRaw(value, count, async, cancellationToken); - - /// - public int ValidateAndGetLength(Memory value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Length, parameter); - - /// - public int ValidateAndGetLength(ReadOnlyMemory value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Length, parameter); - - /// - public async Task Write(ReadOnlyMemory value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (parameter != null && parameter.Size > 0 && parameter.Size < value.Length) - value = value.Slice(0, parameter.Size); - - // The entire segment fits in our buffer, copy it into the buffer as usual. - if (value.Length <= buf.WriteSpaceLeft) - { - buf.WriteBytes(value.Span); - return; - } - - // The segment is larger than our buffer. Perform a direct write, flushing whatever is currently in the buffer - // and then writing the array directly to the socket. - await buf.DirectWrite(value, async, cancellationToken); - } - - /// - public Task Write(Memory value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((ReadOnlyMemory)value, buf, lengthCache, parameter, async, cancellationToken); - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescriptioncancellationToken) - => throw new NotSupportedException("Only writing ReadOnlyMemory to PostgreSQL bytea is supported, no reading."); - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing Memory to PostgreSQL bytea is supported, no reading."); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/ByReference.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/ByReference.cs deleted file mode 100644 index e5f02bddbe..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/ByReference.cs +++ /dev/null @@ -1,10 +0,0 @@ - -// Only used for value types, but can't constrain because MappedCompositeHandler isn't constrained -#nullable disable - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -sealed class ByReference -{ - public T Value; -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeConstructorHandler.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeConstructorHandler.cs deleted file mode 100644 index b1b633748b..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeConstructorHandler.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using System.Reflection; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -class CompositeConstructorHandler -{ - public PostgresType PostgresType { get; } - public ConstructorInfo ConstructorInfo { get; } - public CompositeParameterHandler[] Handlers { get; } - - protected CompositeConstructorHandler(PostgresType postgresType, ConstructorInfo constructorInfo, CompositeParameterHandler[] handlers) - { - PostgresType = postgresType; - ConstructorInfo = constructorInfo; - Handlers = handlers; - } - - public virtual async ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(int), async); - - var fieldCount = buffer.ReadInt32(); - if (fieldCount != Handlers.Length) - throw new InvalidOperationException($"pg_attributes contains {Handlers.Length} fields for type {PostgresType.DisplayName}, but {fieldCount} fields were received."); - - var args = new object?[Handlers.Length]; - foreach (var handler in Handlers) - args[handler.ParameterPosition] = await handler.Read(buffer, async); - - return (TComposite)ConstructorInfo.Invoke(args); - } - - public static CompositeConstructorHandler Create(PostgresType postgresType, ConstructorInfo constructorInfo, CompositeParameterHandler[] parameterHandlers) - { - const int maxGenericParameters = 8; - - if (parameterHandlers.Length > maxGenericParameters) - return new CompositeConstructorHandler(postgresType, constructorInfo, parameterHandlers); - - var parameterTypes = new Type[1 + maxGenericParameters]; - foreach (var parameterHandler in parameterHandlers) - parameterTypes[1 + parameterHandler.ParameterPosition] = parameterHandler.ParameterType; - - for (var parameterIndex = 1; parameterIndex < parameterTypes.Length; parameterIndex++) - parameterTypes[parameterIndex] ??= typeof(Unused); - - parameterTypes[0] = typeof(TComposite); - return (CompositeConstructorHandler)Activator.CreateInstance( - typeof(CompositeConstructorHandler<,,,,,,,,>).MakeGenericType(parameterTypes), - BindingFlags.Instance | BindingFlags.Public, - binder: null, - args: new object[] { postgresType, constructorInfo, parameterHandlers }, - culture: null)!; - } - - readonly struct Unused - { - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeConstructorHandler`.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeConstructorHandler`.cs deleted file mode 100644 index b7d8a7b7b0..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeConstructorHandler`.cs +++ /dev/null @@ -1,66 +0,0 @@ -using System; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -sealed class CompositeConstructorHandler : CompositeConstructorHandler -{ - delegate TComposite CompositeConstructor(in Arguments args); - - readonly CompositeConstructor _constructor; - - public CompositeConstructorHandler(PostgresType postgresType, ConstructorInfo constructorInfo, CompositeParameterHandler[] parameterHandlers) - : base(postgresType, constructorInfo, parameterHandlers) - { - var parameter = Expression.Parameter(typeof(Arguments).MakeByRefType()); - var fields = Enumerable - .Range(1, parameterHandlers.Length) - .Select(i => Expression.Field(parameter, "Argument" + i)); - - _constructor = Expression - .Lambda(Expression.New(constructorInfo, fields), parameter) - .Compile(); - } - - public override async ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(int), async); - - var fieldCount = buffer.ReadInt32(); - if (fieldCount != Handlers.Length) - throw new InvalidOperationException($"pg_attributes contains {Handlers.Length} fields for type {PostgresType.DisplayName}, but {fieldCount} fields were received."); - - var args = default(Arguments); - - foreach (var handler in Handlers) - switch (handler.ParameterPosition) - { - case 0: args.Argument1 = await handler.Read(buffer, async); break; - case 1: args.Argument2 = await handler.Read(buffer, async); break; - case 2: args.Argument3 = await handler.Read(buffer, async); break; - case 3: args.Argument4 = await handler.Read(buffer, async); break; - case 4: args.Argument5 = await handler.Read(buffer, async); break; - case 5: args.Argument6 = await handler.Read(buffer, async); break; - case 6: args.Argument7 = await handler.Read(buffer, async); break; - case 7: args.Argument8 = await handler.Read(buffer, async); break; - } - - return _constructor(args); - } - - struct Arguments - { - public T1 Argument1; - public T2 Argument2; - public T3 Argument3; - public T4 Argument4; - public T5 Argument5; - public T6 Argument6; - public T7 Argument7; - public T8 Argument8; - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeHandler.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeHandler.cs deleted file mode 100644 index 5079b24b1d..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeHandler.cs +++ /dev/null @@ -1,282 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -#region Trimming warning suppressions - -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2046", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2080", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2026", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2090", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2087", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2055", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] -[module: UnconditionalSuppressMessage( - "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", - "IL2077", Scope = "type", Target = "Npgsql.Internal.TypeHandlers.CompositeHandlers.CompositeHandler")] - -#endregion - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -sealed partial class CompositeHandler : NpgsqlTypeHandler, ICompositeHandler -{ - readonly TypeMapper _typeMapper; - readonly INpgsqlNameTranslator _nameTranslator; - - Func? _constructor; - CompositeConstructorHandler? _constructorHandler; - CompositeMemberHandler[] _memberHandlers = null!; - - public Type CompositeType => typeof(T); - - public CompositeHandler(PostgresCompositeType postgresType, TypeMapper typeMapper, INpgsqlNameTranslator nameTranslator) - : base(postgresType) - { - _typeMapper = typeMapper; - _nameTranslator = nameTranslator; - } - - public override ValueTask Read(NpgsqlReadBuffer buffer, int length, bool async, FieldDescription? fieldDescription = null) - { - Initialize(); - - return _constructorHandler is null - ? ReadUsingMemberHandlers(buffer, async) - : _constructorHandler.Read(buffer, async); - - async ValueTask ReadUsingMemberHandlers(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(int), async); - - var fieldCount = buffer.ReadInt32(); - if (fieldCount != _memberHandlers.Length) - throw new InvalidOperationException($"pg_attributes contains {_memberHandlers.Length} fields for type {PgDisplayName}, but {fieldCount} fields were received."); - - if (IsValueType.Value) - { - var composite = new ByReference { Value = _constructor!() }; - foreach (var member in _memberHandlers) - await member.Read(composite, buffer, async); - - return composite.Value; - } - else - { - var composite = _constructor!(); - foreach (var member in _memberHandlers) - await member.Read(composite, buffer, async); - - return composite; - } - } - } - - public override async Task Write(T value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - Initialize(); - - if (buffer.WriteSpaceLeft < sizeof(int)) - await buffer.Flush(async, cancellationToken); - - buffer.WriteInt32(_memberHandlers.Length); - - foreach (var member in _memberHandlers) - await member.Write(value, buffer, lengthCache, async, cancellationToken); - } - - public override int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - Initialize(); - - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - // Leave empty slot for the entire composite type, and go ahead an populate the element slots - var position = lengthCache.Position; - lengthCache.Set(0); - - // number of fields + (type oid + field length) * member count - var length = sizeof(int) + sizeof(int) * 2 * _memberHandlers.Length; - foreach (var member in _memberHandlers) - length += member.ValidateAndGetLength(value, ref lengthCache); - - return lengthCache.Lengths[position] = length; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void Initialize() - { - if (_memberHandlers is null) - InitializeCore(); - - void InitializeCore() - { - var pgType = (PostgresCompositeType)PostgresType; - - _memberHandlers = CreateMemberHandlers(pgType, _typeMapper, _nameTranslator); - _constructorHandler = CreateConstructorHandler(pgType, _typeMapper, _nameTranslator); - _constructor = _constructorHandler is null - ? Expression - .Lambda>(Expression.New(typeof(T))) - .Compile() - : null; - } - } - - static CompositeConstructorHandler? CreateConstructorHandler(PostgresCompositeType pgType, TypeMapper typeMapper, INpgsqlNameTranslator nameTranslator) - { - var pgFields = pgType.Fields; - var clrType = typeof(T); - - ConstructorInfo? clrDefaultConstructor = null; - - foreach (var clrConstructor in clrType.GetConstructors()) - { - var clrParameters = clrConstructor.GetParameters(); - if (clrParameters.Length != pgFields.Count) - { - if (clrParameters.Length == 0) - clrDefaultConstructor = clrConstructor; - - continue; - } - - var clrParameterHandlerCount = 0; - var clrParametersMapped = new ParameterInfo[pgFields.Count]; - - foreach (var clrParameter in clrParameters) - { - var attr = clrParameter.GetCustomAttribute(); - var name = attr?.PgName ?? (clrParameter.Name is string clrName ? nameTranslator.TranslateMemberName(clrName) : null); - if (name is null) - break; - - for (var pgFieldIndex = pgFields.Count - 1; pgFieldIndex >= 0; --pgFieldIndex) - { - var pgField = pgFields[pgFieldIndex]; - if (pgField.Name != name) - continue; - - if (clrParametersMapped[pgFieldIndex] != null) - throw new AmbiguousMatchException($"Multiple constructor parameters are mapped to the '{pgField.Name}' field."); - - clrParameterHandlerCount++; - clrParametersMapped[pgFieldIndex] = clrParameter; - - break; - } - } - - if (clrParameterHandlerCount < pgFields.Count) - continue; - - var clrParameterHandlers = new CompositeParameterHandler[pgFields.Count]; - for (var pgFieldIndex = 0; pgFieldIndex < pgFields.Count; ++pgFieldIndex) - { - var pgField = pgFields[pgFieldIndex]; - - if (!typeMapper.TryResolveByOID(pgField.Type.OID, out var handler)) - throw new NpgsqlException($"PostgreSQL composite type {pgType.DisplayName} has field {pgField.Type.DisplayName} with an unknown type (OID = {pgField.Type.OID})."); - - var clrParameter = clrParametersMapped[pgFieldIndex]; - var clrParameterHandlerType = typeof(CompositeParameterHandler<>) - .MakeGenericType(clrParameter.ParameterType); - - clrParameterHandlers[pgFieldIndex] = (CompositeParameterHandler)Activator.CreateInstance( - clrParameterHandlerType, - BindingFlags.Instance | BindingFlags.Public, - binder: null, - args: new object[] { handler, clrParameter }, - culture: null)!; - } - - return CompositeConstructorHandler.Create(pgType, clrConstructor, clrParameterHandlers); - } - - if (clrDefaultConstructor is null && !clrType.IsValueType) - throw new InvalidOperationException($"No parameterless constructor defined for type '{clrType}'."); - - return null; - } - - static CompositeMemberHandler[] CreateMemberHandlers(PostgresCompositeType pgType, TypeMapper typeMapper, INpgsqlNameTranslator nameTranslator) - { - var pgFields = pgType.Fields; - - var clrType = typeof(T); - var clrMemberHandlers = new CompositeMemberHandler[pgFields.Count]; - var clrMemberHandlerCount = 0; - var clrMemberHandlerType = IsValueType.Value - ? typeof(CompositeStructMemberHandler<,>) - : typeof(CompositeClassMemberHandler<,>); - - foreach (var clrProperty in clrType.GetProperties(BindingFlags.Instance | BindingFlags.Public)) - CreateMemberHandler(clrProperty, clrProperty.PropertyType); - - foreach (var clrField in clrType.GetFields(BindingFlags.Instance | BindingFlags.Public)) - CreateMemberHandler(clrField, clrField.FieldType); - - if (clrMemberHandlerCount != pgFields.Count) - { - var notMappedFields = string.Join(", ", clrMemberHandlers - .Select((member, memberIndex) => member == null ? $"'{pgFields[memberIndex].Name}'" : null) - .Where(member => member != null)); - throw new InvalidOperationException($"PostgreSQL composite type {pgType.DisplayName} contains fields {notMappedFields} which could not match any on CLR type {clrType.Name}"); - } - - return clrMemberHandlers; - - void CreateMemberHandler(MemberInfo clrMember, Type clrMemberType) - { - var attr = clrMember.GetCustomAttribute(); - var name = attr?.PgName ?? nameTranslator.TranslateMemberName(clrMember.Name); - - for (var pgFieldIndex = pgFields.Count - 1; pgFieldIndex >= 0; --pgFieldIndex) - { - var pgField = pgFields[pgFieldIndex]; - if (pgField.Name != name) - continue; - - if (clrMemberHandlers[pgFieldIndex] != null) - throw new AmbiguousMatchException($"Multiple class members are mapped to the '{pgField.Name}' field."); - - if (!typeMapper.TryResolveByOID(pgField.Type.OID, out var handler)) - throw new NpgsqlException($"PostgreSQL composite type {pgType.DisplayName} has field {pgField.Type.DisplayName} with an unknown type (OID = {pgField.Type.OID})."); - - clrMemberHandlerCount++; - clrMemberHandlers[pgFieldIndex] = (CompositeMemberHandler)Activator.CreateInstance( - clrMemberHandlerType.MakeGenericType(clrType, clrMemberType), - BindingFlags.Instance | BindingFlags.Public, - binder: null, - args: new object[] { clrMember, pgField.Type, handler }, - culture: null)!; - - break; - } - } - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandler.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandler.cs deleted file mode 100644 index 48d57e9c82..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandler.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -abstract class CompositeMemberHandler -{ - public MemberInfo MemberInfo { get; } - public PostgresType PostgresType { get; } - - protected CompositeMemberHandler(MemberInfo memberInfo, PostgresType postgresType) - { - MemberInfo = memberInfo; - PostgresType = postgresType; - } - - public abstract ValueTask Read(TComposite composite, NpgsqlReadBuffer buffer, bool async); - - public abstract ValueTask Read(ByReference composite, NpgsqlReadBuffer buffer, bool async); - - public abstract Task Write(TComposite composite, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default); - - public abstract int ValidateAndGetLength(TComposite composite, [NotNullIfNotNull("lengthCache")] ref NpgsqlLengthCache? lengthCache); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfClass.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfClass.cs deleted file mode 100644 index 0593e4d67e..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfClass.cs +++ /dev/null @@ -1,105 +0,0 @@ -using System; -using System.Diagnostics; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -sealed class CompositeClassMemberHandler : CompositeMemberHandler - where TComposite : class -{ - delegate TMember GetMember(TComposite composite); - delegate void SetMember(TComposite composite, TMember value); - - readonly GetMember? _get; - readonly SetMember? _set; - readonly NpgsqlTypeHandler _handler; - - public CompositeClassMemberHandler(FieldInfo fieldInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(fieldInfo, postgresType) - { - var composite = Expression.Parameter(typeof(TComposite), "composite"); - var value = Expression.Parameter(typeof(TMember), "value"); - - _get = Expression - .Lambda(Expression.Field(composite, fieldInfo), composite) - .Compile(); - _set = Expression - .Lambda(Expression.Assign(Expression.Field(composite, fieldInfo), value), composite, value) - .Compile(); - _handler = handler; - } - - public CompositeClassMemberHandler(PropertyInfo propertyInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(propertyInfo, postgresType) - { - var getMethod = propertyInfo.GetGetMethod(); - if (getMethod != null) - _get = (GetMember)Delegate.CreateDelegate(typeof(GetMember), getMethod); - - var setMethod = propertyInfo.GetSetMethod(); - if (setMethod != null) - _set = (SetMember)Delegate.CreateDelegate(typeof(SetMember), setMethod); - - Debug.Assert(setMethod != null || getMethod != null); - - _handler = handler; - } - - public override async ValueTask Read(TComposite composite, NpgsqlReadBuffer buffer, bool async) - { - if (_set == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertySetter(typeof(TComposite), MemberInfo); - - await buffer.Ensure(sizeof(uint) + sizeof(int), async); - - var oid = buffer.ReadUInt32(); - Debug.Assert(oid == PostgresType.OID); - - var length = buffer.ReadInt32(); - if (length == -1) - return; - - var value = NullableHandler.Exists - ? await NullableHandler.ReadAsync(_handler, buffer, length, async) - : await _handler.Read(buffer, length, async); - - _set(composite, value); - } - - public override ValueTask Read(ByReference composite, NpgsqlReadBuffer buffer, bool async) - => throw new NotSupportedException(); - - public override async Task Write(TComposite composite, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - if (buffer.WriteSpaceLeft < sizeof(int)) - await buffer.Flush(async, cancellationToken); - - buffer.WriteUInt32(PostgresType.OID); - if (NullableHandler.Exists) - await NullableHandler.WriteAsync(_handler, _get(composite), buffer, lengthCache, null, async, cancellationToken); - else - await _handler.WriteWithLength(_get(composite), buffer, lengthCache, null, async, cancellationToken); - } - - public override int ValidateAndGetLength(TComposite composite, ref NpgsqlLengthCache? lengthCache) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - var value = _get(composite); - if (value is null) - return 0; - - return NullableHandler.Exists - ? NullableHandler.ValidateAndGetLength(_handler, value, ref lengthCache, null) - : _handler.ValidateAndGetLength(value, ref lengthCache, null); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfStruct.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfStruct.cs deleted file mode 100644 index 2fa1d48ca3..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfStruct.cs +++ /dev/null @@ -1,109 +0,0 @@ -using System; -using System.Diagnostics; -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -sealed class CompositeStructMemberHandler : CompositeMemberHandler - where TComposite : struct -{ - delegate TMember GetMember(ref TComposite composite); - delegate void SetMember(ref TComposite composite, TMember value); - - readonly GetMember? _get; - readonly SetMember? _set; - readonly NpgsqlTypeHandler _handler; - - public CompositeStructMemberHandler(FieldInfo fieldInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(fieldInfo, postgresType) - { - var composite = Expression.Parameter(typeof(TComposite).MakeByRefType(), "composite"); - var value = Expression.Parameter(typeof(TMember), "value"); - - _get = Expression - .Lambda(Expression.Field(composite, fieldInfo), composite) - .Compile(); - _set = Expression - .Lambda(Expression.Assign(Expression.Field(composite, fieldInfo), value), composite, value) - .Compile(); - _handler = handler; - } - - public CompositeStructMemberHandler(PropertyInfo propertyInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(propertyInfo, postgresType) - { - var getMethod = propertyInfo.GetGetMethod(); - if (getMethod != null) - _get = (GetMember)Delegate.CreateDelegate(typeof(GetMember), getMethod); - - var setMethod = propertyInfo.GetSetMethod(); - if (setMethod != null) - _set = (SetMember)Delegate.CreateDelegate(typeof(SetMember), setMethod); - - Debug.Assert(setMethod != null || getMethod != null); - - _handler = handler; - } - - public override ValueTask Read(TComposite composite, NpgsqlReadBuffer buffer, bool async) - => throw new NotSupportedException(); - - public override async ValueTask Read(ByReference composite, NpgsqlReadBuffer buffer, bool async) - { - if (_set == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertySetter(typeof(TComposite), MemberInfo); - - await buffer.Ensure(sizeof(uint) + sizeof(int), async); - - var oid = buffer.ReadUInt32(); - Debug.Assert(oid == PostgresType.OID); - - var length = buffer.ReadInt32(); - if (length == -1) - return; - - var value = NullableHandler.Exists - ? await NullableHandler.ReadAsync(_handler, buffer, length, async) - : await _handler.Read(buffer, length, async); - - Set(composite, value); - } - - public override async Task Write(TComposite composite, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - if (buffer.WriteSpaceLeft < sizeof(int)) - await buffer.Flush(async, cancellationToken); - - buffer.WriteUInt32(PostgresType.OID); - await (NullableHandler.Exists - ? NullableHandler.WriteAsync(_handler, _get(ref composite), buffer, lengthCache, null, async, cancellationToken) - : _handler.WriteWithLength(_get(ref composite), buffer, lengthCache, null, async, cancellationToken)); - } - - public override int ValidateAndGetLength(TComposite composite, ref NpgsqlLengthCache? lengthCache) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - var value = _get(ref composite); - if (value is null) - return 0; - - return NullableHandler.Exists - ? NullableHandler.ValidateAndGetLength(_handler, value, ref lengthCache, null) - : _handler.ValidateAndGetLength(value, ref lengthCache, null); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void Set(ByReference composite, TMember value) - => _set!(ref composite.Value, value); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeParameterHandler.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeParameterHandler.cs deleted file mode 100644 index f99de18bba..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeParameterHandler.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System; -using System.Reflection; -using System.Threading.Tasks; -using Npgsql.Internal.TypeHandling; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -abstract class CompositeParameterHandler -{ - public NpgsqlTypeHandler Handler { get; } - public Type ParameterType { get; } - public int ParameterPosition { get; } - - public CompositeParameterHandler(NpgsqlTypeHandler handler, ParameterInfo parameterInfo) - { - Handler = handler; - ParameterType = parameterInfo.ParameterType; - ParameterPosition = parameterInfo.Position; - } - - public async ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(uint) + sizeof(int), async); - - var oid = buffer.ReadUInt32(); - var length = buffer.ReadInt32(); - if (length == -1) - return default!; - - return NullableHandler.Exists - ? await NullableHandler.ReadAsync(Handler, buffer, length, async) - : await Handler.Read(buffer, length, async); - } - - public abstract ValueTask Read(NpgsqlReadBuffer buffer, bool async); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeParameterHandler`.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeParameterHandler`.cs deleted file mode 100644 index 6c2d9dab8d..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/CompositeParameterHandler`.cs +++ /dev/null @@ -1,21 +0,0 @@ -using System.Reflection; -using System.Threading.Tasks; -using Npgsql.Internal.TypeHandling; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -sealed class CompositeParameterHandler : CompositeParameterHandler -{ - public CompositeParameterHandler(NpgsqlTypeHandler handler, ParameterInfo parameterInfo) - : base(handler, parameterInfo) { } - - public override ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - var task = Read(buffer, async); - return task.IsCompleted - ? new ValueTask(task.Result) - : AwaitTask(task); - - static async ValueTask AwaitTask(ValueTask task) => await task; - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/ICompositeHandler.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/ICompositeHandler.cs deleted file mode 100644 index 5bb186233b..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/ICompositeHandler.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; - -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -interface ICompositeHandler -{ - /// - /// The CLR type mapped to the PostgreSQL composite type. - /// - Type CompositeType { get; } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/IsValueType.cs b/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/IsValueType.cs deleted file mode 100644 index 360cae915d..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/CompositeHandlers/IsValueType.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Npgsql.Internal.TypeHandlers.CompositeHandlers; - -static class IsValueType -{ - public static readonly bool Value = typeof(T).IsValueType; -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/DateHandler.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/DateHandler.cs deleted file mode 100644 index 0831306a67..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/DateHandler.cs +++ /dev/null @@ -1,131 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.Properties; -using NpgsqlTypes; -using static Npgsql.Util.Statics; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -/// -/// A type handler for the PostgreSQL date data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class DateHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -#if NET6_0_OR_GREATER - , INpgsqlSimpleTypeHandler -#endif -{ - static readonly DateTime BaseValueDateTime = new(2000, 1, 1, 0, 0, 0); - - /// - /// Constructs a - /// - public DateHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override DateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt32() switch - { - int.MaxValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateTime.MaxValue, - int.MinValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateTime.MinValue, - var value => BaseValueDateTime + TimeSpan.FromDays(value) - }; - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadInt32(); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) => 4; - - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => 4; - - /// - public override void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (!DisableDateTimeInfinityConversions) - { - if (value == DateTime.MaxValue) - { - buf.WriteInt32(int.MaxValue); - return; - } - - if (value == DateTime.MinValue) - { - buf.WriteInt32(int.MinValue); - return; - } - } - - buf.WriteInt32((value.Date - BaseValueDateTime).Days); - } - - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt32(value); - - #endregion Write - -#if NET6_0_OR_GREATER - static readonly DateOnly BaseValueDateOnly = new(2000, 1, 1); - - DateOnly INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadInt32() switch - { - int.MaxValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateOnly.MaxValue, - int.MinValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateOnly.MinValue, - var value => BaseValueDateOnly.AddDays(value) - }; - - public int ValidateAndGetLength(DateOnly value, NpgsqlParameter? parameter) => 4; - - public void Write(DateOnly value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (!DisableDateTimeInfinityConversions) - { - if (value == DateOnly.MaxValue) - { - buf.WriteInt32(int.MaxValue); - return; - } - - if (value == DateOnly.MinValue) - { - buf.WriteInt32(int.MinValue); - return; - } - } - - buf.WriteInt32(value.DayNumber - BaseValueDateOnly.DayNumber); - } - - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => new RangeHandler(pgRangeType, this); - - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgRangeType) - => new MultirangeHandler(pgRangeType, new RangeHandler(pgRangeType, this)); -#endif -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/DateTimeUtils.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/DateTimeUtils.cs deleted file mode 100644 index 8b702aad12..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/DateTimeUtils.cs +++ /dev/null @@ -1,63 +0,0 @@ -using System; -using System.Runtime.CompilerServices; -using Npgsql.Properties; -using static Npgsql.Util.Statics; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -static class DateTimeUtils -{ - const long PostgresTimestampOffsetTicks = 630822816000000000L; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static DateTime DecodeTimestamp(long value, DateTimeKind kind) - => new(value * 10 + PostgresTimestampOffsetTicks, kind); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static long EncodeTimestamp(DateTime value) - // Rounding here would cause problems because we would round up DateTime.MaxValue - // which would make it impossible to retrieve it back from the database, so we just drop the additional precision - => (value.Ticks - PostgresTimestampOffsetTicks) / 10; - - internal static DateTime ReadDateTime(NpgsqlReadBuffer buf, DateTimeKind kind) - { - try - { - return buf.ReadInt64() switch - { - long.MaxValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateTime.MaxValue, - long.MinValue => DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateTime.MinValue, - var value => DecodeTimestamp(value, kind) - }; - } - catch (ArgumentOutOfRangeException e) - { - throw new InvalidCastException("Out of the range of DateTime (year must be between 1 and 9999)", e); - } - } - - internal static void WriteTimestamp(DateTime value, NpgsqlWriteBuffer buf) - { - if (!DisableDateTimeInfinityConversions) - { - if (value == DateTime.MaxValue) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value == DateTime.MinValue) - { - buf.WriteInt64(long.MinValue); - return; - } - } - - var postgresTimestamp = EncodeTimestamp(value); - buf.WriteInt64(postgresTimestamp); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/IntervalHandler.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/IntervalHandler.cs deleted file mode 100644 index 9cce23e486..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/IntervalHandler.cs +++ /dev/null @@ -1,70 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.Properties; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -/// -/// A type handler for the PostgreSQL date interval type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class IntervalHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - /// - /// Constructs an - /// - public IntervalHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override TimeSpan Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var microseconds = buf.ReadInt64(); - var days = buf.ReadInt32(); - var months = buf.ReadInt32(); - - if (months > 0) - throw new InvalidCastException(NpgsqlStrings.CannotReadIntervalWithMonthsAsTimeSpan); - - return new(microseconds * 10 + days * TimeSpan.TicksPerDay); - } - - NpgsqlInterval INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var ticks = buf.ReadInt64(); - var day = buf.ReadInt32(); - var month = buf.ReadInt32(); - return new NpgsqlInterval(month, day, ticks); - } - - /// - public override int ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) => 16; - - /// - public int ValidateAndGetLength(NpgsqlInterval value, NpgsqlParameter? parameter) => 16; - - /// - public override void Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var ticksInDay = value.Ticks - TimeSpan.TicksPerDay * value.Days; - - buf.WriteInt64(ticksInDay / 10); - buf.WriteInt32(value.Days); - buf.WriteInt32(0); - } - - public void Write(NpgsqlInterval value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.Time); - buf.WriteInt32(value.Days); - buf.WriteInt32(value.Months); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimeHandler.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimeHandler.cs deleted file mode 100644 index f4ec3b689b..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimeHandler.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -/// -/// A type handler for the PostgreSQL time data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TimeHandler : NpgsqlSimpleTypeHandler -#if NET6_0_OR_GREATER - , INpgsqlSimpleTypeHandler -#endif -{ - /// - /// Constructs a . - /// - public TimeHandler(PostgresType postgresType) : base(postgresType) {} - - // PostgreSQL time resolution == 1 microsecond == 10 ticks - /// - public override TimeSpan Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new(buf.ReadInt64() * 10); - - /// - public override int ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) => 8; - - /// - public override void Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value.Ticks / 10); - -#if NET6_0_OR_GREATER - TimeOnly INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => new(buf.ReadInt64() * 10); - - public int ValidateAndGetLength(TimeOnly value, NpgsqlParameter? parameter) => 8; - - public void Write(TimeOnly value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value.Ticks / 10); -#endif -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimeTzHandler.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimeTzHandler.cs deleted file mode 100644 index 464c4abd01..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimeTzHandler.cs +++ /dev/null @@ -1,53 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -/// -/// A type handler for the PostgreSQL timetz data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TimeTzHandler : NpgsqlSimpleTypeHandler -{ - // Binary Format: int64 expressing microseconds, int32 expressing timezone in seconds, negative - - /// - /// Constructs an . - /// - public TimeTzHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override DateTimeOffset Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. - var ticks = buf.ReadInt64() * 10; - var offset = new TimeSpan(0, 0, -buf.ReadInt32()); - return new DateTimeOffset(ticks + TimeSpan.TicksPerDay, offset); - } - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) => 12; - - /// - public override void Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.TimeOfDay.Ticks / 10); - buf.WriteInt32(-(int)(value.Offset.Ticks / TimeSpan.TicksPerSecond)); - } - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimestampHandler.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimestampHandler.cs deleted file mode 100644 index 1887318b44..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimestampHandler.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using static Npgsql.Util.Statics; -using static Npgsql.Internal.TypeHandlers.DateTimeHandlers.DateTimeUtils; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -/// -/// A type handler for the PostgreSQL timestamp data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TimestampHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - /// - /// Constructs a . - /// - public TimestampHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override DateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadDateTime(buf, DateTimeKind.Unspecified); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadInt64(); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => value.Kind != DateTimeKind.Utc || LegacyTimestampBehavior - ? 8 - : throw new InvalidCastException( - "Cannot write DateTime with Kind=UTC to PostgreSQL type 'timestamp without time zone', " + - "consider using 'timestamp with time zone'. " + - "Note that it's not possible to mix DateTimes with different Kinds in an array/range. " + - "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) => 8; - - /// - public override void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => WriteTimestamp(value, buf); - - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimestampTzHandler.cs b/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimestampTzHandler.cs deleted file mode 100644 index 66b3397ecb..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/DateTimeHandlers/TimestampTzHandler.cs +++ /dev/null @@ -1,143 +0,0 @@ -using System; -using System.Diagnostics; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.Properties; -using NpgsqlTypes; -using static Npgsql.Util.Statics; -using static Npgsql.Internal.TypeHandlers.DateTimeHandlers.DateTimeUtils; - -namespace Npgsql.Internal.TypeHandlers.DateTimeHandlers; - -/// -/// A type handler for the PostgreSQL timestamptz data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TimestampTzHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - /// - /// Constructs an . - /// - public TimestampTzHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => new RangeHandler(pgRangeType, this); - - #region Read - - /// - public override DateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var dateTime = ReadDateTime(buf, DateTimeKind.Utc); - return LegacyTimestampBehavior && (DisableDateTimeInfinityConversions || dateTime != DateTime.MaxValue && dateTime != DateTime.MinValue) - ? dateTime.ToLocalTime() - : dateTime; - } - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - try - { - var value = buf.ReadInt64(); - switch (value) - { - case long.MaxValue: - return DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateTimeOffset.MaxValue; - case long.MinValue: - return DisableDateTimeInfinityConversions - ? throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue) - : DateTimeOffset.MinValue; - default: - var dateTime = DecodeTimestamp(value, DateTimeKind.Utc); - return LegacyTimestampBehavior ? dateTime.ToLocalTime() : dateTime; - } - } - catch (ArgumentOutOfRangeException e) - { - throw new InvalidCastException("Out of the range of DateTime (year must be between 1 and 9999)", e); - } - } - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadInt64(); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => value.Kind == DateTimeKind.Utc || - value == DateTime.MinValue || // Allowed since this is default(DateTime) - sent without any timezone conversion. - value == DateTime.MaxValue && !DisableDateTimeInfinityConversions || - LegacyTimestampBehavior - ? 8 - : throw new InvalidCastException( - $"Cannot write DateTime with Kind={value.Kind} to PostgreSQL type 'timestamp with time zone', only UTC is supported. " + - "Note that it's not possible to mix DateTimes with different Kinds in an array/range. " + - "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); - - /// - public int ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) - => value.Offset == TimeSpan.Zero || LegacyTimestampBehavior - ? 8 - : throw new InvalidCastException( - $"Cannot write DateTimeOffset with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', " + - "only offset 0 (UTC) is supported. " + - "Note that it's not possible to mix DateTimes with different Kinds in an array/range. " + - "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) => 8; - - /// - public override void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (LegacyTimestampBehavior) - { - switch (value.Kind) - { - case DateTimeKind.Unspecified: - case DateTimeKind.Utc: - break; - case DateTimeKind.Local: - value = value.ToUniversalTime(); - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {value.Kind} of enum {nameof(DateTimeKind)}. Please file a bug."); - } - } - else - Debug.Assert(value.Kind == DateTimeKind.Utc || value == DateTime.MinValue || value == DateTime.MaxValue); - - WriteTimestamp(value, buf); - } - - /// - public void Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (LegacyTimestampBehavior) - value = value.ToUniversalTime(); - - Debug.Assert(value.Offset == TimeSpan.Zero); - - WriteTimestamp(value.DateTime, buf); - } - - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/EnumHandler.cs b/src/Npgsql/Internal/TypeHandlers/EnumHandler.cs deleted file mode 100644 index 2604563790..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/EnumHandler.cs +++ /dev/null @@ -1,74 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Reflection; -using System.Text; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// Interface implemented by all concrete handlers which handle enums -/// -interface IEnumHandler -{ - /// - /// The CLR enum type mapped to the PostgreSQL enum - /// - Type EnumType { get; } -} - -sealed partial class EnumHandler : NpgsqlSimpleTypeHandler, IEnumHandler where TEnum : struct, Enum -{ - readonly Dictionary _enumToLabel; - readonly Dictionary _labelToEnum; - - public Type EnumType => typeof(TEnum); - - #region Construction - - internal EnumHandler(PostgresEnumType postgresType, Dictionary enumToLabel, Dictionary labelToEnum) - : base(postgresType) - { - Debug.Assert(typeof(TEnum).GetTypeInfo().IsEnum, "EnumHandler instantiated for non-enum type"); - _enumToLabel = enumToLabel; - _labelToEnum = labelToEnum; - } - - #endregion - - #region Read - - public override TEnum Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var str = buf.ReadString(len); - var success = _labelToEnum.TryGetValue(str, out var value); - - if (!success) - throw new InvalidCastException($"Received enum value '{str}' from database which wasn't found on enum {typeof(TEnum)}"); - - return value; - } - - #endregion - - #region Write - - public override int ValidateAndGetLength(TEnum value, NpgsqlParameter? parameter) - => _enumToLabel.TryGetValue(value, out var str) - ? Encoding.UTF8.GetByteCount(str) - : throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); - - public override void Write(TEnum value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (!_enumToLabel.TryGetValue(value, out var str)) - throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); - buf.WriteString(str); - } - - #endregion -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/FullTextSearchHandlers/TsQueryHandler.cs b/src/Npgsql/Internal/TypeHandlers/FullTextSearchHandlers/TsQueryHandler.cs deleted file mode 100644 index 1fefb0f598..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/FullTextSearchHandlers/TsQueryHandler.cs +++ /dev/null @@ -1,291 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -// TODO: Need to work on the nullability here -#nullable disable -#pragma warning disable CS8632 -#pragma warning disable RS0041 - -namespace Npgsql.Internal.TypeHandlers.FullTextSearchHandlers; - -/// -/// A type handler for the PostgreSQL tsquery data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TsQueryHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler -{ - // 1 (type) + 1 (weight) + 1 (is prefix search) + 2046 (max str len) + 1 (null terminator) - const int MaxSingleTokenBytes = 2050; - - public TsQueryHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numTokens = buf.ReadInt32(); - if (numTokens == 0) - return new NpgsqlTsQueryEmpty(); - - NpgsqlTsQuery? value = null; - var nodes = new Stack>(); - len -= 4; - - for (var tokenPos = 0; tokenPos < numTokens; tokenPos++) - { - await buf.Ensure(Math.Min(len, MaxSingleTokenBytes), async); - var readPos = buf.ReadPosition; - - var isOper = buf.ReadByte() == 2; - if (isOper) - { - var operKind = (NpgsqlTsQuery.NodeKind)buf.ReadByte(); - if (operKind == NpgsqlTsQuery.NodeKind.Not) - { - var node = new NpgsqlTsQueryNot(null); - InsertInTree(node, nodes, ref value); - nodes.Push(new Tuple(node, 0)); - } - else - { - var node = operKind switch - { - NpgsqlTsQuery.NodeKind.And => (NpgsqlTsQuery)new NpgsqlTsQueryAnd(null, null), - NpgsqlTsQuery.NodeKind.Or => new NpgsqlTsQueryOr(null, null), - NpgsqlTsQuery.NodeKind.Phrase => new NpgsqlTsQueryFollowedBy(null, buf.ReadInt16(), null), - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {operKind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug.") - }; - - InsertInTree(node, nodes, ref value); - - nodes.Push(new Tuple(node, 1)); - nodes.Push(new Tuple(node, 2)); - } - } - else - { - var weight = (NpgsqlTsQueryLexeme.Weight)buf.ReadByte(); - var prefix = buf.ReadByte() != 0; - var str = buf.ReadNullTerminatedString(); - InsertInTree(new NpgsqlTsQueryLexeme(str, weight, prefix), nodes, ref value); - } - - len -= buf.ReadPosition - readPos; - } - - if (nodes.Count != 0) - throw new InvalidOperationException("Internal Npgsql bug, please report."); - - return value!; - - static void InsertInTree(NpgsqlTsQuery node, Stack> nodes, ref NpgsqlTsQuery? value) - { - if (nodes.Count == 0) - value = node; - else - { - var parent = nodes.Pop(); - if (parent.Item2 == 0) - ((NpgsqlTsQueryNot)parent.Item1).Child = node; - else if (parent.Item2 == 1) - ((NpgsqlTsQueryBinOp)parent.Item1).Left = node; - else - ((NpgsqlTsQueryBinOp)parent.Item1).Right = node; - } - } - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryEmpty)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryLexeme)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryNot)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryAnd)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryOr)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryFollowedBy)await Read(buf, len, async, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlTsQuery value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Kind == NpgsqlTsQuery.NodeKind.Empty - ? 4 - : 4 + GetNodeLength(value); - - int GetNodeLength(NpgsqlTsQuery node) - { - // TODO: Figure out the nullability strategy here - switch (node.Kind) - { - case NpgsqlTsQuery.NodeKind.Lexeme: - var strLen = Encoding.UTF8.GetByteCount(((NpgsqlTsQueryLexeme)node).Text); - if (strLen > 2046) - throw new InvalidCastException("Lexeme text too long. Must be at most 2046 bytes in UTF8."); - return 4 + strLen; - case NpgsqlTsQuery.NodeKind.And: - case NpgsqlTsQuery.NodeKind.Or: - return 2 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right); - case NpgsqlTsQuery.NodeKind.Phrase: - // 2 additional bytes for uint16 phrase operator "distance" field. - return 4 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right); - case NpgsqlTsQuery.NodeKind.Not: - return 2 + GetNodeLength(((NpgsqlTsQueryNot)node).Child); - case NpgsqlTsQuery.NodeKind.Empty: - throw new InvalidOperationException("Empty tsquery nodes must be top-level"); - default: - throw new InvalidOperationException("Illegal node kind: " + node.Kind); - } - } - - /// - public override async Task Write(NpgsqlTsQuery query, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var numTokens = GetTokenCount(query); - - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(numTokens); - - if (numTokens == 0) - return; - - await WriteCore(query, buf, async, cancellationToken); - - static async Task WriteCore(NpgsqlTsQuery node, NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(node.Kind == NpgsqlTsQuery.NodeKind.Lexeme ? (byte)1 : (byte)2); - - if (node.Kind == NpgsqlTsQuery.NodeKind.Lexeme) - { - if (buf.WriteSpaceLeft < MaxSingleTokenBytes) - await buf.Flush(async, cancellationToken); - - var lexemeNode = (NpgsqlTsQueryLexeme)node; - buf.WriteByte((byte)lexemeNode.Weights); - buf.WriteByte(lexemeNode.IsPrefixSearch ? (byte)1 : (byte)0); - buf.WriteString(lexemeNode.Text); - buf.WriteByte(0); - return; - } - - buf.WriteByte((byte)node.Kind); - if (node.Kind == NpgsqlTsQuery.NodeKind.Not) - { - await WriteCore(((NpgsqlTsQueryNot)node).Child, buf, async, cancellationToken); - return; - } - - if (node.Kind == NpgsqlTsQuery.NodeKind.Phrase) - buf.WriteInt16(((NpgsqlTsQueryFollowedBy)node).Distance); - - await WriteCore(((NpgsqlTsQueryBinOp)node).Right, buf, async, cancellationToken); - await WriteCore(((NpgsqlTsQueryBinOp)node).Left, buf, async, cancellationToken); - } - } - - int GetTokenCount(NpgsqlTsQuery node) - { - switch (node.Kind) - { - case NpgsqlTsQuery.NodeKind.Lexeme: - return 1; - case NpgsqlTsQuery.NodeKind.And: - case NpgsqlTsQuery.NodeKind.Or: - case NpgsqlTsQuery.NodeKind.Phrase: - return 1 + GetTokenCount(((NpgsqlTsQueryBinOp)node).Left) + GetTokenCount(((NpgsqlTsQueryBinOp)node).Right); - case NpgsqlTsQuery.NodeKind.Not: - return 1 + GetTokenCount(((NpgsqlTsQueryNot)node).Child); - case NpgsqlTsQuery.NodeKind.Empty: - return 0; - } - return -1; - } - - /// - public int ValidateAndGetLength(NpgsqlTsQueryOr value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryAnd value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryNot value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryLexeme value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryEmpty value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryFollowedBy value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public Task Write(NpgsqlTsQueryOr value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryAnd value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryNot value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryLexeme value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryEmpty value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write( - NpgsqlTsQueryFollowedBy value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/FullTextSearchHandlers/TsVectorHandler.cs b/src/Npgsql/Internal/TypeHandlers/FullTextSearchHandlers/TsVectorHandler.cs deleted file mode 100644 index 141e566fd1..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/FullTextSearchHandlers/TsVectorHandler.cs +++ /dev/null @@ -1,97 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.FullTextSearchHandlers; - -/// -/// A type handler for the PostgreSQL tsvector data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TsVectorHandler : NpgsqlTypeHandler -{ - // 2561 = 2046 (max length lexeme string) + (1) null terminator + - // 2 (num_pos) + sizeof(int16) * 256 (max_num_pos (positions/wegihts)) - const int MaxSingleLexemeBytes = 2561; - - public TsVectorHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numLexemes = buf.ReadInt32(); - len -= 4; - - var lexemes = new List(); - for (var lexemePos = 0; lexemePos < numLexemes; lexemePos++) - { - await buf.Ensure(Math.Min(len, MaxSingleLexemeBytes), async); - var posBefore = buf.ReadPosition; - - List? positions = null; - - var lexemeString = buf.ReadNullTerminatedString(); - int numPositions = buf.ReadInt16(); - for (var i = 0; i < numPositions; i++) - { - var wordEntryPos = buf.ReadInt16(); - if (positions == null) - positions = new List(); - positions.Add(new NpgsqlTsVector.Lexeme.WordEntryPos(wordEntryPos)); - } - - lexemes.Add(new NpgsqlTsVector.Lexeme(lexemeString, positions, true)); - - len -= buf.ReadPosition - posBefore; - } - - return new NpgsqlTsVector(lexemes, true); - } - - #endregion Read - - #region Write - - // TODO: Implement length cache - /// - public override int ValidateAndGetLength(NpgsqlTsVector value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 4 + value.Sum(l => Encoding.UTF8.GetByteCount(l.Text) + 1 + 2 + l.Count * 2); - - /// - public override async Task Write(NpgsqlTsVector vector, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(vector.Count); - - foreach (var lexeme in vector) - { - if (buf.WriteSpaceLeft < MaxSingleLexemeBytes) - await buf.Flush(async, cancellationToken); - - buf.WriteString(lexeme.Text); - buf.WriteByte(0); - buf.WriteInt16(lexeme.Count); - for (var i = 0; i < lexeme.Count; i++) - buf.WriteInt16(lexeme[i].Value); - } - } - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/BoxHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/BoxHandler.cs deleted file mode 100644 index 6ff333f47e..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/BoxHandler.cs +++ /dev/null @@ -1,41 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL box data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class BoxHandler : NpgsqlSimpleTypeHandler -{ - public BoxHandler(PostgresType pgType) : base(pgType) {} - - /// - public override NpgsqlBox Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new( - new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble()), - new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble()) - ); - - /// - public override int ValidateAndGetLength(NpgsqlBox value, NpgsqlParameter? parameter) - => 32; - - /// - public override void Write(NpgsqlBox value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.Right); - buf.WriteDouble(value.Top); - buf.WriteDouble(value.Left); - buf.WriteDouble(value.Bottom); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/CircleHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/CircleHandler.cs deleted file mode 100644 index b450177cd3..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/CircleHandler.cs +++ /dev/null @@ -1,37 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL circle data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class CircleHandler : NpgsqlSimpleTypeHandler -{ - public CircleHandler(PostgresType pgType) : base(pgType) {} - - /// - public override NpgsqlCircle Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new(buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlCircle value, NpgsqlParameter? parameter) - => 24; - - /// - public override void Write(NpgsqlCircle value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.X); - buf.WriteDouble(value.Y); - buf.WriteDouble(value.Radius); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/LineHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/LineHandler.cs deleted file mode 100644 index 8b16b68a67..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/LineHandler.cs +++ /dev/null @@ -1,37 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL line data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class LineHandler : NpgsqlSimpleTypeHandler -{ - public LineHandler(PostgresType pgType) : base(pgType) {} - - /// - public override NpgsqlLine Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new(buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlLine value, NpgsqlParameter? parameter) - => 24; - - /// - public override void Write(NpgsqlLine value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.A); - buf.WriteDouble(value.B); - buf.WriteDouble(value.C); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/LineSegmentHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/LineSegmentHandler.cs deleted file mode 100644 index f34083602f..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/LineSegmentHandler.cs +++ /dev/null @@ -1,38 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL lseg data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class LineSegmentHandler : NpgsqlSimpleTypeHandler -{ - public LineSegmentHandler(PostgresType pgType) : base(pgType) {} - - /// - public override NpgsqlLSeg Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new(buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlLSeg value, NpgsqlParameter? parameter) - => 32; - - /// - public override void Write(NpgsqlLSeg value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.Start.X); - buf.WriteDouble(value.Start.Y); - buf.WriteDouble(value.End.X); - buf.WriteDouble(value.End.Y); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PathHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PathHandler.cs deleted file mode 100644 index 4b7aa4c8b5..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PathHandler.cs +++ /dev/null @@ -1,74 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL path data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class PathHandler : NpgsqlTypeHandler -{ - public PathHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(5, async); - var open = buf.ReadByte() switch - { - 1 => false, - 0 => true, - _ => throw new Exception("Error decoding binary geometric path: bad open byte") - }; - - var numPoints = buf.ReadInt32(); - var result = new NpgsqlPath(numPoints, open); - for (var i = 0; i < numPoints; i++) - { - await buf.Ensure(16, async); - result.Add(new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble())); - } - return result; - } - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlPath value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 5 + value.Count * 16; - - /// - public override async Task Write(NpgsqlPath value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 5) - await buf.Flush(async, cancellationToken); - buf.WriteByte((byte)(value.Open ? 0 : 1)); - buf.WriteInt32(value.Count); - - foreach (var p in value) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(p.X); - buf.WriteDouble(p.Y); - } - } - - #endregion -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PointHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PointHandler.cs deleted file mode 100644 index d02bd67ec8..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PointHandler.cs +++ /dev/null @@ -1,36 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL point data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class PointHandler : NpgsqlSimpleTypeHandler -{ - public PointHandler(PostgresType pgType) : base(pgType) {} - - /// - public override NpgsqlPoint Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new(buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlPoint value, NpgsqlParameter? parameter) - => 16; - - /// - public override void Write(NpgsqlPoint value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.X); - buf.WriteDouble(value.Y); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PolygonHandler.cs b/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PolygonHandler.cs deleted file mode 100644 index 004bd3ebbc..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/GeometricHandlers/PolygonHandler.cs +++ /dev/null @@ -1,65 +0,0 @@ -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.GeometricHandlers; - -/// -/// A type handler for the PostgreSQL polygon data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class PolygonHandler : NpgsqlTypeHandler -{ - public PolygonHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numPoints = buf.ReadInt32(); - var result = new NpgsqlPolygon(numPoints); - for (var i = 0; i < numPoints; i++) - { - await buf.Ensure(16, async); - result.Add(new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble())); - } - return result; - } - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 4 + value.Count * 16; - - /// - public override async Task Write(NpgsqlPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Count); - - foreach (var p in value) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(p.X); - buf.WriteDouble(p.Y); - } - } - - #endregion -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/HstoreHandler.cs b/src/Npgsql/Internal/TypeHandlers/HstoreHandler.cs deleted file mode 100644 index 0b8613d979..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/HstoreHandler.cs +++ /dev/null @@ -1,178 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL hstore extension data type, which stores sets of key/value pairs within a -/// single PostgreSQL value. -/// -/// -/// See https://www.postgresql.org/docs/current/hstore.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public class HstoreHandler : - NpgsqlTypeHandler>, - INpgsqlTypeHandler>, - INpgsqlTypeHandler> -{ - /// - /// The text handler to which we delegate encoding/decoding of the actual strings - /// - readonly TextHandler _textHandler; - - internal HstoreHandler(PostgresType postgresType, TextHandler textHandler) - : base(postgresType) - => _textHandler = textHandler; - - #region Write - - /// - public int ValidateAndGetLength(IDictionary value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - // Leave empty slot for the entire hstore length, and go ahead an populate the individual string slots - var pos = lengthCache.Position; - lengthCache.Set(0); - - var totalLen = 4; // Number of key-value pairs - foreach (var kv in value) - { - totalLen += 8; // Key length + value length - if (kv.Key == null) - throw new FormatException("HSTORE doesn't support null keys"); - totalLen += _textHandler.ValidateAndGetLength(kv.Key, ref lengthCache, null); - if (kv.Value != null) - totalLen += _textHandler.ValidateAndGetLength(kv.Value!, ref lengthCache, null); - } - - return lengthCache.Lengths[pos] = totalLen; - } - - /// - public int ValidateAndGetLength( - ImmutableDictionary value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((IDictionary)value, ref lengthCache, parameter); - - /// - public override int ValidateAndGetLength(Dictionary value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - /// - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - ImmutableDictionary converted => ValidateAndGetLength(converted, ref lengthCache, parameter), - Dictionary converted => ValidateAndGetLength(converted, ref lengthCache, parameter), - IDictionary converted => ValidateAndGetLength(converted, ref lengthCache, parameter), - - DBNull => 0, - null => 0, - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type HstoreHandler") - }; - - /// - public override Task WriteObjectWithLength( - object? value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => value switch - { - ImmutableDictionary converted => ((INpgsqlTypeHandler>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - Dictionary converted => ((INpgsqlTypeHandler>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - IDictionary converted => ((INpgsqlTypeHandler>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - - DBNull => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type BoolHandler") - }; - - /// - public async Task Write(IDictionary value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Count); - if (value.Count == 0) - return; - - foreach (var kv in value) - { - await ((INpgsqlTypeHandler)_textHandler).WriteWithLength(kv.Key, buf, lengthCache, parameter, async, cancellationToken); - await ((INpgsqlTypeHandler)_textHandler).WriteWithLength(kv.Value, buf, lengthCache, parameter, async, cancellationToken); - } - } - - /// - public Task Write(ImmutableDictionary value, - NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((IDictionary)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public override Task Write(Dictionary value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion - - #region Read - - async ValueTask ReadInto(T dictionary, int numElements, NpgsqlReadBuffer buf, bool async) - where T : IDictionary - { - for (var i = 0; i < numElements; i++) - { - await buf.Ensure(4, async); - var keyLen = buf.ReadInt32(); - Debug.Assert(keyLen != -1); - var key = await _textHandler.Read(buf, keyLen, async); - - await buf.Ensure(4, async); - var valueLen = buf.ReadInt32(); - - dictionary[key] = valueLen == -1 - ? null - : await _textHandler.Read(buf, valueLen, async); - } - return dictionary; - } - - /// - public override async ValueTask> Read(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numElements = buf.ReadInt32(); - return await ReadInto(new Dictionary(numElements), numElements, buf, async); - } - - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => new(Read(buf, len, async, fieldDescription).Result); - - async ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(4, async); - var numElements = buf.ReadInt32(); - return (await ReadInto(ImmutableDictionary.Empty.ToBuilder(), numElements, buf, async)) - .ToImmutable(); - } - - #endregion -} diff --git a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/Int2VectorHandler.cs b/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/Int2VectorHandler.cs deleted file mode 100644 index 1523b66d69..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/Int2VectorHandler.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Npgsql.Internal.TypeHandlers.NumericHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.InternalTypeHandlers; - -/// -/// An int2vector is simply a regular array of shorts, with the sole exception that its lower bound must -/// be 0 (we send 1 for regular arrays). -/// -sealed class Int2VectorHandler : ArrayHandler -{ - public Int2VectorHandler(PostgresType arrayPostgresType, PostgresType postgresShortType) - : base(arrayPostgresType, new Int16Handler(postgresShortType), ArrayNullabilityMode.Never, 0) { } - - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => new ArrayHandler(pgArrayType, this, arrayNullabilityMode); -} diff --git a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/InternalCharHandler.cs b/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/InternalCharHandler.cs deleted file mode 100644 index 2131cc16c8..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/InternalCharHandler.cs +++ /dev/null @@ -1,87 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.InternalTypeHandlers; - -/// -/// A type handler for the PostgreSQL "char" type, used only internally. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-character.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class InternalCharHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - public InternalCharHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override char Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => (char)buf.ReadByte(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - #endregion - - #region Write - - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 1; - - /// - public override int ValidateAndGetLength(char value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public override void Write(char value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte(value); - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - - #endregion -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/OIDVectorHandler.cs b/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/OIDVectorHandler.cs deleted file mode 100644 index 00b3a57aa1..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/OIDVectorHandler.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Npgsql.Internal.TypeHandlers.NumericHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.InternalTypeHandlers; - -/// -/// An OIDVector is simply a regular array of uints, with the sole exception that its lower bound must -/// be 0 (we send 1 for regular arrays). -/// -sealed class OIDVectorHandler : ArrayHandler -{ - public OIDVectorHandler(PostgresType oidvectorType, PostgresType oidType) - : base(oidvectorType, new UInt32Handler(oidType), ArrayNullabilityMode.Never, 0) { } - - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => new ArrayHandler(pgArrayType, this, arrayNullabilityMode); -} diff --git a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/PgLsnHandler.cs b/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/PgLsnHandler.cs deleted file mode 100644 index 75e85ab3e6..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/PgLsnHandler.cs +++ /dev/null @@ -1,31 +0,0 @@ -using System.Diagnostics; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.InternalTypeHandlers; - -sealed partial class PgLsnHandler : NpgsqlSimpleTypeHandler -{ - public PgLsnHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - public override NpgsqlLogSequenceNumber Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len == 8); - return new NpgsqlLogSequenceNumber(buf.ReadUInt64()); - } - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(NpgsqlLogSequenceNumber value, NpgsqlParameter? parameter) => 8; - - public override void Write(NpgsqlLogSequenceNumber value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteUInt64((ulong)value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/TidHandler.cs b/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/TidHandler.cs deleted file mode 100644 index 0148fc1071..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/InternalTypeHandlers/TidHandler.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System.Diagnostics; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers.InternalTypeHandlers; - -sealed partial class TidHandler : NpgsqlSimpleTypeHandler -{ - public TidHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - public override NpgsqlTid Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len == 6); - - var blockNumber = buf.ReadUInt32(); - var offsetNumber = buf.ReadUInt16(); - - return new NpgsqlTid(blockNumber, offsetNumber); - } - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(NpgsqlTid value, NpgsqlParameter? parameter) - => 6; - - public override void Write(NpgsqlTid value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteUInt32(value.BlockNumber); - buf.WriteUInt16(value.OffsetNumber); - } - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/JsonPathHandler.cs b/src/Npgsql/Internal/TypeHandlers/JsonPathHandler.cs deleted file mode 100644 index 7b2735fcd3..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/JsonPathHandler.cs +++ /dev/null @@ -1,74 +0,0 @@ -using System; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL jsonpath data type. -/// -/// -/// See https://www.postgresql.org/docs/current/datatype-json.html#DATATYPE-JSONPATH. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class JsonPathHandler : NpgsqlTypeHandler, ITextReaderHandler -{ - readonly TextHandler _textHandler; - - /// - /// Prepended to the string in the wire encoding - /// - const byte JsonPathVersion = 1; - - /// - protected internal JsonPathHandler(PostgresType postgresType, Encoding encoding) - : base(postgresType) - => _textHandler = new TextHandler(postgresType, encoding); - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != JsonPathVersion) - throw new NotSupportedException($"Don't know how to decode JSONPATH with wire format {version}, your connection is now broken"); - - return await _textHandler.Read(buf, len - 1, async, fieldDescription); - } - - /// - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - 1 + _textHandler.ValidateAndGetLength(value, ref lengthCache, parameter); - - /// - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(JsonPathVersion); - - await _textHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - /// - public TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer) - { - var version = stream.ReadByte(); - if (version != JsonPathVersion) - throw new NotSupportedException($"Don't know how to decode JSONPATH with wire format {version}, your connection is now broken"); - - return _textHandler.GetTextReader(stream, buffer); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/JsonTextHandler.cs b/src/Npgsql/Internal/TypeHandlers/JsonTextHandler.cs deleted file mode 100644 index 2842370336..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/JsonTextHandler.cs +++ /dev/null @@ -1,209 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A text-only type handler for the PostgreSQL json and jsonb data type. This handler does not support serialization/deserialization -/// with System.Text.Json or Json.NET. -/// -/// -/// See https://www.postgresql.org/docs/current/datatype-json.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public class JsonTextHandler : NpgsqlTypeHandler, ITextReaderHandler -{ - protected TextHandler TextHandler { get; } - readonly bool _isJsonb; - readonly int _headerLen; - - internal override bool PreferTextWrite => false; - - /// - /// Prepended to the string in the wire encoding - /// - const byte JsonbProtocolVersion = 1; - - /// - public JsonTextHandler(PostgresType postgresType, Encoding encoding, bool isJsonb) - : base(postgresType) - { - _isJsonb = isJsonb; - _headerLen = isJsonb ? 1 : 0; - TextHandler = new TextHandler(postgresType, encoding); - } - - protected bool IsSupportedAsText() - => typeof(T) == typeof(string) || - typeof(T) == typeof(char[]) || - typeof(T) == typeof(ArraySegment) || - typeof(T) == typeof(char) || - typeof(T) == typeof(byte[]) || - typeof(T) == typeof(ReadOnlyMemory); - - protected bool IsSupported(Type type) - => type == typeof(string) || - type == typeof(char[]) || - type == typeof(ArraySegment) || - type == typeof(char) || - type == typeof(byte[]) || - type == typeof(ReadOnlyMemory); - - protected bool TryValidateAndGetLengthCustom( - [DisallowNull] TAny value, - ref NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - out int length) - { - if (IsSupportedAsText()) - { - length = TextHandler.ValidateAndGetLength(value, ref lengthCache, parameter) + _headerLen; - return true; - } - - length = 0; - return false; - } - - /// - protected internal override int ValidateAndGetLengthCustom([DisallowNull] TAny value, ref NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter) - => IsSupportedAsText() - ? TextHandler.ValidateAndGetLength(value, ref lengthCache, parameter) + _headerLen - : throw new InvalidCastException( - $"Can't write CLR type {value.GetType()}. " + - "You may need to use the System.Text.Json or Json.NET plugins, see the docs for more information."); - - protected override async Task WriteWithLengthCustom( - [DisallowNull] TAny value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken) - { - var spaceRequired = _isJsonb ? 5 : 4; - - if (buf.WriteSpaceLeft < spaceRequired) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - - if (_isJsonb) - buf.WriteByte(JsonbProtocolVersion); - - if (typeof(TAny) == typeof(string)) - await TextHandler.Write((string)(object)value, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(char[])) - await TextHandler.Write((char[])(object)value, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(ArraySegment)) - await TextHandler.Write((ArraySegment)(object)value, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(char)) - await TextHandler.Write((char)(object)value, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(byte[])) - await TextHandler.Write((byte[])(object)value, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(ReadOnlyMemory)) - await TextHandler.Write((ReadOnlyMemory)(object)value, buf, lengthCache, parameter, async, cancellationToken); - else throw new InvalidCastException( - $"Can't write CLR type {value.GetType()}. " + - "You may need to use the System.Text.Json or Json.NET plugins, see the docs for more information."); - } - - /// - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCustom(value, ref lengthCache, parameter); - - /// - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (_isJsonb) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - buf.WriteByte(JsonbProtocolVersion); - } - - await TextHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - /// - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - string s => ValidateAndGetLength(s, ref lengthCache, parameter), - char[] s => ValidateAndGetLength(s, ref lengthCache, parameter), - ArraySegment s => ValidateAndGetLength(s, ref lengthCache, parameter), - char s => ValidateAndGetLength(s, ref lengthCache, parameter), - byte[] s => ValidateAndGetLength(s, ref lengthCache, parameter), - ReadOnlyMemory s => ValidateAndGetLength(s, ref lengthCache, parameter), - - _ => throw new InvalidCastException( - $"Can't write CLR type {value.GetType()}. " + - "You may need to use the System.Text.Json or Json.NET plugins, see the docs for more information.") - }; - - /// - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch - { - null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - DBNull => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - string s => WriteWithLengthCustom(s, buf, lengthCache, parameter, async, cancellationToken), - char[] s => WriteWithLengthCustom(s, buf, lengthCache, parameter, async, cancellationToken), - ArraySegment s => WriteWithLengthCustom(s, buf, lengthCache, parameter, async, cancellationToken), - char s => WriteWithLengthCustom(s, buf, lengthCache, parameter, async, cancellationToken), - byte[] s => WriteWithLengthCustom(s, buf, lengthCache, parameter, async, cancellationToken), - ReadOnlyMemory s => WriteWithLengthCustom(s, buf, lengthCache, parameter, async, cancellationToken), - - _ => throw new InvalidCastException( - $"Can't write CLR type {value.GetType()}. " + - "You may need to use the System.Text.Json or Json.NET plugins, see the docs for more information.") - }; - - /// - protected internal override async ValueTask ReadCustom(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - if (_isJsonb) - { - await buf.Ensure(1, async); - var version = buf.ReadByte(); - if (version != JsonbProtocolVersion) - throw new NotSupportedException($"Don't know how to decode JSONB with wire format {version}, your connection is now broken"); - len--; - } - - if (IsSupportedAsText()) - return await TextHandler.Read(buf, len, async, fieldDescription); - - throw new InvalidCastException( - $"Can't read JSON as CLR type {typeof(T)}. " + - "You may need to use the System.Text.Json or Json.NET plugins, see the docs for more information."); - } - - /// - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadCustom(buf, len, async, fieldDescription); - - /// - public TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer) - { - if (_isJsonb) - { - var version = stream.ReadByte(); - if (version != JsonbProtocolVersion) - throw new NpgsqlException($"Don't know how to decode jsonb with wire format {version}, your connection is now broken"); - } - - return TextHandler.GetTextReader(stream, buffer); - } -} diff --git a/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LQueryHandler.cs b/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LQueryHandler.cs deleted file mode 100644 index 9f73a4fb97..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LQueryHandler.cs +++ /dev/null @@ -1,90 +0,0 @@ -using System; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.LTreeHandlers; - -/// -/// LQuery binary encoding is a simple UTF8 string, but prepended with a version number. -/// -public class LQueryHandler : TextHandler -{ - /// - /// Prepended to the string in the wire encoding - /// - const byte LQueryProtocolVersion = 1; - - internal override bool PreferTextWrite => false; - - protected internal LQueryHandler(PostgresType postgresType, Encoding encoding) - : base(postgresType, encoding) {} - - #region Write - - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != LQueryProtocolVersion) - throw new NotSupportedException($"Don't know how to decode lquery with wire format {version}, your connection is now broken"); - - return await base.Read(buf, len - 1, async, fieldDescription); - } - - #endregion - - public override TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer) - { - var version = stream.ReadByte(); - if (version != LQueryProtocolVersion) - throw new NpgsqlException($"Don't know how to decode lquery with wire format {version}, your connection is now broken"); - - return base.GetTextReader(stream, buffer); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LTreeHandler.cs b/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LTreeHandler.cs deleted file mode 100644 index 4f43266d8f..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LTreeHandler.cs +++ /dev/null @@ -1,90 +0,0 @@ -using System; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.LTreeHandlers; - -/// -/// Ltree binary encoding is a simple UTF8 string, but prepended with a version number. -/// -public class LTreeHandler : TextHandler -{ - /// - /// Prepended to the string in the wire encoding - /// - const byte LtreeProtocolVersion = 1; - - internal override bool PreferTextWrite => false; - - protected internal LTreeHandler(PostgresType postgresType, Encoding encoding) - : base(postgresType, encoding) {} - - #region Write - - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LtreeProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LtreeProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LtreeProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != LtreeProtocolVersion) - throw new NotSupportedException($"Don't know how to decode ltree with wire format {version}, your connection is now broken"); - - return await base.Read(buf, len - 1, async, fieldDescription); - } - - #endregion - - public override TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer) - { - var version = stream.ReadByte(); - if (version != LtreeProtocolVersion) - throw new NpgsqlException($"Don't know how to decode ltree with wire format {version}, your connection is now broken"); - - return base.GetTextReader(stream, buffer); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LTxtQueryHandler.cs b/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LTxtQueryHandler.cs deleted file mode 100644 index dcde2a1d73..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/LTreeHandlers/LTxtQueryHandler.cs +++ /dev/null @@ -1,93 +0,0 @@ -using System; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.LTreeHandlers; - -/// -/// LTxtQuery binary encoding is a simple UTF8 string, but prepended with a version number. -/// -public class LTxtQueryHandler : TextHandler -{ - /// - /// Prepended to the string in the wire encoding - /// - const byte LTxtQueryProtocolVersion = 1; - - internal override bool PreferTextWrite => false; - - protected internal LTxtQueryHandler(PostgresType postgresType, Encoding encoding) - : base(postgresType, encoding) {} - - #region Write - - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LTxtQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LTxtQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LTxtQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != LTxtQueryProtocolVersion) - throw new NotSupportedException($"Don't know how to decode ltxtquery with wire format {version}, your connection is now broken"); - - return await base.Read(buf, len - 1, async, fieldDescription); - } - - #endregion - - public override TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer) - { - var version = stream.ReadByte(); - if (version != LTxtQueryProtocolVersion) - throw new NpgsqlException($"Don't know how to decode ltxtquery with wire format {version}, your connection is now broken"); - - return base.GetTextReader(stream, buffer); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/MultirangeHandler.cs b/src/Npgsql/Internal/TypeHandlers/MultirangeHandler.cs deleted file mode 100644 index f0f2c11827..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/MultirangeHandler.cs +++ /dev/null @@ -1,211 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers; - -// NOTE: This cannot inherit from NpgsqlTypeHandler[]>, since that triggers infinite generic recursion in Native AOT -public partial class MultirangeHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler[]>, - INpgsqlTypeHandler>> -{ - /// - /// The type handler for the range that this multirange type holds - /// - protected RangeHandler RangeHandler { get; } - - /// - public MultirangeHandler(PostgresMultirangeType pgMultirangeType, RangeHandler rangeHandler) - : base(pgMultirangeType) - => RangeHandler = rangeHandler; - - public ValueTask[]> Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadMultirangeArray(buf, len, async, fieldDescription); - - protected async ValueTask[]> ReadMultirangeArray( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numRanges = buf.ReadInt32(); - var multirange = new NpgsqlRange[numRanges]; - - for (var i = 0; i < numRanges; i++) - { - await buf.Ensure(4, async); - var rangeLen = buf.ReadInt32(); - multirange[i] = await RangeHandler.ReadRange(buf, rangeLen, async, fieldDescription); - } - - return multirange; - } - - ValueTask>> INpgsqlTypeHandler>>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeList(buf, len, async, fieldDescription); - - protected async ValueTask>> ReadMultirangeList( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numRanges = buf.ReadInt32(); - var multirange = new List>(numRanges); - - for (var i = 0; i < numRanges; i++) - { - await buf.Ensure(4, async); - var rangeLen = buf.ReadInt32(); - multirange.Add(await RangeHandler.ReadRange(buf, rangeLen, async, fieldDescription)); - } - - return multirange; - } - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => await Read(buf, len, async, fieldDescription); - - public int ValidateAndGetLength(NpgsqlRange[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(List> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - protected int ValidateAndGetLengthMultirange( - IList> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - // Leave empty slot for the entire array length, and go ahead an populate the element slots - var pos = lengthCache.Position; - lengthCache.Set(0); - - var sum = 4 + 4 * value.Count; - for (var i = 0; i < value.Count; i++) - sum += RangeHandler.ValidateAndGetLength(value[i], ref lengthCache, parameter); - - lengthCache.Lengths[pos] = sum; - return sum; - } - - public Task Write( - NpgsqlRange[] value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write( - List> value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public async Task WriteMultirange( - IList> value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(value.Count); - - for (var i = 0; i < value.Count; i++) - await ((INpgsqlTypeHandler>)RangeHandler).WriteWithLength(value[i], buf, lengthCache, parameter: null, async, cancellationToken); - } - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(NpgsqlRange[]); - - /// - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => throw new NotSupportedException(); - - /// - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => throw new NotSupportedException(); - - /// - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType) - => throw new NotSupportedException(); -} - -public class MultirangeHandler : MultirangeHandler, - INpgsqlTypeHandler[]>, INpgsqlTypeHandler>> -{ - /// - public MultirangeHandler(PostgresMultirangeType pgMultirangeType, RangeHandler rangeHandler) - : base(pgMultirangeType, rangeHandler) {} - - ValueTask[]> INpgsqlTypeHandler[]>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeArray(buf, len, async, fieldDescription); - - ValueTask>> INpgsqlTypeHandler>>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadMultirangeList(buf, len, async, fieldDescription); - - public int ValidateAndGetLength(List> value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public int ValidateAndGetLength(NpgsqlRange[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthMultirange(value, ref lengthCache, parameter); - - public Task Write( - List> value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write( - NpgsqlRange[] value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => WriteMultirange(value, buf, lengthCache, parameter, async, cancellationToken); - - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - NpgsqlRange[] converted => ((INpgsqlTypeHandler[]>)this).ValidateAndGetLength(converted, ref lengthCache, parameter), - NpgsqlRange[] converted => ((INpgsqlTypeHandler[]>)this).ValidateAndGetLength(converted, ref lengthCache, parameter), - List> converted => ((INpgsqlTypeHandler>>)this).ValidateAndGetLength(converted, ref lengthCache, parameter), - List> converted => ((INpgsqlTypeHandler>>)this).ValidateAndGetLength(converted, ref lengthCache, parameter), - - DBNull => 0, - null => 0, - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type RangeHandler") - }; - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch - { - NpgsqlRange[] converted => ((INpgsqlTypeHandler[]>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - NpgsqlRange[] converted => ((INpgsqlTypeHandler[]>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - List> converted => ((INpgsqlTypeHandler>>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - List> converted => ((INpgsqlTypeHandler>>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - - DBNull => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type RangeHandler") - }; -} diff --git a/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/CidrHandler.cs b/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/CidrHandler.cs deleted file mode 100644 index 6d5eb29f10..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/CidrHandler.cs +++ /dev/null @@ -1,50 +0,0 @@ -using System.Net; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -#pragma warning disable 618 - -namespace Npgsql.Internal.TypeHandlers.NetworkHandlers; - -/// -/// A type handler for the PostgreSQL cidr data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-net-types.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class CidrHandler : NpgsqlSimpleTypeHandler<(IPAddress Address, int Subnet)>, INpgsqlSimpleTypeHandler -{ - public CidrHandler(PostgresType pgType) : base(pgType) {} - - /// - public override (IPAddress Address, int Subnet) Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => InetHandler.DoRead(buf, len, fieldDescription, true); - - NpgsqlInet INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var (address, subnet) = Read(buf, len, fieldDescription); - return new NpgsqlInet(address, subnet); - } - - /// - public override int ValidateAndGetLength((IPAddress Address, int Subnet) value, NpgsqlParameter? parameter) - => InetHandler.GetLength(value.Address); - - /// - public int ValidateAndGetLength(NpgsqlInet value, NpgsqlParameter? parameter) - => InetHandler.GetLength(value.Address); - - /// - public override void Write((IPAddress Address, int Subnet) value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => InetHandler.DoWrite(value.Address, value.Subnet, buf, true); - - /// - public void Write(NpgsqlInet value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => InetHandler.DoWrite(value.Address, value.Netmask, buf, true); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/InetHandler.cs b/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/InetHandler.cs deleted file mode 100644 index ed10be3ef8..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/InetHandler.cs +++ /dev/null @@ -1,133 +0,0 @@ -using System; -using System.Diagnostics; -using System.Net; -using System.Net.Sockets; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -#pragma warning disable 618 - -namespace Npgsql.Internal.TypeHandlers.NetworkHandlers; - -/// -/// A type handler for the PostgreSQL cidr data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-net-types.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class InetHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler<(IPAddress Address, int Subnet)>, - INpgsqlSimpleTypeHandler -{ - // ReSharper disable InconsistentNaming - const byte IPv4 = 2; - const byte IPv6 = 3; - // ReSharper restore InconsistentNaming - - public InetHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override IPAddress Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => DoRead(buf, len, fieldDescription, false).Address; - -#pragma warning disable CA1801 // Review unused parameters - internal static (IPAddress Address, int Subnet) DoRead( - NpgsqlReadBuffer buf, - int len, - FieldDescription? fieldDescription, - bool isCidrHandler) - { - buf.ReadByte(); // addressFamily - var mask = buf.ReadByte(); - var isCidr = buf.ReadByte() == 1; - Debug.Assert(isCidrHandler == isCidr); - var numBytes = buf.ReadByte(); - var bytes = new byte[numBytes]; - for (var i = 0; i < bytes.Length; i++) - bytes[i] = buf.ReadByte(); - - return (new IPAddress(bytes), mask); - } -#pragma warning restore CA1801 // Review unused parameters - - /// - (IPAddress Address, int Subnet) INpgsqlSimpleTypeHandler<(IPAddress Address, int Subnet)>.Read( - NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => DoRead(buf, len, fieldDescription, false); - - NpgsqlInet INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var (address, subnet) = DoRead(buf, len, fieldDescription, false); - return new NpgsqlInet(address, subnet); - } - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(IPAddress value, NpgsqlParameter? parameter) - => GetLength(value); - - /// - public int ValidateAndGetLength((IPAddress Address, int Subnet) value, NpgsqlParameter? parameter) - => GetLength(value.Address); - - /// - public int ValidateAndGetLength(NpgsqlInet value, NpgsqlParameter? parameter) - => GetLength(value.Address); - - /// - public override void Write(IPAddress value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => DoWrite(value, -1, buf, false); - - /// - public void Write((IPAddress Address, int Subnet) value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => DoWrite(value.Address, value.Subnet, buf, false); - - /// - public void Write(NpgsqlInet value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => DoWrite(value.Address, value.Netmask, buf, false); - - internal static void DoWrite(IPAddress ip, int mask, NpgsqlWriteBuffer buf, bool isCidrHandler) - { - switch (ip.AddressFamily) { - case AddressFamily.InterNetwork: - buf.WriteByte(IPv4); - if (mask == -1) - mask = 32; - break; - case AddressFamily.InterNetworkV6: - buf.WriteByte(IPv6); - if (mask == -1) - mask = 128; - break; - default: - throw new InvalidCastException($"Can't handle IPAddress with AddressFamily {ip.AddressFamily}, only InterNetwork or InterNetworkV6!"); - } - - buf.WriteByte((byte)mask); - buf.WriteByte((byte)(isCidrHandler ? 1 : 0)); // Ignored on server side - var bytes = ip.GetAddressBytes(); - buf.WriteByte((byte)bytes.Length); - buf.WriteBytes(bytes, 0, bytes.Length); - } - - internal static int GetLength(IPAddress value) - => value.AddressFamily switch - { - AddressFamily.InterNetwork => 8, - AddressFamily.InterNetworkV6 => 20, - _ => throw new InvalidCastException($"Can't handle IPAddress with AddressFamily {value.AddressFamily}, only InterNetwork or InterNetworkV6!") - }; - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/MacaddrHandler.cs b/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/MacaddrHandler.cs deleted file mode 100644 index 26ade3e22b..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NetworkHandlers/MacaddrHandler.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System.Diagnostics; -using System.Net.NetworkInformation; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NetworkHandlers; - -/// -/// A type handler for the PostgreSQL macaddr and macaddr8 data types. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-net-types.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class MacaddrHandler : NpgsqlSimpleTypeHandler -{ - public MacaddrHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override PhysicalAddress Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len == 6 || len == 8); - - var bytes = new byte[len]; - - buf.ReadBytes(bytes, 0, len); - return new PhysicalAddress(bytes); - } - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(PhysicalAddress value, NpgsqlParameter? parameter) - => value.GetAddressBytes().Length; - - /// - public override void Write(PhysicalAddress value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var bytes = value.GetAddressBytes(); - buf.WriteBytes(bytes, 0, bytes.Length); - } - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/DecimalRaw.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/DecimalRaw.cs deleted file mode 100644 index 0115728f33..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/DecimalRaw.cs +++ /dev/null @@ -1,150 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -[StructLayout(LayoutKind.Explicit)] -struct DecimalRaw -{ - const int SignMask = unchecked((int)0x80000000); - const int ScaleMask = 0x00FF0000; - const int ScaleShift = 16; - - // Fast access for 10^n where n is 0-9 - internal static readonly uint[] Powers10 = - { - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000 - }; - - // The maximum power of 10 that a 32 bit unsigned integer can store - internal static readonly int MaxUInt32Scale = Powers10.Length - 1; - - // Do not change the order in which these fields are declared. It - // should be same as in the System.Decimal struct. - [FieldOffset(0)] - decimal _value; - [FieldOffset(0)] - int _flags; - [FieldOffset(4)] - uint _high; - [FieldOffset(8)] - uint _low; - [FieldOffset(12)] - uint _mid; - - public bool Negative => (_flags & SignMask) != 0; - - public int Scale - { - get => (_flags & ScaleMask) >> ScaleShift; - set => _flags = (_flags & SignMask) | ((value << ScaleShift) & ScaleMask); - } - - public uint High => _high; - public uint Mid => _mid; - public uint Low => _low; - public decimal Value => _value; - - public DecimalRaw(decimal value) : this() => _value = value; - - public DecimalRaw(long value) : this() - { - if (value >= 0) - _flags = 0; - else - { - _flags = SignMask; - value = -value; - } - - _low = (uint)value; - _mid = (uint)(value >> 32); - _high = 0; - } - - public static void Negate(ref DecimalRaw value) - => value._flags ^= SignMask; - - public static void Add(ref DecimalRaw value, uint addend) - { - uint integer; - uint sum; - - integer = value._low; - value._low = sum = integer + addend; - - if (sum >= integer && sum >= addend) - return; - - integer = value._mid; - value._mid = sum = integer + 1; - - if (sum >= integer && sum >= 1) - return; - - integer = value._high; - value._high = sum = integer + 1; - - if (sum < integer || sum < 1) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - } - - public static void Multiply(ref DecimalRaw value, uint multiplier) - { - ulong integer; - uint remainder; - - integer = (ulong)value._low * multiplier; - value._low = (uint)integer; - remainder = (uint)(integer >> 32); - - integer = (ulong)value._mid * multiplier + remainder; - value._mid = (uint)integer; - remainder = (uint)(integer >> 32); - - integer = (ulong)value._high * multiplier + remainder; - value._high = (uint)integer; - remainder = (uint)(integer >> 32); - - if (remainder != 0) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - } - - public static uint Divide(ref DecimalRaw value, uint divisor) - { - ulong integer; - uint remainder = 0; - - if (value._high != 0) - { - integer = value._high; - value._high = (uint)(integer / divisor); - remainder = (uint)(integer % divisor); - } - - if (value._mid != 0 || remainder != 0) - { - integer = ((ulong)remainder << 32) | value._mid; - value._mid = (uint)(integer / divisor); - remainder = (uint)(integer % divisor); - } - - if (value._low != 0 || remainder != 0) - { - integer = ((ulong)remainder << 32) | value._low; - value._low = (uint)(integer / divisor); - remainder = (uint)(integer % divisor); - } - - return remainder; - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/DoubleHandler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/DoubleHandler.cs deleted file mode 100644 index 33b1bae14c..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/DoubleHandler.cs +++ /dev/null @@ -1,32 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL double precision data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class DoubleHandler : NpgsqlSimpleTypeHandler -{ - public DoubleHandler(PostgresType pgType) : base(pgType) {} - - /// - public override double Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadDouble(); - - /// - public override int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - => 8; - - /// - public override void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteDouble(value); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int16Handler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int16Handler.cs deleted file mode 100644 index 30c704e574..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int16Handler.cs +++ /dev/null @@ -1,109 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL smallint data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class Int16Handler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - public Int16Handler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override short Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt16(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((byte)Read(buf, len, fieldDescription)); - - sbyte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((sbyte)Read(buf, len, fieldDescription)); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - decimal INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => 2; - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 2; - /// - public int ValidateAndGetLength(sbyte value, NpgsqlParameter? parameter) => 2; - /// - public int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) => 2; - - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public override void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16(value); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16(value); - /// - public void Write(sbyte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16(value); - /// - public void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int32Handler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int32Handler.cs deleted file mode 100644 index 3b778d9a70..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int32Handler.cs +++ /dev/null @@ -1,96 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL integer data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class Int32Handler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - public Int32Handler(PostgresType pgType) : base(pgType) {} - - #region Read - - public override int Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt32(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((byte)Read(buf, len, fieldDescription)); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((short)Read(buf, len, fieldDescription)); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - decimal INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => 4; - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => 4; - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 4; - /// - public int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) => 4; - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - { - _ = checked((int)value); - return 4; - } - - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) - { - _ = checked((int)value); - return 4; - } - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - { - _ = checked((int)value); - return 4; - } - - /// - public override void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32(value); - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32(value); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32(value); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - /// - public void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int64Handler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int64Handler.cs deleted file mode 100644 index 7a39de1856..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/Int64Handler.cs +++ /dev/null @@ -1,92 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL bigint data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class Int64Handler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - public Int64Handler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override long Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt64(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((byte)Read(buf, len, fieldDescription)); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((short)Read(buf, len, fieldDescription)); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((int)Read(buf, len, fieldDescription)); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - decimal INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(long value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) => 8; - - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) - { - _ = checked((long)value); - return 8; - } - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - { - _ = checked((long)value); - return 8; - } - - /// - public override void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64((long)value); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64((long)value); - /// - public void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64((long)value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/MoneyHandler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/MoneyHandler.cs deleted file mode 100644 index ebab3d3fb9..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/MoneyHandler.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL money data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-money.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class MoneyHandler : NpgsqlSimpleTypeHandler -{ - public MoneyHandler(PostgresType pgType) : base(pgType) {} - - const int MoneyScale = 2; - - /// - public override decimal Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new DecimalRaw(buf.ReadInt64()) { Scale = MoneyScale }.Value; - - /// - public override int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) - => value < -92233720368547758.08M || value > 92233720368547758.07M - ? throw new OverflowException($"The supplied value ({value}) is outside the range for a PostgreSQL money value.") - : 8; - - /// - public override void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var raw = new DecimalRaw(value); - - var scaleDifference = MoneyScale - raw.Scale; - if (scaleDifference > 0) - DecimalRaw.Multiply(ref raw, DecimalRaw.Powers10[scaleDifference]); - else - { - value = Math.Round(value, MoneyScale, MidpointRounding.AwayFromZero); - raw = new DecimalRaw(value); - } - - var result = (long)raw.Mid << 32 | raw.Low; - if (raw.Negative) result = -result; - buf.WriteInt64(result); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/NumericHandler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/NumericHandler.cs deleted file mode 100644 index 1e624f86f7..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/NumericHandler.cs +++ /dev/null @@ -1,434 +0,0 @@ -using System; -using System.Globalization; -using System.Numerics; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL numeric data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class NumericHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler -{ - public NumericHandler(PostgresType pgType) : base(pgType) {} - - const int MaxDecimalScale = 28; - - const int SignPositive = 0x0000; - const int SignNegative = 0x4000; - const int SignNan = 0xC000; - const int SignPinf = 0xD000; - const int SignNinf = 0xF000; - const int SignSpecialMask = 0xC000; - - const int MaxGroupCount = 8; - const int MaxGroupScale = 4; - - static readonly uint MaxGroupSize = DecimalRaw.Powers10[MaxGroupScale]; - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4 * sizeof(short), async); - var result = new DecimalRaw(); - var groups = buf.ReadInt16(); - var weight = buf.ReadInt16() - groups + 1; - var sign = buf.ReadUInt16(); - - if ((sign & SignSpecialMask) == SignSpecialMask) - { - throw sign switch - { - SignNan => new InvalidCastException("Numeric NaN not supported by System.Decimal"), - SignPinf => new InvalidCastException("Numeric Infinity not supported by System.Decimal"), - SignNinf => new InvalidCastException("Numeric -Infinity not supported by System.Decimal"), - _ => new InvalidCastException($"Numeric special value {sign} not supported by System.Decimal") - }; - } - - if (sign == SignNegative) - DecimalRaw.Negate(ref result); - - var scale = buf.ReadInt16(); - if (scale < 0 is var exponential && exponential) - scale = (short)(-scale); - else - result.Scale = scale; - - if (scale > MaxDecimalScale) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - - var scaleDifference = exponential - ? weight * MaxGroupScale - : weight * MaxGroupScale + scale; - - if (groups > MaxGroupCount) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - - await buf.Ensure(groups * sizeof(ushort), async); - - if (groups == MaxGroupCount) - { - while (groups-- > 1) - { - DecimalRaw.Multiply(ref result, MaxGroupSize); - DecimalRaw.Add(ref result, buf.ReadUInt16()); - } - - var group = buf.ReadUInt16(); - var groupSize = DecimalRaw.Powers10[-scaleDifference]; - if (group % groupSize != 0) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - - DecimalRaw.Multiply(ref result, MaxGroupSize / groupSize); - DecimalRaw.Add(ref result, group / groupSize); - } - else - { - while (groups-- > 0) - { - DecimalRaw.Multiply(ref result, MaxGroupSize); - DecimalRaw.Add(ref result, buf.ReadUInt16()); - } - - if (scaleDifference < 0) - DecimalRaw.Divide(ref result, DecimalRaw.Powers10[-scaleDifference]); - else - while (scaleDifference > 0) - { - var scaleChunk = Math.Min(DecimalRaw.MaxUInt32Scale, scaleDifference); - DecimalRaw.Multiply(ref result, DecimalRaw.Powers10[scaleChunk]); - scaleDifference -= scaleChunk; - } - } - - return result.Value; - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (byte)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (short)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (int)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (long)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (float)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (double)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(4 * sizeof(short), async); - - var groups = (int)buf.ReadUInt16(); - var weightLeft = (int)buf.ReadInt16(); - var weightRight = weightLeft - groups + 1; - var sign = buf.ReadUInt16(); - buf.ReadInt16(); // dscale - - if (groups == 0) - { - return sign switch - { - SignPositive or SignNegative => BigInteger.Zero, - SignNan => throw new InvalidCastException("Numeric NaN not supported by BigInteger"), - SignPinf => throw new InvalidCastException("Numeric Infinity not supported by BigInteger"), - SignNinf => throw new InvalidCastException("Numeric -Infinity not supported by BigInteger"), - _ => throw new InvalidCastException($"Numeric special value {sign} not supported") - }; - } - - if (weightRight < 0) - { - await buf.Skip(groups * sizeof(ushort), async); - throw new InvalidCastException("Numeric value with non-zero fractional digits not supported by BigInteger"); - } - - var digits = new ushort[groups]; - - for (var i = 0; i < groups; i++) - { - await buf.Ensure(sizeof(ushort), async); - digits[i] = buf.ReadUInt16(); - } - - // Calculate powers 10^8, 10^16, 10^32, ... - // We should have the last calculated power to be less than the input - var lenPow = 2; // 2 ushorts fit in one uint, represents 10^8 - var numPowers = 0; - while (lenPow < weightLeft + 1) - { - lenPow <<= 1; - ++numPowers; - } - var factors = numPowers > 0 ? new BigInteger[numPowers] : null; - if (numPowers > 0) - { - factors![0] = new BigInteger(100000000U); - for (var i = 1; i < numPowers; i++) - factors[i] = factors[i - 1] * factors[i - 1]; - } - - var result = ToBigIntegerInner(0, weightLeft + 1, digits, factors); - return sign == SignPositive ? result : -result; - - static BigInteger ToBigIntegerInner(int offset, int length, ushort[] digits, BigInteger[]? factors) - { - if (length <= 2) - { - var r = 0U; - for (var i = offset; i < offset + length; i++) - { - r *= 10000U; - r += i < digits.Length ? digits[i] : 0U; - } - return r; - } - else - { - // Split the input into two halves, the lower one should be a power of two in digit length, - // then multiply the higher part with a precomputed power of 10^8 and add the results. - var lenFirstHalf = 2 << 1; // 2 ushorts fit in one uint, skip 1 since we've already covered the base case. - var pos = 0; - while (lenFirstHalf < length) - { - lenFirstHalf <<= 1; - ++pos; - } - var factor = factors![pos]; - lenFirstHalf >>= 1; - var lo = ToBigIntegerInner(offset + length - lenFirstHalf, lenFirstHalf, digits, factors); - var hi = ToBigIntegerInner(offset, length - lenFirstHalf, digits, factors); - return hi * factor + lo; // .NET uses Karatsuba multiplication, so this will be fast. - } - } - } - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(decimal value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var groupCount = 0; - var raw = new DecimalRaw(value); - if (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - { - uint remainder = default; - var scaleChunk = raw.Scale % MaxGroupScale; - if (scaleChunk > 0) - { - var divisor = DecimalRaw.Powers10[scaleChunk]; - var multiplier = DecimalRaw.Powers10[MaxGroupScale - scaleChunk]; - remainder = DecimalRaw.Divide(ref raw, divisor) * multiplier; - } - - while (remainder == 0) - remainder = DecimalRaw.Divide(ref raw, MaxGroupSize); - - groupCount++; - - while (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - { - DecimalRaw.Divide(ref raw, MaxGroupSize); - groupCount++; - } - } - - return lengthCache.Set((4 + groupCount) * sizeof(short)); - } - - /// - public int ValidateAndGetLength(short value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((decimal)value, ref lengthCache, parameter); - /// - public int ValidateAndGetLength(int value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((decimal)value, ref lengthCache, parameter); - /// - public int ValidateAndGetLength(long value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((decimal)value, ref lengthCache, parameter); - /// - public int ValidateAndGetLength(float value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((decimal)value, ref lengthCache, parameter); - /// - public int ValidateAndGetLength(double value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((decimal)value, ref lengthCache, parameter); - /// - public int ValidateAndGetLength(byte value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((decimal)value, ref lengthCache, parameter); - - public override async Task Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < (4 + MaxGroupCount) * sizeof(short)) - await buf.Flush(async, cancellationToken); - - WriteInner(new DecimalRaw(value), buf); - - static void WriteInner(DecimalRaw raw, NpgsqlWriteBuffer buf) - { - var weight = 0; - var groupCount = 0; - Span groups = stackalloc short[MaxGroupCount]; - groups.Fill(0); // SkipLocalsInit - - if (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - { - var scale = raw.Scale; - weight = -scale / MaxGroupScale - 1; - - uint remainder; - var scaleChunk = scale % MaxGroupScale; - if (scaleChunk > 0) - { - var divisor = DecimalRaw.Powers10[scaleChunk]; - var multiplier = DecimalRaw.Powers10[MaxGroupScale - scaleChunk]; - remainder = DecimalRaw.Divide(ref raw, divisor) * multiplier; - - if (remainder != 0) - { - weight--; - goto WriteGroups; - } - } - - while ((remainder = DecimalRaw.Divide(ref raw, MaxGroupSize)) == 0) - weight++; - - WriteGroups: - groups[groupCount++] = (short)remainder; - - while (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - groups[groupCount++] = (short)DecimalRaw.Divide(ref raw, MaxGroupSize); - } - - buf.WriteInt16(groupCount); - buf.WriteInt16(groupCount + weight); - buf.WriteInt16(raw.Negative ? SignNegative : SignPositive); - buf.WriteInt16(raw.Scale); - - while (groupCount > 0) - buf.WriteInt16(groups[--groupCount]); - } - } - - /// - public Task Write(short value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((decimal)value, buf, lengthCache, parameter, async, cancellationToken); - /// - public Task Write(int value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((decimal)value, buf, lengthCache, parameter, async, cancellationToken); - /// - public Task Write(long value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((decimal)value, buf, lengthCache, parameter, async, cancellationToken); - /// - public Task Write(byte value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((decimal)value, buf, lengthCache, parameter, async, cancellationToken); - /// - public Task Write(float value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((decimal)value, buf, lengthCache, parameter, async, cancellationToken); - /// - public Task Write(double value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((decimal)value, buf, lengthCache, parameter, async, cancellationToken); - - static ushort[] FromBigInteger(BigInteger value) - { - var str = value.ToString(CultureInfo.InvariantCulture); - if (str == "0") - return new ushort[4]; - - var negative = str[0] == '-'; - var strLen = str.Length; - var numGroups = (strLen - (negative ? 1 : 0) + 3) / 4; - - if (numGroups > 131072 / 4) - throw new InvalidCastException("Cannot write a BigInteger with more than 131072 digits"); - - var result = new ushort[4 + numGroups]; - - var strPos = strLen - numGroups * 4; - - var firstDigit = 0; - for (var i = 0; i < 4; i++) - { - if (strPos >= 0 && str[strPos] != '-') - firstDigit = firstDigit * 10 + (str[strPos] - '0'); - strPos++; - } - - result[4] = (ushort)firstDigit; - - for (var i = 1; i < numGroups; i++) - { - result[4 + i] = (ushort)((((str[strPos++] - '0') * 10 + (str[strPos++] - '0')) * 10 + (str[strPos++] - '0')) * 10 + - (str[strPos++] - '0')); - - } - - var lastNonZeroDigitPos = numGroups - 1; - while (result[4 + lastNonZeroDigitPos] == 0) - lastNonZeroDigitPos--; - - result[0] = (ushort)(lastNonZeroDigitPos + 1); // number of items in array - result[1] = (ushort)(numGroups - 1); // weight - result[2] = (ushort)(negative ? SignNegative : SignPositive); - result[3] = 0; // dscale - - return result; - } - - public int ValidateAndGetLength(BigInteger value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var result = FromBigInteger(value); - if (parameter != null) - parameter.ConvertedValue = result; - - return lengthCache.Set((4 + result[0]) * sizeof(ushort)); - } - - public async Task Write(BigInteger value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken = default) - { - var result = (ushort[])(parameter?.ConvertedValue ?? FromBigInteger(value))!; - var len = 4 + result[0]; - var pos = 0; - while (len-- > 0) - { - if (buf.WriteSpaceLeft < sizeof(ushort)) - await buf.Flush(async, cancellationToken); - buf.WriteUInt16(result[pos++]); - } - } - - #endregion -} diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/SingleHandler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/SingleHandler.cs deleted file mode 100644 index 09554db1e9..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/SingleHandler.cs +++ /dev/null @@ -1,45 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for the PostgreSQL real data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class SingleHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler -{ - public SingleHandler(PostgresType pgType) : base(pgType) {} - - #region Read - - /// - public override float Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadSingle(); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) => 4; - /// - public override int ValidateAndGetLength(float value, NpgsqlParameter? parameter) => 4; - - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteSingle((float)value); - /// - public override void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteSingle(value); - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/UInt32Handler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/UInt32Handler.cs deleted file mode 100644 index 1ea4633289..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/UInt32Handler.cs +++ /dev/null @@ -1,31 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for PostgreSQL unsigned 32-bit data types. This is only used for internal types. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-oid.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class UInt32Handler : NpgsqlSimpleTypeHandler -{ - public UInt32Handler(PostgresType pgType) : base(pgType) {} - - /// - public override uint Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadUInt32(); - - /// - public override int ValidateAndGetLength(uint value, NpgsqlParameter? parameter) => 4; - - /// - public override void Write(uint value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteUInt32(value); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/UInt64Handler.cs b/src/Npgsql/Internal/TypeHandlers/NumericHandlers/UInt64Handler.cs deleted file mode 100644 index db6d00d1db..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/NumericHandlers/UInt64Handler.cs +++ /dev/null @@ -1,29 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers.NumericHandlers; - -/// -/// A type handler for PostgreSQL unsigned 64-bit data types. This is only used for internal types. -/// -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class UInt64Handler : NpgsqlSimpleTypeHandler -{ - public UInt64Handler(PostgresType pgType) : base(pgType) {} - - /// - public override ulong Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadUInt64(); - - /// - public override int ValidateAndGetLength(ulong value, NpgsqlParameter? parameter) => 8; - - /// - public override void Write(ulong value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteUInt64(value); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/RangeHandler.cs b/src/Npgsql/Internal/TypeHandlers/RangeHandler.cs deleted file mode 100644 index 0c108696e1..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/RangeHandler.cs +++ /dev/null @@ -1,187 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for PostgreSQL range types. -/// -/// -/// See https://www.postgresql.org/docs/current/static/rangetypes.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -/// The range subtype. -// NOTE: This cannot inherit from NpgsqlTypeHandler>, since that triggers infinite generic recursion in Native AOT -public partial class RangeHandler : NpgsqlTypeHandler, INpgsqlTypeHandler> -{ - /// - /// The type handler for the subtype that this range type holds - /// - protected NpgsqlTypeHandler SubtypeHandler { get; } - - /// - public RangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler subtypeHandler) - : base(rangePostgresType) - => SubtypeHandler = subtypeHandler; - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(NpgsqlRange); - - /// - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => new ArrayHandler(pgArrayType, this, arrayNullabilityMode); - - /// - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => throw new NotSupportedException(); - - /// - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType) - => throw new NotSupportedException(); - - #region Read - - /// - public ValueTask> Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadRange(buf, len, async, fieldDescription); - - protected internal async ValueTask> ReadRange(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(1, async); - - var flags = (RangeFlags)buf.ReadByte(); - if ((flags & RangeFlags.Empty) != 0) - return NpgsqlRange.Empty; - - var lowerBound = default(TAnySubtype); - var upperBound = default(TAnySubtype); - - if ((flags & RangeFlags.LowerBoundInfinite) == 0) - lowerBound = await SubtypeHandler.ReadWithLength(buf, async); - - if ((flags & RangeFlags.UpperBoundInfinite) == 0) - upperBound = await SubtypeHandler.ReadWithLength(buf, async); - - return new NpgsqlRange(lowerBound, upperBound, flags); - } - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => await Read(buf, len, async, fieldDescription); - - #endregion - - #region Write - - /// - public int ValidateAndGetLength(NpgsqlRange value, [NotNullIfNotNull("lengthCache")] ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - protected internal int ValidateAndGetLengthRange(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var totalLen = 1; - var lengthCachePos = lengthCache?.Position ?? 0; - if (!value.IsEmpty) - { - if (!value.LowerBoundInfinite) - { - totalLen += 4; - if (value.LowerBound is not null) - totalLen += SubtypeHandler.ValidateAndGetLength(value.LowerBound, ref lengthCache, null); - } - - if (!value.UpperBoundInfinite) - { - totalLen += 4; - if (value.UpperBound is not null) - totalLen += SubtypeHandler.ValidateAndGetLength(value.UpperBound, ref lengthCache, null); - } - } - - // If we're traversing an already-populated length cache, rewind to first element slot so that - // the elements' handlers can access their length cache values - if (lengthCache != null && lengthCache.IsPopulated) - lengthCache.Position = lengthCachePos; - - return totalLen; - } - - /// - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); - - protected internal async Task WriteRange(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte((byte)value.Flags); - - if (value.IsEmpty) - return; - - if (!value.LowerBoundInfinite) - await SubtypeHandler.WriteWithLength(value.LowerBound, buf, lengthCache, null, async, cancellationToken); - - if (!value.UpperBoundInfinite) - await SubtypeHandler.WriteWithLength(value.UpperBound, buf, lengthCache, null, async, cancellationToken); - } - - #endregion -} - -/// -/// Type handler for PostgreSQL range types. -/// -/// -/// Introduced in PostgreSQL 9.2. -/// https://www.postgresql.org/docs/current/static/rangetypes.html -/// -/// The main range subtype. -/// An alternative range subtype. -public class RangeHandler : RangeHandler, INpgsqlTypeHandler> -{ - /// - public RangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler subtypeHandler) - : base(rangePostgresType, subtypeHandler) {} - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadRange(buf, len, async, fieldDescription); - - /// - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthRange(value, ref lengthCache, parameter); - - /// - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteRange(value, buf, lengthCache, parameter, async, cancellationToken); - - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - NpgsqlRange converted => ValidateAndGetLength(converted, ref lengthCache, parameter), - NpgsqlRange converted => ValidateAndGetLength(converted, ref lengthCache, parameter), - - DBNull => 0, - null => 0, - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type RangeHandler") - }; - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch - { - NpgsqlRange converted => ((INpgsqlTypeHandler>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - NpgsqlRange converted => ((INpgsqlTypeHandler>)this).WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - - DBNull => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type RangeHandler") - }; -} diff --git a/src/Npgsql/Internal/TypeHandlers/RecordHandler.cs b/src/Npgsql/Internal/TypeHandlers/RecordHandler.cs deleted file mode 100644 index 9a255e02d3..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/RecordHandler.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// Type handler for PostgreSQL record types. Defaults to returning object[], but can also return or . -/// -/// -/// https://www.postgresql.org/docs/current/static/datatype-pseudo.html -/// -/// Encoding (identical to composite): -/// A 32-bit integer with the number of columns, then for each column: -/// * An OID indicating the type of the column -/// * The length of the column(32-bit integer), or -1 if null -/// * The column data encoded as binary -/// -sealed partial class RecordHandler : NpgsqlTypeHandler -{ - readonly TypeMapper _typeMapper; - - public RecordHandler(PostgresType postgresType, TypeMapper typeMapper) - : base(postgresType) - => _typeMapper = typeMapper; - - #region Read - - protected internal override async ValueTask ReadCustom( - NpgsqlReadBuffer buf, - int len, - bool async, - FieldDescription? fieldDescription) - { - if (typeof(T) == typeof(object[])) - return (T)(object)await Read(buf, len, async, fieldDescription); - - if (typeof(T).FullName?.StartsWith("System.ValueTuple`", StringComparison.Ordinal) == true || - typeof(T).FullName?.StartsWith("System.Tuple`", StringComparison.Ordinal) == true) - { - var asArray = await Read(buf, len, async, fieldDescription); - if (typeof(T).GenericTypeArguments.Length != asArray.Length) - throw new InvalidCastException($"Cannot read record type with {asArray.Length} fields as {typeof(T)}"); - - var constructor = typeof(T).GetConstructors().Single(c => c.GetParameters().Length == asArray.Length); - return (T)constructor.Invoke(asArray); - } - - return await base.ReadCustom(buf, len, async, fieldDescription); - } - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => await Read(buf, len, async, fieldDescription); - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var fieldCount = buf.ReadInt32(); - var result = new object[fieldCount]; - - for (var i = 0; i < fieldCount; i++) - { - await buf.Ensure(8, async); - var typeOID = buf.ReadUInt32(); - var fieldLen = buf.ReadInt32(); - if (fieldLen == -1) // Null field, simply skip it and leave at default - continue; - result[i] = await _typeMapper.ResolveByOID(typeOID).ReadAsObject(buf, fieldLen, async); - } - - return result; - } - - /// - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => throw new NotSupportedException(); - - /// - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType) - => throw new NotSupportedException(); - - #endregion - - #region Write (unsupported) - - public override int ValidateAndGetLength(object[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException("Can't write record types"); - - public override Task Write( - object[] value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => throw new NotSupportedException("Can't write record types"); - - #endregion -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/SystemTextJsonHandler.cs b/src/Npgsql/Internal/TypeHandlers/SystemTextJsonHandler.cs deleted file mode 100644 index 5ba3f03b3e..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/SystemTextJsonHandler.cs +++ /dev/null @@ -1,209 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Text; -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL json and jsonb data type which uses System.Text.Json. -/// -/// -/// See https://www.postgresql.org/docs/current/datatype-json.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public class SystemTextJsonHandler : JsonTextHandler -{ - readonly JsonSerializerOptions _serializerOptions; - readonly bool _isJsonb; - readonly int _headerLen; - - /// - /// Prepended to the string in the wire encoding - /// - const byte JsonbProtocolVersion = 1; - - static readonly JsonSerializerOptions DefaultSerializerOptions = new(); - - /// - public SystemTextJsonHandler(PostgresType postgresType, Encoding encoding, bool isJsonb, JsonSerializerOptions? serializerOptions = null) - : base(postgresType, encoding, isJsonb) - { - _serializerOptions = serializerOptions ?? DefaultSerializerOptions; - _isJsonb = isJsonb; - _headerLen = isJsonb ? 1 : 0; - } - - /// - protected internal override int ValidateAndGetLengthCustom([DisallowNull] TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (IsSupportedAsText()) - return base.ValidateAndGetLengthCustom(value, ref lengthCache, parameter); - - if (typeof(TAny) == typeof(JsonDocument)) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var data = SerializeJsonDocument((JsonDocument)(object)value); - if (parameter != null) - parameter.ConvertedValue = data; - return lengthCache.Set(data.Length + _headerLen); - } - - if (typeof(TAny) == typeof(JsonObject) || typeof(TAny) == typeof(JsonArray)) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var data = SerializeJsonObject((JsonNode)(object)value); - if (parameter != null) - parameter.ConvertedValue = data; - return lengthCache.Set(data.Length + _headerLen); - } - - // User POCO, need to serialize. At least internally ArrayPool buffers are used... - var s = JsonSerializer.Serialize(value, _serializerOptions); - if (parameter != null) - parameter.ConvertedValue = s; - - return TextHandler.ValidateAndGetLength(s, ref lengthCache, parameter) + _headerLen; - } - - /// - protected override async Task WriteWithLengthCustom([DisallowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - { - if (IsSupportedAsText()) - { - await base.WriteWithLengthCustom(value, buf, lengthCache, parameter, async, cancellationToken); - return; - } - - var spaceRequired = _isJsonb ? 5 : 4; - - if (buf.WriteSpaceLeft < spaceRequired) - await buf.Flush(async, cancellationToken); - - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - - if (_isJsonb) - buf.WriteByte(JsonbProtocolVersion); - - if (typeof(TAny) == typeof(JsonDocument)) - { - var data = parameter?.ConvertedValue != null - ? (byte[])parameter.ConvertedValue - : SerializeJsonDocument((JsonDocument)(object)value); - await buf.WriteBytesRaw(data, async, cancellationToken); - } - else if (typeof(TAny) == typeof(JsonObject) || typeof(TAny) == typeof(JsonArray)) - { - var data = parameter?.ConvertedValue != null - ? (byte[])parameter.ConvertedValue - : SerializeJsonObject((JsonNode)(object)value); - await buf.WriteBytesRaw(data, async, cancellationToken); - } - else - { - // User POCO, read serialized representation from the validation phase - var s = parameter?.ConvertedValue != null - ? (string)parameter.ConvertedValue - : JsonSerializer.Serialize(value, value.GetType(), _serializerOptions); - - await TextHandler.Write(s, buf, lengthCache, parameter, async, cancellationToken); - } - } - - /// - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => IsSupported(value.GetType()) - ? base.ValidateObjectAndGetLength(value, ref lengthCache, parameter) - : value switch - { - JsonDocument jsonDocument => ValidateAndGetLengthCustom(jsonDocument, ref lengthCache, parameter), - JsonObject jsonObject => ValidateAndGetLengthCustom(jsonObject, ref lengthCache, parameter), - JsonArray jsonArray => ValidateAndGetLengthCustom(jsonArray, ref lengthCache, parameter), - _ => ValidateAndGetLengthCustom(value, ref lengthCache, parameter) - }; - - /// - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value is null or DBNull || IsSupported(value.GetType()) - ? base.WriteObjectWithLength(value, buf, lengthCache, parameter, async, cancellationToken) - : value switch - { - JsonDocument jsonDocument => WriteWithLengthCustom(jsonDocument, buf, lengthCache, parameter, async, cancellationToken), - JsonObject jsonObject => WriteWithLengthCustom(jsonObject, buf, lengthCache, parameter, async, cancellationToken), - JsonArray jsonArray => WriteWithLengthCustom(jsonArray, buf, lengthCache, parameter, async, cancellationToken), - _ => WriteWithLengthCustom(value, buf, lengthCache, parameter, async, cancellationToken), - }; - - /// - protected internal override async ValueTask ReadCustom(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription) - { - if (IsSupportedAsText()) - { - return await base.ReadCustom(buf, byteLen, async, fieldDescription); - } - - if (_isJsonb) - { - await buf.Ensure(1, async); - var version = buf.ReadByte(); - if (version != JsonbProtocolVersion) - throw new NotSupportedException($"Don't know how to decode JSONB with wire format {version}, your connection is now broken"); - byteLen--; - } - - // JsonDocument is a view over its provided buffer, so we can't return one over our internal buffer (see #2811), so we deserialize - // a string and get a JsonDocument from that. #2818 tracks improving this. - if (typeof(T) == typeof(JsonDocument)) - return (T)(object)JsonDocument.Parse(await TextHandler.Read(buf, byteLen, async, fieldDescription)); - - // User POCO - if (buf.ReadBytesLeft >= byteLen) - return JsonSerializer.Deserialize(buf.ReadSpan(byteLen), _serializerOptions)!; - -#if NET6_0_OR_GREATER - return (async - ? await JsonSerializer.DeserializeAsync(buf.GetStream(byteLen, canSeek: false), _serializerOptions) - : JsonSerializer.Deserialize(buf.GetStream(byteLen, canSeek: false), _serializerOptions))!; -#else - return JsonSerializer.Deserialize(await TextHandler.Read(buf, byteLen, async, fieldDescription), _serializerOptions)!; -#endif - } - - byte[] SerializeJsonDocument(JsonDocument document) - { - // TODO: Writing is currently really inefficient - please don't criticize :) - // We need to implement one-pass writing to serialize directly to the buffer (or just switch to pipelines). - using var stream = new MemoryStream(); - using var writer = new Utf8JsonWriter(stream); - document.WriteTo(writer); - writer.Flush(); - return stream.ToArray(); - } - - byte[] SerializeJsonObject(JsonNode jsonObject) - { - // TODO: Writing is currently really inefficient - please don't criticize :) - // We need to implement one-pass writing to serialize directly to the buffer (or just switch to pipelines). - using var stream = new MemoryStream(); - using var writer = new Utf8JsonWriter(stream); - jsonObject.WriteTo(writer); - writer.Flush(); - return stream.ToArray(); - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/TextHandler.cs b/src/Npgsql/Internal/TypeHandlers/TextHandler.cs deleted file mode 100644 index a707c83efc..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/TextHandler.cs +++ /dev/null @@ -1,317 +0,0 @@ -using System; -using System.Buffers; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for PostgreSQL character data types (text, char, varchar, xml...). -/// -/// -/// See https://www.postgresql.org/docs/current/datatype-character.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class TextHandler : NpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler>, - INpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler>, ITextReaderHandler -{ - // Text types are handled a bit more efficiently when sent as text than as binary - // see https://github.com/npgsql/npgsql/issues/1210#issuecomment-235641670 - internal override bool PreferTextWrite => true; - - readonly Encoding _encoding; - - /// - protected internal TextHandler(PostgresType postgresType, Encoding encoding) - : base(postgresType) - => _encoding = encoding; - - #region Read - - /// - public override ValueTask Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription = null) - { - return buf.ReadBytesLeft >= byteLen - ? new ValueTask(buf.ReadString(byteLen)) - : ReadLong(_encoding, buf, byteLen, async); - - static async ValueTask ReadLong(Encoding encoding, NpgsqlReadBuffer buf, int byteLen, bool async) - { - if (byteLen <= buf.Size) - { - // The string's byte representation can fit in our read buffer, read it. - await buf.Ensure(byteLen, async); - return buf.ReadString(byteLen); - } - - // Bad case: the string's byte representation doesn't fit in our buffer. - // This is rare - will only happen in CommandBehavior.Sequential mode (otherwise the - // entire row is in memory). Tweaking the buffer length via the connection string can - - var tempBuf = ArrayPool.Shared.Rent(byteLen); - - try - { - var pos = 0; - while (true) - { - var len = Math.Min(buf.ReadBytesLeft, byteLen - pos); - buf.ReadBytes(tempBuf, pos, len); - pos += len; - if (pos < byteLen) - { - await buf.ReadMore(async); - continue; - } - break; - } - return encoding.GetString(tempBuf, 0, byteLen); - } - finally - { - ArrayPool.Shared.Return(tempBuf); - } - } - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription) - { - if (byteLen <= buf.Size) - { - // The string's byte representation can fit in our read buffer, read it. - await buf.Ensure(byteLen, async); - return buf.ReadChars(byteLen); - } - - var tempBuf = ArrayPool.Shared.Rent(byteLen); - - try - { - var pos = 0; - while (true) - { - var len = Math.Min(buf.ReadBytesLeft, byteLen - pos); - buf.ReadBytes(tempBuf, pos, len); - pos += len; - if (pos < byteLen) - { - await buf.ReadMore(async); - continue; - } - break; - } - return _encoding.GetChars(tempBuf, 0, byteLen); - } - finally - { - ArrayPool.Shared.Return(tempBuf); - } - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - // Make sure we have enough bytes in the buffer for a single character - // We can get here a much bigger length in case it's a string - // while we want to read only its first character - var maxBytes = Math.Min(_encoding.GetMaxByteCount(1), len); - await buf.Ensure(maxBytes, async); - - var character = ReadCharCore(); - - // We've been requested to read 'len' bytes, which is why we're going to skip them - // This is important for NpgsqlDataReader with CommandBehavior.SequentialAccess - // which tracks how many bytes it has to skip for the next column - await buf.Skip(len, async); - return character; - - char ReadCharCore() - { - var charSpan = buf.Buffer.AsSpan(buf.ReadPosition, maxBytes); - var chars = _encoding.GetCharCount(charSpan); - if (chars < 1) - throw new NpgsqlException("Could not read char - string was empty"); - - Span destination = stackalloc char[chars]; - _encoding.GetChars(charSpan, destination); - return destination[0]; - } - } - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing ArraySegment to PostgreSQL text is supported, no reading."); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription) - { - var bytes = new byte[byteLen]; - if (buf.ReadBytesLeft >= byteLen) - { - buf.ReadBytes(bytes, 0, byteLen); - return new ValueTask(bytes); - } - return ReadLong(buf, bytes, byteLen, async); - - static async ValueTask ReadLong(NpgsqlReadBuffer buf, byte[] bytes, int byteLen, bool async) - { - if (byteLen <= buf.Size) - { - // The bytes can fit in our read buffer, read it. - await buf.Ensure(byteLen, async); - buf.ReadBytes(bytes, 0, byteLen); - return bytes; - } - - // Bad case: the bytes don't fit in our buffer. - // This is rare - will only happen in CommandBehavior.Sequential mode (otherwise the - // entire row is in memory). Tweaking the buffer length via the connection string can - // help avoid this. - - var pos = 0; - while (true) - { - var len = Math.Min(buf.ReadBytesLeft, byteLen - pos); - buf.ReadBytes(bytes, pos, len); - pos += len; - if (pos < byteLen) - { - await buf.ReadMore(async); - continue; - } - break; - } - return bytes; - } - } - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing ReadOnlyMemory to PostgreSQL text is supported, no reading."); - - #endregion - - #region Write - - /// - public override unsafe int ValidateAndGetLength(string value, [NotNullIfNotNull("lengthCache")] ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - if (parameter == null || parameter.Size <= 0 || parameter.Size >= value.Length) - return lengthCache.Set(_encoding.GetByteCount(value)); - fixed (char* p = value) - return lengthCache.Set(_encoding.GetByteCount(p, parameter.Size)); - } - - /// - public virtual int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - return lengthCache.Set( - parameter == null || parameter.Size <= 0 || parameter.Size >= value.Length - ? _encoding.GetByteCount(value) - : _encoding.GetByteCount(value, 0, parameter.Size) - ); - } - - /// - public virtual int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - lengthCache ??= new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - if (parameter?.Size > 0) - throw new ArgumentException($"Parameter {parameter.ParameterName} is of type ArraySegment and should not have its Size set", parameter.ParameterName); - - return lengthCache.Set(value.Array is null ? 0 : _encoding.GetByteCount(value.Array, value.Offset, value.Count)); - } - - /// - public int ValidateAndGetLength(char value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - Span singleCharArray = stackalloc char[1]; - singleCharArray[0] = value; - return _encoding.GetByteCount(singleCharArray); - } - - /// - public int ValidateAndGetLength(byte[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Length; - - /// - public int ValidateAndGetLength(ReadOnlyMemory value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Length; - - /// - public override Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteString(value, buf, lengthCache!, parameter, async, cancellationToken); - - /// - public virtual Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var charLen = parameter == null || parameter.Size <= 0 || parameter.Size >= value.Length - ? value.Length - : parameter.Size; - return buf.WriteChars(value, 0, charLen, lengthCache!.GetLast(), async, cancellationToken); - } - - /// - public virtual Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value.Array is null ? Task.CompletedTask : buf.WriteChars(value.Array, value.Offset, value.Count, lengthCache!.GetLast(), async, cancellationToken); - - Task WriteString(string str, NpgsqlWriteBuffer buf, NpgsqlLengthCache lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var charLen = parameter == null || parameter.Size <= 0 || parameter.Size >= str.Length - ? str.Length - : parameter.Size; - return buf.WriteString(str, charLen, lengthCache.GetLast(), async, cancellationToken); - } - - /// - public async Task Write(char value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < _encoding.GetMaxByteCount(1)) - await buf.Flush(async, cancellationToken); - WriteCharCore(value, buf); - - static unsafe void WriteCharCore(char value, NpgsqlWriteBuffer buf) - { - Span singleCharArray = stackalloc char[1]; - singleCharArray[0] = value; - buf.WriteChars(singleCharArray); - } - } - - - public Task Write(byte[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken = default) - => buf.WriteBytesRaw(value, async, cancellationToken); - - public Task Write(ReadOnlyMemory value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken = default) - => buf.WriteBytesRaw(value, async, cancellationToken); - - #endregion - - /// - public virtual TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer) - { - var byteLength = (int)(stream.Length - stream.Position); - return buffer.ReadBytesLeft >= byteLength - ? buffer.GetPreparedTextReader(_encoding.GetString(buffer.Buffer, buffer.ReadPosition, byteLength), stream) - : new StreamReader(stream, _encoding); - } -} diff --git a/src/Npgsql/Internal/TypeHandlers/UnknownTypeHandler.cs b/src/Npgsql/Internal/TypeHandlers/UnknownTypeHandler.cs deleted file mode 100644 index 43d0be10fc..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/UnknownTypeHandler.cs +++ /dev/null @@ -1,95 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// Handles "conversions" for columns sent by the database with unknown OIDs. -/// This differs from TextHandler in that its a text-only handler (we don't want to receive binary -/// representations of the types registered here). -/// Note that this handler is also used in the very initial query that loads the OID mappings -/// (chicken and egg problem). -/// Also used for sending parameters with unknown types (OID=0) -/// -sealed class UnknownTypeHandler : TextHandler -{ - internal UnknownTypeHandler(Encoding encoding) - : base(UnknownBackendType.Instance, encoding) - { - } - - #region Read - - public override ValueTask Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription = null) - { - if (fieldDescription == null) - throw new Exception($"Received an unknown field but {nameof(fieldDescription)} is null (i.e. COPY mode)"); - - if (fieldDescription.IsBinaryFormat) - { - // At least get the name of the PostgreSQL type for the exception - throw new NotSupportedException( - buf.Connector.TypeMapper.DatabaseInfo.ByOID.TryGetValue(fieldDescription.TypeOID, out var pgType) - ? $"The field '{fieldDescription.Name}' has type '{pgType.DisplayName}', which is currently unknown to Npgsql. You can retrieve it as a string by marking it as unknown, please see the FAQ." - : $"The field '{fieldDescription.Name}' has a type currently unknown to Npgsql (OID {fieldDescription.TypeOID}). You can retrieve it as a string by marking it as unknown, please see the FAQ." - ); - } - - return base.Read(buf, byteLen, async, fieldDescription); - } - - #endregion Read - - #region Write - - // Allow writing anything that is a string or can be converted to one via the unknown type handler - - protected internal override int ValidateAndGetLengthCustom( - [DisallowNull] TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateObjectAndGetLength(value, ref lengthCache, parameter); - - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (value is string asString) - return ValidateAndGetLength(asString, ref lengthCache, parameter); - - if (parameter == null) - throw CreateConversionButNoParamException(value.GetType()); - - var converted = Convert.ToString(value)!; - parameter.ConvertedValue = converted; - - return ValidateAndGetLength(converted, ref lengthCache, parameter); - } - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value is null or DBNull) - return base.WriteObjectWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - - var convertedValue = value is string asString - ? asString - : (string)parameter!.ConvertedValue!; - - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(value, convertedValue, buf, lengthCache, parameter, async, cancellationToken); - - buf.WriteInt32(ValidateObjectAndGetLength(value, ref lengthCache, parameter)); - return Write(convertedValue, buf, lengthCache, parameter, async, cancellationToken); - - async Task WriteWithLengthLong(object value, string convertedValue, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - { - await buf.Flush(async, cancellationToken); - buf.WriteInt32(ValidateObjectAndGetLength(value!, ref lengthCache, parameter)); - await Write(convertedValue, buf, lengthCache, parameter, async, cancellationToken); - } - } - - #endregion Write -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/UnmappedEnumHandler.cs b/src/Npgsql/Internal/TypeHandlers/UnmappedEnumHandler.cs deleted file mode 100644 index d8accea307..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/UnmappedEnumHandler.cs +++ /dev/null @@ -1,149 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Reflection; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandlers; - -sealed class UnmappedEnumHandler : TextHandler -{ - readonly INpgsqlNameTranslator _nameTranslator; - - // Note that a separate instance of UnmappedEnumHandler is created for each PG enum type, so concurrency isn't "really" needed. - // However, in theory multiple different CLR enums may be used with the same PG enum type, and even if there's only one, we only know - // about it late (after construction), when the user actually reads/writes with one. So this handler is fully thread-safe. - readonly ConcurrentDictionary _types = new(); - - internal UnmappedEnumHandler(PostgresEnumType pgType, INpgsqlNameTranslator nameTranslator, Encoding encoding) - : base(pgType, encoding) - => _nameTranslator = nameTranslator; - - #region Read - - protected internal override async ValueTask ReadCustom(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - var s = await base.Read(buf, len, async, fieldDescription); - if (typeof(TAny) == typeof(string)) - return (TAny)(object)s; - - var typeRecord = GetTypeRecord(typeof(TAny)); - - if (!typeRecord.LabelToEnum.TryGetValue(s, out var value)) - throw new InvalidCastException($"Received enum value '{s}' from database which wasn't found on enum {typeof(TAny)}"); - - // TODO: Avoid boxing - return (TAny)(object)value; - } - - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => base.Read(buf, len, async, fieldDescription); - - #endregion - - #region Write - - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value is null || value is DBNull - ? 0 - : ValidateAndGetLength(value, ref lengthCache, parameter); - - protected internal override int ValidateAndGetLengthCustom(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value!, ref lengthCache, parameter); - - [UnconditionalSuppressMessage("Unmapped enums currently aren't trimming-safe.", "IL2072")] - int ValidateAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var type = value.GetType(); - if (type == typeof(string)) - return base.ValidateAndGetLength((string)value, ref lengthCache, parameter); - - var typeRecord = GetTypeRecord(type); - - // TODO: Avoid boxing - return typeRecord.EnumToLabel.TryGetValue((Enum)value, out var str) - ? base.ValidateAndGetLength(str, ref lengthCache, parameter) - : throw new InvalidCastException($"Can't write value {value} as enum {type}"); - } - - // TODO: This boxes the enum (again) - protected override Task WriteWithLengthCustom(TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteObjectWithLength(value!, buf, lengthCache, parameter, async, cancellationToken); - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value is null || value is DBNull) - return WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken); - - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(value, buf, lengthCache, parameter, async, cancellationToken); - - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - return Write(value, buf, lengthCache, parameter, async, cancellationToken); - - async Task WriteWithLengthLong(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - { - await buf.Flush(async, cancellationToken); - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - await Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - } - - [UnconditionalSuppressMessage("Unmapped enums currently aren't trimming-safe.", "IL2072")] - internal Task Write(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = value.GetType(); - if (type == typeof(string)) - return base.Write((string)value, buf, lengthCache, parameter, async, cancellationToken); - - var typeRecord = GetTypeRecord(type); - - // TODO: Avoid boxing - if (!typeRecord.EnumToLabel.TryGetValue((Enum)value, out var str)) - throw new InvalidCastException($"Can't write value {value} as enum {type}"); - return base.Write(str, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Misc - - TypeRecord GetTypeRecord(Type type) - { -#if NETSTANDARD2_0 - return _types.GetOrAdd(type, t => CreateTypeRecord(t, _nameTranslator)); -#else - return _types.GetOrAdd(type, static (t, translator) => CreateTypeRecord(t, translator), _nameTranslator); -#endif - } - - static TypeRecord CreateTypeRecord(Type type, INpgsqlNameTranslator nameTranslator) - { - var enumToLabel = new Dictionary(); - var labelToEnum = new Dictionary(); - - foreach (var field in type.GetFields(BindingFlags.Static | BindingFlags.Public)) - { - var attribute = (PgNameAttribute?)field.GetCustomAttributes(typeof(PgNameAttribute), false).FirstOrDefault(); - var enumName = attribute?.PgName ?? nameTranslator.TranslateMemberName(field.Name); - var enumValue = (Enum)field.GetValue(null)!; - - enumToLabel[enumValue] = enumName; - labelToEnum[enumName] = enumValue; - } - - return new(enumToLabel, labelToEnum); - } - - #endregion - - record struct TypeRecord(Dictionary EnumToLabel, Dictionary LabelToEnum); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/UnsupportedHandler.cs b/src/Npgsql/Internal/TypeHandlers/UnsupportedHandler.cs deleted file mode 100644 index 2d1f22f893..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/UnsupportedHandler.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -sealed class UnsupportedHandler : NpgsqlTypeHandler -{ - readonly string _exceptionMessage; - - public UnsupportedHandler(PostgresType postgresType, string exceptionMessage) : base(postgresType) - => _exceptionMessage = exceptionMessage; - - public override ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => throw new NotSupportedException(_exceptionMessage); - - public override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException(_exceptionMessage); - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken = default) - => throw new NotSupportedException(_exceptionMessage); - - protected internal override ValueTask ReadCustom(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException(_exceptionMessage); - - protected override Task WriteWithLengthCustom(TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, - CancellationToken cancellationToken) - => throw new NotSupportedException(_exceptionMessage); - - protected internal override int ValidateAndGetLengthCustom(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException(_exceptionMessage); - - public override Type GetFieldType(FieldDescription? fieldDescription = null) - => throw new NotSupportedException(_exceptionMessage); - - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => throw new NotSupportedException(_exceptionMessage); - - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => throw new NotSupportedException(_exceptionMessage); - - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType) - => throw new NotSupportedException(_exceptionMessage); -} diff --git a/src/Npgsql/Internal/TypeHandlers/UuidHandler.cs b/src/Npgsql/Internal/TypeHandlers/UuidHandler.cs deleted file mode 100644 index c70da8060d..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/UuidHandler.cs +++ /dev/null @@ -1,76 +0,0 @@ -using System; -using System.Runtime.InteropServices; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// A type handler for the PostgreSQL uuid data type. -/// -/// -/// See https://www.postgresql.org/docs/current/static/datatype-uuid.html. -/// -/// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it -/// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. -/// Use it at your own risk. -/// -public partial class UuidHandler : NpgsqlSimpleTypeHandler -{ - // The following table shows .NET GUID vs Postgres UUID (RFC 4122) layouts. - // - // Note that the first fields are converted from/to native endianness (handled by the Read* - // and Write* methods), while the last field is always read/written in big-endian format. - // - // We're passing BitConverter.IsLittleEndian to prevent reversing endianness on little-endian systems. - // - // | Bits | Bytes | Name | Endianness (GUID) | Endianness (RFC 4122) | - // | ---- | ----- | ----- | ----------------- | --------------------- | - // | 32 | 4 | Data1 | Native | Big | - // | 16 | 2 | Data2 | Native | Big | - // | 16 | 2 | Data3 | Native | Big | - // | 64 | 8 | Data4 | Big | Big | - - public UuidHandler(PostgresType pgType) : base(pgType) {} - - /// - public override Guid Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var raw = new GuidRaw - { - Data1 = buf.ReadInt32(), - Data2 = buf.ReadInt16(), - Data3 = buf.ReadInt16(), - Data4 = buf.ReadInt64(BitConverter.IsLittleEndian) - }; - - return raw.Value; - } - - /// - public override int ValidateAndGetLength(Guid value, NpgsqlParameter? parameter) - => 16; - - /// - public override void Write(Guid value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var raw = new GuidRaw(value); - - buf.WriteInt32(raw.Data1); - buf.WriteInt16(raw.Data2); - buf.WriteInt16(raw.Data3); - buf.WriteInt64(raw.Data4, BitConverter.IsLittleEndian); - } - - [StructLayout(LayoutKind.Explicit)] - struct GuidRaw - { - [FieldOffset(00)] public Guid Value; - [FieldOffset(00)] public int Data1; - [FieldOffset(04)] public short Data2; - [FieldOffset(06)] public short Data3; - [FieldOffset(08)] public long Data4; - public GuidRaw(Guid value) : this() => Value = value; - } -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandlers/VoidHandler.cs b/src/Npgsql/Internal/TypeHandlers/VoidHandler.cs deleted file mode 100644 index da24b58c75..0000000000 --- a/src/Npgsql/Internal/TypeHandlers/VoidHandler.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandlers; - -/// -/// https://www.postgresql.org/docs/current/static/datatype-boolean.html -/// -sealed class VoidHandler : NpgsqlSimpleTypeHandler -{ - public VoidHandler(PostgresType pgType) : base(pgType) {} - - public override DBNull Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => DBNull.Value; - - public override int ValidateAndGetLength(DBNull value, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - - public override void Write(DBNull value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - DBNull => 0, - null => 0, - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type {nameof(VoidHandler)}") - }; - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch - { - DBNull => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type {nameof(VoidHandler)}") - }; -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandling/INpgsqlSimpleTypeHandler.cs b/src/Npgsql/Internal/TypeHandling/INpgsqlSimpleTypeHandler.cs deleted file mode 100644 index 9a8ebf8cfa..0000000000 --- a/src/Npgsql/Internal/TypeHandling/INpgsqlSimpleTypeHandler.cs +++ /dev/null @@ -1,47 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using Npgsql.BackendMessages; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// Type handlers that wish to support reading other types in additional to the main one can -/// implement this interface for all those types. -/// -public interface INpgsqlSimpleTypeHandler -{ - /// - /// Reads a value of type with the given length from the provided buffer, - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - T Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception should be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - int ValidateAndGetLength([DisallowNull] T value, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, with the assumption that there is enough space in the buffer - /// (no I/O will occur). The Npgsql core will have taken care of that. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - void Write([DisallowNull] T value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandling/INpgsqlTypeHandler.cs b/src/Npgsql/Internal/TypeHandling/INpgsqlTypeHandler.cs deleted file mode 100644 index e1a4dc125b..0000000000 --- a/src/Npgsql/Internal/TypeHandling/INpgsqlTypeHandler.cs +++ /dev/null @@ -1,75 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// Type handlers that wish to support reading other types in additional to the main one can -/// implement this interface for all those types. -/// -public interface INpgsqlTypeHandler -{ - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception should be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// A cache where the length calculated during the validation phase can be stored for use at the writing phase. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - int ValidateAndGetLength([DisallowNull] T value, [NotNullIfNotNull(nameof(lengthCache))]ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer. - /// - /// The value to write. - /// The buffer to which to write. - /// A cache where the length calculated during the validation phase can be stored for use at the writing phase. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// - /// If I/O will be necessary (i.e. the buffer is full), determines whether it will be done synchronously or asynchronously. - /// - /// - /// An optional token to cancel the asynchronous operation. The default value is . - /// - Task Write([DisallowNull] T value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); -} - -static class INpgsqlTypeHandlerExtensions -{ - public static async Task WriteWithLength(this INpgsqlTypeHandler handler, T? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - if (value is null or DBNull) - { - buf.WriteInt32(-1); - return; - } - - buf.WriteInt32(handler.ValidateAndGetLength(value, ref lengthCache, parameter)); - await handler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } -} diff --git a/src/Npgsql/Internal/TypeHandling/ITextReaderHandler.cs b/src/Npgsql/Internal/TypeHandling/ITextReaderHandler.cs deleted file mode 100644 index b55000fadf..0000000000 --- a/src/Npgsql/Internal/TypeHandling/ITextReaderHandler.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System.Data.Common; -using System.IO; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// Implemented by handlers which support , returns a standard -/// TextReader given a binary Stream. -/// -interface ITextReaderHandler -{ - TextReader GetTextReader(Stream stream, NpgsqlReadBuffer buffer); -} diff --git a/src/Npgsql/Internal/TypeHandling/NpgsqlLengthCache.cs b/src/Npgsql/Internal/TypeHandling/NpgsqlLengthCache.cs deleted file mode 100644 index b36381e9e8..0000000000 --- a/src/Npgsql/Internal/TypeHandling/NpgsqlLengthCache.cs +++ /dev/null @@ -1,65 +0,0 @@ -using System.Collections.Generic; -using System.Diagnostics; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// An array of cached lengths for the parameters sending process. -/// -/// When sending parameters, lengths need to be calculated more than once (once for Bind, once for -/// an array, once for the string within that array). This cache optimizes that. Lengths are added -/// to the cache, and then retrieved in the same order. -/// -public sealed class NpgsqlLengthCache -{ - public bool IsPopulated; - public int Position; - public List Lengths; - - public NpgsqlLengthCache() => Lengths = new List(); - - public NpgsqlLengthCache(int capacity) => Lengths = new List(capacity); - - /// - /// Stores a length value in the cache, to be fetched later via . - /// Called at the phase. - /// - /// The length parameter. - public int Set(int len) - { - Debug.Assert(!IsPopulated); - Lengths.Add(len); - Position++; - return len; - } - - /// - /// Retrieves a length value previously stored in the cache via . - /// Called at the writing phase, after validation has already occurred and the length cache is populated. - /// - /// - public int Get() - { - Debug.Assert(IsPopulated); - return Lengths[Position++]; - } - - internal int GetLast() - { - Debug.Assert(IsPopulated); - return Lengths[Position-1]; - } - - internal void Rewind() - { - Position = 0; - IsPopulated = true; - } - - internal void Clear() - { - Lengths.Clear(); - Position = 0; - IsPopulated = false; - } -} diff --git a/src/Npgsql/Internal/TypeHandling/NpgsqlSimpleTypeHandler.cs b/src/Npgsql/Internal/TypeHandling/NpgsqlSimpleTypeHandler.cs deleted file mode 100644 index 5a9bbde2cf..0000000000 --- a/src/Npgsql/Internal/TypeHandling/NpgsqlSimpleTypeHandler.cs +++ /dev/null @@ -1,84 +0,0 @@ -using System; -using System.Data.Common; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// Base class for all simple type handlers, which read and write short, non-arbitrary lengthed -/// values to PostgreSQL. Provides a simpler API to implement when compared to - -/// Npgsql takes care of all I/O before calling into this type, so no I/O needs to be performed by it. -/// -/// -/// The default CLR type that this handler will read and write. For example, calling -/// on a column with this handler will return a value with type . -/// Type handlers can support additional types by implementing . -/// -public abstract class NpgsqlSimpleTypeHandler : NpgsqlTypeHandler, INpgsqlSimpleTypeHandler -{ - protected NpgsqlSimpleTypeHandler(PostgresType postgresType) : base(postgresType) {} - - /// - /// Reads a value of type with the given length from the provided buffer, - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public abstract TDefault Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - public sealed override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => throw new NotSupportedException(); - - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(len, async); - return Read(buf, len, fieldDescription)!; - } - - #region Write - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception shold be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - public abstract int ValidateAndGetLength(TDefault value, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, with the assumption that there is enough space in the buffer - /// (no I/O will occur). The Npgsql core will have taken care of that. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - public abstract void Write(TDefault value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter); - - /// - /// Simple type handlers override instead of this. - /// - public sealed override Task Write(TDefault value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => throw new NotSupportedException(); - - /// - /// Simple type handlers override instead of this. - /// - public sealed override int ValidateAndGetLength(TDefault value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - - #endregion -} diff --git a/src/Npgsql/Internal/TypeHandling/NpgsqlTypeHandler.cs b/src/Npgsql/Internal/TypeHandling/NpgsqlTypeHandler.cs deleted file mode 100644 index e9cdf8dd4d..0000000000 --- a/src/Npgsql/Internal/TypeHandling/NpgsqlTypeHandler.cs +++ /dev/null @@ -1,273 +0,0 @@ -using System; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandlers; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// Base class for all type handlers, which read and write CLR types into their PostgreSQL -/// binary representation. -/// Type handler writers shouldn't inherit from this class, inherit -/// or instead. -/// -public abstract class NpgsqlTypeHandler -{ - protected NpgsqlTypeHandler(PostgresType postgresType) - => PostgresType = postgresType; - - /// - /// The PostgreSQL type handled by this type handler. - /// - public PostgresType PostgresType { get; } - - #region Read - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected internal async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - switch (this) - { - case INpgsqlSimpleTypeHandler simpleTypeHandler: - await buf.Ensure(len, async); - return simpleTypeHandler.Read(buf, len, fieldDescription); - case INpgsqlTypeHandler typeHandler: - return await typeHandler.Read(buf, len, async, fieldDescription); - default: - return await ReadCustom(buf, len, async, fieldDescription); - } - } - - /// - /// Version of that's called when we know the entire value - /// is already buffered in memory (i.e. in non-sequential mode). - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public TAny Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(buf.ReadBytesLeft >= len); - - return this switch - { - INpgsqlSimpleTypeHandler simpleTypeHandler => simpleTypeHandler.Read(buf, len, fieldDescription), - INpgsqlTypeHandler typeHandler => typeHandler.Read(buf, len, async: false, fieldDescription).Result, - _ => ReadCustom(buf, len, async: false, fieldDescription).Result - }; - } - - protected internal virtual ValueTask ReadCustom(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new InvalidCastException(fieldDescription == null - ? $"Can't cast database type to {typeof(TAny).Name}" - : $"Can't cast database type {fieldDescription.Handler.PgDisplayName} to {typeof(TAny).Name}"); - - /// - /// Reads a column as the type handler's default read type. If it is not already entirely in - /// memory, sync or async I/O will be performed as specified by . - /// - public abstract ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - /// - /// Version of that's called when we know the entire value - /// is already buffered in memory (i.e. in non-sequential mode). - /// - internal object ReadAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(buf.ReadBytesLeft >= len); - - return ReadAsObject(buf, len, async: false, fieldDescription).Result; - } - - /// - /// Reads a value from the buffer, assuming our read position is at the value's preceding length. - /// If the length is -1 (null), this method will return the default value. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal async ValueTask ReadWithLength(NpgsqlReadBuffer buf, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var len = buf.ReadInt32(); - return len == -1 - ? default! - : NullableHandler.Exists - ? await NullableHandler.ReadAsync(this, buf, len, async, fieldDescription) - : await Read(buf, len, async, fieldDescription); - } - - #endregion - - #region Write - - /// - /// Called to validate and get the length of a value of a generic . - /// and must be handled before calling into this. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected internal int ValidateAndGetLength( - [DisallowNull] TAny value, [NotNullIfNotNull(nameof(lengthCache))] ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - Debug.Assert(value is not DBNull); - - return this switch - { - INpgsqlSimpleTypeHandler simpleTypeHandler => simpleTypeHandler.ValidateAndGetLength(value, parameter), - INpgsqlTypeHandler typeHandler => typeHandler.ValidateAndGetLength(value, ref lengthCache, parameter), - _ => ValidateAndGetLengthCustom(value, ref lengthCache, parameter) - }; - } - - protected internal virtual int ValidateAndGetLengthCustom( - [DisallowNull] TAny value, [NotNullIfNotNull(nameof(lengthCache))] ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - ValidateAndGetLengthCustomCore(parameter, typeof(TAny), PgDisplayName); - - static int ValidateAndGetLengthCustomCore(NpgsqlParameter? parameter, Type type, string displayName) - { - var parameterName = parameter is null - ? null - : parameter.TrimmedName == string.Empty - ? parameter.Collection is { } paramCollection - ? $"${paramCollection.IndexOf(parameter) + 1}" - : null // in case of COPY operations parameter isn't bound to a collection - : parameter.TrimmedName; - - throw new InvalidCastException(parameterName is null - ? $"Cannot write a value of CLR type '{type}' as database type '{displayName}'." - : $"Cannot write a value of CLR type '{type}' as database type '{displayName}' for parameter '{parameterName}'."); - } - - /// - /// Called to write the value of a generic . - /// - /// - /// In the vast majority of cases writing a parameter to the buffer won't need to perform I/O. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public async Task WriteWithLength(TAny? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // TODO: Possibly do a sync path when we don't do I/O (e.g. simple type handler, no flush) - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - if (value is null or DBNull) - { - buf.WriteInt32(-1); - return; - } - - switch (this) - { - case INpgsqlSimpleTypeHandler simpleTypeHandler: - var len = simpleTypeHandler.ValidateAndGetLength(value, parameter); - buf.WriteInt32(len); - if (buf.WriteSpaceLeft < len) - await buf.Flush(async, cancellationToken); - simpleTypeHandler.Write(value, buf, parameter); - return; - case INpgsqlTypeHandler typeHandler: - buf.WriteInt32(typeHandler.ValidateAndGetLength(value, ref lengthCache, parameter)); - await typeHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - return; - default: - await WriteWithLengthCustom(value, buf, lengthCache, parameter, async, cancellationToken); - return; - } - } - - /// - /// Typically does not need to be overridden by type handlers, but may be needed in some - /// cases (e.g. . - /// Note that this method assumes it can write 4 bytes of length (already verified by - /// ). - /// - protected virtual Task WriteWithLengthCustom( - [DisallowNull] TAny value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken) - => throw new InvalidCastException($"Can't write '{typeof(TAny).Name}' with type handler '{GetType().Name}'"); - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception shold be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// If the byte length calculation is costly (e.g. for UTF-8 strings), its result can be stored in the - /// length cache to be reused in the writing process, preventing recalculation. - /// - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - // Source-generated - public abstract int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, using either sync or async I/O. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// - /// An optional token to cancel the asynchronous operation. The default value is . - /// - // Source-generated - public abstract Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - - #endregion Write - - #region Misc - - public abstract Type GetFieldType(FieldDescription? fieldDescription = null); - - internal virtual bool PreferTextWrite => false; - - /// - /// Creates a type handler for arrays of this handler's type. - /// - public abstract NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode); - - /// - /// Creates a type handler for ranges of this handler's type. - /// - public abstract NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType); - - /// - /// Creates a type handler for multiranges of this handler's type. - /// - public abstract NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType); - - /// - /// Used to create an exception when the provided type can be converted and written, but an - /// instance of is required for caching of the converted value - /// (in . - /// - protected Exception CreateConversionButNoParamException(Type clrType) - => new InvalidCastException($"Can't convert .NET type '{clrType}' to PostgreSQL '{PgDisplayName}' within an array"); - - internal string PgDisplayName => PostgresType.DisplayName; - - #endregion Misc -} diff --git a/src/Npgsql/Internal/TypeHandling/NpgsqlTypeHandler`.cs b/src/Npgsql/Internal/TypeHandling/NpgsqlTypeHandler`.cs deleted file mode 100644 index ae1e0eee5c..0000000000 --- a/src/Npgsql/Internal/TypeHandling/NpgsqlTypeHandler`.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System; -using System.Data.Common; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Internal.TypeHandlers; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// Base class for all type handlers, which read and write CLR types into their PostgreSQL -/// binary representation. Unless your type is arbitrary-length, consider inheriting from -/// instead. -/// -/// -/// The default CLR type that this handler will read and write. For example, calling -/// on a column with this handler will return a value with type . -/// Type handlers can support additional types by implementing . -/// -public abstract class NpgsqlTypeHandler : NpgsqlTypeHandler, INpgsqlTypeHandler -{ - protected NpgsqlTypeHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public abstract ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - // Since TAny isn't constrained to class? or struct (C# doesn't have a non-nullable constraint that doesn't limit us to either struct or class), - // we must use the bang operator here to tell the compiler that a null value will never returned. - public override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - #endregion Read - - #region Write - - /// - /// Called to validate and get the length of a value of a generic . - /// - public abstract int ValidateAndGetLength(TDefault value, [NotNullIfNotNull(nameof(lengthCache))] ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Called to write the value of a generic . - /// - public abstract Task Write(TDefault value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - - #endregion Write - - #region Misc - - public override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(TDefault); - - /// - public override NpgsqlTypeHandler CreateArrayHandler(PostgresArrayType pgArrayType, ArrayNullabilityMode arrayNullabilityMode) - => new ArrayHandler(pgArrayType, this, arrayNullabilityMode); - - /// - public override NpgsqlTypeHandler CreateRangeHandler(PostgresType pgRangeType) - => new RangeHandler(pgRangeType, this); - - /// - public override NpgsqlTypeHandler CreateMultirangeHandler(PostgresMultirangeType pgMultirangeType) - => new MultirangeHandler(pgMultirangeType, (RangeHandler)CreateRangeHandler(pgMultirangeType.Subrange)); - - #endregion Misc -} diff --git a/src/Npgsql/Internal/TypeHandling/NullableHandler.cs b/src/Npgsql/Internal/TypeHandling/NullableHandler.cs deleted file mode 100644 index 89fd5a0cb4..0000000000 --- a/src/Npgsql/Internal/TypeHandling/NullableHandler.cs +++ /dev/null @@ -1,54 +0,0 @@ -using System; -using System.Diagnostics; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; - -namespace Npgsql.Internal.TypeHandling; - -abstract class NullableHandler -{ - static NullableHandler? _derivedInstance; - public static bool Exists => default(T) is null && typeof(T).IsValueType; - - static NullableHandler DerivedInstance - { - get - { - Debug.Assert(Exists); - return _derivedInstance ??= (NullableHandler?)Activator.CreateInstance(typeof(NullableHandler<,>).MakeGenericType(typeof(T), typeof(T).GenericTypeArguments[0]))!; - } - } - - public static T Read(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, FieldDescription? fieldDescription = null) => - DerivedInstance.ReadImpl(handler, buffer, columnLength, fieldDescription); - public static ValueTask ReadAsync(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, bool async, FieldDescription? fieldDescription = null) => - DerivedInstance.ReadAsyncImpl(handler, buffer, columnLength, async, fieldDescription); - public static int ValidateAndGetLength(NpgsqlTypeHandler handler, T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - DerivedInstance.ValidateAndGetLengthImpl(handler, value, ref lengthCache, parameter); - public static Task WriteAsync(NpgsqlTypeHandler handler, T value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) => - DerivedInstance.WriteAsyncImpl(handler, value, buffer, lengthCache, parameter, async, cancellationToken); - - protected abstract T ReadImpl(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, FieldDescription? fieldDescription = null); - protected abstract ValueTask ReadAsyncImpl(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLen, bool async, FieldDescription? fieldDescription = null); - protected abstract int ValidateAndGetLengthImpl(NpgsqlTypeHandler handler, T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - protected abstract Task WriteAsyncImpl(NpgsqlTypeHandler handler, T value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); -} - -class NullableHandler : NullableHandler - where TUnderlying : struct -{ - protected override T ReadImpl(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, FieldDescription? fieldDescription = null) - => (T)(object)handler.Read(buffer, columnLength, fieldDescription); - - protected override async ValueTask ReadAsyncImpl(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, bool async, FieldDescription? fieldDescription = null) - => (T)(object)await handler.Read(buffer, columnLength, async, fieldDescription); - - protected override int ValidateAndGetLengthImpl(NpgsqlTypeHandler handler, T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - value != null ? handler.ValidateAndGetLength(((TUnderlying?)(object)value).Value, ref lengthCache, parameter) : 0; - - protected override Task WriteAsyncImpl(NpgsqlTypeHandler handler, T value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value != null - ? handler.WriteWithLength(((TUnderlying?)(object)value).Value, buffer, lengthCache, parameter, async, cancellationToken) - : handler.WriteWithLength(DBNull.Value, buffer, lengthCache, parameter, async, cancellationToken); -} diff --git a/src/Npgsql/Internal/TypeHandling/TypeHandlerResolver.cs b/src/Npgsql/Internal/TypeHandling/TypeHandlerResolver.cs deleted file mode 100644 index feb6a719e7..0000000000 --- a/src/Npgsql/Internal/TypeHandling/TypeHandlerResolver.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System; -using NpgsqlTypes; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeHandling; - -/// -/// An Npgsql resolver for type handlers. Typically used by plugins to alter how Npgsql reads and writes values to PostgreSQL. -/// -public abstract class TypeHandlerResolver -{ - /// - /// Resolves a type handler given a PostgreSQL type name, corresponding to the typname column in the PostgreSQL pg_type catalog table. - /// - /// See . - public abstract NpgsqlTypeHandler? ResolveByDataTypeName(string typeName); - - /// - /// Resolves a type handler for a given NpgsqlDbType. - /// - public virtual NpgsqlTypeHandler? ResolveByNpgsqlDbType(NpgsqlDbType npgsqlDbType) => null; - - /// - /// Resolves a type handler given a .NET CLR type. - /// - public abstract NpgsqlTypeHandler? ResolveByClrType(Type type); - - /// - /// Resolves a type handler given a PostgreSQL type. - /// - public virtual NpgsqlTypeHandler? ResolveByPostgresType(PostgresType type) - => ResolveByDataTypeName(type.Name); - - public virtual NpgsqlTypeHandler? ResolveValueDependentValue(object value) => null; - - public virtual NpgsqlTypeHandler? ResolveValueTypeGenerically(T value) => null; -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandling/TypeHandlerResolverFactory.cs b/src/Npgsql/Internal/TypeHandling/TypeHandlerResolverFactory.cs deleted file mode 100644 index c1d5030b75..0000000000 --- a/src/Npgsql/Internal/TypeHandling/TypeHandlerResolverFactory.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.Internal.TypeHandling; - -public abstract class TypeHandlerResolverFactory -{ - public abstract TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector); - - public virtual TypeMappingResolver? CreateMappingResolver() => null; - - public virtual TypeMappingResolver? CreateGlobalMappingResolver() => null; -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeHandling/TypeMappingInfo.cs b/src/Npgsql/Internal/TypeHandling/TypeMappingInfo.cs deleted file mode 100644 index d669739e6f..0000000000 --- a/src/Npgsql/Internal/TypeHandling/TypeMappingInfo.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; -using System.Data; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeHandling; - -public sealed class TypeMappingInfo -{ - public TypeMappingInfo(NpgsqlDbType? npgsqlDbType, string? dataTypeName, Type clrType) - => (NpgsqlDbType, DataTypeName, ClrTypes) = (npgsqlDbType, dataTypeName, new[] { clrType }); - - public TypeMappingInfo(NpgsqlDbType? npgsqlDbType, string? dataTypeName, params Type[] clrTypes) - => (NpgsqlDbType, DataTypeName, ClrTypes) = (npgsqlDbType, dataTypeName, clrTypes); - - public NpgsqlDbType? NpgsqlDbType { get; } - // Note that we can't cache the result due to nullable's assignment not being thread safe - public DbType DbType - => NpgsqlDbType is null ? DbType.Object : GlobalTypeMapper.NpgsqlDbTypeToDbType(NpgsqlDbType.Value); - public string? DataTypeName { get; } - public Type[] ClrTypes { get; } -} diff --git a/src/Npgsql/Internal/TypeInfoCache.cs b/src/Npgsql/Internal/TypeInfoCache.cs new file mode 100644 index 0000000000..2fce6eb585 --- /dev/null +++ b/src/Npgsql/Internal/TypeInfoCache.cs @@ -0,0 +1,169 @@ +using System; +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +sealed class TypeInfoCache where TPgTypeId : struct +{ + readonly PgSerializerOptions _options; + readonly bool _validatePgTypeIds; + + // Mostly used for parameter writing, 8ns + readonly ConcurrentDictionary _cacheByClrType = new(); + + // Used for reading, occasionally for parameter writing where a db type was given. + // 8ns, about 10ns total to scan an array with 6, 7 different clr types under one pg type + readonly ConcurrentDictionary _cacheByPgTypeId = new(); + + static TypeInfoCache() + { + if (typeof(TPgTypeId) != typeof(Oid) && typeof(TPgTypeId) != typeof(DataTypeName)) + throw new InvalidOperationException("Cannot use this type argument."); + } + + public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) + { + _options = options; + _validatePgTypeIds = validatePgTypeIds; + } + + /// + /// + /// + /// + /// + /// + /// When this flag is true, and both type and pgTypeId are non null, a default info for the pgTypeId can be returned if an exact match + /// can't be found. + /// + /// + /// + public PgTypeInfo? GetOrAddInfo(Type? type, TPgTypeId? pgTypeId, bool defaultTypeFallback = false) + { + if (pgTypeId is { } id) + { + if (_cacheByPgTypeId.TryGetValue(id, out var infos)) + if (FindMatch(type, infos, defaultTypeFallback) is { } info) + return info; + + return AddEntryById(id, infos, defaultTypeFallback); + } + + if (type is not null) + return _cacheByClrType.TryGetValue(type, out var info) ? info : AddByType(type); + + return null; + + PgTypeInfo? FindMatch(Type? type, (Type? Type, PgTypeInfo? Info)[] infos, bool defaultTypeFallback) + { + PgTypeInfo? defaultInfo = null; + var negativeExactMatch = false; + for (var i = 0; i < infos.Length; i++) + { + ref var item = ref infos[i]; + if (item.Type == type) + { + if (item.Info is not null || !defaultTypeFallback) + return item.Info; + negativeExactMatch = true; + } + + if (defaultTypeFallback && item.Type is null) + defaultInfo = item.Info; + } + + // We can only return default info if we've seen a negative match (type: typeof(object), info: null) + // Otherwise we might return a previously requested default while the resolvers could produce the exact match. + return negativeExactMatch ? defaultInfo : null; + } + + PgTypeInfo? AddByType(Type type) + { + // We don't pass PgTypeId as we're interested in default converters here. + var info = CreateInfo(type, null, _options, defaultTypeFallback: false, _validatePgTypeIds); + + return info is null + ? null + : _cacheByClrType.TryAdd(type, info) // We never remove entries so either of these branches will always succeed. + ? info + : _cacheByClrType[type]; + } + + PgTypeInfo? AddEntryById(TPgTypeId pgTypeId, (Type? Type, PgTypeInfo? Info)[]? infos, bool defaultTypeFallback) + { + // We cache negatives (null info) to allow 'object or default' checks to never hit the resolvers after the first lookup. + var info = CreateInfo(type, pgTypeId, _options, defaultTypeFallback, _validatePgTypeIds); + + var isDefaultInfo = type is null && info is not null; + if (infos is null) + { + // Also add defaults by their info type to save a future resolver lookup + resize. + infos = isDefaultInfo + ? new [] { (type, info), (info!.Type, info) } + : new [] { (type, info) }; + + if (_cacheByPgTypeId.TryAdd(pgTypeId, infos)) + return info; + } + + // We have to update it instead. + while (true) + { + infos = _cacheByPgTypeId[pgTypeId]; + if (FindMatch(type, infos, defaultTypeFallback) is { } racedInfo) + return racedInfo; + + // Also add defaults by their info type to save a future resolver lookup + resize. + var oldInfos = infos; + var hasExactType = false; + if (isDefaultInfo) + { + foreach (var oldInfo in oldInfos) + if (oldInfo.Type == info!.Type) + hasExactType = true; + } + Array.Resize(ref infos, oldInfos.Length + (isDefaultInfo && !hasExactType ? 2 : 1)); + infos[oldInfos.Length] = (type, info); + if (isDefaultInfo && !hasExactType) + infos[oldInfos.Length + 1] = (info!.Type, info); + + if (_cacheByPgTypeId.TryUpdate(pgTypeId, infos, oldInfos)) + return info; + } + } + + static PgTypeInfo? CreateInfo(Type? type, TPgTypeId? typeId, PgSerializerOptions options, bool defaultTypeFallback, bool validatePgTypeIds) + { + var pgTypeId = AsPgTypeId(typeId); + // Validate that we only pass data types that are supported by the backend. + var dataTypeName = pgTypeId is { } id ? (DataTypeName?)options.DatabaseInfo.GetDataTypeName(id, validate: validatePgTypeIds) : null; + var info = options.TypeInfoResolver.GetTypeInfo(type, dataTypeName, options); + if (info is null && defaultTypeFallback) + { + type = null; + info = options.TypeInfoResolver.GetTypeInfo(type, dataTypeName, options); + } + + if (info is null) + return null; + + if (pgTypeId is not null && info.PgTypeId != pgTypeId) + throw new InvalidOperationException("A Postgres type was passed but the resolved PgTypeInfo does not have an equal PgTypeId."); + + if (type is not null && !info.IsBoxing && info.Type != type) + throw new InvalidOperationException($"A CLR type '{type}' was passed but the resolved PgTypeInfo does not have an equal Type: {info.Type}."); + + return info; + } + + static PgTypeId? AsPgTypeId(TPgTypeId? pgTypeId) + => pgTypeId switch + { + { } id when typeof(TPgTypeId) == typeof(DataTypeName) => new PgTypeId(Unsafe.As(ref id)), + { } id => new PgTypeId(Unsafe.As(ref id)), + null => null + }; + } +} diff --git a/src/Npgsql/Internal/TypeInfoMapping.cs b/src/Npgsql/Internal/TypeInfoMapping.cs new file mode 100644 index 0000000000..b2dfd58377 --- /dev/null +++ b/src/Npgsql/Internal/TypeInfoMapping.cs @@ -0,0 +1,668 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal; + +/// +/// +/// +/// +/// +/// +/// Signals whether a resolver based TypeInfo can keep its PgTypeId undecided or whether it should follow mapping.DataTypeName. +/// +public delegate PgTypeInfo TypeInfoFactory(PgSerializerOptions options, TypeInfoMapping mapping, bool resolvedDataTypeName); + +public enum MatchRequirement +{ + /// Match when the clr type and datatype name both match. + /// It's also the only requirement that participates in clr type fallback matching. + All, + /// Match when the datatype name or CLR type matches while the other also matches or is absent. + Single, + /// Match when the datatype name matches and the clr type also matches or is absent. + DataTypeName +} + +/// A factory for well-known PgConverters. +public static class PgConverterFactory +{ + public static PgConverter CreateArrayMultirangeConverter(PgConverter rangeConverter, PgSerializerOptions options) where T : notnull + => new MultirangeConverter(rangeConverter); + + public static PgConverter> CreateListMultirangeConverter(PgConverter rangeConverter, PgSerializerOptions options) where T : notnull + => new MultirangeConverter, T>(rangeConverter); + + public static PgConverter> CreateRangeConverter(PgConverter subTypeConverter, PgSerializerOptions options) + => new RangeConverter(subTypeConverter); + + public static PgConverter CreatePolymorphicArrayConverter(Func> arrayConverterFactory, Func> nullableArrayConverterFactory, PgSerializerOptions options) + => options.ArrayNullabilityMode switch + { + ArrayNullabilityMode.Never => arrayConverterFactory(), + ArrayNullabilityMode.Always => nullableArrayConverterFactory(), + ArrayNullabilityMode.PerInstance => new PolymorphicArrayConverter(arrayConverterFactory(), nullableArrayConverterFactory()), + _ => throw new ArgumentOutOfRangeException() + }; +} + +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public readonly struct TypeInfoMapping +{ + public TypeInfoMapping(Type type, string dataTypeName, TypeInfoFactory factory) + { + Type = type; + // For objects it makes no sense to have clr type only matches by default, there are too many implementations. + MatchRequirement = type == typeof(object) ? MatchRequirement.DataTypeName : MatchRequirement.All; + DataTypeName = Postgres.DataTypeName.NormalizeName(dataTypeName); + Factory = factory; + } + + public TypeInfoFactory Factory { get; init; } + public Type Type { get; init; } + public string DataTypeName { get; init; } + + public MatchRequirement MatchRequirement { get; init; } + public Func? TypeMatchPredicate { get; init; } + + public bool TypeEquals(Type type) => TypeMatchPredicate?.Invoke(type) ?? Type == type; + public bool DataTypeNameEquals(string dataTypeName) + { + var span = DataTypeName.AsSpan(); + return Postgres.DataTypeName.IsFullyQualified(span) + ? span.Equals(dataTypeName.AsSpan(), StringComparison.Ordinal) + : span.Equals(Postgres.DataTypeName.ValidatedName(dataTypeName).UnqualifiedNameSpan, StringComparison.Ordinal); + } + + string DebuggerDisplay + { + get + { + var builder = new StringBuilder() + .Append(Type.Name) + .Append(" <-> ") + .Append(Postgres.DataTypeName.FromDisplayName(DataTypeName).DisplayName); + + if (MatchRequirement is not MatchRequirement.All) + builder.Append($" ({MatchRequirement.ToString().ToLowerInvariant()})"); + + return builder.ToString(); + } + } +} + +public sealed class TypeInfoMappingCollection +{ + readonly TypeInfoMappingCollection? _baseCollection; + readonly List _items; + + public TypeInfoMappingCollection(int capacity = 0) + => _items = new(capacity); + + public TypeInfoMappingCollection() : this(0) { } + + // Not used for resolving, only for composing (arrays that need to find the element mapping etc). + public TypeInfoMappingCollection(TypeInfoMappingCollection baseCollection) : this(0) + => _baseCollection = baseCollection; + + public TypeInfoMappingCollection(IEnumerable items) + => _items = new(items); + + public IReadOnlyList Items => _items; + + /// Returns the first default converter or the first converter that matches both type and dataTypeName. + /// If just a type was passed and no default was found we return the first converter with a type match. + public PgTypeInfo? Find(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + TypeInfoMapping? fallback = null; + foreach (var mapping in _items) + { + var looseTypeMatch = mapping.TypeMatchPredicate is { } pred ? pred(type) : type is null || mapping.Type == type; + var typeMatch = type is not null && looseTypeMatch; + var dataTypeMatch = dataTypeName is not null && mapping.DataTypeNameEquals(dataTypeName.Value.Value); + + switch (mapping.MatchRequirement) + { + case var _ when dataTypeMatch && typeMatch: + case not MatchRequirement.All when dataTypeMatch && looseTypeMatch: + case MatchRequirement.Single when dataTypeName is null && looseTypeMatch: + var resolvedMapping = mapping with + { + Type = type ?? mapping.Type, + // Make sure plugins (which match on unqualified names) and resolvers get the fully qualified name to canonicalize. + DataTypeName = dataTypeName is not null ? dataTypeName.GetValueOrDefault().Value : mapping.DataTypeName + }; + return resolvedMapping.Factory(options, resolvedMapping, dataTypeName is not null); + // DataTypeName is explicitly requiring dataTypeName so it won't be used for a fallback, Single would have matched above already. + case MatchRequirement.All when fallback is null && dataTypeName is null && typeMatch: + fallback = mapping.TypeMatchPredicate is not null ? mapping with { Type = type! } : mapping; + break; + default: + continue; + } + } + + return fallback?.Factory(options, fallback.Value, dataTypeName is not null); + } + + bool TryFindMapping(Type type, string dataTypeName, out TypeInfoMapping value) + { + foreach (var mapping in _baseCollection?._items ?? _items) + { + // During mapping we just use look for the declared type, regardless of TypeMatchPredicate. + if (mapping.Type == type && mapping.DataTypeNameEquals(dataTypeName)) + { + value = mapping; + return true; + } + } + + value = default; + return false; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + TypeInfoMapping FindMapping(Type type, string dataTypeName) + => TryFindMapping(type, dataTypeName, out var info) ? info : throw new InvalidOperationException($"Could not find mapping for {type} <-> {dataTypeName}"); + + // Helper to eliminate generic display class duplication. + static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool supportsWriting = true) + => (options, mapping, dataTypeNameMatch) => + { + var innerInfo = innerMapping.Factory(options, innerMapping, dataTypeNameMatch); + var converter = mapper(mapping, innerInfo); + var preferredFormat = copyPreferredFormat ? innerInfo.PreferredFormat : null; + var writingSupported = supportsWriting && innerInfo.SupportsWriting && mapping.Type != typeof(object); + var unboxedType = ComputeUnboxedType(defaultType: mappingType, converter.TypeToConvert, mapping.Type); + + return new PgTypeInfo(options, converter, TypeInfoMappingHelpers.ResolveFullyQualifiedName(options, mapping.DataTypeName), unboxedType) + { + PreferredFormat = preferredFormat, + SupportsWriting = writingSupported + }; + }; + + // Helper to eliminate generic display class duplication. + static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool supportsWriting = true) + => (options, mapping, dataTypeNameMatch) => + { + var innerInfo = (PgResolverTypeInfo)innerMapping.Factory(options, innerMapping, dataTypeNameMatch); + var resolver = mapper(mapping, innerInfo); + var preferredFormat = copyPreferredFormat ? innerInfo.PreferredFormat : null; + var writingSupported = supportsWriting && innerInfo.SupportsWriting && mapping.Type != typeof(object); + var unboxedType = ComputeUnboxedType(defaultType: mappingType, resolver.TypeToConvert, mapping.Type); + // We include the data type name if the inner info did so as well. + // This way we can rely on its logic around resolvedDataTypeName, including when it ignores that flag. + PgTypeId? pgTypeId = innerInfo.PgTypeId is not null + ? TypeInfoMappingHelpers.ResolveFullyQualifiedName(options, mapping.DataTypeName) + : null; + return new PgResolverTypeInfo(options, resolver, pgTypeId, unboxedType) + { + PreferredFormat = preferredFormat, + SupportsWriting = writingSupported + }; + }; + + static Type? ComputeUnboxedType(Type defaultType, Type converterType, Type matchedType) + { + // The minimal hierarchy that should hold for things to work is object < converterType < matchedType. + // Though these types could often be seen in a hierarchy: object < converterType < defaultType < matchedType. + // Some caveats with the latter being for instance Array being the matchedType while the defaultType is int[]. + Debug.Assert(converterType.IsAssignableFrom(matchedType) || matchedType == typeof(object)); + Debug.Assert(converterType.IsAssignableFrom(defaultType)); + + // A special case for object matches, where we return a more specific type than was matched. + // This is to report e.g. Array converters as Array when their matched type was object. + if (matchedType == typeof(object)) + return converterType; + + // This is to report e.g. Array converters as int[,,,] when their matched type was such. + if (matchedType != defaultType) + return matchedType; + + // If defaultType does not equal converterType we take defaultType as it's more specific. + // This is to report e.g. Array converters as int[] when their matched type was their default type. + if (defaultType != converterType) + return defaultType; + + // Keep the converter type. + return null; + } + + public void Add(TypeInfoMapping mapping) => _items.Add(mapping); + + public void AddRange(TypeInfoMappingCollection collection) => _items.AddRange(collection._items); + + Func GetDefaultConfigure(bool isDefault) + => GetDefaultConfigure(isDefault ? MatchRequirement.Single : MatchRequirement.All); + Func GetDefaultConfigure(MatchRequirement matchRequirement) + => matchRequirement switch + { + MatchRequirement.All => static mapping => mapping with { MatchRequirement = MatchRequirement.All }, + MatchRequirement.DataTypeName => static mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }, + MatchRequirement.Single => static mapping => mapping with { MatchRequirement = MatchRequirement.Single }, + _ => throw new ArgumentOutOfRangeException(nameof(matchRequirement), matchRequirement, null) + }; + + Func GetArrayTypeMatchPredicate(Func elementTypeMatchPredicate) + => type => type is null || (type.IsArray && elementTypeMatchPredicate.Invoke(type.GetElementType()!)); + Func GetListTypeMatchPredicate(Func elementTypeMatchPredicate) + => type => type is null || (type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(List<>) + && elementTypeMatchPredicate(type.GetGenericArguments()[0])); + + public void AddType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : class + => AddType(dataTypeName, createInfo, GetDefaultConfigure(isDefault)); + + public void AddType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : class + => AddType(dataTypeName, createInfo, GetDefaultConfigure(matchRequirement)); + + public void AddType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : class + { + var mapping = new TypeInfoMapping(typeof(T), dataTypeName, createInfo); + _items.Add(configure?.Invoke(mapping) ?? mapping); + } + + // Aliased to AddType at this time. + public void AddResolverType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : class + => AddType(dataTypeName, createInfo, GetDefaultConfigure(isDefault)); + + // Aliased to AddType at this time. + public void AddResolverType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : class + => AddType(dataTypeName, createInfo, GetDefaultConfigure(matchRequirement)); + + // Aliased to AddType at this time. + public void AddResolverType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : class + => AddType(dataTypeName, createInfo, configure); + + public void AddArrayType(string elementDataTypeName) where TElement : class + => AddArrayType(FindMapping(typeof(TElement), elementDataTypeName)); + + public void AddArrayType(TypeInfoMapping elementMapping) where TElement : class + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type == typeof(TElement))); + var listTypeMatchPredicate = elementMapping.TypeMatchPredicate is not null ? GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate) : null; + + AddArrayType(elementMapping, typeof(TElement[]), CreateArrayBasedConverter, arrayTypeMatchPredicate, suppressObjectMapping: TryFindMapping(typeof(object), elementMapping.DataTypeName, out _)); + AddArrayType(elementMapping, typeof(List), CreateListBasedConverter, listTypeMatchPredicate, suppressObjectMapping: true); + + void AddArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate = null, bool suppressObjectMapping = false) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + _items.Add(arrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + { + if (!dataTypeNameMatch) + throw new InvalidOperationException("Should not happen, please file a bug."); + + return arrayMapping.Factory(options, mapping, dataTypeNameMatch); + })); + } + } + + public void AddResolverArrayType(string elementDataTypeName) where TElement : class + => AddResolverArrayType(FindMapping(typeof(TElement), elementDataTypeName)); + + public void AddResolverArrayType(TypeInfoMapping elementMapping) where TElement : class + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type == typeof(TElement))); + var listTypeMatchPredicate = elementMapping.TypeMatchPredicate is not null ? GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate) : null; + + AddResolverArrayType(elementMapping, typeof(TElement[]), CreateArrayBasedConverterResolver, arrayTypeMatchPredicate, suppressObjectMapping: TryFindMapping(typeof(object), elementMapping.DataTypeName, out _)); + AddResolverArrayType(elementMapping, typeof(List), CreateListBasedConverterResolver, listTypeMatchPredicate, suppressObjectMapping: true); + + void AddResolverArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate = null, bool suppressObjectMapping = false) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + _items.Add(arrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + { + if (!dataTypeNameMatch) + throw new InvalidOperationException("Should not happen, please file a bug."); + + return arrayMapping.Factory(options, mapping, dataTypeNameMatch); + })); + } + } + + public void AddStructType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : struct + => AddStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverter(innerInfo.GetConcreteResolution().GetConverter()), GetDefaultConfigure(isDefault)); + + public void AddStructType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : struct + => AddStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverter(innerInfo.GetConcreteResolution().GetConverter()), GetDefaultConfigure(matchRequirement)); + + public void AddStructType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : struct + => AddStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverter(innerInfo.GetConcreteResolution().GetConverter()), configure); + + // Lives outside to prevent capture of T. + void AddStructType(Type type, Type nullableType, string dataTypeName, TypeInfoFactory createInfo, + Func nullableConverter, Func? configure) + { + var mapping = new TypeInfoMapping(type, dataTypeName, createInfo); + mapping = configure?.Invoke(mapping) ?? mapping; + _items.Add(mapping); + _items.Add(new TypeInfoMapping(nullableType, dataTypeName, + CreateComposedFactory(nullableType, mapping, nullableConverter, copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement, + TypeMatchPredicate = mapping.TypeMatchPredicate is not null + ? type => type is null + ? mapping.TypeMatchPredicate(null) + : Nullable.GetUnderlyingType(type) is { } underlying && mapping.TypeMatchPredicate(underlying) + : null + }); + } + + public void AddStructArrayType(string elementDataTypeName) where TElement : struct + => AddStructArrayType(FindMapping(typeof(TElement), elementDataTypeName), FindMapping(typeof(TElement?), elementDataTypeName), null); + + public void AddStructArrayType(string elementDataTypeName, Func configure) where TElement : struct + => AddStructArrayType(FindMapping(typeof(TElement), elementDataTypeName), FindMapping(typeof(TElement?), elementDataTypeName), configure); + + public void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, + Func? configure) where TElement : struct + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var nullableArrayTypeMatchPredicate = GetArrayTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type => + type is null || (Nullable.GetUnderlyingType(type) is { } underlying && underlying == typeof(TElement)))); + var listTypeMatchPredicate = elementMapping.TypeMatchPredicate is not null ? GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate) : null; + var nullableListTypeMatchPredicate = nullableElementMapping.TypeMatchPredicate is not null ? GetListTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate) : null; + + AddStructArrayType(elementMapping, nullableElementMapping, typeof(TElement[]), typeof(TElement?[]), + CreateArrayBasedConverter, CreateArrayBasedConverter, + arrayTypeMatchPredicate, nullableArrayTypeMatchPredicate, + configure, suppressObjectMapping: TryFindMapping(typeof(object), elementMapping.DataTypeName, out _)); + + // Don't add the object converter for the list based converter. + AddStructArrayType(elementMapping, nullableElementMapping, typeof(List), typeof(List), + CreateListBasedConverter, CreateListBasedConverter, + listTypeMatchPredicate, nullableListTypeMatchPredicate, + configure, suppressObjectMapping: true); + } + + // Lives outside to prevent capture of TElement. + void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, Type type, Type nullableType, + Func converter, Func nullableConverter, + Func? typeMatchPredicate, Func? nullableTypeMatchPredicate, Func? configure, bool suppressObjectMapping) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + arrayMapping = configure?.Invoke(arrayMapping) ?? arrayMapping; + var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter)) + { + MatchRequirement = arrayMapping.MatchRequirement, + TypeMatchPredicate = nullableTypeMatchPredicate + }; + + _items.Add(arrayMapping); + _items.Add(nullableArrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + { + return options.ArrayNullabilityMode switch + { + _ when !dataTypeNameMatch => throw new InvalidOperationException("Should not happen, please file a bug."), + ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.PerInstance => CreateComposedPerInstance( + arrayMapping.Factory(options, mapping, dataTypeNameMatch), + nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + mapping.DataTypeName + ), + _ => throw new ArgumentOutOfRangeException() + }; + }) { MatchRequirement = MatchRequirement.DataTypeName }); + + PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullableInnerTypeInfo, string dataTypeName) + { + var converter = + new PolymorphicArrayConverter( + innerTypeInfo.GetConcreteResolution().GetConverter(), + nullableInnerTypeInfo.GetConcreteResolution().GetConverter()); + + return new PgTypeInfo(innerTypeInfo.Options, converter, + innerTypeInfo.Options.GetCanonicalTypeId(new DataTypeName(dataTypeName)), unboxedType: typeof(Array)) { SupportsWriting = false }; + } + } + + public void AddResolverStructType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : struct + => AddResolverStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverterResolver(innerInfo), GetDefaultConfigure(isDefault)); + + public void AddResolverStructType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : struct + => AddResolverStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverterResolver(innerInfo), GetDefaultConfigure(matchRequirement)); + + public void AddResolverStructType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : struct + => AddResolverStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverterResolver(innerInfo), configure); + + // Lives outside to prevent capture of T. + void AddResolverStructType(Type type, Type nullableType, string dataTypeName, TypeInfoFactory createInfo, + Func nullableConverter, Func? configure) + { + var mapping = new TypeInfoMapping(type, dataTypeName, createInfo); + mapping = configure?.Invoke(mapping) ?? mapping; + _items.Add(mapping); + _items.Add(new TypeInfoMapping(nullableType, dataTypeName, + CreateComposedFactory(nullableType, mapping, nullableConverter, copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement, + TypeMatchPredicate = mapping.TypeMatchPredicate is not null + ? type => type is null || (Nullable.GetUnderlyingType(type) is { } underlying && mapping.TypeMatchPredicate(underlying)) + : null + }); + } + + public void AddResolverStructArrayType(string elementDataTypeName) where TElement : struct + => AddResolverStructArrayType(FindMapping(typeof(TElement), elementDataTypeName), FindMapping(typeof(TElement?), elementDataTypeName)); + + public void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping) where TElement : struct + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var nullableArrayTypeMatchPredicate = GetArrayTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type => + type is null || (Nullable.GetUnderlyingType(type) is { } underlying && underlying == typeof(TElement)))); + var listTypeMatchPredicate = elementMapping.TypeMatchPredicate is not null ? GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate) : null; + + AddResolverStructArrayType(elementMapping, nullableElementMapping, typeof(TElement[]), typeof(TElement?[]), + CreateArrayBasedConverterResolver, + CreateArrayBasedConverterResolver, suppressObjectMapping: TryFindMapping(typeof(object), elementMapping.DataTypeName, out _), arrayTypeMatchPredicate, nullableArrayTypeMatchPredicate); + + // Don't add the object converter for the list based converter. + AddResolverStructArrayType(elementMapping, nullableElementMapping, typeof(List), typeof(List), + CreateListBasedConverterResolver, + CreateListBasedConverterResolver, suppressObjectMapping: true, listTypeMatchPredicate, nullableArrayTypeMatchPredicate); + } + + // Lives outside to prevent capture of TElement. + void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, Type type, Type nullableType, + Func converter, Func nullableConverter, + bool suppressObjectMapping, Func? typeMatchPredicate, Func? nullableTypeMatchPredicate) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = nullableTypeMatchPredicate + }; + + _items.Add(arrayMapping); + _items.Add(nullableArrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => options.ArrayNullabilityMode switch + { + _ when !dataTypeNameMatch => throw new InvalidOperationException("Should not happen, please file a bug."), + ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.PerInstance => CreateComposedPerInstance( + arrayMapping.Factory(options, mapping, dataTypeNameMatch), + nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + mapping.DataTypeName + ), + _ => throw new ArgumentOutOfRangeException() + }) { MatchRequirement = MatchRequirement.DataTypeName }); + + PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullableInnerTypeInfo, string dataTypeName) + { + var resolver = + new PolymorphicArrayConverterResolver((PgResolverTypeInfo)innerTypeInfo, + (PgResolverTypeInfo)nullableInnerTypeInfo); + + return new PgResolverTypeInfo(innerTypeInfo.Options, resolver, + innerTypeInfo.Options.GetCanonicalTypeId(new DataTypeName(dataTypeName))) { SupportsWriting = false }; + } + } + + public void AddPolymorphicResolverArrayType(string elementDataTypeName, Func> elementToArrayConverterFactory) + => AddPolymorphicResolverArrayType(FindMapping(typeof(object), elementDataTypeName), elementToArrayConverterFactory); + + public void AddPolymorphicResolverArrayType(TypeInfoMapping elementMapping, Func> elementToArrayConverterFactory) + { + AddPolymorphicResolverArrayType(elementMapping, typeof(object), + (mapping, elemInfo) => new ArrayPolymorphicConverterResolver( + elemInfo.Options.GetCanonicalTypeId(new DataTypeName(mapping.DataTypeName)), elemInfo, elementToArrayConverterFactory(elemInfo.Options)) + , null); + + void AddPolymorphicResolverArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + var mapping = new TypeInfoMapping(type, arrayDataTypeName, + CreateComposedFactory(typeof(Array), elementMapping, converter, supportsWriting: false)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + _items.Add(mapping); + } + } + + /// Returns whether type matches any of the types we register pg arrays as. + public static bool IsArrayLikeType(Type type, [NotNullWhen(true)] out Type? elementType) + { + elementType = type switch + { + { IsArray: true } => type.GetElementType(), + { IsConstructedGenericType: true } when type.GetGenericTypeDefinition() == typeof(List<>) => type.GetGenericArguments()[0], + _ => null + }; + + return elementType is not null; + } + + static string GetArrayDataTypeName(string dataTypeName) + => DataTypeName.IsFullyQualified(dataTypeName.AsSpan()) + ? DataTypeName.ValidatedName(dataTypeName).ToArrayName().Value + : "_" + DataTypeName.FromDisplayName(dataTypeName).UnqualifiedName; + + static ArrayBasedArrayConverter CreateArrayBasedConverter(TypeInfoMapping mapping, PgTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ArrayBasedArrayConverter(elemInfo.GetConcreteResolution(), mapping.Type); + + ThrowBoxingNotSupported(resolver: false); + return default; + } + + static ListBasedArrayConverter CreateListBasedConverter(TypeInfoMapping mapping, PgTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ListBasedArrayConverter(elemInfo.GetConcreteResolution()); + + ThrowBoxingNotSupported(resolver: false); + return default; + } + + static ArrayConverterResolver CreateArrayBasedConverterResolver(TypeInfoMapping mapping, PgResolverTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ArrayConverterResolver(elemInfo, mapping.Type); + + ThrowBoxingNotSupported(resolver: true); + return default; + } + + static ArrayConverterResolver CreateListBasedConverterResolver(TypeInfoMapping mapping, PgResolverTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ArrayConverterResolver(elemInfo, mapping.Type); + + ThrowBoxingNotSupported(resolver: true); + return default; + } + + [DoesNotReturn] + static void ThrowBoxingNotSupported(bool resolver) + => throw new InvalidOperationException($"Boxing converters are not supported, manually construct a mapping over a casting converter{(resolver ? " resolver" : "")} instead."); +} + +public static class TypeInfoMappingHelpers +{ + internal static PgTypeId ResolveFullyQualifiedName(PgSerializerOptions options, string dataTypeName) + => !DataTypeName.IsFullyQualified(dataTypeName.AsSpan()) + ? options.ToCanonicalTypeId(options.DatabaseInfo.GetPostgresType(dataTypeName)) + : new(new DataTypeName(dataTypeName)); + + internal static PostgresType GetPgType(this TypeInfoMapping mapping, PgSerializerOptions options) + => !DataTypeName.IsFullyQualified(mapping.DataTypeName.AsSpan()) + ? options.DatabaseInfo.GetPostgresType(mapping.DataTypeName) + : options.DatabaseInfo.GetPostgresType(new DataTypeName(mapping.DataTypeName)); + + public static PgTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverter converter, DataFormat? preferredFormat = null, bool supportsWriting = true) + => new(options, converter, ResolveFullyQualifiedName(options, mapping.DataTypeName)) + { + PreferredFormat = preferredFormat, + SupportsWriting = supportsWriting + }; + + public static PgResolverTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverterResolver resolver, bool includeDataTypeName = true, DataFormat? preferredFormat = null, bool supportsWriting = true) + { + PgTypeId? pgTypeId = includeDataTypeName ? ResolveFullyQualifiedName(options, mapping.DataTypeName) : null; + return new(options, resolver, pgTypeId) + { + PreferredFormat = preferredFormat, + SupportsWriting = supportsWriting + }; + } +} diff --git a/src/Npgsql/Internal/TypeInfoResolverChain.cs b/src/Npgsql/Internal/TypeInfoResolverChain.cs new file mode 100644 index 0000000000..64e1f86e0d --- /dev/null +++ b/src/Npgsql/Internal/TypeInfoResolverChain.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +sealed class TypeInfoResolverChain : IPgTypeInfoResolver +{ + readonly IPgTypeInfoResolver[] _resolvers; + + public TypeInfoResolverChain(IEnumerable resolvers) + => _resolvers = resolvers.ToArray(); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + foreach (var resolver in _resolvers) + { + if (resolver.GetTypeInfo(type, dataTypeName, options) is { } info) + return info; + } + + return null; + } +} diff --git a/src/Npgsql/Internal/TypeMapping/IUserTypeMapping.cs b/src/Npgsql/Internal/TypeMapping/IUserTypeMapping.cs deleted file mode 100644 index aedc14e743..0000000000 --- a/src/Npgsql/Internal/TypeMapping/IUserTypeMapping.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeMapping; - -public interface IUserTypeMapping -{ - public string PgTypeName { get; } - public Type ClrType { get; } - - public NpgsqlTypeHandler CreateHandler(PostgresType pgType, NpgsqlConnector connector); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeMapping/TypeMapper.cs b/src/Npgsql/Internal/TypeMapping/TypeMapper.cs deleted file mode 100644 index eb6bb75f48..0000000000 --- a/src/Npgsql/Internal/TypeMapping/TypeMapper.cs +++ /dev/null @@ -1,539 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Reflection; -using Microsoft.Extensions.Logging; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeMapping; - -/// -/// Type mapper used to map types to type handlers. -/// -public sealed class TypeMapper -{ - internal NpgsqlConnector Connector { get; } - readonly object _writeLock = new(); - - NpgsqlDatabaseInfo? _databaseInfo; - - internal NpgsqlDatabaseInfo DatabaseInfo - { - get - { - var databaseInfo = _databaseInfo; - if (databaseInfo is null) - ThrowHelper.ThrowInvalidOperationException("Internal error: this type mapper hasn't yet been bound to a database info object"); - return databaseInfo; - } - } - - volatile TypeHandlerResolver[] _handlerResolvers; - volatile TypeMappingResolver[] _mappingResolvers; - internal NpgsqlTypeHandler UnrecognizedTypeHandler { get; } - - readonly ConcurrentDictionary _handlersByOID = new(); - readonly ConcurrentDictionary _handlersByNpgsqlDbType = new(); - readonly ConcurrentDictionary _handlersByClrType = new(); - readonly ConcurrentDictionary _handlersByDataTypeName = new(); - - readonly Dictionary _userTypeMappings = new(); - readonly INpgsqlNameTranslator _defaultNameTranslator; - - readonly ILogger _commandLogger; - - #region Construction - - internal TypeMapper(NpgsqlConnector connector, INpgsqlNameTranslator defaultNameTranslator) - { - Connector = connector; - _defaultNameTranslator = defaultNameTranslator; - UnrecognizedTypeHandler = new UnknownTypeHandler(Connector.TextEncoding); - _handlerResolvers = Array.Empty(); - _mappingResolvers = Array.Empty(); - _commandLogger = connector.LoggingConfiguration.CommandLogger; - } - - #endregion Constructors - - internal void Initialize( - NpgsqlDatabaseInfo databaseInfo, - List resolverFactories, - Dictionary userTypeMappings) - { - _databaseInfo = databaseInfo; - - var handlerResolvers = new TypeHandlerResolver[resolverFactories.Count]; - var mappingResolvers = new List(resolverFactories.Count); - for (var i = 0; i < resolverFactories.Count; i++) - { - handlerResolvers[i] = resolverFactories[i].Create(this, Connector); - var mappingResolver = resolverFactories[i].CreateMappingResolver(); - if (mappingResolver is not null) - mappingResolvers.Add(mappingResolver); - } - - // Add global mapper resolvers in backwards because they're inserted in the beginning - for (var i = resolverFactories.Count - 1; i >= 0; i--) - { - var globalMappingResolver = resolverFactories[i].CreateGlobalMappingResolver(); - if (globalMappingResolver is not null) - GlobalTypeMapper.Instance.TryAddMappingResolver(globalMappingResolver); - } - - _handlerResolvers = handlerResolvers; - _mappingResolvers = mappingResolvers.ToArray(); - - foreach (var userTypeMapping in userTypeMappings.Values) - { - if (DatabaseInfo.TryGetPostgresTypeByName(userTypeMapping.PgTypeName, out var pgType)) - { - _handlersByOID[pgType.OID] = - _handlersByDataTypeName[pgType.FullName] = - _handlersByDataTypeName[pgType.Name] = - _handlersByClrType[userTypeMapping.ClrType] = userTypeMapping.CreateHandler(pgType, Connector); - - _userTypeMappings[pgType.OID] = new(npgsqlDbType: null, pgType.Name, userTypeMapping.ClrType); - } - } - } - - #region Type handler lookup - - /// - /// Looks up a type handler by its PostgreSQL type's OID. - /// - /// A PostgreSQL type OID - /// A type handler that can be used to encode and decode values. - public NpgsqlTypeHandler ResolveByOID(uint oid) - => TryResolveByOID(oid, out var result) ? result : UnrecognizedTypeHandler; - - internal bool TryResolveByOID(uint oid, [NotNullWhen(true)] out NpgsqlTypeHandler? handler) - { - if (_handlersByOID.TryGetValue(oid, out handler)) - return true; - - return TryResolveLong(oid, out handler); - - bool TryResolveLong(uint oid, [NotNullWhen(true)] out NpgsqlTypeHandler? handler) - { - if (!DatabaseInfo.ByOID.TryGetValue(oid, out var pgType)) - { - handler = null; - return false; - } - - lock (_writeLock) - { - if ((handler = ResolveByPostgresType(pgType)) is not null) - { - _handlersByOID[oid] = handler; - return true; - } - - if ((handler = ResolveComplexTypeByDataTypeName(pgType.FullName, throwOnError: false)) is not null) - { - _handlersByOID[oid] = handler; - return true; - } - - handler = null; - return false; - } - } - } - - /// - /// Looks up a type handler by NpgsqlDbType. - /// - /// Parameter's NpgsqlDbType - /// A type handler that can be used to encode and decode values. - public NpgsqlTypeHandler ResolveByNpgsqlDbType(NpgsqlDbType npgsqlDbType) - { - if (_handlersByNpgsqlDbType.TryGetValue(npgsqlDbType, out var handler)) - return handler; - - return ResolveLong(npgsqlDbType); - - NpgsqlTypeHandler ResolveLong(NpgsqlDbType npgsqlDbType) - { - lock (_writeLock) - { - // First, try to resolve as a base type; translate the NpgsqlDbType to a PG data type name and look that up. - if (GlobalTypeMapper.NpgsqlDbTypeToDataTypeName(npgsqlDbType) is { } dataTypeName) - { - foreach (var resolver in _handlerResolvers) - { - try - { - if (resolver.ResolveByDataTypeName(dataTypeName) is { } handler) - return _handlersByNpgsqlDbType[npgsqlDbType] = handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, - $"Type resolver {resolver.GetType().Name} threw exception while resolving NpgsqlDbType {npgsqlDbType}"); - } - } - } - - // Can't find (or translate) PG data type name by NpgsqlDbType. - // This might happen because of flags (like Array, Range or Multirange). - foreach (var resolver in _handlerResolvers) - { - try - { - if (resolver.ResolveByNpgsqlDbType(npgsqlDbType) is { } handler) - return _handlersByNpgsqlDbType[npgsqlDbType] = handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, - $"Type resolver {resolver.GetType().Name} threw exception while resolving NpgsqlDbType {npgsqlDbType}"); - } - } - - if (npgsqlDbType.HasFlag(NpgsqlDbType.Array)) - { - var elementHandler = ResolveByNpgsqlDbType(npgsqlDbType & ~NpgsqlDbType.Array); - - if (elementHandler.PostgresType.Array is not { } pgArrayType) - throw new ArgumentException( - $"No array type could be found in the database for element {elementHandler.PostgresType}"); - - return _handlersByNpgsqlDbType[npgsqlDbType] = - elementHandler.CreateArrayHandler(pgArrayType, Connector.Settings.ArrayNullabilityMode); - } - - throw new NpgsqlException($"The NpgsqlDbType '{npgsqlDbType}' isn't present in your database. " + - "You may need to install an extension or upgrade to a newer version."); - } - } - } - - internal NpgsqlTypeHandler ResolveByDataTypeName(string typeName) - => ResolveByDataTypeNameCore(typeName) ?? ResolveComplexTypeByDataTypeName(typeName, throwOnError: true)!; - - NpgsqlTypeHandler? ResolveByDataTypeNameCore(string typeName) - { - if (_handlersByDataTypeName.TryGetValue(typeName, out var handler)) - return handler; - - return ResolveLong(typeName); - - NpgsqlTypeHandler? ResolveLong(string typeName) - { - lock (_writeLock) - { - foreach (var resolver in _handlerResolvers) - { - try - { - if (resolver.ResolveByDataTypeName(typeName) is { } handler) - return _handlersByDataTypeName[typeName] = handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, $"Type resolver {resolver.GetType().Name} threw exception while resolving data type name {typeName}"); - } - } - - return null; - } - } - } - - NpgsqlTypeHandler? ResolveByPostgresType(PostgresType type) - { - if (_handlersByDataTypeName.TryGetValue(type.FullName, out var handler)) - return handler; - - return ResolveLong(type); - - NpgsqlTypeHandler? ResolveLong(PostgresType type) - { - lock (_writeLock) - { - foreach (var resolver in _handlerResolvers) - { - try - { - if (resolver.ResolveByPostgresType(type) is { } handler) - return _handlersByDataTypeName[type.FullName] = handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, $"Type resolver {resolver.GetType().Name} threw exception while resolving data type name {type.FullName}"); - } - } - - return null; - } - } - } - - NpgsqlTypeHandler? ResolveComplexTypeByDataTypeName(string typeName, bool throwOnError) - { - lock (_writeLock) - { - var pgType = DatabaseInfo.GetPostgresTypeByName(typeName); - - switch (pgType) - { - case PostgresArrayType pgArrayType: - { - var elementHandler = ResolveByOID(pgArrayType.Element.OID); - return _handlersByDataTypeName[typeName] = - elementHandler.CreateArrayHandler(pgArrayType, Connector.Settings.ArrayNullabilityMode); - } - - case PostgresEnumType pgEnumType: - { - // A mapped enum would have been registered in _extraHandlersByDataTypeName and bound above - this is unmapped. - return _handlersByDataTypeName[typeName] = - new UnmappedEnumHandler(pgEnumType, _defaultNameTranslator, Connector.TextEncoding); - } - - case PostgresDomainType pgDomainType: - return _handlersByDataTypeName[typeName] = ResolveByOID(pgDomainType.BaseType.OID); - - case PostgresBaseType pgBaseType: - return throwOnError - ? throw new NotSupportedException($"PostgreSQL type '{pgBaseType}' isn't supported by Npgsql") - : null; - - case PostgresCompositeType pgCompositeType: - // We don't support writing unmapped composite types, but we do support reading unmapped composite types. - // So when we're invoked from ResolveOID (which is the read path), we don't want to raise an exception. - return throwOnError - ? throw new NotSupportedException( - $"Composite type '{pgCompositeType}' must be mapped with Npgsql before being used, see the docs.") - : null; - -#pragma warning disable CS0618 - case PostgresRangeType: - case PostgresMultirangeType: - return throwOnError - ? throw new NotSupportedException( - $"'{pgType}' is a range type; please call {nameof(NpgsqlSlimDataSourceBuilder.EnableRanges)} on {nameof(NpgsqlSlimDataSourceBuilder)} to enable ranges. " + - "See https://www.npgsql.org/doc/types/ranges.html for more information.") - : null; -#pragma warning restore CS0618 - - default: - throw new ArgumentOutOfRangeException($"Unhandled PostgreSQL type type: {pgType.GetType()}"); - } - } - } - - internal NpgsqlTypeHandler ResolveByValue(T value) - { - if (value is null) - return ResolveByClrType(typeof(T)); - - if (typeof(T).IsValueType) - { - // Attempt to resolve value types generically via the resolver. This is the efficient fast-path, where we don't even need to - // do a dictionary lookup (the JIT elides type checks in generic methods for value types) - NpgsqlTypeHandler? handler; - - foreach (var resolver in _handlerResolvers) - { - try - { - if ((handler = resolver.ResolveValueTypeGenerically(value)) is not null) - return handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, $"Type resolver {resolver.GetType().Name} threw exception while resolving value with type {typeof(T)}"); - } - } - - // There may still be some value types not resolved by the above, e.g. NpgsqlRange - } - - // Value types would have been resolved above, so this is a reference type - no JIT optimizations. - // We go through the regular logic (and there's no boxing). - return ResolveByValue((object)value); - } - - internal NpgsqlTypeHandler ResolveByValue(object value) - { - // We resolve as follows: - // 1. Cached by-type lookup (fast path). This will work for almost all types after the very first resolution. - // 2. Value-dependent type lookup (e.g. DateTime by Kind) via the resolvers. This includes complex types (e.g. array/range - // over DateTime), and the results cannot be cached. - // 3. Uncached by-type lookup (for the very first resolution of a given type) - - var type = value.GetType(); - if (_handlersByClrType.TryGetValue(type, out var handler)) - return handler; - - return ResolveLong(value, type); - - NpgsqlTypeHandler ResolveLong(object value, Type type) - { - foreach (var resolver in _handlerResolvers) - { - try - { - if (resolver.ResolveValueDependentValue(value) is { } handler) - return handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, $"Type resolver {resolver.GetType().Name} threw exception while resolving value with type {type}"); - } - } - - // ResolveByClrType either throws, or resolves a handler and caches it in _handlersByClrType (where it would be found above the - // next time we resolve this type) - return ResolveByClrType(type); - } - } - - // TODO: This is needed as a separate method only because of binary COPY, see #3957 - /// - /// Looks up a type handler by CLR Type. - /// - /// Parameter's CLR type - /// A type handler that can be used to encode and decode values. - public NpgsqlTypeHandler ResolveByClrType(Type type) - { - if (_handlersByClrType.TryGetValue(type, out var handler)) - return handler; - - return ResolveLong(type); - - NpgsqlTypeHandler ResolveLong(Type type) - { - lock (_writeLock) - { - foreach (var resolver in _handlerResolvers) - { - try - { - if (resolver.ResolveByClrType(type) is { } handler) - return _handlersByClrType[type] = handler; - } - catch (Exception e) - { - _commandLogger.LogError(e, $"Type resolver {resolver.GetType().Name} threw exception while resolving value with type {type}"); - } - } - - // Try to see if it is an array type - var arrayElementType = GetArrayListElementType(type); - if (arrayElementType is not null) - { - if (ResolveByClrType(arrayElementType) is not { } elementHandler) - throw new ArgumentException($"Array type over CLR type {arrayElementType.Name} isn't supported by Npgsql"); - - if (elementHandler.PostgresType.Array is not { } pgArrayType) - throw new ArgumentException( - $"No array type could be found in the database for element {elementHandler.PostgresType}"); - - return _handlersByClrType[type] = - elementHandler.CreateArrayHandler(pgArrayType, Connector.Settings.ArrayNullabilityMode); - } - - if (Nullable.GetUnderlyingType(type) is { } underlyingType && ResolveByClrType(underlyingType) is { } underlyingHandler) - return _handlersByClrType[type] = underlyingHandler; - - if (type.IsEnum) - { - return DatabaseInfo.TryGetPostgresTypeByName(GetPgName(type, _defaultNameTranslator), out var pgType) - && pgType is PostgresEnumType pgEnumType - ? _handlersByClrType[type] = new UnmappedEnumHandler(pgEnumType, _defaultNameTranslator, Connector.TextEncoding) - : throw new NotSupportedException( - $"Could not find a PostgreSQL enum type corresponding to {type.Name}. " + - "Consider mapping the enum before usage, refer to the documentation for more details."); - } - - if (typeof(IEnumerable).IsAssignableFrom(type)) - throw new NotSupportedException("IEnumerable parameters are not supported, pass an array or List instead"); - - throw new NotSupportedException($"The CLR type {type} isn't natively supported by Npgsql or your PostgreSQL. " + - $"To use it with a PostgreSQL composite you need to specify {nameof(NpgsqlParameter.DataTypeName)} or to map it, please refer to the documentation."); - } - - static Type? GetArrayListElementType(Type type) - { - var typeInfo = type.GetTypeInfo(); - if (typeInfo.IsArray) - return GetUnderlyingType(type.GetElementType()!); // The use of bang operator is justified here as Type.GetElementType() only returns null for the Array base class which can't be mapped in a useful way. - - var ilist = typeInfo.ImplementedInterfaces.FirstOrDefault(x => x.GetTypeInfo().IsGenericType && x.GetGenericTypeDefinition() == typeof(IList<>)); - if (ilist != null) - return GetUnderlyingType(ilist.GetGenericArguments()[0]); - - if (typeof(IList).IsAssignableFrom(type)) - throw new NotSupportedException("Non-generic IList is a supported parameter, but the NpgsqlDbType parameter must be set on the parameter"); - - return null; - - Type GetUnderlyingType(Type t) - => Nullable.GetUnderlyingType(t) ?? t; - } - } - } - - #endregion Type handler lookup - - internal bool TryGetMapping(PostgresType pgType, [NotNullWhen(true)] out TypeMappingInfo? mapping) - { - foreach (var resolver in _mappingResolvers) - if ((mapping = resolver.GetMappingByPostgresType(this, pgType)) is not null) - return true; - - switch (pgType) - { - case PostgresArrayType pgArrayType: - if (TryGetMapping(pgArrayType.Element, out var elementMapping)) - { - mapping = new(elementMapping.NpgsqlDbType | NpgsqlDbType.Array, pgType.DisplayName); - return true; - } - - break; - - case PostgresDomainType pgDomainType: - if (TryGetMapping(pgDomainType.BaseType, out var baseMapping)) - { - mapping = new(baseMapping.NpgsqlDbType, pgType.DisplayName, baseMapping.ClrTypes); - return true; - } - - break; - - case PostgresEnumType or PostgresCompositeType: - return _userTypeMappings.TryGetValue(pgType.OID, out mapping); - } - - mapping = null; - return false; - } - - internal (NpgsqlDbType? npgsqlDbType, PostgresType postgresType) GetTypeInfoByOid(uint oid) - { - if (!DatabaseInfo.ByOID.TryGetValue(oid, out var pgType)) - ThrowHelper.ThrowInvalidOperationException($"Couldn't find PostgreSQL type with OID {oid}"); - - if (TryGetMapping(pgType, out var mapping)) - return (mapping.NpgsqlDbType, pgType); - - return (null, pgType); - } - - static string GetPgName(Type clrType, INpgsqlNameTranslator nameTranslator) - => clrType.GetCustomAttribute()?.PgName - ?? nameTranslator.TranslateTypeName(clrType.Name); -} diff --git a/src/Npgsql/Internal/TypeMapping/TypeMappingResolver.cs b/src/Npgsql/Internal/TypeMapping/TypeMappingResolver.cs deleted file mode 100644 index af426e6f2f..0000000000 --- a/src/Npgsql/Internal/TypeMapping/TypeMappingResolver.cs +++ /dev/null @@ -1,25 +0,0 @@ -using System; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeMapping; - -public abstract class TypeMappingResolver -{ - public abstract string? GetDataTypeNameByClrType(Type clrType); - public virtual string? GetDataTypeNameByValueDependentValue(object value) => null; - public abstract TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName); - - /// - /// Gets type mapping information for a given PostgreSQL type. - /// Invoked in scenarios when mapping information is required, rather than a type handler for reading or writing. - /// - public virtual TypeMappingInfo? GetMappingByPostgresType(TypeMapper typeMapper, PostgresType type) - => GetMappingByDataTypeName(type.Name); - - internal TypeMappingInfo? GetMappingByValueDependentValue(object value) - => GetDataTypeNameByValueDependentValue(value) is { } dataTypeName ? GetMappingByDataTypeName(dataTypeName) : null; - - internal TypeMappingInfo? GetMappingByClrType(Type clrType) - => GetDataTypeNameByClrType(clrType) is { } dataTypeName ? GetMappingByDataTypeName(dataTypeName) : null; -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeMapping/UserCompositeTypeMappings.cs b/src/Npgsql/Internal/TypeMapping/UserCompositeTypeMappings.cs deleted file mode 100644 index 75d680b200..0000000000 --- a/src/Npgsql/Internal/TypeMapping/UserCompositeTypeMappings.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using Npgsql.Internal.TypeHandlers.CompositeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.Internal.TypeMapping; - -public interface IUserCompositeTypeMapping : IUserTypeMapping -{ - INpgsqlNameTranslator NameTranslator { get; } -} - -sealed class UserCompositeTypeMapping : IUserCompositeTypeMapping -{ - public string PgTypeName { get; } - public Type ClrType => typeof(T); - public INpgsqlNameTranslator NameTranslator { get; } - - public UserCompositeTypeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) - => (PgTypeName, NameTranslator) = (pgTypeName, nameTranslator); - - public NpgsqlTypeHandler CreateHandler(PostgresType pgType, NpgsqlConnector connector) - => new CompositeHandler((PostgresCompositeType)pgType, connector.TypeMapper, NameTranslator); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/TypeMapping/UserEnumTypeMappings.cs b/src/Npgsql/Internal/TypeMapping/UserEnumTypeMappings.cs deleted file mode 100644 index 9c2c3e35d7..0000000000 --- a/src/Npgsql/Internal/TypeMapping/UserEnumTypeMappings.cs +++ /dev/null @@ -1,46 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal.TypeMapping; - -public interface IUserEnumTypeMapping : IUserTypeMapping -{ - INpgsqlNameTranslator NameTranslator { get; } -} - -sealed class UserEnumTypeMapping : IUserEnumTypeMapping - where TEnum : struct, Enum -{ - public string PgTypeName { get; } - public Type ClrType => typeof(TEnum); - public INpgsqlNameTranslator NameTranslator { get; } - - readonly Dictionary _enumToLabel = new(); - readonly Dictionary _labelToEnum = new(); - - public UserEnumTypeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) - { - (PgTypeName, NameTranslator) = (pgTypeName, nameTranslator); - - foreach (var field in typeof(TEnum).GetFields(BindingFlags.Static | BindingFlags.Public)) - { - var attribute = (PgNameAttribute?)field.GetCustomAttributes(typeof(PgNameAttribute), false).FirstOrDefault(); - var enumName = attribute is null - ? nameTranslator.TranslateMemberName(field.Name) - : attribute.PgName; - var enumValue = (TEnum)field.GetValue(null)!; - - _enumToLabel[enumValue] = enumName; - _labelToEnum[enumName] = enumValue; - } - } - - public NpgsqlTypeHandler CreateHandler(PostgresType postgresType, NpgsqlConnector connector) - => new EnumHandler((PostgresEnumType)postgresType, _enumToLabel, _labelToEnum); -} \ No newline at end of file diff --git a/src/Npgsql/Internal/ValueMetadata.cs b/src/Npgsql/Internal/ValueMetadata.cs new file mode 100644 index 0000000000..ff041a3060 --- /dev/null +++ b/src/Npgsql/Internal/ValueMetadata.cs @@ -0,0 +1,9 @@ +namespace Npgsql.Internal; + +public readonly struct ValueMetadata +{ + public required DataFormat Format { get; init; } + public required Size BufferRequirement { get; init; } + public required Size Size { get; init; } + public object? WriteState { get; init; } +} diff --git a/src/Npgsql/MultiplexingDataSource.cs b/src/Npgsql/MultiplexingDataSource.cs index 03c8718216..e9e7fa3069 100644 --- a/src/Npgsql/MultiplexingDataSource.cs +++ b/src/Npgsql/MultiplexingDataSource.cs @@ -262,7 +262,7 @@ bool WriteCommand(NpgsqlConnector connector, NpgsqlCommand command, ref Multiple if (t.IsFaulted) { - FailWrite(conn, t.Exception!.UnwrapAggregate()); + FailWrite(conn, t.Exception!.InnerException!); return; } @@ -314,7 +314,7 @@ void Flush(NpgsqlConnector connector, ref MultiplexingStats stats) var conn = (NpgsqlConnector)o!; if (t.IsFaulted) { - FailWrite(conn, t.Exception!.UnwrapAggregate()); + FailWrite(conn, t.Exception!.InnerException!); return; } diff --git a/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs b/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs index cdb9bb40a8..c4ba594ba7 100644 --- a/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs +++ b/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs @@ -11,6 +11,8 @@ namespace Npgsql.NameTranslation; /// public sealed class NpgsqlSnakeCaseNameTranslator : INpgsqlNameTranslator { + internal static NpgsqlSnakeCaseNameTranslator Instance { get; } = new(); + readonly CultureInfo _culture; /// diff --git a/src/Npgsql/Npgsql.csproj b/src/Npgsql/Npgsql.csproj index 0c1c0600ff..53a0e38377 100644 --- a/src/Npgsql/Npgsql.csproj +++ b/src/Npgsql/Npgsql.csproj @@ -25,14 +25,14 @@ - + - + @@ -55,5 +55,4 @@ NpgsqlStrings.resx - diff --git a/src/Npgsql/NpgsqlBatchCommandCollection.cs b/src/Npgsql/NpgsqlBatchCommandCollection.cs index 58227ac69a..a79afa359b 100644 --- a/src/Npgsql/NpgsqlBatchCommandCollection.cs +++ b/src/Npgsql/NpgsqlBatchCommandCollection.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Data.Common; using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; namespace Npgsql; @@ -111,4 +110,4 @@ static NpgsqlBatchCommand Cast(DbBatchCommand? value) static void ThrowInvalidCastException(DbBatchCommand? value) => throw new InvalidCastException( $"The value \"{value}\" is not of type \"{nameof(NpgsqlBatchCommand)}\" and cannot be used in this batch command collection."); -} \ No newline at end of file +} diff --git a/src/Npgsql/NpgsqlBinaryExporter.cs b/src/Npgsql/NpgsqlBinaryExporter.cs index f772334feb..a7c0e395e9 100644 --- a/src/Npgsql/NpgsqlBinaryExporter.cs +++ b/src/Npgsql/NpgsqlBinaryExporter.cs @@ -6,9 +6,7 @@ using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.TypeMapping; +using Npgsql.Internal.Postgres; using NpgsqlTypes; using static Npgsql.Util.Statics; @@ -20,23 +18,27 @@ namespace Npgsql; /// public sealed class NpgsqlBinaryExporter : ICancelable { + const int BeforeRow = -2; + const int BeforeColumn = -1; + #region Fields and Properties NpgsqlConnector _connector; NpgsqlReadBuffer _buf; - TypeMapper _typeMapper; bool _isConsumed, _isDisposed; - int _leftToReadInDataMsg, _columnLen; + long _endOfMessagePos; short _column; ulong _rowsExported; + PgReader PgReader => _buf.PgReader; + /// /// The number of columns, as returned from the backend in the CopyInResponse. /// internal int NumColumns { get; private set; } - NpgsqlTypeHandler?[] _typeHandlerCache; + PgConverterInfo[] _columnInfoCache; readonly ILogger _copyLogger; @@ -61,10 +63,8 @@ internal NpgsqlBinaryExporter(NpgsqlConnector connector) { _connector = connector; _buf = connector.ReadBuffer; - _typeMapper = connector.TypeMapper; - _columnLen = int.MinValue; // Mark that the (first) column length hasn't been read yet - _column = -1; - _typeHandlerCache = null!; + _column = BeforeRow; + _columnInfoCache = null!; _copyLogger = connector.LoggingConfiguration.CopyLogger; } @@ -80,7 +80,7 @@ internal async Task Init(string copyToCommand, bool async, CancellationToken can switch (msg.Code) { case BackendMessageCode.CopyOutResponse: - copyOutResponse = (CopyOutResponseMessage) msg; + copyOutResponse = (CopyOutResponseMessage)msg; if (!copyOutResponse.IsBinary) { throw _connector.Break( @@ -98,14 +98,16 @@ internal async Task Init(string copyToCommand, bool async, CancellationToken can } NumColumns = copyOutResponse.NumColumns; - _typeHandlerCache = new NpgsqlTypeHandler[NumColumns]; + _columnInfoCache = new PgConverterInfo[NumColumns]; _rowsExported = 0; + _endOfMessagePos = _buf.CumulativeReadPosition; await ReadHeader(async); } async Task ReadHeader(bool async) { - _leftToReadInDataMsg = Expect(await _connector.ReadMessage(async), _connector).Length; + var msg = await _connector.ReadMessage(async); + _endOfMessagePos = _buf.CumulativeReadPosition + Expect(msg, _connector).Length; var headerLen = NpgsqlRawCopyStream.BinarySignature.Length + 4 + 4; await _buf.Ensure(headerLen, async); @@ -117,7 +119,6 @@ async Task ReadHeader(bool async) throw new NotSupportedException("Unsupported flags in COPY operation (OID inclusion?)"); _buf.ReadInt32(); // Header extensions, currently unused - _leftToReadInDataMsg -= headerLen; } #endregion @@ -148,38 +149,44 @@ public ValueTask StartRowAsync(CancellationToken cancellationToken = defaul async ValueTask StartRow(bool async, CancellationToken cancellationToken = default) { + CheckDisposed(); if (_isConsumed) return -1; using var registration = _connector.StartNestedCancellableOperation(cancellationToken); + // Consume and advance any active column. + if (_column >= 0) + await Commit(async, resumableOp: false); + // The very first row (i.e. _column == -1) is included in the header's CopyData message. // Otherwise we need to read in a new CopyData row (the docs specify that there's a CopyData // message per row). if (_column == NumColumns) - _leftToReadInDataMsg = Expect(await _connector.ReadMessage(async), _connector).Length; - else if (_column != -1) + { + var msg = Expect(await _connector.ReadMessage(async), _connector); + _endOfMessagePos = _buf.CumulativeReadPosition + msg.Length; + } + else if (_column != BeforeRow) ThrowHelper.ThrowInvalidOperationException("Already in the middle of a row"); await _buf.Ensure(2, async); - _leftToReadInDataMsg -= 2; var numColumns = _buf.ReadInt16(); if (numColumns == -1) { - Debug.Assert(_leftToReadInDataMsg == 0); Expect(await _connector.ReadMessage(async), _connector); Expect(await _connector.ReadMessage(async), _connector); Expect(await _connector.ReadMessage(async), _connector); - _column = -1; + _column = BeforeRow; _isConsumed = true; return -1; } Debug.Assert(numColumns == NumColumns); - _column = 0; + _column = BeforeColumn; _rowsExported++; return NumColumns; } @@ -194,7 +201,7 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = /// specify the type. /// /// The value of the column - public T Read() => Read(false).GetAwaiter().GetResult(); + public T Read() => Read(async: false).GetAwaiter().GetResult(); /// /// Reads the current column, returns its value and moves ahead to the next column. @@ -209,22 +216,33 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = public ValueTask ReadAsync(CancellationToken cancellationToken = default) { using (NoSynchronizationContextScope.Enter()) - return Read(true, cancellationToken); + return Read(async: true, cancellationToken); } ValueTask Read(bool async, CancellationToken cancellationToken = default) - { - CheckDisposed(); - - if (_column == -1 || _column == NumColumns) - ThrowHelper.ThrowInvalidOperationException("Not reading a row"); + => Read(async, null, cancellationToken); - var type = typeof(T); - var handler = _typeHandlerCache[_column]; - if (handler == null) - handler = _typeHandlerCache[_column] = _typeMapper.ResolveByClrType(type); + PgConverterInfo CreateConverterInfo(Type type, NpgsqlDbType? npgsqlDbType = null) + { + var options = _connector.SerializerOptions; + PgTypeId? pgTypeId = null; + if (npgsqlDbType.HasValue) + { + pgTypeId = npgsqlDbType.Value.ToDataTypeName() is { } name + ? options.GetCanonicalTypeId(name) + // Handle plugin types via lookup. + : GetRepresentationalOrDefault(npgsqlDbType.Value.ToUnqualifiedDataTypeNameOrThrow()); + } + var info = options.GetTypeInfo(type, pgTypeId) + ?? throw new NotSupportedException($"Reading is not supported for type '{type}'{(npgsqlDbType is null ? "" : $" and NpgsqlDbType '{npgsqlDbType}'")}"); + // Binary export has no type info so we only do caller-directed interpretation of data. + return info.Bind(new Field("?", info.PgTypeId!.Value, -1), DataFormat.Binary); - return DoRead(handler, async, cancellationToken); + PgTypeId GetRepresentationalOrDefault(string dataTypeName) + { + var type = options.DatabaseInfo.GetPostgresType(dataTypeName); + return options.ToCanonicalTypeId(type.GetRepresentationalType()); + } } /// @@ -240,7 +258,7 @@ ValueTask Read(bool async, CancellationToken cancellationToken = default) /// /// The .NET type of the column to be read. /// The value of the column - public T Read(NpgsqlDbType type) => Read(type, false).GetAwaiter().GetResult(); + public T Read(NpgsqlDbType type) => Read(async: false, type, CancellationToken.None).GetAwaiter().GetResult(); /// /// Reads the current column, returns its value according to and @@ -261,58 +279,76 @@ ValueTask Read(bool async, CancellationToken cancellationToken = default) public ValueTask ReadAsync(NpgsqlDbType type, CancellationToken cancellationToken = default) { using (NoSynchronizationContextScope.Enter()) - return Read(type, true, cancellationToken); + return Read(async: true, type, cancellationToken); } - ValueTask Read(NpgsqlDbType type, bool async, CancellationToken cancellationToken = default) + async ValueTask Read(bool async, NpgsqlDbType? type, CancellationToken cancellationToken) { CheckDisposed(); - if (_column == -1 || _column == NumColumns) + if (_column is BeforeRow) ThrowHelper.ThrowInvalidOperationException("Not reading a row"); - var handler = _typeHandlerCache[_column]; - if (handler == null) - handler = _typeHandlerCache[_column] = _typeMapper.ResolveByNpgsqlDbType(type); + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - return DoRead(handler, async, cancellationToken); - } + // Allow one more read if the field is a db null. + // We cannot allow endless rereads otherwise it becomes quite unclear when a column advance happens. + if (PgReader is { Resumable: true, FieldSize: -1 }) + { + await Commit(async, resumableOp: false); + return DbNullOrThrow(); + } - async ValueTask DoRead(NpgsqlTypeHandler handler, bool async, CancellationToken cancellationToken = default) - { - try + // We must commit the current column before reading the next one unless it was an IsNull call. + PgConverterInfo info; + if (!PgReader.Resumable || PgReader.CurrentRemaining != PgReader.FieldSize) { - using var registration = _connector.StartNestedCancellableOperation(cancellationToken); + await Commit(async, resumableOp: false); + info = GetInfo(); - await ReadColumnLenIfNeeded(async); + // We need to get info after potential I/O as we don't know beforehand at what column we're at. + var columnLen = await ReadColumnLenIfNeeded(async, resumableOp: false); + if (_column == NumColumns) + ThrowHelper.ThrowInvalidOperationException("No more columns left in the current row"); - if (_columnLen == -1) - { -#pragma warning disable CS8653 // A default expression introduces a null value when 'T' is a non-nullable reference type. - // When T is a Nullable, we support returning null - if (NullableHandler.Exists) - return default!; -#pragma warning restore CS8653 - throw new InvalidCastException("Column is null"); - } + if (columnLen is -1) + return DbNullOrThrow(); - // If we know the entire column is already in memory, use the code path without async - var result = NullableHandler.Exists - ? _columnLen <= _buf.ReadBytesLeft - ? NullableHandler.Read(handler, _buf, _columnLen) - : await NullableHandler.ReadAsync(handler, _buf, _columnLen, async) - : _columnLen <= _buf.ReadBytesLeft - ? handler.Read(_buf, _columnLen) - : await handler.Read(_buf, _columnLen, async); - - _leftToReadInDataMsg -= _columnLen; - _columnLen = int.MinValue; // Mark that the (next) column length hasn't been read yet - _column++; - return result; } - catch (Exception e) + else + info = GetInfo(); + + T result; + if (async) + { + await PgReader.StartReadAsync(info.BufferRequirement, cancellationToken); + result = info.AsObject + ? (T)await info.Converter.ReadAsObjectAsync(PgReader, cancellationToken) + : await info.GetConverter().ReadAsync(PgReader, cancellationToken); + await PgReader.EndReadAsync(); + } + else { - _connector.Break(e); - throw; + PgReader.StartRead(info.BufferRequirement); + result = info.AsObject + ? (T)info.Converter.ReadAsObject(PgReader) + : info.GetConverter().Read(PgReader); + PgReader.EndRead(); + } + + return result; + + PgConverterInfo GetInfo() + { + ref var cachedInfo = ref _columnInfoCache[_column]; + return cachedInfo.IsDefault ? cachedInfo = CreateConverterInfo(typeof(T), type) : cachedInfo; + } + + T DbNullOrThrow() + { + // When T is a Nullable, we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; + throw new InvalidCastException("Column is null"); } } @@ -323,8 +359,8 @@ public bool IsNull { get { - ReadColumnLenIfNeeded(false).GetAwaiter().GetResult(); - return _columnLen == -1; + Commit(async: false, resumableOp: true); + return ReadColumnLenIfNeeded(async: false, resumableOp: true).GetAwaiter().GetResult() is -1; } } @@ -348,26 +384,34 @@ async Task Skip(bool async, CancellationToken cancellationToken = default) using var registration = _connector.StartNestedCancellableOperation(cancellationToken); - await ReadColumnLenIfNeeded(async); - if (_columnLen != -1) - await _buf.Skip(_columnLen, async); - - _columnLen = int.MinValue; - _column++; + // We allow IsNull to have been called before skip. + if (PgReader.Initialized && PgReader is not { Resumable: true, FieldSize: -1 }) + await Commit(async, resumableOp: false); + await ReadColumnLenIfNeeded(async, resumableOp: false); + await PgReader.Consume(async, cancellationToken: cancellationToken); } #endregion #region Utilities - async Task ReadColumnLenIfNeeded(bool async) + ValueTask Commit(bool async, bool resumableOp) { - if (_columnLen == int.MinValue) - { - await _buf.Ensure(4, async); - _columnLen = _buf.ReadInt32(); - _leftToReadInDataMsg -= 4; - } + var resuming = PgReader is { Initialized: true, Resumable: true } && resumableOp; + if (!resuming) + _column++; + return PgReader.Commit(async, resuming); + } + + async ValueTask ReadColumnLenIfNeeded(bool async, bool resumableOp) + { + if (PgReader is { Resumable: true, FieldSize: -1 }) + return -1; + + await _buf.Ensure(4, async); + var columnLen = _buf.ReadInt32(); + PgReader.Init(columnLen, DataFormat.Binary, resumableOp); + return PgReader.FieldSize; } void CheckDisposed() @@ -423,8 +467,10 @@ async ValueTask DisposeAsync(bool async) try { using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); + // Be sure to commit the reader. + await PgReader.Commit(async, resuming: false); // Finish the current CopyData message - _buf.Skip(_leftToReadInDataMsg); + await _buf.Skip(checked((int)(_endOfMessagePos - _buf.CumulativeReadPosition)), async); // Read to the end _connector.SkipUntil(BackendMessageCode.CopyDone); // We intentionally do not pass a CancellationToken since we don't want to cancel cleanup @@ -458,7 +504,6 @@ void Cleanup() _connector = null; } - _typeMapper = null; _buf = null; _isDisposed = true; } diff --git a/src/Npgsql/NpgsqlBinaryImporter.cs b/src/Npgsql/NpgsqlBinaryImporter.cs index be963c1552..a57c071448 100644 --- a/src/Npgsql/NpgsqlBinaryImporter.cs +++ b/src/Npgsql/NpgsqlBinaryImporter.cs @@ -1,6 +1,4 @@ using System; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -43,6 +41,7 @@ public sealed class NpgsqlBinaryImporter : ICancelable NpgsqlParameter?[] _params; readonly ILogger _copyLogger; + PgWriter _pgWriter = null!; // Setup in Init /// /// Current timeout @@ -82,7 +81,7 @@ internal async Task Init(string copyFromCommand, bool async, CancellationToken c switch (msg.Code) { case BackendMessageCode.CopyInResponse: - copyInResponse = (CopyInResponseMessage) msg; + copyInResponse = (CopyInResponseMessage)msg; if (!copyInResponse.IsBinary) { throw _connector.Break( @@ -104,6 +103,8 @@ internal async Task Init(string copyFromCommand, bool async, CancellationToken c _rowsImported = 0; _buf.StartCopyMode(); WriteHeader(); + // Only init after header. + _pgWriter = _buf.GetWriter(_connector.DatabaseInfo); } void WriteHeader() @@ -144,6 +145,7 @@ async Task StartRow(bool async, CancellationToken cancellationToken = default) await _buf.Flush(async, cancellationToken); _buf.WriteInt16(NumColumns); + _pgWriter.Refresh(); _column = 0; _rowsImported++; } @@ -239,7 +241,7 @@ Task Write(T value, NpgsqlDbType npgsqlDbType, bool async, CancellationToken if (p == null) { // First row, create the parameter objects - _params[_column] = p = typeof(T) == typeof(object) + _params[_column] = p = typeof(T) == typeof(object) || typeof(T) == typeof(DBNull) ? new NpgsqlParameter() : new NpgsqlParameter(); p.NpgsqlDbType = npgsqlDbType; @@ -309,14 +311,14 @@ async Task Write(T value, NpgsqlParameter param, bool async, CancellationToke if (_column == -1) throw new InvalidOperationException("A row hasn't been started"); - if (value == null || value is DBNull) - { - await WriteNull(async, cancellationToken); - return; - } - - if (typeof(T) == typeof(object)) + if (typeof(T) == typeof(object) || typeof(T) == typeof(DBNull)) { + if (param.GetType() != typeof(NpgsqlParameter)) + { + var newParam = _params[_column] = new NpgsqlParameter(); + newParam.NpgsqlDbType = param.NpgsqlDbType; + param = newParam; + } param.Value = value; } else @@ -329,11 +331,17 @@ async Task Write(T value, NpgsqlParameter param, bool async, CancellationToke } typedParam.TypedValue = value; } - param.ResolveHandler(_connector.TypeMapper); - param.ValidateAndGetLength(); - param.LengthCache?.Rewind(); - await param.WriteWithLength(_buf, async, cancellationToken); - param.LengthCache?.Clear(); + param.ResolveTypeInfo(_connector.SerializerOptions); + param.Bind(out _, out _); + try + { + await param.Write(async, _pgWriter.WithFlushMode(async ? FlushMode.NonBlocking : FlushMode.Blocking), cancellationToken); + } + catch (Exception ex) + { + _connector.Break(ex); + throw; + } _column++; } @@ -363,6 +371,7 @@ async Task WriteNull(bool async, CancellationToken cancellationToken = default) await _buf.Flush(async, cancellationToken); _buf.WriteInt32(-1); + _pgWriter.Refresh(); _column++; } @@ -465,8 +474,8 @@ async ValueTask Complete(bool async, CancellationToken cancellationToken /// /// /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will - /// be reverted. - /// + /// be reverted. + /// /// public void Dispose() => Close(); @@ -476,8 +485,8 @@ async ValueTask Complete(bool async, CancellationToken cancellationToken /// /// /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will - /// be reverted. - /// + /// be reverted. + /// /// public ValueTask DisposeAsync() { @@ -513,8 +522,8 @@ async Task Cancel(bool async, CancellationToken cancellationToken = default) /// /// /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will - /// be reverted. - /// + /// be reverted. + /// /// public void Close() => CloseAsync(false).GetAwaiter().GetResult(); @@ -524,8 +533,8 @@ async Task Cancel(bool async, CancellationToken cancellationToken = default) /// /// /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will - /// be reverted. - /// + /// be reverted. + /// /// public ValueTask CloseAsync(CancellationToken cancellationToken = default) { diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index 6f83b8f0ad..77c192e601 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -17,6 +17,7 @@ using System.Threading.Channels; using Microsoft.Extensions.Logging; using Npgsql.Internal; +using Npgsql.Internal.Postgres; using Npgsql.Properties; namespace Npgsql; @@ -483,14 +484,14 @@ void DeriveParametersForFunction() throw new InvalidOperationException($"{CommandText} does not exist in pg_proc"); } - var typeMapper = c.InternalConnection!.Connector!.TypeMapper; + var serializerOptions = c.InternalConnection!.Connector!.SerializerOptions; for (var i = 0; i < types.Length; i++) { var param = new NpgsqlParameter(); - var (npgsqlDbType, postgresType) = typeMapper.GetTypeInfoByOid(types[i]); - + var postgresType = serializerOptions.DatabaseInfo.GetPostgresType(types[i]); + var npgsqlDbType = postgresType.DataTypeName.ToNpgsqlDbType(); param.DataTypeName = postgresType.DisplayName; param.PostgresType = postgresType; if (npgsqlDbType.HasValue) @@ -560,8 +561,9 @@ void DeriveParametersForQuery(NpgsqlConnector connector) var param = batchCommand.PositionalParameters[i]; var paramOid = paramTypeOIDs[i]; - var (npgsqlDbType, postgresType) = connector.TypeMapper.GetTypeInfoByOid(paramOid); - + var postgresType = connector.SerializerOptions.DatabaseInfo.GetPostgresType(paramOid); + // We want to keep any domain types visible on the parameter, it will internally do a representational lookup again if necessary. + var npgsqlDbType = postgresType.GetRepresentationalType().DataTypeName.ToNpgsqlDbType(); if (param.NpgsqlDbType != NpgsqlDbType.Unknown && param.NpgsqlDbType != npgsqlDbType) throw new NpgsqlException( "The backend parser inferred different types for parameters with the same name. Please try explicit casting within your SQL statement or batch or use different placeholder names."); @@ -649,7 +651,7 @@ Task Prepare(bool async, CancellationToken cancellationToken = default) { foreach (var batchCommand in InternalBatchCommands) { - batchCommand.Parameters.ProcessParameters(connector.TypeMapper, validateValues: false, CommandType); + batchCommand.Parameters.ProcessParameters(connector.SerializerOptions, validateValues: false, CommandType); ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); needToPrepare = batchCommand.ExplicitPrepare(connector) || needToPrepare; @@ -660,7 +662,7 @@ Task Prepare(bool async, CancellationToken cancellationToken = default) } else { - Parameters.ProcessParameters(connector.TypeMapper, validateValues: false, CommandType); + Parameters.ProcessParameters(connector.SerializerOptions, validateValues: false, CommandType); ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); foreach (var batchCommand in InternalBatchCommands) @@ -1346,7 +1348,6 @@ internal virtual async ValueTask ExecuteReader(CommandBehavior { if (connector is not null) { - var dataSource = connector.DataSource; var logger = connector.CommandLogger; cancellationToken.ThrowIfCancellationRequested(); @@ -1378,7 +1379,7 @@ internal virtual async ValueTask ExecuteReader(CommandBehavior goto case false; } - batchCommand.Parameters.ProcessParameters(dataSource.TypeMapper, validateParameterValues, CommandType); + batchCommand.Parameters.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); } } else @@ -1391,7 +1392,7 @@ internal virtual async ValueTask ExecuteReader(CommandBehavior ResetPreparation(); goto case false; } - Parameters.ProcessParameters(dataSource.TypeMapper, validateParameterValues, CommandType); + Parameters.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); } NpgsqlEventSource.Log.CommandStartPrepared(); @@ -1407,7 +1408,7 @@ internal virtual async ValueTask ExecuteReader(CommandBehavior { var batchCommand = InternalBatchCommands[i]; - batchCommand.Parameters.ProcessParameters(dataSource.TypeMapper, validateParameterValues, CommandType); + batchCommand.Parameters.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); if (connector.Settings.MaxAutoPrepare > 0 && batchCommand.TryAutoPrepare(connector)) @@ -1419,7 +1420,7 @@ internal virtual async ValueTask ExecuteReader(CommandBehavior } else { - Parameters.ProcessParameters(dataSource.TypeMapper, validateParameterValues, CommandType); + Parameters.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); if (connector.Settings.MaxAutoPrepare > 0) @@ -1513,13 +1514,13 @@ internal virtual async ValueTask ExecuteReader(CommandBehavior { foreach (var batchCommand in InternalBatchCommands) { - batchCommand.Parameters.ProcessParameters(dataSource.TypeMapper, validateValues: true, CommandType); + batchCommand.Parameters.ProcessParameters(dataSource.SerializerOptions, validateValues: true, CommandType); ProcessRawQuery(null, standardConformingStrings: true, batchCommand); } } else { - Parameters.ProcessParameters(dataSource.TypeMapper, validateValues: true, CommandType); + Parameters.ProcessParameters(dataSource.SerializerOptions, validateValues: true, CommandType); ProcessRawQuery(null, standardConformingStrings: true, batchCommand: null); } @@ -1733,10 +1734,9 @@ internal void FixupRowDescription(RowDescriptionMessage rowDescription, bool isF for (var i = 0; i < rowDescription.Count; i++) { var field = rowDescription[i]; - field.FormatCode = (UnknownResultTypeList == null || !isFirst ? AllResultTypesAreUnknown : UnknownResultTypeList[i]) - ? FormatCode.Text - : FormatCode.Binary; - field.ResolveHandler(); + field.DataFormat = (UnknownResultTypeList == null || !isFirst ? AllResultTypesAreUnknown : UnknownResultTypeList[i]) + ? DataFormat.Text + : DataFormat.Binary; } } @@ -1818,7 +1818,11 @@ public virtual NpgsqlCommand Clone() { var clone = new NpgsqlCommand(CommandText, InternalConnection, Transaction) { - CommandTimeout = CommandTimeout, CommandType = CommandType, DesignTimeVisible = DesignTimeVisible, _allResultTypesAreUnknown = _allResultTypesAreUnknown, _unknownResultTypeList = _unknownResultTypeList + CommandTimeout = CommandTimeout, + CommandType = CommandType, + DesignTimeVisible = DesignTimeVisible, + _allResultTypesAreUnknown = _allResultTypesAreUnknown, + _unknownResultTypeList = _unknownResultTypeList }; _parameters.CloneTo(clone._parameters); return clone; diff --git a/src/Npgsql/NpgsqlConnection.cs b/src/Npgsql/NpgsqlConnection.cs index 627dcb1443..53e2afe5b0 100644 --- a/src/Npgsql/NpgsqlConnection.cs +++ b/src/Npgsql/NpgsqlConnection.cs @@ -884,7 +884,7 @@ async Task CloseAsync(bool async) } } - Debug.Assert(connector.IsReady || connector.IsBroken); + Debug.Assert(connector.IsReady || connector.IsBroken, $"Connector is not ready or broken during close, it's {connector.State}"); Debug.Assert(connector.CurrentReader == null); Debug.Assert(connector.CurrentCopyOperation == null); diff --git a/src/Npgsql/NpgsqlConnectionStringBuilder.cs b/src/Npgsql/NpgsqlConnectionStringBuilder.cs index b927807844..c1a11c34c3 100644 --- a/src/Npgsql/NpgsqlConnectionStringBuilder.cs +++ b/src/Npgsql/NpgsqlConnectionStringBuilder.cs @@ -9,7 +9,6 @@ using System.Linq; using Npgsql.Internal; using Npgsql.Netstandard20; -using Npgsql.Properties; using Npgsql.Replication; namespace Npgsql; diff --git a/src/Npgsql/NpgsqlDataReader.cs b/src/Npgsql/NpgsqlDataReader.cs index db30da551c..cc86de063e 100644 --- a/src/Npgsql/NpgsqlDataReader.cs +++ b/src/Npgsql/NpgsqlDataReader.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; @@ -10,17 +11,14 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; -using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; +using Npgsql.Internal.Converters; using Npgsql.PostgresTypes; using Npgsql.Schema; -using Npgsql.Util; using NpgsqlTypes; using static Npgsql.Util.Statics; @@ -34,6 +32,9 @@ namespace Npgsql; public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator #pragma warning restore CA1010 { + static readonly Task TrueTask = Task.FromResult(true); + static readonly Task FalseTask = Task.FromResult(false); + internal NpgsqlCommand Command { get; private set; } = default!; internal NpgsqlConnector Connector { get; } NpgsqlConnection? _connection; @@ -52,6 +53,7 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator internal ReaderState State = ReaderState.Disposed; internal NpgsqlReadBuffer Buffer = default!; + PgReader PgReader => Buffer.PgReader; /// /// Holds the list of statements being executed by this reader. @@ -81,14 +83,6 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator /// int _column; - /// - /// For streaming types (e.g. bytea), holds the byte length of the column. - /// Does not include the length prefix. - /// - internal int ColumnLen; - - internal int PosInColumn; - /// /// The position in the buffer at which the current data row message ends. /// Used only when the row is consumed non-sequentially. @@ -102,13 +96,16 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator /// bool _canConsumeRowNonSequentially; - int _charPos; - /// /// The RowDescription message for the current resultset being processed /// internal RowDescriptionMessage? RowDescription; + /// + /// Stores the last converter info resolved by column, to speed up repeated reading. + /// + PgConverterInfo[]? ColumnInfoCache { get; set; } + ulong? _recordsAffected; /// @@ -124,17 +121,6 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator bool _isSchemaOnly; bool _isSequential; - /// - /// A stream that has been opened on a column. - /// - NpgsqlReadBuffer.ColumnStream? _columnStream; - - /// - /// Used to keep track of every unique row this reader object ever traverses. - /// This is used to detect whether nested DbDataReaders are still valid. - /// - internal ulong UniqueRowId; - internal NpgsqlNestedDataReader? CachedFreeNestedDataReader; long _startTimestamp; @@ -153,6 +139,7 @@ internal void Init( long startTimestamp = 0, Task? sendTask = null) { + Debug.Assert(ColumnInfoCache is null); Command = command; _connection = command.InternalConnection; _behavior = behavior; @@ -179,7 +166,6 @@ public override bool Read() { CheckClosedOrDisposed(); - UniqueRowId++; var fastRead = TryFastRead(); return fastRead.HasValue ? fastRead.Value @@ -197,10 +183,9 @@ public override Task ReadAsync(CancellationToken cancellationToken) { CheckClosedOrDisposed(); - UniqueRowId++; var fastRead = TryFastRead(); if (fastRead.HasValue) - return fastRead.Value ? PGUtil.TrueTask : PGUtil.FalseTask; + return fastRead.Value ? TrueTask : FalseTask; using (NoSynchronizationContextScope.Enter()) return Read(true, cancellationToken); @@ -252,8 +237,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) async Task Read(bool async, CancellationToken cancellationToken = default) { - var registration = Connector.StartNestedCancellableOperation(cancellationToken); - + using var registration = Connector.StartNestedCancellableOperation(cancellationToken); try { switch (State) @@ -304,13 +288,11 @@ async Task Read(bool async, CancellationToken cancellationToken = default) } catch { - State = ReaderState.Consumed; + // Break may have progressed the reader already. + if (State is not ReaderState.Closed) + State = ReaderState.Consumed; throw; } - finally - { - registration.Dispose(); - } } ValueTask ReadMessage(bool async) @@ -387,7 +369,11 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo case BackendMessageCode.EmptyQueryResponse: ProcessMessage(completedMsg); - if (_statements[StatementIndex].AppendErrorBarrier ?? Command.EnableErrorBarriers) + var statement = _statements[StatementIndex]; + if (statement.IsPrepared && ColumnInfoCache is not null) + RowDescription!.SetConverterInfoCache(new(ColumnInfoCache, 0, _numColumns)); + + if (statement.AppendErrorBarrier ?? Command.EnableErrorBarriers) Expect(await Connector.ReadMessage(async), Connector); break; @@ -402,8 +388,11 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo break; case ReaderState.BetweenResults: + { + if (StatementIndex >= 0 && _statements[StatementIndex].IsPrepared && ColumnInfoCache is not null) + RowDescription!.SetConverterInfoCache(new(ColumnInfoCache, 0, _numColumns)); break; - + } case ReaderState.Consumed: case ReaderState.Closed: case ReaderState.Disposed: @@ -474,7 +463,20 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo }; } - if (RowDescription == null) + if (RowDescription is not null) + { + if (ColumnInfoCache?.Length >= RowDescription.Count) + Array.Clear(ColumnInfoCache, 0, RowDescription.Count); + else + { + if (ColumnInfoCache is { } cache) + ArrayPool.Shared.Return(cache, clearArray: true); + ColumnInfoCache = ArrayPool.Shared.Rent(RowDescription.Count); + } + if (statement.IsPrepared) + RowDescription.LoadConverterInfoCache(ColumnInfoCache); + } + else { // Statement did not generate a resultset (e.g. INSERT) // Read and process its completion message and move on to the next statement @@ -605,7 +607,9 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo } } - State = ReaderState.Consumed; + // Break may have progressed the reader already. + if (State is not ReaderState.Closed) + State = ReaderState.Consumed; throw; } } @@ -648,12 +652,11 @@ void PopulateOutputParameters() p.Value = pending.Dequeue(); } + PgReader.Commit(async: false, resuming: false).GetAwaiter().GetResult(); State = ReaderState.BeforeResult; // Set the state back Buffer.ReadPosition = currentPosition; // Restore position _column = -1; - ColumnLen = -1; - PosInColumn = 0; } /// @@ -739,18 +742,29 @@ async Task NextResultSchemaOnly(bool async, bool isConsuming = false, Canc } // Found a resultset - if (RowDescription != null) + if (RowDescription is not null) + { + if (ColumnInfoCache?.Length >= RowDescription.Count) + Array.Clear(ColumnInfoCache, 0, RowDescription.Count); + else + { + if (ColumnInfoCache is { } cache) + ArrayPool.Shared.Return(cache, clearArray: true); + ColumnInfoCache = ArrayPool.Shared.Rent(RowDescription.Count); + } return true; + } } - RowDescription = null; State = ReaderState.Consumed; - + RowDescription = null; return false; } catch (Exception e) { - State = ReaderState.Consumed; + // Break may have progressed the reader already. + if (State is not ReaderState.Closed) + State = ReaderState.Consumed; // Reference the triggering statement from the exception if (e is PostgresException postgresException && StatementIndex >= 0 && StatementIndex < _statements.Count) @@ -832,12 +846,11 @@ void ProcessDataRowMessage(DataRowMessage msg) // recapture the connector's buffer on each new DataRow. // Note that this can happen even in sequential mode, if the row description message is big // (see #2003) - Buffer = Connector.ReadBuffer; + if (!ReferenceEquals(Buffer, Connector.ReadBuffer)) + Buffer = Connector.ReadBuffer; _hasRows = true; _column = -1; - ColumnLen = -1; - PosInColumn = 0; // We assume that the row's number of columns is identical to the description's _numColumns = Buffer.ReadInt16(); @@ -1115,7 +1128,7 @@ internal async Task Close(bool connectionClosing, bool async, bool isDisposing) { await Consume(async); } - catch (Exception ex) when (ex is OperationCanceledException or NpgsqlException { InnerException : TimeoutException }) + catch (Exception ex) when (ex is OperationCanceledException or NpgsqlException { InnerException: TimeoutException }) { // Timeout/cancellation - completely normal, consume has basically completed. } @@ -1187,6 +1200,12 @@ internal async Task Cleanup(bool async, bool connectionClosing = false, bool isD } } + if (ColumnInfoCache is { } cache) + { + ColumnInfoCache = null; + ArrayPool.Shared.Return(cache, clearArray: true); + } + State = ReaderState.Closed; Command.State = CommandState.Idle; Connector.CurrentReader = null; @@ -1238,84 +1257,84 @@ internal async Task Cleanup(bool async, bool connectionClosing = false, bool isD /// /// The zero-based column ordinal. /// The value of the specified column. - public override bool GetBoolean(int ordinal) => GetFieldValue(ordinal); + public override bool GetBoolean(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a byte. /// /// The zero-based column ordinal. /// The value of the specified column. - public override byte GetByte(int ordinal) => GetFieldValue(ordinal); + public override byte GetByte(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a single character. /// /// The zero-based column ordinal. /// The value of the specified column. - public override char GetChar(int ordinal) => GetFieldValue(ordinal); + public override char GetChar(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a 16-bit signed integer. /// /// The zero-based column ordinal. /// The value of the specified column. - public override short GetInt16(int ordinal) => GetFieldValue(ordinal); + public override short GetInt16(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a 32-bit signed integer. /// /// The zero-based column ordinal. /// The value of the specified column. - public override int GetInt32(int ordinal) => GetFieldValue(ordinal); + public override int GetInt32(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a 64-bit signed integer. /// /// The zero-based column ordinal. /// The value of the specified column. - public override long GetInt64(int ordinal) => GetFieldValue(ordinal); + public override long GetInt64(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a object. /// /// The zero-based column ordinal. /// The value of the specified column. - public override DateTime GetDateTime(int ordinal) => GetFieldValue(ordinal); + public override DateTime GetDateTime(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as an instance of . /// /// The zero-based column ordinal. /// The value of the specified column. - public override string GetString(int ordinal) => GetFieldValue(ordinal); + public override string GetString(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a object. /// /// The zero-based column ordinal. /// The value of the specified column. - public override decimal GetDecimal(int ordinal) => GetFieldValue(ordinal); + public override decimal GetDecimal(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a double-precision floating point number. /// /// The zero-based column ordinal. /// The value of the specified column. - public override double GetDouble(int ordinal) => GetFieldValue(ordinal); + public override double GetDouble(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a single-precision floating point number. /// /// The zero-based column ordinal. /// The value of the specified column. - public override float GetFloat(int ordinal) => GetFieldValue(ordinal); + public override float GetFloat(int ordinal) => GetFieldValueCore(ordinal); /// /// Gets the value of the specified column as a globally-unique identifier (GUID). /// /// The zero-based column ordinal. /// The value of the specified column. - public override Guid GetGuid(int ordinal) => GetFieldValue(ordinal); + public override Guid GetGuid(int ordinal) => GetFieldValueCore(ordinal); /// /// Populates an array of objects with the column values of the current row. @@ -1356,7 +1375,7 @@ public override int GetValues(object[] values) /// /// The zero-based column ordinal. /// The value of the specified column. - public TimeSpan GetTimeSpan(int ordinal) => GetFieldValue(ordinal); + public TimeSpan GetTimeSpan(int ordinal) => GetFieldValueCore(ordinal); /// protected override DbDataReader GetDbDataReader(int ordinal) => GetData(ordinal); @@ -1370,30 +1389,33 @@ public override int GetValues(object[] values) /// A data reader. public new NpgsqlNestedDataReader GetData(int ordinal) { + if (_isSequential) + throw new NotSupportedException("GetData() not supported in sequential mode."); + var field = CheckRowAndGetField(ordinal); var type = field.PostgresType; var isArray = type is PostgresArrayType; var elementType = isArray ? ((PostgresArrayType)type).Element : type; var compositeType = elementType as PostgresCompositeType; - if (elementType.InternalName != "record" && compositeType == null) + if (field.DataFormat is DataFormat.Text || (elementType.InternalName != "record" && compositeType == null)) throw new InvalidCastException("GetData() not supported for type " + field.TypeDisplayName); - SeekToColumn(ordinal, false).GetAwaiter().GetResult(); - if (ColumnLen == -1) + var columnLength = SeekToColumn(async: false, ordinal, field, resumableOp: true).GetAwaiter().GetResult(); + if (columnLength is -1) ThrowHelper.ThrowInvalidCastException_NoValue(field); - if (_isSequential) - throw new NotSupportedException("GetData() not supported in sequential mode."); + if (PgReader.FieldOffset > 0) + PgReader.Rewind(PgReader.FieldOffset); var reader = CachedFreeNestedDataReader; if (reader != null) { CachedFreeNestedDataReader = null; - reader.Init(UniqueRowId, compositeType); + reader.Init(compositeType); } else { - reader = new NpgsqlNestedDataReader(this, null, UniqueRowId, 1, compositeType); + reader = new NpgsqlNestedDataReader(this, null, 1, compositeType); } if (isArray) reader.InitArray(); @@ -1425,34 +1447,22 @@ public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); var field = CheckRowAndGetField(ordinal); - var handler = field.Handler; - if (!(handler is ByteaHandler)) - throw new InvalidCastException("GetBytes() not supported for type " + field.Name); - - SeekToColumn(ordinal, false).GetAwaiter().GetResult(); - if (ColumnLen is -1) + var columnLength = SeekToColumn(async: false, ordinal, field, resumableOp: true).GetAwaiter().GetResult(); + if (columnLength == -1) ThrowHelper.ThrowInvalidCastException_NoValue(field); if (buffer is null) - return ColumnLen; - - var dataOffset2 = (int)dataOffset; - SeekInColumn(dataOffset2, false).GetAwaiter().GetResult(); + return columnLength; - // Attempt to read beyond the end of the column - if (dataOffset2 + length > ColumnLen) - length = Math.Max(ColumnLen - dataOffset2, 0); - - var left = length; - while (left > 0) - { - var read = Buffer.Read(new Span(buffer, bufferOffset, left)); - bufferOffset += read; - left -= read; - } + // Move to offset + if (_isSequential && PgReader.FieldOffset > dataOffset) + ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); - PosInColumn += length; + PgReader.Seek((int)dataOffset); + // At offset, read into buffer. + length = Math.Min(length, PgReader.FieldRemaining); + PgReader.ReadBytes(new Span(buffer, bufferOffset, length)); return length; } @@ -1461,7 +1471,8 @@ public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int /// /// The zero-based column ordinal. /// The returned object. - public override Stream GetStream(int ordinal) => GetStream(ordinal, false).Result; + public override Stream GetStream(int ordinal) + => GetFieldValueCore(ordinal); /// /// Retrieves data as a . @@ -1472,31 +1483,7 @@ public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int /// /// The returned object. public Task GetStreamAsync(int ordinal, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return GetStream(ordinal, true, cancellationToken).AsTask(); - } - - ValueTask GetStream(int ordinal, bool async, CancellationToken cancellationToken = default) => - GetStreamInternal(CheckRowAndGetField(ordinal), ordinal, async, cancellationToken); - - async ValueTask GetStreamInternal(FieldDescription field, int ordinal, bool async, CancellationToken cancellationToken = default) - { - if (_columnStream is { IsDisposed: false }) - ThrowHelper.ThrowInvalidOperationException("A stream is already open for this reader"); - - using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - - await SeekToColumn(ordinal, async, cancellationToken); - if (_isSequential) - CheckColumnStart(); - - if (ColumnLen == -1) - ThrowHelper.ThrowInvalidCastException_NoValue(field); - - PosInColumn += ColumnLen; - return _columnStream = (NpgsqlReadBuffer.ColumnStream)Buffer.GetStream(ColumnLen, !_isSequential); - } + => GetFieldValueAsync(ordinal, cancellationToken); #endregion @@ -1520,96 +1507,30 @@ public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); - var field = CheckRowAndGetField(ordinal); - var handler = field.Handler as TextHandler; - if (handler == null) - throw new InvalidCastException("The GetChars method is not supported for type " + field.Name); + // Check whether we can do resumable reads. + var field = GetInfo(ordinal, typeof(GetChars), out var converter, out var bufferRequirement, out var asObject); + if (converter is not IResumableRead { Supported: true }) + throw new NotSupportedException("The GetChars method is not supported for this column type"); - SeekToColumn(ordinal, false).GetAwaiter().GetResult(); - if (ColumnLen == -1) + var columnLength = SeekToColumn(async: false, ordinal, field, resumableOp: true).GetAwaiter().GetResult(); + if (columnLength == -1) ThrowHelper.ThrowInvalidCastException_NoValue(field); - if (PosInColumn == 0) - _charPos = 0; - - var decoder = Buffer.TextEncoding.GetDecoder(); + dataOffset = buffer is null ? 0 : dataOffset; + PgReader.InitCharsRead(checked((int)dataOffset), + buffer is not null ? new ArraySegment(buffer, bufferOffset, length) : (ArraySegment?)null, + out var previousDataOffset); - if (buffer == null) - { - // Note: Getting the length of a text column means decoding the entire field, - // very inefficient and also consumes the column in sequential mode. But this seems to - // be SqlClient's behavior as well. - var (bytesSkipped, charsSkipped) = SkipChars(decoder, int.MaxValue, ColumnLen - PosInColumn); - Debug.Assert(bytesSkipped == ColumnLen - PosInColumn); - PosInColumn += bytesSkipped; - _charPos += charsSkipped; - return _charPos; - } - - if (PosInColumn == ColumnLen || dataOffset < _charPos) - { - // Either the column has already been read (e.g. GetString()) or a previous GetChars() - // has positioned us in the column *after* the requested read start offset. Seek back - // (this will throw for sequential) - SeekInColumn(0, false).GetAwaiter().GetResult(); - _charPos = 0; - } - - if (dataOffset > _charPos) - { - var charsToSkip = (int)dataOffset - _charPos; - var (bytesSkipped, charsSkipped) = SkipChars(decoder, charsToSkip, ColumnLen - PosInColumn); - decoder.Reset(); - PosInColumn += bytesSkipped; - _charPos += charsSkipped; - if (charsSkipped < charsToSkip) // data offset is beyond the column's end - return 0; - } - - // We're now positioned at the start of the segment of characters we need to read. - if (length == 0) - return 0; - - var (bytesRead, charsRead) = DecodeChars(decoder, buffer.AsSpan(bufferOffset, length), ColumnLen - PosInColumn); - - PosInColumn += bytesRead; - _charPos += charsRead; - return charsRead; - } - - (int BytesRead, int CharsRead) DecodeChars(Decoder decoder, Span output, int byteCount) - { - var (bytesRead, charsRead) = (0, 0); - var outputLength = output.Length; - - while (true) - { - Buffer.Ensure(1); // Make sure we have at least some data - var maxBytes = Math.Min(byteCount - bytesRead, Buffer.ReadBytesLeft); - var bytes = Buffer.Buffer.AsSpan(Buffer.ReadPosition, maxBytes); - decoder.Convert(bytes, output, false, out var bytesUsed, out var charsUsed, out _); - Buffer.ReadPosition += bytesUsed; - bytesRead += bytesUsed; - charsRead += charsUsed; - if (charsRead == outputLength || bytesRead == byteCount) - break; - output = output.Slice(charsUsed); - } + if (_isSequential && previousDataOffset > dataOffset) + ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); - return (bytesRead, charsRead); - } - - internal (int BytesSkipped, int CharsSkipped) SkipChars(Decoder decoder, int charCount, int byteCount) - { - Span tempCharBuf = stackalloc char[512]; - var (charsSkipped, bytesSkipped) = (0, 0); - while (charsSkipped < charCount && bytesSkipped < byteCount) - { - var (bytesRead, charsRead) = DecodeChars(decoder, tempCharBuf.Slice(0, Math.Min(charCount, tempCharBuf.Length)), byteCount); - bytesSkipped += bytesRead; - charsSkipped += charsRead; - } - return (bytesSkipped, charsSkipped); + PgReader.StartRead(bufferRequirement); + var result = asObject + ? (GetChars)converter.ReadAsObject(PgReader) + : ((PgConverter)converter).Read(PgReader); + PgReader.AdvanceCharsRead(result.Read); + PgReader.EndRead(); + return result.Read; } /// @@ -1618,7 +1539,7 @@ public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int /// The zero-based column ordinal. /// The returned object. public override TextReader GetTextReader(int ordinal) - => GetTextReader(ordinal, false).Result; + => GetFieldValueCore(ordinal); /// /// Retrieves data as a . @@ -1629,25 +1550,7 @@ public override TextReader GetTextReader(int ordinal) /// /// The returned object. public Task GetTextReaderAsync(int ordinal, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return GetTextReader(ordinal, true, cancellationToken).AsTask(); - } - - async ValueTask GetTextReader(int ordinal, bool async, CancellationToken cancellationToken = default) - { - var field = CheckRowAndGetField(ordinal); - - if (field.Handler is ITextReaderHandler handler) - { - var stream = async - ? await GetStreamInternal(field, ordinal, true, cancellationToken) - : GetStreamInternal(field, ordinal, false, CancellationToken.None).Result; - return handler.GetTextReader(stream, Buffer); - } - - throw new InvalidCastException($"The GetTextReader method is not supported for type {field.PostgresType.DisplayName}"); - } + => GetFieldValueAsync(ordinal, cancellationToken); #endregion @@ -1664,18 +1567,40 @@ async ValueTask GetTextReader(int ordinal, bool async, CancellationT /// public override Task GetFieldValueAsync(int ordinal, CancellationToken cancellationToken) { - if (typeof(T) == typeof(Stream)) - return (Task)(object)GetStreamAsync(ordinal, cancellationToken); - - if (typeof(T) == typeof(TextReader)) - return (Task)(object)GetTextReaderAsync(ordinal, cancellationToken); - // In non-sequential, we know that the column is already buffered - no I/O will take place if (!_isSequential) - return Task.FromResult(GetFieldValue(ordinal)); + return Task.FromResult(GetFieldValueCore(ordinal)); using (NoSynchronizationContextScope.Enter()) - return GetFieldValueSequential(ordinal, true, cancellationToken).AsTask(); + return Core(ordinal, cancellationToken).AsTask(); + + async ValueTask Core(int ordinal, CancellationToken cancellationToken) + { + using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + var isStream = typeof(T) == typeof(Stream); + var field = GetInfo(ordinal, isStream ? null : typeof(T), out var converter, out var bufferRequirement, out var asObject); + + var columnLength = await SeekToColumn(async: true, ordinal, field); + if (columnLength == -1) + return DbNullValueOrThrow(field); + + if (isStream || typeof(T) == typeof(TextReader)) + { + PgReader.ThrowIfStreamActive(); + + // The only statically mapped converter, it always exists. + if (isStream) + return (T)(object)PgReader.GetStream(canSeek: !_isSequential); + } + + Debug.Assert(asObject || converter is PgConverter); + await PgReader.StartReadAsync(bufferRequirement, cancellationToken); + var result = asObject + ? (T)await converter.ReadAsObjectAsync(PgReader, cancellationToken) + : await Unsafe.As>(converter).ReadAsync(PgReader, cancellationToken); + await PgReader.EndReadAsync(); + return result; + } } /// @@ -1684,93 +1609,40 @@ public override Task GetFieldValueAsync(int ordinal, CancellationToken can /// Synchronously gets the value of the specified column as a type. /// The column to be retrieved. /// The column to be retrieved. - public override T GetFieldValue(int ordinal) - { - if (typeof(T) == typeof(Stream)) - return (T)(object)GetStream(ordinal); + public override T GetFieldValue(int ordinal) => GetFieldValueCore(ordinal); - if (typeof(T) == typeof(TextReader)) - return (T)(object)GetTextReader(ordinal); - - if (_isSequential) - return GetFieldValueSequential(ordinal, false).GetAwaiter().GetResult(); - - // In non-sequential, we know that the column is already buffered - no I/O will take place - - var field = CheckRowAndGetField(ordinal); - SeekToColumnNonSequential(ordinal); - - if (ColumnLen == -1) - { - // When T is a Nullable (and only in that case), we support returning null - if (NullableHandler.Exists) - return default!; - - if (typeof(T) == typeof(object)) - return (T)(object)DBNull.Value; - - ThrowHelper.ThrowInvalidCastException_NoValue(field); - } - - // We don't handle exceptions or update PosInColumn - // As with non-sequential reads we always just move to the start/end of the column - return NullableHandler.Exists - ? NullableHandler.Read(field.Handler, Buffer, ColumnLen, field) - : typeof(T) == typeof(object) - ? (T)field.Handler.ReadAsObject(Buffer, ColumnLen, field) - : field.Handler.Read(Buffer, ColumnLen, field); - } - - async ValueTask GetFieldValueSequential(int column, bool async, CancellationToken cancellationToken = default) + T GetFieldValueCore(int ordinal) { - using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - - var field = CheckRowAndGetField(column); - await SeekToColumnSequential(column, async, CancellationToken.None); - CheckColumnStart(); - - if (ColumnLen == -1) - { - // When T is a Nullable (and only in that case), we support returning null - if (NullableHandler.Exists) - return default!; + // The only statically mapped converter, it always exists. + if (typeof(T) == typeof(Stream)) + return GetStream(); - if (typeof(T) == typeof(object)) - return (T)(object)DBNull.Value; + var field = GetInfo(ordinal, typeof(T), out var converter, out var bufferRequirement, out var asObject); - ThrowHelper.ThrowInvalidCastException_NoValue(field); - } + if (typeof(T) == typeof(TextReader)) + PgReader.ThrowIfStreamActive(); + + var columnLength = SeekToColumn(async: false, ordinal, field).GetAwaiter().GetResult(); + if (columnLength == -1) + return DbNullValueOrThrow(field); + + Debug.Assert(asObject || converter is PgConverter); + PgReader.StartRead(bufferRequirement); + var result = asObject + ? (T)converter.ReadAsObject(PgReader) + : Unsafe.As>(converter).Read(PgReader); + PgReader.EndRead(); + return result; - var position = Buffer.ReadPosition; - try - { - return NullableHandler.Exists - ? ColumnLen <= Buffer.ReadBytesLeft - ? NullableHandler.Read(field.Handler, Buffer, ColumnLen, field) - : await NullableHandler.ReadAsync(field.Handler, Buffer, ColumnLen, async, field) - : typeof(T) == typeof(object) - ? ColumnLen <= Buffer.ReadBytesLeft - ? (T)field.Handler.ReadAsObject(Buffer, ColumnLen, field) - : (T)await field.Handler.ReadAsObject(Buffer, ColumnLen, async, field) - : ColumnLen <= Buffer.ReadBytesLeft - ? field.Handler.Read(Buffer, ColumnLen, field) - : await field.Handler.Read(Buffer, ColumnLen, async, field); - } - catch + [MethodImpl(MethodImplOptions.NoInlining)] + T GetStream() { - if (Connector.State != ConnectorState.Broken) - { - var writtenBytes = Buffer.ReadPosition - position; - var remainingBytes = ColumnLen - writtenBytes; - if (remainingBytes > 0) - await Buffer.Skip(remainingBytes, async); - } - throw; - } - finally - { - // Important: position must still be updated - PosInColumn += ColumnLen; + var field = GetInfo(ordinal, null, out _, out _, out _); + PgReader.ThrowIfStreamActive(); + var columnLength = SeekToColumn(async: false, ordinal, field).GetAwaiter().GetResult(); + if (columnLength == -1) + return DbNullValueOrThrow(field); + return (T)(object)PgReader.GetStream(canSeek: !_isSequential); } } @@ -1785,43 +1657,14 @@ async ValueTask GetFieldValueSequential(int column, bool async, Cancellati /// The value of the specified column. public override object GetValue(int ordinal) { - var fieldDescription = CheckRowAndGetField(ordinal); - - if (_isSequential) - { - SeekToColumnSequential(ordinal, false).GetAwaiter().GetResult(); - CheckColumnStart(); - } - else - SeekToColumnNonSequential(ordinal); - - if (ColumnLen == -1) + var field = GetInfo(ordinal, null, out var converter, out var bufferRequirement, out _); + var columnLength = SeekToColumn(async: false, ordinal, field).GetAwaiter().GetResult(); + if (columnLength == -1) return DBNull.Value; - object result; - var position = Buffer.ReadPosition; - try - { - result = _isSequential - ? fieldDescription.Handler.ReadAsObject(Buffer, ColumnLen, false, fieldDescription).GetAwaiter().GetResult() - : fieldDescription.Handler.ReadAsObject(Buffer, ColumnLen, fieldDescription); - } - catch - { - if (Connector.State != ConnectorState.Broken) - { - var writtenBytes = Buffer.ReadPosition - position; - var remainingBytes = ColumnLen - writtenBytes; - if (remainingBytes > 0) - Buffer.Skip(remainingBytes, false).GetAwaiter().GetResult(); - } - throw; - } - finally - { - // Important: position must still be updated - PosInColumn += ColumnLen; - } + PgReader.StartRead(bufferRequirement); + var result = converter.ReadAsObject(PgReader); + PgReader.EndRead(); return result; } @@ -1843,16 +1686,7 @@ public override object GetValue(int ordinal) /// The zero-based column ordinal. /// true if the specified column is equivalent to ; otherwise false. public override bool IsDBNull(int ordinal) - { - CheckRowAndGetField(ordinal); - - if (_isSequential) - SeekToColumnSequential(ordinal, false).GetAwaiter().GetResult(); - else - SeekToColumnNonSequential(ordinal); - - return ColumnLen == -1; - } + => SeekToColumn(async: false, ordinal, CheckRowAndGetField(ordinal), resumableOp: true).GetAwaiter().GetResult() is -1; /// /// An asynchronous version of , which gets a value that indicates whether the column contains non-existent or missing values. @@ -1865,21 +1699,16 @@ public override bool IsDBNull(int ordinal) /// true if the specified column value is equivalent to otherwise false. public override Task IsDBNullAsync(int ordinal, CancellationToken cancellationToken) { - CheckRowAndGetField(ordinal); - if (!_isSequential) - return IsDBNull(ordinal) ? PGUtil.TrueTask : PGUtil.FalseTask; + return IsDBNull(ordinal) ? TrueTask : FalseTask; using (NoSynchronizationContextScope.Enter()) - return IsDBNullAsyncInternal(ordinal, cancellationToken); + return Core(ordinal, cancellationToken); - // ReSharper disable once InconsistentNaming - async Task IsDBNullAsyncInternal(int ordinal, CancellationToken cancellationToken) + async Task Core(int ordinal, CancellationToken cancellationToken) { using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - - await SeekToColumn(ordinal, true, cancellationToken); - return ColumnLen == -1; + return await SeekToColumn(async: true, ordinal, CheckRowAndGetField(ordinal), resumableOp: true) is -1; } } @@ -1932,6 +1761,7 @@ public override int GetOrdinal(string name) /// /// The zero-based column ordinal. /// The data type of the specified column. + [UnconditionalSuppressMessage("ILLink", "IL2093", Justification = "No members are dynamically accessed by Npgsql via GetFieldType")] public override Type GetFieldType(int ordinal) => GetField(ordinal).FieldType; @@ -2074,111 +1904,138 @@ Task> GetColumnSchema(bool async, Cancellatio #region Seeking - Task SeekToColumn(int column, bool async, CancellationToken cancellationToken = default) - { - if (_isSequential) - return SeekToColumnSequential(column, async, cancellationToken); - SeekToColumnNonSequential(column); - return Task.CompletedTask; - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + ValueTask SeekToColumn(bool async, int ordinal, FieldDescription field, bool resumableOp = false) + => _isSequential + ? SeekToColumnSequential(async, ordinal, field, resumableOp) + : new(SeekToColumnNonSequential(ordinal, field, resumableOp)); - void SeekToColumnNonSequential(int column) + int SeekToColumnNonSequential(int ordinal, FieldDescription field, bool resumableOp = false) { - // Shut down any streaming going on on the column - if (_columnStream != null) - { - _columnStream.Dispose(); - _columnStream = null; - } + PgReader.Commit(async: false, _column == ordinal && PgReader.Resumable && resumableOp).GetAwaiter().GetResult(); - for (var lastColumnRead = _columns.Count; column >= lastColumnRead; lastColumnRead++) + for (var lastColumnRead = _columns.Count; ordinal >= lastColumnRead; lastColumnRead++) { - int lastColumnLen; - (Buffer.ReadPosition, lastColumnLen) = _columns[lastColumnRead - 1]; + (Buffer.ReadPosition, var lastColumnLen) = _columns[lastColumnRead - 1]; if (lastColumnLen != -1) Buffer.ReadPosition += lastColumnLen; var len = Buffer.ReadInt32(); _columns.Add((Buffer.ReadPosition, len)); } - (Buffer.ReadPosition, ColumnLen) = _columns[column]; - _column = column; - PosInColumn = 0; + (Buffer.ReadPosition, var columnLength) = _columns[ordinal]; + PgReader.Init(columnLength, field.DataFormat, resumableOp); + _column = ordinal; + + return columnLength; } /// - /// Seeks to the given column. The 4-byte length is read and stored in . + /// Seeks to the given column. The 4-byte length is read and returned. /// - async Task SeekToColumnSequential(int column, bool async, CancellationToken cancellationToken = default) + ValueTask SeekToColumnSequential(bool async, int ordinal, FieldDescription field, bool resumableOp = false) { - if (column < 0 || column >= _numColumns) - throw new IndexOutOfRangeException("Column index out of range"); - - if (column < _column) - throw new InvalidOperationException($"Invalid attempt to read from column ordinal '{column}'. With CommandBehavior.SequentialAccess, you may only read from column ordinal '{_column}' or greater."); - - if (column == _column) - return; - - // Need to seek forward - - // Shut down any streaming going on on the column - if (_columnStream != null) + var reread = _column == ordinal; + // Column rereading rules for sequential mode: + // * We never allow rereading if the column didn't get initialized as resumable the previous time + // * If it did get initialized as resumable we only allow rereading when either of the following is true: + // - The op is a resumable one again + // - The op isn't resumable but the field is still entirely unconsumed + if (ordinal < _column || (reread && (!PgReader.Resumable || (!resumableOp && !PgReader.IsAtStart)))) + ThrowHelper.ThrowInvalidOperationException( + $"Invalid attempt to read from column ordinal '{ordinal}'. With CommandBehavior.SequentialAccess, " + + $"you may only read from column ordinal '{_column}' or greater."); + + var committed = false; + if (!PgReader.CommitHasIO(reread)) { - _columnStream.Dispose(); - _columnStream = null; - // Disposing the stream leaves us at the end of the column - PosInColumn = ColumnLen; + PgReader.Commit(async: false, reread).GetAwaiter().GetResult(); + committed = true; + if (TrySeekBuffered(ordinal, out var columnLength)) + { + PgReader.Init(columnLength, field.DataFormat, columnLength is -1 || resumableOp); + return new(columnLength); + } + + // If we couldn't consume the column TrySeekBuffered had to stop at, do so now. + if (columnLength > -1) + { + // Resumable: true causes commit to consume without error. + PgReader.Init(columnLength, field.DataFormat, resumable: true); + committed = false; + } } - // Skip to end of column if needed - // TODO: Simplify by better initializing _columnLen/_posInColumn - var remainingInColumn = ColumnLen == -1 ? 0 : ColumnLen - PosInColumn; - if (remainingInColumn > 0) - await Buffer.Skip(remainingInColumn, async); + return Core(async, !committed, ordinal, field.DataFormat, resumableOp); - // Skip over unwanted fields - for (; _column < column - 1; _column++) +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask Core(bool async, bool commit, int ordinal, DataFormat dataFormat, bool resumableOp) { - await Buffer.Ensure(4, async); - var len = Buffer.ReadInt32(); - if (len != -1) - await Buffer.Skip(len, async); - } + if (commit) + { + Debug.Assert(ordinal != _column); + await PgReader.Commit(async, reread); + } - await Buffer.Ensure(4, async); - ColumnLen = Buffer.ReadInt32(); - PosInColumn = 0; - _column = column; - } + if (ordinal == _column) + { + PgReader.Init(PgReader.FieldSize, field.DataFormat, PgReader.FieldSize is -1 || resumableOp); + return PgReader.FieldSize; + } - Task SeekInColumn(int dataOffset, bool async, CancellationToken cancellationToken = default) - { - if (_isSequential) - return SeekInColumnSequential(dataOffset, async); + // Seek to the requested column + var buffer = Buffer; + for (; _column < ordinal - 1; _column++) + { + await buffer.Ensure(4, async); + var len = buffer.ReadInt32(); + if (len != -1) + await buffer.Skip(len, async); + } - if (dataOffset >= ColumnLen) - ThrowHelper.ThrowArgumentOutOfRange_OutOfColumnBounds(nameof(dataOffset), ColumnLen); + await buffer.Ensure(4, async); + var columnLength = buffer.ReadInt32(); + _column = ordinal; - Buffer.ReadPosition = _columns[_column].Offset + dataOffset; - PosInColumn = dataOffset; - return Task.CompletedTask; + PgReader.Init(columnLength, dataFormat, resumableOp); + return columnLength; + } - async Task SeekInColumnSequential(int dataOffset, bool async) + bool TrySeekBuffered(int ordinal, out int columnLength) { - Debug.Assert(_column > -1); - - if (dataOffset < PosInColumn) - ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); + if (ordinal == _column) + { + columnLength = PgReader.FieldSize; + return true; + } - if (dataOffset >= ColumnLen) - ThrowHelper.ThrowArgumentOutOfRange_OutOfColumnBounds(nameof(dataOffset), ColumnLen); + // Skip over unwanted fields + columnLength = -1; + var buffer = Buffer; + for (; _column < ordinal - 1; _column++) + { + if (buffer.ReadBytesLeft < 4) + return false; + columnLength = buffer.ReadInt32(); + if (columnLength > 0) + { + if (buffer.ReadBytesLeft < columnLength) + return false; + buffer.Skip(columnLength); + } + } - if (dataOffset > PosInColumn) + if (buffer.ReadBytesLeft < 4) { - await Buffer.Skip(dataOffset - PosInColumn, async); - PosInColumn = dataOffset; + columnLength = -1; + return false; } + + columnLength = buffer.ReadInt32(); + _column = ordinal; + return true; } } @@ -2190,8 +2047,6 @@ Task ConsumeRow(bool async) { Debug.Assert(State == ReaderState.InResult || State == ReaderState.BeforeResult); - UniqueRowId++; - if (!_canConsumeRowNonSequentially) return ConsumeRowSequential(async); @@ -2201,19 +2056,7 @@ Task ConsumeRow(bool async) async Task ConsumeRowSequential(bool async) { - if (_columnStream != null) - { - _columnStream.Dispose(); - _columnStream = null; - // Disposing the stream leaves us at the end of the column - PosInColumn = ColumnLen; - } - - // TODO: Potential for code-sharing with ReadColumn above, which also skips - // Skip to end of column if needed - var remainingInColumn = ColumnLen == -1 ? 0 : ColumnLen - PosInColumn; - if (remainingInColumn > 0) - await Buffer.Skip(remainingInColumn, async); + await PgReader.Commit(async, resuming: false); // Skip over the remaining columns in the row for (; _column < _numColumns - 1; _column++) @@ -2230,14 +2073,7 @@ async Task ConsumeRowSequential(bool async) void ConsumeRowNonSequential() { Debug.Assert(State == ReaderState.InResult || State == ReaderState.BeforeResult); - - if (_columnStream is not null) - { - _columnStream.Dispose(); - _columnStream = null; - // Disposing the stream leaves us at the end of the column - PosInColumn = ColumnLen; - } + PgReader.Commit(async: false, resuming: false).GetAwaiter().GetResult(); Buffer.ReadPosition = _dataMsgEnd; } @@ -2264,27 +2100,72 @@ void CheckResultSet() } } - FieldDescription CheckRowAndGetField(int column) + [MethodImpl(MethodImplOptions.NoInlining)] + static T DbNullValueOrThrow(FieldDescription field) { - switch (State) + // When T is a Nullable (and only in that case), we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; + + if (typeof(T) == typeof(object)) + return (T)(object)DBNull.Value; + + ThrowHelper.ThrowInvalidCastException_NoValue(field); + return default; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + FieldDescription GetInfo(int ordinal, Type? type, out PgConverter converter, out Size bufferRequirement, out bool asObject) + { + var field = CheckRowAndGetField(ordinal); + + if (type is null) { - case ReaderState.InResult: - break; - case ReaderState.Closed: - ThrowHelper.ThrowInvalidOperationException("The reader is closed"); - break; - case ReaderState.Disposed: - ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); - break; - default: - ThrowHelper.ThrowInvalidOperationException("No row is available"); - break; + converter = field.ObjectOrDefaultInfo.Converter; + bufferRequirement = field.ObjectOrDefaultInfo.BufferRequirement; + asObject = field.ObjectOrDefaultInfo.AsObject; + return field; } - if (column < 0 || column >= RowDescription!.Count) - ThrowColumnOutOfRange(RowDescription!.Count); + ref var info = ref ColumnInfoCache![ordinal]; + field.GetInfo(type, ref info); + converter = info.Converter; + bufferRequirement = info.BufferRequirement; + asObject = info.AsObject; + return field; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + FieldDescription CheckRowAndGetField(int column) + { + var columns = RowDescription; + var state = State; + if (state is ReaderState.InResult && column >= 0 && column < columns!.Count) + return columns[column]; + + return HandleInvalidState(state, columns?.Count ?? 0); - return RowDescription[column]; + [MethodImpl(MethodImplOptions.NoInlining)] + static FieldDescription HandleInvalidState(ReaderState state, int maxColumns) + { + switch (state) + { + case ReaderState.InResult: + break; + case ReaderState.Closed: + ThrowHelper.ThrowInvalidOperationException("The reader is closed"); + break; + case ReaderState.Disposed: + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); + break; + default: + ThrowHelper.ThrowInvalidOperationException("No row is available"); + break; + } + + ThrowColumnOutOfRange(maxColumns); + return default!; + } } /// @@ -2296,17 +2177,11 @@ FieldDescription GetField(int column) if (RowDescription is null) ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); - if (column < 0 || column >= RowDescription.Count) - ThrowColumnOutOfRange(RowDescription.Count); + var columns = RowDescription; + if (column < 0 || column >= columns.Count) + ThrowColumnOutOfRange(columns.Count); - return RowDescription[column]; - } - - void CheckColumnStart() - { - Debug.Assert(_isSequential); - if (PosInColumn != 0) - ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); + return columns[column]; } void CheckClosedOrDisposed() diff --git a/src/Npgsql/NpgsqlDataSource.cs b/src/Npgsql/NpgsqlDataSource.cs index 510513d0fb..ee3ec18eb5 100644 --- a/src/Npgsql/NpgsqlDataSource.cs +++ b/src/Npgsql/NpgsqlDataSource.cs @@ -10,8 +10,7 @@ using System.Transactions; using Microsoft.Extensions.Logging; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Resolvers; using Npgsql.Properties; using Npgsql.Util; @@ -32,11 +31,8 @@ public abstract class NpgsqlDataSource : DbDataSource internal NpgsqlDataSourceConfiguration Configuration { get; } internal NpgsqlLoggingConfiguration LoggingConfiguration { get; } - readonly List _resolverFactories; - readonly Dictionary _userTypeMappings; - readonly INpgsqlNameTranslator _defaultNameTranslator; - - internal TypeMapper TypeMapper { get; private set; } = null!; // Initialized at bootstrapping + readonly IPgTypeInfoResolver _resolver; + internal PgSerializerOptions SerializerOptions { get; private set; } = null!; // Initialized at bootstrapping /// /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). @@ -81,6 +77,8 @@ private protected readonly Dictionary> _pendi /// readonly SemaphoreSlim _setupMappingsSemaphore = new(1); + readonly INpgsqlNameTranslator _defaultNameTranslator; + internal NpgsqlDataSource( NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig) @@ -100,14 +98,15 @@ internal NpgsqlDataSource( _periodicPasswordProvider, _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval, - _resolverFactories, - _userTypeMappings, + var resolverChain, _defaultNameTranslator, ConnectionInitializer, ConnectionInitializerAsync) = dataSourceConfig; _connectionLogger = LoggingConfiguration.ConnectionLogger; + // TODO probably want this on the options so it can devirt unconditionally. + _resolver = new TypeInfoResolverChain(resolverChain); _password = settings.Password; if (_periodicPasswordSuccessRefreshInterval != default) @@ -127,11 +126,11 @@ internal NpgsqlDataSource( MetricsReporter = new MetricsReporter(this); } - /// + /// public new NpgsqlConnection CreateConnection() => NpgsqlConnection.FromDataSource(this); - /// + /// public new NpgsqlConnection OpenConnection() { var connection = CreateConnection(); @@ -152,7 +151,7 @@ internal NpgsqlDataSource( protected override DbConnection OpenDbConnection() => OpenConnection(); - /// + /// public new async ValueTask OpenConnectionAsync(CancellationToken cancellationToken = default) { var connection = CreateConnection(); @@ -233,19 +232,29 @@ internal async Task Bootstrap( return; // The type loading below will need to send queries to the database, and that depends on a type mapper being set up (even if its - // empty). So we set up here, and then later inject the DatabaseInfo. - var typeMapper = new TypeMapper(connector, _defaultNameTranslator); - connector.TypeMapper = typeMapper; + // empty). So we set up a minimal version here, and then later inject the actual DatabaseInfo. + connector.SerializerOptions = + new(PostgresMinimalDatabaseInfo.DefaultTypeCatalog) + { + TextEncoding = connector.TextEncoding, + TypeInfoResolver = AdoTypeInfoResolver.Instance + }; NpgsqlDatabaseInfo databaseInfo; using (connector.StartUserAction(ConnectorState.Executing, cancellationToken)) databaseInfo = await NpgsqlDatabaseInfo.Load(connector, timeout, async); - DatabaseInfo = databaseInfo; - connector.DatabaseInfo = databaseInfo; - typeMapper.Initialize(databaseInfo, _resolverFactories, _userTypeMappings); - TypeMapper = typeMapper; + connector.DatabaseInfo = DatabaseInfo = databaseInfo; + connector.SerializerOptions = SerializerOptions = + new(databaseInfo, CreateTimeZoneProvider(connector.Timezone)) + { + ArrayNullabilityMode = Settings.ArrayNullabilityMode, + EnableDateTimeInfinityConversions = !Statics.DisableDateTimeInfinityConversions, + TextEncoding = connector.TextEncoding, + TypeInfoResolver = _resolver, + DefaultNameTranslator = _defaultNameTranslator + }; _isBootstrapped = true; } @@ -253,6 +262,18 @@ internal async Task Bootstrap( { _setupMappingsSemaphore.Release(); } + + // Func in a static function to make sure we don't capture state that might not stay around, like a connector. + static Func CreateTimeZoneProvider(string postgresTimeZone) + => () => + { + if (string.Equals(postgresTimeZone, "localtime", StringComparison.OrdinalIgnoreCase)) + throw new TimeZoneNotFoundException( + "The special PostgreSQL timezone 'localtime' is not supported when reading values of type 'timestamp with time zone'. " + + "Please specify a real timezone in 'postgresql.conf' on the server, or set the 'PGTZ' environment variable on the client."); + + return postgresTimeZone; + }; } #region Password management @@ -478,7 +499,7 @@ sealed class DatabaseStateInfo // While the TimeStamp is not strictly required, it does lower the risk of overwriting the current state with an old value internal readonly DateTime TimeStamp; - public DatabaseStateInfo() : this(default, default, default) {} + public DatabaseStateInfo() : this(default, default, default) { } public DatabaseStateInfo(DatabaseState state, NpgsqlTimeout timeout, DateTime timeStamp) => (State, Timeout, TimeStamp) = (state, timeout, timeStamp); diff --git a/src/Npgsql/NpgsqlDataSourceBuilder.cs b/src/Npgsql/NpgsqlDataSourceBuilder.cs index 356fa48cb3..de87962d5c 100644 --- a/src/Npgsql/NpgsqlDataSourceBuilder.cs +++ b/src/Npgsql/NpgsqlDataSourceBuilder.cs @@ -1,12 +1,14 @@ using System; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Net.Security; using System.Security.Cryptography.X509Certificates; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; -using Npgsql.Internal.TypeHandling; +using Npgsql.Internal; +using Npgsql.Internal.Resolvers; using Npgsql.TypeMapping; using NpgsqlTypes; @@ -17,6 +19,8 @@ namespace Npgsql; /// public sealed class NpgsqlDataSourceBuilder : INpgsqlTypeMapper { + static UnsupportedTypeInfoResolver UnsupportedTypeInfoResolver { get; } = new(); + readonly NpgsqlSlimDataSourceBuilder _internalBuilder; /// @@ -45,14 +49,80 @@ public INpgsqlNameTranslator DefaultNameTranslator /// public string ConnectionString => _internalBuilder.ConnectionString; + internal static void ResetGlobalMappings(bool overwrite) + => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers(new IPgTypeInfoResolver[] + { + overwrite ? new AdoTypeInfoResolver() : AdoTypeInfoResolver.Instance, + new ExtraConversionsResolver(), + new SystemTextJsonTypeInfoResolver(), + new SystemTextJsonPocoTypeInfoResolver(), + new RangeTypeInfoResolver(), + new RecordTypeInfoResolver(), + new TupledRecordTypeInfoResolver(), + new FullTextSearchTypeInfoResolver(), + new NetworkTypeInfoResolver(), + new GeometricTypeInfoResolver(), + new LTreeTypeInfoResolver(), + new UnmappedEnumTypeInfoResolver(), + new UnmappedRangeTypeInfoResolver(), + new UnmappedMultirangeTypeInfoResolver(), + // Arrays + new AdoArrayTypeInfoResolver(), + new ExtraConversionsArrayTypeInfoResolver(), + new SystemTextJsonArrayTypeInfoResolver(), + new SystemTextJsonPocoArrayTypeInfoResolver(), + new RangeArrayTypeInfoResolver(), + new RecordArrayTypeInfoResolver(), + new TupledRecordArrayTypeInfoResolver(), + new UnmappedEnumArrayTypeInfoResolver(), + new UnmappedRangeArrayTypeInfoResolver(), + new UnmappedMultirangeArrayTypeInfoResolver(), + }, overwrite); + + static NpgsqlDataSourceBuilder() + => ResetGlobalMappings(overwrite: false); + /// /// Constructs a new , optionally starting out from the given . /// public NpgsqlDataSourceBuilder(string? connectionString = null) { - _internalBuilder = new(connectionString); - + _internalBuilder = new(new NpgsqlConnectionStringBuilder(connectionString)); AddDefaultFeatures(); + + void AddDefaultFeatures() + { + _internalBuilder.EnableEncryption(); + AddTypeInfoResolver(UnsupportedTypeInfoResolver); + // Reverse order arrays. + AddTypeInfoResolver(new UnmappedMultirangeArrayTypeInfoResolver()); + AddTypeInfoResolver(new UnmappedRangeArrayTypeInfoResolver()); + AddTypeInfoResolver(new UnmappedEnumArrayTypeInfoResolver()); + AddTypeInfoResolver(new TupledRecordArrayTypeInfoResolver()); + AddTypeInfoResolver(new RecordArrayTypeInfoResolver()); + AddTypeInfoResolver(new RangeArrayTypeInfoResolver()); + AddTypeInfoResolver(new SystemTextJsonPocoArrayTypeInfoResolver()); + AddTypeInfoResolver(new SystemTextJsonArrayTypeInfoResolver()); + AddTypeInfoResolver(new ExtraConversionsArrayTypeInfoResolver()); + AddTypeInfoResolver(new AdoArrayTypeInfoResolver()); + // Reverse order. + AddTypeInfoResolver(new UnmappedMultirangeTypeInfoResolver()); + AddTypeInfoResolver(new UnmappedRangeTypeInfoResolver()); + AddTypeInfoResolver(new UnmappedEnumTypeInfoResolver()); + AddTypeInfoResolver(new LTreeTypeInfoResolver()); + AddTypeInfoResolver(new GeometricTypeInfoResolver()); + AddTypeInfoResolver(new NetworkTypeInfoResolver()); + AddTypeInfoResolver(new FullTextSearchTypeInfoResolver()); + AddTypeInfoResolver(new TupledRecordTypeInfoResolver()); + AddTypeInfoResolver(new RecordTypeInfoResolver()); + AddTypeInfoResolver(new RangeTypeInfoResolver()); + AddTypeInfoResolver(new SystemTextJsonPocoTypeInfoResolver()); + AddTypeInfoResolver(new SystemTextJsonTypeInfoResolver()); + AddTypeInfoResolver(new ExtraConversionsResolver()); + AddTypeInfoResolver(AdoTypeInfoResolver.Instance); + foreach (var plugin in GlobalTypeMapper.Instance.GetPluginResolvers().Reverse()) + AddTypeInfoResolver(plugin); + } } /// @@ -208,8 +278,12 @@ public NpgsqlDataSourceBuilder UsePeriodicPasswordProvider( #region Type mapping /// - public void AddTypeResolverFactory(TypeHandlerResolverFactory resolverFactory) - => _internalBuilder.AddTypeResolverFactory(resolverFactory); + public void AddTypeInfoResolver(IPgTypeInfoResolver resolver) + => _internalBuilder.AddTypeInfoResolver(resolver); + + /// + void INpgsqlTypeMapper.Reset() + => _internalBuilder.ResetTypeMappings(); /// /// Sets up System.Text.Json mappings for the PostgreSQL json and jsonb types. @@ -226,7 +300,8 @@ public NpgsqlDataSourceBuilder UseSystemTextJson( Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null) { - AddTypeResolverFactory(new SystemTextJsonTypeHandlerResolverFactory(jsonbClrTypes, jsonClrTypes, serializerOptions)); + AddTypeInfoResolver(new SystemTextJsonPocoArrayTypeInfoResolver(jsonbClrTypes, jsonClrTypes, serializerOptions)); + AddTypeInfoResolver(new SystemTextJsonPocoTypeInfoResolver(jsonbClrTypes, jsonClrTypes, serializerOptions)); return this; } @@ -269,13 +344,6 @@ public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? name public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) => _internalBuilder.UnmapComposite(clrType, pgName, nameTranslator); - void INpgsqlTypeMapper.Reset() - { - ((INpgsqlTypeMapper)_internalBuilder).Reset(); - - AddDefaultFeatures(); - } - #endregion Type mapping /// @@ -318,13 +386,4 @@ public NpgsqlDataSource Build() /// public NpgsqlMultiHostDataSource BuildMultiHost() => _internalBuilder.BuildMultiHost(); - - void AddDefaultFeatures() - { - _internalBuilder.EnableEncryption(); - _internalBuilder.AddDefaultTypeResolverFactory(new SystemTextJsonTypeHandlerResolverFactory()); - _internalBuilder.AddDefaultTypeResolverFactory(new RangeTypeHandlerResolverFactory()); - _internalBuilder.AddDefaultTypeResolverFactory(new RecordTypeHandlerResolverFactory()); - _internalBuilder.AddDefaultTypeResolverFactory(new FullTextSearchTypeHandlerResolverFactory()); - } } diff --git a/src/Npgsql/NpgsqlDataSourceConfiguration.cs b/src/Npgsql/NpgsqlDataSourceConfiguration.cs index 40aec62171..749ab7df7b 100644 --- a/src/Npgsql/NpgsqlDataSourceConfiguration.cs +++ b/src/Npgsql/NpgsqlDataSourceConfiguration.cs @@ -5,8 +5,6 @@ using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; namespace Npgsql; @@ -19,8 +17,7 @@ sealed record NpgsqlDataSourceConfiguration( Func>? PeriodicPasswordProvider, TimeSpan PeriodicPasswordSuccessRefreshInterval, TimeSpan PeriodicPasswordFailureRefreshInterval, - List ResolverFactories, - Dictionary UserTypeMappings, + IEnumerable ResolverChain, INpgsqlNameTranslator DefaultNameTranslator, Action? ConnectionInitializer, Func? ConnectionInitializerAsync); diff --git a/src/Npgsql/NpgsqlLargeObjectManager.cs b/src/Npgsql/NpgsqlLargeObjectManager.cs index 4ec6cb002d..8f9b4cf6ea 100644 --- a/src/Npgsql/NpgsqlLargeObjectManager.cs +++ b/src/Npgsql/NpgsqlLargeObjectManager.cs @@ -1,5 +1,4 @@ using Npgsql.Util; -using System; using System.Data; using System.Text; using System.Threading; diff --git a/src/Npgsql/NpgsqlNestedDataReader.cs b/src/Npgsql/NpgsqlNestedDataReader.cs index 55234f5423..060592e312 100644 --- a/src/Npgsql/NpgsqlNestedDataReader.cs +++ b/src/Npgsql/NpgsqlNestedDataReader.cs @@ -1,8 +1,5 @@ using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; using System; using System.Collections; using System.Collections.Generic; @@ -10,7 +7,7 @@ using System.Globalization; using System.IO; using System.Runtime.CompilerServices; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Postgres; namespace Npgsql; @@ -22,7 +19,6 @@ namespace Npgsql; public sealed class NpgsqlNestedDataReader : DbDataReader { readonly NpgsqlDataReader _outermostReader; - ulong _uniqueOutermostReaderRowId; readonly NpgsqlNestedDataReader? _outerNestedReader; NpgsqlNestedDataReader? _cachedFreeNestedDataReader; PostgresCompositeType? _compositeType; @@ -33,37 +29,53 @@ public sealed class NpgsqlNestedDataReader : DbDataReader ReaderState _readerState; readonly List _columns = new(); + long _startPos; - readonly struct ColumnInfo + DataFormat Format => DataFormat.Binary; + + struct ColumnInfo { - public readonly uint TypeOid; - public readonly int BufferPos; - public readonly NpgsqlTypeHandler TypeHandler; + readonly DataFormat _format; + public PostgresType PostgresType { get; } + public int BufferPos { get; } + public PgConverterInfo LastConverterInfo { get; private set; } + + public PgTypeInfo ObjectOrDefaultTypeInfo { get; } + public PgConverterInfo ObjectOrDefaultInfo => ObjectOrDefaultTypeInfo.Bind(Field, _format); + + Field Field => new("?", ObjectOrDefaultTypeInfo.Options.PortableTypeIds ? PostgresType.DataTypeName : (Oid)PostgresType.OID, -1); - public ColumnInfo(uint typeOid, int bufferPos, NpgsqlTypeHandler typeHandler) + public ColumnInfo SetConverterInfo(PgTypeInfo typeInfo) + => this with + { + LastConverterInfo = typeInfo.Bind(Field, _format) + }; + + public ColumnInfo(PostgresType postgresType, int bufferPos, PgTypeInfo objectOrDefaultTypeInfo, DataFormat format) { - TypeOid = typeOid; + _format = format; + PostgresType = postgresType; BufferPos = bufferPos; - TypeHandler = typeHandler; + ObjectOrDefaultTypeInfo = objectOrDefaultTypeInfo; } } - NpgsqlReadBuffer Buffer => _outermostReader.Buffer; - TypeMapper TypeMapper => _outermostReader.Connector.TypeMapper; + PgReader PgReader => _outermostReader.Buffer.PgReader; + PgSerializerOptions SerializerOptions => _outermostReader.Connector.SerializerOptions; internal NpgsqlNestedDataReader(NpgsqlDataReader outermostReader, NpgsqlNestedDataReader? outerNestedReader, - ulong uniqueOutermostReaderRowId, int depth, PostgresCompositeType? compositeType) + int depth, PostgresCompositeType? compositeType) { _outermostReader = outermostReader; _outerNestedReader = outerNestedReader; - _uniqueOutermostReaderRowId = uniqueOutermostReaderRowId; _depth = depth; _compositeType = compositeType; + _startPos = PgReader.FieldStartPos; } - internal void Init(ulong uniqueOutermostReaderRowId, PostgresCompositeType? compositeType) + internal void Init(PostgresCompositeType? compositeType) { - _uniqueOutermostReaderRowId = uniqueOutermostReaderRowId; + _startPos = PgReader.FieldStartPos; _columns.Clear(); _numRows = 0; _nextRowIndex = 0; @@ -74,9 +86,9 @@ internal void Init(ulong uniqueOutermostReaderRowId, PostgresCompositeType? comp internal void InitArray() { - var dimensions = Buffer.ReadInt32(); - var containsNulls = Buffer.ReadInt32() == 1; - Buffer.ReadUInt32(); // Element OID. Ignored. + var dimensions = PgReader.ReadInt32(); + var containsNulls = PgReader.ReadInt32() == 1; + PgReader.ReadUInt32(); // Element OID. Ignored. if (containsNulls) throw new InvalidOperationException("Record array contains null record"); @@ -87,19 +99,19 @@ internal void InitArray() if (dimensions != 1) throw new InvalidOperationException("Cannot read a multidimensional array with a nested DbDataReader"); - _numRows = Buffer.ReadInt32(); - Buffer.ReadInt32(); // Lower bound + _numRows = PgReader.ReadInt32(); + PgReader.ReadInt32(); // Lower bound if (_numRows > 0) - Buffer.ReadInt32(); // Length of first row + PgReader.ReadInt32(); // Length of first row - _nextRowBufferPos = Buffer.ReadPosition; + _nextRowBufferPos = PgReader.FieldOffset; } internal void InitSingleRow() { _numRows = 1; - _nextRowBufferPos = Buffer.ReadPosition; + _nextRowBufferPos = PgReader.FieldOffset; } /// @@ -141,7 +153,7 @@ public override bool HasRows /// public override bool IsClosed => _readerState == ReaderState.Closed || _readerState == ReaderState.Disposed - || _outermostReader.IsClosed || _uniqueOutermostReaderRowId != _outermostReader.UniqueRowId; + || _outermostReader.IsClosed || PgReader.FieldStartPos != _startPos; /// public override int RecordsAffected => -1; @@ -181,26 +193,22 @@ public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); - var field = CheckRowAndColumnAndSeek(ordinal); - var handler = field.Handler; - if (!(handler is ByteaHandler)) - throw new InvalidCastException("GetBytes() not supported for type " + field.Handler.PgDisplayName); + var columnLen = CheckRowAndColumnAndSeek(ordinal, out var column); + if (columnLen is -1) + ThrowHelper.ThrowInvalidCastException_NoValue(); - if (field.Length == -1) - throw new InvalidCastException("field is null"); + if (buffer is null) + return columnLen; - var dataOffset2 = (int)dataOffset; - if (dataOffset2 >= field.Length) - ThrowHelper.ThrowArgumentOutOfRange_OutOfColumnBounds(nameof(dataOffset), field.Length); + using var _ = PgReader.BeginNestedRead(columnLen, Size.Zero); - Buffer.ReadPosition += dataOffset2; + // Move to offset + PgReader.Seek((int)dataOffset); - length = Math.Min(length, field.Length - dataOffset2); - - if (buffer == null) - return length; - - return Buffer.Read(new Span(buffer, bufferOffset, length)); + // At offset, read into buffer. + length = Math.Min(length, PgReader.CurrentRemaining); + PgReader.ReadBytes(new Span(buffer, bufferOffset, length)); + return length; } /// public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) @@ -217,26 +225,26 @@ public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int /// A data reader. public new NpgsqlNestedDataReader GetData(int ordinal) { - var field = CheckRowAndColumnAndSeek(ordinal); - var type = field.Handler.PostgresType; + var valueLength = CheckRowAndColumnAndSeek(ordinal, out var column); + var type = column.PostgresType; var isArray = type is PostgresArrayType; var elementType = isArray ? ((PostgresArrayType)type).Element : type; var compositeType = elementType as PostgresCompositeType; if (elementType.InternalName != "record" && compositeType == null) throw new InvalidCastException("GetData() not supported for type " + type.DisplayName); - if (field.Length == -1) + if (valueLength == -1) throw new InvalidCastException("field is null"); var reader = _cachedFreeNestedDataReader; if (reader != null) { _cachedFreeNestedDataReader = null; - reader.Init(_uniqueOutermostReaderRowId, compositeType); + reader.Init(compositeType); } else { - reader = new NpgsqlNestedDataReader(_outermostReader, this, _uniqueOutermostReaderRowId, _depth + 1, compositeType); + reader = new NpgsqlNestedDataReader(_outermostReader, this, _depth + 1, compositeType); } if (isArray) reader.InitArray(); @@ -249,7 +257,7 @@ public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int public override string GetDataTypeName(int ordinal) { var column = CheckRowAndColumn(ordinal); - return column.TypeHandler.PgDisplayName; + return column.PostgresType.DisplayName; } /// @@ -288,16 +296,19 @@ public override int GetOrdinal(string name) public override Type GetFieldType(int ordinal) { var column = CheckRowAndColumn(ordinal); - return column.TypeHandler.GetFieldType(); + return column.ObjectOrDefaultTypeInfo.Type; } /// public override object GetValue(int ordinal) { - var column = CheckRowAndColumnAndSeek(ordinal); - if (column.Length == -1) + var columnLength = CheckRowAndColumnAndSeek(ordinal, out var column); + var info = column.ObjectOrDefaultInfo; + if (columnLength == -1) return DBNull.Value; - return column.Handler.ReadAsObject(Buffer, column.Length); + + using var _ = PgReader.BeginNestedRead(columnLength, info.BufferRequirement); + return info.Converter.ReadAsObject(PgReader); } /// @@ -315,7 +326,7 @@ public override int GetValues(object[] values) /// public override bool IsDBNull(int ordinal) - => CheckRowAndColumnAndSeek(ordinal).Length == -1; + => CheckRowAndColumnAndSeek(ordinal, out _) == -1; /// public override T GetFieldValue(int ordinal) @@ -326,25 +337,25 @@ public override T GetFieldValue(int ordinal) if (typeof(T) == typeof(TextReader)) return (T)(object)GetTextReader(ordinal); - var field = CheckRowAndColumnAndSeek(ordinal); + var columnLength = CheckRowAndColumnAndSeek(ordinal, out var column); + var info = GetOrAddConverterInfo(typeof(T), column, ordinal); - if (field.Length == -1) + if (columnLength == -1) { // When T is a Nullable (and only in that case), we support returning null - if (NullableHandler.Exists) + if (default(T) is null && typeof(T).IsValueType) return default!; if (typeof(T) == typeof(object)) return (T)(object)DBNull.Value; - throw new InvalidCastException("field is null"); + ThrowHelper.ThrowInvalidCastException_NoValue(); } - return NullableHandler.Exists - ? NullableHandler.Read(field.Handler, Buffer, field.Length, fieldDescription: null) - : typeof(T) == typeof(object) - ? (T)field.Handler.ReadAsObject(Buffer, field.Length, fieldDescription: null) - : field.Handler.Read(Buffer, field.Length, fieldDescription: null); + using var _ = PgReader.BeginNestedRead(columnLength, info.BufferRequirement); + return info.AsObject + ? (T)info.Converter.ReadAsObject(PgReader)! + : info.GetConverter().Read(PgReader); } /// @@ -352,7 +363,7 @@ public override bool Read() { CheckResultSet(); - Buffer.ReadPosition = _nextRowBufferPos; + PgReader.Seek(_nextRowBufferPos); if (_nextRowIndex == _numRows) { _readerState = ReaderState.AfterRows; @@ -360,27 +371,34 @@ public override bool Read() } if (_nextRowIndex++ != 0) - Buffer.ReadInt32(); // Length of record + PgReader.ReadInt32(); // Length of record - var numColumns = Buffer.ReadInt32(); + var numColumns = PgReader.ReadInt32(); for (var i = 0; i < numColumns; i++) { - var typeOid = Buffer.ReadUInt32(); - var bufferPos = Buffer.ReadPosition; + var typeOid = PgReader.ReadUInt32(); + var bufferPos = PgReader.FieldOffset; if (i >= _columns.Count) - _columns.Add(new ColumnInfo(typeOid, bufferPos, TypeMapper.ResolveByOID(typeOid))); + { + var pgType = SerializerOptions.DatabaseInfo.GetPostgresType(typeOid); + _columns.Add(new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgType, SerializerOptions), Format)); + } else - _columns[i] = new ColumnInfo(typeOid, bufferPos, - _columns[i].TypeOid == typeOid ? _columns[i].TypeHandler : TypeMapper.ResolveByOID(typeOid)); + { + var pgType = _columns[i].PostgresType.OID == typeOid + ? _columns[i].PostgresType + : SerializerOptions.DatabaseInfo.GetPostgresType(typeOid); + _columns[i] = new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgType, SerializerOptions), Format); + } - var columnLen = Buffer.ReadInt32(); + var columnLen = PgReader.ReadInt32(); if (columnLen >= 0) - Buffer.Skip(columnLen); + PgReader.Consume(columnLen); } _columns.RemoveRange(numColumns, _columns.Count - numColumns); - _nextRowBufferPos = Buffer.ReadPosition; + _nextRowBufferPos = PgReader.FieldOffset; _readerState = ReaderState.OnRow; return true; @@ -465,12 +483,25 @@ ColumnInfo CheckRowAndColumn(int column) return _columns[column]; } - (NpgsqlTypeHandler Handler, int Length) CheckRowAndColumnAndSeek(int ordinal) + int CheckRowAndColumnAndSeek(int ordinal, out ColumnInfo column) { - var column = CheckRowAndColumn(ordinal); - Buffer.ReadPosition = column.BufferPos; - var len = Buffer.ReadInt32(); - return (column.TypeHandler, len); + column = CheckRowAndColumn(ordinal); + PgReader.Seek(column.BufferPos); + return PgReader.ReadInt32(); + } + + PgConverterInfo GetOrAddConverterInfo(Type type, ColumnInfo column, int ordinal) + { + PgConverterInfo info; + if (!column.LastConverterInfo.IsDefault && column.LastConverterInfo.TypeToConvert == type) + info = column.LastConverterInfo; + else + { + var columnInfo = column.SetConverterInfo(AdoSerializerHelpers.GetTypeInfoForReading(type, column.PostgresType, SerializerOptions)); + _columns[ordinal] = columnInfo; + info = columnInfo.LastConverterInfo; + } + return info; } enum ReaderState diff --git a/src/Npgsql/NpgsqlParameter.cs b/src/Npgsql/NpgsqlParameter.cs index 3be8758799..79e9ed8ccd 100644 --- a/src/Npgsql/NpgsqlParameter.cs +++ b/src/Npgsql/NpgsqlParameter.cs @@ -3,16 +3,15 @@ using System.Data; using System.Data.Common; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; using Npgsql.TypeMapping; using Npgsql.Util; using NpgsqlTypes; -using static Npgsql.Util.Statics; namespace Npgsql; @@ -27,29 +26,27 @@ public class NpgsqlParameter : DbParameter, IDbDataParameter, ICloneable private protected byte _scale; private protected int _size; - // ReSharper disable InconsistentNaming private protected NpgsqlDbType? _npgsqlDbType; private protected string? _dataTypeName; - // ReSharper restore InconsistentNaming - private protected string _name = string.Empty; - private protected object? _value; - private protected string _sourceColumn; + private protected string _name = string.Empty; + object? _value; + private protected bool _useSubStream; + private protected SubReadStream? _subStream; + private protected string _sourceColumn; internal string TrimmedName { get; private protected set; } = PositionalName; - internal const string PositionalName = ""; - - /// - /// Can be used to communicate a value from the validation phase to the writing phase. - /// To be used by type handlers only. - /// - public object? ConvertedValue { get; set; } + internal const string PositionalName = ""; - internal NpgsqlLengthCache? LengthCache { get; set; } + internal PgTypeInfo? TypeInfo { get; private set; } - internal NpgsqlTypeHandler? Handler { get; set; } + internal PgTypeId PgTypeId { get; set; } + internal PgConverter? Converter { get; private set; } - internal FormatCode FormatCode { get; private set; } + internal DataFormat Format { get; private protected set; } + private protected Size? WriteSize { get; set; } + private protected object? _writeState; + private protected Size _bufferRequirement; #endregion @@ -250,14 +247,14 @@ public sealed override string ParameterName { if (Collection is not null) Collection.ChangeParameterName(this, value); - else + else ChangeParameterName(value); } } internal void ChangeParameterName(string? value) { - if (value == null) + if (value is null) _name = TrimmedName = PositionalName; else if (value.Length > 0 && (value[0] == ':' || value[0] == '@')) TrimmedName = (_name = value).Substring(1); @@ -278,10 +275,9 @@ public override object? Value get => _value; set { - if (_value == null || value == null || _value.GetType() != value.GetType()) - Handler = null; + if (value is null || _value?.GetType() != value.GetType()) + ResetTypeInfo(); _value = value; - ConvertedValue = null; } } @@ -314,27 +310,25 @@ public sealed override DbType DbType { get { - if (_npgsqlDbType.HasValue) - return GlobalTypeMapper.NpgsqlDbTypeToDbType(_npgsqlDbType.Value); + if (_npgsqlDbType is { } npgsqlDbType) + return npgsqlDbType.ToDbType(); if (_dataTypeName is not null) - return GlobalTypeMapper.NpgsqlDbTypeToDbType(GlobalTypeMapper.DataTypeNameToNpgsqlDbType(_dataTypeName)); + return Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName).ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; - if (Value is not null) // Infer from value but don't cache - { - return GlobalTypeMapper.Instance.TryResolveMappingByValue(Value, out var mapping) - ? mapping.DbType - : DbType.Object; - } + // Infer from value but don't cache + if (Value is not null) + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + return GlobalTypeMapper.Instance.TryGetDataTypeName(GetValueType(StaticValueType)!, Value)?.ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; return DbType.Object; } set { - Handler = null; + ResetTypeInfo(); _npgsqlDbType = value == DbType.Object ? null - : GlobalTypeMapper.DbTypeToNpgsqlDbType(value) + : value.ToNpgsqlDbType() ?? throw new NotSupportedException($"The parameter type DbType.{value} isn't supported by PostgreSQL or Npgsql"); } } @@ -355,14 +349,12 @@ public NpgsqlDbType NpgsqlDbType return _npgsqlDbType.Value; if (_dataTypeName is not null) - return GlobalTypeMapper.DataTypeNameToNpgsqlDbType(_dataTypeName); + return Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName).ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; - if (Value is not null) // Infer from value - { - return GlobalTypeMapper.Instance.TryResolveMappingByValue(Value, out var mapping) - ? mapping.NpgsqlDbType ?? NpgsqlDbType.Unknown - : throw new NotSupportedException("Can't infer NpgsqlDbType for type " + Value.GetType()); - } + // Infer from value but don't cache + if (Value is not null) + // We pass ValueType here for the generic derived type (NpgsqlParameter) where we should respect T and not the runtime type. + return GlobalTypeMapper.Instance.TryGetDataTypeName(GetValueType(StaticValueType)!, Value)?.ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; return NpgsqlDbType.Unknown; } @@ -373,7 +365,7 @@ public NpgsqlDbType NpgsqlDbType if (value == NpgsqlDbType.Range) throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Range, Binary-Or with the element type (e.g. Range of integer is NpgsqlDbType.Range | NpgsqlDbType.Integer)"); - Handler = null; + ResetTypeInfo(); _npgsqlDbType = value; } } @@ -388,22 +380,25 @@ public string? DataTypeName if (_dataTypeName != null) return _dataTypeName; - if (_npgsqlDbType.HasValue) - return GlobalTypeMapper.NpgsqlDbTypeToDataTypeName(_npgsqlDbType.Value); - - if (Value != null) // Infer from value + // Map it to a display name. + if (_npgsqlDbType is { } npgsqlDbType) { - return GlobalTypeMapper.Instance.TryResolveMappingByValue(Value, out var mapping) - ? mapping.DataTypeName - : null; + var unqualifiedName = npgsqlDbType.ToUnqualifiedDataTypeName(); + return unqualifiedName is null ? null : Internal.Postgres.DataTypeName.ValidatedName( + "pg_catalog." + unqualifiedName).UnqualifiedDisplayName; } + // Infer from value but don't cache + if (Value is not null) + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + return GlobalTypeMapper.Instance.TryGetDataTypeName(GetValueType(StaticValueType)!, Value)?.DisplayName; + return null; } set { + ResetTypeInfo(); _dataTypeName = value; - Handler = null; } } @@ -431,11 +426,7 @@ public string? DataTypeName public new byte Precision { get => _precision; - set - { - _precision = value; - Handler = null; - } + set => _precision = value; } /// @@ -447,11 +438,7 @@ public string? DataTypeName public new byte Scale { get => _scale; - set - { - _scale = value; - Handler = null; - } + set => _scale = value; } #pragma warning restore CS0109 @@ -466,8 +453,8 @@ public sealed override int Size if (value < -1) throw new ArgumentException($"Invalid parameter Size value '{value}'. The value must be greater than or equal to 0."); + ResetBindingInfo(); _size = value; - Handler = null; } } @@ -506,60 +493,247 @@ public sealed override string SourceColumn #region Internals - internal virtual void ResolveHandler(TypeMapper typeMapper) - { - if (Handler is not null) - return; + private protected virtual Type StaticValueType => typeof(object); - Resolve(typeMapper); + Type? GetValueType(Type staticValueType) => staticValueType != typeof(object) ? staticValueType : Value?.GetType(); - void Resolve(TypeMapper typeMapper) + /// Attempt to resolve a type info based on available (postgres) type information on the parameter. + internal void ResolveTypeInfo(PgSerializerOptions options) + { + var previouslyBound = TypeInfo?.Options == options; + if (!previouslyBound) { - if (_npgsqlDbType.HasValue) - Handler = typeMapper.ResolveByNpgsqlDbType(_npgsqlDbType.Value); + var staticValueType = StaticValueType; + var valueType = GetValueType(StaticValueType); + + string? dataTypeName = null; + DataTypeName? builtinDataTypeName = null; + if (_npgsqlDbType is { } npgsqlDbType) + { + dataTypeName = npgsqlDbType.ToUnqualifiedDataTypeNameOrThrow(); + builtinDataTypeName = npgsqlDbType.ToDataTypeName(); + } else if (_dataTypeName is not null) - Handler = typeMapper.ResolveByDataTypeName(_dataTypeName); - else if (_value is not null) - Handler = typeMapper.ResolveByValue(_value); - else - ThrowInvalidOperationException(); + { + dataTypeName = Internal.Postgres.DataTypeName.NormalizeName(_dataTypeName); + // If we can find a match in an NpgsqlDbType we known we're dealing with a fully qualified built-in data type name. + builtinDataTypeName = NpgsqlDbTypeExtensions.ToNpgsqlDbType(dataTypeName)?.ToDataTypeName(); + } + + var pgTypeId = dataTypeName is null + ? (PgTypeId?)null + : TryGetRepresentationalTypeId(builtinDataTypeName ?? dataTypeName, out var id) + ? id + : throw new NotSupportedException(_npgsqlDbType is not null + ? $"The NpgsqlDbType '{_npgsqlDbType}' isn't present in your database. You may need to install an extension or upgrade to a newer version." + : $"The data type name '{builtinDataTypeName ?? dataTypeName}' isn't present in your database. You may need to install an extension or upgrade to a newer version."); + + if (staticValueType == typeof(object)) + { + if (valueType == null && pgTypeId is null) + { + var parameterName = !string.IsNullOrEmpty(ParameterName) ? ParameterName : $"${Collection?.IndexOf(this) + 1}"; + ThrowHelper.ThrowInvalidOperationException( + $"Parameter '{parameterName}' must have either its NpgsqlDbType or its DataTypeName or its Value set."); + return; + } + + // We treat object typed DBNull values as default info. + if (valueType == typeof(DBNull)) + { + valueType = null; + pgTypeId ??= options.ToCanonicalTypeId(options.UnknownPgType); + } + } + + TypeInfo = AdoSerializerHelpers.GetTypeInfoForWriting(valueType, pgTypeId, options, _npgsqlDbType); + } + + // This step isn't part of BindValue because we need to know the PgTypeId beforehand for things like SchemaOnly with null values. + // We never reuse resolutions for resolvers across executions as a mutable value itself may influence the result. + // TODO we could expose a property on a Converter/TypeInfo to indicate whether it's immutable, at that point we can reuse. + if (!previouslyBound || TypeInfo is PgResolverTypeInfo) + { + ResetConverterResolution(); + var resolution = ResolveConverter(TypeInfo!); + Converter = resolution.Converter; + PgTypeId = resolution.PgTypeId; } - void ThrowInvalidOperationException() + bool TryGetRepresentationalTypeId(string dataTypeName, out PgTypeId pgTypeId) { - var parameterName = !string.IsNullOrEmpty(ParameterName) ? ParameterName : $"${Collection?.IndexOf(this) + 1}"; - ThrowHelper.ThrowInvalidOperationException($"Parameter '{parameterName}' must have either its NpgsqlDbType or its DataTypeName or its Value set"); + if (options.DatabaseInfo.TryGetPostgresTypeByName(dataTypeName, out var pgType)) + { + pgTypeId = options.ToCanonicalTypeId(pgType.GetRepresentationalType()); + return true; + } + + pgTypeId = default; + return false; } } - internal void Bind(TypeMapper typeMapper) + // Pull from Value so we also support object typed generic params. + private protected virtual PgConverterResolution ResolveConverter(PgTypeInfo typeInfo) => typeInfo.GetObjectResolution(Value); + + /// Bind the current value to the type info, truncate (if applicable), take its size, and do any final validation before writing. + internal void Bind(out DataFormat format, out Size size) { - ResolveHandler(typeMapper); - FormatCode = Handler!.PreferTextWrite ? FormatCode.Text : FormatCode.Binary; + if (TypeInfo is null) + ThrowHelper.ThrowInvalidOperationException($"Missing type info, {nameof(ResolveTypeInfo)} needs to be called before {nameof(Bind)}."); + + if (!TypeInfo.SupportsWriting) + ThrowHelper.ThrowNotSupportedException($"Cannot write values for parameters of type '{TypeInfo.Type}' and postgres type '{TypeInfo.Options.DatabaseInfo.GetDataTypeName(PgTypeId).DisplayName}'."); + + // We might call this twice, once during validation and once during WriteBind, only compute things once. + if (WriteSize is not null) + { + format = Format; + size = WriteSize.Value; + return; + } + + // Handle Size truncate behavior for a predetermined set of types and pg types. + // Doesn't matter if we 'box' Value, all supported types are reference types. + if (_size > 0 && Converter!.TypeToConvert is var type && + (type == typeof(string) || type == typeof(char[]) || type == typeof(byte[]) || type == typeof(Stream)) && + Value is { } value) + { + var dataTypeName = TypeInfo!.Options.GetDataTypeName(PgTypeId); + if (dataTypeName == DataTypeNames.Text || dataTypeName == DataTypeNames.Varchar || dataTypeName == DataTypeNames.Bpchar) + { + if (value is string s && s.Length > _size) + Value = s.Substring(0, _size); + else if (value is char[] chars && chars.Length > _size) + { + var truncated = new char[_size]; + Array.Copy(chars, truncated, _size); + Value = truncated; + } + } + else if (dataTypeName == DataTypeNames.Bytea) + { + if (value is byte[] bytes && bytes.Length > _size) + { + var truncated = new byte[_size]; + Array.Copy(bytes, truncated, _size); + Value = truncated; + } + else if (value is Stream) + _useSubStream = true; + } + } + + BindCore(); + format = Format; + size = WriteSize!.Value; + } + + private protected virtual void BindCore(bool allowNullReference = false) + { + // Pull from Value so we also support object typed generic params. + var value = Value; + if (value is null && !allowNullReference) + ThrowHelper.ThrowInvalidOperationException($"Parameter '{ParameterName}' cannot be null, DBNull.Value should be used instead."); + + if (_useSubStream && value is not null) + value = _subStream = new SubReadStream((Stream)value, _size); + + if (TypeInfo!.BindObject(Converter!, value, out var size, out _writeState, out var dataFormat) is { } info) + { + WriteSize = size; + _bufferRequirement = info.BufferRequirement; + } + else + { + WriteSize = -1; + _bufferRequirement = default; + } + Format = dataFormat; } - internal virtual int ValidateAndGetLength() + internal async ValueTask Write(bool async, PgWriter writer, CancellationToken cancellationToken) { - if (_value is DBNull) - return 0; - if (_value == null) - ThrowHelper.ThrowInvalidCastException("Parameter {0} must be set", ParameterName); - - var lengthCache = LengthCache; - var len = Handler!.ValidateObjectAndGetLength(_value, ref lengthCache, this); - LengthCache = lengthCache; - return len; + if (WriteSize is not { } writeSize) + { + ThrowHelper.ThrowInvalidOperationException("Missing type info or binding info."); + return; + } + + try + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken); + + writer.WriteInt32(writeSize.Value); + if (writeSize.Value is -1) + { + writer.Commit(sizeof(int)); + return; + } + + var current = new ValueMetadata + { + Format = Format, + BufferRequirement = _bufferRequirement, + Size = writeSize, + WriteState = _writeState + }; + await writer.BeginWrite(async, current, cancellationToken).ConfigureAwait(false); + await WriteValue(async, writer, cancellationToken); + writer.Commit(writeSize.Value + sizeof(int)); + } + finally + { + ResetBindingInfo(); + } } - internal virtual Task WriteWithLength(NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - => Handler!.WriteObjectWithLength(_value!, buf, LengthCache, this, async, cancellationToken); + private protected virtual ValueTask WriteValue(bool async, PgWriter writer, CancellationToken cancellationToken) + { + // Pull from Value so we also support base calls from generic parameters. + var value = (_useSubStream ? _subStream : Value)!; + if (async) + return Converter!.WriteAsObjectAsync(writer, value, cancellationToken); + + Converter!.WriteAsObject(writer, value); + return new(); + } /// public override void ResetDbType() { _npgsqlDbType = null; _dataTypeName = null; - Handler = null; + ResetTypeInfo(); + } + + private protected void ResetTypeInfo() + { + TypeInfo = null; + ResetConverterResolution(); + } + + void ResetConverterResolution() + { + Converter = null; + PgTypeId = default; + ResetBindingInfo(); + } + + void ResetBindingInfo() + { + if (_writeState is not null) + TypeInfo?.DisposeWriteState(_writeState); + if (_useSubStream) + { + _useSubStream = false; + _subStream?.Dispose(); + _subStream = null; + } + WriteSize = null; + Format = default; + _bufferRequirement = default; } internal bool IsInputDirection => Direction == ParameterDirection.InputOutput || Direction == ParameterDirection.Input; @@ -599,4 +773,4 @@ private protected virtual NpgsqlParameter CloneCore() => object ICloneable.Clone() => Clone(); #endregion -} \ No newline at end of file +} diff --git a/src/Npgsql/NpgsqlParameterCollection.cs b/src/Npgsql/NpgsqlParameterCollection.cs index 3f1e139b08..58c0315753 100644 --- a/src/Npgsql/NpgsqlParameterCollection.cs +++ b/src/Npgsql/NpgsqlParameterCollection.cs @@ -5,9 +5,7 @@ using System.Data.Common; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using Npgsql.Internal.TypeMapping; -using Npgsql.TypeMapping; +using Npgsql.Internal; using NpgsqlTypes; namespace Npgsql; @@ -38,7 +36,7 @@ static NpgsqlParameterCollection() /// /// Initializes a new instance of the NpgsqlParameterCollection class. /// - internal NpgsqlParameterCollection() {} + internal NpgsqlParameterCollection() { } bool LookupEnabled => InternalList.Count >= LookupThreshold; @@ -681,14 +679,15 @@ internal void CloneTo(NpgsqlParameterCollection other) } } - internal void ProcessParameters(TypeMapper typeMapper, bool validateValues, CommandType commandType) + internal void ProcessParameters(PgSerializerOptions options, bool validateValues, CommandType commandType) { HasOutputParameters = false; PlaceholderType = PlaceholderType.NoParameters; - for (var i = 0; i < InternalList.Count; i++) + var list = InternalList; + for (var i = 0; i < list.Count; i++) { - var p = InternalList[i]; + var p = list[i]; switch (PlaceholderType) { @@ -737,12 +736,11 @@ internal void ProcessParameters(TypeMapper typeMapper, bool validateValues, Comm break; } - p.Bind(typeMapper); + p.ResolveTypeInfo(options); if (validateValues) { - p.LengthCache?.Clear(); - p.ValidateAndGetLength(); + p.Bind(out _, out _); } } } diff --git a/src/Npgsql/NpgsqlParameter`.cs b/src/Npgsql/NpgsqlParameter`.cs index a0487a9aec..18ac5aff45 100644 --- a/src/Npgsql/NpgsqlParameter`.cs +++ b/src/Npgsql/NpgsqlParameter`.cs @@ -1,12 +1,11 @@ using System; using System.Data; +using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; -using Npgsql.Internal.TypeMapping; -using Npgsql.TypeMapping; using NpgsqlTypes; -using static Npgsql.Util.Statics; namespace Npgsql; @@ -17,10 +16,21 @@ namespace Npgsql; /// The type of the value that will be stored in the parameter. public sealed class NpgsqlParameter : NpgsqlParameter { + T? _typedValue; + /// /// Gets or sets the strongly-typed value of the parameter. /// - public T? TypedValue { get; set; } + public T? TypedValue + { + get => _typedValue; + set + { + if (typeof(T) == typeof(object) && (value is null || _typedValue?.GetType() != value.GetType())) + ResetTypeInfo(); + _typedValue = value; + } + } /// /// Gets or sets the value of the parameter. This delegates to . @@ -31,12 +41,14 @@ public override object? Value set => TypedValue = (T)value!; } + private protected override Type StaticValueType => typeof(T); + #region Constructors /// /// Initializes a new instance of . /// - public NpgsqlParameter() {} + public NpgsqlParameter() { } /// /// Initializes a new instance of with a parameter name and value. @@ -67,33 +79,45 @@ public NpgsqlParameter(string parameterName, DbType dbType) #endregion Constructors - internal override void ResolveHandler(TypeMapper typeMapper) + private protected override PgConverterResolution ResolveConverter(PgTypeInfo typeInfo) + => typeInfo.IsBoxing ? base.ResolveConverter(typeInfo) : typeInfo.GetResolution(TypedValue); + + private protected override void BindCore(bool allowNullReference = false) { - if (Handler is not null) + // If we're object typed we should support DBNull, call into base BindCore. + if (typeof(T) == typeof(object) || TypeInfo!.IsBoxing || _useSubStream) + { + base.BindCore(TypeInfo!.IsBoxing || _useSubStream || allowNullReference); return; + } - // TODO: Better exceptions in case of cast failure etc. - if (_npgsqlDbType.HasValue) - Handler = typeMapper.ResolveByNpgsqlDbType(_npgsqlDbType.Value); - else if (_dataTypeName is not null) - Handler = typeMapper.ResolveByDataTypeName(_dataTypeName); + var value = TypedValue; + Debug.Assert(Converter is PgConverter); + if (TypeInfo!.Bind(Unsafe.As>(Converter), value, out var size, out _writeState, out var dataFormat) is { } info) + { + WriteSize = size; + _bufferRequirement = info.BufferRequirement; + } else - Handler = typeMapper.ResolveByValue(TypedValue); + { + WriteSize = -1; + _bufferRequirement = default; + } + Format = dataFormat; } - internal override int ValidateAndGetLength() + private protected override ValueTask WriteValue(bool async, PgWriter writer, CancellationToken cancellationToken) { - if (TypedValue is null or DBNull) - return 0; + if (TypeInfo!.IsBoxing || _useSubStream) + return base.WriteValue(async, writer, cancellationToken); - var lengthCache = LengthCache; - var len = Handler!.ValidateAndGetLength(TypedValue, ref lengthCache, this); - LengthCache = lengthCache; - return len; - } + Debug.Assert(Converter is PgConverter); + if (async) + return Unsafe.As>(Converter!).WriteAsync(writer, TypedValue!, cancellationToken); - internal override Task WriteWithLength(NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - => Handler!.WriteWithLength(TypedValue, buf, LengthCache, this, async, cancellationToken); + Unsafe.As>(Converter!).Write(writer, TypedValue!); + return new(); + } private protected override NpgsqlParameter CloneCore() => // use fields instead of properties @@ -114,4 +138,4 @@ private protected override NpgsqlParameter CloneCore() => TypedValue = TypedValue, SourceColumnNullMapping = SourceColumnNullMapping, }; -} \ No newline at end of file +} diff --git a/src/Npgsql/NpgsqlSchema.cs b/src/Npgsql/NpgsqlSchema.cs index e8d65ecbf1..461ae2e873 100644 --- a/src/Npgsql/NpgsqlSchema.cs +++ b/src/Npgsql/NpgsqlSchema.cs @@ -6,6 +6,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Npgsql.Internal; using Npgsql.PostgresTypes; using NpgsqlTypes; @@ -556,106 +557,110 @@ static DataTable GetDataTypes(NpgsqlConnection conn) // Npgsql-specific table.Columns.Add("OID", typeof(uint)); - // TODO: Support type name restriction - foreach (var baseType in connector.DatabaseInfo.BaseTypes.Cast() - .Concat(connector.DatabaseInfo.EnumTypes) - .Concat(connector.DatabaseInfo.CompositeTypes)) + // TODO: Support type name restriction + try { - if (!connector.TypeMapper.TryGetMapping(baseType, out var mapping)) - continue; + PgSerializerOptions.IntrospectionCaller = true; + foreach (var baseType in connector.DatabaseInfo.BaseTypes.Cast() + .Concat(connector.DatabaseInfo.EnumTypes) + .Concat(connector.DatabaseInfo.CompositeTypes)) + { + if (connector.SerializerOptions.GetDefaultTypeInfo(baseType) is not { } info) + continue; - var row = table.Rows.Add(); + var row = table.Rows.Add(); - PopulateDefaultDataTypeInfo(row, baseType); - PopulateHardcodedDataTypeInfo(row, baseType); + PopulateDefaultDataTypeInfo(row, baseType); + PopulateHardcodedDataTypeInfo(row, baseType); - if (mapping.ClrTypes.Length > 0) - row["DataType"] = mapping.ClrTypes[0].FullName; - if (mapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)mapping.NpgsqlDbType.Value; - } + row["DataType"] = info.Type.FullName; + if (baseType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; + } - foreach (var arrayType in connector.DatabaseInfo.ArrayTypes) - { - if (!connector.TypeMapper.TryGetMapping(arrayType.Element, out var elementMapping)) - continue; - - var row = table.Rows.Add(); - - PopulateDefaultDataTypeInfo(row, arrayType.Element); - // Populate hardcoded values based on the element type (e.g. citext[] is case-insensitive). - PopulateHardcodedDataTypeInfo(row, arrayType.Element); - - row["TypeName"] = arrayType.DisplayName; - row["OID"] = arrayType.OID; - row["CreateFormat"] += "[]"; - if (elementMapping.ClrTypes.Length > 0) - row["DataType"] = elementMapping.ClrTypes[0].MakeArrayType().FullName; - if (elementMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)(elementMapping.NpgsqlDbType.Value | NpgsqlDbType.Array); - } + foreach (var arrayType in connector.DatabaseInfo.ArrayTypes) + { + if (connector.SerializerOptions.GetDefaultTypeInfo(arrayType) is not { } info) + continue; - foreach (var rangeType in connector.DatabaseInfo.RangeTypes) - { - if (!connector.TypeMapper.TryGetMapping(rangeType.Subtype, out var subtypeMapping)) - continue; - - var row = table.Rows.Add(); - - PopulateDefaultDataTypeInfo(row, rangeType.Subtype); - // Populate hardcoded values based on the subtype type (e.g. citext[] is case-insensitive). - PopulateHardcodedDataTypeInfo(row, rangeType.Subtype); - - row["TypeName"] = rangeType.DisplayName; - row["OID"] = rangeType.OID; - row["CreateFormat"] = rangeType.DisplayName.ToUpperInvariant(); - if (subtypeMapping.ClrTypes.Length > 0) - row["DataType"] = typeof(NpgsqlRange<>).MakeGenericType(subtypeMapping.ClrTypes[0]).FullName; - if (subtypeMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)(subtypeMapping.NpgsqlDbType.Value | NpgsqlDbType.Range); - } + var row = table.Rows.Add(); - foreach (var multirangeType in connector.DatabaseInfo.MultirangeTypes) - { - var subtypeType = multirangeType.Subrange.Subtype; - if (!connector.TypeMapper.TryGetMapping(subtypeType, out var subtypeMapping)) - continue; - - var row = table.Rows.Add(); - - PopulateDefaultDataTypeInfo(row, subtypeType); - // Populate hardcoded values based on the subtype type (e.g. citext[] is case-insensitive). - PopulateHardcodedDataTypeInfo(row, subtypeType); - - row["TypeName"] = multirangeType.DisplayName; - row["OID"] = multirangeType.OID; - row["CreateFormat"] = multirangeType.DisplayName.ToUpperInvariant(); - if (subtypeMapping.ClrTypes.Length > 0) - row["DataType"] = typeof(NpgsqlRange<>).MakeGenericType(subtypeMapping.ClrTypes[0]).FullName; - if (subtypeMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)(subtypeMapping.NpgsqlDbType.Value | NpgsqlDbType.Range); - } + PopulateDefaultDataTypeInfo(row, arrayType.Element); + // Populate hardcoded values based on the element type (e.g. citext[] is case-insensitive). + PopulateHardcodedDataTypeInfo(row, arrayType.Element); - foreach (var domainType in connector.DatabaseInfo.DomainTypes) - { - if (!connector.TypeMapper.TryGetMapping(domainType, out var baseMapping)) - continue; + row["TypeName"] = arrayType.DisplayName; + row["OID"] = arrayType.OID; + row["CreateFormat"] += "[]"; + row["DataType"] = info.Type.FullName; + if (arrayType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; + } + + foreach (var rangeType in connector.DatabaseInfo.RangeTypes) + { + if (connector.SerializerOptions.GetDefaultTypeInfo(rangeType) is not { } info) + continue; - var row = table.Rows.Add(); + var row = table.Rows.Add(); - PopulateDefaultDataTypeInfo(row, domainType.BaseType); - // Populate hardcoded values based on the element type (e.g. citext[] is case-insensitive). - PopulateHardcodedDataTypeInfo(row, domainType.BaseType); - row["TypeName"] = domainType.DisplayName; - row["OID"] = domainType.OID; - // A domain is never the best match, since its underlying base type is - row["IsBestMatch"] = false; + PopulateDefaultDataTypeInfo(row, rangeType.Subtype); + // Populate hardcoded values based on the subtype type (e.g. citext[] is case-insensitive). + PopulateHardcodedDataTypeInfo(row, rangeType.Subtype); - if (baseMapping.ClrTypes.Length > 0) - row["DataType"] = baseMapping.ClrTypes[0].FullName; - if (baseMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)baseMapping.NpgsqlDbType.Value; + row["TypeName"] = rangeType.DisplayName; + row["OID"] = rangeType.OID; + row["CreateFormat"] = rangeType.DisplayName.ToUpperInvariant(); + row["DataType"] = info.Type.FullName; + if (rangeType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; + } + + foreach (var multirangeType in connector.DatabaseInfo.MultirangeTypes) + { + var subtypeType = multirangeType.Subrange.Subtype; + if (connector.SerializerOptions.GetDefaultTypeInfo(multirangeType) is not { } info) + continue; + + var row = table.Rows.Add(); + + PopulateDefaultDataTypeInfo(row, subtypeType); + // Populate hardcoded values based on the subtype type (e.g. citext[] is case-insensitive). + PopulateHardcodedDataTypeInfo(row, subtypeType); + + row["TypeName"] = multirangeType.DisplayName; + row["OID"] = multirangeType.OID; + row["CreateFormat"] = multirangeType.DisplayName.ToUpperInvariant(); + row["DataType"] = info.Type.FullName; + if (multirangeType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; + } + + foreach (var domainType in connector.DatabaseInfo.DomainTypes) + { + var representationalType = domainType.GetRepresentationalType(); + if (connector.SerializerOptions.GetDefaultTypeInfo(representationalType) is not { } info) + continue; + + var row = table.Rows.Add(); + + PopulateDefaultDataTypeInfo(row, representationalType); + // Populate hardcoded values based on the element type (e.g. citext[] is case-insensitive). + PopulateHardcodedDataTypeInfo(row, representationalType); + row["TypeName"] = domainType.DisplayName; + row["OID"] = domainType.OID; + // A domain is never the best match, since its underlying base type is + row["IsBestMatch"] = false; + + row["DataType"] = info.Type.FullName; + if (representationalType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; + } + } + finally + { + PgSerializerOptions.IntrospectionCaller = false; } return table; diff --git a/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs b/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs index 97a7dd7a34..c074e29dec 100644 --- a/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs +++ b/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs @@ -1,16 +1,15 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Net.Security; -using System.Reflection; using System.Security.Cryptography.X509Certificates; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Resolvers; using Npgsql.Properties; using Npgsql.TypeMapping; using NpgsqlTypes; @@ -26,6 +25,8 @@ namespace Npgsql; /// public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper { + static UnsupportedTypeInfoResolver UnsupportedTypeInfoResolver { get; } = new(); + ILoggerFactory? _loggerFactory; bool _sensitiveDataLoggingEnabled; @@ -36,11 +37,8 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper Func>? _periodicPasswordProvider; TimeSpan _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval; - readonly List _resolverFactories = new(); - readonly Dictionary _userTypeMappings = new(); - - /// - public INpgsqlNameTranslator DefaultNameTranslator { get; set; } = GlobalTypeMapper.Instance.DefaultNameTranslator; + readonly List _resolverChain = new(); + readonly UserTypeMapper _userTypeMapper; Action? _syncConnectionInitializer; Func? _asyncConnectionInitializer; @@ -55,6 +53,12 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper /// public string ConnectionString => ConnectionStringBuilder.ToString(); + static NpgsqlSlimDataSourceBuilder() + => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers(new [] + { + AdoTypeInfoResolver.Instance + }); + /// /// A diagnostics name used by Npgsql when generating tracing, logging and metrics. /// @@ -67,8 +71,19 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper public NpgsqlSlimDataSourceBuilder(string? connectionString = null) { ConnectionStringBuilder = new NpgsqlConnectionStringBuilder(connectionString); + _userTypeMapper = new(); + // Reverse order + AddTypeInfoResolver(UnsupportedTypeInfoResolver); + AddTypeInfoResolver(new AdoTypeInfoResolver()); + // When used publicly we start off with our slim defaults. + foreach (var plugin in GlobalTypeMapper.Instance.GetPluginResolvers().Reverse()) + AddTypeInfoResolver(plugin); + } - ResetTypeMappings(); + internal NpgsqlSlimDataSourceBuilder(NpgsqlConnectionStringBuilder connectionStringBuilder) + { + ConnectionStringBuilder = connectionStringBuilder; + _userTypeMapper = new(); } /// @@ -237,158 +252,105 @@ public NpgsqlSlimDataSourceBuilder UsePeriodicPasswordProvider( #region Type mapping /// - public void AddTypeResolverFactory(TypeHandlerResolverFactory resolverFactory) - { - var type = resolverFactory.GetType(); - - for (var i = 0; i < _resolverFactories.Count; i++) - { - if (_resolverFactories[i].GetType() == type) - { - _resolverFactories.RemoveAt(i); - break; - } - } - - _resolverFactories.Insert(0, resolverFactory); - } - - internal void AddDefaultTypeResolverFactory(TypeHandlerResolverFactory resolverFactory) - { - // For these "default" resolvers: - // 1. If they were already added in the global type mapper, we don't want to replace them (there may be custom user config, e.g. - // for JSON. - // 2. They can't be at the start, since then they'd override a user-added resolver in global (e.g. the range handler would override - // NodaTime, but NodaTime has special handling for tstzrange, mapping it to Interval in addition to NpgsqlRange). - // 3. They also can't be at the end, since then they'd be overridden by builtin (builtin has limited JSON handler, but we want - // the System.Text.Json handler to take precedence. - // So we (currently) add these at the end, but before the builtin resolver. - var type = resolverFactory.GetType(); - - // 1st pass to skip if the resolver already exists from the global type mapper - for (var i = 0; i < _resolverFactories.Count; i++) - if (_resolverFactories[i].GetType() == type) - return; - - for (var i = 0; i < _resolverFactories.Count; i++) - { - if (_resolverFactories[i] is BuiltInTypeHandlerResolverFactory) - { - _resolverFactories.Insert(i, resolverFactory); - return; - } - } - - throw new Exception("No built-in resolver factory found"); - } + public INpgsqlNameTranslator DefaultNameTranslator { get; set; } = GlobalTypeMapper.Instance.DefaultNameTranslator; /// public INpgsqlTypeMapper MapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(typeof(TEnum), nameTranslator); - - _userTypeMappings[pgName] = new UserEnumTypeMapping(pgName, nameTranslator); + _userTypeMapper.MapEnum(pgName, nameTranslator); return this; } /// public bool UnmapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum - { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(typeof(TEnum), nameTranslator); - - return _userTypeMappings.Remove(pgName); - } + => _userTypeMapper.UnmapEnum(pgName, nameTranslator); /// - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + [RequiresUnreferencedCode("Composite type mapping isn't trimming-safe.")] public INpgsqlTypeMapper MapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(typeof(T), nameTranslator); - - _userTypeMappings[pgName] = new UserCompositeTypeMapping(pgName, nameTranslator); + _userTypeMapper.MapComposite(typeof(T), pgName, nameTranslator); return this; } /// - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + [RequiresUnreferencedCode("Composite type mapping isn't trimming-safe.")] + public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _userTypeMapper.UnmapComposite(typeof(T), pgName, nameTranslator); + + /// + [RequiresUnreferencedCode("Composite type mapping isn't trimming-safe.")] public INpgsqlTypeMapper MapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { - var openMethod = typeof(NpgsqlSlimDataSourceBuilder).GetMethod(nameof(MapComposite), new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; - var method = openMethod.MakeGenericMethod(clrType); - method.Invoke(this, new object?[] { pgName, nameTranslator }); - + _userTypeMapper.MapComposite(clrType, pgName, nameTranslator); return this; } /// - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] - public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => UnmapComposite(typeof(T), pgName, nameTranslator); - - /// - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + [RequiresUnreferencedCode("Composite type mapping isn't trimming-safe.")] public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _userTypeMapper.UnmapComposite(clrType, pgName, nameTranslator); + + /// + /// Adds a type info resolver which can add or modify support for PostgreSQL types. + /// Typically used by plugins. + /// + /// The type resolver to be added. + public void AddTypeInfoResolver(IPgTypeInfoResolver resolver) { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); + var type = resolver.GetType(); - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(clrType, nameTranslator); + for (var i = 0; i < _resolverChain.Count; i++) + if (_resolverChain[i].GetType() == type) + { + _resolverChain.RemoveAt(i); + break; + } - return _userTypeMappings.Remove(pgName); + _resolverChain.Insert(0, resolver); } void INpgsqlTypeMapper.Reset() => ResetTypeMappings(); - void ResetTypeMappings() + internal void ResetTypeMappings() { - var globalMapper = GlobalTypeMapper.Instance; - globalMapper.Lock.EnterReadLock(); - try - { - _resolverFactories.Clear(); - foreach (var resolverFactory in globalMapper.HandlerResolverFactories) - _resolverFactories.Add(resolverFactory); - - _userTypeMappings.Clear(); - foreach (var kv in globalMapper.UserTypeMappings) - _userTypeMappings[kv.Key] = kv.Value; - } - finally - { - globalMapper.Lock.ExitReadLock(); - } + _resolverChain.Clear(); + _resolverChain.AddRange(GlobalTypeMapper.Instance.GetPluginResolvers()); } - static string GetPgName(Type clrType, INpgsqlNameTranslator nameTranslator) - => clrType.GetCustomAttribute()?.PgName - ?? nameTranslator.TranslateTypeName(clrType.Name); - #endregion Type mapping #region Optional opt-ins /// - /// Sets up mappings for the PostgreSQL range and multirange types. + /// Sets up mappings for the PostgreSQL array types. + /// + public NpgsqlSlimDataSourceBuilder EnableArrays() + { + AddTypeInfoResolver(new RangeArrayTypeInfoResolver()); + AddTypeInfoResolver(new ExtraConversionsArrayTypeInfoResolver()); + AddTypeInfoResolver(new AdoArrayTypeInfoResolver()); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL range types. /// public NpgsqlSlimDataSourceBuilder EnableRanges() { - AddTypeResolverFactory(new RangeTypeHandlerResolverFactory()); + AddTypeInfoResolver(new RangeTypeInfoResolver()); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL multirange types. + /// + public NpgsqlSlimDataSourceBuilder EnableMultiranges() + { + AddTypeInfoResolver(new RangeTypeInfoResolver()); return this; } @@ -407,7 +369,7 @@ public NpgsqlSlimDataSourceBuilder UseSystemTextJson( Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null) { - AddTypeResolverFactory(new SystemTextJsonTypeHandlerResolverFactory(jsonbClrTypes, jsonClrTypes, serializerOptions)); + AddTypeInfoResolver(new SystemTextJsonPocoTypeInfoResolver(jsonbClrTypes, jsonClrTypes, serializerOptions)); return this; } @@ -416,7 +378,7 @@ public NpgsqlSlimDataSourceBuilder UseSystemTextJson( /// public NpgsqlSlimDataSourceBuilder EnableRecords() { - AddTypeResolverFactory(new RecordTypeHandlerResolverFactory()); + AddTypeInfoResolver(new RecordTypeInfoResolver()); return this; } @@ -425,7 +387,25 @@ public NpgsqlSlimDataSourceBuilder EnableRecords() /// public NpgsqlSlimDataSourceBuilder EnableFullTextSearch() { - AddTypeResolverFactory(new FullTextSearchTypeHandlerResolverFactory()); + AddTypeInfoResolver(new FullTextSearchTypeInfoResolver()); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL ltree extension types. + /// + public NpgsqlSlimDataSourceBuilder EnableLTree() + { + AddTypeInfoResolver(new LTreeTypeInfoResolver()); + return this; + } + + /// + /// Sets up mappings for extra conversions from PostgreSQL to .NET types. + /// + public NpgsqlSlimDataSourceBuilder EnableExtraConversions() + { + AddTypeInfoResolver(new ExtraConversionsResolver()); return this; } @@ -536,11 +516,25 @@ _loggerFactory is null _periodicPasswordProvider, _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval, - _resolverFactories, - _userTypeMappings, + Resolvers(), DefaultNameTranslator, _syncConnectionInitializer, _asyncConnectionInitializer); + + IEnumerable Resolvers() + { + var resolvers = new List(); + + if (_userTypeMapper.Items.Count > 0) + resolvers.Add(_userTypeMapper.Build()); + + if (GlobalTypeMapper.Instance.GetUserMappingsResolver() is { } globalUserTypeMapper) + resolvers.Add(globalUserTypeMapper); + + resolvers.AddRange(_resolverChain); + + return resolvers; + } } void ValidateMultiHost() diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs b/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs index 8df0ee874f..b05f623867 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs @@ -1,6 +1,8 @@ using System; +using System.Data; using Npgsql; -using Npgsql.TypeMapping; +using Npgsql.Internal.Postgres; +using static Npgsql.Util.Statics; #pragma warning disable CA1720 @@ -635,6 +637,420 @@ public enum NpgsqlDbType #endregion } +static class NpgsqlDbTypeExtensions +{ + internal static NpgsqlDbType? ToNpgsqlDbType(this DbType dbType) + => dbType switch + { + DbType.AnsiString => NpgsqlDbType.Text, + DbType.Binary => NpgsqlDbType.Bytea, + DbType.Byte => NpgsqlDbType.Smallint, + DbType.Boolean => NpgsqlDbType.Boolean, + DbType.Currency => NpgsqlDbType.Money, + DbType.Date => NpgsqlDbType.Date, + DbType.DateTime => LegacyTimestampBehavior ? NpgsqlDbType.Timestamp : NpgsqlDbType.TimestampTz, + DbType.Decimal => NpgsqlDbType.Numeric, + DbType.VarNumeric => NpgsqlDbType.Numeric, + DbType.Double => NpgsqlDbType.Double, + DbType.Guid => NpgsqlDbType.Uuid, + DbType.Int16 => NpgsqlDbType.Smallint, + DbType.Int32 => NpgsqlDbType.Integer, + DbType.Int64 => NpgsqlDbType.Bigint, + DbType.Single => NpgsqlDbType.Real, + DbType.String => NpgsqlDbType.Text, + DbType.Time => NpgsqlDbType.Time, + DbType.AnsiStringFixedLength => NpgsqlDbType.Text, + DbType.StringFixedLength => NpgsqlDbType.Text, + DbType.Xml => NpgsqlDbType.Xml, + DbType.DateTime2 => NpgsqlDbType.Timestamp, + DbType.DateTimeOffset => NpgsqlDbType.TimestampTz, + + DbType.Object => null, + DbType.SByte => null, + DbType.UInt16 => null, + DbType.UInt32 => null, + DbType.UInt64 => null, + + _ => throw new ArgumentOutOfRangeException(nameof(dbType), dbType, null) + }; + + public static DbType ToDbType(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType switch + { + // Numeric types + NpgsqlDbType.Smallint => DbType.Int16, + NpgsqlDbType.Integer => DbType.Int32, + NpgsqlDbType.Bigint => DbType.Int64, + NpgsqlDbType.Real => DbType.Single, + NpgsqlDbType.Double => DbType.Double, + NpgsqlDbType.Numeric => DbType.Decimal, + NpgsqlDbType.Money => DbType.Currency, + + // Text types + NpgsqlDbType.Text => DbType.String, + NpgsqlDbType.Xml => DbType.Xml, + NpgsqlDbType.Varchar => DbType.String, + NpgsqlDbType.Char => DbType.String, + NpgsqlDbType.Name => DbType.String, + NpgsqlDbType.Citext => DbType.String, + NpgsqlDbType.Refcursor => DbType.Object, + NpgsqlDbType.Jsonb => DbType.Object, + NpgsqlDbType.Json => DbType.Object, + NpgsqlDbType.JsonPath => DbType.Object, + + // Date/time types + NpgsqlDbType.Timestamp => LegacyTimestampBehavior ? DbType.DateTime : DbType.DateTime2, + NpgsqlDbType.TimestampTz => LegacyTimestampBehavior ? DbType.DateTimeOffset : DbType.DateTime, + NpgsqlDbType.Date => DbType.Date, + NpgsqlDbType.Time => DbType.Time, + + // Misc data types + NpgsqlDbType.Bytea => DbType.Binary, + NpgsqlDbType.Boolean => DbType.Boolean, + NpgsqlDbType.Uuid => DbType.Guid, + + NpgsqlDbType.Unknown => DbType.Object, + + _ => DbType.Object + }; + + /// Can return null when a custom range type is used. + internal static string? ToUnqualifiedDataTypeName(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType switch + { + // Numeric types + NpgsqlDbType.Smallint => "int2", + NpgsqlDbType.Integer => "int4", + NpgsqlDbType.Bigint => "int8", + NpgsqlDbType.Real => "float4", + NpgsqlDbType.Double => "float8", + NpgsqlDbType.Numeric => "numeric", + NpgsqlDbType.Money => "money", + + // Text types + NpgsqlDbType.Text => "text", + NpgsqlDbType.Xml => "xml", + NpgsqlDbType.Varchar => "varchar", + NpgsqlDbType.Char => "bpchar", + NpgsqlDbType.Name => "name", + NpgsqlDbType.Refcursor => "refcursor", + NpgsqlDbType.Jsonb => "jsonb", + NpgsqlDbType.Json => "json", + NpgsqlDbType.JsonPath => "jsonpath", + + // Date/time types + NpgsqlDbType.Timestamp => "timestamp", + NpgsqlDbType.TimestampTz => "timestamptz", + NpgsqlDbType.Date => "date", + NpgsqlDbType.Time => "time", + NpgsqlDbType.TimeTz => "timetz", + NpgsqlDbType.Interval => "interval", + + // Network types + NpgsqlDbType.Cidr => "cidr", + NpgsqlDbType.Inet => "inet", + NpgsqlDbType.MacAddr => "macaddr", + NpgsqlDbType.MacAddr8 => "macaddr8", + + // Full-text search types + NpgsqlDbType.TsQuery => "tsquery", + NpgsqlDbType.TsVector => "tsvector", + + // Geometry types + NpgsqlDbType.Box => "box", + NpgsqlDbType.Circle => "circle", + NpgsqlDbType.Line => "line", + NpgsqlDbType.LSeg => "lseg", + NpgsqlDbType.Path => "path", + NpgsqlDbType.Point => "point", + NpgsqlDbType.Polygon => "polygon", + + + // UInt types + NpgsqlDbType.Oid => "oid", + NpgsqlDbType.Xid => "xid", + NpgsqlDbType.Xid8 => "xid8", + NpgsqlDbType.Cid => "cid", + NpgsqlDbType.Regtype => "regtype", + NpgsqlDbType.Regconfig => "regconfig", + + // Misc types + NpgsqlDbType.Boolean => "bool", + NpgsqlDbType.Bytea => "bytea", + NpgsqlDbType.Uuid => "uuid", + NpgsqlDbType.Varbit => "varbit", + NpgsqlDbType.Bit => "bit", + + // Built-in range types + NpgsqlDbType.IntegerRange => "int4range", + NpgsqlDbType.BigIntRange => "int8range", + NpgsqlDbType.NumericRange => "numrange", + NpgsqlDbType.TimestampRange => "tsrange", + NpgsqlDbType.TimestampTzRange => "tstzrange", + NpgsqlDbType.DateRange => "daterange", + + // Built-in multirange types + NpgsqlDbType.IntegerMultirange => "int4multirange", + NpgsqlDbType.BigIntMultirange => "int8multirange", + NpgsqlDbType.NumericMultirange => "nummultirange", + NpgsqlDbType.TimestampMultirange => "tsmultirange", + NpgsqlDbType.TimestampTzMultirange => "tstzmultirange", + NpgsqlDbType.DateMultirange => "datemultirange", + + // Internal types + NpgsqlDbType.Int2Vector => "int2vector", + NpgsqlDbType.Oidvector => "oidvector", + NpgsqlDbType.PgLsn => "pg_lsn", + NpgsqlDbType.Tid => "tid", + NpgsqlDbType.InternalChar => "char", + + // Plugin types + NpgsqlDbType.Citext => "citext", + NpgsqlDbType.LQuery => "lquery", + NpgsqlDbType.LTree => "ltree", + NpgsqlDbType.LTxtQuery => "ltxtquery", + NpgsqlDbType.Hstore => "hstore", + NpgsqlDbType.Geometry => "geometry", + NpgsqlDbType.Geography => "geography", + + NpgsqlDbType.Unknown => "unknown", + + // Unknown cannot be composed + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) && (npgsqlDbType & ~NpgsqlDbType.Array) == NpgsqlDbType.Unknown + => "unknown", + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Range) && (npgsqlDbType & ~NpgsqlDbType.Range) == NpgsqlDbType.Unknown + => "unknown", + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) && (npgsqlDbType & ~NpgsqlDbType.Multirange) == NpgsqlDbType.Unknown + => "unknown", + + _ => npgsqlDbType.HasFlag(NpgsqlDbType.Array) + ? ToUnqualifiedDataTypeName(npgsqlDbType & ~NpgsqlDbType.Array) is { } name ? "_" + name : null + : null // e.g. ranges + }; + + internal static string ToUnqualifiedDataTypeNameOrThrow(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType.ToUnqualifiedDataTypeName() ?? throw new ArgumentOutOfRangeException(nameof(npgsqlDbType), npgsqlDbType, "Cannot convert NpgsqlDbType to DataTypeName"); + + /// Can return null when a plugin type or custom range type is used. + internal static DataTypeName? ToDataTypeName(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType switch + { + // Numeric types + NpgsqlDbType.Smallint => DataTypeNames.Int2, + NpgsqlDbType.Integer => DataTypeNames.Int4, + NpgsqlDbType.Bigint => DataTypeNames.Int8, + NpgsqlDbType.Real => DataTypeNames.Float4, + NpgsqlDbType.Double => DataTypeNames.Float8, + NpgsqlDbType.Numeric => DataTypeNames.Numeric, + NpgsqlDbType.Money => DataTypeNames.Money, + + // Text types + NpgsqlDbType.Text => DataTypeNames.Text, + NpgsqlDbType.Xml => DataTypeNames.Xml, + NpgsqlDbType.Varchar => DataTypeNames.Varchar, + NpgsqlDbType.Char => DataTypeNames.Bpchar, + NpgsqlDbType.Name => DataTypeNames.Name, + NpgsqlDbType.Refcursor => DataTypeNames.RefCursor, + NpgsqlDbType.Jsonb => DataTypeNames.Jsonb, + NpgsqlDbType.Json => DataTypeNames.Json, + NpgsqlDbType.JsonPath => DataTypeNames.Jsonpath, + + // Date/time types + NpgsqlDbType.Timestamp => DataTypeNames.Timestamp, + NpgsqlDbType.TimestampTz => DataTypeNames.TimestampTz, + NpgsqlDbType.Date => DataTypeNames.Date, + NpgsqlDbType.Time => DataTypeNames.Time, + NpgsqlDbType.TimeTz => DataTypeNames.TimeTz, + NpgsqlDbType.Interval => DataTypeNames.Interval, + + // Network types + NpgsqlDbType.Cidr => DataTypeNames.Cidr, + NpgsqlDbType.Inet => DataTypeNames.Inet, + NpgsqlDbType.MacAddr => DataTypeNames.MacAddr, + NpgsqlDbType.MacAddr8 => DataTypeNames.MacAddr8, + + // Full-text search types + NpgsqlDbType.TsQuery => DataTypeNames.TsQuery, + NpgsqlDbType.TsVector => DataTypeNames.TsVector, + + // Geometry types + NpgsqlDbType.Box => DataTypeNames.Box, + NpgsqlDbType.Circle => DataTypeNames.Circle, + NpgsqlDbType.Line => DataTypeNames.Line, + NpgsqlDbType.LSeg => DataTypeNames.LSeg, + NpgsqlDbType.Path => DataTypeNames.Path, + NpgsqlDbType.Point => DataTypeNames.Point, + NpgsqlDbType.Polygon => DataTypeNames.Polygon, + + // UInt types + NpgsqlDbType.Oid => DataTypeNames.Oid, + NpgsqlDbType.Xid => DataTypeNames.Xid, + NpgsqlDbType.Xid8 => DataTypeNames.Xid8, + NpgsqlDbType.Cid => DataTypeNames.Cid, + NpgsqlDbType.Regtype => DataTypeNames.RegType, + NpgsqlDbType.Regconfig => DataTypeNames.RegConfig, + + // Misc types + NpgsqlDbType.Boolean => DataTypeNames.Bool, + NpgsqlDbType.Bytea => DataTypeNames.Bytea, + NpgsqlDbType.Uuid => DataTypeNames.Uuid, + NpgsqlDbType.Varbit => DataTypeNames.Varbit, + NpgsqlDbType.Bit => DataTypeNames.Bit, + + // Built-in range types + NpgsqlDbType.IntegerRange => DataTypeNames.Int4Range, + NpgsqlDbType.BigIntRange => DataTypeNames.Int8Range, + NpgsqlDbType.NumericRange => DataTypeNames.NumRange, + NpgsqlDbType.TimestampRange => DataTypeNames.TsRange, + NpgsqlDbType.TimestampTzRange => DataTypeNames.TsTzRange, + NpgsqlDbType.DateRange => DataTypeNames.DateRange, + + // Internal types + NpgsqlDbType.Int2Vector => DataTypeNames.Int2Vector, + NpgsqlDbType.Oidvector => DataTypeNames.OidVector, + NpgsqlDbType.PgLsn => DataTypeNames.PgLsn, + NpgsqlDbType.Tid => DataTypeNames.Tid, + NpgsqlDbType.InternalChar => DataTypeNames.Char, + + // Special types + NpgsqlDbType.Unknown => DataTypeNames.Unknown, + + // Unknown cannot be composed + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) && (npgsqlDbType & ~NpgsqlDbType.Array) == NpgsqlDbType.Unknown + => DataTypeNames.Unknown, + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Range) && (npgsqlDbType & ~NpgsqlDbType.Range) == NpgsqlDbType.Unknown + => DataTypeNames.Unknown, + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) && (npgsqlDbType & ~NpgsqlDbType.Multirange) == NpgsqlDbType.Unknown + => DataTypeNames.Unknown, + + // If both multirange and array are set we first remove array, so array is added to the outermost datatypename. + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) + => ToDataTypeName(npgsqlDbType & ~NpgsqlDbType.Array)?.ToArrayName(), + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) + => ToDataTypeName((npgsqlDbType | NpgsqlDbType.Range) & ~NpgsqlDbType.Multirange)?.ToDefaultMultirangeName(), + + // Plugin types don't have a stable fully qualified name. + _ => null + }; + + internal static NpgsqlDbType? ToNpgsqlDbType(this DataTypeName dataTypeName) => ToNpgsqlDbType(dataTypeName.UnqualifiedName); + /// Should not be used with display names, first normalize it instead. + internal static NpgsqlDbType? ToNpgsqlDbType(string dataTypeName) + { + var unqualifiedName = dataTypeName; + if (dataTypeName.IndexOf(".", StringComparison.Ordinal) is not -1 and var index) + unqualifiedName = dataTypeName.Substring(0, index); + + return unqualifiedName switch + { + // Numeric types + "int2" => NpgsqlDbType.Smallint, + "int4" => NpgsqlDbType.Integer, + "int8" => NpgsqlDbType.Bigint, + "float4" => NpgsqlDbType.Real, + "float8" => NpgsqlDbType.Double, + "numeric" => NpgsqlDbType.Numeric, + "money" => NpgsqlDbType.Money, + + // Text types + "text" => NpgsqlDbType.Text, + "xml" => NpgsqlDbType.Xml, + "varchar" => NpgsqlDbType.Varchar, + "bpchar" => NpgsqlDbType.Char, + "name" => NpgsqlDbType.Name, + "refcursor" => NpgsqlDbType.Refcursor, + "jsonb" => NpgsqlDbType.Jsonb, + "json" => NpgsqlDbType.Json, + "jsonpath" => NpgsqlDbType.JsonPath, + + // Date/time types + "timestamp" => NpgsqlDbType.Timestamp, + "timestamptz" => NpgsqlDbType.TimestampTz, + "date" => NpgsqlDbType.Date, + "time" => NpgsqlDbType.Time, + "timetz" => NpgsqlDbType.TimeTz, + "interval" => NpgsqlDbType.Interval, + + // Network types + "cidr" => NpgsqlDbType.Cidr, + "inet" => NpgsqlDbType.Inet, + "macaddr" => NpgsqlDbType.MacAddr, + "macaddr8" => NpgsqlDbType.MacAddr8, + + // Full-text search types + "tsquery" => NpgsqlDbType.TsQuery, + "tsvector" => NpgsqlDbType.TsVector, + + // Geometry types + "box" => NpgsqlDbType.Box, + "circle" => NpgsqlDbType.Circle, + "line" => NpgsqlDbType.Line, + "lseg" => NpgsqlDbType.LSeg, + "path" => NpgsqlDbType.Path, + "point" => NpgsqlDbType.Point, + "polygon" => NpgsqlDbType.Polygon, + + // UInt types + "oid" => NpgsqlDbType.Oid, + "xid" => NpgsqlDbType.Xid, + "xid8" => NpgsqlDbType.Xid8, + "cid" => NpgsqlDbType.Cid, + "regtype" => NpgsqlDbType.Regtype, + "regconfig" => NpgsqlDbType.Regconfig, + + // Misc types + "bool" => NpgsqlDbType.Boolean, + "bytea" => NpgsqlDbType.Bytea, + "uuid" => NpgsqlDbType.Uuid, + "varbit" => NpgsqlDbType.Varbit, + "bit" => NpgsqlDbType.Bit, + + // Built-in range types + "int4range" => NpgsqlDbType.IntegerRange, + "int8range" => NpgsqlDbType.BigIntRange, + "numrange" => NpgsqlDbType.NumericRange, + "tsrange" => NpgsqlDbType.TimestampRange, + "tstzrange" => NpgsqlDbType.TimestampTzRange, + "daterange" => NpgsqlDbType.DateRange, + + // Built-in multirange types + "int4multirange" => NpgsqlDbType.IntegerMultirange, + "int8multirange" => NpgsqlDbType.BigIntMultirange, + "nummultirange" => NpgsqlDbType.NumericMultirange, + "tsmultirange" => NpgsqlDbType.TimestampMultirange, + "tstzmultirange" => NpgsqlDbType.TimestampTzMultirange, + "datemultirange" => NpgsqlDbType.DateMultirange, + + // Internal types + "int2vector" => NpgsqlDbType.Int2Vector, + "oidvector" => NpgsqlDbType.Oidvector, + "pg_lsn" => NpgsqlDbType.PgLsn, + "tid" => NpgsqlDbType.Tid, + "char" => NpgsqlDbType.InternalChar, + + // Plugin types + "citext" => NpgsqlDbType.Citext, + "lquery" => NpgsqlDbType.LQuery, + "ltree" => NpgsqlDbType.LTree, + "ltxtquery" => NpgsqlDbType.LTxtQuery, + "hstore" => NpgsqlDbType.Hstore, + "geometry" => NpgsqlDbType.Geometry, + "geography" => NpgsqlDbType.Geography, + + _ when unqualifiedName.Contains("unknown") + => !unqualifiedName.StartsWith("_", StringComparison.Ordinal) + ? NpgsqlDbType.Unknown + : null, + _ when unqualifiedName.StartsWith("_", StringComparison.Ordinal) + => ToNpgsqlDbType(unqualifiedName.Substring(1)) is { } elementNpgsqlDbType + ? elementNpgsqlDbType | NpgsqlDbType.Array + : null, + // e.g. custom ranges, plugin types etc. + _ => null + }; + } +} + /// /// Represents a built-in PostgreSQL type as it appears in pg_type, including its name and OID. /// Extension types with variable OIDs are not represented. @@ -669,4 +1085,4 @@ internal BuiltInPostgresType( MultirangeName = multirangeName; MultirangeOID = multirangeOID; } -} \ No newline at end of file +} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs b/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs index f3c1d49139..f4b51ba4a9 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs @@ -1,8 +1,4 @@ using System; -using System.Collections; -using System.Collections.Generic; -using System.Text; -using Npgsql; // ReSharper disable once CheckNamespace namespace NpgsqlTypes; diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs index e83b69edad..cd68ecc3c8 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs @@ -89,7 +89,7 @@ public static NpgsqlTsQuery Parse(string value) var pos = 0; var expectingBinOp = false; - var lastFollowedByOpDistance = -1; + short lastFollowedByOpDistance = -1; NextToken: if (pos >= value.Length) @@ -125,7 +125,7 @@ public static NpgsqlTsQuery Parse(string value) { lastFollowedByOpDistance = 1; } - else if (!int.TryParse(followedByOpDistanceString, out lastFollowedByOpDistance) + else if (!short.TryParse(followedByOpDistanceString, out lastFollowedByOpDistance) || lastFollowedByOpDistance < 0) { throw new FormatException("Syntax error in tsquery. Malformed distance in 'followed by' operator."); @@ -172,7 +172,7 @@ public static NpgsqlTsQuery Parse(string value) var tsOp = opStack.Pop(); valStack.Push((char)tsOp switch { - '&' => (NpgsqlTsQuery)new NpgsqlTsQueryAnd(left, right), + '&' => new NpgsqlTsQueryAnd(left, right), '|' => new NpgsqlTsQueryOr(left, right), '<' => new NpgsqlTsQueryFollowedBy(left, tsOp.FollowedByDistance, right), _ => throw new FormatException("Syntax error in tsquery") @@ -383,9 +383,9 @@ public override bool Equals(object? obj) readonly struct NpgsqlTsQueryOperator { public readonly char Char; - public readonly int FollowedByDistance; + public readonly short FollowedByDistance; - public NpgsqlTsQueryOperator(char character, int followedByDistance) + public NpgsqlTsQueryOperator(char character, short followedByDistance) { Char = character; FollowedByDistance = followedByDistance; @@ -670,7 +670,7 @@ public sealed class NpgsqlTsQueryFollowedBy : NpgsqlTsQueryBinOp /// /// The distance between the 2 nodes, in lexemes. /// - public int Distance { get; set; } + public short Distance { get; set; } /// /// Creates a "followed by" operator, specifying 2 child nodes and the @@ -681,7 +681,7 @@ public sealed class NpgsqlTsQueryFollowedBy : NpgsqlTsQueryBinOp /// public NpgsqlTsQueryFollowedBy( NpgsqlTsQuery left, - int distance, + short distance, NpgsqlTsQuery right) : base(NodeKind.Phrase, left, right) { @@ -741,4 +741,4 @@ public override bool Equals(NpgsqlTsQuery? other) /// public override int GetHashCode() => Kind.GetHashCode(); -} \ No newline at end of file +} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs index 792a215774..c4a69a0c58 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs @@ -405,19 +405,17 @@ public override int GetHashCode() } /// -/// Represents a PostgreSQL inet type, which is a combination of an IPAddress and a -/// subnet mask. +/// Represents a PostgreSQL inet type, which is a combination of an IPAddress and a subnet mask. /// /// /// https://www.postgresql.org/docs/current/static/datatype-net-types.html /// -[Obsolete("Use ValueTuple instead")] -public struct NpgsqlInet : IEquatable +public readonly record struct NpgsqlInet { - public IPAddress Address { get; set; } - public int Netmask { get; set; } + public IPAddress Address { get; } + public byte Netmask { get; } - public NpgsqlInet(IPAddress address, int netmask) + public NpgsqlInet(IPAddress address, byte netmask) { if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); @@ -427,76 +425,92 @@ public NpgsqlInet(IPAddress address, int netmask) } public NpgsqlInet(IPAddress address) + : this(address, (byte)(address.AddressFamily == AddressFamily.InterNetwork ? 32 : 128)) { - if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) - throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); - - Address = address; - Netmask = address.AddressFamily == AddressFamily.InterNetwork ? 32 : 128; } public NpgsqlInet(string addr) { - if (addr.IndexOf('/') > 0) + switch (addr.Split('/')) { - var addrbits = addr.Split('/'); - if (addrbits.GetUpperBound(0) != 1) - { - throw new FormatException("Invalid number of parts in CIDR specification"); - } + case { Length: 2 } segments: + Address = IPAddress.Parse(segments[0]); + Netmask = byte.Parse(segments[1]); + return; - Address = IPAddress.Parse(addrbits[0]); - Netmask = int.Parse(addrbits[1]); - } - else - { - Address = IPAddress.Parse(addr); + case { Length: 1 } segments: + Address = IPAddress.Parse(segments[0]); Netmask = 32; + return; + + default: + throw new FormatException("Invalid number of parts in CIDR specification"); } } public override string ToString() - { - if ((Address.AddressFamily == AddressFamily.InterNetwork && Netmask == 32) || - (Address.AddressFamily == AddressFamily.InterNetworkV6 && Netmask == 128)) - { - return Address.ToString(); - } + => (Address.AddressFamily == AddressFamily.InterNetwork && Netmask == 32) || + (Address.AddressFamily == AddressFamily.InterNetworkV6 && Netmask == 128) + ? Address.ToString() + : $"{Address}/{Netmask}"; - return $"{Address}/{Netmask}"; - } + public static explicit operator IPAddress(NpgsqlInet inet) + => inet.Address; + + public static explicit operator NpgsqlInet(IPAddress ip) + => new(ip); - // ReSharper disable once InconsistentNaming - public static IPAddress ToIPAddress(NpgsqlInet inet) + public void Deconstruct(out IPAddress address, out byte netmask) { - if (inet.Netmask != 32) - throw new InvalidCastException("Cannot cast CIDR network to address"); - return inet.Address; + address = Address; + netmask = Netmask; } +} - public static explicit operator IPAddress(NpgsqlInet inet) => ToIPAddress(inet); +/// +/// Represents a PostgreSQL cidr type. +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-net-types.html +/// +public readonly record struct NpgsqlCidr +{ + public IPAddress Address { get; } + public byte Netmask { get; } - public static NpgsqlInet ToNpgsqlInet(IPAddress? ip) - => ip is null ? default : new NpgsqlInet(ip); + public NpgsqlCidr(IPAddress address, byte netmask) + { + if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) + throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); - public static implicit operator NpgsqlInet(IPAddress ip) => ToNpgsqlInet(ip); + Address = address; + Netmask = netmask; + } - public void Deconstruct(out IPAddress address, out int netmask) + public NpgsqlCidr(string addr) { - address = Address; - netmask = Netmask; + switch (addr.Split('/')) + { + case { Length: 2 } segments: + Address = IPAddress.Parse(segments[0]); + Netmask = byte.Parse(segments[1]); + return; + + case { Length: 1 } segments: + throw new FormatException("Missing netmask"); + default: + throw new FormatException("Invalid number of parts in CIDR specification"); + } } - public bool Equals(NpgsqlInet other) => Address.Equals(other.Address) && Netmask == other.Netmask; - - public override bool Equals(object? obj) - => obj is NpgsqlInet inet && Equals(inet); - - public override int GetHashCode() - => HashCode.Combine(Address, Netmask); + public override string ToString() + => $"{Address}/{Netmask}"; - public static bool operator ==(NpgsqlInet x, NpgsqlInet y) => x.Equals(y); - public static bool operator !=(NpgsqlInet x, NpgsqlInet y) => !(x == y); + public void Deconstruct(out IPAddress address, out byte netmask) + { + address = Address; + netmask = Netmask; + } } /// diff --git a/src/Npgsql/PoolManager.cs b/src/Npgsql/PoolManager.cs index adc3c75fa6..d1086b5196 100644 --- a/src/Npgsql/PoolManager.cs +++ b/src/Npgsql/PoolManager.cs @@ -1,7 +1,5 @@ using System; using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; -using System.Threading; namespace Npgsql; diff --git a/src/Npgsql/PoolingDataSource.cs b/src/Npgsql/PoolingDataSource.cs index 52f8adb2f4..a7a4863522 100644 --- a/src/Npgsql/PoolingDataSource.cs +++ b/src/Npgsql/PoolingDataSource.cs @@ -230,9 +230,9 @@ bool CheckIdleConnector([NotNullWhen(true)] NpgsqlConnector? connector) // The connector directly references the data source type mapper into the connector, to protect it against changes by a concurrent // ReloadTypes. We update them here before returning the connector from the pool. - Debug.Assert(TypeMapper is not null); + Debug.Assert(SerializerOptions is not null); Debug.Assert(DatabaseInfo is not null); - connector.TypeMapper = TypeMapper; + connector.SerializerOptions = SerializerOptions; connector.DatabaseInfo = DatabaseInfo; Debug.Assert(connector.State == ConnectorState.Ready, diff --git a/src/Npgsql/PostgresDatabaseInfo.cs b/src/Npgsql/PostgresDatabaseInfo.cs index 4d640fb261..a4e7a33462 100644 --- a/src/Npgsql/PostgresDatabaseInfo.cs +++ b/src/Npgsql/PostgresDatabaseInfo.cs @@ -3,14 +3,13 @@ using System.Diagnostics; using System.Globalization; using System.Linq; -using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Npgsql.BackendMessages; using Npgsql.Internal; using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; using Npgsql.Util; using static Npgsql.Util.Statics; @@ -80,6 +79,10 @@ internal PostgresDatabaseInfo(NpgsqlConnector conn) : base(conn.Host!, conn.Port, conn.Database!, conn.PostgresParameters["server_version"]) => _connectionLogger = conn.LoggingConfiguration.ConnectionLogger; + private protected PostgresDatabaseInfo(string host, int port, string databaseName, string serverVersion) + : base(host, port, databaseName, serverVersion) + => _connectionLogger = NullLogger.Instance; + /// /// Loads database information from the PostgreSQL database specified by . /// @@ -142,7 +145,7 @@ JOIN pg_namespace AS ns ON (ns.oid = typnamespace) WHERE typtype IN ('b', 'r', 'm', 'e', 'd') OR -- Base, range, multirange, enum, domain (typtype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "relkind='c'")}) OR -- User-defined free-standing composites (not table composites) by default - (typtype = 'p' AND typname IN ('record', 'void')) OR -- Some special supported pseudo-types + (typtype = 'p' AND typname IN ('record', 'void', 'unknown')) OR -- Some special supported pseudo-types (typtype = 'a' AND ( -- Array of... elemtyptype IN ('b', 'r', 'm', 'e', 'd') OR -- Array of base, range, multirange, enum, domain (elemtyptype = 'p' AND elemtypname IN ('record', 'void')) OR -- Arrays of special supported pseudo-types @@ -543,4 +546,4 @@ static string SanitizeForReplicationConnection(string str) static string ReadNonNullableString(NpgsqlReadBuffer buffer) => buffer.ReadString(buffer.ReadInt32()); } -} \ No newline at end of file +} diff --git a/src/Npgsql/PostgresMinimalDatabaseInfo.cs b/src/Npgsql/PostgresMinimalDatabaseInfo.cs index 31b2d24f1d..94b76f541c 100644 --- a/src/Npgsql/PostgresMinimalDatabaseInfo.cs +++ b/src/Npgsql/PostgresMinimalDatabaseInfo.cs @@ -59,4 +59,24 @@ internal PostgresMinimalDatabaseInfo(NpgsqlConnector conn) HasIntegerDateTimes = !conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) || intDateTimes == "on"; } -} \ No newline at end of file + + // TODO, split database info and type catalog. + internal PostgresMinimalDatabaseInfo() + : base("minimal", 5432, "minimal", "14") + { + } + + static PostgresMinimalDatabaseInfo? _defaultTypeCatalog; + internal static PostgresMinimalDatabaseInfo DefaultTypeCatalog + { + get + { + if (_defaultTypeCatalog is not null) + return _defaultTypeCatalog; + + var catalog = new PostgresMinimalDatabaseInfo(); + catalog.ProcessTypes(); + return _defaultTypeCatalog = catalog; + } + } +} diff --git a/src/Npgsql/PostgresTypes/PostgresArrayType.cs b/src/Npgsql/PostgresTypes/PostgresArrayType.cs index cfeb89c736..7f0b2246d3 100644 --- a/src/Npgsql/PostgresTypes/PostgresArrayType.cs +++ b/src/Npgsql/PostgresTypes/PostgresArrayType.cs @@ -1,6 +1,4 @@ -using System.Diagnostics; - -namespace Npgsql.PostgresTypes; +namespace Npgsql.PostgresTypes; /// /// Represents a PostgreSQL array data type, which can hold several multiple values in a single column. @@ -18,10 +16,9 @@ public class PostgresArrayType : PostgresType /// /// Constructs a representation of a PostgreSQL array data type. /// - protected internal PostgresArrayType(string ns, string internalName, uint oid, PostgresType elementPostgresType) - : base(ns, elementPostgresType.Name + "[]", internalName, oid) + protected internal PostgresArrayType(string ns, string name, uint oid, PostgresType elementPostgresType) + : base(ns, name, oid) { - Debug.Assert(internalName == '_' + elementPostgresType.InternalName); Element = elementPostgresType; Element.Array = this; } @@ -34,4 +31,4 @@ internal override string GetPartialNameWithFacets(int typeModifier) internal override PostgresFacets GetFacets(int typeModifier) => Element.GetFacets(typeModifier); -} \ No newline at end of file +} diff --git a/src/Npgsql/PostgresTypes/PostgresBaseType.cs b/src/Npgsql/PostgresTypes/PostgresBaseType.cs index de9a7bc13e..a7cb0857cc 100644 --- a/src/Npgsql/PostgresTypes/PostgresBaseType.cs +++ b/src/Npgsql/PostgresTypes/PostgresBaseType.cs @@ -7,8 +7,8 @@ namespace Npgsql.PostgresTypes; public class PostgresBaseType : PostgresType { /// - protected internal PostgresBaseType(string ns, string internalName, uint oid) - : base(ns, TranslateInternalName(internalName), internalName, oid) + protected internal PostgresBaseType(string ns, string name, uint oid) + : base(ns, name, oid) {} /// @@ -68,27 +68,4 @@ internal override PostgresFacets GetFacets(int typeModifier) return PostgresFacets.None; } } - - // The type names returned by PostgreSQL are internal names (int4 instead of - // integer). We perform translation to the user-facing standard names. - // https://www.postgresql.org/docs/current/static/datatype.html#DATATYPE-TABLE - static string TranslateInternalName(string internalName) - => internalName switch - { - "bool" => "boolean", - "bpchar" => "character", - "decimal" => "numeric", - "float4" => "real", - "float8" => "double precision", - "int2" => "smallint", - "int4" => "integer", - "int8" => "bigint", - "time" => "time without time zone", - "timestamp" => "timestamp without time zone", - "timetz" => "time with time zone", - "timestamptz" => "timestamp with time zone", - "varbit" => "bit varying", - "varchar" => "character varying", - _ => internalName - }; -} \ No newline at end of file +} diff --git a/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs b/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs index 3d35783263..2e57075cb3 100644 --- a/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs +++ b/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs @@ -23,4 +23,4 @@ protected internal PostgresMultirangeType(string ns, string name, uint oid, Post Subrange = rangePostgresType; Subrange.Multirange = this; } -} \ No newline at end of file +} diff --git a/src/Npgsql/PostgresTypes/PostgresType.cs b/src/Npgsql/PostgresTypes/PostgresType.cs index 543cf3dcfd..8cc5fb7b63 100644 --- a/src/Npgsql/PostgresTypes/PostgresType.cs +++ b/src/Npgsql/PostgresTypes/PostgresType.cs @@ -1,5 +1,5 @@ using System; -using System.Linq; +using Npgsql.Internal.Postgres; namespace Npgsql.PostgresTypes; @@ -22,23 +22,11 @@ public abstract class PostgresType /// The data type's namespace (or schema). /// The data type's name. /// The data type's OID. - protected PostgresType(string ns, string name, uint oid) - : this(ns, name, name, oid) {} - - /// - /// Constructs a representation of a PostgreSQL data type. - /// - /// The data type's namespace (or schema). - /// The data type's name. - /// The data type's internal name (e.g. _int4 for integer[]). - /// The data type's OID. - protected PostgresType(string ns, string name, string internalName, uint oid) + private protected PostgresType(string ns, string name, uint oid) { - Namespace = ns; - Name = name; - FullName = Namespace + '.' + Name; - InternalName = internalName; + DataTypeName = DataTypeName.FromDisplayName(name, ns); OID = oid; + FullName = Namespace + "." + Name; } #endregion @@ -53,7 +41,7 @@ protected PostgresType(string ns, string name, string internalName, uint oid) /// /// The data type's namespace (or schema). /// - public string Namespace { get; } + public string Namespace => DataTypeName.Schema; /// /// The data type's name. @@ -62,24 +50,26 @@ protected PostgresType(string ns, string name, string internalName, uint oid) /// Note that this is the standard, user-displayable type name (e.g. integer[]) rather than the internal /// PostgreSQL name as it is in pg_type (_int4). See for the latter. /// - public string Name { get; } + public string Name => DataTypeName.UnqualifiedDisplayName; /// /// The full name of the backend type, including its namespace. /// public string FullName { get; } + internal DataTypeName DataTypeName { get; } + /// /// A display name for this backend type, including the namespace unless it is pg_catalog (the namespace /// for all built-in types). /// - public string DisplayName => Namespace == "pg_catalog" ? Name : FullName; + public string DisplayName => DataTypeName.DisplayName; /// /// The data type's internal PostgreSQL name (e.g. _int4 not integer[]). /// See for a more user-friendly name. /// - public string InternalName { get; } + public string InternalName => DataTypeName.UnqualifiedName; /// /// If a PostgreSQL array type exists for this type, it will be referenced here. @@ -111,4 +101,21 @@ internal string GetDisplayNameWithFacets(int typeModifier) /// Returns a string that represents the current object. /// public override string ToString() => DisplayName; -} \ No newline at end of file + + PostgresType? _representationalType; + + /// Canonizes (nested) domain types to underlying types, does not handle composites. + internal PostgresType GetRepresentationalType() + { + return _representationalType ??= Core(this) ?? throw new InvalidOperationException("Couldn't map type to representational type"); + + static PostgresType? Core(PostgresType? postgresType) + => (postgresType as PostgresDomainType)?.BaseType ?? postgresType switch + { + PostgresArrayType { Element: PostgresDomainType domain } => Core(domain.BaseType)?.Array, + PostgresMultirangeType { Subrange.Subtype: PostgresDomainType domain } => domain.BaseType.Range?.Multirange, + PostgresRangeType { Subtype: PostgresDomainType domain } => domain.Range, + var type => type + }; + } +} diff --git a/src/Npgsql/PostgresTypes/PostgresUnknownType.cs b/src/Npgsql/PostgresTypes/PostgresUnknownType.cs index a520df9696..bbe952726d 100644 --- a/src/Npgsql/PostgresTypes/PostgresUnknownType.cs +++ b/src/Npgsql/PostgresTypes/PostgresUnknownType.cs @@ -3,7 +3,7 @@ /// /// Represents a PostgreSQL data type that isn't known to Npgsql and cannot be handled. /// -public class UnknownBackendType : PostgresType +public sealed class UnknownBackendType : PostgresType { internal static readonly PostgresType Instance = new UnknownBackendType(); @@ -13,4 +13,4 @@ public class UnknownBackendType : PostgresType #pragma warning disable CA2222 // Do not decrease inherited member visibility UnknownBackendType() : base("", "", 0) { } #pragma warning restore CA2222 // Do not decrease inherited member visibility -} \ No newline at end of file +} diff --git a/src/Npgsql/PreparedStatement.cs b/src/Npgsql/PreparedStatement.cs index 015adc5dd3..9186df77c9 100644 --- a/src/Npgsql/PreparedStatement.cs +++ b/src/Npgsql/PreparedStatement.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using Npgsql.BackendMessages; +using Npgsql.Internal.Postgres; namespace Npgsql; @@ -46,9 +47,7 @@ sealed class PreparedStatement /// Contains the handler types for a prepared statement's parameters, for overloaded cases (same SQL, different param types) /// Only populated after the statement has been prepared (i.e. null for candidates). /// - internal Type[]? HandlerParamTypes { get; private set; } - - static readonly Type[] EmptyParamTypes = Type.EmptyTypes; + PgTypeId[]? ConverterParamTypes { get; set; } internal static PreparedStatement CreateExplicit( PreparedStatementManager manager, @@ -81,22 +80,22 @@ internal void SetParamTypes(List parameters) { if (parameters.Count == 0) { - HandlerParamTypes = EmptyParamTypes; + ConverterParamTypes = Array.Empty(); return; } - HandlerParamTypes = new Type[parameters.Count]; + ConverterParamTypes = new PgTypeId[parameters.Count]; for (var i = 0; i < parameters.Count; i++) - HandlerParamTypes[i] = parameters[i].Handler!.GetType(); + ConverterParamTypes[i] = parameters[i].PgTypeId; } internal bool DoParametersMatch(List parameters) { - if (HandlerParamTypes!.Length != parameters.Count) + if (ConverterParamTypes!.Length != parameters.Count) return false; - for (var i = 0; i < HandlerParamTypes.Length; i++) - if (HandlerParamTypes[i] != parameters[i].Handler!.GetType()) + for (var i = 0; i < ConverterParamTypes.Length; i++) + if (ConverterParamTypes[i] != parameters[i].PgTypeId) return false; return true; @@ -170,4 +169,4 @@ enum PreparedState /// The statement was invalidated because e.g. table schema has changed since preparation. /// Invalidated -} \ No newline at end of file +} diff --git a/src/Npgsql/PreparedTextReader.cs b/src/Npgsql/PreparedTextReader.cs index 8a2cf806d2..8862daa3e7 100644 --- a/src/Npgsql/PreparedTextReader.cs +++ b/src/Npgsql/PreparedTextReader.cs @@ -27,7 +27,7 @@ public void Init(string str, NpgsqlReadBuffer.ColumnStream stream) public override int Peek() { CheckDisposed(); - + return _position < _str.Length ? _str[_position] : -1; @@ -36,7 +36,7 @@ public override int Peek() public override int Read() { CheckDisposed(); - + return _position < _str.Length ? _str[_position++] : -1; @@ -82,7 +82,7 @@ public override Task ReadAsync(char[] buffer, int index, int count) public #if !NETSTANDARD2_0 - override + override #endif ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => new(Read(buffer.Span)); @@ -91,7 +91,7 @@ public override Task ReadAsync(char[] buffer, int index, int count) public override string ReadToEnd() { CheckDisposed(); - + if (_position == _str.Length) return string.Empty; @@ -108,6 +108,12 @@ void CheckDisposed() ThrowHelper.ThrowObjectDisposedException(nameof(PreparedTextReader)); } + public void Restart() + { + CheckDisposed(); + _position = 0; + } + protected override void Dispose(bool disposing) { base.Dispose(disposing); diff --git a/src/Npgsql/Properties/AssemblyInfo.cs b/src/Npgsql/Properties/AssemblyInfo.cs index e71a69a9dd..666ee3f170 100644 --- a/src/Npgsql/Properties/AssemblyInfo.cs +++ b/src/Npgsql/Properties/AssemblyInfo.cs @@ -33,7 +33,7 @@ "7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + "29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] -[assembly: InternalsVisibleTo("Npgsql.NodaTime.Tests, PublicKey=" + +[assembly: InternalsVisibleTo("Npgsql.PluginTests, PublicKey=" + "0024000004800000940000000602000000240000525341310004000001000100" + "2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + "8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + diff --git a/src/Npgsql/Properties/NpgsqlStrings.Designer.cs b/src/Npgsql/Properties/NpgsqlStrings.Designer.cs index 5f0847543f..707240754c 100644 --- a/src/Npgsql/Properties/NpgsqlStrings.Designer.cs +++ b/src/Npgsql/Properties/NpgsqlStrings.Designer.cs @@ -11,46 +11,32 @@ namespace Npgsql.Properties { using System; - /// - /// A strongly-typed resource class, for looking up localized strings, etc. - /// - // This class was auto-generated by the StronglyTypedResourceBuilder - // class via a tool like ResGen or Visual Studio. - // To add or remove a member, edit your .ResX file then rerun ResGen - // with the /str option, or rebuild your VS project. - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] - [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] - [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + [System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [System.Diagnostics.DebuggerNonUserCodeAttribute()] + [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] internal class NpgsqlStrings { - private static global::System.Resources.ResourceManager resourceMan; + private static System.Resources.ResourceManager resourceMan; - private static global::System.Globalization.CultureInfo resourceCulture; + private static System.Globalization.CultureInfo resourceCulture; - [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + [System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] internal NpgsqlStrings() { } - /// - /// Returns the cached ResourceManager instance used by this class. - /// - [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] - internal static global::System.Resources.ResourceManager ResourceManager { + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + internal static System.Resources.ResourceManager ResourceManager { get { - if (object.ReferenceEquals(resourceMan, null)) { - global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Npgsql.Properties.NpgsqlStrings", typeof(NpgsqlStrings).Assembly); + if (object.Equals(null, resourceMan)) { + System.Resources.ResourceManager temp = new System.Resources.ResourceManager("Npgsql.Properties.NpgsqlStrings", typeof(NpgsqlStrings).Assembly); resourceMan = temp; } return resourceMan; } } - /// - /// Overrides the current thread's CurrentUICulture property for all - /// resource lookups using this strongly typed resource class. - /// - [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] - internal static global::System.Globalization.CultureInfo Culture { + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + internal static System.Globalization.CultureInfo Culture { get { return resourceCulture; } @@ -59,156 +45,147 @@ internal NpgsqlStrings() { } } - /// - /// Looks up a localized string similar to '{0}' must be positive.. - /// - internal static string ArgumentMustBePositive { + internal static string CannotUseSslVerifyWithUserCallback { get { - return ResourceManager.GetString("ArgumentMustBePositive", resourceCulture); + return ResourceManager.GetString("CannotUseSslVerifyWithUserCallback", resourceCulture); } } - /// - /// Looks up a localized string similar to Cannot read infinity value since Npgsql.DisableDateTimeInfinityConversions is enabled.. - /// - internal static string CannotReadInfinityValue { + internal static string CannotUseSslRootCertificateWithUserCallback { get { - return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); + return ResourceManager.GetString("CannotUseSslRootCertificateWithUserCallback", resourceCulture); } } - /// - /// Looks up a localized string similar to Cannot read interval values with non-zero months as TimeSpan, since that type doesn't support months. Consider using NodaTime Period which better corresponds to PostgreSQL interval, or read the value as NpgsqlInterval, or transform the interval to not contain months or years in PostgreSQL before reading it.. - /// - internal static string CannotReadIntervalWithMonthsAsTimeSpan { + internal static string EncryptionDisabled { get { - return ResourceManager.GetString("CannotReadIntervalWithMonthsAsTimeSpan", resourceCulture); + return ResourceManager.GetString("EncryptionDisabled", resourceCulture); + } + } + + internal static string NoMultirangeTypeFound { + get { + return ResourceManager.GetString("NoMultirangeTypeFound", resourceCulture); + } + } + + internal static string NotSupportedOnDataSourceCommand { + get { + return ResourceManager.GetString("NotSupportedOnDataSourceCommand", resourceCulture); + } + } + + internal static string NotSupportedOnDataSourceBatch { + get { + return ResourceManager.GetString("NotSupportedOnDataSourceBatch", resourceCulture); } } - /// - /// Looks up a localized string similar to When registering a password provider, a password or password file may not be set.. - /// internal static string CannotSetBothPasswordProviderAndPassword { get { return ResourceManager.GetString("CannotSetBothPasswordProviderAndPassword", resourceCulture); } } - /// - /// Looks up a localized string similar to When creating a multi-host data source, TargetSessionAttributes cannot be specified. Create without TargetSessionAttributes, and then obtain DataSource wrappers from it. Consult the docs for more information.. - /// + internal static string PasswordProviderMissing { + get { + return ResourceManager.GetString("PasswordProviderMissing", resourceCulture); + } + } + + internal static string ArgumentMustBePositive { + get { + return ResourceManager.GetString("ArgumentMustBePositive", resourceCulture); + } + } + internal static string CannotSpecifyTargetSessionAttributes { get { return ResourceManager.GetString("CannotSpecifyTargetSessionAttributes", resourceCulture); } } - /// - /// Looks up a localized string similar to RootCertificate cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback.. - /// - internal static string CannotUseSslRootCertificateWithUserCallback { + internal static string CannotReadIntervalWithMonthsAsTimeSpan { get { - return ResourceManager.GetString("CannotUseSslRootCertificateWithUserCallback", resourceCulture); + return ResourceManager.GetString("CannotReadIntervalWithMonthsAsTimeSpan", resourceCulture); } } - /// - /// Looks up a localized string similar to SslMode.{0} cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback.. - /// - internal static string CannotUseSslVerifyWithUserCallback { + internal static string PositionalParameterAfterNamed { get { - return ResourceManager.GetString("CannotUseSslVerifyWithUserCallback", resourceCulture); + return ResourceManager.GetString("PositionalParameterAfterNamed", resourceCulture); + } + } + + internal static string CannotReadInfinityValue { + get { + return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); + } + } + + internal static string SyncAndAsyncConnectionInitializersRequired { + get { + return ResourceManager.GetString("SyncAndAsyncConnectionInitializersRequired", resourceCulture); } } - /// - /// Looks up a localized string similar to ValidationRootCertificateCallback cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback.. - /// internal static string CannotUseValidationRootCertificateCallbackWithUserCallback { get { return ResourceManager.GetString("CannotUseValidationRootCertificateCallbackWithUserCallback", resourceCulture); } } - /// - /// Looks up a localized string similar to NpgsqlSlimDataSourceBuilder is being used, and encryption hasn't been enabled, call EnableEncryption() on NpgsqlSlimDataSourceBuilder to enable it.. - /// - internal static string EncryptionDisabled { + internal static string RecordsNotEnabled { get { - return ResourceManager.GetString("EncryptionDisabled", resourceCulture); + return ResourceManager.GetString("RecordsNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to Full-text search isn't enabled; please call {0} on {1} to enable full-text search.. - /// internal static string FullTextSearchNotEnabled { get { return ResourceManager.GetString("FullTextSearchNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to No multirange type could be found in the database for subtype {0}.. - /// - internal static string NoMultirangeTypeFound { + internal static string LTreeNotEnabled { get { - return ResourceManager.GetString("NoMultirangeTypeFound", resourceCulture); + return ResourceManager.GetString("LTreeNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to Connection and transaction access is not supported on batches created from DbDataSource.. - /// - internal static string NotSupportedOnDataSourceBatch { + internal static string RangesNotEnabled { get { - return ResourceManager.GetString("NotSupportedOnDataSourceBatch", resourceCulture); + return ResourceManager.GetString("RangesNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to Connection and transaction access is not supported on commands created from DbDataSource.. - /// - internal static string NotSupportedOnDataSourceCommand { + internal static string RangeArraysNotEnabled { get { - return ResourceManager.GetString("NotSupportedOnDataSourceCommand", resourceCulture); + return ResourceManager.GetString("RangeArraysNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to The right type of password provider (sync or async) was not found.. - /// - internal static string PasswordProviderMissing { + internal static string MultirangesNotEnabled { get { - return ResourceManager.GetString("PasswordProviderMissing", resourceCulture); + return ResourceManager.GetString("MultirangesNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to When using CommandType.StoredProcedure, all positional parameters must come before named parameters.. - /// - internal static string PositionalParameterAfterNamed { + internal static string MultirangeArraysNotEnabled { get { - return ResourceManager.GetString("PositionalParameterAfterNamed", resourceCulture); + return ResourceManager.GetString("MultirangeArraysNotEnabled", resourceCulture); } } - /// - /// Looks up a localized string similar to Records aren't enabled; please call {0} on {1} to enable records.. - /// - internal static string RecordsNotEnabled { + internal static string TimestampTzNoDateTimeUnspecified { get { - return ResourceManager.GetString("RecordsNotEnabled", resourceCulture); + return ResourceManager.GetString("TimestampTzNoDateTimeUnspecified", resourceCulture); } } - /// - /// Looks up a localized string similar to Both sync and async connection initializers must be provided.. - /// - internal static string SyncAndAsyncConnectionInitializersRequired { + internal static string TimestampNoDateTimeUtc { get { - return ResourceManager.GetString("SyncAndAsyncConnectionInitializersRequired", resourceCulture); + return ResourceManager.GetString("TimestampNoDateTimeUtc", resourceCulture); } } } diff --git a/src/Npgsql/Properties/NpgsqlStrings.resx b/src/Npgsql/Properties/NpgsqlStrings.resx index 8df8e0b335..5ca209070f 100644 --- a/src/Npgsql/Properties/NpgsqlStrings.resx +++ b/src/Npgsql/Properties/NpgsqlStrings.resx @@ -69,4 +69,25 @@ Full-text search isn't enabled; please call {0} on {1} to enable full-text search. - \ No newline at end of file + + Ltree isn't enabled; please call {0} on {1} to enable LTree. + + + Ranges aren't enabled; please call {0} on {1} to enable ranges. + + + Range arrays aren't enabled; please call {0} on {1} to enable arrays for ranges. + + + Multiranges aren't enabled; please call {0} on {1} to enable multiranges. + + + Multirange arrays aren't enabled; please call {0} on {1} to enable arrays for multiranges. + + + Cannot write DateTime with Kind={0} to PostgreSQL type '{1}', only UTC is supported. Note that it's not possible to mix DateTimes with different Kinds in an array, range, or multirange. + + + Cannot write DateTime with Kind=UTC to PostgreSQL type '{0}', consider using '{1}'. Note that it's not possible to mix DateTimes with different Kinds in an array, range, or multirange. + + diff --git a/src/Npgsql/PublicAPI.Unshipped.txt b/src/Npgsql/PublicAPI.Unshipped.txt index aa795b81ff..37a23693b4 100644 --- a/src/Npgsql/PublicAPI.Unshipped.txt +++ b/src/Npgsql/PublicAPI.Unshipped.txt @@ -7,26 +7,31 @@ Npgsql.ChannelBinding.Require = 2 -> Npgsql.ChannelBinding Npgsql.NpgsqlBatch.CreateBatchCommand() -> Npgsql.NpgsqlBatchCommand! Npgsql.NpgsqlConnectionStringBuilder.ChannelBinding.get -> Npgsql.ChannelBinding Npgsql.NpgsqlConnectionStringBuilder.ChannelBinding.set -> void -Npgsql.NpgsqlSlimDataSourceBuilder.EnableFullTextSearch() -> Npgsql.NpgsqlSlimDataSourceBuilder! -Npgsql.NpgsqlSlimDataSourceBuilder.EnableRecords() -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlBinaryImporter.WriteRow(params object?[]! values) -> void Npgsql.NpgsqlBinaryImporter.WriteRowAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken), params object?[]! values) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlDataSourceBuilder.AddTypeInfoResolver(Npgsql.Internal.IPgTypeInfoResolver! resolver) -> void Npgsql.NpgsqlDataSourceBuilder.Name.get -> string? Npgsql.NpgsqlDataSourceBuilder.Name.set -> void Npgsql.NpgsqlDataSourceBuilder.UseRootCertificate(System.Security.Cryptography.X509Certificates.X509Certificate2? rootCertificate) -> Npgsql.NpgsqlDataSourceBuilder! Npgsql.NpgsqlDataSourceBuilder.UseRootCertificateCallback(System.Func? rootCertificateCallback) -> Npgsql.NpgsqlDataSourceBuilder! Npgsql.NpgsqlDataSourceBuilder.UseSystemTextJson(System.Text.Json.JsonSerializerOptions? serializerOptions = null, System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> Npgsql.NpgsqlDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder -Npgsql.NpgsqlSlimDataSourceBuilder.AddTypeResolverFactory(Npgsql.Internal.TypeHandling.TypeHandlerResolverFactory! resolverFactory) -> void +Npgsql.NpgsqlSlimDataSourceBuilder.AddTypeInfoResolver(Npgsql.Internal.IPgTypeInfoResolver! resolver) -> void Npgsql.NpgsqlSlimDataSourceBuilder.Build() -> Npgsql.NpgsqlDataSource! Npgsql.NpgsqlSlimDataSourceBuilder.BuildMultiHost() -> Npgsql.NpgsqlMultiHostDataSource! Npgsql.NpgsqlSlimDataSourceBuilder.ConnectionString.get -> string! Npgsql.NpgsqlSlimDataSourceBuilder.ConnectionStringBuilder.get -> Npgsql.NpgsqlConnectionStringBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.DefaultNameTranslator.get -> Npgsql.INpgsqlNameTranslator! Npgsql.NpgsqlSlimDataSourceBuilder.DefaultNameTranslator.set -> void +Npgsql.NpgsqlSlimDataSourceBuilder.EnableArrays() -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.EnableEncryption() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableExtraConversions() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableFullTextSearch() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableLTree() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableMultiranges() -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.EnableParameterLogging(bool parameterLoggingEnabled = true) -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.EnableRanges() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableRecords() -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! @@ -46,10 +51,28 @@ Npgsql.NpgsqlSlimDataSourceBuilder.UseRootCertificate(System.Security.Cryptograp Npgsql.NpgsqlSlimDataSourceBuilder.UseRootCertificateCallback(System.Func? rootCertificateCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.UseSystemTextJson(System.Text.Json.JsonSerializerOptions? serializerOptions = null, System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> Npgsql.NpgsqlSlimDataSourceBuilder! Npgsql.NpgsqlSlimDataSourceBuilder.UseUserCertificateValidationCallback(System.Net.Security.RemoteCertificateValidationCallback! userCertificateValidationCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.PostgresTypes.PostgresArrayType.PostgresArrayType(string! ns, string! name, uint oid, Npgsql.PostgresTypes.PostgresType! elementPostgresType) -> void +Npgsql.PostgresTypes.PostgresBaseType.PostgresBaseType(string! ns, string! name, uint oid) -> void Npgsql.Replication.PhysicalReplicationConnection.StartReplication(Npgsql.Replication.PhysicalReplicationSlot? slot, NpgsqlTypes.NpgsqlLogSequenceNumber walLocation, System.Threading.CancellationToken cancellationToken, uint timeline = 0) -> System.Collections.Generic.IAsyncEnumerable! Npgsql.Replication.PhysicalReplicationConnection.StartReplication(NpgsqlTypes.NpgsqlLogSequenceNumber walLocation, System.Threading.CancellationToken cancellationToken, uint timeline = 0) -> System.Collections.Generic.IAsyncEnumerable! Npgsql.Replication.PhysicalReplicationSlot.PhysicalReplicationSlot(string! slotName, NpgsqlTypes.NpgsqlLogSequenceNumber? restartLsn = null, uint? restartTimeline = null) -> void Npgsql.Replication.PhysicalReplicationSlot.RestartTimeline.get -> uint? +Npgsql.TypeMapping.INpgsqlTypeMapper.AddTypeInfoResolver(Npgsql.Internal.IPgTypeInfoResolver! resolver) -> void +Npgsql.TypeMapping.UserTypeMapping +Npgsql.TypeMapping.UserTypeMapping.ClrType.get -> System.Type! +Npgsql.TypeMapping.UserTypeMapping.PgTypeName.get -> string! +NpgsqlTypes.NpgsqlCidr +NpgsqlTypes.NpgsqlCidr.Address.get -> System.Net.IPAddress! +NpgsqlTypes.NpgsqlCidr.Deconstruct(out System.Net.IPAddress! address, out byte netmask) -> void +NpgsqlTypes.NpgsqlCidr.Netmask.get -> byte +NpgsqlTypes.NpgsqlCidr.NpgsqlCidr() -> void +NpgsqlTypes.NpgsqlCidr.NpgsqlCidr(string! addr) -> void +NpgsqlTypes.NpgsqlCidr.NpgsqlCidr(System.Net.IPAddress! address, byte netmask) -> void +NpgsqlTypes.NpgsqlInet.Deconstruct(out System.Net.IPAddress! address, out byte netmask) -> void +NpgsqlTypes.NpgsqlInet.NpgsqlInet(System.Net.IPAddress! address, byte netmask) -> void +NpgsqlTypes.NpgsqlInet.Netmask.get -> byte +NpgsqlTypes.NpgsqlTsQueryFollowedBy.Distance.get -> short +NpgsqlTypes.NpgsqlTsQueryFollowedBy.NpgsqlTsQueryFollowedBy(NpgsqlTypes.NpgsqlTsQuery! left, short distance, NpgsqlTypes.NpgsqlTsQuery! right) -> void override Npgsql.NpgsqlBatch.Dispose() -> void *REMOVED*static NpgsqlTypes.NpgsqlBox.Parse(string! s) -> NpgsqlTypes.NpgsqlBox *REMOVED*static NpgsqlTypes.NpgsqlCircle.Parse(string! s) -> NpgsqlTypes.NpgsqlCircle @@ -58,6 +81,14 @@ override Npgsql.NpgsqlBatch.Dispose() -> void *REMOVED*static NpgsqlTypes.NpgsqlPath.Parse(string! s) -> NpgsqlTypes.NpgsqlPath *REMOVED*static NpgsqlTypes.NpgsqlPoint.Parse(string! s) -> NpgsqlTypes.NpgsqlPoint *REMOVED*static NpgsqlTypes.NpgsqlPolygon.Parse(string! s) -> NpgsqlTypes.NpgsqlPolygon +*REMOVED*NpgsqlTypes.NpgsqlInet.Deconstruct(out System.Net.IPAddress! address, out int netmask) -> void +*REMOVED*NpgsqlTypes.NpgsqlInet.NpgsqlInet(System.Net.IPAddress! address, int netmask) -> void +*REMOVED*NpgsqlTypes.NpgsqlInet.Address.set -> void +*REMOVED*NpgsqlTypes.NpgsqlInet.Equals(NpgsqlTypes.NpgsqlInet other) -> bool +*REMOVED*NpgsqlTypes.NpgsqlInet.Netmask.get -> int +*REMOVED*NpgsqlTypes.NpgsqlInet.Netmask.set -> void +*REMOVED*NpgsqlTypes.NpgsqlTsQueryFollowedBy.Distance.get -> int +*REMOVED*NpgsqlTypes.NpgsqlTsQueryFollowedBy.NpgsqlTsQueryFollowedBy(NpgsqlTypes.NpgsqlTsQuery! left, int distance, NpgsqlTypes.NpgsqlTsQuery! right) -> void *REMOVED*Npgsql.NpgsqlBinaryImporter.WriteRow(params object![]! values) -> void *REMOVED*Npgsql.NpgsqlBinaryImporter.WriteRowAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken), params object![]! values) -> System.Threading.Tasks.Task! *REMOVED*override Npgsql.NpgsqlDataReader.GetProviderSpecificFieldType(int ordinal) -> System.Type! @@ -66,7 +97,24 @@ override Npgsql.NpgsqlBatch.Dispose() -> void *REMOVED*override Npgsql.NpgsqlNestedDataReader.GetProviderSpecificFieldType(int ordinal) -> System.Type! *REMOVED*override Npgsql.NpgsqlNestedDataReader.GetProviderSpecificValue(int ordinal) -> object! *REMOVED*override Npgsql.NpgsqlNestedDataReader.GetProviderSpecificValues(object![]! values) -> int +*REMOVED*override NpgsqlTypes.NpgsqlInet.Equals(object? obj) -> bool +*REMOVED*override NpgsqlTypes.NpgsqlInet.GetHashCode() -> int *REMOVED*Npgsql.Replication.PhysicalReplicationConnection.StartReplication(Npgsql.Replication.PhysicalReplicationSlot? slot, NpgsqlTypes.NpgsqlLogSequenceNumber walLocation, System.Threading.CancellationToken cancellationToken, ulong timeline = 0) -> System.Collections.Generic.IAsyncEnumerable! *REMOVED*Npgsql.Replication.PhysicalReplicationConnection.StartReplication(NpgsqlTypes.NpgsqlLogSequenceNumber walLocation, System.Threading.CancellationToken cancellationToken, ulong timeline = 0) -> System.Collections.Generic.IAsyncEnumerable! *REMOVED*Npgsql.Replication.PhysicalReplicationSlot.PhysicalReplicationSlot(string! slotName, NpgsqlTypes.NpgsqlLogSequenceNumber? restartLsn = null, ulong? restartTimeline = null) -> void *REMOVED*Npgsql.Replication.PhysicalReplicationSlot.RestartTimeline.get -> ulong? +override NpgsqlTypes.NpgsqlCidr.ToString() -> string! +*REMOVED*static NpgsqlTypes.NpgsqlInet.operator !=(NpgsqlTypes.NpgsqlInet x, NpgsqlTypes.NpgsqlInet y) -> bool +*REMOVED*static NpgsqlTypes.NpgsqlInet.operator ==(NpgsqlTypes.NpgsqlInet x, NpgsqlTypes.NpgsqlInet y) -> bool +*REMOVED*static NpgsqlTypes.NpgsqlInet.ToIPAddress(NpgsqlTypes.NpgsqlInet inet) -> System.Net.IPAddress! +*REMOVED*static NpgsqlTypes.NpgsqlInet.ToNpgsqlInet(System.Net.IPAddress? ip) -> NpgsqlTypes.NpgsqlInet +*REMOVED*Npgsql.NpgsqlDataSourceBuilder.AddTypeResolverFactory(Npgsql.Internal.TypeHandling.TypeHandlerResolverFactory! resolverFactory) -> void +static NpgsqlTypes.NpgsqlInet.explicit operator NpgsqlTypes.NpgsqlInet(System.Net.IPAddress! ip) -> NpgsqlTypes.NpgsqlInet +*REMOVED*Npgsql.NpgsqlParameter.ConvertedValue.get -> object? +*REMOVED*Npgsql.NpgsqlParameter.ConvertedValue.set -> void +*REMOVED*Npgsql.PostgresTypes.PostgresArrayType.PostgresArrayType(string! ns, string! internalName, uint oid, Npgsql.PostgresTypes.PostgresType! elementPostgresType) -> void +*REMOVED*Npgsql.PostgresTypes.PostgresBaseType.PostgresBaseType(string! ns, string! internalName, uint oid) -> void +*REMOVED*static NpgsqlTypes.NpgsqlInet.implicit operator NpgsqlTypes.NpgsqlInet(System.Net.IPAddress! ip) -> NpgsqlTypes.NpgsqlInet +*REMOVED*Npgsql.PostgresTypes.PostgresType.PostgresType(string! ns, string! name, string! internalName, uint oid) -> void +*REMOVED*Npgsql.PostgresTypes.PostgresType.PostgresType(string! ns, string! name, uint oid) -> void +*REMOVED*Npgsql.TypeMapping.INpgsqlTypeMapper.AddTypeResolverFactory(Npgsql.Internal.TypeHandling.TypeHandlerResolverFactory! resolverFactory) -> void diff --git a/src/Npgsql/Replication/PgDateTime.cs b/src/Npgsql/Replication/PgDateTime.cs new file mode 100644 index 0000000000..aa68bda7f6 --- /dev/null +++ b/src/Npgsql/Replication/PgDateTime.cs @@ -0,0 +1,16 @@ +using System; + +namespace Npgsql.Replication; + +static class PgDateTime +{ + const long PostgresTimestampOffsetTicks = 630822816000000000L; + + public static DateTime DecodeTimestamp(long value, DateTimeKind kind) + => new(value * 10 + PostgresTimestampOffsetTicks, kind); + + public static long EncodeTimestamp(DateTime value) + // Rounding here would cause problems because we would round up DateTime.MaxValue + // which would make it impossible to retrieve it back from the database, so we just drop the additional precision + => (value.Ticks - PostgresTimestampOffsetTicks) / 10; +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs index 8a9a34741d..6fd36d7ea0 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; diff --git a/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs index 933b50ac68..a426a2b6ad 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs @@ -1,6 +1,5 @@ using NpgsqlTypes; using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; diff --git a/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs index 7da8f77c68..814780cf37 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; diff --git a/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs index 14f31b1672..021458140d 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; diff --git a/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs index d0f67841e9..fe862ead1b 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs @@ -1,6 +1,5 @@ using NpgsqlTypes; using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; diff --git a/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs index 9905d44753..9b30b3e1df 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs @@ -1,6 +1,5 @@ using NpgsqlTypes; using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; diff --git a/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs index b93e27fa3c..24de9e201f 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs @@ -1,7 +1,4 @@ -using NpgsqlTypes; -using System; - -namespace Npgsql.Replication.PgOutput.Messages; +namespace Npgsql.Replication.PgOutput.Messages; /// /// The base class of all Logical Replication Protocol Messages diff --git a/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs index f9be4a1eeb..85d83debb7 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs @@ -1,7 +1,6 @@ using NpgsqlTypes; using System; using System.Collections.Generic; -using System.Collections.Immutable; using Npgsql.BackendMessages; namespace Npgsql.Replication.PgOutput.Messages; @@ -136,4 +135,4 @@ public enum ReplicaIdentitySetting : byte /// IndexWithIndIsReplIdent = (byte)'i' } -} \ No newline at end of file +} diff --git a/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs b/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs index 76d983a6ee..5b53e06bdf 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs @@ -5,10 +5,8 @@ using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers.DateTimeHandlers; using Npgsql.Replication.Internal; using Npgsql.Replication.PgOutput.Messages; -using Npgsql.Util; using NpgsqlTypes; namespace Npgsql.Replication.PgOutput; @@ -91,7 +89,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc _slot, cancellationToken, _walLocation, _options.GetOptionPairs(), bypassingStream: true); var buf = _connection.Connector!.ReadBuffer; var inStreamingTransaction = false; - var formatCode = _options.Binary ?? false ? FormatCode.Binary : FormatCode.Text; + var dataFormat = _options.Binary ?? false ? DataFormat.Binary : DataFormat.Text; await foreach (var xLogData in stream.WithCancellation(cancellationToken)) { @@ -104,7 +102,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc await buf.EnsureAsync(20); yield return _beginMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionFinalLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionCommitTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), transactionXid: buf.ReadUInt32()); continue; } @@ -128,7 +126,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc await buf.EnsureAsync(4); var length = buf.ReadUInt32(); var data = (NpgsqlReadBuffer.ColumnStream)xLogData.Data; - data.Init(checked((int)length), false); + data.Init(checked((int)length), canSeek: false, commandScoped: false); yield return _logicalDecodingMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, flags, messageLsn, prefix, data); continue; @@ -141,7 +139,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc (CommitMessage.CommitFlags)buf.ReadByte(), commitLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), transactionEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionCommitTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); continue; } case BackendReplicationMessageCode.Origin: @@ -193,7 +191,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc } msg.RowDescription = RowDescriptionMessage.CreateForReplication( - _connection.Connector.TypeMapper, relationId, formatCode, columns); + _connection.Connector.SerializerOptions, relationId, dataFormat, columns); yield return msg; continue; @@ -397,7 +395,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc yield return _streamCommitMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid: buf.ReadUInt32(), flags: buf.ReadByte(), commitLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), transactionEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionCommitTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); continue; } case BackendReplicationMessageCode.StreamAbort: @@ -413,7 +411,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc yield return _beginPrepareMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, prepareLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), prepareEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionPrepareTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), transactionXid: buf.ReadUInt32(), transactionGid: buf.ReadNullTerminatedString()); continue; @@ -425,7 +423,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc flags: (PrepareMessage.PrepareFlags)buf.ReadByte(), prepareLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), prepareEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionPrepareTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), transactionXid: buf.ReadUInt32(), transactionGid: buf.ReadNullTerminatedString()); continue; @@ -437,7 +435,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc flags: (CommitPreparedMessage.CommitPreparedFlags)buf.ReadByte(), commitPreparedLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), commitPreparedEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionCommitTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), transactionXid: buf.ReadUInt32(), transactionGid: buf.ReadNullTerminatedString()); continue; @@ -449,8 +447,8 @@ async IAsyncEnumerator StartReplicationInternal(Canc flags: (RollbackPreparedMessage.RollbackPreparedFlags)buf.ReadByte(), preparedTransactionEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), rollbackPreparedEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionPrepareTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), - transactionRollbackTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionRollbackTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), transactionXid: buf.ReadUInt32(), transactionGid: buf.ReadNullTerminatedString()); continue; @@ -462,7 +460,7 @@ async IAsyncEnumerator StartReplicationInternal(Canc flags: (StreamPrepareMessage.StreamPrepareFlags)buf.ReadByte(), prepareLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), prepareEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - transactionPrepareTimestamp: DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), transactionXid: buf.ReadUInt32(), transactionGid: buf.ReadNullTerminatedString()); continue; diff --git a/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs b/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs index 596dc471fb..df910af4d2 100644 --- a/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs +++ b/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs @@ -1,7 +1,6 @@ using System; using System.Collections; using System.Collections.Generic; -using Npgsql.Replication.PgOutput.Messages; namespace Npgsql.Replication.PgOutput; diff --git a/src/Npgsql/Replication/PgOutput/ReplicationValue.cs b/src/Npgsql/Replication/PgOutput/ReplicationValue.cs index 6ad0cbc6e1..7c5f104f3e 100644 --- a/src/Npgsql/Replication/PgOutput/ReplicationValue.cs +++ b/src/Npgsql/Replication/PgOutput/ReplicationValue.cs @@ -4,9 +4,7 @@ using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; using Npgsql.PostgresTypes; -using Npgsql.Replication.PgOutput.Messages; namespace Npgsql.Replication.PgOutput; @@ -27,26 +25,21 @@ public class ReplicationValue /// public TupleDataKind Kind { get; private set; } - bool _columnConsumed; FieldDescription _fieldDescription = null!; + PgConverterInfo _lastInfo; + bool _isConsumed; - /// - /// A stream that has been opened on a column. - /// - readonly NpgsqlReadBuffer.ColumnStream _columnStream; + PgReader PgReader => _readBuffer.PgReader; - internal ReplicationValue(NpgsqlConnector connector) - { - _readBuffer = connector.ReadBuffer; - _columnStream = new NpgsqlReadBuffer.ColumnStream(connector, startCancellableOperations: false); - } + internal ReplicationValue(NpgsqlConnector connector) => _readBuffer = connector.ReadBuffer; internal void Reset(TupleDataKind kind, int length, FieldDescription fieldDescription) { Kind = kind; Length = length; _fieldDescription = fieldDescription; - _columnConsumed = false; + _lastInfo = default; + _isConsumed = false; } // ReSharper disable once InconsistentNaming @@ -93,13 +86,16 @@ public bool IsUnchangedToastedValue /// public ValueTask Get(CancellationToken cancellationToken = default) { - CheckAndMarkConsumed(); + CheckActive(); + + ref var info = ref _lastInfo; + _fieldDescription.GetInfo(typeof(T), ref info); switch (Kind) { case TupleDataKind.Null: // When T is a Nullable (and only in that case), we support returning null - if (NullableHandler.Exists) + if (default(T) is null && typeof(T).IsValueType) return default!; if (typeof(T) == typeof(object)) @@ -114,36 +110,19 @@ public ValueTask Get(CancellationToken cancellationToken = default) } using (NoSynchronizationContextScope.Enter()) - return GetCore(cancellationToken); + return GetCore(info, _fieldDescription.DataFormat, _readBuffer, Length, cancellationToken); - async ValueTask GetCore(CancellationToken cancellationToken) + static async ValueTask GetCore(PgConverterInfo info, DataFormat format, NpgsqlReadBuffer buffer, int length, CancellationToken cancellationToken) { - using var tokenRegistration = _readBuffer.ReadBytesLeft < Length - ? _readBuffer.Connector.StartNestedCancellableOperation(cancellationToken) - : default; - - var position = _readBuffer.ReadPosition; - - try - { - return NullableHandler.Exists - ? await NullableHandler.ReadAsync(_fieldDescription.Handler, _readBuffer, Length, async: true, _fieldDescription) - : typeof(T) == typeof(object) - ? (T)await _fieldDescription.Handler.ReadAsObject(_readBuffer, Length, async: true, _fieldDescription) - : await _fieldDescription.Handler.Read(_readBuffer, Length, async: true, _fieldDescription); - } - catch - { - if (_readBuffer.Connector.State != ConnectorState.Broken) - { - var writtenBytes = _readBuffer.ReadPosition - position; - var remainingBytes = Length - writtenBytes; - if (remainingBytes > 0) - _readBuffer.Skip(remainingBytes, false).GetAwaiter().GetResult(); - } - - throw; - } + using var registration = buffer.Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + var reader = buffer.PgReader.Init(length, format); + await reader.StartReadAsync(info.BufferRequirement, cancellationToken); + var result = info.AsObject + ? (T)await info.Converter.ReadAsObjectAsync(reader, cancellationToken) + : await info.GetConverter().ReadAsync(reader, cancellationToken); + await reader.EndReadAsync(); + return result; } } @@ -154,56 +133,38 @@ async ValueTask GetCore(CancellationToken cancellationToken) /// An optional token to cancel the asynchronous operation. The default value is . /// /// - public ValueTask Get(CancellationToken cancellationToken = default) + public ValueTask Get(CancellationToken cancellationToken = default) => Get(cancellationToken); + + /// + /// Retrieves data as a . + /// + public Stream GetStream() { - CheckAndMarkConsumed(); + CheckActive(); switch (Kind) { case TupleDataKind.Null: - return new ValueTask(DBNull.Value); + ThrowHelper.ThrowInvalidCastException_NoValue(_fieldDescription); + break; case TupleDataKind.UnchangedToastedValue: - throw new InvalidCastException( - $"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); + throw new InvalidCastException($"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); } - using (NoSynchronizationContextScope.Enter()) - return GetCore(cancellationToken); - - async ValueTask GetCore(CancellationToken cancellationToken) - { - using var tokenRegistration = _readBuffer.ReadBytesLeft < Length - ? _readBuffer.Connector.StartNestedCancellableOperation(cancellationToken) - : default; - - var position = _readBuffer.ReadPosition; - - try - { - return await _fieldDescription.Handler.ReadAsObject(_readBuffer, Length, async: true, _fieldDescription); - } - catch - { - if (_readBuffer.Connector.State != ConnectorState.Broken) - { - var writtenBytes = _readBuffer.ReadPosition - position; - var remainingBytes = Length - writtenBytes; - if (remainingBytes > 0) - _readBuffer.Skip(remainingBytes, false).GetAwaiter().GetResult(); - } - - throw; - } - } + var reader = _readBuffer.PgReader.Init(Length, _fieldDescription.DataFormat); + return reader.GetStream(canSeek: false); } /// - /// Retrieves data as a . + /// Retrieves data as a . /// - public Stream GetStream() + public TextReader GetTextReader() { - CheckAndMarkConsumed(); + CheckActive(); + + ref var info = ref _lastInfo; + _fieldDescription.GetInfo(typeof(TextReader), ref info); switch (Kind) { @@ -215,44 +176,29 @@ public Stream GetStream() throw new InvalidCastException($"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); } - _columnStream.Init(Length, canSeek: false); - return _columnStream; + var reader = PgReader.Init(Length, _fieldDescription.DataFormat); + reader.StartRead(info.BufferRequirement); + var result = (TextReader)info.Converter.ReadAsObject(reader); + reader.EndRead(); + return result; } - /// - /// Retrieves data as a . - /// - public TextReader GetTextReader() - => _fieldDescription.Handler is ITextReaderHandler handler - ? handler.GetTextReader(GetStream(), _readBuffer) - : throw new InvalidCastException( - $"The GetTextReader method is not supported for type {_fieldDescription.Handler.PgDisplayName}"); - internal async Task Consume(CancellationToken cancellationToken) { - if (!_columnStream.IsDisposed) - await _columnStream.DisposeAsync(); + if (_isConsumed) + return; - if (!_columnConsumed) - { - if (_readBuffer.ReadBytesLeft < 4) - { - using var tokenRegistration = _readBuffer.Connector.StartNestedCancellableOperation(cancellationToken); - await _readBuffer.Skip(Length, async: true); - } - else - { - await _readBuffer.Skip(Length, async: true); - } - } + if (!PgReader.Initialized) + PgReader.Init(Length, _fieldDescription.DataFormat); + await PgReader.ConsumeAsync(cancellationToken: cancellationToken); + await PgReader.Commit(async: true, resuming: false); - _columnConsumed = true; + _isConsumed = true; } - void CheckAndMarkConsumed() + void CheckActive() { - if (_columnConsumed) + if (PgReader.Initialized) throw new InvalidOperationException("Column has already been consumed"); - _columnConsumed = true; } -} \ No newline at end of file +} diff --git a/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs b/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs index 95e5bfe293..dc54a92515 100644 --- a/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs +++ b/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs @@ -4,7 +4,6 @@ using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Replication.PgOutput.Messages; namespace Npgsql.Replication.PgOutput; @@ -64,11 +63,7 @@ async ValueTask MoveNextCore() break; case TupleDataKind.TextValue: case TupleDataKind.BinaryValue: - if (_readBuffer.ReadBytesLeft < 4) - { - using var tokenRegistration = _readBuffer.Connector.StartNestedCancellableOperation(_cancellationToken); - await _readBuffer.Ensure(4, async: true); - } + await _readBuffer.Ensure(4, async: true); len = _readBuffer.ReadInt32(); break; default: @@ -96,4 +91,4 @@ public async ValueTask DisposeAsync() _tupleEnumerable.State = RowState.Consumed; } -} \ No newline at end of file +} diff --git a/src/Npgsql/Replication/ReplicationConnection.cs b/src/Npgsql/Replication/ReplicationConnection.cs index 903e6b7b28..5b0381afb0 100644 --- a/src/Npgsql/Replication/ReplicationConnection.cs +++ b/src/Npgsql/Replication/ReplicationConnection.cs @@ -11,7 +11,6 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers.DateTimeHandlers; using static Npgsql.Util.Statics; using Npgsql.Util; @@ -242,6 +241,8 @@ public async Task Open(CancellationToken cancellationToken = default) SetTimeouts(CommandTimeout, CommandTimeout); + _npgsqlConnection.Connector!.LongRunningConnection = true; + ReplicationLogger = _npgsqlConnection.Connector!.LoggingConfiguration.ReplicationLogger; } @@ -449,7 +450,7 @@ internal async IAsyncEnumerator StartReplicationInternal( _replicationCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - using var _ = Connector.StartUserAction( + using var _ = connector.StartUserAction( ConnectorState.Replication, _replicationCancellationTokenSource.Token, attemptPgCancellation: _pgCancellationSupported); NpgsqlReadBuffer.ColumnStream? columnStream = null; @@ -474,8 +475,7 @@ internal async IAsyncEnumerator StartReplicationInternal( var buf = connector.ReadBuffer; - // Cancellation is handled at the replication level - we don't want every ReadAsync - columnStream = new NpgsqlReadBuffer.ColumnStream(connector, startCancellableOperations: false); + columnStream = new NpgsqlReadBuffer.ColumnStream(connector); SetTimeouts(_walReceiverTimeout, CommandTimeout); @@ -484,7 +484,7 @@ internal async IAsyncEnumerator StartReplicationInternal( while (true) { - msg = await Connector.ReadMessage(async: true); + msg = await connector.ReadMessage(async: true); Expect(msg, Connector); // We received some message so there's no need to forcibly request feedback @@ -501,7 +501,7 @@ internal async IAsyncEnumerator StartReplicationInternal( await buf.EnsureAsync(24); var startLsn = buf.ReadUInt64(); var endLsn = buf.ReadUInt64(); - var sendTime = DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc); + var sendTime = PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc); if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < startLsn) Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)startLsn)); @@ -510,7 +510,7 @@ internal async IAsyncEnumerator StartReplicationInternal( // dataLen = msg.Length - (code = 1 + walStart = 8 + walEnd = 8 + serverClock = 8) var dataLen = messageLength - 25; - columnStream.Init(dataLen, canSeek: false); + columnStream.Init(dataLen, canSeek: false, commandScoped: false); _cachedXLogDataMessage.Populate(new NpgsqlLogSequenceNumber(startLsn), new NpgsqlLogSequenceNumber(endLsn), sendTime, columnStream); @@ -519,7 +519,7 @@ internal async IAsyncEnumerator StartReplicationInternal( // Our consumer may not have read the stream to the end, but it might as well have been us // ourselves bypassing the stream and reading directly from the buffer in StartReplication() if (!columnStream.IsDisposed && columnStream.Position < columnStream.Length && !bypassingStream) - await buf.Skip(columnStream.Length - columnStream.Position, true); + await buf.Skip(checked((int)(columnStream.Length - columnStream.Position)), true); continue; } @@ -532,7 +532,7 @@ internal async IAsyncEnumerator StartReplicationInternal( if (ReplicationLogger.IsEnabled(LogLevel.Trace)) { var endLsn = new NpgsqlLogSequenceNumber(end); - var timestamp = DateTimeUtils.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc); + var timestamp = PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc); LogMessages.ReceivedReplicationPrimaryKeepalive(ReplicationLogger, endLsn, timestamp, Connector.Id); } else @@ -679,7 +679,7 @@ async Task SendFeedback(bool waitOnSemaphore = false, bool requestReply = false, buf.WriteInt64(lastReceivedLsn); buf.WriteInt64(lastFlushedLsn); buf.WriteInt64(lastAppliedLsn); - buf.WriteInt64(DateTimeUtils.EncodeTimestamp(timestamp)); + buf.WriteInt64(PgDateTime.EncodeTimestamp(timestamp)); buf.WriteByte(requestReply ? (byte)1 : (byte)0); await connector.Flush(async: true, cancellationToken); diff --git a/src/Npgsql/Schema/DbColumnSchemaGenerator.cs b/src/Npgsql/Schema/DbColumnSchemaGenerator.cs index 835cfbe424..88cc775e00 100644 --- a/src/Npgsql/Schema/DbColumnSchemaGenerator.cs +++ b/src/Npgsql/Schema/DbColumnSchemaGenerator.cs @@ -8,9 +8,10 @@ using System.Transactions; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandlers.CompositeHandlers; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; using Npgsql.Util; +using NpgsqlTypes; namespace Npgsql.Schema; @@ -115,18 +116,18 @@ internal async Task> GetColumnSchema(bool asy .Where(f => f.TableOID != 0) // Only column fields .Select(c => $"(attr.attrelid={c.TableOID} AND attr.attnum={c.ColumnAttributeNumber})") .Join(" OR "); - + if (columnFieldFilter != string.Empty) { var query = oldQueryMode ? GenerateOldColumnsQuery(columnFieldFilter) : GenerateColumnsQuery(_connection.PostgreSqlVersion, columnFieldFilter); - + using var scope = new TransactionScope( TransactionScopeOption.Suppress, async ? TransactionScopeAsyncFlowOption.Enabled : TransactionScopeAsyncFlowOption.Suppress); using var connection = (NpgsqlConnection)((ICloneable)_connection).Clone(); - + await connection.Open(async, cancellationToken); using var cmd = new NpgsqlCommand(query, connection); @@ -135,7 +136,7 @@ internal async Task> GetColumnSchema(bool asy { while (async ? await reader.ReadAsync(cancellationToken) : reader.Read()) { - var column = LoadColumnDefinition(reader, _connection.Connector!.TypeMapper.DatabaseInfo, oldQueryMode); + var column = LoadColumnDefinition(reader, _connection.Connector!.DatabaseInfo, oldQueryMode); for (var ordinal = 0; ordinal < numFields; ordinal++) { var field = _rowDescription[ordinal]; @@ -253,19 +254,16 @@ NpgsqlDbColumn SetUpNonColumnField(FieldDescription field) /// void ColumnPostConfig(NpgsqlDbColumn column, int typeModifier) { - var typeMapper = _connection.Connector!.TypeMapper; - - column.NpgsqlDbType = typeMapper.GetTypeInfoByOid(column.TypeOID).npgsqlDbType; - column.DataType = typeMapper.TryResolveByOID(column.TypeOID, out var handler) - ? handler.GetFieldType() - : null; + var serializerOptions = _connection.Connector!.SerializerOptions; - if (column.DataType != null) + column.NpgsqlDbType = column.PostgresType.DataTypeName.ToNpgsqlDbType(); + if (serializerOptions.GetObjectOrDefaultTypeInfo(column.PostgresType) is { } typeInfo) { - column.IsLong = handler is ByteaHandler; + column.DataType = typeInfo.Type; + column.IsLong = column.PostgresType.DataTypeName == DataTypeNames.Bytea; - if (handler is ICompositeHandler) - column.UdtAssemblyQualifiedName = column.DataType.AssemblyQualifiedName; + if (column.PostgresType is PostgresCompositeType) + column.UdtAssemblyQualifiedName = typeInfo.Type.AssemblyQualifiedName; } var facets = column.PostgresType.GetFacets(typeModifier); @@ -276,4 +274,4 @@ void ColumnPostConfig(NpgsqlDbColumn column, int typeModifier) if (facets.Scale != null) column.NumericScale = facets.Scale; } -} \ No newline at end of file +} diff --git a/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs b/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs index c752cf2199..02f5c2077c 100644 --- a/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs +++ b/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs @@ -1,6 +1,3 @@ -using System; -using System.Collections.Concurrent; - namespace System.Collections.Concurrent; #if NETSTANDARD2_0 diff --git a/src/Npgsql/Shims/MemoryExtensions.cs b/src/Npgsql/Shims/MemoryExtensions.cs new file mode 100644 index 0000000000..6247c6a21e --- /dev/null +++ b/src/Npgsql/Shims/MemoryExtensions.cs @@ -0,0 +1,18 @@ +#if !NET7_0_OR_GREATER +namespace System; + +static class MemoryExtensions +{ + public static int IndexOfAnyExcept(this ReadOnlySpan span, T value0, T value1) where T : IEquatable + { + for (var i = 0; i < span.Length; i++) + { + var v = span[i]; + if (!v.Equals(value0) && !v.Equals(value1)) + return i; + } + + return -1; + } +} +#endif diff --git a/src/Npgsql/Shims/ReadOnlySequenceExtensions.cs b/src/Npgsql/Shims/ReadOnlySequenceExtensions.cs new file mode 100644 index 0000000000..0370285a7d --- /dev/null +++ b/src/Npgsql/Shims/ReadOnlySequenceExtensions.cs @@ -0,0 +1,13 @@ +namespace System.Buffers; + +static class ReadOnlySequenceExtensions +{ + public static ReadOnlySpan GetFirstSpan(this ReadOnlySequence sequence) + { +#if NETSTANDARD + return sequence.First.Span; +# else + return sequence.FirstSpan; +#endif + } +} diff --git a/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs b/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs index c805e984a5..11a70c9793 100644 --- a/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs +++ b/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs @@ -1,7 +1,5 @@ using System; -using System.Collections.Generic; using System.Runtime.CompilerServices; -using System.Text; namespace Npgsql.Netstandard20; diff --git a/src/Npgsql/Shims/ReferenceEqualityComparer.cs b/src/Npgsql/Shims/ReferenceEqualityComparer.cs new file mode 100644 index 0000000000..38515ed90f --- /dev/null +++ b/src/Npgsql/Shims/ReferenceEqualityComparer.cs @@ -0,0 +1,48 @@ +using System.Runtime.CompilerServices; + +namespace System.Collections.Generic; + +#if NETSTANDARD +sealed class ReferenceEqualityComparer : IEqualityComparer, IEqualityComparer +{ + ReferenceEqualityComparer() { } + + /// + /// Gets the singleton instance. + /// + public static ReferenceEqualityComparer Instance { get; } = new(); + + /// + /// Determines whether two object references refer to the same object instance. + /// + /// The first object to compare. + /// The second object to compare. + /// + /// if both and refer to the same object instance + /// or if both are ; otherwise, . + /// + /// + /// This API is a wrapper around . + /// It is not necessarily equivalent to calling . + /// + public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); + + /// + /// Returns a hash code for the specified object. The returned hash code is based on the object + /// identity, not on the contents of the object. + /// + /// The object for which to retrieve the hash code. + /// A hash code for the identity of . + /// + /// This API is a wrapper around . + /// It is not necessarily equivalent to calling . + /// + public int GetHashCode(object? obj) + { + // Depending on target framework, RuntimeHelpers.GetHashCode might not be annotated + // with the proper nullability attribute. We'll suppress any warning that might + // result. + return RuntimeHelpers.GetHashCode(obj!); + } +} +#endif diff --git a/src/Npgsql/Shims/StreamExtensions.cs b/src/Npgsql/Shims/StreamExtensions.cs index 925061870d..5215b02ce0 100644 --- a/src/Npgsql/Shims/StreamExtensions.cs +++ b/src/Npgsql/Shims/StreamExtensions.cs @@ -1,7 +1,9 @@ -#if NETSTANDARD2_0 +#if NETSTANDARD2_0 || !NET7_0_OR_GREATER using System.Buffers; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; +using Npgsql; // ReSharper disable once CheckNamespace namespace System.IO @@ -9,6 +11,33 @@ namespace System.IO // Helpers to read/write Span/Memory to Stream before netstandard 2.1 static class StreamExtensions { + public static void ReadExactly(this Stream stream, Span buffer) + { + var totalRead = 0; + while (totalRead < buffer.Length) + { + var read = stream.Read(buffer.Slice(totalRead)); + if (read is 0) + throw new EndOfStreamException(); + + totalRead += read; + } + } + + public static async ValueTask ReadExactlyAsync(this Stream stream, Memory buffer, CancellationToken cancellationToken = default) + { + var totalRead = 0; + while (totalRead < buffer.Length) + { + var read = await stream.ReadAsync(buffer.Slice(totalRead), cancellationToken).ConfigureAwait(false); + if (read is 0) + throw new EndOfStreamException(); + + totalRead += read; + } + } + +#if NETSTANDARD2_0 public static int Read(this Stream stream, Span buffer) { var sharedBuffer = ArrayPool.Shared.Rent(buffer.Length); @@ -66,6 +95,7 @@ public static async ValueTask WriteAsync(this Stream stream, ReadOnlyMemory.Shared.Return(sharedBuffer); } } +#endif } } #endif diff --git a/src/Npgsql/Shims/UnreachableException.cs b/src/Npgsql/Shims/UnreachableException.cs new file mode 100644 index 0000000000..c45f3fd1d8 --- /dev/null +++ b/src/Npgsql/Shims/UnreachableException.cs @@ -0,0 +1,41 @@ +#if !NET7_0_OR_GREATER +using System; + +namespace System.Diagnostics; + +/// +/// Exception thrown when the program executes an instruction that was thought to be unreachable. +/// +sealed class UnreachableException : Exception +{ + /// + /// Initializes a new instance of the class with the default error message. + /// + public UnreachableException() + : base("The program executed an instruction that was thought to be unreachable.") + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public UnreachableException(string? message) + : base(message) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message and a reference to the inner exception that is the cause of + /// this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. + public UnreachableException(string? message, Exception? innerException) + : base(message, innerException) + { + } +} +#endif diff --git a/src/Npgsql/ThrowHelper.cs b/src/Npgsql/ThrowHelper.cs index 57eaa5cc42..d6666bd130 100644 --- a/src/Npgsql/ThrowHelper.cs +++ b/src/Npgsql/ThrowHelper.cs @@ -52,6 +52,14 @@ internal static void ThrowInvalidCastException(string message, object argument) internal static void ThrowInvalidCastException_NoValue(FieldDescription field) => throw new InvalidCastException($"Column '{field.Name}' is null."); + [DoesNotReturn] + internal static void ThrowInvalidCastException(string message) => + throw new InvalidCastException(message); + + [DoesNotReturn] + internal static void ThrowInvalidCastException_NoValue() => + throw new InvalidCastException("Field is null."); + [DoesNotReturn] internal static void ThrowArgumentOutOfRange_OutOfColumnBounds(string paramName, int columnLength) => throw new ArgumentOutOfRangeException(paramName, $"The value is out of bounds from the column data, dataOffset must be between 0 and {columnLength}"); @@ -96,6 +104,10 @@ internal static void ThrowArgumentException(string message, string paramName) internal static void ThrowArgumentNullException(string paramName) => throw new ArgumentNullException(paramName); + [DoesNotReturn] + internal static void ThrowArgumentNullException(string message, string paramName) + => throw new ArgumentNullException(paramName, message); + [DoesNotReturn] internal static void ThrowIndexOutOfRangeException(string message) => throw new IndexOutOfRangeException(message); diff --git a/src/Npgsql/TypeMapping/BuiltInTypeHandlerResolver.cs b/src/Npgsql/TypeMapping/BuiltInTypeHandlerResolver.cs deleted file mode 100644 index fcdbb626d1..0000000000 --- a/src/Npgsql/TypeMapping/BuiltInTypeHandlerResolver.cs +++ /dev/null @@ -1,449 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.Specialized; -using System.IO; -using System.Net; -using System.Net.NetworkInformation; -using System.Numerics; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandlers.DateTimeHandlers; -using Npgsql.Internal.TypeHandlers.FullTextSearchHandlers; -using Npgsql.Internal.TypeHandlers.GeometricHandlers; -using Npgsql.Internal.TypeHandlers.InternalTypeHandlers; -using Npgsql.Internal.TypeHandlers.LTreeHandlers; -using Npgsql.Internal.TypeHandlers.NetworkHandlers; -using Npgsql.Internal.TypeHandlers.NumericHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.Properties; -using NpgsqlTypes; -using static Npgsql.Util.Statics; - -namespace Npgsql.TypeMapping; - -sealed class BuiltInTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlConnector _connector; - readonly NpgsqlDatabaseInfo _databaseInfo; - - #region Cached handlers - - // Numeric types - readonly Int16Handler _int16Handler; - readonly Int32Handler _int32Handler; - readonly Int64Handler _int64Handler; - SingleHandler? _singleHandler; - readonly DoubleHandler _doubleHandler; - readonly NumericHandler _numericHandler; - MoneyHandler? _moneyHandler; - - // Text types - readonly TextHandler _textHandler; - TextHandler? _xmlHandler; - TextHandler? _varcharHandler; - TextHandler? _charHandler; - TextHandler? _nameHandler; - TextHandler? _refcursorHandler; - TextHandler? _citextHandler; - - // Note that old versions of PG - as well as some PG-like databases (Redshift, CockroachDB) don't have json/jsonb, so we create - // these handlers lazily rather than eagerly. - JsonTextHandler? _jsonbHandler; - JsonTextHandler? _jsonHandler; - JsonPathHandler? _jsonPathHandler; - - // Date/time types - readonly TimestampHandler _timestampHandler; - readonly TimestampTzHandler _timestampTzHandler; - readonly DateHandler _dateHandler; - TimeHandler? _timeHandler; - TimeTzHandler? _timeTzHandler; - IntervalHandler? _intervalHandler; - - // Network types - CidrHandler? _cidrHandler; - InetHandler? _inetHandler; - MacaddrHandler? _macaddrHandler; - MacaddrHandler? _macaddr8Handler; - - // Geometry types - BoxHandler? _boxHandler; - CircleHandler? _circleHandler; - LineHandler? _lineHandler; - LineSegmentHandler? _lineSegmentHandler; - PathHandler? _pathHandler; - PointHandler? _pointHandler; - PolygonHandler? _polygonHandler; - - // LTree types - LQueryHandler? _lQueryHandler; - LTreeHandler? _lTreeHandler; - LTxtQueryHandler? _lTxtQueryHandler; - - // UInt types - UInt32Handler? _oidHandler; - UInt32Handler? _xidHandler; - UInt64Handler? _xid8Handler; - UInt32Handler? _cidHandler; - UInt32Handler? _regtypeHandler; - UInt32Handler? _regconfigHandler; - - // Misc types - readonly BoolHandler _boolHandler; - ByteaHandler? _byteaHandler; - UuidHandler? _uuidHandler; - BitStringHandler? _bitVaryingHandler; - BitStringHandler? _bitHandler; - VoidHandler? _voidHandler; - HstoreHandler? _hstoreHandler; - - // Internal types - Int2VectorHandler? _int2VectorHandler; - OIDVectorHandler? _oidVectorHandler; - PgLsnHandler? _pgLsnHandler; - TidHandler? _tidHandler; - InternalCharHandler? _internalCharHandler; - - // Special types - UnknownTypeHandler? _unknownHandler; - - // Complex type handlers over timestamp/timestamptz (because DateTime is value-dependent) - NpgsqlTypeHandler? _timestampArrayHandler; - NpgsqlTypeHandler? _timestampTzArrayHandler; - - #endregion Cached handlers - - internal BuiltInTypeHandlerResolver(NpgsqlConnector connector) - { - _connector = connector; - _databaseInfo = connector.DatabaseInfo; - - // Eagerly instantiate some handlers for very common types so we don't need to check later - _int16Handler = new Int16Handler(PgType("smallint")); - _int32Handler = new Int32Handler(PgType("integer")); - _int64Handler = new Int64Handler(PgType("bigint")); - _doubleHandler = new DoubleHandler(PgType("double precision")); - _numericHandler = new NumericHandler(PgType("numeric")); - _textHandler ??= new TextHandler(PgType("text"), _connector.TextEncoding); - _timestampHandler ??= new TimestampHandler(PgType("timestamp without time zone")); - _timestampTzHandler ??= new TimestampTzHandler(PgType("timestamp with time zone")); - _dateHandler ??= new DateHandler(PgType("date")); - _boolHandler ??= new BoolHandler(PgType("boolean")); - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName switch - { - // Numeric types - "smallint" => _int16Handler, - "integer" or "int" => _int32Handler, - "bigint" => _int64Handler, - "real" => SingleHandler(), - "double precision" => _doubleHandler, - "numeric" or "decimal" => _numericHandler, - "money" => MoneyHandler(), - - // Text types - "text" => _textHandler, - "xml" => XmlHandler(), - "varchar" or "character varying" => VarcharHandler(), - "character" => CharHandler(), - "name" => NameHandler(), - "refcursor" => RefcursorHandler(), - "citext" => CitextHandler(), - "jsonb" => JsonbHandler(), - "json" => JsonHandler(), - "jsonpath" => JsonPathHandler(), - - // Date/time types - "timestamp" or "timestamp without time zone" => _timestampHandler, - "timestamptz" or "timestamp with time zone" => _timestampTzHandler, - "date" => _dateHandler, - "time without time zone" => TimeHandler(), - "time with time zone" => TimeTzHandler(), - "interval" => IntervalHandler(), - - // Network types - "cidr" => CidrHandler(), - "inet" => InetHandler(), - "macaddr" => MacaddrHandler(), - "macaddr8" => Macaddr8Handler(), - - // Geometry types - "box" => BoxHandler(), - "circle" => CircleHandler(), - "line" => LineHandler(), - "lseg" => LineSegmentHandler(), - "path" => PathHandler(), - "point" => PointHandler(), - "polygon" => PolygonHandler(), - - // LTree types - "lquery" => LQueryHandler(), - "ltree" => LTreeHandler(), - "ltxtquery" => LTxtHandler(), - - // UInt types - "oid" => OidHandler(), - "xid" => XidHandler(), - "xid8" => Xid8Handler(), - "cid" => CidHandler(), - "regtype" => RegtypeHandler(), - "regconfig" => RegconfigHandler(), - - // Misc types - "bool" or "boolean" => _boolHandler, - "bytea" => ByteaHandler(), - "uuid" => UuidHandler(), - "bit varying" or "varbit" => BitVaryingHandler(), - "bit" => BitHandler(), - "hstore" => HstoreHandler(), - - // Internal types - "int2vector" => Int2VectorHandler(), - "oidvector" => OidVectorHandler(), - "pg_lsn" => PgLsnHandler(), - "tid" => TidHandler(), - "char" => InternalCharHandler(), - "void" => VoidHandler(), - - "unknown" => UnknownHandler(), - - // Types that are unsupported by default when using NpgsqlSlimDataSourceBuilder - "record" => UnsupportedRecordHandler(), - "tsvector" => UnsupportedTsVectorHandler(), - "tsquery" => UnsupportedTsQueryHandler(), - - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - { - if (BuiltInTypeMappingResolver.ClrTypeToDataTypeName(type) is { } dataTypeName) - return ResolveByDataTypeName(dataTypeName); - - if (type.IsSubclassOf(typeof(Stream))) - return ResolveByDataTypeName("bytea"); - - switch (type.FullName) - { - case "NpgsqlTypes.NpgsqlTsVector": - case "NpgsqlTypes.NpgsqlTsQueryLexeme": - case "NpgsqlTypes.NpgsqlTsQueryAnd": - case "NpgsqlTypes.NpgsqlTsQueryOr": - case "NpgsqlTypes.NpgsqlTsQueryNot": - case "NpgsqlTypes.NpgsqlTsQueryEmpty": - case "NpgsqlTypes.NpgsqlTsQueryFollowedBy": - return UnsupportedTsQueryHandler(); - - default: - return null; - } - } - - public override NpgsqlTypeHandler? ResolveValueDependentValue(object value) - { - // In LegacyTimestampBehavior, DateTime isn't value-dependent, and handled above in ClrTypeToDataTypeNameTable like other types - if (LegacyTimestampBehavior) - return null; - - return value switch - { - DateTime dateTime => dateTime.Kind == DateTimeKind.Utc ? _timestampTzHandler : _timestampHandler, - - // For arrays/lists, return timestamp or timestamptz based on the kind of the first DateTime; if the user attempts to - // mix incompatible Kinds, that will fail during validation. For empty arrays it doesn't matter. - IList array => ArrayHandler(array.Count == 0 ? DateTimeKind.Unspecified : array[0].Kind), - - _ => null - }; - - NpgsqlTypeHandler ArrayHandler(DateTimeKind kind) - => kind == DateTimeKind.Utc - ? _timestampTzArrayHandler ??= _timestampTzHandler.CreateArrayHandler( - (PostgresArrayType)PgType("timestamp with time zone[]"), _connector.Settings.ArrayNullabilityMode) - : _timestampArrayHandler ??= _timestampHandler.CreateArrayHandler( - (PostgresArrayType)PgType("timestamp without time zone[]"), _connector.Settings.ArrayNullabilityMode); - } - - public override NpgsqlTypeHandler? ResolveValueTypeGenerically(T value) - { - // This method only ever gets called for value types, and relies on the JIT specializing the method for T by eliding all the - // type checks below. - - // Numeric types - if (typeof(T) == typeof(byte)) - return _int16Handler; - if (typeof(T) == typeof(short)) - return _int16Handler; - if (typeof(T) == typeof(int)) - return _int32Handler; - if (typeof(T) == typeof(long)) - return _int64Handler; - if (typeof(T) == typeof(float)) - return SingleHandler(); - if (typeof(T) == typeof(double)) - return _doubleHandler; - if (typeof(T) == typeof(decimal)) - return _numericHandler; - if (typeof(T) == typeof(BigInteger)) - return _numericHandler; - - // Text types - if (typeof(T) == typeof(char)) - return _textHandler; - if (typeof(T) == typeof(ArraySegment)) - return _textHandler; - - // Date/time types - // No resolution for DateTime, since that's value-dependent (Kind) - if (typeof(T) == typeof(DateTimeOffset)) - return _timestampTzHandler; -#if NET6_0_OR_GREATER - if (typeof(T) == typeof(DateOnly)) - return _dateHandler; - if (typeof(T) == typeof(TimeOnly)) - return _timeHandler; -#endif - if (typeof(T) == typeof(TimeSpan)) - return _intervalHandler; - if (typeof(T) == typeof(NpgsqlInterval)) - return _intervalHandler; - - // Network types - if (typeof(T) == typeof(IPAddress)) - return InetHandler(); - if (typeof(T) == typeof(PhysicalAddress)) - return _macaddrHandler; - if (typeof(T) == typeof(TimeSpan)) - return _intervalHandler; - - // Geometry types - if (typeof(T) == typeof(NpgsqlBox)) - return BoxHandler(); - if (typeof(T) == typeof(NpgsqlCircle)) - return CircleHandler(); - if (typeof(T) == typeof(NpgsqlLine)) - return LineHandler(); - if (typeof(T) == typeof(NpgsqlLSeg)) - return LineSegmentHandler(); - if (typeof(T) == typeof(NpgsqlPath)) - return PathHandler(); - if (typeof(T) == typeof(NpgsqlPoint)) - return PointHandler(); - if (typeof(T) == typeof(NpgsqlPolygon)) - return PolygonHandler(); - - // Misc types - if (typeof(T) == typeof(bool)) - return _boolHandler; - if (typeof(T) == typeof(Guid)) - return UuidHandler(); - if (typeof(T) == typeof(BitVector32)) - return BitVaryingHandler(); - - // Internal types - if (typeof(T) == typeof(NpgsqlLogSequenceNumber)) - return PgLsnHandler(); - if (typeof(T) == typeof(NpgsqlTid)) - return TidHandler(); - if (typeof(T) == typeof(DBNull)) - return UnknownHandler(); - - return null; - } - - PostgresType PgType(string pgTypeName) => _databaseInfo.GetPostgresTypeByName(pgTypeName); - - #region Handler accessors - - // Numeric types - NpgsqlTypeHandler SingleHandler() => _singleHandler ??= new SingleHandler(PgType("real")); - NpgsqlTypeHandler MoneyHandler() => _moneyHandler ??= new MoneyHandler(PgType("money")); - - // Text types - NpgsqlTypeHandler XmlHandler() => _xmlHandler ??= new TextHandler(PgType("xml"), _connector.TextEncoding); - NpgsqlTypeHandler VarcharHandler() => _varcharHandler ??= new TextHandler(PgType("character varying"), _connector.TextEncoding); - NpgsqlTypeHandler CharHandler() => _charHandler ??= new TextHandler(PgType("character"), _connector.TextEncoding); - NpgsqlTypeHandler NameHandler() => _nameHandler ??= new TextHandler(PgType("name"), _connector.TextEncoding); - NpgsqlTypeHandler RefcursorHandler() => _refcursorHandler ??= new TextHandler(PgType("refcursor"), _connector.TextEncoding); - NpgsqlTypeHandler? CitextHandler() => _citextHandler ??= _databaseInfo.TryGetPostgresTypeByName("citext", out var pgType) - ? new TextHandler(pgType, _connector.TextEncoding) - : null; - NpgsqlTypeHandler JsonbHandler() => _jsonbHandler ??= new JsonTextHandler(PgType("jsonb"), _connector.TextEncoding, isJsonb: true); - NpgsqlTypeHandler JsonHandler() => _jsonHandler ??= new JsonTextHandler(PgType("json"), _connector.TextEncoding, isJsonb: false); - NpgsqlTypeHandler JsonPathHandler() => _jsonPathHandler ??= new JsonPathHandler(PgType("jsonpath"), _connector.TextEncoding); - - // Date/time types - NpgsqlTypeHandler TimeHandler() => _timeHandler ??= new TimeHandler(PgType("time without time zone")); - NpgsqlTypeHandler TimeTzHandler() => _timeTzHandler ??= new TimeTzHandler(PgType("time with time zone")); - NpgsqlTypeHandler IntervalHandler() => _intervalHandler ??= new IntervalHandler(PgType("interval")); - - // Network types - NpgsqlTypeHandler CidrHandler() => _cidrHandler ??= new CidrHandler(PgType("cidr")); - NpgsqlTypeHandler InetHandler() => _inetHandler ??= new InetHandler(PgType("inet")); - NpgsqlTypeHandler MacaddrHandler() => _macaddrHandler ??= new MacaddrHandler(PgType("macaddr")); - NpgsqlTypeHandler Macaddr8Handler() => _macaddr8Handler ??= new MacaddrHandler(PgType("macaddr8")); - - // Geometry types - NpgsqlTypeHandler BoxHandler() => _boxHandler ??= new BoxHandler(PgType("box")); - NpgsqlTypeHandler CircleHandler() => _circleHandler ??= new CircleHandler(PgType("circle")); - NpgsqlTypeHandler LineHandler() => _lineHandler ??= new LineHandler(PgType("line")); - NpgsqlTypeHandler LineSegmentHandler() => _lineSegmentHandler ??= new LineSegmentHandler(PgType("lseg")); - NpgsqlTypeHandler PathHandler() => _pathHandler ??= new PathHandler(PgType("path")); - NpgsqlTypeHandler PointHandler() => _pointHandler ??= new PointHandler(PgType("point")); - NpgsqlTypeHandler PolygonHandler() => _polygonHandler ??= new PolygonHandler(PgType("polygon")); - - // LTree types - NpgsqlTypeHandler? LQueryHandler() => _lQueryHandler ??= _databaseInfo.TryGetPostgresTypeByName("lquery", out var pgType) - ? new LQueryHandler(pgType, _connector.TextEncoding) - : null; - NpgsqlTypeHandler? LTreeHandler() => _lTreeHandler ??= _databaseInfo.TryGetPostgresTypeByName("ltree", out var pgType) - ? new LTreeHandler(pgType, _connector.TextEncoding) - : null; - NpgsqlTypeHandler? LTxtHandler() => _lTxtQueryHandler ??= _databaseInfo.TryGetPostgresTypeByName("ltxtquery", out var pgType) - ? new LTxtQueryHandler(pgType, _connector.TextEncoding) - : null; - - // UInt types - NpgsqlTypeHandler OidHandler() => _oidHandler ??= new UInt32Handler(PgType("oid")); - NpgsqlTypeHandler XidHandler() => _xidHandler ??= new UInt32Handler(PgType("xid")); - NpgsqlTypeHandler Xid8Handler() => _xid8Handler ??= new UInt64Handler(PgType("xid8")); - NpgsqlTypeHandler CidHandler() => _cidHandler ??= new UInt32Handler(PgType("cid")); - NpgsqlTypeHandler RegtypeHandler() => _regtypeHandler ??= new UInt32Handler(PgType("regtype")); - NpgsqlTypeHandler RegconfigHandler() => _regconfigHandler ??= new UInt32Handler(PgType("regconfig")); - - // Misc types - NpgsqlTypeHandler ByteaHandler() => _byteaHandler ??= new ByteaHandler(PgType("bytea")); - NpgsqlTypeHandler UuidHandler() => _uuidHandler ??= new UuidHandler(PgType("uuid")); - NpgsqlTypeHandler BitVaryingHandler() => _bitVaryingHandler ??= new BitStringHandler(PgType("bit varying")); - NpgsqlTypeHandler BitHandler() => _bitHandler ??= new BitStringHandler(PgType("bit")); - NpgsqlTypeHandler? HstoreHandler() => _hstoreHandler ??= _databaseInfo.TryGetPostgresTypeByName("hstore", out var pgType) - ? new HstoreHandler(pgType, _textHandler) - : null; - - // Internal types - NpgsqlTypeHandler Int2VectorHandler() => _int2VectorHandler ??= new Int2VectorHandler(PgType("int2vector"), PgType("smallint")); - NpgsqlTypeHandler OidVectorHandler() => _oidVectorHandler ??= new OIDVectorHandler(PgType("oidvector"), PgType("oid")); - NpgsqlTypeHandler PgLsnHandler() => _pgLsnHandler ??= new PgLsnHandler(PgType("pg_lsn")); - NpgsqlTypeHandler TidHandler() => _tidHandler ??= new TidHandler(PgType("tid")); - NpgsqlTypeHandler InternalCharHandler() => _internalCharHandler ??= new InternalCharHandler(PgType("char")); - NpgsqlTypeHandler VoidHandler() => _voidHandler ??= new VoidHandler(PgType("void")); - - // Types that are unsupported by default when using NpgsqlSlimDataSourceBuilder - NpgsqlTypeHandler UnsupportedRecordHandler() => new UnsupportedHandler(PgType("record"), string.Format( - NpgsqlStrings.RecordsNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableRecords), nameof(NpgsqlSlimDataSourceBuilder))); - - NpgsqlTypeHandler UnsupportedTsVectorHandler() => new UnsupportedHandler(PgType("tsvector"), string.Format( - NpgsqlStrings.FullTextSearchNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), - nameof(NpgsqlSlimDataSourceBuilder))); - - NpgsqlTypeHandler UnsupportedTsQueryHandler() => new UnsupportedHandler(PgType("tsquery"), string.Format( - NpgsqlStrings.FullTextSearchNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), - nameof(NpgsqlSlimDataSourceBuilder))); - - NpgsqlTypeHandler UnknownHandler() => _unknownHandler ??= new UnknownTypeHandler(_connector.TextEncoding); - - #endregion Handler accessors -} diff --git a/src/Npgsql/TypeMapping/BuiltInTypeHandlerResolverFactory.cs b/src/Npgsql/TypeMapping/BuiltInTypeHandlerResolverFactory.cs deleted file mode 100644 index 2912b97249..0000000000 --- a/src/Npgsql/TypeMapping/BuiltInTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,13 +0,0 @@ -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.TypeMapping; - -sealed class BuiltInTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new BuiltInTypeHandlerResolver(connector); - - public override TypeMappingResolver CreateMappingResolver() => new BuiltInTypeMappingResolver(); -} \ No newline at end of file diff --git a/src/Npgsql/TypeMapping/BuiltInTypeMappingResolver.cs b/src/Npgsql/TypeMapping/BuiltInTypeMappingResolver.cs deleted file mode 100644 index 8a236a86f6..0000000000 --- a/src/Npgsql/TypeMapping/BuiltInTypeMappingResolver.cs +++ /dev/null @@ -1,237 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Collections.Specialized; -using System.Net; -using System.Net.NetworkInformation; -using System.Numerics; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using NpgsqlTypes; -using static Npgsql.Util.Statics; - -namespace Npgsql.TypeMapping; - -sealed class BuiltInTypeMappingResolver : TypeMappingResolver -{ - static readonly Type ReadOnlyIPAddressType = IPAddress.Loopback.GetType(); - - static readonly Dictionary Mappings = new() - { - // Numeric types - { "smallint", new(NpgsqlDbType.Smallint, "smallint", typeof(short), typeof(byte), typeof(sbyte)) }, - { "integer", new(NpgsqlDbType.Integer, "integer", typeof(int)) }, - { "int", new(NpgsqlDbType.Integer, "integer", typeof(int)) }, - { "bigint", new(NpgsqlDbType.Bigint, "bigint", typeof(long)) }, - { "real", new(NpgsqlDbType.Real, "real", typeof(float)) }, - { "double precision", new(NpgsqlDbType.Double, "double precision", typeof(double)) }, - { "numeric", new(NpgsqlDbType.Numeric, "numeric", typeof(decimal), typeof(BigInteger)) }, - { "decimal", new(NpgsqlDbType.Numeric, "numeric", typeof(decimal), typeof(BigInteger)) }, - { "money", new(NpgsqlDbType.Money, "money") }, - - // Text types - { "text", new(NpgsqlDbType.Text, "text", typeof(string), typeof(char[]), typeof(char), typeof(ArraySegment)) }, - { "xml", new(NpgsqlDbType.Xml, "xml") }, - { "character varying", new(NpgsqlDbType.Varchar, "character varying") }, - { "varchar", new(NpgsqlDbType.Varchar, "character varying") }, - { "character", new(NpgsqlDbType.Char, "character") }, - { "name", new(NpgsqlDbType.Name, "name") }, - { "refcursor", new(NpgsqlDbType.Refcursor, "refcursor") }, - { "citext", new(NpgsqlDbType.Citext, "citext") }, - { "jsonb", new(NpgsqlDbType.Jsonb, "jsonb") }, - { "json", new(NpgsqlDbType.Json, "json") }, - { "jsonpath", new(NpgsqlDbType.JsonPath, "jsonpath") }, - - // Date/time types - { "timestamp without time zone", new(NpgsqlDbType.Timestamp, "timestamp without time zone", typeof(DateTime)) }, - { "timestamp", new(NpgsqlDbType.Timestamp, "timestamp without time zone", typeof(DateTime)) }, - { "timestamp with time zone", new(NpgsqlDbType.TimestampTz, "timestamp with time zone", typeof(DateTimeOffset)) }, - { "timestamptz", new(NpgsqlDbType.TimestampTz, "timestamp with time zone", typeof(DateTimeOffset)) }, - { "date", new(NpgsqlDbType.Date, "date" -#if NET6_0_OR_GREATER - , typeof(DateOnly) -#endif - ) }, - { "time without time zone", new(NpgsqlDbType.Time, "time without time zone" -#if NET6_0_OR_GREATER - , typeof(TimeOnly) -#endif - ) }, - { "time", new(NpgsqlDbType.Time, "time without time zone" -#if NET6_0_OR_GREATER - , typeof(TimeOnly) -#endif - ) }, - { "time with time zone", new(NpgsqlDbType.TimeTz, "time with time zone") }, - { "timetz", new(NpgsqlDbType.TimeTz, "time with time zone") }, - { "interval", new(NpgsqlDbType.Interval, "interval", typeof(TimeSpan)) }, - - { "timestamp without time zone[]", new(NpgsqlDbType.Array | NpgsqlDbType.Timestamp, "timestamp without time zone[]") }, - { "timestamp with time zone[]", new(NpgsqlDbType.Array | NpgsqlDbType.TimestampTz, "timestamp with time zone[]") }, - - // Network types - { "cidr", new(NpgsqlDbType.Cidr, "cidr") }, -#pragma warning disable 618 - { "inet", new(NpgsqlDbType.Inet, "inet", typeof(IPAddress), typeof((IPAddress Address, int Subnet)), typeof(NpgsqlInet), ReadOnlyIPAddressType) }, -#pragma warning restore 618 - { "macaddr", new(NpgsqlDbType.MacAddr, "macaddr", typeof(PhysicalAddress)) }, - { "macaddr8", new(NpgsqlDbType.MacAddr8, "macaddr8") }, - - // Geometry types - { "box", new(NpgsqlDbType.Box, "box", typeof(NpgsqlBox)) }, - { "circle", new(NpgsqlDbType.Circle, "circle", typeof(NpgsqlCircle)) }, - { "line", new(NpgsqlDbType.Line, "line", typeof(NpgsqlLine)) }, - { "lseg", new(NpgsqlDbType.LSeg, "lseg", typeof(NpgsqlLSeg)) }, - { "path", new(NpgsqlDbType.Path, "path", typeof(NpgsqlPath)) }, - { "point", new(NpgsqlDbType.Point, "point", typeof(NpgsqlPoint)) }, - { "polygon", new(NpgsqlDbType.Polygon, "polygon", typeof(NpgsqlPolygon)) }, - - // LTree types - { "lquery", new(NpgsqlDbType.LQuery, "lquery") }, - { "ltree", new(NpgsqlDbType.LTree, "ltree") }, - { "ltxtquery", new(NpgsqlDbType.LTxtQuery, "ltxtquery") }, - - // UInt types - { "oid", new(NpgsqlDbType.Oid, "oid") }, - { "xid", new(NpgsqlDbType.Xid, "xid") }, - { "xid8", new(NpgsqlDbType.Xid8, "xid8") }, - { "cid", new(NpgsqlDbType.Cid, "cid") }, - { "regtype", new(NpgsqlDbType.Regtype, "regtype") }, - { "regconfig", new(NpgsqlDbType.Regconfig, "regconfig") }, - - // Misc types - { "boolean", new(NpgsqlDbType.Boolean, "boolean", typeof(bool)) }, - { "bool", new(NpgsqlDbType.Boolean, "boolean", typeof(bool)) }, - { "bytea", new(NpgsqlDbType.Bytea, "bytea", typeof(byte[]), typeof(ArraySegment) -#if !NETSTANDARD2_0 - , typeof(ReadOnlyMemory), typeof(Memory) -#endif - ) }, - { "uuid", new(NpgsqlDbType.Uuid, "uuid", typeof(Guid)) }, - { "bit varying", new(NpgsqlDbType.Varbit, "bit varying", typeof(BitArray), typeof(BitVector32)) }, - { "varbit", new(NpgsqlDbType.Varbit, "bit varying", typeof(BitArray), typeof(BitVector32)) }, - { "bit", new(NpgsqlDbType.Bit, "bit") }, - { "hstore", new(NpgsqlDbType.Hstore, "hstore", typeof(Dictionary), typeof(IDictionary), typeof(ImmutableDictionary)) }, - - // Internal types - { "int2vector", new(NpgsqlDbType.Int2Vector, "int2vector") }, - { "oidvector", new(NpgsqlDbType.Oidvector, "oidvector") }, - { "pg_lsn", new(NpgsqlDbType.PgLsn, "pg_lsn", typeof(NpgsqlLogSequenceNumber)) }, - { "tid", new(NpgsqlDbType.Tid, "tid", typeof(NpgsqlTid)) }, - { "char", new(NpgsqlDbType.InternalChar, "char") }, - - // Special types - { "unknown", new(NpgsqlDbType.Unknown, "unknown") }, - }; - - static readonly Dictionary ClrTypeToDataTypeNameTable; - - static BuiltInTypeMappingResolver() - { - ClrTypeToDataTypeNameTable = new() - { - // Numeric types - { typeof(byte), "smallint" }, - { typeof(short), "smallint" }, - { typeof(int), "integer" }, - { typeof(long), "bigint" }, - { typeof(float), "real" }, - { typeof(double), "double precision" }, - { typeof(decimal), "decimal" }, - { typeof(BigInteger), "decimal" }, - - // Text types - { typeof(string), "text" }, - { typeof(char[]), "text" }, - { typeof(char), "text" }, - { typeof(ArraySegment), "text" }, - - // Date/time types - // The DateTime entry is for LegacyTimestampBehavior mode only. In regular mode we resolve through - // ResolveValueDependentValue below - { typeof(DateTime), "timestamp without time zone" }, - { typeof(DateTimeOffset), "timestamp with time zone" }, -#if NET6_0_OR_GREATER - { typeof(DateOnly), "date" }, - { typeof(TimeOnly), "time without time zone" }, -#endif - { typeof(TimeSpan), "interval" }, - { typeof(NpgsqlInterval), "interval" }, - - // Network types - { typeof(IPAddress), "inet" }, - // See ReadOnlyIPAddress below - { typeof((IPAddress Address, int Subnet)), "inet" }, -#pragma warning disable 618 - { typeof(NpgsqlInet), "inet" }, -#pragma warning restore 618 - { typeof(PhysicalAddress), "macaddr" }, - - // Geometry types - { typeof(NpgsqlBox), "box" }, - { typeof(NpgsqlCircle), "circle" }, - { typeof(NpgsqlLine), "line" }, - { typeof(NpgsqlLSeg), "lseg" }, - { typeof(NpgsqlPath), "path" }, - { typeof(NpgsqlPoint), "point" }, - { typeof(NpgsqlPolygon), "polygon" }, - - // Misc types - { typeof(bool), "boolean" }, - { typeof(byte[]), "bytea" }, - { typeof(ArraySegment), "bytea" }, -#if !NETSTANDARD2_0 - { typeof(ReadOnlyMemory), "bytea" }, - { typeof(Memory), "bytea" }, -#endif - { typeof(Guid), "uuid" }, - { typeof(BitArray), "bit varying" }, - { typeof(BitVector32), "bit varying" }, - { typeof(Dictionary), "hstore" }, - { typeof(ImmutableDictionary), "hstore" }, - - // Internal types - { typeof(NpgsqlLogSequenceNumber), "pg_lsn" }, - { typeof(NpgsqlTid), "tid" }, - { typeof(DBNull), "unknown" } - }; - - // Recent versions of .NET Core have an internal ReadOnlyIPAddress type (returned e.g. for IPAddress.Loopback) - // But older versions don't have it - if (ReadOnlyIPAddressType != typeof(IPAddress)) - ClrTypeToDataTypeNameTable[ReadOnlyIPAddressType] = "inet"; - - if (LegacyTimestampBehavior) - ClrTypeToDataTypeNameTable[typeof(DateTime)] = "timestamp without time zone"; - } - - public override string? GetDataTypeNameByClrType(Type clrType) - => ClrTypeToDataTypeName(clrType); - - internal static string? ClrTypeToDataTypeName(Type clrType) - => ClrTypeToDataTypeNameTable.TryGetValue(clrType, out var dataTypeName) ? dataTypeName : null; - - public override string? GetDataTypeNameByValueDependentValue(object value) - { - // In LegacyTimestampBehavior, DateTime isn't value-dependent, and handled above in ClrTypeToDataTypeNameTable like other types - if (LegacyTimestampBehavior) - return null; - - return value switch - { - DateTime dateTime => dateTime.Kind == DateTimeKind.Utc ? "timestamp with time zone" : "timestamp without time zone", - - // For arrays/lists, return timestamp or timestamptz based on the kind of the first DateTime; if the user attempts to - // mix incompatible Kinds, that will fail during validation. For empty arrays it doesn't matter. - IList array => array.Count == 0 - ? "timestamp without time zone[]" - : array[0].Kind == DateTimeKind.Utc ? "timestamp with time zone[]" : "timestamp without time zone[]", - - _ => null - }; - } - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => Mappings.TryGetValue(dataTypeName, out var mapping) ? mapping : null; -} diff --git a/src/Npgsql/TypeMapping/DefaultPgTypes.cs b/src/Npgsql/TypeMapping/DefaultPgTypes.cs new file mode 100644 index 0000000000..015e338a2c --- /dev/null +++ b/src/Npgsql/TypeMapping/DefaultPgTypes.cs @@ -0,0 +1,191 @@ +using System; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; +using static Npgsql.TypeMapping.PgTypeGroup; + +namespace Npgsql.TypeMapping; + +static class DefaultPgTypes +{ + static IEnumerable> GetIdentifiers() + { + var list = new List>(); + foreach (var group in Items) + { + list.Add(new(group.Oid, group.Name)); + list.Add(new(group.ArrayOid, group.ArrayName)); + if (group.TypeKind is PgTypeKind.Range) + { + list.Add(new(group.MultirangeOid!.Value, group.MultirangeName!.Value)); + list.Add(new(group.MultirangeArrayOid!.Value, group.MultirangeArrayName!.Value)); + } + } + + return list; + } + + static Dictionary? _oidMap; + public static IReadOnlyDictionary OidMap + { + get + { + if (_oidMap is not null) + return _oidMap; + + var dict = new Dictionary(); + foreach (var element in GetIdentifiers()) + dict.Add(element.Key, element.Value); + + return _oidMap = dict; + } + } + + static Dictionary? _dataTypeNameMap; + public static IReadOnlyDictionary DataTypeNameMap + { + get + { + if (_dataTypeNameMap is not null) + return _dataTypeNameMap; + + var dict = new Dictionary(); + foreach (var element in GetIdentifiers()) + dict.Add(element.Value, element.Key); + + return _dataTypeNameMap = dict; + } + } + + // We could also codegen this from pg_type.dat that lives in the postgres repo. + public static IEnumerable Items + => new[] + { + Create(DataTypeNames.Int2, oid: 21, arrayOid: 1005), + Create(DataTypeNames.Int4, oid: 23, arrayOid: 1007), + Create(DataTypeNames.Int4Range, oid: 3904, arrayOid: 3905, multirangeOid: 4451, multirangeArrayOid: 6150, typeKind: PgTypeKind.Range), + Create(DataTypeNames.Int8, oid: 20, arrayOid: 1016), + Create(DataTypeNames.Int8Range, oid: 3926, arrayOid: 3927, multirangeOid: 4536, multirangeArrayOid: 6157, typeKind: PgTypeKind.Range), + Create(DataTypeNames.Float4, oid: 700, arrayOid: 1021), + Create(DataTypeNames.Float8, oid: 701, arrayOid: 1022), + Create(DataTypeNames.Numeric, oid: 1700, arrayOid: 1231), + Create(DataTypeNames.NumRange, oid: 3906, arrayOid: 3907, multirangeOid: 4532, multirangeArrayOid: 6151, typeKind: PgTypeKind.Range), + Create(DataTypeNames.Money, oid: 790, arrayOid: 791), + Create(DataTypeNames.Bool, oid: 16, arrayOid: 1000), + Create(DataTypeNames.Box, oid: 603, arrayOid: 1020), + Create(DataTypeNames.Circle, oid: 718, arrayOid: 719), + Create(DataTypeNames.Line, oid: 628, arrayOid: 629), + Create(DataTypeNames.LSeg, oid: 601, arrayOid: 1018), + Create(DataTypeNames.Path, oid: 602, arrayOid: 1019), + Create(DataTypeNames.Point, oid: 600, arrayOid: 1017), + Create(DataTypeNames.Polygon, oid: 604, arrayOid: 1027), + Create(DataTypeNames.Bpchar, oid: 1042, arrayOid: 1014), + Create(DataTypeNames.Text, oid: 25, arrayOid: 1009), + Create(DataTypeNames.Varchar, oid: 1043, arrayOid: 1015), + Create(DataTypeNames.Name, oid: 19, arrayOid: 1003), + Create(DataTypeNames.Bytea, oid: 17, arrayOid: 1001), + Create(DataTypeNames.Date, oid: 1082, arrayOid: 1182), + Create(DataTypeNames.DateRange, oid: 3912, arrayOid: 3913, multirangeOid: 4535, multirangeArrayOid: 6155, typeKind: PgTypeKind.Range), + Create(DataTypeNames.Time, oid: 1083, arrayOid: 1183), + Create(DataTypeNames.Timestamp, oid: 1114, arrayOid: 1115), + Create(DataTypeNames.TsRange, oid: 3908, arrayOid: 3909, multirangeOid: 4533, multirangeArrayOid: 6152, typeKind: PgTypeKind.Range), + Create(DataTypeNames.TimestampTz, oid: 1184, arrayOid: 1185), + Create(DataTypeNames.TsTzRange, oid: 3910, arrayOid: 3911, multirangeOid: 4534, multirangeArrayOid: 6153, typeKind: PgTypeKind.Range), + Create(DataTypeNames.Interval, oid: 1186, arrayOid: 1187), + Create(DataTypeNames.TimeTz, oid: 1266, arrayOid: 1270), + Create(DataTypeNames.Inet, oid: 869, arrayOid: 1041), + Create(DataTypeNames.Cidr, oid: 650, arrayOid: 651), + Create(DataTypeNames.MacAddr, oid: 829, arrayOid: 1040), + Create(DataTypeNames.MacAddr8, oid: 774, arrayOid: 775), + Create(DataTypeNames.Bit, oid: 1560, arrayOid: 1561), + Create(DataTypeNames.Varbit, oid: 1562, arrayOid: 1563), + Create(DataTypeNames.TsVector, oid: 3614, arrayOid: 3643), + Create(DataTypeNames.TsQuery, oid: 3615, arrayOid: 3645), + Create(DataTypeNames.RegConfig, oid: 3734, arrayOid: 3735), + Create(DataTypeNames.Uuid, oid: 2950, arrayOid: 2951), + Create(DataTypeNames.Xml, oid: 142, arrayOid: 143), + Create(DataTypeNames.Json, oid: 114, arrayOid: 199), + Create(DataTypeNames.Jsonb, oid: 3802, arrayOid: 3807), + Create(DataTypeNames.Jsonpath, oid: 4072, arrayOid: 4073), + Create(DataTypeNames.RefCursor, oid: 1790, arrayOid: 2201), + Create(DataTypeNames.OidVector, oid: 30, arrayOid: 1013), + Create(DataTypeNames.Int2Vector, oid: 22, arrayOid: 1006), + Create(DataTypeNames.Oid, oid: 26, arrayOid: 1028), + Create(DataTypeNames.Xid, oid: 28, arrayOid: 1011), + Create(DataTypeNames.Xid8, oid: 5069, arrayOid: 271), + Create(DataTypeNames.Cid, oid: 29, arrayOid: 1012), + Create(DataTypeNames.RegType, oid: 2206, arrayOid: 2211), + Create(DataTypeNames.Tid, oid: 27, arrayOid: 1010), + Create(DataTypeNames.PgLsn, oid: 3220, arrayOid: 3221), + Create(DataTypeNames.Unknown, oid: 705, arrayOid: 0, typeKind: PgTypeKind.Pseudo), + Create(DataTypeNames.Void, oid: 2278, arrayOid: 0, typeKind: PgTypeKind.Pseudo), + }; +} + +enum PgTypeKind +{ + /// A base type. + Base, + /// An enum carying its variants. + Enum, + /// A pseudo type like anyarray. + Pseudo, + // An array carying its element type. + Array, + // A range carying its element type. + Range, + // A multi-range carying its element type. + Multirange, + // A domain carying its underlying type. + Domain, + // A composite carying its constituent fields. + Composite +} + +readonly struct PgTypeGroup +{ + public required PgTypeKind TypeKind { get; init; } + public required DataTypeName Name { get; init; } + public required Oid Oid { get; init; } + public required DataTypeName ArrayName { get; init; } + public required Oid ArrayOid { get; init; } + public DataTypeName? MultirangeName { get; init; } + public Oid? MultirangeOid { get; init; } + public DataTypeName? MultirangeArrayName { get; init; } + public Oid? MultirangeArrayOid { get; init; } + + public static PgTypeGroup Create(DataTypeName name, Oid oid, Oid arrayOid, string? multirangeName = null, Oid? multirangeOid = null, Oid? multirangeArrayOid = null, PgTypeKind typeKind = PgTypeKind.Base) + { + DataTypeName? multirangeDataTypeName = null; + if (typeKind is PgTypeKind.Range) + { + if (multirangeOid is null) + throw new ArgumentException("When a range is supplied its multirange oid cannot be omitted."); + if (multirangeArrayOid is null) + throw new ArgumentException("When a range is supplied its multirange array oid cannot be omitted."); + multirangeDataTypeName = multirangeName is not null ? DataTypeName.CreateFullyQualifiedName(multirangeName) : name.ToDefaultMultirangeName(); + } + else + { + if (multirangeName is not null || multirangeOid is not null) + throw new ArgumentException("Only range types can have a multirange oid or name."); + + if (multirangeArrayOid is not null) + throw new ArgumentException("Only range types can have a multirange array oid."); + } + + return new PgTypeGroup + { + TypeKind = typeKind, + Name = name, + Oid = oid, + + ArrayName = name.ToArrayName(), + ArrayOid = arrayOid, + + MultirangeName = multirangeDataTypeName, + MultirangeOid = multirangeOid, + MultirangeArrayName = multirangeDataTypeName?.ToArrayName(), + MultirangeArrayOid = multirangeArrayOid + }; + } +} diff --git a/src/Npgsql/TypeMapping/FullTextSearchTypeHandlerResolver.cs b/src/Npgsql/TypeMapping/FullTextSearchTypeHandlerResolver.cs deleted file mode 100644 index 38db435814..0000000000 --- a/src/Npgsql/TypeMapping/FullTextSearchTypeHandlerResolver.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers.FullTextSearchHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeMapping; - -sealed class FullTextSearchTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlDatabaseInfo _databaseInfo; - - public FullTextSearchTypeHandlerResolver(NpgsqlConnector connector) - => _databaseInfo = connector.DatabaseInfo; - - TsQueryHandler? _tsQueryHandler; - TsVectorHandler? _tsVectorHandler; - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) => - typeName switch - { - "tsquery" => TsQueryHandler(), - "tsvector" => TsVectorHandler(), - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => FullTextSearchTypeMappingResolver.ClrTypeToDataTypeName(type) is { } dataTypeName ? ResolveByDataTypeName(dataTypeName) : null; - - NpgsqlTypeHandler TsQueryHandler() => _tsQueryHandler ??= new TsQueryHandler(PgType("tsquery")); - NpgsqlTypeHandler TsVectorHandler() => _tsVectorHandler ??= new TsVectorHandler(PgType("tsvector")); - - PostgresType PgType(string pgTypeName) => _databaseInfo.GetPostgresTypeByName(pgTypeName); -} diff --git a/src/Npgsql/TypeMapping/FullTextSearchTypeHandlerResolverFactory.cs b/src/Npgsql/TypeMapping/FullTextSearchTypeHandlerResolverFactory.cs deleted file mode 100644 index cbfb8a9838..0000000000 --- a/src/Npgsql/TypeMapping/FullTextSearchTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.TypeMapping; - -sealed class FullTextSearchTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) => - new FullTextSearchTypeHandlerResolver(connector); - - public override TypeMappingResolver CreateMappingResolver() => new FullTextSearchTypeMappingResolver(); - - public override TypeMappingResolver CreateGlobalMappingResolver() => new FullTextSearchTypeMappingResolver(); -} diff --git a/src/Npgsql/TypeMapping/FullTextSearchTypeMappingResolver.cs b/src/Npgsql/TypeMapping/FullTextSearchTypeMappingResolver.cs deleted file mode 100644 index 90185578c0..0000000000 --- a/src/Npgsql/TypeMapping/FullTextSearchTypeMappingResolver.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System; -using System.Collections.Generic; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeMapping; - -sealed class FullTextSearchTypeMappingResolver : TypeMappingResolver -{ - static readonly TypeMappingInfo TsQueryMappingInfo = new(NpgsqlDbType.TsQuery, "tsquery", - typeof(NpgsqlTsQuery), typeof(NpgsqlTsQueryAnd), typeof(NpgsqlTsQueryEmpty), typeof(NpgsqlTsQueryFollowedBy), - typeof(NpgsqlTsQueryLexeme), typeof(NpgsqlTsQueryNot), typeof(NpgsqlTsQueryOr), typeof(NpgsqlTsQueryBinOp)); - - static readonly TypeMappingInfo TsVectorMappingInfo = new(NpgsqlDbType.TsVector, "tsvector", typeof(NpgsqlTsVector)); - - static readonly Dictionary ClrTypeToDataTypeNameTable = new() - { - { typeof(NpgsqlTsVector), "tsvector" }, - { typeof(NpgsqlTsQueryLexeme), "tsquery" }, - { typeof(NpgsqlTsQueryAnd), "tsquery" }, - { typeof(NpgsqlTsQueryOr), "tsquery" }, - { typeof(NpgsqlTsQueryNot), "tsquery" }, - { typeof(NpgsqlTsQueryEmpty), "tsquery" }, - { typeof(NpgsqlTsQueryFollowedBy), "tsquery" }, - }; - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => dataTypeName switch - { - "tsquery" => TsQueryMappingInfo, - "tsvector" => TsVectorMappingInfo, - _ => null - }; - - public override string? GetDataTypeNameByClrType(Type clrType) - => ClrTypeToDataTypeName(clrType); - - internal static string? ClrTypeToDataTypeName(Type clrType) - => ClrTypeToDataTypeNameTable.TryGetValue(clrType, out var dataTypeName) ? dataTypeName : null; -} diff --git a/src/Npgsql/TypeMapping/GlobalTypeMapper.cs b/src/Npgsql/TypeMapping/GlobalTypeMapper.cs index a13abf5ea0..fdaa340bd8 100644 --- a/src/Npgsql/TypeMapping/GlobalTypeMapper.cs +++ b/src/Npgsql/TypeMapping/GlobalTypeMapper.cs @@ -1,654 +1,240 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; -using System.Data; using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Reflection; using System.Threading; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.NameTranslation; -using NpgsqlTypes; -using static Npgsql.Util.Statics; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; namespace Npgsql.TypeMapping; +/// sealed class GlobalTypeMapper : INpgsqlTypeMapper { - public static GlobalTypeMapper Instance { get; } + readonly UserTypeMapper _userTypeMapper = new(); + readonly List _pluginResolvers = new(); + readonly ReaderWriterLockSlim _lock = new(); + IPgTypeInfoResolver[] _typeMappingResolvers = Array.Empty(); - public INpgsqlNameTranslator DefaultNameTranslator { get; set; } = new NpgsqlSnakeCaseNameTranslator(); - - internal List HandlerResolverFactories { get; } = new(); - List MappingResolvers { get; } = new(); - public ConcurrentDictionary UserTypeMappings { get; } = new(); - - readonly ConcurrentDictionary _mappingsByClrType = new(); - - internal ReaderWriterLockSlim Lock { get; } - = new(LockRecursionPolicy.SupportsRecursion); - - static GlobalTypeMapper() - => Instance = new GlobalTypeMapper(); - - GlobalTypeMapper() - => Reset(); - - #region Mapping management - - public INpgsqlTypeMapper MapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum + internal IEnumerable GetPluginResolvers() { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(typeof(TEnum), nameTranslator); - - Lock.EnterWriteLock(); + var resolvers = new List(); + _lock.EnterReadLock(); try { - UserTypeMappings[pgName] = new UserEnumTypeMapping(pgName, nameTranslator); - RecordChange(); - return this; + resolvers.AddRange(_pluginResolvers); } finally { - Lock.ExitWriteLock(); + _lock.ExitReadLock(); } + + return resolvers; } - public bool UnmapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum + internal IPgTypeInfoResolver? GetUserMappingsResolver() { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(typeof(TEnum), nameTranslator); - - Lock.EnterWriteLock(); + _lock.EnterReadLock(); try { - if (UserTypeMappings.TryRemove(pgName, out _)) - { - RecordChange(); - return true; - } - - return false; + return _userTypeMapper.Items.Count > 0 ? _userTypeMapper.Build() : null; } finally { - Lock.ExitWriteLock(); + _lock.ExitReadLock(); } } - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] - public INpgsqlTypeMapper MapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + internal void AddGlobalTypeMappingResolvers(IPgTypeInfoResolver[] resolvers, bool overwrite = false) { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(typeof(T), nameTranslator); - - Lock.EnterWriteLock(); - try - { - UserTypeMappings[pgName] = new UserCompositeTypeMapping(pgName, nameTranslator); - RecordChange(); - return this; - } - finally + // Good enough logic to prevent SlimBuilder overriding the normal Builder. + if (overwrite || resolvers.Length > _typeMappingResolvers.Length) { - Lock.ExitWriteLock(); + _typeMappingResolvers = resolvers; + ResetTypeMappingCache(); } } - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] - public INpgsqlTypeMapper MapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - { - var openMethod = typeof(GlobalTypeMapper).GetMethod(nameof(MapComposite), new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; - var method = openMethod.MakeGenericMethod(clrType); - method.Invoke(this, new object?[] { pgName, nameTranslator }); + void ResetTypeMappingCache() => _typeMappingOptions = null; - return this; - } - - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] - public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => UnmapComposite(typeof(T), pgName, nameTranslator); + PgSerializerOptions? _typeMappingOptions; - [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] - public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + PgSerializerOptions TypeMappingOptions { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); + get + { + if (_typeMappingOptions is not null) + return _typeMappingOptions; - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(clrType, nameTranslator); + _lock.EnterReadLock(); + try + { + var resolvers = new List(); + resolvers.Add(_userTypeMapper.Build()); + resolvers.AddRange(_pluginResolvers); + resolvers.AddRange(_typeMappingResolvers); + return _typeMappingOptions = new(PostgresMinimalDatabaseInfo.DefaultTypeCatalog) + { + // This means we don't ever have a missing oid for a datatypename as our canonical format is datatypenames. + PortableTypeIds = true, + // Don't throw if our catalog doesn't know the datatypename. + IntrospectionMode = true, + TypeInfoResolver = new TypeInfoResolverChain(resolvers) + }; + } + finally + { + _lock.ExitReadLock(); + } + } + } - Lock.EnterWriteLock(); - try - { - if (UserTypeMappings.TryRemove(pgName, out _)) + internal DataTypeName? TryGetDataTypeName(Type type, object value) + { + var typeInfo = TypeMappingOptions.GetTypeInfo(type); + DataTypeName? dataTypeName; + if (typeInfo is PgResolverTypeInfo info) + try + { + dataTypeName = info.GetObjectResolution(value).PgTypeId.DataTypeName; + } + catch { - RecordChange(); - return true; + dataTypeName = null; } + else + dataTypeName = typeInfo?.GetConcreteResolution().PgTypeId.DataTypeName; - return false; - } - finally - { - Lock.ExitWriteLock(); - } + return dataTypeName; } - public void AddTypeResolverFactory(TypeHandlerResolverFactory resolverFactory) + internal static GlobalTypeMapper Instance { get; } + + static GlobalTypeMapper() + => Instance = new GlobalTypeMapper(); + + /// + /// Adds a type info resolver which can add or modify support for PostgreSQL types. + /// Typically used by plugins. + /// + /// The type resolver to be added. + public void AddTypeInfoResolver(IPgTypeInfoResolver resolver) { - Lock.EnterWriteLock(); + _lock.EnterWriteLock(); try { - // Since EFCore.PG plugins (and possibly other users) repeatedly call NpgsqlConnection.GlobalTypeMapped.UseNodaTime, - // we replace an existing resolver of the same CLR type. - var type = resolverFactory.GetType(); + var type = resolver.GetType(); - if (HandlerResolverFactories[0].GetType() == type) - HandlerResolverFactories[0] = resolverFactory; - else + // Since EFCore.PG plugins (and possibly other users) repeatedly call NpgsqlConnection.GlobalTypeMapper.UseNodaTime, + // we replace an existing resolver of the same CLR type. + if (_pluginResolvers.Count > 0 && _pluginResolvers[0].GetType() == type) + _pluginResolvers[0] = resolver; + for (var i = 0; i < _pluginResolvers.Count; i++) { - for (var i = 0; i < HandlerResolverFactories.Count; i++) - if (HandlerResolverFactories[i].GetType() == type) - HandlerResolverFactories.RemoveAt(i); - - HandlerResolverFactories.Insert(0, resolverFactory); + if (_pluginResolvers[i].GetType() == type) + { + _pluginResolvers.RemoveAt(i); + break; + } } - var mappingResolver = resolverFactory.CreateMappingResolver(); - if (mappingResolver is not null) - AddMappingResolver(mappingResolver, overwrite: true); - - RecordChange(); + _pluginResolvers.Insert(0, resolver); + ResetTypeMappingCache(); } finally { - Lock.ExitWriteLock(); + _lock.ExitWriteLock(); } } - internal void TryAddMappingResolver(TypeMappingResolver resolver) + /// + public void Reset() { - Lock.EnterWriteLock(); + _lock.EnterWriteLock(); try { - // For global mapper resolvers we don't need to overwrite them in case we add another of the same type - // because they shouldn't have a state. - // The only exception is whenever a user adds a resolver factory to global type mapper specifically. - // In that case we create a local mapper resolver and always overwrite the one we already have - // as it can have settings (e.g. json serialization) - if (AddMappingResolver(resolver, overwrite: false)) - RecordChange(); + _pluginResolvers.Clear(); + _userTypeMapper.Items.Clear(); } finally { - Lock.ExitWriteLock(); + _lock.ExitWriteLock(); } } - bool AddMappingResolver(TypeMappingResolver resolver, bool overwrite) + /// + public INpgsqlNameTranslator DefaultNameTranslator { - // Since EFCore.PG plugins (and possibly other users) repeatedly call NpgsqlConnection.GlobalTypeMapped.UseNodaTime, - // we replace an existing resolver of the same CLR type. - var type = resolver.GetType(); + get => _userTypeMapper.DefaultNameTranslator; + set => _userTypeMapper.DefaultNameTranslator = value; + } - if (MappingResolvers[0].GetType() == type) + /// + public INpgsqlTypeMapper MapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum + { + _lock.EnterWriteLock(); + try { - if (!overwrite) - return false; - MappingResolvers[0] = resolver; + _userTypeMapper.MapEnum(pgName, nameTranslator); + return this; } - else + finally { - for (var i = 0; i < MappingResolvers.Count; i++) - { - if (MappingResolvers[i].GetType() == type) - { - if (!overwrite) - return false; - MappingResolvers.RemoveAt(i); - break; - } - } - - MappingResolvers.Insert(0, resolver); + _lock.ExitWriteLock(); } - - return true; } - public void Reset() + /// + public bool UnmapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum { - Lock.EnterWriteLock(); + _lock.EnterWriteLock(); try { - HandlerResolverFactories.Clear(); - HandlerResolverFactories.Add(new BuiltInTypeHandlerResolverFactory()); - - MappingResolvers.Clear(); - MappingResolvers.Add(new BuiltInTypeMappingResolver()); - - UserTypeMappings.Clear(); - - RecordChange(); + return _userTypeMapper.UnmapEnum(pgName, nameTranslator); } finally { - Lock.ExitWriteLock(); + _lock.ExitWriteLock(); } } - internal void RecordChange() - => _mappingsByClrType.Clear(); - - static string GetPgName(Type clrType, INpgsqlNameTranslator nameTranslator) - => clrType.GetCustomAttribute()?.PgName - ?? nameTranslator.TranslateTypeName(clrType.Name); - - #endregion Mapping management + /// + [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + public INpgsqlTypeMapper MapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => MapComposite(typeof(T), pgName, nameTranslator); - #region NpgsqlDbType/DbType inference for NpgsqlParameter + /// + [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => UnmapComposite(typeof(T), pgName, nameTranslator); - [RequiresUnreferencedCode("ToNpgsqlDbType uses interface-based reflection and isn't trimming-safe")] - internal bool TryResolveMappingByValue(object value, [NotNullWhen(true)] out TypeMappingInfo? typeMapping) + /// + [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + public INpgsqlTypeMapper MapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { - Lock.EnterReadLock(); + _lock.EnterWriteLock(); try { - // We resolve as follows: - // 1. Cached by-type lookup (fast path). This will work for almost all types after the very first resolution. - // 2. Value-dependent type lookup (e.g. DateTime by Kind) via the resolvers. This includes complex types (e.g. array/range - // over DateTime), and the results cannot be cached. - // 3. Uncached by-type lookup (for the very first resolution of a given type) - - var type = value.GetType(); - if (_mappingsByClrType.TryGetValue(type, out typeMapping)) - return true; - - foreach (var resolver in MappingResolvers) - if ((typeMapping = resolver.GetMappingByValueDependentValue(value)) is not null) - return true; - - return TryResolveMappingByClrType(type, out typeMapping); + _userTypeMapper.MapComposite(clrType, pgName, nameTranslator); + return this; } finally { - Lock.ExitReadLock(); - } - - bool TryResolveMappingByClrType(Type clrType, [NotNullWhen(true)] out TypeMappingInfo? typeMapping) - { - if (_mappingsByClrType.TryGetValue(clrType, out typeMapping)) - return true; - - foreach (var resolver in MappingResolvers) - { - if ((typeMapping = resolver.GetMappingByClrType(clrType)) is not null) - { - _mappingsByClrType[clrType] = typeMapping; - return true; - } - } - - if (clrType.IsArray) - { - if (TryResolveMappingByClrType(clrType.GetElementType()!, out var elementMapping)) - { - _mappingsByClrType[clrType] = typeMapping = new( - NpgsqlDbType.Array | elementMapping.NpgsqlDbType, - elementMapping.DataTypeName + "[]"); - return true; - } - - typeMapping = null; - return false; - } - - var typeInfo = clrType.GetTypeInfo(); - - var ilist = typeInfo.ImplementedInterfaces.FirstOrDefault(x => - x.GetTypeInfo().IsGenericType && x.GetGenericTypeDefinition() == typeof(IList<>)); - if (ilist != null) - { - if (TryResolveMappingByClrType(ilist.GetGenericArguments()[0], out var elementMapping)) - { - _mappingsByClrType[clrType] = typeMapping = new( - NpgsqlDbType.Array | elementMapping.NpgsqlDbType, - elementMapping.DataTypeName + "[]"); - return true; - } - - typeMapping = null; - return false; - } - - if (typeInfo.IsGenericType && clrType.GetGenericTypeDefinition() == typeof(NpgsqlRange<>)) - { - if (TryResolveMappingByClrType(clrType.GetGenericArguments()[0], out var elementMapping)) - { - _mappingsByClrType[clrType] = typeMapping = new( - NpgsqlDbType.Range | elementMapping.NpgsqlDbType, - dataTypeName: null); - return true; - } - - typeMapping = null; - return false; - } - - typeMapping = null; - return false; + _lock.ExitWriteLock(); } } - #endregion NpgsqlDbType/DbType inference for NpgsqlParameter - - #region Static translation tables - - public static string? NpgsqlDbTypeToDataTypeName(NpgsqlDbType npgsqlDbType) - => npgsqlDbType switch - { - // Numeric types - NpgsqlDbType.Smallint => "smallint", - NpgsqlDbType.Integer => "integer", - NpgsqlDbType.Bigint => "bigint", - NpgsqlDbType.Real => "real", - NpgsqlDbType.Double => "double precision", - NpgsqlDbType.Numeric => "numeric", - NpgsqlDbType.Money => "money", - - // Text types - NpgsqlDbType.Text => "text", - NpgsqlDbType.Xml => "xml", - NpgsqlDbType.Varchar => "character varying", - NpgsqlDbType.Char => "character", - NpgsqlDbType.Name => "name", - NpgsqlDbType.Refcursor => "refcursor", - NpgsqlDbType.Citext => "citext", - NpgsqlDbType.Jsonb => "jsonb", - NpgsqlDbType.Json => "json", - NpgsqlDbType.JsonPath => "jsonpath", - - // Date/time types - NpgsqlDbType.Timestamp => "timestamp without time zone", - NpgsqlDbType.TimestampTz => "timestamp with time zone", - NpgsqlDbType.Date => "date", - NpgsqlDbType.Time => "time without time zone", - NpgsqlDbType.TimeTz => "time with time zone", - NpgsqlDbType.Interval => "interval", - - // Network types - NpgsqlDbType.Cidr => "cidr", - NpgsqlDbType.Inet => "inet", - NpgsqlDbType.MacAddr => "macaddr", - NpgsqlDbType.MacAddr8 => "macaddr8", - - // Full-text search types - NpgsqlDbType.TsQuery => "tsquery", - NpgsqlDbType.TsVector => "tsvector", - - // Geometry types - NpgsqlDbType.Box => "box", - NpgsqlDbType.Circle => "circle", - NpgsqlDbType.Line => "line", - NpgsqlDbType.LSeg => "lseg", - NpgsqlDbType.Path => "path", - NpgsqlDbType.Point => "point", - NpgsqlDbType.Polygon => "polygon", - - // LTree types - NpgsqlDbType.LQuery => "lquery", - NpgsqlDbType.LTree => "ltree", - NpgsqlDbType.LTxtQuery => "ltxtquery", - - // UInt types - NpgsqlDbType.Oid => "oid", - NpgsqlDbType.Xid => "xid", - NpgsqlDbType.Xid8 => "xid8", - NpgsqlDbType.Cid => "cid", - NpgsqlDbType.Regtype => "regtype", - NpgsqlDbType.Regconfig => "regconfig", - - // Misc types - NpgsqlDbType.Boolean => "boolean", - NpgsqlDbType.Bytea => "bytea", - NpgsqlDbType.Uuid => "uuid", - NpgsqlDbType.Varbit => "bit varying", - NpgsqlDbType.Bit => "bit", - NpgsqlDbType.Hstore => "hstore", - - NpgsqlDbType.Geometry => "geometry", - NpgsqlDbType.Geography => "geography", - - // Built-in range types - NpgsqlDbType.IntegerRange => "int4range", - NpgsqlDbType.BigIntRange => "int8range", - NpgsqlDbType.NumericRange => "numrange", - NpgsqlDbType.TimestampRange => "tsrange", - NpgsqlDbType.TimestampTzRange => "tstzrange", - NpgsqlDbType.DateRange => "daterange", - - // Built-in multirange types - NpgsqlDbType.IntegerMultirange => "int4multirange", - NpgsqlDbType.BigIntMultirange => "int8multirange", - NpgsqlDbType.NumericMultirange => "nummultirange", - NpgsqlDbType.TimestampMultirange => "tsmultirange", - NpgsqlDbType.TimestampTzMultirange => "tstzmultirange", - NpgsqlDbType.DateMultirange => "datemultirange", - - // Internal types - NpgsqlDbType.Int2Vector => "int2vector", - NpgsqlDbType.Oidvector => "oidvector", - NpgsqlDbType.PgLsn => "pg_lsn", - NpgsqlDbType.Tid => "tid", - NpgsqlDbType.InternalChar => "char", - - // Special types - NpgsqlDbType.Unknown => "unknown", - - _ => npgsqlDbType.HasFlag(NpgsqlDbType.Array) - ? NpgsqlDbTypeToDataTypeName(npgsqlDbType & ~NpgsqlDbType.Array) + "[]" - : null // e.g. ranges - }; - - public static NpgsqlDbType DataTypeNameToNpgsqlDbType(string typeName) + /// + [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { - // Strip any facet information (length/precision/scale) - var parenIndex = typeName.IndexOf('('); - if (parenIndex > -1) - typeName = typeName.Substring(0, parenIndex); - - return typeName switch - { - // Numeric types - "smallint" => NpgsqlDbType.Smallint, - "integer" or "int" => NpgsqlDbType.Integer, - "bigint" => NpgsqlDbType.Bigint, - "real" => NpgsqlDbType.Real, - "double precision" => NpgsqlDbType.Double, - "numeric" => NpgsqlDbType.Numeric, - "money" => NpgsqlDbType.Money, - - // Text types - "text" => NpgsqlDbType.Text, - "xml" => NpgsqlDbType.Xml, - "character varying" or "varchar" => NpgsqlDbType.Varchar, - "character" => NpgsqlDbType.Char, - "name" => NpgsqlDbType.Name, - "refcursor" => NpgsqlDbType.Refcursor, - "citext" => NpgsqlDbType.Citext, - "jsonb" => NpgsqlDbType.Jsonb, - "json" => NpgsqlDbType.Json, - "jsonpath" => NpgsqlDbType.JsonPath, - - // Date/time types - "timestamp without time zone" or "timestamp" => NpgsqlDbType.Timestamp, - "timestamp with time zone" or "timestamptz" => NpgsqlDbType.TimestampTz, - "date" => NpgsqlDbType.Date, - "time without time zone" or "timetz" => NpgsqlDbType.Time, - "time with time zone" or "time" => NpgsqlDbType.TimeTz, - "interval" => NpgsqlDbType.Interval, - - // Network types - "cidr" => NpgsqlDbType.Cidr, - "inet" => NpgsqlDbType.Inet, - "macaddr" => NpgsqlDbType.MacAddr, - "macaddr8" => NpgsqlDbType.MacAddr8, - - // Full-text search types - "tsquery" => NpgsqlDbType.TsQuery, - "tsvector" => NpgsqlDbType.TsVector, - - // Geometry types - "box" => NpgsqlDbType.Box, - "circle" => NpgsqlDbType.Circle, - "line" => NpgsqlDbType.Line, - "lseg" => NpgsqlDbType.LSeg, - "path" => NpgsqlDbType.Path, - "point" => NpgsqlDbType.Point, - "polygon" => NpgsqlDbType.Polygon, - - // LTree types - "lquery" => NpgsqlDbType.LQuery, - "ltree" => NpgsqlDbType.LTree, - "ltxtquery" => NpgsqlDbType.LTxtQuery, - - // UInt types - "oid" => NpgsqlDbType.Oid, - "xid" => NpgsqlDbType.Xid, - "xid8" => NpgsqlDbType.Xid8, - "cid" => NpgsqlDbType.Cid, - "regtype" => NpgsqlDbType.Regtype, - "regconfig" => NpgsqlDbType.Regconfig, - - // Misc types - "boolean" or "bool" => NpgsqlDbType.Boolean, - "bytea" => NpgsqlDbType.Bytea, - "uuid" => NpgsqlDbType.Uuid, - "bit varying" or "varbit" => NpgsqlDbType.Varbit, - "bit" => NpgsqlDbType.Bit, - "hstore" => NpgsqlDbType.Hstore, - - "geometry" => NpgsqlDbType.Geometry, - "geography" => NpgsqlDbType.Geography, - - // Built-in range types - "int4range" => NpgsqlDbType.IntegerRange, - "int8range" => NpgsqlDbType.BigIntRange, - "numrange" => NpgsqlDbType.NumericRange, - "tsrange" => NpgsqlDbType.TimestampRange, - "tstzrange" => NpgsqlDbType.TimestampTzRange, - "daterange" => NpgsqlDbType.DateRange, - - // Built-in multirange types - "int4multirange" => NpgsqlDbType.IntegerMultirange, - "int8multirange" => NpgsqlDbType.BigIntMultirange, - "nummultirange" => NpgsqlDbType.NumericMultirange, - "tsmultirange" => NpgsqlDbType.TimestampMultirange, - "tstzmultirange" => NpgsqlDbType.TimestampTzMultirange, - "datemultirange" => NpgsqlDbType.DateMultirange, - - // Internal types - "int2vector" => NpgsqlDbType.Int2Vector, - "oidvector" => NpgsqlDbType.Oidvector, - "pg_lsn" => NpgsqlDbType.PgLsn, - "tid" => NpgsqlDbType.Tid, - "char" => NpgsqlDbType.InternalChar, - - _ => typeName.EndsWith("[]", StringComparison.Ordinal) && - DataTypeNameToNpgsqlDbType(typeName.Substring(0, typeName.Length - 2)) is { } elementNpgsqlDbType && - elementNpgsqlDbType != NpgsqlDbType.Unknown - ? elementNpgsqlDbType | NpgsqlDbType.Array - : NpgsqlDbType.Unknown // e.g. ranges - }; - } - - internal static NpgsqlDbType? DbTypeToNpgsqlDbType(DbType dbType) - => dbType switch + _lock.EnterWriteLock(); + try { - DbType.AnsiString => NpgsqlDbType.Text, - DbType.Binary => NpgsqlDbType.Bytea, - DbType.Byte => NpgsqlDbType.Smallint, - DbType.Boolean => NpgsqlDbType.Boolean, - DbType.Currency => NpgsqlDbType.Money, - DbType.Date => NpgsqlDbType.Date, - DbType.DateTime => LegacyTimestampBehavior ? NpgsqlDbType.Timestamp : NpgsqlDbType.TimestampTz, - DbType.Decimal => NpgsqlDbType.Numeric, - DbType.VarNumeric => NpgsqlDbType.Numeric, - DbType.Double => NpgsqlDbType.Double, - DbType.Guid => NpgsqlDbType.Uuid, - DbType.Int16 => NpgsqlDbType.Smallint, - DbType.Int32 => NpgsqlDbType.Integer, - DbType.Int64 => NpgsqlDbType.Bigint, - DbType.Single => NpgsqlDbType.Real, - DbType.String => NpgsqlDbType.Text, - DbType.Time => NpgsqlDbType.Time, - DbType.AnsiStringFixedLength => NpgsqlDbType.Text, - DbType.StringFixedLength => NpgsqlDbType.Text, - DbType.Xml => NpgsqlDbType.Xml, - DbType.DateTime2 => NpgsqlDbType.Timestamp, - DbType.DateTimeOffset => NpgsqlDbType.TimestampTz, - - DbType.Object => null, - DbType.SByte => null, - DbType.UInt16 => null, - DbType.UInt32 => null, - DbType.UInt64 => null, - - _ => throw new ArgumentOutOfRangeException(nameof(dbType), dbType, null) - }; - - internal static DbType NpgsqlDbTypeToDbType(NpgsqlDbType npgsqlDbType) - => npgsqlDbType switch + return _userTypeMapper.UnmapComposite(clrType, pgName, nameTranslator); + } + finally { - // Numeric types - NpgsqlDbType.Smallint => DbType.Int16, - NpgsqlDbType.Integer => DbType.Int32, - NpgsqlDbType.Bigint => DbType.Int64, - NpgsqlDbType.Real => DbType.Single, - NpgsqlDbType.Double => DbType.Double, - NpgsqlDbType.Numeric => DbType.Decimal, - NpgsqlDbType.Money => DbType.Currency, - - // Text types - NpgsqlDbType.Text => DbType.String, - NpgsqlDbType.Xml => DbType.Xml, - NpgsqlDbType.Varchar => DbType.String, - NpgsqlDbType.Char => DbType.String, - NpgsqlDbType.Name => DbType.String, - NpgsqlDbType.Refcursor => DbType.String, - NpgsqlDbType.Citext => DbType.String, - NpgsqlDbType.Jsonb => DbType.Object, - NpgsqlDbType.Json => DbType.Object, - NpgsqlDbType.JsonPath => DbType.String, - - // Date/time types - NpgsqlDbType.Timestamp => LegacyTimestampBehavior ? DbType.DateTime : DbType.DateTime2, - NpgsqlDbType.TimestampTz => LegacyTimestampBehavior ? DbType.DateTimeOffset : DbType.DateTime, - NpgsqlDbType.Date => DbType.Date, - NpgsqlDbType.Time => DbType.Time, - - // Misc data types - NpgsqlDbType.Bytea => DbType.Binary, - NpgsqlDbType.Boolean => DbType.Boolean, - NpgsqlDbType.Uuid => DbType.Guid, - - NpgsqlDbType.Unknown => DbType.Object, - - _ => DbType.Object - }; - - #endregion Static translation tables -} \ No newline at end of file + _lock.ExitWriteLock(); + } + } +} diff --git a/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs b/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs index 3d0c46dd92..2f4d7ff040 100644 --- a/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs +++ b/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs @@ -1,6 +1,6 @@ using System; using System.Diagnostics.CodeAnalysis; -using Npgsql.Internal.TypeHandling; +using Npgsql.Internal; using Npgsql.NameTranslation; using NpgsqlTypes; @@ -147,14 +147,14 @@ bool UnmapComposite( INpgsqlNameTranslator? nameTranslator = null); /// - /// Adds a type resolver factory, which produces resolvers that can add or modify support for PostgreSQL types. + /// Adds a type info resolver which can add or modify support for PostgreSQL types. /// Typically used by plugins. /// - /// The type resolver factory to be added. - void AddTypeResolverFactory(TypeHandlerResolverFactory resolverFactory); + /// The type resolver to be added. + void AddTypeInfoResolver(IPgTypeInfoResolver resolver); /// /// Resets all mapping changes performed on this type mapper and reverts it to its original, starting state. /// void Reset(); -} \ No newline at end of file +} diff --git a/src/Npgsql/TypeMapping/PostgresTypeOIDs.cs b/src/Npgsql/TypeMapping/PostgresTypeOIDs.cs deleted file mode 100644 index e3f0d72c4d..0000000000 --- a/src/Npgsql/TypeMapping/PostgresTypeOIDs.cs +++ /dev/null @@ -1,112 +0,0 @@ -#pragma warning disable RS0016 -#pragma warning disable 1591 - -namespace Npgsql.TypeMapping; - -/// -/// Holds well-known, built-in PostgreSQL type OIDs. -/// -/// -/// Source: -/// -static class PostgresTypeOIDs -{ - // Numeric - public const uint Int8 = 20; - public const uint Float8 = 701; - public const uint Int4 = 23; - public const uint Numeric = 1700; - public const uint Float4 = 700; - public const uint Int2 = 21; - public const uint Money = 790; - - // Boolean - public const uint Bool = 16; - - // Geometric - public const uint Box = 603; - public const uint Circle = 718; - public const uint Line = 628; - public const uint LSeg = 601; - public const uint Path = 602; - public const uint Point = 600; - public const uint Polygon = 604; - - // Character - public const uint BPChar = 1042; - public const uint Text = 25; - public const uint Varchar = 1043; - public const uint Name = 19; - public const uint Char = 18; - - // Binary data - public const uint Bytea = 17; - - // Date/Time - public const uint Date = 1082; - public const uint Time = 1083; - public const uint Timestamp = 1114; - public const uint TimestampTz = 1184; - public const uint Interval = 1186; - public const uint TimeTz = 1266; - public const uint Abstime = 702; - - // Network address - public const uint Inet = 869; - public const uint Cidr = 650; - public const uint Macaddr = 829; - public const uint Macaddr8 = 774; - - // Bit string - public const uint Bit = 1560; - public const uint Varbit = 1562; - - // Text search - public const uint TsVector = 3614; - public const uint TsQuery = 3615; - public const uint Regconfig = 3734; - - // UUID - public const uint Uuid = 2950; - - // XML - public const uint Xml = 142; - - // JSON - public const uint Json = 114; - public const uint Jsonb = 3802; - public const uint JsonPath = 4072; - - // public - public const uint Refcursor = 1790; - public const uint Oidvector = 30; - public const uint Int2vector = 22; - public const uint Oid = 26; - public const uint Xid = 28; - public const uint Xid8 = 5069; - public const uint Cid = 29; - public const uint Regtype = 2206; - public const uint Tid = 27; - public const uint PgLsn = 3220; - - // Special - public const uint Record = 2249; - public const uint Void = 2278; - public const uint Unknown = 705; - - // Range types - public const uint Int4Range = 3904; - public const uint Int8Range = 3926; - public const uint NumRange = 3906; - public const uint TsRange = 3908; - public const uint TsTzRange = 3910; - public const uint DateRange = 3912; - - // Multirange types - public const uint Int4Multirange = 4451; - public const uint Int8Multirange = 4536; - public const uint NumMultirange = 4532; - public const uint TsMultirange = 4533; - public const uint TsTzMultirange = 4534; - public const uint DateMultirange = 4535; -} \ No newline at end of file diff --git a/src/Npgsql/TypeMapping/RangeTypeHandlerResolver.cs b/src/Npgsql/TypeMapping/RangeTypeHandlerResolver.cs deleted file mode 100644 index ba3f0f44d2..0000000000 --- a/src/Npgsql/TypeMapping/RangeTypeHandlerResolver.cs +++ /dev/null @@ -1,178 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers.DateTimeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using Npgsql.Properties; -using Npgsql.Util; -using NpgsqlTypes; -using static Npgsql.Util.Statics; - -namespace Npgsql.TypeMapping; - -sealed class RangeTypeHandlerResolver : TypeHandlerResolver -{ - readonly TypeMapper _typeMapper; - readonly NpgsqlDatabaseInfo _databaseInfo; - - readonly TimestampHandler _timestampHandler; - readonly TimestampTzHandler _timestampTzHandler; - - NpgsqlTypeHandler? _timestampRangeHandler; - NpgsqlTypeHandler? _timestampTzRangeHandler; - NpgsqlTypeHandler? _timestampMultirangeHandler; - NpgsqlTypeHandler? _timestampTzMultirangeHandler; - - internal RangeTypeHandlerResolver(TypeMapper typeMapper, NpgsqlConnector connector) - { - _typeMapper = typeMapper; - _databaseInfo = connector.DatabaseInfo; - - _timestampHandler = new TimestampHandler(PgType("timestamp without time zone")); - _timestampTzHandler = new TimestampTzHandler(PgType("timestamp with time zone")); - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - { - if (!_databaseInfo.TryGetPostgresTypeByName(typeName, out var pgType)) - return null; - - return pgType switch - { - PostgresRangeType pgRangeType - => _typeMapper.ResolveByOID(pgRangeType.Subtype.OID).CreateRangeHandler(pgRangeType), - PostgresMultirangeType pgMultirangeType - => _typeMapper.ResolveByOID(pgMultirangeType.Subrange.Subtype.OID).CreateMultirangeHandler(pgMultirangeType), - _ => null - }; - } - - public override NpgsqlTypeHandler? ResolveByNpgsqlDbType(NpgsqlDbType npgsqlDbType) - { - if (npgsqlDbType.HasFlag(NpgsqlDbType.Range)) - { - var subtypeHandler = _typeMapper.ResolveByNpgsqlDbType(npgsqlDbType & ~NpgsqlDbType.Range); - - if (subtypeHandler.PostgresType.Range is not { } pgRangeType) - throw new ArgumentException( - $"No range type could be found in the database for subtype {subtypeHandler.PostgresType}"); - - return subtypeHandler.CreateRangeHandler(pgRangeType); - } - - if (npgsqlDbType.HasFlag(NpgsqlDbType.Multirange)) - { - var subtypeHandler = _typeMapper.ResolveByNpgsqlDbType(npgsqlDbType & ~NpgsqlDbType.Multirange); - - if (subtypeHandler.PostgresType.Range?.Multirange is not { } pgMultirangeType) - throw new ArgumentException(string.Format(NpgsqlStrings.NoMultirangeTypeFound, subtypeHandler.PostgresType)); - - return subtypeHandler.CreateMultirangeHandler(pgMultirangeType); - } - - // Not a range or multirange - return null; - } - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - { - // Try to see if it is an array type - var arrayElementType = GetArrayListElementType(type); - if (arrayElementType is not null) - { - // With PG14, we map arrays over range types to PG multiranges by default, not to regular arrays over ranges. - if (arrayElementType.IsGenericType && - arrayElementType.GetGenericTypeDefinition() == typeof(NpgsqlRange<>) && - _databaseInfo.Version.IsGreaterOrEqual(14)) - { - var arraySubtypeType = arrayElementType.GetGenericArguments()[0]; - - return _typeMapper.ResolveByClrType(arraySubtypeType) is - { PostgresType : { Range : { Multirange: { } pgMultirangeType } } } arraySubtypeHandler - ? arraySubtypeHandler.CreateMultirangeHandler(pgMultirangeType) - : throw new NotSupportedException($"The CLR range type {type} isn't supported by Npgsql or your PostgreSQL."); - } - } - - // TODO: We can make the following compatible with reflection-free mode by having NpgsqlRange implement some interface, and - // check for that. - if (!type.IsGenericType || type.GetGenericTypeDefinition() != typeof(NpgsqlRange<>)) - return null; - - var subtypeType = type.GetGenericArguments()[0]; - - return _typeMapper.ResolveByClrType(subtypeType) is { PostgresType : { Range : { } pgRangeType } } subtypeHandler - ? subtypeHandler.CreateRangeHandler(pgRangeType) - : throw new NotSupportedException($"The CLR range type {type} isn't supported by Npgsql or your PostgreSQL."); - - static Type? GetArrayListElementType(Type type) - { - var typeInfo = type.GetTypeInfo(); - if (typeInfo.IsArray) - return GetUnderlyingType(type.GetElementType()!); // The use of bang operator is justified here as Type.GetElementType() only returns null for the Array base class which can't be mapped in a useful way. - - var ilist = typeInfo.ImplementedInterfaces.FirstOrDefault(x => x.GetTypeInfo().IsGenericType && x.GetGenericTypeDefinition() == typeof(IList<>)); - if (ilist != null) - return GetUnderlyingType(ilist.GetGenericArguments()[0]); - - if (typeof(IList).IsAssignableFrom(type)) - throw new NotSupportedException("Non-generic IList is a supported parameter, but the NpgsqlDbType parameter must be set on the parameter"); - - return null; - - Type GetUnderlyingType(Type t) - => Nullable.GetUnderlyingType(t) ?? t; - } - } - - public override NpgsqlTypeHandler? ResolveValueDependentValue(object value) - { - // In LegacyTimestampBehavior, DateTime isn't value-dependent, and handled above in ClrTypeToDataTypeNameTable like other types - if (LegacyTimestampBehavior) - return null; - - return value switch - { - NpgsqlRange range => RangeHandler(!range.LowerBoundInfinite ? range.LowerBound.Kind : - !range.UpperBoundInfinite ? range.UpperBound.Kind : DateTimeKind.Unspecified), - - NpgsqlRange[] multirange => MultirangeHandler(GetMultirangeKind(multirange)), - List> multirange => MultirangeHandler(GetMultirangeKind(multirange)), - - _ => null - }; - - NpgsqlTypeHandler RangeHandler(DateTimeKind kind) - => kind == DateTimeKind.Utc - ? _timestampTzRangeHandler ??= _timestampTzHandler.CreateRangeHandler((PostgresRangeType)PgType("tstzrange")) - : _timestampRangeHandler ??= _timestampHandler.CreateRangeHandler((PostgresRangeType)PgType("tsrange")); - - NpgsqlTypeHandler MultirangeHandler(DateTimeKind kind) - => kind == DateTimeKind.Utc - ? _timestampTzMultirangeHandler ??= _timestampTzHandler.CreateMultirangeHandler((PostgresMultirangeType)PgType("tstzmultirange")) - : _timestampMultirangeHandler ??= _timestampHandler.CreateMultirangeHandler((PostgresMultirangeType)PgType("tsmultirange")); - } - - static DateTimeKind GetRangeKind(NpgsqlRange range) - => !range.LowerBoundInfinite - ? range.LowerBound.Kind - : !range.UpperBoundInfinite - ? range.UpperBound.Kind - : DateTimeKind.Unspecified; - - static DateTimeKind GetMultirangeKind(IList> multirange) - { - for (var i = 0; i < multirange.Count; i++) - if (!multirange[i].IsEmpty) - return GetRangeKind(multirange[i]); - - return DateTimeKind.Unspecified; - } - - PostgresType PgType(string pgTypeName) => _databaseInfo.GetPostgresTypeByName(pgTypeName); -} diff --git a/src/Npgsql/TypeMapping/RangeTypeHandlerResolverFactory.cs b/src/Npgsql/TypeMapping/RangeTypeHandlerResolverFactory.cs deleted file mode 100644 index bc7212eda8..0000000000 --- a/src/Npgsql/TypeMapping/RangeTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.TypeMapping; - -sealed class RangeTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new RangeTypeHandlerResolver(typeMapper, connector); - - public override TypeMappingResolver CreateMappingResolver() => new RangeTypeMappingResolver(); - - public override TypeMappingResolver CreateGlobalMappingResolver() => new RangeTypeMappingResolver(); -} diff --git a/src/Npgsql/TypeMapping/RangeTypeMappingResolver.cs b/src/Npgsql/TypeMapping/RangeTypeMappingResolver.cs deleted file mode 100644 index 5061a9780c..0000000000 --- a/src/Npgsql/TypeMapping/RangeTypeMappingResolver.cs +++ /dev/null @@ -1,118 +0,0 @@ -using System; -using System.Collections.Generic; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using NpgsqlTypes; -using static Npgsql.Util.Statics; - -namespace Npgsql.TypeMapping; - -sealed class RangeTypeMappingResolver : TypeMappingResolver -{ - static readonly Dictionary Mappings = new() - { - { "int4range", new(NpgsqlDbType.IntegerRange, "int4range") }, - { "int8range", new(NpgsqlDbType.BigIntRange, "int8range") }, - { "numrange", new(NpgsqlDbType.NumericRange, "numrange") }, - { "daterange", new(NpgsqlDbType.DateRange, "daterange") }, - { "tsrange", new(NpgsqlDbType.TimestampRange, "tsrange") }, - { "tstzrange", new(NpgsqlDbType.TimestampTzRange, "tstzrange") }, - - { "int4multirange", new(NpgsqlDbType.IntegerMultirange, "int4range") }, - { "int8multirange", new(NpgsqlDbType.BigIntMultirange, "int8range") }, - { "nummultirange", new(NpgsqlDbType.NumericMultirange, "numrange") }, - { "datemultirange", new(NpgsqlDbType.DateMultirange, "datemultirange") }, - { "tsmultirange", new(NpgsqlDbType.TimestampMultirange, "tsmultirange") }, - { "tstzmultirange", new(NpgsqlDbType.TimestampTzMultirange, "tstzmultirange") } - }; - - static readonly Dictionary ClrTypeToDataTypeNameTable = new() - { - // Built-in range types - { typeof(NpgsqlRange), "int4range" }, - { typeof(NpgsqlRange), "int8range" }, - { typeof(NpgsqlRange), "numrange" }, -#if NET6_0_OR_GREATER - { typeof(NpgsqlRange), "daterange" }, -#endif - - // Built-in multirange types - { typeof(NpgsqlRange[]), "int4multirange" }, - { typeof(List>), "int4multirange" }, - { typeof(NpgsqlRange[]), "int8multirange" }, - { typeof(List>), "int8multirange" }, - { typeof(NpgsqlRange[]), "nummultirange" }, - { typeof(List>), "nummultirange" }, -#if NET6_0_OR_GREATER - { typeof(NpgsqlRange[]), "datemultirange" }, - { typeof(List>), "datemultirange" }, -#endif - }; - - public override string? GetDataTypeNameByClrType(Type clrType) - => ClrTypeToDataTypeNameTable.TryGetValue(clrType, out var dataTypeName) ? dataTypeName : null; - - public override string? GetDataTypeNameByValueDependentValue(object value) - { - // In LegacyTimestampBehavior, DateTime isn't value-dependent, and handled above in ClrTypeToDataTypeNameTable like other types - if (LegacyTimestampBehavior) - return null; - - return value switch - { - NpgsqlRange range => GetRangeKind(range) == DateTimeKind.Utc ? "tstzrange" : "tsrange", - - NpgsqlRange[] multirange => GetMultirangeKind(multirange) == DateTimeKind.Utc ? "tstzmultirange" : "tsmultirange", - - _ => null - }; - } - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => Mappings.TryGetValue(dataTypeName, out var mapping) ? mapping : null; - - public override TypeMappingInfo? GetMappingByPostgresType(TypeMapper mapper, PostgresType type) - { - switch (type) - { - case PostgresRangeType pgRangeType: - { - if (mapper.TryGetMapping(pgRangeType.Subtype, out var subtypeMapping)) - { - return new(subtypeMapping.NpgsqlDbType | NpgsqlDbType.Range, type.DisplayName); - } - - break; - } - - case PostgresMultirangeType pgMultirangeType: - { - if (mapper.TryGetMapping(pgMultirangeType.Subrange.Subtype, out var subtypeMapping)) - { - return new(subtypeMapping.NpgsqlDbType | NpgsqlDbType.Multirange, type.DisplayName); - } - - break; - } - } - - return null; - } - - static DateTimeKind GetRangeKind(NpgsqlRange range) - => !range.LowerBoundInfinite - ? range.LowerBound.Kind - : !range.UpperBoundInfinite - ? range.UpperBound.Kind - : DateTimeKind.Unspecified; - - static DateTimeKind GetMultirangeKind(IList> multirange) - { - for (var i = 0; i < multirange.Count; i++) - if (!multirange[i].IsEmpty) - return GetRangeKind(multirange[i]); - - return DateTimeKind.Unspecified; - } -} diff --git a/src/Npgsql/TypeMapping/RecordTypeHandlerResolver.cs b/src/Npgsql/TypeMapping/RecordTypeHandlerResolver.cs deleted file mode 100644 index df0a44f4e4..0000000000 --- a/src/Npgsql/TypeMapping/RecordTypeHandlerResolver.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeMapping; - -sealed class RecordTypeHandlerResolver : TypeHandlerResolver -{ - readonly TypeMapper _typeMapper; - readonly NpgsqlDatabaseInfo _databaseInfo; - - RecordHandler? _recordHandler; - - public RecordTypeHandlerResolver(TypeMapper typeMapper, NpgsqlConnector connector) - { - _typeMapper = typeMapper; - _databaseInfo = connector.DatabaseInfo; - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName == "record" ? GetHandler() : null; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) => null; - - NpgsqlTypeHandler GetHandler() => _recordHandler ??= new RecordHandler(_databaseInfo.GetPostgresTypeByName("record"), _typeMapper); -} diff --git a/src/Npgsql/TypeMapping/RecordTypeHandlerResolverFactory.cs b/src/Npgsql/TypeMapping/RecordTypeHandlerResolverFactory.cs deleted file mode 100644 index e308fb03e4..0000000000 --- a/src/Npgsql/TypeMapping/RecordTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.TypeMapping; - -sealed class RecordTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new RecordTypeHandlerResolver(typeMapper, connector); -} diff --git a/src/Npgsql/TypeMapping/SystemTextJsonTypeHandlerResolver.cs b/src/Npgsql/TypeMapping/SystemTextJsonTypeHandlerResolver.cs deleted file mode 100644 index a60f53d9c5..0000000000 --- a/src/Npgsql/TypeMapping/SystemTextJsonTypeHandlerResolver.cs +++ /dev/null @@ -1,60 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Nodes; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeMapping; - -sealed class SystemTextJsonTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlConnector _connector; - readonly NpgsqlDatabaseInfo _databaseInfo; - readonly JsonSerializerOptions _serializerOptions; - readonly Dictionary? _userClrTypes; - - // Note that old versions of PG - as well as some PG-like databases (Redshift, CockroachDB) don't have json/jsonb, so we create - // these handlers lazily rather than eagerly. - SystemTextJsonHandler? _jsonbHandler; - SystemTextJsonHandler? _jsonHandler; - - internal SystemTextJsonTypeHandlerResolver( - NpgsqlConnector connector, - Dictionary? userClrTypes, - JsonSerializerOptions serializerOptions) - { - _connector = connector; - _databaseInfo = connector.DatabaseInfo; - _serializerOptions = serializerOptions; - _userClrTypes = userClrTypes; - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName switch - { - "jsonb" => JsonbHandler(), - "json" => JsonHandler(), - _ => null - }; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => SystemTextJsonTypeMappingResolver.ClrTypeToDataTypeName(type, _userClrTypes) is { } dataTypeName && - ResolveByDataTypeName(dataTypeName) is { } handler - ? handler - : null; - - public override NpgsqlTypeHandler? ResolveValueTypeGenerically(T value) - => typeof(T) == typeof(JsonDocument) || typeof(T) == typeof(JsonObject) || typeof(T) == typeof(JsonArray) - ? JsonbHandler() - : null; - - NpgsqlTypeHandler JsonbHandler() - => _jsonbHandler ??= new SystemTextJsonHandler(PgType("jsonb"), _connector.TextEncoding, isJsonb: true, _serializerOptions); - NpgsqlTypeHandler JsonHandler() - => _jsonHandler ??= new SystemTextJsonHandler(PgType("json"), _connector.TextEncoding, isJsonb: false, _serializerOptions); - - PostgresType PgType(string pgTypeName) => _databaseInfo.GetPostgresTypeByName(pgTypeName); -} diff --git a/src/Npgsql/TypeMapping/SystemTextJsonTypeHandlerResolverFactory.cs b/src/Npgsql/TypeMapping/SystemTextJsonTypeHandlerResolverFactory.cs deleted file mode 100644 index 26f593933e..0000000000 --- a/src/Npgsql/TypeMapping/SystemTextJsonTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,45 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text.Json; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; - -namespace Npgsql.TypeMapping; - -sealed class SystemTextJsonTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - readonly JsonSerializerOptions _settings; - readonly Dictionary? _userClrTypes; - - public SystemTextJsonTypeHandlerResolverFactory( - Type[]? jsonbClrTypes = null, - Type[]? jsonClrTypes = null, - JsonSerializerOptions? settings = null) - { - _settings = settings ?? new JsonSerializerOptions(); - - if (jsonbClrTypes is not null) - { - _userClrTypes ??= new(); - - foreach (var type in jsonbClrTypes) - _userClrTypes[type] = "jsonb"; - } - - if (jsonClrTypes is not null) - { - _userClrTypes ??= new(); - - foreach (var type in jsonClrTypes) - _userClrTypes[type] = "json"; - } - } - - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new SystemTextJsonTypeHandlerResolver(connector, _userClrTypes, _settings); - - public override TypeMappingResolver CreateMappingResolver() => new SystemTextJsonTypeMappingResolver(_userClrTypes); - - public override TypeMappingResolver CreateGlobalMappingResolver() => new SystemTextJsonTypeMappingResolver(userClrTypes: null); -} diff --git a/src/Npgsql/TypeMapping/SystemTextJsonTypeMappingResolver.cs b/src/Npgsql/TypeMapping/SystemTextJsonTypeMappingResolver.cs deleted file mode 100644 index b76820f718..0000000000 --- a/src/Npgsql/TypeMapping/SystemTextJsonTypeMappingResolver.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Nodes; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.TypeMapping; - -sealed class SystemTextJsonTypeMappingResolver : TypeMappingResolver -{ - readonly Dictionary? _userClrTypes; - - public SystemTextJsonTypeMappingResolver(Dictionary? userClrTypes) => _userClrTypes = userClrTypes; - - public override string? GetDataTypeNameByClrType(Type type) - => ClrTypeToDataTypeName(type, _userClrTypes); - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => DoGetMappingByDataTypeName(dataTypeName); - - internal static string? ClrTypeToDataTypeName(Type type, Dictionary? clrTypes) - => type == typeof(JsonDocument) - || type == typeof(JsonObject) || type == typeof(JsonArray) - ? "jsonb" - : clrTypes is not null && clrTypes.TryGetValue(type, out var dataTypeName) ? dataTypeName : null; - - static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName) - => dataTypeName switch - { - "jsonb" => new(NpgsqlDbType.Jsonb, "jsonb", typeof(JsonDocument) - , typeof(JsonObject), typeof(JsonArray) - ), - "json" => new(NpgsqlDbType.Json, "json"), - _ => null - }; -} diff --git a/src/Npgsql/TypeMapping/UserTypeMapper.cs b/src/Npgsql/TypeMapping/UserTypeMapper.cs new file mode 100644 index 0000000000..8524dfeb14 --- /dev/null +++ b/src/Npgsql/TypeMapping/UserTypeMapper.cs @@ -0,0 +1,216 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal; +using Npgsql.Internal.Composites; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.NameTranslation; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.TypeMapping; + +/// +/// The base class for user type mappings. +/// +public abstract class UserTypeMapping +{ + /// + /// The name of the PostgreSQL type that this mapping is for. + /// + public string PgTypeName { get; } + /// + /// The CLR type that this mapping is for. + /// + public Type ClrType { get; } + + internal UserTypeMapping(string pgTypeName, Type type) + => (PgTypeName, ClrType) = (pgTypeName, type); + + internal abstract void Build(TypeInfoMappingCollection mappings); +} + +sealed class UserTypeMapper +{ + readonly List _mappings; + public IList Items => _mappings; + + public INpgsqlNameTranslator DefaultNameTranslator { get; set; } = NpgsqlSnakeCaseNameTranslator.Instance; + + UserTypeMapper(IEnumerable mappings) => _mappings = new List(mappings); + public UserTypeMapper() => _mappings = new(); + + public UserTypeMapper Clone() => new(_mappings) { DefaultNameTranslator = DefaultNameTranslator }; + + public UserTypeMapper MapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + { + Unmap(typeof(TEnum), out var resolvedName, pgName, nameTranslator); + Items.Add(new EnumMapping(resolvedName, nameTranslator ?? DefaultNameTranslator)); + return this; + } + + public bool UnmapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + => Unmap(typeof(TEnum), out _, pgName, nameTranslator ?? DefaultNameTranslator); + + public UserTypeMapper MapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : class + { + Unmap(typeof(T), out var resolvedName, pgName, nameTranslator); + Items.Add(new CompositeMapping(resolvedName, nameTranslator ?? DefaultNameTranslator)); + return this; + } + + public UserTypeMapper MapStructComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : struct + { + Unmap(typeof(T), out var resolvedName, pgName, nameTranslator); + Items.Add(new StructCompositeMapping(resolvedName, nameTranslator ?? DefaultNameTranslator)); + return this; + } + + [RequiresUnreferencedCode("Composite type mapping currently isn't trimming-safe.")] + public UserTypeMapper MapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + if (clrType.IsConstructedGenericType && clrType.GetGenericTypeDefinition() == typeof(Nullable<>)) + throw new ArgumentException("Cannot map nullable.", nameof(clrType)); + + var openMethod = typeof(UserTypeMapper).GetMethod( + clrType.IsValueType ? nameof(MapStructComposite) : nameof(MapComposite), + new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; + + var method = openMethod.MakeGenericMethod(clrType); + + method.Invoke(this, new object?[] { pgName, nameTranslator }); + + return this; + } + + public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : class + => UnmapComposite(typeof(T), pgName, nameTranslator); + + public bool UnmapStructComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : struct + => UnmapComposite(typeof(T), pgName, nameTranslator); + + public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => Unmap(clrType, out _, pgName, nameTranslator); + + bool Unmap(Type type, out string resolvedName, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + if (pgName != null && pgName.Trim() == "") + throw new ArgumentException("pgName can't be empty", nameof(pgName)); + + nameTranslator ??= DefaultNameTranslator; + resolvedName = pgName ??= GetPgName(type, nameTranslator); + + UserTypeMapping? toRemove = null; + foreach (var item in _mappings) + if (item.PgTypeName == pgName) + toRemove = item; + + return toRemove is not null && _mappings.Remove(toRemove); + } + + static string GetPgName(Type type, INpgsqlNameTranslator nameTranslator) + => type.GetCustomAttribute()?.PgName + ?? nameTranslator.TranslateTypeName(type.Name); + + public IPgTypeInfoResolver Build() + { + var infoMappings = new TypeInfoMappingCollection(); + foreach (var mapping in _mappings) + mapping.Build(infoMappings); + + return new UserMappingResolver(infoMappings); + } + + sealed class UserMappingResolver : IPgTypeInfoResolver + { + readonly TypeInfoMappingCollection _mappings; + public UserMappingResolver(TypeInfoMappingCollection mappings) => _mappings = mappings; + PgTypeInfo? IPgTypeInfoResolver.GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => _mappings.Find(type, dataTypeName, options); + } + + sealed class CompositeMapping : UserTypeMapping where T : class + { + readonly INpgsqlNameTranslator _nameTranslator; + + public CompositeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, typeof(T)) + => _nameTranslator = nameTranslator; + + internal override void Build(TypeInfoMappingCollection mappings) + { + mappings.AddType(PgTypeName, (options, mapping, _) => + { + var pgType = mapping.GetPgType(options); + if (pgType is not PostgresCompositeType compositeType) + throw new InvalidOperationException("Composite mapping must be to a composite type"); + + return mapping.CreateInfo(options, new CompositeConverter( + ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, _nameTranslator, options))); + }, isDefault: true); + // TODO this should be split out so we can enjoy EnableArray trimming. + mappings.AddArrayType(PgTypeName); + } + } + + sealed class StructCompositeMapping : UserTypeMapping where T : struct + { + readonly INpgsqlNameTranslator _nameTranslator; + + public StructCompositeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, typeof(T)) + => _nameTranslator = nameTranslator; + + internal override void Build(TypeInfoMappingCollection mappings) + { + mappings.AddStructType(PgTypeName, (options, mapping, dataTypeNameMatch) => + { + var pgType = mapping.GetPgType(options); + if (pgType is not PostgresCompositeType compositeType) + throw new InvalidOperationException("Composite mapping must be to a composite type"); + + return mapping.CreateInfo(options, new CompositeConverter( + ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, _nameTranslator, options))); + }, isDefault: true); + // TODO this should be split out so we can enjoy EnableArray trimming. + mappings.AddStructArrayType(PgTypeName); + } + } + + sealed class EnumMapping : UserTypeMapping + where TEnum : struct, Enum + { + readonly Dictionary _enumToLabel = new(); + readonly Dictionary _labelToEnum = new(); + + public EnumMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, typeof(TEnum)) + { + foreach (var field in typeof(TEnum).GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var attribute = (PgNameAttribute?)field.GetCustomAttribute(typeof(PgNameAttribute), false); + var enumName = attribute is null + ? nameTranslator.TranslateMemberName(field.Name) + : attribute.PgName; + var enumValue = (TEnum)field.GetValue(null)!; + + _enumToLabel[enumValue] = enumName; + _labelToEnum[enumName] = enumValue; + } + } + + internal override void Build(TypeInfoMappingCollection mappings) + { + mappings.AddStructType(PgTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new EnumConverter(_enumToLabel, _labelToEnum, options.TextEncoding), preferredFormat: DataFormat.Text), isDefault: true); + + // TODO this should be split out so we can enjoy EnableArray trimming. + mappings.AddStructArrayType(PgTypeName); + } + } +} + diff --git a/src/Npgsql/UnpooledDataSource.cs b/src/Npgsql/UnpooledDataSource.cs index 8226524635..3e3cf5b019 100644 --- a/src/Npgsql/UnpooledDataSource.cs +++ b/src/Npgsql/UnpooledDataSource.cs @@ -1,7 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; -using System.Transactions; using Npgsql.Internal; using Npgsql.Util; @@ -48,4 +47,4 @@ internal override void Return(NpgsqlConnector connector) } internal override void Clear() {} -} \ No newline at end of file +} diff --git a/src/Npgsql/Util/NpgsqlTimeout.cs b/src/Npgsql/Util/NpgsqlTimeout.cs new file mode 100644 index 0000000000..eb4fb06aed --- /dev/null +++ b/src/Npgsql/Util/NpgsqlTimeout.cs @@ -0,0 +1,57 @@ +using System; +using System.Threading; +using Npgsql.Internal; + +namespace Npgsql.Util; + +/// +/// Represents a timeout that will expire at some point. +/// +public readonly struct NpgsqlTimeout +{ + readonly DateTime _expiration; + + internal static readonly NpgsqlTimeout Infinite = new(TimeSpan.Zero); + + internal NpgsqlTimeout(TimeSpan expiration) + => _expiration = expiration > TimeSpan.Zero + ? DateTime.UtcNow + expiration + : expiration == TimeSpan.Zero + ? DateTime.MaxValue + : DateTime.MinValue; + + internal void Check() + { + if (HasExpired) + ThrowHelper.ThrowNpgsqlExceptionWithInnerTimeoutException("The operation has timed out"); + } + + internal void CheckAndApply(NpgsqlConnector connector) + { + if (!IsSet) + return; + + var timeLeft = CheckAndGetTimeLeft(); + // Set the remaining timeout on the read and write buffers + connector.ReadBuffer.Timeout = connector.WriteBuffer.Timeout = timeLeft; + + // Note that we set UserTimeout as well, otherwise the read timeout will get overwritten in ReadMessage + // Note also that we must set the read buffer's timeout directly (above), since the SSL handshake + // reads data directly from the buffer, without going through ReadMessage. + connector.UserTimeout = (int) Math.Ceiling(timeLeft.TotalMilliseconds); + } + + internal bool IsSet => _expiration != DateTime.MaxValue; + + internal bool HasExpired => DateTime.UtcNow >= _expiration; + + internal TimeSpan CheckAndGetTimeLeft() + { + if (!IsSet) + return Timeout.InfiniteTimeSpan; + var timeLeft = _expiration - DateTime.UtcNow; + if (timeLeft <= TimeSpan.Zero) + Check(); + return timeLeft; + } +} diff --git a/src/Npgsql/Util/PGUtil.cs b/src/Npgsql/Util/PGUtil.cs deleted file mode 100644 index 1b9aa0ce2a..0000000000 --- a/src/Npgsql/Util/PGUtil.cs +++ /dev/null @@ -1,228 +0,0 @@ -using Npgsql.Internal; -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql.Util; - -static class Statics -{ -#if DEBUG - internal static bool LegacyTimestampBehavior; - internal static bool DisableDateTimeInfinityConversions; -#else - internal static readonly bool LegacyTimestampBehavior; - internal static readonly bool DisableDateTimeInfinityConversions; -#endif - - static Statics() - { - LegacyTimestampBehavior = AppContext.TryGetSwitch("Npgsql.EnableLegacyTimestampBehavior", out var enabled) && enabled; - DisableDateTimeInfinityConversions = AppContext.TryGetSwitch("Npgsql.DisableDateTimeInfinityConversions", out enabled) && enabled; - } - - internal static T Expect(IBackendMessage msg, NpgsqlConnector connector) - { - if (msg.GetType() != typeof(T)) - ThrowIfMsgWrongType(msg, connector); - - return (T)msg; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static T ExpectAny(IBackendMessage msg, NpgsqlConnector connector) - { - if (msg is T t) - return t; - - ThrowIfMsgWrongType(msg, connector); - return default; - } - - [DoesNotReturn] - static void ThrowIfMsgWrongType(IBackendMessage msg, NpgsqlConnector connector) - => throw connector.Break( - new NpgsqlException($"Received backend message {msg.Code} while expecting {typeof(T).Name}. Please file a bug.")); - - internal static DeferDisposable Defer(Action action) => new(action); - internal static DeferDisposable Defer(Action action, T arg) => new(action, arg); - internal static DeferDisposable Defer(Action action, T1 arg1, T2 arg2) => new(action, arg1, arg2); - // internal static AsyncDeferDisposable DeferAsync(Func func) => new AsyncDeferDisposable(func); - internal static AsyncDeferDisposable DeferAsync(Func func) => new(func); - - internal readonly struct DeferDisposable : IDisposable - { - readonly Action _action; - public DeferDisposable(Action action) => _action = action; - public void Dispose() => _action(); - } - - internal readonly struct DeferDisposable : IDisposable - { - readonly Action _action; - readonly T _arg; - public DeferDisposable(Action action, T arg) - { - _action = action; - _arg = arg; - } - public void Dispose() => _action(_arg); - } - - internal readonly struct DeferDisposable : IDisposable - { - readonly Action _action; - readonly T1 _arg1; - readonly T2 _arg2; - public DeferDisposable(Action action, T1 arg1, T2 arg2) - { - _action = action; - _arg1 = arg1; - _arg2 = arg2; - } - public void Dispose() => _action(_arg1, _arg2); - } - - internal readonly struct AsyncDeferDisposable : IAsyncDisposable - { - readonly Func _func; - public AsyncDeferDisposable(Func func) => _func = func; - public async ValueTask DisposeAsync() => await _func(); - } -} - -// ReSharper disable once InconsistentNaming -static class PGUtil -{ - internal static readonly UTF8Encoding UTF8Encoding = new(false, true); - internal static readonly UTF8Encoding RelaxedUTF8Encoding = new(false, false); - - internal static void ValidateBackendMessageCode(BackendMessageCode code) - { - switch (code) - { - case BackendMessageCode.AuthenticationRequest: - case BackendMessageCode.BackendKeyData: - case BackendMessageCode.BindComplete: - case BackendMessageCode.CloseComplete: - case BackendMessageCode.CommandComplete: - case BackendMessageCode.CopyData: - case BackendMessageCode.CopyDone: - case BackendMessageCode.CopyBothResponse: - case BackendMessageCode.CopyInResponse: - case BackendMessageCode.CopyOutResponse: - case BackendMessageCode.DataRow: - case BackendMessageCode.EmptyQueryResponse: - case BackendMessageCode.ErrorResponse: - case BackendMessageCode.FunctionCall: - case BackendMessageCode.FunctionCallResponse: - case BackendMessageCode.NoData: - case BackendMessageCode.NoticeResponse: - case BackendMessageCode.NotificationResponse: - case BackendMessageCode.ParameterDescription: - case BackendMessageCode.ParameterStatus: - case BackendMessageCode.ParseComplete: - case BackendMessageCode.PasswordPacket: - case BackendMessageCode.PortalSuspended: - case BackendMessageCode.ReadyForQuery: - case BackendMessageCode.RowDescription: - return; - default: - ThrowUnknownMessageCode(code); - return; - } - - static void ThrowUnknownMessageCode(BackendMessageCode code) - => ThrowHelper.ThrowNpgsqlException($"Unknown message code: {code}"); - } - - internal static readonly Task TrueTask = Task.FromResult(true); - internal static readonly Task FalseTask = Task.FromResult(false); -} - -enum FormatCode : short -{ - Text = 0, - Binary = 1 -} - -static class EnumerableExtensions -{ - internal static string Join(this IEnumerable values, string separator) - => string.Join(separator, values); -} - -static class ExceptionExtensions -{ - internal static Exception UnwrapAggregate(this Exception exception) - => exception is AggregateException agg ? agg.InnerException! : exception; -} - -/// -/// Represents a timeout that will expire at some point. -/// -public readonly struct NpgsqlTimeout -{ - readonly DateTime _expiration; - - internal static readonly NpgsqlTimeout Infinite = new(TimeSpan.Zero); - - internal NpgsqlTimeout(TimeSpan expiration) - => _expiration = expiration > TimeSpan.Zero - ? DateTime.UtcNow + expiration - : expiration == TimeSpan.Zero - ? DateTime.MaxValue - : DateTime.MinValue; - - internal void Check() - { - if (HasExpired) - ThrowHelper.ThrowNpgsqlExceptionWithInnerTimeoutException("The operation has timed out"); - } - - internal void CheckAndApply(NpgsqlConnector connector) - { - if (!IsSet) - return; - - var timeLeft = CheckAndGetTimeLeft(); - // Set the remaining timeout on the read and write buffers - connector.ReadBuffer.Timeout = connector.WriteBuffer.Timeout = timeLeft; - - // Note that we set UserTimeout as well, otherwise the read timeout will get overwritten in ReadMessage - // Note also that we must set the read buffer's timeout directly (above), since the SSL handshake - // reads data directly from the buffer, without going through ReadMessage. - connector.UserTimeout = (int) Math.Ceiling(timeLeft.TotalMilliseconds); - } - - internal bool IsSet => _expiration != DateTime.MaxValue; - - internal bool HasExpired => DateTime.UtcNow >= _expiration; - - internal TimeSpan CheckAndGetTimeLeft() - { - if (!IsSet) - return Timeout.InfiniteTimeSpan; - var timeLeft = _expiration - DateTime.UtcNow; - if (timeLeft <= TimeSpan.Zero) - Check(); - return timeLeft; - } -} - -static class MethodInfos -{ - internal static readonly ConstructorInfo InvalidCastExceptionCtor = - typeof(InvalidCastException).GetConstructor(new[] { typeof(string) })!; - - internal static readonly MethodInfo StringFormat = - typeof(string).GetMethod(nameof(string.Format), new[] { typeof(string), typeof(object) })!; - - internal static readonly MethodInfo ObjectGetType = - typeof(object).GetMethod(nameof(GetType), new Type[0])!; -} \ No newline at end of file diff --git a/src/Npgsql/Util/ResettableCancellationTokenSource.cs b/src/Npgsql/Util/ResettableCancellationTokenSource.cs index 9bb507b1cb..0912ceb7b9 100644 --- a/src/Npgsql/Util/ResettableCancellationTokenSource.cs +++ b/src/Npgsql/Util/ResettableCancellationTokenSource.cs @@ -97,11 +97,8 @@ public void RestartTimeoutWithoutReset() /// in order make sure the next call to will not invalidate /// the cancellation token. /// - /// - /// An optional token to cancel the asynchronous operation. The default value is . - /// /// The from the wrapped - public CancellationToken Reset(CancellationToken cancellationToken = default) + public CancellationToken Reset() { _registration.Dispose(); lock (lockObject) @@ -124,8 +121,6 @@ public CancellationToken Reset(CancellationToken cancellationToken = default) _cts = new CancellationTokenSource(); } } - if (cancellationToken.CanBeCanceled) - _registration = cancellationToken.Register(cts => ((CancellationTokenSource)cts!).Cancel(), _cts); #if DEBUG _isRunning = false; #endif @@ -230,4 +225,4 @@ public void Dispose() isDisposed = true; } } -} \ No newline at end of file +} diff --git a/src/Npgsql/Util/Statics.cs b/src/Npgsql/Util/Statics.cs new file mode 100644 index 0000000000..b84cea4afd --- /dev/null +++ b/src/Npgsql/Util/Statics.cs @@ -0,0 +1,92 @@ +using Npgsql.Internal; +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace Npgsql.Util; + +static class Statics +{ +#if DEBUG + internal static bool LegacyTimestampBehavior; + internal static bool DisableDateTimeInfinityConversions; +#else + internal static readonly bool LegacyTimestampBehavior; + internal static readonly bool DisableDateTimeInfinityConversions; +#endif + + static Statics() + { + LegacyTimestampBehavior = AppContext.TryGetSwitch("Npgsql.EnableLegacyTimestampBehavior", out var enabled) && enabled; + DisableDateTimeInfinityConversions = AppContext.TryGetSwitch("Npgsql.DisableDateTimeInfinityConversions", out enabled) && enabled; + } + + internal static T Expect(IBackendMessage msg, NpgsqlConnector connector) + { + if (msg.GetType() != typeof(T)) + ThrowIfMsgWrongType(msg, connector); + + return (T)msg; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static T ExpectAny(IBackendMessage msg, NpgsqlConnector connector) + { + if (msg is T t) + return t; + + ThrowIfMsgWrongType(msg, connector); + return default; + } + + [DoesNotReturn] + static void ThrowIfMsgWrongType(IBackendMessage msg, NpgsqlConnector connector) + => throw connector.Break( + new NpgsqlException($"Received backend message {msg.Code} while expecting {typeof(T).Name}. Please file a bug.")); + + internal static void ValidateBackendMessageCode(BackendMessageCode code) + { + switch (code) + { + case BackendMessageCode.AuthenticationRequest: + case BackendMessageCode.BackendKeyData: + case BackendMessageCode.BindComplete: + case BackendMessageCode.CloseComplete: + case BackendMessageCode.CommandComplete: + case BackendMessageCode.CopyData: + case BackendMessageCode.CopyDone: + case BackendMessageCode.CopyBothResponse: + case BackendMessageCode.CopyInResponse: + case BackendMessageCode.CopyOutResponse: + case BackendMessageCode.DataRow: + case BackendMessageCode.EmptyQueryResponse: + case BackendMessageCode.ErrorResponse: + case BackendMessageCode.FunctionCall: + case BackendMessageCode.FunctionCallResponse: + case BackendMessageCode.NoData: + case BackendMessageCode.NoticeResponse: + case BackendMessageCode.NotificationResponse: + case BackendMessageCode.ParameterDescription: + case BackendMessageCode.ParameterStatus: + case BackendMessageCode.ParseComplete: + case BackendMessageCode.PasswordPacket: + case BackendMessageCode.PortalSuspended: + case BackendMessageCode.ReadyForQuery: + case BackendMessageCode.RowDescription: + return; + default: + ThrowUnknownMessageCode(code); + return; + } + + static void ThrowUnknownMessageCode(BackendMessageCode code) + => ThrowHelper.ThrowNpgsqlException($"Unknown message code: {code}"); + } +} + +static class EnumerableExtensions +{ + internal static string Join(this IEnumerable values, string separator) + => string.Join(separator, values); +} diff --git a/src/Npgsql/Util/StrongBox.cs b/src/Npgsql/Util/StrongBox.cs new file mode 100644 index 0000000000..d72c3140e0 --- /dev/null +++ b/src/Npgsql/Util/StrongBox.cs @@ -0,0 +1,41 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Util; + +abstract class StrongBox +{ + private protected StrongBox() { } + public abstract bool HasValue { get; } + public abstract object? Value { get; set; } + public abstract void Clear(); +} + +sealed class StrongBox : StrongBox +{ + bool _hasValue; + + [MaybeNull] T _typedValue; + [MaybeNull] + public T TypedValue { + get => _typedValue; + set + { + _hasValue = true; + _typedValue = value; + } + } + + public override bool HasValue => _hasValue; + + public override object? Value + { + get => TypedValue; + set => TypedValue = (T)value!; + } + + public override void Clear() + { + _hasValue = false; + TypedValue = default!; + } +} diff --git a/src/Npgsql/Util/SubReadStream.cs b/src/Npgsql/Util/SubReadStream.cs new file mode 100644 index 0000000000..6aaee9651a --- /dev/null +++ b/src/Npgsql/Util/SubReadStream.cs @@ -0,0 +1,227 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Util; + +// Adapted from https://github.com/dotnet/runtime/blob/83adfae6a6273d8fb4c69554aa3b1cc7cbf01c71/src/libraries/System.IO.Compression/src/System/IO/Compression/ZipCustomStreams.cs#L221 +sealed class SubReadStream : Stream +{ + readonly long _startInSuperStream; + long _positionInSuperStream; + readonly long _endInSuperStream; + readonly Stream _superStream; + readonly bool _canSeek; + bool _isDisposed; + + public SubReadStream(Stream superStream, long maxLength) + { + _startInSuperStream = -1; + _positionInSuperStream = 0; + _endInSuperStream = maxLength; + _superStream = superStream; + _canSeek = false; + _isDisposed = false; + } + + public SubReadStream(Stream superStream, long startPosition, long maxLength) + { + _startInSuperStream = startPosition; + _positionInSuperStream = startPosition; + _endInSuperStream = startPosition + maxLength; + _superStream = superStream; + _canSeek = superStream.CanSeek; + _isDisposed = false; + } + + public override long Length + { + get + { + ThrowIfDisposed(); + + if (!_canSeek) + throw new NotSupportedException(); + + return _endInSuperStream - _startInSuperStream; + } + } + + public override long Position + { + get + { + ThrowIfDisposed(); + + if (!_canSeek) + throw new NotSupportedException(); + + return _positionInSuperStream - _startInSuperStream; + } + set + { + ThrowIfDisposed(); + + throw new NotSupportedException(); + } + } + + public override bool CanRead => _superStream.CanRead && !_isDisposed; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + void ThrowIfDisposed() + { + if (_isDisposed) + throw new ObjectDisposedException(GetType().ToString()); + } + + void ThrowIfCantRead() + { + if (!CanRead) + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + // parameter validation sent to _superStream.Read + var origCount = count; + + ThrowIfDisposed(); + ThrowIfCantRead(); + + if (_canSeek && _superStream.Position != _positionInSuperStream) + _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin); + if (_positionInSuperStream > _endInSuperStream - count) + count = (int)(_endInSuperStream - _positionInSuperStream); + + Debug.Assert(count >= 0); + Debug.Assert(count <= origCount); + + var ret = _superStream.Read(buffer, offset, count); + + _positionInSuperStream += ret; + return ret; + } + +#if !NETSTANDARD2_0 + public override int Read(Span destination) +#else + int Read(Span destination) +#endif + { + // parameter validation sent to _superStream.Read + var origCount = destination.Length; + var count = destination.Length; + + ThrowIfDisposed(); + ThrowIfCantRead(); + + if (_canSeek && _superStream.Position != _positionInSuperStream) + _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin); + if (_positionInSuperStream + count > _endInSuperStream) + count = (int)(_endInSuperStream - _positionInSuperStream); + + Debug.Assert(count >= 0); + Debug.Assert(count <= origCount); + + var ret = _superStream.Read(destination.Slice(0, count)); + + _positionInSuperStream += ret; + return ret; + } + + public override int ReadByte() + { + Span b = stackalloc byte[1]; + return Read(b) == 1 ? b[0] : -1; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBufferArguments(buffer, offset, count); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + +#if !NETSTANDARD2_0 + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) +#else + ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) +#endif + { + ThrowIfDisposed(); + ThrowIfCantRead(); + if (_canSeek && _superStream.Position != _positionInSuperStream) + { + _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin); + } + + if (_positionInSuperStream > _endInSuperStream - buffer.Length) + { + buffer = buffer.Slice(0, (int)(_endInSuperStream - _positionInSuperStream)); + } + + return Core(buffer, cancellationToken); + + async ValueTask Core(Memory buffer, CancellationToken cancellationToken) + { + var ret = await _superStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _positionInSuperStream += ret; + return ret; + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + public override void Flush() + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + // Close the stream for reading. Note that this does NOT close the superStream (since + // the substream is just 'a chunk' of the super-stream + protected override void Dispose(bool disposing) + { + if (disposing && !_isDisposed) + { + _isDisposed = true; + } + base.Dispose(disposing); + } + +#if NETSTANDARD + void ValidateBufferArguments(byte[]? buffer, int offset, int count) + { + if (buffer is null) + ThrowHelper.ThrowArgumentNullException(nameof(buffer)); + + if (offset < 0) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(offset), "Offset is less than 0"); + + if ((uint)count > buffer.Length - offset) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Count larger than buffer minus offset"); + + } +#endif +} diff --git a/src/Npgsql/VolatileResourceManager.cs b/src/Npgsql/VolatileResourceManager.cs index 70afea0557..816cf15b32 100644 --- a/src/Npgsql/VolatileResourceManager.cs +++ b/src/Npgsql/VolatileResourceManager.cs @@ -1,5 +1,4 @@ using System; -using System.Diagnostics; using System.Threading; using System.Transactions; using Microsoft.Extensions.Logging; diff --git a/src/Shared/CodeAnalysis.cs b/src/Shared/CodeAnalysis.cs index 594d5bb5d0..8e8e3b3d9e 100644 --- a/src/Shared/CodeAnalysis.cs +++ b/src/Shared/CodeAnalysis.cs @@ -5,6 +5,54 @@ namespace System.Diagnostics.CodeAnalysis { +#if !NET7_0_OR_GREATER + /// + /// Indicates that the specified method requires the ability to generate new code at runtime, + /// for example through . + /// + /// + /// This allows tools to understand which methods are unsafe to call when compiling ahead of time. + /// + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] + sealed class RequiresDynamicCodeAttribute : Attribute + { + /// + /// Initializes a new instance of the class + /// with the specified message. + /// + /// + /// A message that contains information about the usage of dynamic code. + /// + public RequiresDynamicCodeAttribute(string message) + { + Message = message; + } + + /// + /// Gets a message that contains information about the usage of dynamic code. + /// + public string Message { get; } + + /// + /// Gets or sets an optional URL that contains more information about the method, + /// why it requires dynamic code, and what options a consumer has to deal with it. + /// + public string? Url { get; set; } + } + + [AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)] + sealed class SetsRequiredMembersAttribute : Attribute + { + } + [AttributeUsageAttribute(AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] + sealed class UnscopedRefAttribute : Attribute + { + /// + /// Initializes a new instance of the class. + /// + public UnscopedRefAttribute() { } + } +#endif #if NETSTANDARD2_0 [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] sealed class AllowNullAttribute : Attribute @@ -167,9 +215,43 @@ public UnconditionalSuppressMessageAttribute(string category, string checkId) #endif } -#if !NET5_0_OR_GREATER namespace System.Runtime.CompilerServices { - internal static class IsExternalInit {} +#if !NET5_0_OR_GREATER + static class IsExternalInit {} +#endif +#if !NET7_0_OR_GREATER + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)] + sealed class RequiredMemberAttribute : Attribute + { } + + [AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)] + sealed class CompilerFeatureRequiredAttribute : Attribute + { + public CompilerFeatureRequiredAttribute(string featureName) + { + FeatureName = featureName; + } + + /// + /// The name of the compiler feature. + /// + public string FeatureName { get; } + + /// + /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand . + /// + public bool IsOptional { get; init; } + + /// + /// The used for the ref structs C# feature. + /// + public const string RefStructs = nameof(RefStructs); + + /// + /// The used for the required members C# feature. + /// + public const string RequiredMembers = nameof(RequiredMembers); + } +#endif } -#endif \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/Prepare.cs b/test/Npgsql.Benchmarks/Prepare.cs index 246b25e491..6b8d9b06bc 100644 --- a/test/Npgsql.Benchmarks/Prepare.cs +++ b/test/Npgsql.Benchmarks/Prepare.cs @@ -1,5 +1,4 @@ -using System.Diagnostics.CodeAnalysis; -using System.Linq; +using System.Linq; using System.Reflection; using System.Text; using BenchmarkDotNet.Attributes; diff --git a/test/Npgsql.Benchmarks/ReadArray.cs b/test/Npgsql.Benchmarks/ReadArray.cs index fecda03f43..e1e5b2d8de 100644 --- a/test/Npgsql.Benchmarks/ReadArray.cs +++ b/test/Npgsql.Benchmarks/ReadArray.cs @@ -1,10 +1,4 @@ using BenchmarkDotNet.Attributes; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; namespace Npgsql.Benchmarks; diff --git a/test/Npgsql.Benchmarks/ResolveHandler.cs b/test/Npgsql.Benchmarks/ResolveHandler.cs index 419b3e179c..86e5d20fbb 100644 --- a/test/Npgsql.Benchmarks/ResolveHandler.cs +++ b/test/Npgsql.Benchmarks/ResolveHandler.cs @@ -1,8 +1,6 @@ using BenchmarkDotNet.Attributes; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; -using Npgsql.TypeMapping; -using NpgsqlTypes; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; namespace Npgsql.Benchmarks; @@ -10,7 +8,7 @@ namespace Npgsql.Benchmarks; public class ResolveHandler { NpgsqlDataSource? _dataSource; - TypeMapper _typeMapper = null!; + PgSerializerOptions _serializerOptions = null!; [Params(0, 1, 2)] public int NumPlugins { get; set; } @@ -24,29 +22,21 @@ public void Setup() if (NumPlugins > 1) dataSourceBuilder.UseNetTopologySuite(); _dataSource = dataSourceBuilder.Build(); - _typeMapper = _dataSource.TypeMapper; + _serializerOptions = _dataSource.SerializerOptions; } [GlobalCleanup] public void Cleanup() => _dataSource?.Dispose(); [Benchmark] - public NpgsqlTypeHandler ResolveOID() - => _typeMapper.ResolveByOID(23); // int4 + public PgTypeInfo? ResolveDefault() + => _serializerOptions.GetDefaultTypeInfo(new Oid(23)); // int4 [Benchmark] - public NpgsqlTypeHandler ResolveNpgsqlDbType() - => _typeMapper.ResolveByNpgsqlDbType(NpgsqlDbType.Integer); + public PgTypeInfo? ResolveType() + => _serializerOptions.GetTypeInfo(typeof(int)); [Benchmark] - public NpgsqlTypeHandler ResolveDataTypeName() - => _typeMapper.ResolveByDataTypeName("integer"); - - [Benchmark] - public NpgsqlTypeHandler ResolveClrTypeNonGeneric() - => _typeMapper.ResolveByValue((object)8); - - [Benchmark] - public NpgsqlTypeHandler ResolveClrTypeGeneric() - => _typeMapper.ResolveByValue(8); + public PgTypeInfo? ResolveBoth() + => _serializerOptions.GetTypeInfo(typeof(int), new Oid(23)); // int4 } diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs b/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs index 496b51af6f..52418a7240 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs @@ -1,9 +1,4 @@ -using System.Collections.Generic; -using BenchmarkDotNet.Attributes; -using Npgsql.NameTranslation; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; -using Npgsql.Util; + /* Disabling for now: unmapped composite support is probably going away, and there's a good chance this * class can be simplified to a certain extent diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs b/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs index 19e044b0a4..42f5f3936a 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs @@ -1,43 +1,43 @@ using System.Collections.Generic; using BenchmarkDotNet.Attributes; -using Npgsql.Internal.TypeHandlers.NumericHandlers; +using Npgsql.Internal.Converters; namespace Npgsql.Benchmarks.TypeHandlers; [Config(typeof(Config))] public class Int16 : TypeHandlerBenchmarks { - public Int16() : base(new Int16Handler(GetPostgresType("smallint"))) { } + public Int16() : base(new Int2Converter()) { } } [Config(typeof(Config))] public class Int32 : TypeHandlerBenchmarks { - public Int32() : base(new Int32Handler(GetPostgresType("integer"))) { } + public Int32() : base(new Int4Converter()) { } } [Config(typeof(Config))] public class Int64 : TypeHandlerBenchmarks { - public Int64() : base(new Int64Handler(GetPostgresType("bigint"))) { } + public Int64() : base(new Int8Converter()) { } } [Config(typeof(Config))] public class Single : TypeHandlerBenchmarks { - public Single() : base(new SingleHandler(GetPostgresType("real"))) { } + public Single() : base(new RealConverter()) { } } [Config(typeof(Config))] public class Double : TypeHandlerBenchmarks { - public Double() : base(new DoubleHandler(GetPostgresType("double precision"))) { } + public Double() : base(new DoubleConverter()) { } } [Config(typeof(Config))] public class Numeric : TypeHandlerBenchmarks { - public Numeric() : base(new NumericHandler(GetPostgresType("numeric"))) { } + public Numeric() : base(new DecimalNumericConverter()) { } protected override IEnumerable ValuesOverride() => new[] { @@ -62,5 +62,5 @@ protected override IEnumerable ValuesOverride() => new[] [Config(typeof(Config))] public class Money : TypeHandlerBenchmarks { - public Money() : base(new MoneyHandler(GetPostgresType("money"))) { } -} \ No newline at end of file + public Money() : base(new MoneyConverter()) { } +} diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Text.cs b/test/Npgsql.Benchmarks/TypeHandlers/Text.cs index 407a749240..80d5f6ce0c 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Text.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Text.cs @@ -1,18 +1,18 @@ using BenchmarkDotNet.Attributes; using System.Collections.Generic; using System.Text; -using Npgsql.Internal.TypeHandlers; +using Npgsql.Internal.Converters; namespace Npgsql.Benchmarks.TypeHandlers; [Config(typeof(Config))] public class Text : TypeHandlerBenchmarks { - public Text() : base(new TextHandler(GetPostgresType("text"), Encoding.UTF8)) { } + public Text() : base(new StringTextConverter(Encoding.UTF8)) { } protected override IEnumerable ValuesOverride() { for (var i = 1; i <= 10000; i *= 10) yield return new string('x', i); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs b/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs index 76cc862378..994839c219 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs @@ -5,11 +5,8 @@ using System; using System.Collections.Generic; using System.IO; -using System.Text; +using System.Threading; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.Util; #nullable disable @@ -40,27 +37,26 @@ public override void SetLength(long value) { } public override void Write(byte[] buffer, int offset, int count) { } } - readonly EndlessStream _stream; - readonly NpgsqlTypeHandler _handler; - readonly NpgsqlReadBuffer _readBuffer; + readonly PgConverter _converter; + readonly PgReader _reader; + readonly PgWriter _writer; readonly NpgsqlWriteBuffer _writeBuffer; - T _value; - int _elementSize; + readonly NpgsqlReadBuffer _readBuffer; + readonly BufferRequirements _binaryRequirements; - protected TypeHandlerBenchmarks(NpgsqlTypeHandler handler) - { - _stream = new EndlessStream(); - _handler = handler ?? throw new ArgumentNullException(nameof(handler)); - _readBuffer = new NpgsqlReadBuffer(null, _stream, null, NpgsqlReadBuffer.MinimumSize, Encoding.UTF8, PGUtil.RelaxedUTF8Encoding); - _writeBuffer = new NpgsqlWriteBuffer(null, _stream, null, NpgsqlWriteBuffer.MinimumSize, Encoding.UTF8); - } + T _value; + Size _elementSize; - protected static PostgresType GetPostgresType(string pgType) + protected TypeHandlerBenchmarks(PgConverter handler) { - using (var conn = BenchmarkEnvironment.OpenConnection()) - using (var cmd = new NpgsqlCommand($"SELECT NULL::{pgType}", conn)) - using (var reader = cmd.ExecuteReader()) - return reader.GetPostgresType(0); + var stream = new EndlessStream(); + _converter = handler ?? throw new ArgumentNullException(nameof(handler)); + _readBuffer = new NpgsqlReadBuffer(null, stream, null, NpgsqlReadBuffer.MinimumSize, NpgsqlWriteBuffer.UTF8Encoding, NpgsqlWriteBuffer.RelaxedUTF8Encoding); + _writeBuffer = new NpgsqlWriteBuffer(null, stream, null, NpgsqlWriteBuffer.MinimumSize, NpgsqlWriteBuffer.UTF8Encoding); + _reader = new PgReader(_readBuffer); + _writer = new PgWriter(new NpgsqlBufferWriter(_writeBuffer)); + _writer.Init(new PostgresMinimalDatabaseInfo()); + _converter.CanConvert(DataFormat.Binary, out _binaryRequirements); } public IEnumerable Values() => ValuesOverride(); @@ -73,17 +69,16 @@ public T Value get => _value; set { - NpgsqlLengthCache cache = null; - _value = value; - _elementSize = _handler.ValidateAndGetLength(value, ref cache, null); - - cache.Rewind(); - - _handler.WriteWithLength(_value, _writeBuffer, cache, null, false); - Buffer.BlockCopy(_writeBuffer.Buffer, 0, _readBuffer.Buffer, 0, _elementSize); - - _readBuffer.FilledBytes = _elementSize; + object state = null; + var size = _elementSize = _converter.GetSizeAsObject(new(DataFormat.Binary, _binaryRequirements.Write), value, ref state); + var current = new ValueMetadata { Format = DataFormat.Binary, BufferRequirement = _binaryRequirements.Write, Size = size, WriteState = state }; + _writer.BeginWrite(async: false, current, CancellationToken.None).GetAwaiter().GetResult(); + _converter.WriteAsObject(_writer, value); + Buffer.BlockCopy(_writeBuffer.Buffer, 0, _readBuffer.Buffer, 0, size.Value); + + _writer.Commit(size.Value); + _readBuffer.FilledBytes = size.Value; _writeBuffer.WritePosition = 0; } } @@ -92,13 +87,18 @@ public T Value public T Read() { _readBuffer.ReadPosition = sizeof(int); - return _handler.Read(_readBuffer, _elementSize); + _reader.StartRead(_binaryRequirements.Read); + var value = ((PgConverter)_converter).Read(_reader); + _reader.EndRead(); + return value; } [Benchmark] public void Write() { _writeBuffer.WritePosition = 0; - _handler.WriteWithLength(_value, _writeBuffer, null, null, false); + var current = new ValueMetadata { Format = DataFormat.Binary, BufferRequirement = _binaryRequirements.Write, Size = _elementSize, WriteState = null }; + _writer.BeginWrite(async: false, current, CancellationToken.None).GetAwaiter().GetResult(); + ((PgConverter)_converter).Write(_writer, _value); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs b/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs index 78d4018dfd..7c229a3b57 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs @@ -1,11 +1,11 @@ using System; using BenchmarkDotNet.Attributes; -using Npgsql.Internal.TypeHandlers; +using Npgsql.Internal.Converters; namespace Npgsql.Benchmarks.TypeHandlers; [Config(typeof(Config))] public class Uuid : TypeHandlerBenchmarks { - public Uuid() : base(new UuidHandler(GetPostgresType("uuid"))) { } -} \ No newline at end of file + public Uuid() : base(new GuidUuidConverter()) { } +} diff --git a/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj b/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj index 3396a51a92..bc680c3052 100644 --- a/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj +++ b/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj @@ -5,12 +5,13 @@ net8.0 true + true true true true - false - Size + false true + Size diff --git a/test/Npgsql.NodaTime.Tests/LegacyNodaTimeTests.cs b/test/Npgsql.NodaTime.Tests/LegacyNodaTimeTests.cs deleted file mode 100644 index 67c4202ff4..0000000000 --- a/test/Npgsql.NodaTime.Tests/LegacyNodaTimeTests.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System; -using System.Data; -using System.Threading.Tasks; -using NodaTime; -using Npgsql.Tests; -using NpgsqlTypes; -using NUnit.Framework; - -namespace Npgsql.NodaTime.Tests; - -// Since this test suite manipulates TimeZone, it is incompatible with multiplexing -[NonParallelizable] -public class LegacyNodaTimeTests : TestBase -{ - [Test] - public Task Timestamp_as_Instant() - => AssertType( - new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), - "1998-04-12 13:26:38.789", - "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime); - - [Test] - public Task Timestamp_as_LocalDateTime() - => AssertType( - new LocalDateTime(1998, 4, 12, 13, 26, 38, 789), - "1998-04-12 13:26:38.789", - "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime, - isDefaultForReading: false); - - [Test] - public Task Timestamptz_as_Instant() - => AssertType( - new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), - "1998-04-12 15:26:38.789+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, - isDefault: false); - - [Test] - public Task Timestamptz_ZonedDateTime_infinite_values_are_not_supported() - => AssertTypeUnsupported(Instant.MaxValue.InZone(DateTimeZone.Utc), "infinity", "timestamptz"); - - [Test] - public Task Timestamptz_OffsetDateTime_infinite_values_are_not_supported() - => AssertTypeUnsupported(Instant.MaxValue.WithOffset(Offset.Zero), "infinity", "timestamptz"); - - #region Support - - protected override async ValueTask OpenConnectionAsync() - { - var conn = await base.OpenConnectionAsync(); - await conn.ExecuteNonQueryAsync("SET TimeZone='Europe/Berlin'"); - return conn; - } - - protected override NpgsqlConnection OpenConnection() - => throw new NotSupportedException(); - -#pragma warning disable CS1998 // Release code blocks below lack await -#pragma warning disable CS0618 // GlobalTypeMapper is obsolete - [OneTimeSetUp] - public async Task Setup() - { -#if DEBUG - Internal.NodaTimeUtils.LegacyTimestampBehavior = true; - Util.Statics.LegacyTimestampBehavior = true; - - // Clear any previous cached mappings/handlers in case tests were executed before the legacy flag was set. - NpgsqlConnection.GlobalTypeMapper.Reset(); - NpgsqlConnection.GlobalTypeMapper.UseNodaTime(); - await using var connection = await OpenConnectionAsync(); - await connection.ReloadTypesAsync(); -#else - Assert.Ignore( - "Legacy NodaTime tests rely on the Npgsql.EnableLegacyTimestampBehavior AppContext switch and can only be run in DEBUG builds"); -#endif - - } - - [OneTimeTearDown] - public async Task Teardown() - { -#if DEBUG - Internal.NodaTimeUtils.LegacyTimestampBehavior = false; - Util.Statics.LegacyTimestampBehavior = false; - - // Clear any previous cached mappings/handlers to not affect test which will run later without the legacy flag - NpgsqlConnection.GlobalTypeMapper.Reset(); - NpgsqlConnection.GlobalTypeMapper.UseNodaTime(); - - await using var connection = await OpenConnectionAsync(); - await connection.ReloadTypesAsync(); -#endif - } -#pragma warning restore CS1998 -#pragma warning restore CS0618 // GlobalTypeMapper is obsolete - - #endregion Support -} diff --git a/test/Npgsql.NodaTime.Tests/NodaTimeSetupFixture.cs b/test/Npgsql.NodaTime.Tests/NodaTimeSetupFixture.cs deleted file mode 100644 index 25ab4f58cd..0000000000 --- a/test/Npgsql.NodaTime.Tests/NodaTimeSetupFixture.cs +++ /dev/null @@ -1,18 +0,0 @@ -using NUnit.Framework; - -namespace Npgsql.NodaTime.Tests; - -// Note that we register NodaTime globally, rather than using the more standard data source mapping. -// We can do this since NUnit runs each test assembly in a different process, so we get isolation and don't interfere with other, -// non-NodaTime tests. This also allows us to test global type inference, which only works with global mappings. -[SetUpFixture] -public class NodaTimeSetupFixture -{ -#pragma warning disable CS0618 // GlobalTypeMapper is obsolete - [OneTimeSetUp] - public void OneTimeSetUp() => NpgsqlConnection.GlobalTypeMapper.UseNodaTime(); - - [OneTimeTearDown] - public void OneTimeTearDown() => NpgsqlConnection.GlobalTypeMapper.Reset(); -#pragma warning restore CS0618 // GlobalTypeMapper is obsolete -} diff --git a/test/Npgsql.NodaTime.Tests/Npgsql.NodaTime.Tests.csproj b/test/Npgsql.NodaTime.Tests/Npgsql.NodaTime.Tests.csproj deleted file mode 100644 index bfa9b74079..0000000000 --- a/test/Npgsql.NodaTime.Tests/Npgsql.NodaTime.Tests.csproj +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - - - - - - - - diff --git a/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs b/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs new file mode 100644 index 0000000000..6af0afec24 --- /dev/null +++ b/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs @@ -0,0 +1,106 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Tests; +using NpgsqlTypes; +using NUnit.Framework; +using Npgsql.NodaTime.Internal; + +namespace Npgsql.PluginTests; + +[NonParallelizable] // Since this test suite manipulates an AppContext switch +public class LegacyNodaTimeTests : TestBase, IDisposable +{ + const string TimeZone = "Europe/Berlin"; + + [Test] + public async Task Timestamp_as_ZonedDateTime() + { + await AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InZoneLeniently(DateTimeZoneProviders.Tzdb[TimeZone]), + "1998-04-12 13:26:38.789+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + isNpgsqlDbTypeInferredFromClrType: false, isDefault: false); + } + + [Test] + public Task Timestamp_as_Instant() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), + "1998-04-12 13:26:38.789", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Timestamp_as_LocalDateTime() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789), + "1998-04-12 13:26:38.789", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Timestamptz_as_Instant() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), + "1998-04-12 15:26:38.789+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public async Task Timestamptz_ZonedDateTime_infinite_values_are_not_supported() + { + await AssertTypeUnsupportedRead("infinity", "timestamptz"); + await AssertTypeUnsupportedWrite(Instant.MaxValue.WithOffset(Offset.Zero), "timestamptz"); + } + + [Test] + public async Task Timestamptz_OffsetDateTime_infinite_values_are_not_supported() + { + await AssertTypeUnsupportedRead("infinity", "timestamptz"); + await AssertTypeUnsupportedWrite(Instant.MaxValue.WithOffset(Offset.Zero), "timestamptz"); + } + + #region Support + + protected override NpgsqlDataSource DataSource { get; } + + public LegacyNodaTimeTests() + { +#if DEBUG + NodaTimeUtils.LegacyTimestampBehavior = true; + Util.Statics.LegacyTimestampBehavior = true; + + var builder = CreateDataSourceBuilder(); + builder.UseNodaTime(); + builder.ConnectionStringBuilder.Timezone = TimeZone; + DataSource = builder.Build(); +#else + Assert.Ignore( + "Legacy NodaTime tests rely on the Npgsql.EnableLegacyTimestampBehavior AppContext switch and can only be run in DEBUG builds"); +#endif + } + + public void Dispose() + { +#if DEBUG + NodaTimeUtils.LegacyTimestampBehavior = false; + Util.Statics.LegacyTimestampBehavior = false; + + DataSource.Dispose(); +#endif + } + + #endregion Support +} diff --git a/test/Npgsql.PluginTests/NetTopologySuiteTests.cs b/test/Npgsql.PluginTests/NetTopologySuiteTests.cs index 20fc9f17a4..2fb33f678d 100644 --- a/test/Npgsql.PluginTests/NetTopologySuiteTests.cs +++ b/test/Npgsql.PluginTests/NetTopologySuiteTests.cs @@ -1,5 +1,4 @@ using System; -using System.Collections; using System.Collections.Concurrent; using System.Linq; using System.Threading.Tasks; @@ -14,26 +13,34 @@ namespace Npgsql.PluginTests; public class NetTopologySuiteTests : TestBase { - public struct TestData + static readonly TestCaseData[] TestCases = { - public Ordinates Ordinates; - public Geometry Geometry; - public string CommandText; - } + new TestCaseData(Ordinates.None, new Point(1d, 2500d), "st_makepoint(1,2500)") + .SetName("Point"), - public static IEnumerable TestCases { - get - { - // Two dimensional data - yield return new TestCaseData(Ordinates.None, new Point(1d, 2500d), "st_makepoint(1,2500)"); + new TestCaseData(Ordinates.None, new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }), "st_multi(st_makepoint(1, 1))") + .SetName("MultiPoint"), - yield return new TestCaseData( + new TestCaseData( Ordinates.None, new LineString(new[] { new Coordinate(1d, 1d), new Coordinate(1d, 2500d) }), - "st_makeline(st_makepoint(1,1),st_makepoint(1,2500))" - ); + "st_makeline(st_makepoint(1,1),st_makepoint(1,2500))") + .SetName("LineString"), + + new TestCaseData( + Ordinates.None, + new MultiLineString(new[] + { + new LineString(new[] + { + new Coordinate(1d, 1d), + new Coordinate(1d, 2500d) + }) + }), + "st_multi(st_makeline(st_makepoint(1,1),st_makepoint(1,2500)))") + .SetName("MultiLineString"), - yield return new TestCaseData( + new TestCaseData( Ordinates.None, new Polygon( new LinearRing(new[] @@ -44,29 +51,10 @@ public static IEnumerable TestCases { new Coordinate(1d, 1d) }) ), - "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))" - ); - - yield return new TestCaseData( - Ordinates.None, - new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }), - "st_multi(st_makepoint(1, 1))" - ); + "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))") + .SetName("Polygon"), - yield return new TestCaseData( - Ordinates.None, - new MultiLineString(new[] - { - new LineString(new[] - { - new Coordinate(1d, 1d), - new Coordinate(1d, 2500d) - }) - }), - "st_multi(st_makeline(st_makepoint(1,1),st_makepoint(1,2500)))" - ); - - yield return new TestCaseData( + new TestCaseData( Ordinates.None, new MultiPolygon(new[] { @@ -80,16 +68,13 @@ public static IEnumerable TestCases { }) ) }), - "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))" - ); + "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))") + .SetName("MultiPolygon"), - yield return new TestCaseData( - Ordinates.None, - GeometryCollection.Empty, - "st_geomfromtext('GEOMETRYCOLLECTION EMPTY')" - ); + new TestCaseData(Ordinates.None, GeometryCollection.Empty, "st_geomfromtext('GEOMETRYCOLLECTION EMPTY')") + .SetName("EmptyCollection"), - yield return new TestCaseData( + new TestCaseData( Ordinates.None, new GeometryCollection(new Geometry[] { @@ -107,10 +92,10 @@ public static IEnumerable TestCases { ) }) }), - "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))))" - ); + "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))))") + .SetName("Collection"), - yield return new TestCaseData( + new TestCaseData( Ordinates.None, new GeometryCollection(new Geometry[] { @@ -132,26 +117,26 @@ public static IEnumerable TestCases { }) }) }), - "st_collect(st_makepoint(1,1),st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))))" - ); + "st_collect(st_makepoint(1,1),st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))))") + .SetName("CollectionNested"), - yield return new TestCaseData(Ordinates.XYZ, new Point(1d, 2d, 3d), "st_makepoint(1,2,3)"); + new TestCaseData(Ordinates.XYZ, new Point(1d, 2d, 3d), "st_makepoint(1,2,3)") + .SetName("PointXYZ"), - yield return new TestCaseData( + new TestCaseData( Ordinates.XYZM, new Point( new DotSpatialAffineCoordinateSequence(new[] { 1d, 2d }, new[] { 3d }, new[] { 4d }), GeometryFactory.Default), - "st_makepoint(1,2,3,4)" - ); - } - } + "st_makepoint(1,2,3,4)") + .SetName("PointXYZM") + }; [Test, TestCaseSource(nameof(TestCases))] public async Task Read(Ordinates ordinates, Geometry geometry, string sqlRepresentation) { - using var conn = await OpenConnectionAsync(); - using var cmd = conn.CreateCommand(); + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); cmd.CommandText = $"SELECT {sqlRepresentation}"; Assert.That(Equals(cmd.ExecuteScalar(), geometry)); } @@ -159,8 +144,8 @@ public async Task Read(Ordinates ordinates, Geometry geometry, string sqlReprese [Test, TestCaseSource(nameof(TestCases))] public async Task Write(Ordinates ordinates, Geometry geometry, string sqlRepresentation) { - using var conn = await OpenConnectionAsync(handleOrdinates: ordinates); - using var cmd = conn.CreateCommand(); + await using var conn = await OpenConnectionAsync(handleOrdinates: ordinates); + await using var cmd = conn.CreateCommand(); cmd.Parameters.AddWithValue("p1", geometry); cmd.CommandText = $"SELECT st_asewkb(@p1) = st_asewkb({sqlRepresentation})"; Assert.That(cmd.ExecuteScalar(), Is.True); @@ -172,7 +157,7 @@ public async Task Array() var point = new Point(new Coordinate(1d, 1d)); await AssertType( - NtsDataSource, + DataSource, new Geometry[] { point }, '{' + GetSqlLiteral(point) + '}', "geometry[]", @@ -183,9 +168,9 @@ await AssertType( [Test] public async Task Read_as_concrete_type() { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT st_makepoint(1,1)", conn); - using var reader = cmd.ExecuteReader(); + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_makepoint(1,1)", conn); + await using var reader = cmd.ExecuteReader(); reader.Read(); Assert.That(reader.GetFieldValue(0), Is.EqualTo(new Point(new Coordinate(1d, 1d)))); Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); @@ -195,16 +180,16 @@ public async Task Read_as_concrete_type() public async Task Roundtrip_geometry_geography() { var point = new Point(new Coordinate(1d, 1d)); - using var conn = await OpenConnectionAsync(); - conn.ExecuteNonQuery("CREATE TEMP TABLE data (geom GEOMETRY, geog GEOGRAPHY)"); - using (var cmd = new NpgsqlCommand("INSERT INTO data (geom, geog) VALUES (@p, @p)", conn)) + await using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE data (geom GEOMETRY, geog GEOGRAPHY)"); + await using (var cmd = new NpgsqlCommand("INSERT INTO data (geom, geog) VALUES (@p, @p)", conn)) { cmd.Parameters.AddWithValue("@p", point); cmd.ExecuteNonQuery(); } - using (var cmd = new NpgsqlCommand("SELECT geom, geog FROM data", conn)) - using (var reader = cmd.ExecuteReader()) + await using (var cmd = new NpgsqlCommand("SELECT geom, geog FROM data", conn)) + await using (var reader = cmd.ExecuteReader()) { reader.Read(); Assert.That(reader[0], Is.EqualTo(point)); @@ -215,7 +200,7 @@ public async Task Roundtrip_geometry_geography() [Test, Explicit] public async Task Concurrency_test() { - await using var adminConnection = OpenConnection(); + await using var adminConnection = await OpenConnectionAsync(); var table = await CreateTempTable( adminConnection, "point GEOMETRY, linestring GEOMETRY, polygon GEOMETRY, " + @@ -324,7 +309,7 @@ protected ValueTask OpenConnectionAsync(string? connectionStri }); if (handleOrdinates == Ordinates.XY) - NtsDataSource = dataSource; + _xyDataSource ??= dataSource; return dataSource.OpenConnectionAsync(); } @@ -343,6 +328,8 @@ public async Task SetUp() public async Task Teardown() => await Task.WhenAll(NtsDataSources.Values.Select(async ds => await ds.DisposeAsync())); - NpgsqlDataSource NtsDataSource = default!; + protected override NpgsqlDataSource DataSource => _xyDataSource ?? throw new InvalidOperationException(); + NpgsqlDataSource? _xyDataSource; + ConcurrentDictionary NtsDataSources = new(); } diff --git a/test/Npgsql.NodaTime.Tests/NodaTimeInfinityTests.cs b/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs similarity index 78% rename from test/Npgsql.NodaTime.Tests/NodaTimeInfinityTests.cs rename to test/Npgsql.PluginTests/NodaTimeInfinityTests.cs index b719449e1d..59f581e7de 100644 --- a/test/Npgsql.NodaTime.Tests/NodaTimeInfinityTests.cs +++ b/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs @@ -7,32 +7,50 @@ using NUnit.Framework; using static Npgsql.NodaTime.Internal.NodaTimeUtils; -namespace Npgsql.NodaTime.Tests; +namespace Npgsql.PluginTests; [TestFixture(false)] #if DEBUG [TestFixture(true)] -[NonParallelizable] +[NonParallelizable] // Since this test suite manipulates an AppContext switch #endif -public class NodaTimeInfinityTests : TestBase +public class NodaTimeInfinityTests : TestBase, IDisposable { [Test] // #4715 public async Task DateRange_with_upper_bound_infinity() { - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) return; await AssertType( new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue), "[-infinity,infinity]", "daterange", - NpgsqlDbType.DateRange); + NpgsqlDbType.DateRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] {new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue)}, + """{"[-infinity,infinity]"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] {new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue)}, + """{[-infinity,infinity]}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); } [Test] public async Task Timestamptz_read_values() { - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) return; await using var conn = await OpenConnectionAsync(); @@ -50,7 +68,7 @@ public async Task Timestamptz_read_values() [Test] public async Task Timestamptz_write_values() { - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) return; await using var conn = await OpenConnectionAsync(); @@ -83,7 +101,7 @@ public async Task Timestamptz_write() Parameters = { new() { Value = Instant.MinValue, NpgsqlDbType = NpgsqlDbType.TimestampTz } } }; - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) { // NodaTime Instant.MinValue is outside the PG timestamp range. Assert.That(async () => await cmd.ExecuteScalarAsync(), @@ -100,7 +118,7 @@ public async Task Timestamptz_write() Parameters = { new() { Value = Instant.MaxValue, NpgsqlDbType = NpgsqlDbType.TimestampTz } } }; - Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(DisableDateTimeInfinityConversions ? "9999-12-31 23:59:59.999999" : "infinity")); + Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(Statics.DisableDateTimeInfinityConversions ? "9999-12-31 23:59:59.999999" : "infinity")); } [Test] @@ -113,7 +131,7 @@ public async Task Timestamptz_read() await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) { Assert.That(() => reader[0], Throws.Exception.TypeOf()); Assert.That(() => reader[1], Throws.Exception.TypeOf()); @@ -130,14 +148,12 @@ public async Task Timestamp_write() { await using var conn = await OpenConnectionAsync(); - // TODO: Switch to use LocalDateTime.MinMaxValue when available (#4061) - await using var cmd = new NpgsqlCommand("SELECT $1::text", conn) { - Parameters = { new() { Value = LocalDate.MinIsoValue + LocalTime.MinValue, NpgsqlDbType = NpgsqlDbType.Timestamp } } + Parameters = { new() { Value = LocalDateTime.MinIsoValue, NpgsqlDbType = NpgsqlDbType.Timestamp } } }; - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) { // NodaTime LocalDateTime.MinValue is outside the PG timestamp range. Assert.That(async () => await cmd.ExecuteScalarAsync(), @@ -151,10 +167,10 @@ public async Task Timestamp_write() await using var cmd2 = new NpgsqlCommand("SELECT $1::text", conn) { - Parameters = { new() { Value = LocalDate.MaxIsoValue + LocalTime.MaxValue, NpgsqlDbType = NpgsqlDbType.Timestamp } } + Parameters = { new() { Value = LocalDateTime.MaxIsoValue, NpgsqlDbType = NpgsqlDbType.Timestamp } } }; - Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(DisableDateTimeInfinityConversions + Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(Statics.DisableDateTimeInfinityConversions ? "9999-12-31 23:59:59.999999" : "infinity")); } @@ -169,16 +185,15 @@ public async Task Timestamp_read() await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) { Assert.That(() => reader[0], Throws.Exception.TypeOf()); Assert.That(() => reader[1], Throws.Exception.TypeOf()); } else { - // TODO: Switch to use LocalDateTime.MinMaxValue when available (#4061) - Assert.That(reader[0], Is.EqualTo(LocalDate.MinIsoValue + LocalTime.MinValue)); - Assert.That(reader[1], Is.EqualTo(LocalDate.MaxIsoValue + LocalTime.MaxValue)); + Assert.That(reader[0], Is.EqualTo(LocalDateTime.MinIsoValue)); + Assert.That(reader[1], Is.EqualTo(LocalDateTime.MaxIsoValue)); } } @@ -193,7 +208,7 @@ public async Task Date_write() }; // LocalDate.MinIsoValue is outside of the PostgreSQL date range - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf() .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.DatetimeFieldOverflow)); @@ -202,7 +217,7 @@ public async Task Date_write() cmd.Parameters[0].Value = LocalDate.MaxIsoValue; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(DisableDateTimeInfinityConversions ? "9999-12-31" : "infinity")); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(Statics.DisableDateTimeInfinityConversions ? "9999-12-31" : "infinity")); } [Test] @@ -215,7 +230,7 @@ public async Task Date_read() await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) { Assert.That(() => reader[0], Throws.Exception.TypeOf()); Assert.That(() => reader[1], Throws.Exception.TypeOf()); @@ -230,7 +245,7 @@ public async Task Date_read() [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] public async Task DateConvertInfinity() { - if (DisableDateTimeInfinityConversions) + if (Statics.DisableDateTimeInfinityConversions) return; await using var conn = await OpenConnectionAsync(); @@ -266,20 +281,11 @@ public async Task DateConvertInfinity() } } - protected override async ValueTask OpenConnectionAsync() - { - var conn = await base.OpenConnectionAsync(); - await conn.ExecuteNonQueryAsync("SET TimeZone='Europe/Berlin'"); - return conn; - } - - protected override NpgsqlConnection OpenConnection() - => throw new NotSupportedException(); + protected override NpgsqlDataSource DataSource { get; } public NodaTimeInfinityTests(bool disableDateTimeInfinityConversions) { #if DEBUG - DisableDateTimeInfinityConversions = disableDateTimeInfinityConversions; Statics.DisableDateTimeInfinityConversions = disableDateTimeInfinityConversions; #else if (disableDateTimeInfinityConversions) @@ -288,13 +294,19 @@ public NodaTimeInfinityTests(bool disableDateTimeInfinityConversions) "NodaTimeInfinityTests rely on the Npgsql.DisableDateTimeInfinityConversions AppContext switch and can only be run in DEBUG builds"); } #endif + + var builder = CreateDataSourceBuilder(); + builder.UseNodaTime(); + builder.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin"; + DataSource = builder.Build(); } public void Dispose() { #if DEBUG - DisableDateTimeInfinityConversions = false; Statics.DisableDateTimeInfinityConversions = false; #endif + + DataSource.Dispose(); } } diff --git a/test/Npgsql.NodaTime.Tests/NodaTimeTests.cs b/test/Npgsql.PluginTests/NodaTimeTests.cs similarity index 64% rename from test/Npgsql.NodaTime.Tests/NodaTimeTests.cs rename to test/Npgsql.PluginTests/NodaTimeTests.cs index 1aa6784261..adccd163cc 100644 --- a/test/Npgsql.NodaTime.Tests/NodaTimeTests.cs +++ b/test/Npgsql.PluginTests/NodaTimeTests.cs @@ -2,6 +2,7 @@ using System.Data; using System.Threading.Tasks; using NodaTime; +using Npgsql.NodaTime.Properties; using Npgsql.Tests; using NpgsqlTypes; using NUnit.Framework; @@ -10,10 +11,9 @@ // ReSharper disable AccessToModifiedClosure // ReSharper disable AccessToDisposedClosure -namespace Npgsql.NodaTime.Tests; +namespace Npgsql.PluginTests; -// Since this test suite manipulates TimeZone, it is incompatible with multiplexing -public class NodaTimeTests : TestBase +public class NodaTimeTests : MultiplexingTestBase, IDisposable { #region Timestamp without time zone @@ -29,7 +29,8 @@ public class NodaTimeTests : TestBase [Test, TestCaseSource(nameof(TimestampValues))] public Task Timestamp_as_LocalDateTime(LocalDateTime localDateTime, string sqlLiteral) - => AssertType(localDateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2); + => AssertType(localDateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, + isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Timestamp_as_unspecified_DateTime() @@ -81,17 +82,41 @@ public Task Timestamp_cannot_use_as_DateTimeOffset() [Test] public Task Timestamp_cannot_write_utc_DateTime() - => AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "timestamp without time zone"); + => AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "timestamp without time zone"); [Test] - public Task Tsrange_as_NpgsqlRange_of_LocalDateTime() - => AssertType( + public async Task Tsrange_as_NpgsqlRange_of_LocalDateTime() + { + await AssertType( new NpgsqlRange( new(1998, 4, 12, 13, 26, 38), new(1998, 4, 12, 15, 26, 38)), - @"[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""]", + """["1998-04-12 13:26:38","1998-04-12 15:26:38"]""", "tsrange", - NpgsqlDbType.TimestampRange); + NpgsqlDbType.TimestampRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] { new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38), + new(1998, 4, 12, 15, 26, 38)), }, + """{"[\"1998-04-12 13:26:38\",\"1998-04-12 15:26:38\"]"}""", + "tsrange[]", + NpgsqlDbType.TimestampRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38), + new(1998, 4, 12, 15, 26, 38)), }, + """{["1998-04-12 13:26:38","1998-04-12 15:26:38"]}""", + "tsmultirange", + NpgsqlDbType.TimestampMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } [Test] public async Task Tsmultirange_as_array_of_NpgsqlRange_of_LocalDateTime() @@ -109,9 +134,10 @@ await AssertType( new(1998, 4, 13, 13, 26, 38), new(1998, 4, 13, 15, 26, 38)), }, - @"{[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""],[""1998-04-13 13:26:38"",""1998-04-13 15:26:38""]}", + """{["1998-04-12 13:26:38","1998-04-12 15:26:38"],["1998-04-13 13:26:38","1998-04-13 15:26:38"]}""", "tsmultirange", - NpgsqlDbType.TimestampMultirange); + NpgsqlDbType.TimestampMultirange, + isNpgsqlDbTypeInferredFromClrType: false); } #endregion Timestamp without time zone @@ -132,7 +158,8 @@ await AssertType( [Test, TestCaseSource(nameof(TimestamptzValues))] public Task Timestamptz_as_Instant(Instant instant, string sqlLiteral) - => AssertType(instant, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime); + => AssertType(instant, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Timestamptz_as_ZonedDateTime() @@ -142,6 +169,7 @@ public Task Timestamptz_as_ZonedDateTime() "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false, isDefaultForReading: false); [Test] @@ -152,6 +180,7 @@ public Task Timestamptz_as_OffsetDateTime() "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false, isDefaultForReading: false); [Test] @@ -190,56 +219,81 @@ public Task Timestamptz_cannot_use_as_LocalDateTime() [Test] public async Task Timestamptz_cannot_write_non_utc_ZonedDateTime() - => await AssertTypeUnsupportedWrite( + => await AssertTypeUnsupportedWrite( new LocalDateTime().InUtc().ToInstant().InZone(DateTimeZoneProviders.Tzdb["Europe/Berlin"]), "timestamp with time zone"); [Test] public async Task Timestamptz_cannot_write_non_utc_OffsetDateTime() - => await AssertTypeUnsupportedWrite(new LocalDateTime().WithOffset(Offset.FromHours(2)), "timestamp with time zone"); + => await AssertTypeUnsupportedWrite(new LocalDateTime().WithOffset(Offset.FromHours(2)), "timestamp with time zone"); [Test] public async Task Timestamptz_cannot_write_non_utc_DateTime() { - await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "timestamp with time zone"); - await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), "timestamp with time zone"); + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "timestamp with time zone"); + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), "timestamp with time zone"); } [Test] - public Task Tstzrange_as_Interval() - => AssertType( + public async Task Tstzrange_as_Interval() + { + await AssertType( new Interval( new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), - @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02"")", + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02")""", "tstzrange", - NpgsqlDbType.TimestampTzRange); + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] { new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), }, + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\")"}""", + "tstzrange[]", + NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), }, + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02")}""", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } [Test] public Task Tstzrange_with_no_end_as_Interval() => AssertType( - new Interval( - new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), null), - @"[""1998-04-12 15:26:38+02"",)", + new Interval(new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), null), + """["1998-04-12 15:26:38+02",)""", "tstzrange", - NpgsqlDbType.TimestampTzRange); + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); [Test] public Task Tstzrange_with_no_start_as_Interval() => AssertType( - new Interval( null, - new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant()), - @"(,""1998-04-12 15:26:38+02"")", + new Interval(null, new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant()), + """(,"1998-04-12 15:26:38+02")""", "tstzrange", - NpgsqlDbType.TimestampTzRange); + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); [Test] public Task Tstzrange_with_no_start_or_end_as_Interval() => AssertType( new Interval(null, null), - @"(,)", + """(,)""", "tstzrange", - NpgsqlDbType.TimestampTzRange); + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); [Test] public Task Tstzrange_as_NpgsqlRange_of_Instant() @@ -247,10 +301,11 @@ public Task Tstzrange_as_NpgsqlRange_of_Instant() new NpgsqlRange( new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), - @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""]", + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", "tstzrange", NpgsqlDbType.TimestampTzRange, - isDefaultForReading: false); + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); [Test] public Task Tstzrange_as_NpgsqlRange_of_ZonedDateTime() @@ -258,10 +313,11 @@ public Task Tstzrange_as_NpgsqlRange_of_ZonedDateTime() new NpgsqlRange( new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc()), - @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""]", + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", "tstzrange", NpgsqlDbType.TimestampTzRange, - isDefaultForReading: false); + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); [Test] public Task Tstzrange_as_NpgsqlRange_of_OffsetDateTime() @@ -269,10 +325,11 @@ public Task Tstzrange_as_NpgsqlRange_of_OffsetDateTime() new NpgsqlRange( new LocalDateTime(1998, 4, 12, 13, 26, 38).WithOffset(Offset.Zero), new LocalDateTime(1998, 4, 12, 15, 26, 38).WithOffset(Offset.Zero)), - @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""]", + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", "tstzrange", NpgsqlDbType.TimestampTzRange, - isDefaultForReading: false); + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); [Test] public async Task Tstzmultirange_as_array_of_Interval() @@ -290,9 +347,10 @@ await AssertType( new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), }, - @"{[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""),[""1998-04-13 15:26:38+02"",""1998-04-13 17:26:38+02"")}", + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"),["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02")}""", "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange); + NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false); } [Test] @@ -311,9 +369,10 @@ await AssertType( new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), }, - @"{[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""],[""1998-04-13 15:26:38+02"",""1998-04-13 17:26:38+02""]}", + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", "tstzmultirange", NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false, isDefaultForReading: false); } @@ -333,9 +392,10 @@ await AssertType( new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc(), new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc()), }, - @"{[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""],[""1998-04-13 15:26:38+02"",""1998-04-13 17:26:38+02""]}", + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", "tstzmultirange", NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false, isDefaultForReading: false); } @@ -355,9 +415,10 @@ await AssertType( new LocalDateTime(1998, 4, 13, 13, 26, 38).WithOffset(Offset.Zero), new LocalDateTime(1998, 4, 13, 15, 26, 38).WithOffset(Offset.Zero)), }, - @"{[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""],[""1998-04-13 15:26:38+02"",""1998-04-13 17:26:38+02""]}", + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", "tstzmultirange", NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false, isDefaultForReading: false); } @@ -385,9 +446,10 @@ await AssertType( null, null) }, - @"{""[\""1998-04-12 15:26:38+02\"",\""1998-04-12 17:26:38+02\"")"",""[\""1998-04-13 15:26:38+02\"",\""1998-04-13 17:26:38+02\"")"",""[\""1998-04-13 15:26:38+02\"",)"",""(,\""1998-04-13 15:26:38+02\"")"",""(,)""}", + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\")","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\")","[\"1998-04-13 15:26:38+02\",)","(,\"1998-04-13 15:26:38+02\")","(,)"}""", "tstzrange[]", NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isNpgsqlDbTypeInferredFromClrType: false, isDefaultForWriting: false); } @@ -406,9 +468,10 @@ await AssertType( new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), }, - @"{""[\""1998-04-12 15:26:38+02\"",\""1998-04-12 17:26:38+02\""]"",""[\""1998-04-13 15:26:38+02\"",\""1998-04-13 17:26:38+02\""]""}", + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\"]","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\"]"}""", "tstzrange[]", NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isNpgsqlDbTypeInferredFromClrType: false, isDefault: false); } @@ -418,7 +481,8 @@ await AssertType( [Test] public Task Date_as_LocalDate() - => AssertType(new LocalDate(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date); + => AssertType(new LocalDate(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, + isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Date_as_DateTime() @@ -429,21 +493,61 @@ public Task Date_as_int() => AssertType(7579, "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); [Test] - public Task Daterange_as_DateInterval() - => AssertType( + public async Task Daterange_as_DateInterval() + { + await AssertType( new DateInterval(new(2002, 3, 4), new(2002, 3, 6)), "[2002-03-04,2002-03-07)", "daterange", - NpgsqlDbType.DateRange); + NpgsqlDbType.DateRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // DateInterval[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] {new DateInterval(new(2002, 3, 4), new(2002, 3, 6))}, + """{"[2002-03-04,2002-03-07)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] {new DateInterval(new(2002, 3, 4), new(2002, 3, 6))}, + """{[2002-03-04,2002-03-07)}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } [Test] - public Task Daterange_as_NpgsqlRange_of_LocalDate() - => AssertType( + public async Task Daterange_as_NpgsqlRange_of_LocalDate() + { + await AssertType( new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), "[2002-03-04,2002-03-06)", "daterange", NpgsqlDbType.DateRange, - isDefaultForReading: false); + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{"[2002-03-04,2002-03-06)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{[2002-03-04,2002-03-06)}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isDefault: false, skipArrayCheck: true); + } [Test] public async Task Datemultirange_as_array_of_DateInterval() @@ -459,7 +563,8 @@ await AssertType( }, "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", "datemultirange", - NpgsqlDbType.DateMultirange); + NpgsqlDbType.DateMultirange, + isNpgsqlDbTypeInferredFromClrType: false); } [Test] @@ -477,7 +582,8 @@ await AssertType( "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", "datemultirange", NpgsqlDbType.DateMultirange, - isDefaultForReading: false); + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); } #if NET6_0_OR_GREATER @@ -486,13 +592,32 @@ public Task Date_as_DateOnly() => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForReading: false); [Test] - public Task Daterange_as_NpgsqlRange_of_DateOnly() - => AssertType( + public async Task Daterange_as_NpgsqlRange_of_DateOnly() + { + await AssertType( new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), "[2002-03-04,2002-03-06)", "daterange", NpgsqlDbType.DateRange, - isDefaultForReading: false); + isDefaultForReading: false, skipArrayCheck: true); + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{"[2002-03-04,2002-03-06)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{[2002-03-04,2002-03-06)}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isDefault: false, skipArrayCheck: true); + } #endif [Test] @@ -506,7 +631,7 @@ await AssertType( new DateInterval(new(2002, 3, 4), new(2002, 3, 5)), new DateInterval(new(2002, 3, 8), new(2002, 3, 10)) }, - @"{""[2002-03-04,2002-03-06)"",""[2002-03-08,2002-03-11)""}", + """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-11)"}""", "daterange[]", NpgsqlDbType.DateRange | NpgsqlDbType.Array, isDefaultForWriting: false); @@ -523,7 +648,7 @@ await AssertType( new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) }, - @"{""[2002-03-04,2002-03-06)"",""[2002-03-08,2002-03-11)""}", + """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-11)"}""", "daterange[]", NpgsqlDbType.DateRange | NpgsqlDbType.Array, isDefault: false); @@ -535,7 +660,8 @@ await AssertType( [Test] public Task Time_as_LocalTime() - => AssertType(new LocalTime(10, 45, 34, 500), "10:45:34.5", "time without time zone", NpgsqlDbType.Time, DbType.Time); + => AssertType(new LocalTime(10, 45, 34, 500), "10:45:34.5", "time without time zone", NpgsqlDbType.Time, DbType.Time, + isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Time_as_TimeSpan() @@ -569,7 +695,8 @@ public Task TimeTz_as_OffsetTime() new OffsetTime(new LocalTime(1, 2, 3, 4).PlusNanoseconds(5000), Offset.FromHoursAndMinutes(3, 30) + Offset.FromSeconds(5)), "01:02:03.004005+03:30:05", "time with time zone", - NpgsqlDbType.TimeTz); + NpgsqlDbType.TimeTz, + isNpgsqlDbTypeInferredFromClrType: false); [Test] public async Task TimeTz_as_DateTimeOffset() @@ -608,7 +735,8 @@ public Task Interval_as_Period() }.Build().Normalize(), "1 year 2 mons 25 days 05:06:07.008009", "interval", - NpgsqlDbType.Interval); + NpgsqlDbType.Interval, + isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Interval_as_Duration() @@ -618,24 +746,28 @@ public Task Interval_as_Duration() "5 days 00:04:03.002001", "interval", NpgsqlDbType.Interval, - isDefaultForReading: false); + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); [Test] - public Task Interval_as_Duration_with_months_fails() - => AssertTypeUnsupportedRead("2 months", "interval"); + public async Task Interval_as_Duration_with_months_fails() + { + var exception = await AssertTypeUnsupportedRead("2 months", "interval"); + Assert.That(exception.Message, Is.EqualTo(NpgsqlNodaTimeStrings.CannotReadIntervalWithMonthsAsDuration)); + } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3438")] public async Task Bug3438() { await using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); + await using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); var expected = Duration.FromSeconds(2148); cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Interval) { Value = expected }); cmd.Parameters.AddWithValue("p2", expected); - using var reader = cmd.ExecuteReader(); - reader.Read(); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); for (var i = 0; i < 2; i++) { Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Period))); @@ -646,15 +778,19 @@ public async Task Bug3438() #region Support - protected override async ValueTask OpenConnectionAsync() + protected override NpgsqlDataSource DataSource { get; } + + public NodaTimeTests(MultiplexingMode multiplexingMode) + : base(multiplexingMode) { - var conn = await base.OpenConnectionAsync(); - await conn.ExecuteNonQueryAsync("SET TimeZone='Europe/Berlin'"); - return conn; + var builder = CreateDataSourceBuilder(); + builder.UseNodaTime(); + builder.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin"; + DataSource = builder.Build(); } - protected override NpgsqlConnection OpenConnection() - => throw new NotSupportedException(); + public void Dispose() + => DataSource.Dispose(); #endregion Support } diff --git a/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj b/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj index b7e0b21a09..30dfb8ea16 100644 --- a/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj +++ b/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj @@ -1,11 +1,16 @@  + + + + + diff --git a/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs b/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs index 45bfb1a197..356d1da966 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs @@ -1,5 +1,4 @@ using AdoNet.Specification.Tests; -using Xunit; namespace Npgsql.Specification.Tests; diff --git a/test/Npgsql.Specification.Tests/Utility.cs b/test/Npgsql.Specification.Tests/Utility.cs index 9e91767d55..51bdc18dcd 100644 --- a/test/Npgsql.Specification.Tests/Utility.cs +++ b/test/Npgsql.Specification.Tests/Utility.cs @@ -1,4 +1,3 @@ -using System; using AdoNet.Specification.Tests; namespace Npgsql.Specification.Tests; diff --git a/test/Npgsql.Tests/AuthenticationTests.cs b/test/Npgsql.Tests/AuthenticationTests.cs index b0173e36a2..487bc5457c 100644 --- a/test/Npgsql.Tests/AuthenticationTests.cs +++ b/test/Npgsql.Tests/AuthenticationTests.cs @@ -7,7 +7,6 @@ using Npgsql.Properties; using Npgsql.Tests.Support; using NUnit.Framework; -using static Npgsql.Util.Statics; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; @@ -292,6 +291,15 @@ public void Password_source_precedence() Assert.That(() => dataSource4.OpenConnection(), Throws.Nothing); } + + static DeferDisposable Defer(Action action) => new(action); + } + + readonly struct DeferDisposable : IDisposable + { + readonly Action _action; + public DeferDisposable(Action action) => _action = action; + public void Dispose() => _action(); } [Test, Description("Connects with a bad password to ensure the proper error is thrown")] diff --git a/test/Npgsql.Tests/BatchTests.cs b/test/Npgsql.Tests/BatchTests.cs index 2983285f85..96013a9676 100644 --- a/test/Npgsql.Tests/BatchTests.cs +++ b/test/Npgsql.Tests/BatchTests.cs @@ -1,4 +1,3 @@ -using Npgsql.Util; using NUnit.Framework; using System; using System.Collections.Generic; diff --git a/test/Npgsql.Tests/BugTests.cs b/test/Npgsql.Tests/BugTests.cs index d702f8d0b2..6dac813475 100644 --- a/test/Npgsql.Tests/BugTests.cs +++ b/test/Npgsql.Tests/BugTests.cs @@ -9,12 +9,15 @@ using System.Threading; using System.Threading.Tasks; using System.Transactions; +using Npgsql.Internal.Postgres; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; public class BugTests : TestBase { + static uint ByteaOid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Bytea].Value; + #region Sequential reader bugs [Test, Description("In sequential access, performing a null check on a non-first field would check the first field")] @@ -71,18 +74,6 @@ public void Many_parameters_with_mixed_FormatCode() .Or.EqualTo(PostgresErrorCodes.TooManyColumns)); // PostgreSQL 14.5, 13.8, 12.12, 11.17 and 10.22 changed the returned error } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1238")] - public void Record_with_non_int_field() - { - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand("SELECT ('one'::TEXT, 2)", conn); - using var reader = cmd.ExecuteReader(); - reader.Read(); - var record = reader.GetFieldValue(0); - Assert.That(record[0], Is.EqualTo("one")); - Assert.That(record[1], Is.EqualTo(2)); - } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1450")] public void Bug1450() { @@ -1201,7 +1192,7 @@ LANGUAGE plpgsql AS END; $$;"); - Assert.ThrowsAsync(async () => await connection.ExecuteScalarAsync($"SELECT {func}(0)")); + Assert.ThrowsAsync(async () => await connection.ExecuteScalarAsync($"SELECT {func}(0)")); } [Test] @@ -1370,7 +1361,7 @@ public async Task Bug4099() await server .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteRowDescription(new FieldDescription(ByteaOid)) .WriteDataRowWithFlush(data); var otherData = new byte[10]; @@ -1379,7 +1370,7 @@ await server .WriteReadyForQuery() .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteRowDescription(new FieldDescription(ByteaOid)) .WriteDataRow(otherData) .WriteCommandComplete() .WriteReadyForQuery() diff --git a/test/Npgsql.Tests/CommandParameterTests.cs b/test/Npgsql.Tests/CommandParameterTests.cs new file mode 100644 index 0000000000..aa2cb0ee15 --- /dev/null +++ b/test/Npgsql.Tests/CommandParameterTests.cs @@ -0,0 +1,207 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using NpgsqlTypes; +using NUnit.Framework; + +namespace Npgsql.Tests; + +public class CommandParameterTests : MultiplexingTestBase +{ + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.SequentialAccess)] + public async Task Input_and_output_parameters(CommandBehavior behavior) + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @c-1 AS c, @a+2 AS b", conn); + cmd.Parameters.Add(new NpgsqlParameter("a", 3)); + var b = new NpgsqlParameter { ParameterName = "b", Direction = ParameterDirection.Output }; + cmd.Parameters.Add(b); + var c = new NpgsqlParameter { ParameterName = "c", Direction = ParameterDirection.InputOutput, Value = 4 }; + cmd.Parameters.Add(c); + using (await cmd.ExecuteReaderAsync(behavior)) + { + Assert.AreEqual(5, b.Value); + Assert.AreEqual(3, c.Value); + } + } + + [Test] + public async Task Send_NpgsqlDbType_Unknown([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p::TIMESTAMP", conn); + cmd.CommandText = "SELECT @p::TIMESTAMP"; + cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Unknown) { Value = "2008-1-1" }); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetValue(0), Is.EqualTo(new DateTime(2008, 1, 1))); + } + + [Test] + public async Task Positional_parameter() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + } + + [Test] + public async Task Positional_parameters_are_not_supported_with_legacy_batching() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1; SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); + } + + [Test] + public async Task Unreferenced_named_parameter_works() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.AddWithValue("not_used", 8); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test] + public async Task Unreferenced_positional_parameter_works() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.Add(new NpgsqlParameter { Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test] + public async Task Mixing_positional_and_named_parameters_is_not_supported() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { Value = 8 }); + cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = 9 }); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4171")] + public async Task Reuse_command_with_different_parameter_placeholder_types() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + cmd.CommandText = "SELECT @p1"; + cmd.Parameters.AddWithValue("@p1", 8); + _ = await cmd.ExecuteScalarAsync(); + + cmd.CommandText = "SELECT $1"; + cmd.Parameters[0].ParameterName = null; + _ = await cmd.ExecuteScalarAsync(); + } + + [Test] + public async Task Positional_output_parameters_are_not_supported() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { Value = 8, Direction = ParameterDirection.InputOutput }); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public void Parameters_get_name() + { + var command = new NpgsqlCommand(); + + // Add parameters. + command.Parameters.Add(new NpgsqlParameter(":Parameter1", DbType.Boolean)); + command.Parameters.Add(new NpgsqlParameter(":Parameter2", DbType.Int32)); + command.Parameters.Add(new NpgsqlParameter(":Parameter3", DbType.DateTime)); + command.Parameters.Add(new NpgsqlParameter("Parameter4", DbType.DateTime)); + + var idbPrmtr = command.Parameters["Parameter1"]; + Assert.IsNotNull(idbPrmtr); + command.Parameters[0].Value = 1; + + // Get by indexers. + + Assert.AreEqual(":Parameter1", command.Parameters["Parameter1"].ParameterName); + Assert.AreEqual(":Parameter2", command.Parameters["Parameter2"].ParameterName); + Assert.AreEqual(":Parameter3", command.Parameters["Parameter3"].ParameterName); + Assert.AreEqual("Parameter4", command.Parameters["Parameter4"].ParameterName); //Should this work? + + Assert.AreEqual(":Parameter1", command.Parameters[0].ParameterName); + Assert.AreEqual(":Parameter2", command.Parameters[1].ParameterName); + Assert.AreEqual(":Parameter3", command.Parameters[2].ParameterName); + Assert.AreEqual("Parameter4", command.Parameters[3].ParameterName); + } + + [Test] + public async Task Same_param_multiple_times() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p1, @p1", conn); + cmd.Parameters.AddWithValue("@p1", 8); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(reader[1], Is.EqualTo(8)); + } + + [Test] + public async Task Generic_parameter() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn); + cmd.Parameters.Add(new NpgsqlParameter("p1", 8)); + cmd.Parameters.Add(new NpgsqlParameter("p2", 8) { NpgsqlDbType = NpgsqlDbType.Integer }); + cmd.Parameters.Add(new NpgsqlParameter("p3", "hello")); + cmd.Parameters.Add(new NpgsqlParameter("p4", new[] { 'f', 'o', 'o' })); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + Assert.That(reader.GetInt32(1), Is.EqualTo(8)); + Assert.That(reader.GetString(2), Is.EqualTo("hello")); + Assert.That(reader.GetString(3), Is.EqualTo("foo")); + } + + [Test] + [TestCase(false)] + [TestCase(true)] + public async Task Parameter_must_be_set(bool genericParam) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1::TEXT", conn); + cmd.Parameters.Add( + genericParam + ? new NpgsqlParameter("p1", null) + : new NpgsqlParameter("p1", null) + ); + + Assert.That(async () => await cmd.ExecuteReaderAsync(), + Throws.Exception + .TypeOf() + .With.Message.EqualTo("Parameter 'p1' must have either its NpgsqlDbType or its DataTypeName or its Value set.")); + } + + [Test] + public async Task Object_generic_param_does_runtime_lookup() + { + await AssertTypeWrite(1, "1", "integer", NpgsqlDbType.Integer, DbType.Int32, DbType.Int32, isDefault: false, + isNpgsqlDbTypeInferredFromClrType: true, skipArrayCheck: true); + await AssertTypeWrite(new[] {1, 1}, "{1,1}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefault: false, + isNpgsqlDbTypeInferredFromClrType: true, skipArrayCheck: true); + } + + public CommandParameterTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) + { + } +} diff --git a/test/Npgsql.Tests/CommandTests.cs b/test/Npgsql.Tests/CommandTests.cs index 68c559509f..6133e100c5 100644 --- a/test/Npgsql.Tests/CommandTests.cs +++ b/test/Npgsql.Tests/CommandTests.cs @@ -11,12 +11,16 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Npgsql.Internal.Postgres; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; public class CommandTests : MultiplexingTestBase { + static uint Int4Oid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Int4].Value; + static uint TextOid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Text].Value; + #region Legacy batching [Test] @@ -126,7 +130,7 @@ public async Task Multiple_statements_large_first_command() [NonParallelizable] // Disables sql rewriting public async Task Legacy_batching_is_not_supported_when_EnableSqlParsing_is_disabled() { - using var _ = DisableSqlRewriting(); + using var _ = DisableSqlRewriting(ClearDataSources); using var conn = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); @@ -134,6 +138,30 @@ public async Task Legacy_batching_is_not_supported_when_EnableSqlParsing_is_disa .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); } + [Test] + [NonParallelizable] // Disables sql rewriting + public async Task Positional_parameters_are_supported_when_EnableSqlParsing_is_disabled() + { + using var _ = DisableSqlRewriting(ClearDataSources); + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + } + + [Test] + [NonParallelizable] // Disables sql rewriting + public async Task Named_parameters_are_not_supported_when_EnableSqlParsing_is_disabled() + { + using var _ = DisableSqlRewriting(ClearDataSources); + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter("p", 8)); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } + #endregion #region Timeout @@ -402,7 +430,7 @@ public async Task Bug3466([Values(false, true)] bool isBroken) await serverMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .WriteCommandComplete() .WriteReadyForQuery() @@ -537,197 +565,6 @@ public async Task SingleRow([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepa Assert.That(reader.Read(), Is.False); } - #region Parameters - - [Test] - public async Task Positional_parameter() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1", conn); - cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); - } - - [Test] - public async Task Positional_parameters_are_not_supported_with_legacy_batching() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1; SELECT $1", conn); - cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); - Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); - } - - [Test] - [NonParallelizable] // Disables sql rewriting - public async Task Positional_parameters_are_supported_when_EnableSqlParsing_is_disabled() - { - using var _ = DisableSqlRewriting(); - - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT $1", conn); - cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); - } - - [Test] - [NonParallelizable] // Disables sql rewriting - public async Task Named_parameters_are_not_supported_when_EnableSqlParsing_is_disabled() - { - using var _ = DisableSqlRewriting(); - - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p", conn); - cmd.Parameters.Add(new NpgsqlParameter("p", 8)); - Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); - } - - [Test, Description("Makes sure writing an unset parameter isn't allowed")] - public async Task Parameter_without_Value() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p", conn); - cmd.Parameters.Add(new NpgsqlParameter("@p", NpgsqlDbType.Integer)); - Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); - } - - [Test] - public async Task Unreferenced_named_parameter_works() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT 1", conn); - cmd.Parameters.AddWithValue("not_used", 8); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); - } - - [Test] - public async Task Unreferenced_positional_parameter_works() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT 1", conn); - cmd.Parameters.Add(new NpgsqlParameter { Value = 8 }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); - } - - [Test] - public async Task Mixing_positional_and_named_parameters_is_not_supported() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1, @p", conn); - cmd.Parameters.Add(new NpgsqlParameter { Value = 8 }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = 9 }); - Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); - } - - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/4171")] - public async Task Cached_command_clears_parameters_placeholder_type() - { - await using var conn = await OpenConnectionAsync(); - - await using (var cmd1 = conn.CreateCommand()) - { - cmd1.CommandText = "SELECT @p1"; - cmd1.Parameters.AddWithValue("@p1", 8); - await using var reader1 = await cmd1.ExecuteReaderAsync(); - reader1.Read(); - Assert.That(reader1[0], Is.EqualTo(8)); - } - - await using (var cmd2 = conn.CreateCommand()) - { - cmd2.CommandText = "SELECT $1"; - cmd2.Parameters.AddWithValue(8); - await using var reader2 = await cmd2.ExecuteReaderAsync(); - reader2.Read(); - Assert.That(reader2[0], Is.EqualTo(8)); - } - } - - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/4171")] - public async Task Reuse_command_with_different_parameter_placeholder_types() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = conn.CreateCommand(); - - cmd.CommandText = "SELECT @p1"; - cmd.Parameters.AddWithValue("@p1", 8); - _ = await cmd.ExecuteScalarAsync(); - - cmd.CommandText = "SELECT $1"; - cmd.Parameters[0].ParameterName = null; - _ = await cmd.ExecuteScalarAsync(); - } - - [Test] - public async Task Positional_output_parameters_are_not_supported() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1", conn); - cmd.Parameters.Add(new NpgsqlParameter { Value = 8, Direction = ParameterDirection.InputOutput }); - Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); - } - - [Test] - public void Parameters_get_name() - { - var command = new NpgsqlCommand(); - - // Add parameters. - command.Parameters.Add(new NpgsqlParameter(":Parameter1", DbType.Boolean)); - command.Parameters.Add(new NpgsqlParameter(":Parameter2", DbType.Int32)); - command.Parameters.Add(new NpgsqlParameter(":Parameter3", DbType.DateTime)); - command.Parameters.Add(new NpgsqlParameter("Parameter4", DbType.DateTime)); - - var idbPrmtr = command.Parameters["Parameter1"]; - Assert.IsNotNull(idbPrmtr); - command.Parameters[0].Value = 1; - - // Get by indexers. - - Assert.AreEqual(":Parameter1", command.Parameters["Parameter1"].ParameterName); - Assert.AreEqual(":Parameter2", command.Parameters["Parameter2"].ParameterName); - Assert.AreEqual(":Parameter3", command.Parameters["Parameter3"].ParameterName); - Assert.AreEqual("Parameter4", command.Parameters["Parameter4"].ParameterName); //Should this work? - - Assert.AreEqual(":Parameter1", command.Parameters[0].ParameterName); - Assert.AreEqual(":Parameter2", command.Parameters[1].ParameterName); - Assert.AreEqual(":Parameter3", command.Parameters[2].ParameterName); - Assert.AreEqual("Parameter4", command.Parameters[3].ParameterName); - } - - [Test] - public async Task Same_param_multiple_times() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p1", conn); - cmd.Parameters.AddWithValue("@p1", 8); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader[0], Is.EqualTo(8)); - Assert.That(reader[1], Is.EqualTo(8)); - } - - [Test] - public async Task Generic_parameter() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn); - cmd.Parameters.Add(new NpgsqlParameter("p1", 8)); - cmd.Parameters.Add(new NpgsqlParameter("p2", 8) { NpgsqlDbType = NpgsqlDbType.Integer }); - cmd.Parameters.Add(new NpgsqlParameter("p3", "hello")); - cmd.Parameters.Add(new NpgsqlParameter("p4", new[] { 'f', 'o', 'o' })); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(8)); - Assert.That(reader.GetInt32(1), Is.EqualTo(8)); - Assert.That(reader.GetString(2), Is.EqualTo("hello")); - Assert.That(reader.GetString(3), Is.EqualTo("foo")); - } - - #endregion Parameters - [Test] public async Task CommandText_not_set() { @@ -834,6 +671,31 @@ public async Task Parameter_and_operator_unclear() Assert.AreEqual(rdr.GetInt32(0), 4); } + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4171")] + public async Task Cached_command_clears_parameters_placeholder_type() + { + await using var conn = await OpenConnectionAsync(); + + await using (var cmd1 = conn.CreateCommand()) + { + cmd1.CommandText = "SELECT @p1"; + cmd1.Parameters.AddWithValue("@p1", 8); + await using var reader1 = await cmd1.ExecuteReaderAsync(); + reader1.Read(); + Assert.That(reader1[0], Is.EqualTo(8)); + } + + await using (var cmd2 = conn.CreateCommand()) + { + cmd2.CommandText = "SELECT $1"; + cmd2.Parameters.AddWithValue(8); + await using var reader2 = await cmd2.ExecuteReaderAsync(); + reader2.Read(); + Assert.That(reader2[0], Is.EqualTo(8)); + } + } + [Test] [TestCase(CommandBehavior.Default)] [TestCase(CommandBehavior.SequentialAccess)] @@ -937,41 +799,6 @@ public async Task TableDirect() Assert.That(rdr["name"], Is.EqualTo("foo")); } - [Test] - [TestCase(CommandBehavior.Default)] - [TestCase(CommandBehavior.SequentialAccess)] - public async Task Input_and_output_parameters(CommandBehavior behavior) - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @c-1 AS c, @a+2 AS b", conn); - cmd.Parameters.Add(new NpgsqlParameter("a", 3)); - var b = new NpgsqlParameter { ParameterName = "b", Direction = ParameterDirection.Output }; - cmd.Parameters.Add(b); - var c = new NpgsqlParameter { ParameterName = "c", Direction = ParameterDirection.InputOutput, Value = 4 }; - cmd.Parameters.Add(c); - using (await cmd.ExecuteReaderAsync(behavior)) - { - Assert.AreEqual(5, b.Value); - Assert.AreEqual(3, c.Value); - } - } - - [Test] - public async Task Send_NpgsqlDbType_Unknown([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p::TIMESTAMP", conn); - cmd.CommandText = "SELECT @p::TIMESTAMP"; - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Unknown) { Value = "2008-1-1" }); - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(new DateTime(2008, 1, 1))); - } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/503")] public async Task Invalid_UTF8() @@ -1532,7 +1359,7 @@ public async Task Oversize_buffer_lost_messages() await server .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Text)) + .WriteRowDescription(new FieldDescription(TextOid)) .WriteDataRowWithFlush(Encoding.ASCII.GetBytes(new string('a', connection.Settings.ReadBufferSize * 2))); // Just to make sure we have enough space await server.FlushAsync(); @@ -1557,7 +1384,7 @@ await server await server .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Text)) + .WriteRowDescription(new FieldDescription(TextOid)) .WriteDataRow(Encoding.ASCII.GetBytes("abc")) .WriteCommandComplete() .WriteReadyForQuery() diff --git a/test/Npgsql.Tests/ConnectionTests.cs b/test/Npgsql.Tests/ConnectionTests.cs index 19fb21b693..01ce93b1d2 100644 --- a/test/Npgsql.Tests/ConnectionTests.cs +++ b/test/Npgsql.Tests/ConnectionTests.cs @@ -13,7 +13,6 @@ using System.Threading.Tasks; using Npgsql.Internal; using Npgsql.PostgresTypes; -using Npgsql.Properties; using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; @@ -1176,7 +1175,7 @@ public async Task NoTypeLoading() }; Assert.That(async () => await cmd.ExecuteScalarAsync(), - Throws.Exception.TypeOf() + Throws.Exception.TypeOf() .With.Message.EqualTo("The NpgsqlDbType 'IntegerMultirange' isn't present in your database. You may need to install an extension or upgrade to a newer version.")); } } diff --git a/test/Npgsql.Tests/CopyTests.cs b/test/Npgsql.Tests/CopyTests.cs index 41fabe6ddc..1ab6405956 100644 --- a/test/Npgsql.Tests/CopyTests.cs +++ b/test/Npgsql.Tests/CopyTests.cs @@ -1,7 +1,9 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.Specialized; using System.Data; +using System.Diagnostics; using System.IO; using System.Numerics; using System.Text; @@ -65,10 +67,9 @@ public async Task Raw_binary_roundtrip([Values(false, true)] bool async) const int iterations = 500; var table = await GetTempTableName(conn); - + await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); using (var tx = conn.BeginTransaction()) { - await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); // Preload some data into the table using (var cmd = @@ -159,14 +160,15 @@ public async Task Cancel_raw_binary_import() using var conn = await OpenConnectionAsync(); var table = await GetTempTableName(conn); await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); - - var garbage = new byte[] {1, 2, 3, 4}; - using (var s = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) + await using (var tx = await conn.BeginTransactionAsync()) { - s.Write(garbage, 0, garbage.Length); - s.Cancel(); + var garbage = new byte[] {1, 2, 3, 4}; + using (var s = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) + { + s.Write(garbage, 0, garbage.Length); + s.Cancel(); + } } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); } @@ -294,6 +296,7 @@ public async Task Binary_roundtrip([Values(false, true)] bool async) Assert.That(reader.StartRow(), Is.EqualTo(2)); Assert.That(reader.Read(), Is.EqualTo(longString)); Assert.That(reader.IsNull, Is.True); + Assert.That(reader.IsNull, Is.True); reader.Skip(); Assert.That(reader.StartRow(), Is.EqualTo(-1)); @@ -307,13 +310,15 @@ public async Task Cancel_binary_import() { using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); - - using (var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) + await using (var tx = await conn.BeginTransactionAsync()) { - writer.StartRow(); - writer.Write("Hello"); - writer.Write(8); - // No commit should rollback + using (var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write("Hello"); + writer.Write(8); + // No commit should rollback + } } Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); } @@ -525,12 +530,21 @@ public async Task Export_long_string() using (var reader = conn.BeginBinaryExport($"COPY {table} (foo1, foo2, foo3, foo4, foo5) TO STDIN BINARY")) { - for (var row = 0; row < iterations; row++) + int row, col = 0; + for (row = 0; row < iterations; row++) { Assert.That(reader.StartRow(), Is.EqualTo(5)); - for (var col = 0; col < 5; col++) - Assert.That(reader.Read().Length, Is.EqualTo(len)); + for (col = 0; col < 5; col++) + { + var str = reader.Read(); + Assert.That(str.Length, Is.EqualTo(len)); +#if NET6_0_OR_GREATER + Assert.True(str.AsSpan().IndexOfAnyExcept('x') is -1); +#endif + } } + Assert.That(row, Is.EqualTo(100)); + Assert.That(col, Is.EqualTo(5)); } } @@ -541,12 +555,13 @@ public async Task Read_bit_string() var table = await GetTempTableName(conn); await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table} (bits BIT(3), bitarray BIT(3)[]); -INSERT INTO {table} (bits, bitarray) VALUES (B'101', ARRAY[B'101', B'111'])"); +CREATE TABLE {table} (bits BIT(11), bitvector BIT(11), bitarray BIT(3)[]); +INSERT INTO {table} (bits, bitvector, bitarray) VALUES (B'00000001101', B'00000001101', ARRAY[B'101', B'111'])"); - using var reader = conn.BeginBinaryExport($"COPY {table} (bits, bitarray) TO STDIN BINARY"); + using var reader = conn.BeginBinaryExport($"COPY {table} (bits, bitvector, bitarray) TO STDIN BINARY"); reader.StartRow(); - Assert.That(reader.Read(), Is.EqualTo(new BitArray(new[] { true, false, true }))); + Assert.That(reader.Read(), Is.EqualTo(new BitArray(new[] { false, false, false, false, false, false, false, true, true, false, true }))); + Assert.That(reader.Read(), Is.EqualTo(new BitVector32(0b00000001101000000000000000000000))); Assert.That(reader.Read(), Is.EqualTo(new[] { new BitArray(new[] { true, false, true }), @@ -744,12 +759,15 @@ public async Task Write_column_out_of_bounds_throws() public async Task Cancel_raw_binary_export_when_not_consumed_and_then_Dispose() { await using var conn = await OpenConnectionAsync(); - // This must be large enough to cause Postgres to queue up CopyData messages. - var stream = conn.BeginRawBinaryCopy("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); - var buffer = new byte[32]; - await stream.ReadAsync(buffer, 0, buffer.Length); - stream.Cancel(); - Assert.DoesNotThrowAsync(async () => await stream.DisposeAsync()); + await using (var tx = await conn.BeginTransactionAsync()) + { + // This must be large enough to cause Postgres to queue up CopyData messages. + var stream = conn.BeginRawBinaryCopy("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); + var buffer = new byte[32]; + await stream.ReadAsync(buffer, 0, buffer.Length); + stream.Cancel(); + Assert.DoesNotThrowAsync(async () => await stream.DisposeAsync()); + } Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); } @@ -757,28 +775,18 @@ public async Task Cancel_raw_binary_export_when_not_consumed_and_then_Dispose() public async Task Cancel_binary_export_when_not_consumed_and_then_Dispose() { await using var conn = await OpenConnectionAsync(); - // This must be large enough to cause Postgres to queue up CopyData messages. - var exporter = conn.BeginBinaryExport("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); - await exporter.StartRowAsync(); - await exporter.ReadAsync(); - exporter.Cancel(); - Assert.DoesNotThrowAsync(async () => await exporter.DisposeAsync()); + await using (var tx = await conn.BeginTransactionAsync()) + { + // This must be large enough to cause Postgres to queue up CopyData messages. + var exporter = conn.BeginBinaryExport("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); + await exporter.StartRowAsync(); + await exporter.ReadAsync(); + exporter.Cancel(); + Assert.DoesNotThrowAsync(async () => await exporter.DisposeAsync()); + } Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/4417")] - public async Task Binary_copy_throws_for_nullable() - { - await using var conn = await OpenConnectionAsync(); - var tableName = await CreateTempTable(conn, "house_number integer"); - - await using var writer = await conn.BeginBinaryImportAsync($"COPY {tableName}(house_number) FROM STDIN BINARY"); - int? value = 1; - await writer.StartRowAsync(); - Assert.ThrowsAsync(async () => await writer.WriteAsync(value, NpgsqlDbType.Integer)); - } - [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/5110")] public async Task Binary_copy_read_char_column() @@ -836,10 +844,12 @@ public async Task Cancel_text_import() { using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); - - var writer = (NpgsqlCopyTextWriter)conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); - writer.Write("HELLO\t1\n"); - writer.Cancel(); + await using (var tx = await conn.BeginTransactionAsync()) + { + var writer = (NpgsqlCopyTextWriter)conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); + writer.Write("HELLO\t1\n"); + writer.Cancel(); + } Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); } @@ -944,12 +954,15 @@ public async Task Wrong_format_text_export() public async Task Cancel_text_export_when_not_consumed_and_then_Dispose() { await using var conn = await OpenConnectionAsync(); - // This must be large enough to cause Postgres to queue up CopyData messages. - var reader = (NpgsqlCopyTextReader) conn.BeginTextExport("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT"); - var buffer = new char[32]; - await reader.ReadAsync(buffer, 0, buffer.Length); - reader.Cancel(); - Assert.DoesNotThrow(reader.Dispose); + await using (var tx = await conn.BeginTransactionAsync()) + { + // This must be large enough to cause Postgres to queue up CopyData messages. + var reader = (NpgsqlCopyTextReader) conn.BeginTextExport("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT"); + var buffer = new char[32]; + await reader.ReadAsync(buffer, 0, buffer.Length); + reader.Cancel(); + Assert.DoesNotThrow(reader.Dispose); + } Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); } @@ -1029,7 +1042,7 @@ public async Task Write_null_values() { writer.StartRow(); writer.Write(DBNull.Value, NpgsqlDbType.Integer); - writer.Write((string?)null, NpgsqlDbType.Uuid); + writer.Write(null, NpgsqlDbType.Uuid); writer.Write(DBNull.Value); writer.Write((string?)null); var rowsWritten = writer.Complete(); @@ -1054,7 +1067,7 @@ public async Task Write_different_types() { writer.StartRow(); writer.Write(3.0, NpgsqlDbType.Integer); - writer.Write((object)new[] { 1, 2, 3 }); + writer.Write(new[] { 1, 2, 3 }); writer.StartRow(); writer.Write(3, NpgsqlDbType.Integer); writer.Write((object)new List { 4, 5, 6 }); diff --git a/test/Npgsql.Tests/FunctionTests.cs b/test/Npgsql.Tests/FunctionTests.cs index 6ca3c2db6d..37f203b812 100644 --- a/test/Npgsql.Tests/FunctionTests.cs +++ b/test/Npgsql.Tests/FunctionTests.cs @@ -4,7 +4,6 @@ using Npgsql.PostgresTypes; using NpgsqlTypes; using NUnit.Framework; -using static Npgsql.Util.Statics; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; diff --git a/test/Npgsql.Tests/GlobalTypeMapperTests.cs b/test/Npgsql.Tests/GlobalTypeMapperTests.cs new file mode 100644 index 0000000000..32c647731c --- /dev/null +++ b/test/Npgsql.Tests/GlobalTypeMapperTests.cs @@ -0,0 +1,85 @@ +using System; +using System.Threading.Tasks; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +#pragma warning disable CS0618 // GlobalTypeMapper is obsolete + +[NonParallelizable] +public class GlobalTypeMapperTests : TestBase +{ + [Test] + public async Task MapEnum() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + NpgsqlConnection.GlobalTypeMapper.MapEnum(type); + + await using var dataSource1 = CreateDataSource(); + + await using (var connection = await dataSource1.OpenConnectionAsync()) + { + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + + await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null); + } + + NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); + + // Global mapping changes have no effect on already-built data sources + await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + + // But they do affect on new data sources + await using var dataSource2 = CreateDataSource(); + await AssertType(dataSource2, "happy", "happy", type, npgsqlDbType: null, isDefault: false); + } + + [Test] + public async Task Reset() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + NpgsqlConnection.GlobalTypeMapper.MapEnum(type); + + await using var dataSource1 = CreateDataSource(); + + await using (var connection = await dataSource1.OpenConnectionAsync()) + { + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + } + + // A global mapping change has no effects on data sources which have already been built + NpgsqlConnection.GlobalTypeMapper.Reset(); + + // Global mapping changes have no effect on already-built data sources + await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + + // But they do affect on new data sources + await using var dataSource2 = CreateDataSource(); + await AssertType(dataSource2, "happy", "happy", type, npgsqlDbType: null, isDefault: false); + } + + [Test] + public void Reset_and_add_resolver() + { + NpgsqlConnection.GlobalTypeMapper.Reset(); + NpgsqlConnection.GlobalTypeMapper.AddTypeInfoResolver(new DummyResolver()); + } + + [TearDown] + public void Teardown() + => NpgsqlConnection.GlobalTypeMapper.Reset(); + + enum Mood { Sad, Ok, Happy } + + class DummyResolver : IPgTypeInfoResolver + { + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => null; + } +} diff --git a/test/Npgsql.Tests/MultipleHostsTests.cs b/test/Npgsql.Tests/MultipleHostsTests.cs index 2b2c3f5304..5de416672b 100644 --- a/test/Npgsql.Tests/MultipleHostsTests.cs +++ b/test/Npgsql.Tests/MultipleHostsTests.cs @@ -957,7 +957,7 @@ async Task Query(NpgsqlDataSource dataSource) [NonParallelizable] // Disables sql rewriting public async Task Multiple_hosts_with_disabled_sql_rewriting() { - using var _ = DisableSqlRewriting(); + using var _ = DisableSqlRewriting(ClearDataSources); var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString) { diff --git a/test/Npgsql.Tests/NotificationTests.cs b/test/Npgsql.Tests/NotificationTests.cs index 8f3810a779..9df9aba44d 100644 --- a/test/Npgsql.Tests/NotificationTests.cs +++ b/test/Npgsql.Tests/NotificationTests.cs @@ -3,7 +3,6 @@ using System.Data; using System.Threading; using System.Threading.Tasks; -using Npgsql.Internal; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; diff --git a/test/Npgsql.Tests/NpgsqlParameterTests.cs b/test/Npgsql.Tests/NpgsqlParameterTests.cs index fe9f5f96b5..1678b3b37e 100644 --- a/test/Npgsql.Tests/NpgsqlParameterTests.cs +++ b/test/Npgsql.Tests/NpgsqlParameterTests.cs @@ -1,7 +1,6 @@ using NpgsqlTypes; using NUnit.Framework; using System; -using System.Collections.Generic; using System.Data; using System.Data.Common; @@ -109,8 +108,8 @@ public void Cannot_infer_data_type_name_from_NpgsqlDbType_for_unknown_range() [Test] public void Infer_data_type_name_from_ClrType() { - var p = new NpgsqlParameter("p1", new Dictionary()); - Assert.That(p.DataTypeName, Is.EqualTo("hstore")); + var p = new NpgsqlParameter("p1", Array.Empty()); + Assert.That(p.DataTypeName, Is.EqualTo("bytea")); } [Test] diff --git a/test/Npgsql.Tests/PoolTests.cs b/test/Npgsql.Tests/PoolTests.cs index eda0bbedf7..469e98e7f6 100644 --- a/test/Npgsql.Tests/PoolTests.cs +++ b/test/Npgsql.Tests/PoolTests.cs @@ -1,11 +1,9 @@ using System; -using System.Collections.Generic; using System.Linq; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using NUnit.Framework; -using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; diff --git a/test/Npgsql.Tests/PostgresTypeTests.cs b/test/Npgsql.Tests/PostgresTypeTests.cs index 644d839697..056830cf32 100644 --- a/test/Npgsql.Tests/PostgresTypeTests.cs +++ b/test/Npgsql.Tests/PostgresTypeTests.cs @@ -1,7 +1,6 @@ using System.Linq; using System.Threading.Tasks; using Npgsql.Internal; -using Npgsql.TypeMapping; using NUnit.Framework; namespace Npgsql.Tests; @@ -70,6 +69,6 @@ public async Task Multirange() async Task GetDatabaseInfo() { await using var conn = await OpenConnectionAsync(); - return conn.NpgsqlDataSource.TypeMapper.DatabaseInfo; + return conn.NpgsqlDataSource.DatabaseInfo; } } diff --git a/test/Npgsql.Tests/ReadBufferTests.cs b/test/Npgsql.Tests/ReadBufferTests.cs index 9246479355..b9ace59606 100644 --- a/test/Npgsql.Tests/ReadBufferTests.cs +++ b/test/Npgsql.Tests/ReadBufferTests.cs @@ -1,5 +1,4 @@ using Npgsql.Internal; -using Npgsql.Util; using NUnit.Framework; using System; using System.IO; @@ -17,14 +16,14 @@ public void Skip() for (byte i = 0; i < 50; i++) Writer.WriteByte(i); - ReadBuffer.Ensure(10); + ReadBuffer.Ensure(10, async: false).GetAwaiter().GetResult(); ReadBuffer.Skip(7); Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(7)); ReadBuffer.Skip(10); - ReadBuffer.Ensure(1); + ReadBuffer.Ensure(1, async: false).GetAwaiter().GetResult(); Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(18)); ReadBuffer.Skip(20); - ReadBuffer.Ensure(1); + ReadBuffer.Ensure(1, async: false).GetAwaiter().GetResult(); Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(39)); } @@ -36,7 +35,7 @@ public void ReadSingle() Array.Reverse(bytes); Writer.Write(bytes); - ReadBuffer.Ensure(4); + ReadBuffer.Ensure(4, async: false).GetAwaiter().GetResult(); Assert.That(ReadBuffer.ReadSingle(), Is.EqualTo(expected)); } @@ -48,7 +47,7 @@ public void ReadDouble() Array.Reverse(bytes); Writer.Write(bytes); - ReadBuffer.Ensure(8); + ReadBuffer.Ensure(8, async: false).GetAwaiter().GetResult(); Assert.That(ReadBuffer.ReadDouble(), Is.EqualTo(expected)); } @@ -56,12 +55,12 @@ public void ReadDouble() public void ReadNullTerminatedString_buffered_only() { Writer - .Write(PGUtil.UTF8Encoding.GetBytes(new string("foo"))) + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("foo"))) .WriteByte(0) - .Write(PGUtil.UTF8Encoding.GetBytes(new string("bar"))) + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("bar"))) .WriteByte(0); - ReadBuffer.Ensure(1); + ReadBuffer.Ensure(1, async: false); Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("foo")); Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("bar")); @@ -70,15 +69,15 @@ public void ReadNullTerminatedString_buffered_only() [Test] public async Task ReadNullTerminatedString_with_io() { - Writer.Write(PGUtil.UTF8Encoding.GetBytes(new string("Chunked "))); - ReadBuffer.Ensure(1); + Writer.Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("Chunked "))); + await ReadBuffer.Ensure(1, async: true); var task = ReadBuffer.ReadNullTerminatedString(async: true); Assert.That(!task.IsCompleted); Writer - .Write(PGUtil.UTF8Encoding.GetBytes(new string("string"))) + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("string"))) .WriteByte(0) - .Write(PGUtil.UTF8Encoding.GetBytes(new string("bar"))) + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("bar"))) .WriteByte(0); Assert.That(task.IsCompleted); Assert.That(await task, Is.EqualTo("Chunked string")); @@ -90,7 +89,7 @@ public async Task ReadNullTerminatedString_with_io() public void SetUp() { var stream = new MockStream(); - ReadBuffer = new NpgsqlReadBuffer(null, stream, null, NpgsqlReadBuffer.DefaultSize, PGUtil.UTF8Encoding, PGUtil.RelaxedUTF8Encoding); + ReadBuffer = new NpgsqlReadBuffer(null, stream, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding, NpgsqlWriteBuffer.RelaxedUTF8Encoding); Writer = stream.Writer; } #pragma warning restore CS8625 diff --git a/test/Npgsql.Tests/ReaderNewSchemaTests.cs b/test/Npgsql.Tests/ReaderNewSchemaTests.cs index d70f772e37..7489bae711 100644 --- a/test/Npgsql.Tests/ReaderNewSchemaTests.cs +++ b/test/Npgsql.Tests/ReaderNewSchemaTests.cs @@ -1,5 +1,4 @@ -using System; -using System.Collections.ObjectModel; +using System.Collections.ObjectModel; using System.Data; using System.Linq; using System.Threading.Tasks; diff --git a/test/Npgsql.Tests/ReaderTests.cs b/test/Npgsql.Tests/ReaderTests.cs index 8f47f53aab..790b1b48e0 100644 --- a/test/Npgsql.Tests/ReaderTests.cs +++ b/test/Npgsql.Tests/ReaderTests.cs @@ -10,8 +10,7 @@ using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; using Npgsql.Tests.Support; using Npgsql.TypeMapping; @@ -28,6 +27,24 @@ namespace Npgsql.Tests; [TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.SequentialAccess)] public class ReaderTests : MultiplexingTestBase { + static uint Int4Oid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Int4].Value; + static uint ByteaOid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Bytea].Value; + + [Test] + public async Task Resumable_non_consumed_to_non_resumable() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( "SELECT 'aaaaaaaa', 1", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + + await reader.IsDBNullAsync(0); // resumable, no consumption + _ = reader.IsDBNull(0); // resumable, no consumption + await using var stream = await reader.GetStreamAsync(0); // non-resumable + if (IsSequential) + Assert.That(() => reader.GetString(0), Throws.Exception.TypeOf()); + } + [Test] public async Task Seek_columns() { @@ -1167,7 +1184,7 @@ public async Task Bug3772() pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4), new FieldDescription(PostgresTypeOIDs.Bytea)); + .WriteRowDescription(new FieldDescription(Int4Oid), new FieldDescription(ByteaOid)); var intValue = new byte[] { 0, 0, 0, 1 }; var byteValue = new byte[] { 1, 2, 3, 4 }; @@ -1209,13 +1226,19 @@ public async Task Dispose_does_not_swallow_exceptions([Values(true, false)] bool await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); + await using var tx = IsMultiplexing ? await conn.BeginTransactionAsync() : null; var pgMock = await postmasterMock.WaitForServerConnection(); + if (IsMultiplexing) + pgMock + .WriteEmptyQueryResponse() + .WriteReadyForQuery(TransactionStatus.InTransactionBlock); + // Write responses for the query, but break the connection before sending CommandComplete/ReadyForQuery await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(DefaultPgTypes.DataTypeNameMap[DataTypeNames.Int4].Value)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .FlushAsync(); @@ -1283,11 +1306,6 @@ public async Task GetBytes() Assert.That(actual, Is.EqualTo(expected)); Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); - Assert.That(() => reader.GetBytes(1, 0, null, 0, 0), Throws.Exception.TypeOf(), - "GetBytes on non-bytea"); - Assert.That(() => reader.GetBytes(1, 0, actual, 0, 1), - Throws.Exception.TypeOf(), - "GetBytes on non-bytea"); Assert.That(reader.GetString(1), Is.EqualTo("foo")); reader.GetBytes(2, 0, actual, 0, 2); // Jump to another column from the middle of the column @@ -1533,7 +1551,8 @@ public async Task GetChars() Assert.That(reader.GetChars(0, 0, actual, 0, 2), Is.EqualTo(2)); Assert.That(actual[0], Is.EqualTo(expected[0])); Assert.That(actual[1], Is.EqualTo(expected[1])); - Assert.That(reader.GetChars(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + if (!IsSequential) + Assert.That(reader.GetChars(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); // Note: Unlike with bytea, finding out the length of the column consumes it (variable-width // UTF8 encoding) Assert.That(reader.GetChars(2, 0, actual, 0, 2), Is.EqualTo(2)); @@ -1728,7 +1747,7 @@ public async Task SafeReadException() { var dataSourceBuilder = CreateDataSourceBuilder(); // Temporarily reroute integer to go to a type handler which generates SafeReadExceptions - dataSourceBuilder.AddTypeResolverFactory(new ExplodingTypeHandlerResolverFactory(safe: true)); + dataSourceBuilder.AddTypeInfoResolver(new ExplodingTypeHandlerResolver(safe: true)); await using var dataSource = dataSourceBuilder.Build(); await using var connection = await dataSource.OpenConnectionAsync(); @@ -1745,14 +1764,14 @@ public async Task Non_SafeReadException() { var dataSourceBuilder = CreateDataSourceBuilder(); // Temporarily reroute integer to go to a type handler which generates some exception - dataSourceBuilder.AddTypeResolverFactory(new ExplodingTypeHandlerResolverFactory(safe: false)); + dataSourceBuilder.AddTypeInfoResolver(new ExplodingTypeHandlerResolver(safe: false)); await using var dataSource = dataSourceBuilder.Build(); await using var connection = await dataSource.OpenConnectionAsync(); await using var cmd = new NpgsqlCommand(@"SELECT 1, 'hello'", connection); await using var reader = await cmd.ExecuteReaderAsync(Behavior); await reader.ReadAsync(); - Assert.That(() => reader.GetInt32(0), Throws.Exception.With.Message.EqualTo("Non-safe read exception as requested")); + Assert.That(() => reader.GetInt32(0), Throws.Exception.With.Message.EqualTo("Broken")); Assert.That(connection.FullState, Is.EqualTo(ConnectionState.Broken)); Assert.That(connection.State, Is.EqualTo(ConnectionState.Closed)); } @@ -1774,7 +1793,7 @@ public async Task ReadAsync_cancel_command_soft() await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .FlushAsync(); @@ -1823,7 +1842,7 @@ public async Task ReadAsync_cancel_soft() await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .FlushAsync(); @@ -1874,7 +1893,7 @@ public async Task NextResult_cancel_soft() await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .WriteCommandComplete() .FlushAsync(); @@ -1926,7 +1945,7 @@ public async Task ReadAsync_cancel_hard([Values(true, false)] bool passCancelled await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .FlushAsync(); @@ -1970,7 +1989,7 @@ public async Task NextResultAsync_cancel_hard([Values(true, false)] bool passCan await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) .WriteCommandComplete() .FlushAsync(); @@ -2018,7 +2037,7 @@ public async Task GetFieldValueAsync_sequential_cancel([Values(true, false)] boo await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteRowDescription(new FieldDescription(ByteaOid)) .WriteDataRowWithFlush(new byte[10000]); using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); @@ -2056,7 +2075,7 @@ public async Task IsDBNullAsync_sequential_cancel([Values(true, false)] bool pas await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea), new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(ByteaOid), new FieldDescription(Int4Oid)) .WriteDataRowWithFlush(new byte[10000], new byte[4]); using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); @@ -2122,7 +2141,7 @@ public async Task GetFieldValueAsync_sequential_timeout() await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteRowDescription(new FieldDescription(ByteaOid)) .WriteDataRowWithFlush(new byte[10000]); using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); @@ -2162,7 +2181,7 @@ public async Task IsDBNullAsync_sequential_timeout() await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea), new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(ByteaOid), new FieldDescription(Int4Oid)) .WriteDataRowWithFlush(new byte[10000], new byte[4]); using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); @@ -2192,7 +2211,7 @@ public async Task Bug3446() await pgMock .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(new byte[4]) .FlushAsync(); @@ -2231,52 +2250,43 @@ public ReaderTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) #region Mock Type Handlers -class ExplodingTypeHandlerResolverFactory : TypeHandlerResolverFactory +class ExplodingTypeHandlerResolver : IPgTypeInfoResolver { readonly bool _safe; - public ExplodingTypeHandlerResolverFactory(bool safe) => _safe = safe; - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) => new ExplodingTypeHandlerResolver(_safe); + public ExplodingTypeHandlerResolver(bool safe) => _safe = safe; - class ExplodingTypeHandlerResolver : TypeHandlerResolver + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) { - readonly bool _safe; - - public ExplodingTypeHandlerResolver(bool safe) => _safe = safe; + if (dataTypeName == DataTypeNames.Int4 && (type == typeof(int) || type is null)) + return new(options, new ExplodingTypeHandler(_safe), DataTypeNames.Int4); - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) => - typeName == "integer" ? new ExplodingTypeHandler(null!, _safe) : null; - public override NpgsqlTypeHandler? ResolveByClrType(Type type) => null; + return null; } } -class ExplodingTypeHandler : NpgsqlSimpleTypeHandler +class ExplodingTypeHandler : PgBufferedConverter { readonly bool _safe; - internal ExplodingTypeHandler(PostgresType postgresType, bool safe) : base(postgresType) => _safe = safe; + internal ExplodingTypeHandler(bool safe) => _safe = safe; - public override int Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - buf.ReadInt32(); + public override Size GetSize(SizeContext context, int value, ref object? writeState) + => throw new NotSupportedException(); - throw _safe - ? new Exception("Safe read exception as requested") - : buf.Connector.Break(new Exception("Non-safe read exception as requested")); - } + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); - public override int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => throw new NotSupportedException(); - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - public override void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => throw new NotSupportedException(); - - public override Task WriteObjectWithLength( - object? value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) + protected override void WriteCore(PgWriter writer, int value) => throw new NotSupportedException(); + + protected override int ReadCore(PgReader reader) + { + if (_safe) + throw new Exception("Safe read exception as requested"); + + reader.BreakConnection(); + return default; + } } #endregion diff --git a/test/Npgsql.Tests/Replication/CommonReplicationTests.cs b/test/Npgsql.Tests/Replication/CommonReplicationTests.cs index 9033d14e31..36a11b434a 100644 --- a/test/Npgsql.Tests/Replication/CommonReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/CommonReplicationTests.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Runtime.CompilerServices; diff --git a/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs b/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs index d8fd2ed3a2..8497646f9d 100644 --- a/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs @@ -33,7 +33,6 @@ namespace Npgsql.Tests.Replication; // [TestFixture(ProtocolVersion.V3, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] // [TestFixture(ProtocolVersion.V3, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] // [TestFixture(ProtocolVersion.V3, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.StreamingTransactionMode)] -[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] [NonParallelizable] // These tests aren't designed to be parallelizable public class PgOutputReplicationTests : SafeReplicationTestBase { diff --git a/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs b/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs index 732c1a3e67..5d7c633f6c 100644 --- a/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs @@ -1,5 +1,4 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using NUnit.Framework; diff --git a/test/Npgsql.Tests/SchemaTests.cs b/test/Npgsql.Tests/SchemaTests.cs index 5deee67a0d..2a143cccfd 100644 --- a/test/Npgsql.Tests/SchemaTests.cs +++ b/test/Npgsql.Tests/SchemaTests.cs @@ -535,6 +535,14 @@ await conn.ExecuteNonQueryAsync($@" Assert.That(row["data_type"], Is.EqualTo($"{schema}.{enumName}")); } + [Test] + public async Task SlimBuilder_introspection_without_unsupported_type_exceptions() + { + await using var dataSource = new NpgsqlSlimDataSourceBuilder(ConnectionString).Build(); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(() => GetSchema(conn, DbMetaDataCollectionNames.DataTypes), Throws.Nothing); + } + public SchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } // ReSharper disable MethodHasAsyncOverload diff --git a/test/Npgsql.Tests/SecurityTests.cs b/test/Npgsql.Tests/SecurityTests.cs index 9aa2ee7d50..8600942969 100644 --- a/test/Npgsql.Tests/SecurityTests.cs +++ b/test/Npgsql.Tests/SecurityTests.cs @@ -1,4 +1,6 @@ using System; +using System.IO; +using System.Runtime.InteropServices; using System.Security.Authentication; using System.Threading; using System.Threading.Tasks; @@ -153,6 +155,7 @@ public void Bug1718() csb.SslMode = SslMode.Require; }); using var conn = dataSource.OpenConnection(); + using var tx = conn.BeginTransaction(); using var cmd = CreateSleepCommand(conn, 10000); var cts = new CancellationTokenSource(1000).Token; Assert.That(async () => await cmd.ExecuteNonQueryAsync(cts), Throws.Exception @@ -276,6 +279,13 @@ public async Task Connect_with_only_non_ssl_allowed_user([Values] bool multiplex await using var conn = await dataSource.OpenConnectionAsync(); Assert.IsFalse(conn.IsSecure); } + catch (NpgsqlException ex) when (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && ex.InnerException is IOException) + { + // Windows server to windows client invites races that can cause the socket to be reset before all data can be read. + // https://www.postgresql.org/message-id/flat/90b34057-4176-7bb0-0dbb-9822a5f6425b%40greiz-reinsdorf.de + // https://www.postgresql.org/message-id/flat/16678-253e48d34dc0c376@postgresql.org + Assert.Ignore(); + } catch (Exception e) when (!IsOnBuildServer) { Console.WriteLine(e); @@ -387,20 +397,23 @@ public async Task Bug4305_Secure([Values] bool async) } await using var __ = conn; - var originalConnector = conn.Connector; - await using var cmd = conn.CreateCommand(); - cmd.CommandText = "select pg_sleep(30)"; - cmd.CommandTimeout = 3; - var ex = async - ? Assert.ThrowsAsync(() => cmd.ExecuteNonQueryAsync())! - : Assert.Throws(() => cmd.ExecuteNonQuery())!; - Assert.That(ex.InnerException, Is.TypeOf()); + await using (var tx = await conn.BeginTransactionAsync()) + { + var originalConnector = conn.Connector; - await conn.CloseAsync(); - await conn.OpenAsync(); + cmd.CommandText = "select pg_sleep(30)"; + cmd.CommandTimeout = 3; + var ex = async + ? Assert.ThrowsAsync(() => cmd.ExecuteNonQueryAsync())! + : Assert.Throws(() => cmd.ExecuteNonQuery())!; + Assert.That(ex.InnerException, Is.TypeOf()); - Assert.AreSame(originalConnector, conn.Connector); + await conn.CloseAsync(); + await conn.OpenAsync(); + + Assert.AreSame(originalConnector, conn.Connector); + } cmd.CommandText = "SELECT 1"; if (async) diff --git a/test/Npgsql.Tests/SqlQueryParserTests.cs b/test/Npgsql.Tests/SqlQueryParserTests.cs index d161823114..1044b707fc 100644 --- a/test/Npgsql.Tests/SqlQueryParserTests.cs +++ b/test/Npgsql.Tests/SqlQueryParserTests.cs @@ -1,5 +1,4 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Data; using System.Linq; using NUnit.Framework; diff --git a/test/Npgsql.Tests/Support/AssemblySetUp.cs b/test/Npgsql.Tests/Support/AssemblySetUp.cs index 851e452acb..f1619ecec4 100644 --- a/test/Npgsql.Tests/Support/AssemblySetUp.cs +++ b/test/Npgsql.Tests/Support/AssemblySetUp.cs @@ -1,7 +1,5 @@ -using Microsoft.Extensions.Logging; -using Npgsql; +using Npgsql; using Npgsql.Tests; -using Npgsql.Tests.Support; using NUnit.Framework; using System; using System.Threading; diff --git a/test/Npgsql.Tests/Support/MultiplexingTestBase.cs b/test/Npgsql.Tests/Support/MultiplexingTestBase.cs index c7483390e0..892dd79f5e 100644 --- a/test/Npgsql.Tests/Support/MultiplexingTestBase.cs +++ b/test/Npgsql.Tests/Support/MultiplexingTestBase.cs @@ -34,4 +34,4 @@ public enum MultiplexingMode { NonMultiplexing, Multiplexing -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/Support/PgPostmasterMock.cs b/test/Npgsql.Tests/Support/PgPostmasterMock.cs index 7cc33c1877..e45c1a7f28 100644 --- a/test/Npgsql.Tests/Support/PgPostmasterMock.cs +++ b/test/Npgsql.Tests/Support/PgPostmasterMock.cs @@ -7,7 +7,6 @@ using System.Threading.Channels; using System.Threading.Tasks; using Npgsql.Internal; -using Npgsql.Util; namespace Npgsql.Tests.Support; @@ -18,8 +17,8 @@ class PgPostmasterMock : IAsyncDisposable const int CancelRequestCode = 1234 << 16 | 5678; const int SslRequest = 80877103; - static readonly Encoding Encoding = PGUtil.UTF8Encoding; - static readonly Encoding RelaxedEncoding = PGUtil.RelaxedUTF8Encoding; + static readonly Encoding Encoding = NpgsqlWriteBuffer.UTF8Encoding; + static readonly Encoding RelaxedEncoding = NpgsqlWriteBuffer.RelaxedUTF8Encoding; readonly Socket _socket; readonly List _allServers = new(); diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index 6a83cc0248..0135059d0d 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -7,15 +7,19 @@ using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.Internal; +using Npgsql.Internal.Postgres; using Npgsql.TypeMapping; -using Npgsql.Util; using NUnit.Framework; namespace Npgsql.Tests.Support; class PgServerMock : IDisposable { - static readonly Encoding Encoding = PGUtil.UTF8Encoding; + static uint BoolOid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Bool].Value; + static uint Int4Oid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Int4].Value; + static uint TextOid => DefaultPgTypes.DataTypeNameMap[DataTypeNames.Text].Value; + + static readonly Encoding Encoding = NpgsqlWriteBuffer.UTF8Encoding; readonly NetworkStream _stream; readonly NpgsqlReadBuffer _readBuffer; @@ -90,12 +94,12 @@ internal Task SendMockState(MockState state) return WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bool)) + .WriteRowDescription(new FieldDescription(BoolOid)) .WriteDataRow(BitConverter.GetBytes(isStandby)) .WriteCommandComplete() .WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Text)) + .WriteRowDescription(new FieldDescription(TextOid)) .WriteDataRow(Encoding.ASCII.GetBytes(transactionReadOnly)) .WriteCommandComplete() .WriteReadyForQuery() @@ -159,7 +163,7 @@ internal Task FlushAsync() internal Task WriteScalarResponseAndFlush(int value) => WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteRowDescription(new FieldDescription(Int4Oid)) .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(value))) .WriteCommandComplete() .WriteReadyForQuery() @@ -168,7 +172,7 @@ internal Task WriteScalarResponseAndFlush(int value) internal Task WriteScalarResponseAndFlush(bool value) => WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bool)) + .WriteRowDescription(new FieldDescription(BoolOid)) .WriteDataRow(BitConverter.GetBytes(value)) .WriteCommandComplete() .WriteReadyForQuery() @@ -177,7 +181,7 @@ internal Task WriteScalarResponseAndFlush(bool value) internal Task WriteScalarResponseAndFlush(string value) => WriteParseComplete() .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Text)) + .WriteRowDescription(new FieldDescription(TextOid)) .WriteDataRow(Encoding.ASCII.GetBytes(value)) .WriteCommandComplete() .WriteReadyForQuery() @@ -219,7 +223,7 @@ internal PgServerMock WriteRowDescription(params FieldDescription[] fields) _writeBuffer.WriteUInt32(field.TypeOID); _writeBuffer.WriteInt16(field.TypeSize); _writeBuffer.WriteInt32(field.TypeModifier); - _writeBuffer.WriteInt16((short)field.FormatCode); + _writeBuffer.WriteInt16(field.DataFormat.ToFormatCode()); } return this; @@ -233,6 +237,14 @@ internal PgServerMock WriteNoData() return this; } + internal PgServerMock WriteEmptyQueryResponse() + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.EmptyQueryResponse); + _writeBuffer.WriteInt32(4); + return this; + } + internal PgServerMock WriteDataRow(params byte[][] columnValues) { CheckDisposed(); diff --git a/test/Npgsql.Tests/Support/TestBase.cs b/test/Npgsql.Tests/Support/TestBase.cs index 126a3575fd..81bac44b3e 100644 --- a/test/Npgsql.Tests/Support/TestBase.cs +++ b/test/Npgsql.Tests/Support/TestBase.cs @@ -40,12 +40,13 @@ public async Task AssertType( bool isDefaultForWriting = true, bool? isDefault = null, bool isNpgsqlDbTypeInferredFromClrType = true, - Func? comparer = null) + Func? comparer = null, + bool skipArrayCheck = false) { await using var connection = await OpenConnectionAsync(); return await AssertType( connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, isDefaultForWriting, - isDefault, isNpgsqlDbTypeInferredFromClrType, comparer); + isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck); } public async Task AssertType( @@ -60,12 +61,13 @@ public async Task AssertType( bool isDefaultForWriting = true, bool? isDefault = null, bool isNpgsqlDbTypeInferredFromClrType = true, - Func? comparer = null) + Func? comparer = null, + bool skipArrayCheck = false) { await using var connection = await dataSource.OpenConnectionAsync(); return await AssertType(connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, - isDefaultForWriting, isDefault, isNpgsqlDbTypeInferredFromClrType, comparer); + isDefaultForWriting, isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck); } public async Task AssertType( @@ -80,19 +82,27 @@ public async Task AssertType( bool isDefaultForWriting = true, bool? isDefault = null, bool isNpgsqlDbTypeInferredFromClrType = true, - Func? comparer = null) + Func? comparer = null, + bool skipArrayCheck = false) { if (isDefault is not null) isDefaultForReading = isDefaultForWriting = isDefault.Value; - await AssertTypeWrite(connection, () => value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForWriting, isNpgsqlDbTypeInferredFromClrType); - return await AssertTypeRead(connection, sqlLiteral, pgTypeName, value, isDefaultForReading, comparer); + await AssertTypeWrite(connection, () => value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForWriting, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + return await AssertTypeRead(connection, sqlLiteral, pgTypeName, value, isDefaultForReading, comparer, fieldType: null, skipArrayCheck); } - public async Task AssertTypeRead(string sqlLiteral, string pgTypeName, T expected, bool isDefault = true) + public async Task AssertTypeRead(string sqlLiteral, string pgTypeName, T expected, bool isDefault = true, bool skipArrayCheck = false) { await using var connection = await OpenConnectionAsync(); - return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault); + return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer: null, fieldType: null, skipArrayCheck); + } + + public async Task AssertTypeRead(NpgsqlDataSource dataSource, string sqlLiteral, string pgTypeName, T expected, + bool isDefault = true, Func? comparer = null, Type? fieldType = null, bool skipArrayCheck = false) + { + await using var connection = await dataSource.OpenConnectionAsync(); + return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer, fieldType, skipArrayCheck); } public async Task AssertTypeWrite( @@ -104,12 +114,13 @@ public async Task AssertTypeWrite( DbType? dbType = null, DbType? inferredDbType = null, bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true) + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) { await using var connection = await dataSource.OpenConnectionAsync(); await AssertTypeWrite(connection, () => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, - isNpgsqlDbTypeInferredFromClrType); + isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); } public Task AssertTypeWrite( @@ -120,9 +131,10 @@ public Task AssertTypeWrite( DbType? dbType = null, DbType? inferredDbType = null, bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true) + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) => AssertTypeWrite(() => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, - isNpgsqlDbTypeInferredFromClrType); + isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); public async Task AssertTypeWrite( Func valueFactory, @@ -132,10 +144,11 @@ public async Task AssertTypeWrite( DbType? dbType = null, DbType? inferredDbType = null, bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true) + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) { await using var connection = await OpenConnectionAsync(); - await AssertTypeWrite(connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, isNpgsqlDbTypeInferredFromClrType); + await AssertTypeWrite(connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); } internal static async Task AssertTypeRead( @@ -144,7 +157,35 @@ internal static async Task AssertTypeRead( string pgTypeName, T expected, bool isDefault = true, - Func? comparer = null) + Func? comparer = null, + Type? fieldType = null, + bool skipArrayCheck = false) + { + var result = await AssertTypeReadCore(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer); + + // Check the corresponding array type as well + if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + { + await AssertTypeReadCore( + connection, + ArrayLiteral(sqlLiteral), + pgTypeName + "[]", + new[] { expected, expected }, + isDefault, + comparer is null ? null : (array1, array2) => comparer(array1[0], array2[0]) && comparer(array1[1], array2[1])); + } + + return result; + } + + internal static async Task AssertTypeReadCore( + NpgsqlConnection connection, + string sqlLiteral, + string pgTypeName, + T expected, + bool isDefault = true, + Func? comparer = null, + Type? fieldType = null) { if (sqlLiteral.Contains('\'')) sqlLiteral = sqlLiteral.Replace("'", "''"); @@ -166,7 +207,7 @@ internal static async Task AssertTypeRead( if (isDefault) { // For arrays, GetFieldType always returns typeof(Array), since PG arrays can have arbitrary dimensionality - Assert.That(reader.GetFieldType(0), Is.EqualTo(dataTypeName.EndsWith("[]") ? typeof(Array) : typeof(T)), + Assert.That(reader.GetFieldType(0), Is.EqualTo(dataTypeName.EndsWith("[]") ? typeof(Array) : fieldType ?? typeof(T)), $"Got wrong result from GetFieldType when reading '{truncatedSqlLiteral}'"); } @@ -179,6 +220,38 @@ internal static async Task AssertTypeRead( } internal static async Task AssertTypeWrite( + NpgsqlConnection connection, + Func valueFactory, + string expectedSqlLiteral, + string pgTypeName, + NpgsqlDbType? npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefault = true, + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) + { + await AssertTypeWriteCore( + connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, + isNpgsqlDbTypeInferredFromClrType); + + // Check the corresponding array type as well + if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + { + await AssertTypeWriteCore( + connection, + () => new[] { valueFactory(), valueFactory() }, + ArrayLiteral(expectedSqlLiteral), + pgTypeName + "[]", + npgsqlDbType | NpgsqlDbType.Array, + dbType: null, + inferredDbType: null, + isDefault, + isNpgsqlDbTypeInferredFromClrType); + } + } + + internal static async Task AssertTypeWriteCore( NpgsqlConnection connection, Func valueFactory, string expectedSqlLiteral, @@ -198,7 +271,10 @@ internal static async Task AssertTypeWrite( // Strip any facet information (length/precision/scale) var parenIndex = pgTypeName.IndexOf('('); - var pgTypeNameWithoutFacets = parenIndex > -1 ? pgTypeName[..parenIndex] : pgTypeName; + // var pgTypeNameWithoutFacets = parenIndex > -1 ? pgTypeName[..parenIndex] : pgTypeName; + var pgTypeNameWithoutFacets = parenIndex > -1 + ? pgTypeName[..parenIndex] + pgTypeName[(pgTypeName.IndexOf(')') + 1)..] + : pgTypeName; // We test the following scenarios (between 2 and 5 in total): // 1. With NpgsqlDbType explicitly set @@ -241,14 +317,14 @@ internal static async Task AssertTypeWrite( // With (non-generic) value only p = new NpgsqlParameter { Value = valueFactory() }; cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = "Value only (non-generic)"; + errorIdentifier[++errorIdentifierIndex] = $"Value only (type {p.Value!.GetType().Name}, non-generic)"; if (isNpgsqlDbTypeInferredFromClrType) CheckInference(); // With (generic) value only p = new NpgsqlParameter { TypedValue = valueFactory() }; cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = "Value only (generic)"; + errorIdentifier[++errorIdentifierIndex] = $"Value only (type {p.Value!.GetType().Name}, generic)"; if (isNpgsqlDbTypeInferredFromClrType) CheckInference(); } @@ -294,6 +370,8 @@ public async Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName dataSource ??= DefaultDataSource; await using var conn = await dataSource.OpenConnectionAsync(); + // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. + await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", conn); await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); @@ -307,7 +385,7 @@ public Task AssertTypeUnsupportedRead(string sqlLiteral public async Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) where TException : Exception { - dataSource ??= DefaultDataSource; + dataSource ??= DataSource; await using var conn = await dataSource.OpenConnectionAsync(); await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", conn); @@ -323,9 +401,11 @@ public Task AssertTypeUnsupportedWrite(T value, string? public async Task AssertTypeUnsupportedWrite(T value, string? pgTypeName = null, NpgsqlDataSource? dataSource = null) where TException : Exception { - dataSource ??= DefaultDataSource; + dataSource ??= DataSource; await using var conn = await dataSource.OpenConnectionAsync(); + // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. + await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; await using var cmd = new NpgsqlCommand("SELECT $1", conn) { Parameters = { new() { Value = value } } @@ -352,6 +432,31 @@ public bool Equals(T? x, T? y) public int GetHashCode(T obj) => throw new NotSupportedException(); } + // For array quoting rules, see array_out in https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c + static string ArrayLiteral(string elementLiteral) + { + switch (elementLiteral) + { + case "": + elementLiteral = "\"\""; + break; + case "NULL": + elementLiteral = "\"NULL\""; + break; + default: + // Escape quotes and backslashes, quote for special chars + elementLiteral = elementLiteral.Replace("\\", "\\\\").Replace("\"", "\\\""); + if (elementLiteral.Any(c => c is '{' or '}' or ',' or '"' or '\\' || char.IsWhiteSpace(c))) + { + elementLiteral = '"' + elementLiteral + '"'; + } + + break; + } + + return $"{{{elementLiteral},{elementLiteral}}}"; + } + #endregion Type testing #region Utilities for use by tests @@ -364,16 +469,23 @@ protected virtual NpgsqlDataSourceBuilder CreateDataSourceBuilder() protected virtual NpgsqlDataSource CreateDataSource() => CreateDataSource(ConnectionString); - protected virtual NpgsqlDataSource CreateDataSource(string connectionString) + protected NpgsqlDataSource CreateDataSource(string connectionString) => NpgsqlDataSource.Create(connectionString); - protected virtual NpgsqlDataSource CreateDataSource(Action connectionStringBuilderAction) + protected NpgsqlDataSource CreateDataSource(Action connectionStringBuilderAction) { var connectionStringBuilder = new NpgsqlConnectionStringBuilder(ConnectionString); connectionStringBuilderAction(connectionStringBuilder); return NpgsqlDataSource.Create(connectionStringBuilder); } + protected NpgsqlDataSource CreateDataSource(Action configure) + { + var builder = new NpgsqlDataSourceBuilder(ConnectionString); + configure(builder); + return builder.Build(); + } + protected static NpgsqlDataSource GetDataSource(string connectionString) { if (!DataSources.TryGetValue(connectionString, out var dataSource)) @@ -412,8 +524,12 @@ protected virtual NpgsqlDataSource CreateLoggingDataSource( protected NpgsqlDataSource DefaultDataSource => GetDataSource(ConnectionString); + protected virtual NpgsqlDataSource DataSource => DefaultDataSource; + + protected void ClearDataSources() => DataSources.Clear(); + protected virtual NpgsqlConnection CreateConnection() - => DefaultDataSource.CreateConnection(); + => DataSource.CreateConnection(); protected virtual NpgsqlConnection OpenConnection() { diff --git a/test/Npgsql.Tests/TestUtil.cs b/test/Npgsql.Tests/TestUtil.cs index 1fa69cb6e1..35df4e6e4c 100644 --- a/test/Npgsql.Tests/TestUtil.cs +++ b/test/Npgsql.Tests/TestUtil.cs @@ -57,7 +57,7 @@ public static void MinimumPgVersion(NpgsqlDataSource dataSource, string minVersi MinimumPgVersion(connection, minVersion, ignoreText); } - public static void MinimumPgVersion(NpgsqlConnection conn, string minVersion, string? ignoreText = null) + public static bool MinimumPgVersion(NpgsqlConnection conn, string minVersion, string? ignoreText = null) { var min = new Version(minVersion); if (conn.PostgreSqlVersion < min) @@ -66,7 +66,10 @@ public static void MinimumPgVersion(NpgsqlConnection conn, string minVersion, st if (ignoreText != null) msg += ": " + ignoreText; Assert.Ignore(msg); + return false; } + + return true; } public static void MaximumPgVersionExclusive(NpgsqlConnection conn, string maxVersion, string? ignoreText = null) @@ -105,16 +108,24 @@ public static Task EnsureExtensionAsync(NpgsqlConnection conn, string extension, static async Task EnsureExtension(NpgsqlConnection conn, string extension, string? minVersion, bool async) { - if (minVersion != null) - MinimumPgVersion(conn, minVersion, $"The extension '{extension}' only works for PostgreSQL {minVersion} and higher."); + if (minVersion != null && !MinimumPgVersion(conn, minVersion, $"The extension '{extension}' only works for PostgreSQL {minVersion} and higher.")) + return; if (conn.PostgreSqlVersion < MinCreateExtensionVersion) Assert.Ignore($"The 'CREATE EXTENSION' command only works for PostgreSQL {MinCreateExtensionVersion} and higher."); - if (async) - await conn.ExecuteNonQueryAsync($"CREATE EXTENSION IF NOT EXISTS {extension}"); - else - conn.ExecuteNonQuery($"CREATE EXTENSION IF NOT EXISTS {extension}"); + try + { + if (async) + await conn.ExecuteNonQueryAsync($"CREATE EXTENSION IF NOT EXISTS {extension}"); + else + conn.ExecuteNonQuery($"CREATE EXTENSION IF NOT EXISTS {extension}"); + } + catch (PostgresException ex) when (ex.ConstraintName == "pg_extension_name_index") + { + // The extension is already installed, but we can race across threads. + // https://stackoverflow.com/questions/63104126/create-extention-if-not-exists-doesnt-really-check-if-extention-does-not-exis + } conn.ReloadTypes(); } @@ -154,6 +165,7 @@ static async Task IgnoreIfFeatureNotSupported(NpgsqlConnection conn, string test public static async Task EnsurePostgis(NpgsqlConnection conn) { + var isPreRelease = IsPgPrerelease(conn); try { await EnsureExtensionAsync(conn, "postgis"); @@ -161,10 +173,14 @@ public static async Task EnsurePostgis(NpgsqlConnection conn) catch (PostgresException e) when (e.SqlState == PostgresErrorCodes.UndefinedFile) { // PostGIS packages aren't available for PostgreSQL prereleases - if (IsPgPrerelease(conn)) + if (isPreRelease) { Assert.Ignore($"PostGIS could not be installed, but PostgreSQL is prerelease ({conn.ServerVersion}), ignoring test suite."); } + else + { + throw; + } } } @@ -360,12 +376,10 @@ internal static IDisposable SetCurrentCulture(CultureInfo culture) return new DeferredExecutionDisposable(() => CultureInfo.CurrentCulture = oldCulture); } - internal static IDisposable DisableSqlRewriting() + internal static IDisposable DisableSqlRewriting(Action clearDataSources) { #if DEBUG - // We clear the pools to make sure we don't accidentally reuse a pool - // Since EnableSqlRewriting is a global change - PoolManager.Reset(); + clearDataSources(); NpgsqlCommand.EnableSqlRewriting = false; return new DeferredExecutionDisposable(() => NpgsqlCommand.EnableSqlRewriting = true); #else diff --git a/test/Npgsql.Tests/TypeMapperTests.cs b/test/Npgsql.Tests/TypeMapperTests.cs index 55db858600..92db0bdea1 100644 --- a/test/Npgsql.Tests/TypeMapperTests.cs +++ b/test/Npgsql.Tests/TypeMapperTests.cs @@ -1,87 +1,15 @@ using Npgsql.Internal; -using Npgsql.Internal.TypeHandlers; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; using NUnit.Framework; using System; using System.Threading.Tasks; -using Npgsql.Internal.TypeMapping; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; public class TypeMapperTests : TestBase { -#pragma warning disable CS0618 // GlobalTypeMapper is obsolete - [Test, NonParallelizable] - public async Task Global_mapping() - { - await using var adminConnection = await OpenConnectionAsync(); - var type = await GetTempTypeName(adminConnection); - NpgsqlConnection.GlobalTypeMapper.MapEnum(type); - - try - { - await using var dataSource1 = CreateDataSource(); - - await using (var connection = await dataSource1.OpenConnectionAsync()) - { - await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - await connection.ReloadTypesAsync(); - - await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null); - } - - NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); - - // Global mapping changes have no effect on already-built data sources - await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); - - // But they do affect on new data sources - await using var dataSource2 = CreateDataSource(); - Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", type, npgsqlDbType: null)); - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); - } - } - - [Test, NonParallelizable] - public async Task Global_mapping_reset() - { - await using var adminConnection = await OpenConnectionAsync(); - var type = await GetTempTypeName(adminConnection); - NpgsqlConnection.GlobalTypeMapper.MapEnum(type); - - try - { - await using var dataSource1 = CreateDataSource(); - - await using (var connection = await dataSource1.OpenConnectionAsync()) - { - await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - await connection.ReloadTypesAsync(); - } - - // A global mapping change has no effects on data sources which have already been built - NpgsqlConnection.GlobalTypeMapper.Reset(); - - // Global mapping changes have no effect on already-built data sources - await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); - - // But they do affect on new data sources - await using var dataSource2 = CreateDataSource(); - Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", type, npgsqlDbType: null)); - } - finally - { - NpgsqlConnection.GlobalTypeMapper.Reset(); - } - } -#pragma warning restore CS0618 // GlobalTypeMapper is obsolete - [Test] public async Task ReloadTypes_across_connections_in_data_source() { @@ -91,7 +19,7 @@ public async Task ReloadTypes_across_connections_in_data_source() // via the data source. var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.MapEnum(); + dataSourceBuilder.MapEnum(type); await using var dataSource = dataSourceBuilder.Build(); await using var connection1 = await dataSource.OpenConnectionAsync(); await using var connection2 = await dataSource.OpenConnectionAsync(); @@ -101,8 +29,8 @@ public async Task ReloadTypes_across_connections_in_data_source() // The data source type mapper has been replaced and connection1 should have the new mapper, but connection2 should retain the older // type mapper - where there's no mapping - as long as it's still open + Assert.ThrowsAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); Assert.DoesNotThrowAsync(async () => await connection1.ExecuteScalarAsync($"SELECT 'happy'::{type}")); - Assert.ThrowsAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); // Close connection2 and reopen to make sure it picks up the new type and mapping from the data source var connId = connection2.ProcessID; @@ -121,7 +49,7 @@ public async Task String_to_citext() await EnsureExtensionAsync(adminConnection, "citext"); var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.AddTypeResolverFactory(new CitextToStringTypeHandlerResolverFactory()); + dataSourceBuilder.AddTypeInfoResolver(new CitextToStringTypeHandlerResolverFactory()); await using var dataSource = dataSourceBuilder.Build(); await using var connection = await dataSource.OpenConnectionAsync(); @@ -162,25 +90,15 @@ await conn.ExecuteNonQueryAsync(@$" #region Support - class CitextToStringTypeHandlerResolverFactory : TypeHandlerResolverFactory + class CitextToStringTypeHandlerResolverFactory : IPgTypeInfoResolver { - public override TypeHandlerResolver Create(TypeMapper typeMapper, NpgsqlConnector connector) - => new CitextToStringTypeHandlerResolver(connector); - - class CitextToStringTypeHandlerResolver : TypeHandlerResolver + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) { - readonly NpgsqlConnector _connector; - readonly PostgresType _pgCitextType; - - public CitextToStringTypeHandlerResolver(NpgsqlConnector connector) - { - _connector = connector; - _pgCitextType = connector.DatabaseInfo.GetPostgresTypeByName("citext"); - } - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - => type == typeof(string) ? new TextHandler(_pgCitextType, _connector.TextEncoding) : null; - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) => null; + if (type == typeof(string) || dataTypeName?.UnqualifiedName == "citext") + if (options.DatabaseInfo.TryGetPostgresTypeByName("citext", out var pgType)) + return new(options, new StringTextConverter(options.TextEncoding), options.ToCanonicalTypeId(pgType)); + + return null; } } diff --git a/test/Npgsql.Tests/Types/ArrayTests.cs b/test/Npgsql.Tests/Types/ArrayTests.cs index 5e56c75c50..6c929c07a7 100644 --- a/test/Npgsql.Tests/Types/ArrayTests.cs +++ b/test/Npgsql.Tests/Types/ArrayTests.cs @@ -1,11 +1,10 @@ using System; -using System.Collections; using System.Collections.Generic; using System.Data; using System.Linq; using System.Text; using System.Threading.Tasks; -using Npgsql.Internal.TypeHandlers; +using Npgsql.Internal.Converters; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; @@ -75,6 +74,18 @@ public async Task Array_resolution() } } + [Test] + public async Task Throws_too_many_dimensions() + { + await using var conn = CreateConnection(); + await conn.OpenAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.AddWithValue("p", new int[1, 1, 1, 1, 1, 1, 1, 1, 1]); // 9 dimensions + Assert.That( + () => cmd.ExecuteScalarAsync(), + Throws.Exception.TypeOf().With.Message.EqualTo("values (Parameter 'Postgres arrays can have at most 8 dimensions.')")); + } + [Test] public async Task Bind_int_then_array_of_int() { @@ -150,9 +161,9 @@ public async Task Nullable_ints_cannot_be_read_as_non_nullable() await using var reader = await cmd.ExecuteReaderAsync(); reader.Read(); - Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); - Assert.That(() => reader.GetFieldValue>(0), Throws.Exception.TypeOf()); - Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); + Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); + Assert.That(() => reader.GetFieldValue>(0), Throws.Exception.TypeOf()); + Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); } [Test, Description("Checks that PG arrays containing nulls are returned as set via ValueTypeArrayMode.")] @@ -184,9 +195,9 @@ public async Task Value_type_array_nullabilities(ArrayNullabilityMode mode) Assert.That(reader.GetValue(1), Is.EqualTo(new [,]{{1, 2}, {3, 4}})); reader.Read(); Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); - Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); + Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); - Assert.That(() => reader.GetValue(1), Throws.Exception.TypeOf()); + Assert.That(() => reader.GetValue(1), Throws.Exception.TypeOf()); break; case ArrayNullabilityMode.Always: reader.Read(); @@ -271,8 +282,8 @@ public async Task Wrong_array_dimensions_throws() var reader = await cmd.ExecuteReaderAsync(); reader.Read(); - var ex = Assert.Throws(() => reader.GetFieldValue(0))!; - Assert.That(ex.Message, Is.EqualTo("Cannot read an array with 1 dimension(s) from an array with 2 dimension(s)")); + var ex = Assert.Throws(() => reader.GetFieldValue(0))!; + Assert.That(ex.Message, Does.StartWith("Cannot read an array value with 2 dimensions into a collection type with 1 dimension")); } [Test, Description("Verifies that an attempt to read an Array of value types that contains null values as array of a non-nullable type fails.")] @@ -289,8 +300,8 @@ public async Task Read_null_as_non_nullable_array_throws() Assert.That( () => reader.GetFieldValue(0), - Throws.Exception.TypeOf() - .With.Message.EqualTo(ArrayHandlerCore.ReadNonNullableCollectionWithNullsExceptionMessage)); + Throws.Exception.TypeOf() + .With.Message.EqualTo(PgArrayConverter.ReadNonNullableCollectionWithNullsExceptionMessage)); } @@ -308,8 +319,8 @@ public async Task Read_null_as_non_nullable_list_throws() Assert.That( () => reader.GetFieldValue>(0), - Throws.Exception.TypeOf() - .With.Message.EqualTo(ArrayHandlerCore.ReadNonNullableCollectionWithNullsExceptionMessage)); + Throws.Exception.TypeOf() + .With.Message.EqualTo(PgArrayConverter.ReadNonNullableCollectionWithNullsExceptionMessage)); } [Test, Description("Roundtrips a large, one-dimensional array of ints that will be chunked")] @@ -435,19 +446,6 @@ public async Task Array_of_byte_arrays() Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(Array))); } - - [Test, Description("Roundtrips a non-generic IList as an array")] - // ReSharper disable once InconsistentNaming - public async Task IList_non_generic() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT @p", conn); - var expected = new ArrayList(new[] { 1, 2, 3 }); - var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Integer) { Value = expected }; - cmd.Parameters.Add(p); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(expected.ToArray())); - } - [Test, Description("Roundtrips a generic List as an array")] // ReSharper disable once InconsistentNaming public async Task IList_generic() @@ -477,11 +475,11 @@ public async Task IList_generic_fails_for_multidimensional_array() await using var reader = await cmd.ExecuteReaderAsync(); reader.Read(); Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - var exception = Assert.Throws(() => + var exception = Assert.Throws(() => { reader.GetFieldValue>(0); })!; - Assert.That(exception.Message, Is.EqualTo("Can't read multidimensional array as List")); + Assert.That(exception.Message, Does.StartWith("Cannot read an array value with 2 dimensions")); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/844")] @@ -490,19 +488,7 @@ public async Task IEnumerable_throws_friendly_exception() await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT @p1", conn); cmd.Parameters.AddWithValue("p1", Enumerable.Range(1, 3)); - Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf().With.Message.Contains("array or List")); - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/960")] - public async Task Mixed_element_types() - { - var mixedList = new ArrayList { 1, "yo" }; - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT @p1", conn); - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, mixedList); - Assert.That(async () => await cmd.ExecuteNonQueryAsync(), Throws.Exception - .TypeOf() - .With.Message.Contains("mix")); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf().With.Property("InnerException").Message.Contains("array or List")); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/960")] @@ -515,17 +501,8 @@ public async Task Jagged_arrays_not_supported() await using var cmd = new NpgsqlCommand("SELECT @p1", conn); cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, jagged); Assert.That(async () => await cmd.ExecuteNonQueryAsync(), Throws.Exception - .TypeOf() - .With.Message.Contains("jagged")); - } - - [Test, Description("Checks that ILists are properly serialized as arrays of their underlying types")] - public async Task List_type_resolution() - { - await using var conn = await OpenConnectionAsync(); - await AssertIListRoundtrips(conn, new[] { 1, 2, 3 }); - await AssertIListRoundtrips(conn, new IntList { 1, 2, 3 }); - await AssertIListRoundtrips(conn, new MisleadingIntList() { 1, 2, 3 }); + .TypeOf() + .With.Property("InnerException").Message.Contains("jagged")); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1546")] @@ -618,17 +595,6 @@ public async Task Read_two_empty_arrays() Assert.AreNotSame(reader.GetFieldValue>(0), reader.GetFieldValue>(1)); } - async Task AssertIListRoundtrips(NpgsqlConnection conn, IEnumerable value) - { - await using var cmd = new NpgsqlCommand("SELECT @p", conn); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = value }); - - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer[]")); - Assert.That(reader[0], Is.EqualTo(value.ToArray())); - } - class IntList : List { } // ReSharper disable UnusedTypeParameter class MisleadingIntList : List { } diff --git a/test/Npgsql.Tests/Types/BitStringTests.cs b/test/Npgsql.Tests/Types/BitStringTests.cs index 4a22a2a9e6..95c81ffb41 100644 --- a/test/Npgsql.Tests/Types/BitStringTests.cs +++ b/test/Npgsql.Tests/Types/BitStringTests.cs @@ -51,7 +51,7 @@ public Task BitVector32() [Test] public Task BitVector32_too_long() - => AssertTypeUnsupportedRead(new string('0', 34), "bit varying"); + => AssertTypeUnsupportedRead(new string('0', 34), "bit varying"); [Test] public Task Bool() @@ -60,8 +60,8 @@ public Task Bool() [Test] public async Task Bitstring_with_multiple_bits_as_bool_throws() { - await AssertTypeUnsupportedRead("01", "varbit"); - await AssertTypeUnsupportedRead("01", "bit(2)"); + await AssertTypeUnsupportedRead("01", "varbit"); + await AssertTypeUnsupportedRead("01", "bit(2)"); } [Test] @@ -117,16 +117,12 @@ public async Task Array_of_single_bits_and_null() } [Test] - public Task Write_as_string() - => AssertTypeWrite("010101", "010101", "bit varying", NpgsqlDbType.Varbit, isDefault: false); + public Task As_string() + => AssertType("010101", "010101", "bit varying", NpgsqlDbType.Varbit, isDefault: false); [Test] public Task Write_as_string_validation() - => AssertTypeUnsupportedWrite("001q0", "bit varying"); - - [Test] - public Task Read_as_string_is_not_supported() - => AssertTypeUnsupportedRead("010101", "bit varying"); + => AssertTypeUnsupportedWrite("001q0", "bit varying"); public BitStringTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/ByteaTests.cs b/test/Npgsql.Tests/Types/ByteaTests.cs index f29f6e490b..2db7aca492 100644 --- a/test/Npgsql.Tests/Types/ByteaTests.cs +++ b/test/Npgsql.Tests/Types/ByteaTests.cs @@ -35,36 +35,26 @@ public async Task Bytea_long() } [Test] - public Task Write_as_Memory() - => AssertTypeWrite( - new Memory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); - - [Test] - public Task Read_as_Memory_not_supported() - => AssertTypeUnsupportedRead, NotSupportedException>("\\x010203", "bytea"); + public Task AsMemory() + => AssertType( + new Memory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, + comparer: (left, right) => left.Span.SequenceEqual(right.Span)); [Test] - public Task Write_as_ReadOnlyMemory() - => AssertTypeWrite( - new ReadOnlyMemory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + public Task AsReadOnlyMemory() + => AssertType( + new ReadOnlyMemory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, + comparer: (left, right) => left.Span.SequenceEqual(right.Span)); [Test] - public Task Read_as_ReadOnlyMemory_not_supported() - => AssertTypeUnsupportedRead, NotSupportedException>("\\x010203", "bytea"); - - [Test] - public Task Write_as_ArraySegment() - => AssertTypeWrite( + public Task AsArraySegment() + => AssertType( new ArraySegment(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); - [Test] - public Task Read_as_ArraySegment_not_supported() - => AssertTypeUnsupportedRead, NotSupportedException>("\\x010203", "bytea"); - [Test] public Task Write_as_MemoryStream() => AssertTypeWrite( - () => new MemoryStream(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + () => new MemoryStream(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, skipArrayCheck: true); [Test] public Task Write_as_MemoryStream_truncated() @@ -77,7 +67,7 @@ public Task Write_as_MemoryStream_truncated() }; return AssertTypeWrite( - msFactory, "\\x020304", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + msFactory, "\\x020304", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, skipArrayCheck: true); } [Test] @@ -89,7 +79,7 @@ public async Task Write_as_MemoryStream_long() var expectedSql = "\\x" + ToHex(bytes); await AssertTypeWrite( - () => new MemoryStream(bytes), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + () => new MemoryStream(bytes), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, skipArrayCheck: true); } [Test] @@ -102,7 +92,7 @@ public async Task Write_as_FileStream() await File.WriteAllBytesAsync(filePath, new byte[] { 1, 2, 3 }); await AssertTypeWrite( - () => FileStreamFactory(filePath, fsList), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + () => FileStreamFactory(filePath, fsList), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, skipArrayCheck: true); } finally { @@ -138,7 +128,7 @@ public async Task Write_as_FileStream_long() var expectedSql = "\\x" + ToHex(bytes); await AssertTypeWrite( - () => FileStreamFactory(filePath, fsList), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + () => FileStreamFactory(filePath, fsList), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, skipArrayCheck: true); } finally { @@ -187,6 +177,7 @@ public async Task Truncate_array() var p = new NpgsqlParameter("p", data) { Size = 4 }; cmd.Parameters.Add(p); Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); + Assert.That(p.Value, Is.EqualTo(new byte[] { 1, 2, 3, 4 }), "Truncated parameter value should be persisted on the parameter per DbParameter.Size docs"); // NpgsqlParameter.Size needs to persist when value is changed byte[] data2 = { 11, 12, 13, 14, 15, 16 }; @@ -194,6 +185,7 @@ public async Task Truncate_array() Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 11, 12, 13, 14 })); // NpgsqlParameter.Size larger than the value size should mean the value size, as well as 0 and -1 + p.Value = data2; p.Size = data2.Length + 10; Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); p.Size = 0; @@ -205,7 +197,6 @@ public async Task Truncate_array() } [Test, Description("Tests that bytea stream values are truncated when the NpgsqlParameter's Size is set")] - [NonParallelizable] // The last check will break the connection, which can fail other unrelated queries in multiplexing public async Task Truncate_stream() { await using var conn = await OpenConnectionAsync(); @@ -235,13 +226,9 @@ public async Task Truncate_stream() Assert.That(() => p.Size = -2, Throws.Exception.TypeOf()); - // NpgsqlParameter.Size larger than the value size should throw - p.Size = data2.Length + 10; p.Value = new MemoryStream(data2); - var ex = Assert.ThrowsAsync(async () => await cmd.ExecuteScalarAsync())!; - Assert.That(ex.InnerException, Is.TypeOf()); - if (!IsMultiplexing) - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + p.Size = data2.Length + 10; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); } [Test] @@ -261,7 +248,7 @@ public async Task Write_as_NonSeekable_stream() p.Value = new NonSeekableStream(data); p.Size = 0; - Assert.ThrowsAsync(async () => await cmd.ExecuteScalarAsync()); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data)); Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); } diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs index 9e252e2d1b..bf9cb38241 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs @@ -62,7 +62,8 @@ public Task Read_type_with_two_properties_inverted() => [Test] public Task Read_type_with_private_property_throws() => - Read(new TypeWithPrivateProperty(), (execute, expected) => Assert.Throws(() => execute())); + Read(new TypeWithPrivateProperty(), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); [Test] public Task Read_type_with_private_getter() => @@ -99,15 +100,18 @@ public Task Read_type_with_more_properties_than_attributes() => [Test] public Task Read_type_with_less_properties_than_attributes_throws() => - Read(new TypeWithLessPropertiesThanAttributes(), (execute, expected) => Assert.Throws(() => execute())); + Read(new TypeWithLessPropertiesThanAttributes(), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); [Test] public Task Read_type_with_less_parameters_than_attributes_throws() => - Read(new TypeWithLessParametersThanAttributes(TheAnswer), (execute, expected) => Assert.Throws(() => execute())); + Read(new TypeWithLessParametersThanAttributes(TheAnswer), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); [Test] public Task Read_type_with_more_parameters_than_attributes_throws() => - Read(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik), (execute, expected) => Assert.Throws(() => execute())); + Read(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); [Test] public Task Read_type_with_one_parameter() => diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs index 938ac9f01a..a251cdd4ed 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs @@ -63,7 +63,9 @@ public Task Write_type_with_two_properties_inverted() [Test] public void Write_type_with_private_property_throws() - => Assert.ThrowsAsync(async () => await Write(new TypeWithPrivateProperty())); + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithPrivateProperty())); [Test] public void Write_type_with_private_getter_throws() @@ -95,13 +97,19 @@ public Task Write_type_with_more_properties_than_attributes() [Test] public void Write_type_with_less_properties_than_attributes_throws() - => Assert.ThrowsAsync(async () => await Write(new TypeWithLessPropertiesThanAttributes())); + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithLessPropertiesThanAttributes())); [Test] public void Write_type_with_less_parameters_than_attributes_throws() - => Assert.ThrowsAsync(async () => await Write(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik))); + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik))); [Test] public void Write_type_with_more_parameters_than_attributes_throws() - => Assert.ThrowsAsync(async () => await Write(new TypeWithLessParametersThanAttributes(TheAnswer))); + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithLessParametersThanAttributes(TheAnswer))); } diff --git a/test/Npgsql.Tests/Types/CompositeTests.cs b/test/Npgsql.Tests/Types/CompositeTests.cs index 2795968470..11f7739158 100644 --- a/test/Npgsql.Tests/Types/CompositeTests.cs +++ b/test/Npgsql.Tests/Types/CompositeTests.cs @@ -138,7 +138,8 @@ await AssertType( new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, @"(8,""(9,foo)"")", $"{secondSchemaName}.container", - npgsqlDbType: null); + npgsqlDbType: null, + isDefaultForWriting: false); await AssertType( connection, @@ -146,7 +147,7 @@ await AssertType( @"(8,""(9,foo)"")", $"{firstSchemaName}.container", npgsqlDbType: null, - isDefaultForWriting: false); + isDefaultForWriting: true); } [Test] @@ -237,6 +238,29 @@ await AssertType( npgsqlDbType: null); } + [Test] + public async Task Composite_containing_array_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {compositeType} AS (ints int4[])"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeWithArray { Ints = new[] { 1, 2, 3, 4 } }, + @"(""{1,2,3,4}"")", + compositeType, + npgsqlDbType: null, + comparer: (actual, expected) => actual.Ints!.SequenceEqual(expected.Ints!)); + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/990")] public async Task Table_as_composite([Values] bool enabled) { @@ -254,7 +278,7 @@ public async Task Table_as_composite([Values] bool enabled) await DoAssertion(); else { - Assert.ThrowsAsync(DoAssertion); + Assert.ThrowsAsync(DoAssertion); // Start a transaction specifically for multiplexing (to bind a connector to the connection) await using var tx = await connection.BeginTransactionAsync(); Assert.Null(connection.Connector!.DatabaseInfo.CompositeTypes.SingleOrDefault(c => c.Name.Contains(table))); @@ -402,6 +426,11 @@ struct SomeCompositeStruct public string SomeText { get; set; } } + class SomeCompositeWithArray + { + public int[]? Ints { get; set; } + } + record NameTranslationComposite { public int Simple { get; set; } diff --git a/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs b/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs index fc316adfee..bf2e0d0e65 100644 --- a/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs +++ b/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs @@ -22,16 +22,20 @@ public async Task TimestampTz_write() { Parameters = { - new() { Value = DateTime.MinValue, NpgsqlDbType = NpgsqlDbType.TimestampTz }, + new() + { + Value = DisableDateTimeInfinityConversions ? DateTime.MinValue.ToUniversalTime().AddYears(1) : DateTime.MinValue, + NpgsqlDbType = NpgsqlDbType.TimestampTz + }, } }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(DisableDateTimeInfinityConversions ? "0001-01-01 00:00:00" : "-infinity")); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(DisableDateTimeInfinityConversions ? "0002-01-01 00:00:00" : "-infinity")); cmd.Parameters[0].Value = DateTime.MaxValue; if (DisableDateTimeInfinityConversions) - Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); else Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("infinity")); } @@ -205,6 +209,8 @@ public DateTimeInfinityTests(bool disableDateTimeInfinityConversions) "DateTimeInfinityTests rely on the Npgsql.DisableDateTimeInfinityConversions AppContext switch and can only be run in DEBUG builds"); } #endif + // The switch is baked into the serializer options, so clear the sources on change here. + ClearDataSources(); } public void Dispose() diff --git a/test/Npgsql.Tests/Types/DateTimeTests.cs b/test/Npgsql.Tests/Types/DateTimeTests.cs index f387387dcc..7382891cf5 100644 --- a/test/Npgsql.Tests/Types/DateTimeTests.cs +++ b/test/Npgsql.Tests/Types/DateTimeTests.cs @@ -64,7 +64,21 @@ public Task Daterange_as_NpgsqlRange_of_DateOnly() "[2002-03-04,2002-03-06)", "daterange", NpgsqlDbType.DateRange, - isDefaultForReading: false); + isDefaultForReading: false, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + [Test] + public Task Daterange_array_as_NpgsqlRange_of_DateOnly_array() + => AssertType( + new[] + { + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 9), false) + }, + """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-09)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false); [Test] public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateOnly() @@ -72,7 +86,7 @@ public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateOnly() await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - await AssertType( + await AssertType( new[] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), @@ -150,11 +164,13 @@ public Task TimeTz_before_utc_zero() [Test, TestCaseSource(nameof(TimestampValues))] public Task Timestamp_as_DateTime(DateTime dateTime, string sqlLiteral) - => AssertType(dateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2); + => AssertType(dateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, + // Explicitly check kind as well. + comparer: (actual, expected) => actual.Kind == expected.Kind && actual.Equals(expected)); [Test] public Task Timestamp_cannot_write_utc_DateTime() - => AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "timestamp without time zone"); + => AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "timestamp without time zone"); [Test] public Task Timestamp_as_long() @@ -181,7 +197,25 @@ public Task Tsrange_as_NpgsqlRange_of_DateTime() new(1998, 4, 12, 15, 26, 38, DateTimeKind.Local)), @"[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""]", "tsrange", - NpgsqlDbType.TimestampRange); + NpgsqlDbType.TimestampRange, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + [Test] + public Task Tsrange_array_as_NpgsqlRange_of_DateTime_array() + => AssertType( + new[] + { + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 12, 15, 26, 38, DateTimeKind.Local)), + new NpgsqlRange( + new(1998, 4, 13, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)), + }, + """{"[\"1998-04-12 13:26:38\",\"1998-04-12 15:26:38\"]","[\"1998-04-13 13:26:38\",\"1998-04-13 15:26:38\"]"}""", + "tsrange[]", + NpgsqlDbType.TimestampRange | NpgsqlDbType.Array, + isDefault: false); [Test] public async Task Tsmultirange_as_array_of_NpgsqlRange_of_DateTime() @@ -222,7 +256,9 @@ await AssertType( [Test, TestCaseSource(nameof(TimestampTzWriteValues))] public Task Timestamptz_as_DateTime(DateTime dateTime, string sqlLiteral) - => AssertType(dateTime, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime); + => AssertType(dateTime, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + // Explicitly check kind as well. + comparer: (actual, expected) => actual.Kind == expected.Kind && actual.Equals(expected)); [Test] public async Task Timestamptz_infinity_as_DateTime() @@ -236,8 +272,8 @@ await AssertType(DateTime.MaxValue, "infinity", "timestamp with time zone", Npgs [Test] public async Task Timestamptz_cannot_write_non_utc_DateTime() { - await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "timestamp with time zone"); - await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), "timestamp with time zone"); + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "timestamp with time zone"); + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), "timestamp with time zone"); } [Test] @@ -267,7 +303,7 @@ public Task Timestamptz_as_DateTimeOffset_utc_with_DbType_DateTimeOffset() [Test] public Task Timestamptz_cannot_write_non_utc_DateTimeOffset() - => AssertTypeUnsupportedWrite(new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.FromHours(2))); + => AssertTypeUnsupportedWrite(new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.FromHours(2))); [Test] public Task Timestamptz_as_long() @@ -279,6 +315,24 @@ public Task Timestamptz_as_long() DbType.DateTime, isDefault: false); + [Test] + public async Task Timestamptz_array_as_DateTimeOffset_array() + { + var dateTimeOffsets = await AssertType( + new[] + { + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + new DateTimeOffset(1999, 4, 12, 13, 26, 38, TimeSpan.Zero) + }, + """{"1998-04-12 15:26:38+02","1999-04-12 15:26:38+02"}""", + "timestamp with time zone[]", + NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, + isDefaultForReading: false); + + Assert.That(dateTimeOffsets[0].Offset, Is.EqualTo(TimeSpan.Zero)); + Assert.That(dateTimeOffsets[1].Offset, Is.EqualTo(TimeSpan.Zero)); + } + [Test] public Task Tstzrange_as_NpgsqlRange_of_DateTime() => AssertType( @@ -287,7 +341,25 @@ public Task Tstzrange_as_NpgsqlRange_of_DateTime() new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""]", "tstzrange", - NpgsqlDbType.TimestampTzRange); + NpgsqlDbType.TimestampTzRange, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + [Test] + public Task Tstzrange_array_as_NpgsqlRange_of_DateTime_array() + => AssertType( + new[] + { + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new(1998, 4, 13, 13, 26, 38, DateTimeKind.Utc), + new(1998, 4, 13, 15, 26, 38, DateTimeKind.Utc)), + }, + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\"]","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\"]"}""", + "tstzrange[]", + NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isDefault: false); [Test] public async Task Tstzmultirange_as_array_of_NpgsqlRange_of_DateTime() @@ -312,7 +384,7 @@ await AssertType( [Test] public Task Cannot_mix_DateTime_Kinds_in_array() - => AssertTypeUnsupportedWrite(new[] + => AssertTypeUnsupportedWrite(new[] { new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), @@ -321,7 +393,7 @@ public Task Cannot_mix_DateTime_Kinds_in_array() [Test] public Task Cannot_mix_DateTime_Kinds_in_range() - => AssertTypeUnsupportedWrite(new NpgsqlRange( + => AssertTypeUnsupportedWrite, ArgumentException>(new NpgsqlRange( new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local))); @@ -331,8 +403,29 @@ public async Task Cannot_mix_DateTime_Kinds_in_multirange() await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - await AssertTypeUnsupportedWrite(new[] + await AssertTypeUnsupportedWrite[], ArgumentException>(new[] { + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), new NpgsqlRange( new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), @@ -375,6 +468,19 @@ public void NpgsqlParameterNpgsqlDbType_is_value_dependent_timestamp_or_timestam Assert.AreEqual(NpgsqlDbType.TimestampTz, dtotimestamptz.NpgsqlDbType); } + [Test] + public async Task Array_of_nullable_timestamptz() + => await AssertType( + new DateTime?[] + { + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + null + }, + @"{""1998-04-12 15:26:38+02"",NULL}", + "timestamp with time zone[]", + NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, + isDefault: false); + #endregion #region Interval @@ -415,7 +521,7 @@ public Task Interval_as_NpgsqlInterval() [Test] public Task Interval_with_months_cannot_read_as_TimeSpan() - => AssertTypeUnsupportedRead("1 month 2 days", "interval"); + => AssertTypeUnsupportedRead("1 month 2 days", "interval"); #endregion diff --git a/test/Npgsql.Tests/Types/FullTextSearchTests.cs b/test/Npgsql.Tests/Types/FullTextSearchTests.cs index ae759a295c..5ffe8b2880 100644 --- a/test/Npgsql.Tests/Types/FullTextSearchTests.cs +++ b/test/Npgsql.Tests/Types/FullTextSearchTests.cs @@ -4,7 +4,6 @@ using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; -using NUnit.Framework.Constraints; namespace Npgsql.Tests.Types; @@ -78,15 +77,15 @@ public async Task Full_text_search_supported_only_with_EnableFullTextSearch([Val } else { - var exception = await AssertTypeUnsupportedRead("a", "tsquery", dataSource); - Assert.AreEqual(errorMessage, exception.Message); - exception = await AssertTypeUnsupportedWrite(new NpgsqlTsQueryLexeme("a"), pgTypeName: null, dataSource); - Assert.AreEqual(errorMessage, exception.Message); + var exception = await AssertTypeUnsupportedRead("a", "tsquery", dataSource); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + exception = await AssertTypeUnsupportedWrite(new NpgsqlTsQueryLexeme("a"), pgTypeName: null, dataSource); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); - exception = await AssertTypeUnsupportedRead("1", "tsvector", dataSource); - Assert.AreEqual(errorMessage, exception.Message); - exception = await AssertTypeUnsupportedWrite(NpgsqlTsVector.Parse("'1'"), pgTypeName: null, dataSource); - Assert.AreEqual(errorMessage, exception.Message); + exception = await AssertTypeUnsupportedRead("1", "tsvector", dataSource); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + exception = await AssertTypeUnsupportedWrite(NpgsqlTsVector.Parse("'1'"), pgTypeName: null, dataSource); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); } } } diff --git a/test/Npgsql.Tests/Types/GeometricTypeTests.cs b/test/Npgsql.Tests/Types/GeometricTypeTests.cs index b4101cca14..d84218bd12 100644 --- a/test/Npgsql.Tests/Types/GeometricTypeTests.cs +++ b/test/Npgsql.Tests/Types/GeometricTypeTests.cs @@ -26,12 +26,25 @@ public Task LineSegment() [Test] public Task Box() - => AssertType(new NpgsqlBox(3, 4, 1, 2), "(4,3),(2,1)", "box", NpgsqlDbType.Box); + => AssertType(new NpgsqlBox(3, 4, 1, 2), "(4,3),(2,1)", "box", NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + + [Test] + public Task Box_array() + => AssertType( + new[] + { + new NpgsqlBox(3, 4, 1, 2), + new NpgsqlBox(5, 6, 3, 4) + }, + "{(4,3),(2,1);(6,5),(4,3)}", + "box[]", + NpgsqlDbType.Box | NpgsqlDbType.Array); [Test] public Task Path_closed() => AssertType( - new NpgsqlPath(new[] {new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)}, false), + new NpgsqlPath(new[] { new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4) }, false), "((1,2),(3,4))", "path", NpgsqlDbType.Path); diff --git a/test/Npgsql.Tests/Types/HstoreTests.cs b/test/Npgsql.Tests/Types/HstoreTests.cs index ab1ee2ad6c..5696cad98b 100644 --- a/test/Npgsql.Tests/Types/HstoreTests.cs +++ b/test/Npgsql.Tests/Types/HstoreTests.cs @@ -6,7 +6,6 @@ namespace Npgsql.Tests.Types; -[NonParallelizable] public class HstoreTests : MultiplexingTestBase { [Test] @@ -20,11 +19,11 @@ public Task Hstore() }, @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", "hstore", - NpgsqlDbType.Hstore); + NpgsqlDbType.Hstore, isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Hstore_empty() - => AssertType(new Dictionary(), @"", "hstore", NpgsqlDbType.Hstore); + => AssertType(new Dictionary(), @"", "hstore", NpgsqlDbType.Hstore, isNpgsqlDbTypeInferredFromClrType: false); [Test] public Task Hstore_as_ImmutableDictionary() @@ -40,7 +39,7 @@ public Task Hstore_as_ImmutableDictionary() @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", "hstore", NpgsqlDbType.Hstore, - isDefaultForReading: false); + isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); } [Test] @@ -55,7 +54,7 @@ public Task Hstore_as_IDictionary() @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", "hstore", NpgsqlDbType.Hstore, - isDefaultForReading: false); + isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); [OneTimeSetUp] public async Task SetUp() diff --git a/test/Npgsql.Tests/Types/JsonPathTests.cs b/test/Npgsql.Tests/Types/JsonPathTests.cs index 3d068aa3d2..de49a631e0 100644 --- a/test/Npgsql.Tests/Types/JsonPathTests.cs +++ b/test/Npgsql.Tests/Types/JsonPathTests.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System.Data; +using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; @@ -16,6 +17,18 @@ public JsonPathTests(MultiplexingMode multiplexingMode) new object[] { "'$\"varname\"'", "$\"varname\"" }, }; + [Test] + [TestCase("$")] + [TestCase("$\"varname\"")] + public async Task JsonPath(string jsonPath) + { + using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); + await AssertType( + jsonPath, jsonPath, "jsonpath", NpgsqlDbType.JsonPath, isDefaultForWriting: false, isNpgsqlDbTypeInferredFromClrType: false, + inferredDbType: DbType.Object); + } + [Test] [TestCaseSource(nameof(ReadWriteCases))] public async Task Read(string query, string expected) @@ -43,4 +56,4 @@ public async Task Write(string query, string expected) Assert.True(rdr.Read()); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/Types/JsonTests.cs b/test/Npgsql.Tests/Types/JsonTests.cs index 2323430773..3460a88a5c 100644 --- a/test/Npgsql.Tests/Types/JsonTests.cs +++ b/test/Npgsql.Tests/Types/JsonTests.cs @@ -2,6 +2,7 @@ using System.Text; using System.Text.Json; using System.Text.Json.Nodes; +using System.Text.Json.Serialization; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; @@ -166,6 +167,54 @@ await AssertTypeUnsupported( slimDataSource); } + [Test] + public async Task Poco_does_not_stomp_GetValue_string() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + var dataSource = dataSourceBuilder.UseSystemTextJson(null, new[] {typeof(WeatherForecast)}, new[] {typeof(WeatherForecast)}).Build(); + var sqlLiteral = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" + : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{(IsJsonb ? "jsonb" : "json")}", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader.GetValue(0), Is.TypeOf()); + } + + [Test] + public Task Roundtrip_string() + => AssertType( + @"{""p"": 1}", + @"{""p"": 1}", + PostgresType, + NpgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Roundtrip_char_array() + => AssertType( + @"{""p"": 1}".ToCharArray(), + @"{""p"": 1}", + PostgresType, + NpgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Roundtrip_byte_array() + => AssertType( + Encoding.ASCII.GetBytes(@"{""p"": 1}"), + @"{""p"": 1}", + PostgresType, + NpgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [JsonDerivedType(typeof(ExtendedDerivedWeatherForecast), typeDiscriminator: "extended")] record WeatherForecast { public DateTime Date { get; set; } @@ -173,6 +222,15 @@ record WeatherForecast public string Summary { get; set; } = ""; } + record DerivedWeatherForecast : WeatherForecast + { + } + + record ExtendedDerivedWeatherForecast : DerivedWeatherForecast + { + public int TemperatureF => 32 + (int)(TemperatureC / 0.5556); + } + [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/2811")] [IssueLink("https://github.com/npgsql/efcore.pg/issues/1177")] @@ -256,7 +314,7 @@ await AssertTypeWrite( isDefault: false); } - [Test] + [Test, Ignore("TODO We should not change the default type for json/jsonb, it makes little sense.")] public async Task Poco_default_mapping() { var dataSourceBuilder = CreateDataSourceBuilder(); @@ -266,7 +324,7 @@ public async Task Poco_default_mapping() dataSourceBuilder.UseSystemTextJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); await using var dataSource = dataSourceBuilder.Build(); - await AssertTypeWrite( + await AssertType( dataSource, new WeatherForecast { @@ -279,9 +337,138 @@ await AssertTypeWrite( : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", PostgresType, NpgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } + + [Test] + public async Task Poco_polymorphic_mapping() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + if (IsJsonb) + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseSystemTextJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); } + [Test] + public async Task Poco_polymorphic_mapping_read_parents() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + if (IsJsonb) + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseSystemTextJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + var value = new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; + + var sql = """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertTypeWrite( + dataSource, + value, + sql, + PostgresType, + NpgsqlDbType, + isNpgsqlDbTypeInferredFromClrType: false); + + // GetFieldValue + await AssertTypeRead(dataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(ExtendedDerivedWeatherForecast), + isDefault: false); + + await AssertTypeRead(dataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(DerivedWeatherForecast), isDefault: false); + + await AssertTypeRead(dataSource, sql, PostgresType, value, isDefault: false); + } + + + [Test] + public async Task Poco_exact_polymorphic_mapping() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + if (IsJsonb) + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseSystemTextJson(jsonClrTypes: new[] { typeof(ExtendedDerivedWeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + """{"TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } + + [Test] + public async Task Poco_unspecified_polymorphic_mapping() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + // In this case we don't have any statically mapped base type to check its PolymorphicOptions on. + // Detecting whether the type could be polymorphic would require us to duplicate STJ's nearest polymorphic ancestor search. + if (IsJsonb) + return; + + var value = new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; + + var sql = """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertType( + value, + sql, + PostgresType, + NpgsqlDbType, + isDefault: false); + + await AssertTypeRead(DataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(DerivedWeatherForecast), isDefault: false); + + await AssertTypeRead(DataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(ExtendedDerivedWeatherForecast), isDefault: false); + } + public JsonTests(MultiplexingMode multiplexingMode, NpgsqlDbType npgsqlDbType) : base(multiplexingMode) { diff --git a/test/Npgsql.Tests/Types/LTreeTests.cs b/test/Npgsql.Tests/Types/LTreeTests.cs index 48f7d950c9..5d104a4c54 100644 --- a/test/Npgsql.Tests/Types/LTreeTests.cs +++ b/test/Npgsql.Tests/Types/LTreeTests.cs @@ -1,10 +1,10 @@ using System.Threading.Tasks; +using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; namespace Npgsql.Tests.Types; -[NonParallelizable] public class LTreeTests : MultiplexingTestBase { [Test] @@ -19,6 +19,32 @@ public Task LTree() public Task LTxtQuery() => AssertType("Science & Astronomy", "Science & Astronomy", "ltxtquery", NpgsqlDbType.LTxtQuery, isDefaultForWriting: false); + [Test] + public async Task LTree_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var errorMessage = string.Format( + NpgsqlStrings.LTreeNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableLTree), nameof(NpgsqlSlimDataSourceBuilder)); + + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); + + var exception = + await AssertTypeUnsupportedRead>("Top.Science.Astronomy", "ltree", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + exception = await AssertTypeUnsupportedWrite("Top.Science.Astronomy", "ltree", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableLTree() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableLTree(); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, "Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false); + } + [OneTimeSetUp] public async Task SetUp() { diff --git a/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs b/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs index fbf1b537d1..2b9ae54813 100644 --- a/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs +++ b/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs @@ -1,7 +1,7 @@ using System; using System.Data; using System.Threading.Tasks; -using Npgsql.TypeMapping; +using Npgsql.Internal.Resolvers; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Util.Statics; @@ -49,21 +49,22 @@ public Task Timestamptz_local_DateTime_converts() isDefaultForWriting: false); } - protected override async ValueTask OpenConnectionAsync() - { - var conn = await base.OpenConnectionAsync(); - await conn.ExecuteNonQueryAsync("SET TimeZone='Europe/Berlin'"); - return conn; - } - - protected override NpgsqlConnection OpenConnection() - => throw new NotSupportedException(); + NpgsqlDataSource _dataSource = null!; + protected override NpgsqlDataSource DataSource => _dataSource; [OneTimeSetUp] public void Setup() { #if DEBUG LegacyTimestampBehavior = true; + _dataSource = CreateDataSource(builder => + { + // Can't use the static AdoTypeInfoResolver instance, it already captured the feature flag. + builder.AddTypeInfoResolver(new AdoTypeInfoResolver()); + builder.AddTypeInfoResolver(new AdoArrayTypeInfoResolver()); + builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + }); + NpgsqlDataSourceBuilder.ResetGlobalMappings(overwrite: true); #else Assert.Ignore( "Legacy DateTime tests rely on the Npgsql.EnableLegacyTimestampBehavior AppContext switch and can only be run in DEBUG builds"); @@ -72,6 +73,10 @@ public void Setup() #if DEBUG [OneTimeTearDown] - public void Teardown() => LegacyTimestampBehavior = false; + public void Teardown() + { + LegacyTimestampBehavior = false; + NpgsqlDataSourceBuilder.ResetGlobalMappings(overwrite: true); + } #endif } diff --git a/test/Npgsql.Tests/Types/MiscTypeTests.cs b/test/Npgsql.Tests/Types/MiscTypeTests.cs index 41555e776e..57d241a811 100644 --- a/test/Npgsql.Tests/Types/MiscTypeTests.cs +++ b/test/Npgsql.Tests/Types/MiscTypeTests.cs @@ -1,10 +1,8 @@ using System; using System.Data; using System.Threading.Tasks; -using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; -using NUnit.Framework.Constraints; namespace Npgsql.Tests.Types; @@ -16,8 +14,11 @@ class MiscTypeTests : MultiplexingTestBase [Test] public async Task Boolean() { - await AssertType(true, "true", "boolean", NpgsqlDbType.Boolean, DbType.Boolean); - await AssertType(false, "false", "boolean", NpgsqlDbType.Boolean, DbType.Boolean); + await AssertType(true, "true", "boolean", NpgsqlDbType.Boolean, DbType.Boolean, skipArrayCheck: true); + await AssertType(false, "false", "boolean", NpgsqlDbType.Boolean, DbType.Boolean, skipArrayCheck: true); + + // The literal representations for bools inside array are different ({t,f} instead of true/false, so we check separately. + await AssertType(new[] { true, false }, "{t,f}", "boolean[]", NpgsqlDbType.Boolean | NpgsqlDbType.Array); } [Test] @@ -49,7 +50,7 @@ public async Task Null() { cmd.Parameters.AddWithValue("p1", DBNull.Value); cmd.Parameters.Add(new NpgsqlParameter("p2", null)); - cmd.Parameters.Add(new NpgsqlParameter("p3", DBNull.Value)); + cmd.Parameters.Add(new NpgsqlParameter("p3", DBNull.Value)); await using var reader = await cmd.ExecuteReaderAsync(); reader.Read(); @@ -60,107 +61,21 @@ public async Task Null() } } - // Setting non-generic NpgsqlParameter.Value is not allowed, only DBNull.Value + // Setting non-generic NpgsqlParameter.Value to null is not allowed, only DBNull.Value await using (var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn)) { cmd.Parameters.AddWithValue("p4", NpgsqlDbType.Text, null!); - Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); + Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); } - } - - #region Record - - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/724")] - [IssueLink("https://github.com/npgsql/npgsql/issues/1980")] - public async Task Read_Record_as_object_array() - { - var recordLiteral = "(1,'foo'::text)::record"; - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - - var record = (object[])reader[0]; - Assert.That(record[0], Is.EqualTo(1)); - Assert.That(record[1], Is.EqualTo("foo")); - - var array = (object[][])reader[1]; - Assert.That(array.Length, Is.EqualTo(2)); - Assert.That(array[0][0], Is.EqualTo(1)); - Assert.That(array[1][0], Is.EqualTo(1)); - } - - [Test] - public async Task Read_Record_as_ValueTuple() - { - var recordLiteral = "(1,'foo'::text)::record"; - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - - var record = reader.GetFieldValue<(int, string)>(0); - Assert.That(record.Item1, Is.EqualTo(1)); - Assert.That(record.Item2, Is.EqualTo("foo")); - - var array = (object[][])reader[1]; - Assert.That(array.Length, Is.EqualTo(2)); - Assert.That(array[0][0], Is.EqualTo(1)); - Assert.That(array[1][0], Is.EqualTo(1)); - } - [Test] - public async Task Read_Record_as_Tuple() - { - var recordLiteral = "(1,'foo'::text)::record"; - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - - var record = reader.GetFieldValue>(0); - Assert.That(record.Item1, Is.EqualTo(1)); - Assert.That(record.Item2, Is.EqualTo("foo")); - - var array = (object[][])reader[1]; - Assert.That(array.Length, Is.EqualTo(2)); - Assert.That(array[0][0], Is.EqualTo(1)); - Assert.That(array[1][0], Is.EqualTo(1)); - } - - [Test] - public Task Write_Record_is_not_supported() - => AssertTypeUnsupportedWrite(new object[] { 1, "foo" }, "record"); - - [Test] - public async Task Records_supported_only_with_EnableRecords([Values] bool withMappings) - { - Func assertExpr = () => withMappings - ? Throws.Nothing - : Throws.Exception - .TypeOf() - .With.Property("Message") - .EqualTo(string.Format(NpgsqlStrings.RecordsNotEnabled, "EnableRecords", "NpgsqlSlimDataSourceBuilder")); - - var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); - if (withMappings) - dataSourceBuilder.EnableRecords(); - await using var dataSource = dataSourceBuilder.Build(); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = conn.CreateCommand(); - - // RecordHandler doesn't support writing, so we only check for reading - cmd.CommandText = "SELECT ('one'::text, 2)"; - await using var reader = await cmd.ExecuteReaderAsync(); - await reader.ReadAsync(); - - Assert.That(() => reader.GetValue(0), assertExpr()); - Assert.That(() => reader.GetFieldValue(0), assertExpr()); + // Setting generic NpgsqlParameter.Value to null is not allowed, only DBNull.Value + await using (var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn)) + { + cmd.Parameters.Add(new NpgsqlParameter("p4", NpgsqlDbType.Text) { Value = null! }); + Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); + } } - #endregion Record - [Test, Description("Makes sure that setting DbType.Object makes Npgsql infer the type")] [IssueLink("https://github.com/npgsql/npgsql/issues/694")] public async Task DbType_causes_inference() @@ -250,7 +165,7 @@ public Task Oidvector() public async Task Void() { await using var conn = await OpenConnectionAsync(); - Assert.That(await conn.ExecuteScalarAsync("SELECT pg_sleep(0)"), Is.SameAs(DBNull.Value)); + Assert.That(await conn.ExecuteScalarAsync("SELECT pg_sleep(0)"), Is.SameAs(null)); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1364")] diff --git a/test/Npgsql.Tests/Types/MoneyTests.cs b/test/Npgsql.Tests/Types/MoneyTests.cs index 8aceb03dac..4c38f3d111 100644 --- a/test/Npgsql.Tests/Types/MoneyTests.cs +++ b/test/Npgsql.Tests/Types/MoneyTests.cs @@ -1,5 +1,4 @@ -using System; -using System.Data; +using System.Data; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; diff --git a/test/Npgsql.Tests/Types/MultirangeTests.cs b/test/Npgsql.Tests/Types/MultirangeTests.cs index 0162fc78ed..84f815c63c 100644 --- a/test/Npgsql.Tests/Types/MultirangeTests.cs +++ b/test/Npgsql.Tests/Types/MultirangeTests.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Data; +using System.Linq; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; @@ -9,189 +11,130 @@ namespace Npgsql.Tests.Types; public class MultirangeTests : TestBase { - [Test] - public async Task Read() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT '{[3,7), (8,]}'::int4multirange", conn); - await using var reader = await cmd.ExecuteReaderAsync(); - await reader.ReadAsync(); - - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4multirange")); - - var multirangeArray = (NpgsqlRange[])reader[0]; - Assert.That(multirangeArray.Length, Is.EqualTo(2)); - Assert.That(multirangeArray[0], Is.EqualTo(new NpgsqlRange(3, true, false, 7, false, false))); - Assert.That(multirangeArray[1], Is.EqualTo(new NpgsqlRange(9, true, false, 0, false, true))); - - var multirangeList = reader.GetFieldValue>>(0); - Assert.That(multirangeList.Count, Is.EqualTo(2)); - Assert.That(multirangeList[0], Is.EqualTo(new NpgsqlRange(3, true, false, 7, false, false))); - Assert.That(multirangeList[1], Is.EqualTo(new NpgsqlRange(9, true, false, 0, false, true))); - } - - [Test] - public async Task Write() - { - var multirangeArray = new NpgsqlRange[] - { - new(3, true, false, 7, false, false), - new(8, false, false, 0, false, true) - }; - - var multirangeList = new List>(multirangeArray); - - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1::text", conn); - - await WriteInternal(multirangeArray); - await WriteInternal(multirangeList); - - async Task WriteInternal(IList> multirange) - { - await conn.ReloadTypesAsync(); - cmd.Parameters.Add(new() { Value = multirange }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[3,7),[9,)}")); - - await conn.ReloadTypesAsync(); - cmd.Parameters[0] = new() { Value = multirange, NpgsqlDbType = NpgsqlDbType.IntegerMultirange }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[3,7),[9,)}")); - - await conn.ReloadTypesAsync(); - cmd.Parameters[0] = new() { Value = multirange, DataTypeName = "int4multirange" }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[3,7),[9,)}")); - } - } - - [Test] - public async Task Write_nummultirange() - { - var multirangeArray = new NpgsqlRange[] - { - new(3, true, false, 7, false, false), - new(8, false, false, 0, false, true) - }; - - var multirangeList = new List>(multirangeArray); - - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1::text", conn); - - await WriteInternal(multirangeArray); - await WriteInternal(multirangeList); - - async Task WriteInternal(IList> multirange) - { - conn.ReloadTypes(); - cmd.Parameters.Add(new() { Value = multirange }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[3,7),(8,)}")); - - conn.ReloadTypes(); - cmd.Parameters[0] = new() { Value = multirange, NpgsqlDbType = NpgsqlDbType.NumericMultirange }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[3,7),(8,)}")); - - conn.ReloadTypes(); - cmd.Parameters[0] = new() { Value = multirange, DataTypeName = "nummultirange" }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[3,7),(8,)}")); - } - } - - [Test] - public async Task Read_Datemultirange() + static readonly TestCaseData[] MultirangeTestCases = { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT '{[2020-01-01,2020-01-05), (2020-01-10,]}'::datemultirange", conn); - await using var reader = await cmd.ExecuteReaderAsync(); - await reader.ReadAsync(); - - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("datemultirange")); - - var multirangeDateTimeArray = (NpgsqlRange[])reader[0]; - Assert.That(multirangeDateTimeArray.Length, Is.EqualTo(2)); - Assert.That(multirangeDateTimeArray[0], Is.EqualTo(new NpgsqlRange(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false))); - Assert.That(multirangeDateTimeArray[1], Is.EqualTo(new NpgsqlRange(new(2020, 1, 11), true, false, default, false, true))); - - var multirangeDateTimeList = reader.GetFieldValue>>(0); - Assert.That(multirangeDateTimeList.Count, Is.EqualTo(2)); - Assert.That(multirangeDateTimeList[0], Is.EqualTo(new NpgsqlRange(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false))); - Assert.That(multirangeDateTimeList[1], Is.EqualTo(new NpgsqlRange(new(2020, 1, 11), true, false, default, false, true))); + // int4multirange + new TestCaseData( + new NpgsqlRange[] + { + new(3, true, false, 7, false, false), + new(9, true, false, 0, false, true) + }, + "{[3,7),[9,)}", "int4multirange", NpgsqlDbType.IntegerMultirange, true, true, default(NpgsqlRange)) + .SetName("Int"), + + // int8multirange + new TestCaseData( + new NpgsqlRange[] + { + new(3, true, false, 7, false, false), + new(9, true, false, 0, false, true) + }, + "{[3,7),[9,)}", "int8multirange", NpgsqlDbType.BigIntMultirange, true, true, default(NpgsqlRange)) + .SetName("Long"), + + // nummultirange + // numeric is non-discrete so doesn't undergo normalization, use that to test bound scenarios which otherwise get normalized + new TestCaseData( + new NpgsqlRange[] + { + new(3, true, false, 7, true, false), + new(9, false, false, 0, false, true) + }, + "{[3,7],(9,)}", "nummultirange", NpgsqlDbType.NumericMultirange, true, true, default(NpgsqlRange)) + .SetName("Decimal"), + + // daterange + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), + new(new(2020, 1, 10), true, false, default, false, true) + }, + "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", NpgsqlDbType.DateMultirange, true, false, default(NpgsqlRange)) + .SetName("DateTime DateMultirange"), + + // tsmultirange + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), + new(new(2020, 1, 10), true, false, default, false, true) + }, + """{["2020-01-01 00:00:00","2020-01-05 00:00:00"),["2020-01-10 00:00:00",)}""", "tsmultirange", NpgsqlDbType.TimestampMultirange, true, true, default(NpgsqlRange)) + .SetName("DateTime TimestampMultirange"), + + // tstzmultirange + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1, 0, 0, 0, kind: DateTimeKind.Utc), true, false, new(2020, 1, 5, 0, 0, 0, kind: DateTimeKind.Utc), false, false), + new(new(2020, 1, 10, 0, 0, 0, kind: DateTimeKind.Utc), true, false, default, false, true) + }, + """{["2020-01-01 01:00:00+01","2020-01-05 01:00:00+01"),["2020-01-10 01:00:00+01",)}""", "tstzmultirange", NpgsqlDbType.TimestampTzMultirange, true, true, default(NpgsqlRange)) + .SetName("DateTime TimestampTzMultirange"), #if NET6_0_OR_GREATER - var multirangeDateOnlyArray = reader.GetFieldValue[]>(0); - Assert.That(multirangeDateOnlyArray.Length, Is.EqualTo(2)); - Assert.That(multirangeDateOnlyArray[0], Is.EqualTo(new NpgsqlRange(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false))); - Assert.That(multirangeDateOnlyArray[1], Is.EqualTo(new NpgsqlRange(new(2020, 1, 11), true, false, default, false, true))); - - var multirangeDateOnlyList = reader.GetFieldValue>>(0); - Assert.That(multirangeDateOnlyList.Count, Is.EqualTo(2)); - Assert.That(multirangeDateOnlyList[0], Is.EqualTo(new NpgsqlRange(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false))); - Assert.That(multirangeDateOnlyList[1], Is.EqualTo(new NpgsqlRange(new(2020, 1, 11), true, false, default, false, true))); + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), + new(new(2020, 1, 10), true, false, default, false, true) + }, + "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", NpgsqlDbType.DateMultirange, false, false, default(NpgsqlRange)) + .SetName("DateOnly"), #endif - } + }; + + [Test, TestCaseSource(nameof(MultirangeTestCases))] + public Task Multirange_as_array( + T multirangeAsArray, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType, bool isDefaultForReading, bool isDefaultForWriting, TRange _) + => AssertType(multirangeAsArray, sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForReading: isDefaultForReading, + isDefaultForWriting: isDefaultForWriting); + + [Test, TestCaseSource(nameof(MultirangeTestCases))] + public Task Multirange_as_list( + T multirangeAsArray, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType, bool isDefaultForReading, bool isDefaultForWriting, TRange _) + where T : IList + => AssertType( + new List(multirangeAsArray), + sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForReading: false, isDefaultForWriting: isDefaultForWriting); -#if NET6_0_OR_GREATER [Test] - public async Task Write_Datemultirange_DateOnly() + [NonParallelizable] + public async Task Unmapped_multirange_with_mapped_subtype() { - var multirangeArray = new NpgsqlRange[] - { - new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), - new(new(2020, 1, 10), false, false, default, false, true) - }; - - var multirangeList = new List>(multirangeArray); - - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1::text", conn); - - await WriteInternal(multirangeArray); - await WriteInternal(multirangeList); + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + conn.ReloadTypes(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + var value = new[] {new NpgsqlRange( + new string('a', conn.Settings.WriteBufferSize + 10).ToCharArray(), + new string('z', conn.Settings.WriteBufferSize + 10).ToCharArray() + )}; + + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { DataTypeName = typeName + "_multirange", ParameterName = "p", Value = value }); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); - async Task WriteInternal(IList> multirange) - { - conn.ReloadTypes(); - cmd.Parameters.Add(new() { Value = multirange }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[2020-01-01,2020-01-05),[2020-01-11,)}")); - - conn.ReloadTypes(); - cmd.Parameters[0] = new() { Value = multirange, NpgsqlDbType = NpgsqlDbType.DateMultirange }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[2020-01-01,2020-01-05),[2020-01-11,)}")); - - conn.ReloadTypes(); - cmd.Parameters[0] = new() { Value = multirange, DataTypeName = "datemultirange" }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[2020-01-01,2020-01-05),[2020-01-11,)}")); - } + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(NpgsqlRange[]))); + var result = reader.GetFieldValue[]>(0); + Assert.That(result, Is.EqualTo(value).Using[]>((actual, expected) => + actual[0].LowerBound!.SequenceEqual(expected[0].LowerBound!) && actual[0].UpperBound!.SequenceEqual(expected[0].UpperBound!))); } -#endif - - [Test] - public async Task Write_Datemultirange_DateTime() - { - var multirangeArray = new NpgsqlRange[] - { - new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), - new(new(2020, 1, 10), false, false, default, false, true) - }; - var multirangeList = new List>(multirangeArray); - - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1::text", conn); + protected override NpgsqlDataSource DataSource { get; } - await WriteInternal(multirangeArray); - await WriteInternal(multirangeList); - - async Task WriteInternal(IList> multirange) + public MultirangeTests() => DataSource = CreateDataSource(builder => { - conn.ReloadTypes(); - cmd.Parameters.Add(new() { Value = multirange, NpgsqlDbType = NpgsqlDbType.DateMultirange }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[2020-01-01,2020-01-05),[2020-01-11,)}")); - - conn.ReloadTypes(); - cmd.Parameters[0] = new() { Value = multirange, DataTypeName = "datemultirange" }; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("{[2020-01-01,2020-01-05),[2020-01-11,)}")); - } - } + builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + }); [OneTimeSetUp] public async Task Setup() @@ -199,7 +142,4 @@ public async Task Setup() await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); } - - protected override NpgsqlConnection OpenConnection() - => throw new NotSupportedException(); } diff --git a/test/Npgsql.Tests/Types/NetworkTypeTests.cs b/test/Npgsql.Tests/Types/NetworkTypeTests.cs index 63d456b9cd..994fdd45e4 100644 --- a/test/Npgsql.Tests/Types/NetworkTypeTests.cs +++ b/test/Npgsql.Tests/Types/NetworkTypeTests.cs @@ -17,46 +17,45 @@ class NetworkTypeTests : MultiplexingTestBase { [Test] public Task Inet_v4_as_IPAddress() - => AssertType(IPAddress.Parse("192.168.1.1"), "192.168.1.1/32", "inet", NpgsqlDbType.Inet); + => AssertType(IPAddress.Parse("192.168.1.1"), "192.168.1.1/32", "inet", NpgsqlDbType.Inet, skipArrayCheck: true); [Test] - public Task Inet_v6_as_IPAddress() + public Task Inet_v4_array_as_IPAddress_array() => AssertType( - IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), - "2001:1db8:85a3:1142:1000:8a2e:1370:7334/128", - "inet", - NpgsqlDbType.Inet); + new[] + { + IPAddress.Parse("192.168.1.1"), + IPAddress.Parse("192.168.1.2") + }, + "{192.168.1.1,192.168.1.2}", "inet[]", NpgsqlDbType.Inet | NpgsqlDbType.Array); [Test] - public Task Inet_v4_as_tuple() - => AssertType((IPAddress.Parse("192.168.1.1"), 24), "192.168.1.1/24", "inet", NpgsqlDbType.Inet, isDefaultForReading: false); - - [Test] - public Task Inet_v6_as_tuple() + public Task Inet_v6_as_IPAddress() => AssertType( - (IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 24), - "2001:1db8:85a3:1142:1000:8a2e:1370:7334/24", + IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), + "2001:1db8:85a3:1142:1000:8a2e:1370:7334/128", "inet", NpgsqlDbType.Inet, - isDefaultForReading: false); + skipArrayCheck: true); [Test] - public Task Inet_v6_array_as_tuple() + public Task Inet_v6_array_as_IPAddress_array() => AssertType( - new[] { (IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 24) }, - "{2001:1db8:85a3:1142:1000:8a2e:1370:7334/24}", - "inet[]", - NpgsqlDbType.Inet | NpgsqlDbType.Array, - isDefaultForReading: false); + new[] + { + IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), + IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7335") + }, + "{2001:1db8:85a3:1142:1000:8a2e:1370:7334,2001:1db8:85a3:1142:1000:8a2e:1370:7335}", "inet[]", NpgsqlDbType.Inet | NpgsqlDbType.Array); [Test, IssueLink("https://github.com/dotnet/corefx/issues/33373")] public Task IPAddress_Any() - => AssertTypeWrite(IPAddress.Any, "0.0.0.0/32", "inet", NpgsqlDbType.Inet); + => AssertTypeWrite(IPAddress.Any, "0.0.0.0/32", "inet", NpgsqlDbType.Inet, skipArrayCheck: true); [Test] public Task Cidr() => AssertType( - (Address: IPAddress.Parse("192.168.1.0"), Subnet: 24), + new NpgsqlCidr(IPAddress.Parse("192.168.1.0"), netmask: 24), "192.168.1.0/24", "cidr", NpgsqlDbType.Cidr, @@ -129,10 +128,7 @@ public async Task Macaddr_write_validation() if (conn.PostgreSqlVersion < new Version(10, 0)) Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - var exception = await AssertTypeUnsupportedWrite( - PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "macaddr"); - - Assert.That(exception.Message, Does.StartWith("22P03:").And.Contain("1")); + await AssertTypeUnsupportedWrite(PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "macaddr"); } public NetworkTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} diff --git a/test/Npgsql.Tests/Types/NumericTypeTests.cs b/test/Npgsql.Tests/Types/NumericTypeTests.cs index 9c5c13c027..78dc2f7fa7 100644 --- a/test/Npgsql.Tests/Types/NumericTypeTests.cs +++ b/test/Npgsql.Tests/Types/NumericTypeTests.cs @@ -1,11 +1,9 @@ using System; -using System.Collections.Generic; using System.Data; using System.Globalization; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -using NUnit.Framework.Internal; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests.Types; @@ -110,4 +108,4 @@ public Task Read_overflow(T _, double value, string pgTypeName) => AssertTypeUnsupportedRead(value.ToString(CultureInfo.InvariantCulture), pgTypeName); public NumericTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/Types/RangeTests.cs b/test/Npgsql.Tests/Types/RangeTests.cs index 33a1365ec0..db57f0d78d 100644 --- a/test/Npgsql.Tests/Types/RangeTests.cs +++ b/test/Npgsql.Tests/Types/RangeTests.cs @@ -2,122 +2,69 @@ using System.ComponentModel; using System.Data; using System.Globalization; +using System.Linq; using System.Threading.Tasks; +using Npgsql.Properties; using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; - using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests.Types; -/// -/// https://www.postgresql.org/docs/current/static/rangetypes.html -/// class RangeTests : MultiplexingTestBase { - [Test, NUnit.Framework.Description("Resolves a range type handler via the different pathways")] - public async Task Range_resolution() + static readonly TestCaseData[] RangeTestCases = { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - - await using var dataSource = CreateDataSource(csb => csb.Pooling = false); - await using var conn = await OpenConnectionAsync(); - - // Resolve type by NpgsqlDbType - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Range | NpgsqlDbType.Integer, DBNull.Value); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = new NpgsqlRange(3, 5) }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - } - } - - // Resolve type by DataTypeName - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName="p", DataTypeName = "int4range", Value = DBNull.Value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT int4range(3, 5)", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - Assert.That(reader.GetFieldValue>(0), Is.EqualTo(new NpgsqlRange(3, true, 5, false))); - } - } - - [Test] - public async Task Range() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Range | NpgsqlDbType.Integer) { Value = NpgsqlRange.Empty }; - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = new NpgsqlRange(1, 10) }; - var p3 = new NpgsqlParameter { ParameterName = "p3", Value = new NpgsqlRange(1, false, 10, false) }; - var p4 = new NpgsqlParameter { ParameterName = "p4", Value = new NpgsqlRange(0, false, true, 10, false, false) }; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Range | NpgsqlDbType.Integer)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - - Assert.That(reader[0].ToString(), Is.EqualTo("empty")); - Assert.That(reader[1].ToString(), Is.EqualTo("[1,11)")); - Assert.That(reader[2].ToString(), Is.EqualTo("[2,10)")); - Assert.That(reader[3].ToString(), Is.EqualTo("(,10)")); - } - - [Test] - [NonParallelizable] - public async Task Range_with_long_subtype() - { - await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); - await using var conn = await dataSource.OpenConnectionAsync(); - - var typeName = await GetTempTypeName(conn); - await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug - conn.ReloadTypes(); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - - var value = new NpgsqlRange( - new string('a', conn.Settings.WriteBufferSize + 10), - new string('z', conn.Settings.WriteBufferSize + 10) - ); - - await using var cmd = new NpgsqlCommand("SELECT @p", conn); - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Range | NpgsqlDbType.Text) { Value = value }); - await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); - await reader.ReadAsync(); - Assert.That(reader[0], Is.EqualTo(value)); - } + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", NpgsqlDbType.IntegerRange) + .SetName("IntegerRange"), + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int8range", NpgsqlDbType.BigIntRange) + .SetName("BigIntRange"), + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("NumericRange"), + new TestCaseData(new NpgsqlRange( + new DateTime(2020, 1, 1, 12, 0, 0), true, + new DateTime(2020, 1, 3, 13, 0, 0), false), + """["2020-01-01 12:00:00","2020-01-03 13:00:00")""", "tsrange", NpgsqlDbType.TimestampRange) + .SetName("TimestampRange"), + // Note that the below text representations are local (according to TimeZone, which is set to Europe/Berlin in this test class), + // because that's how PG does timestamptz *text* representation. + new TestCaseData(new NpgsqlRange( + new DateTime(2020, 1, 1, 12, 0, 0, DateTimeKind.Utc), true, + new DateTime(2020, 1, 3, 13, 0, 0, DateTimeKind.Utc), false), + """["2020-01-01 13:00:00+01","2020-01-03 14:00:00+01")""", "tstzrange", NpgsqlDbType.TimestampTzRange) + .SetName("TimestampTzRange"), + + // Note that numrange is a non-discrete range, and therefore doesn't undergo normalization to inclusive/exclusive in PG + new TestCaseData(NpgsqlRange.Empty, "empty", "numrange", NpgsqlDbType.NumericRange) + .SetName("EmptyRange"), + new TestCaseData(new NpgsqlRange(1, true, 10, true), "[1,10]", "numrange", NpgsqlDbType.NumericRange) + .SetName("Inclusive"), + new TestCaseData(new NpgsqlRange(1, false, 10, false), "(1,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("Exclusive"), + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("InclusiveExclusive"), + new TestCaseData(new NpgsqlRange(1, false, 10, true), "(1,10]", "numrange", NpgsqlDbType.NumericRange) + .SetName("ExclusiveInclusive"), + new TestCaseData(new NpgsqlRange(1, false, true, 10, false, false), "(,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("InfiniteLowerBound"), + new TestCaseData(new NpgsqlRange(1, true, false, 10, false, true), "[1,)", "numrange", NpgsqlDbType.NumericRange) + .SetName("InfiniteUpperBound") + }; + + // See more test cases in DateTimeTests + [Test, TestCaseSource(nameof(RangeTestCases))] + public Task Range(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) + => AssertType(range, sqlLiteral, pgTypeName, npgsqlDbType, + // NpgsqlRange[] is mapped to multirange by default, not array, so the built-in AssertType testing for arrays fails + // (see below) + skipArrayCheck: true); + + // This re-executes the same scenario as above, but with isDefaultForWriting: false and without skipArrayCheck: true. + // This tests coverage of range arrays (as opposed to multiranges). + [Test, TestCaseSource(nameof(RangeTestCases))] + public Task Range_array(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) + => AssertType(range, sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForWriting: false); [Test] public void Equality_finite() @@ -217,6 +164,35 @@ public async Task TimestampTz_range_with_DateTimeOffset() Assert.That(actual, Is.EqualTo(range)); } + [Test] + [NonParallelizable] + public async Task Unmapped_range_with_mapped_subtype() + { + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + conn.ReloadTypes(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + var value = new NpgsqlRange( + new string('a', conn.Settings.WriteBufferSize + 10).ToCharArray(), + new string('z', conn.Settings.WriteBufferSize + 10).ToCharArray() + ); + + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { DataTypeName = typeName, ParameterName = "p", Value = value }); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(NpgsqlRange))); + var result = reader.GetFieldValue>(0); + Assert.That(result, Is.EqualTo(value).Using>((actual, expected) => + actual.LowerBound!.SequenceEqual(expected.LowerBound!) && actual.UpperBound!.SequenceEqual(expected.UpperBound!))); + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4441")] public async Task Array_of_range() { @@ -240,13 +216,41 @@ await AssertType( new(3, lowerBoundIsInclusive: true, 4, upperBoundIsInclusive: false), new(5, lowerBoundIsInclusive: true, 6, upperBoundIsInclusive: false) }, - @"{""[3,4)"",""[5,6)""}", + """{"[3,4)","[5,6)"}""", "int4range[]", NpgsqlDbType.IntegerRange | NpgsqlDbType.Array, isDefaultForWriting: !supportsMultirange, isNpgsqlDbTypeInferredFromClrType: false); } + [Test] + public async Task Ranges_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var errorMessage = string.Format( + NpgsqlStrings.RangesNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableRanges), nameof(NpgsqlSlimDataSourceBuilder)); + + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); + + var exception = await AssertTypeUnsupportedRead>("[1,10)", "int4range", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + exception = await AssertTypeUnsupportedWrite>( + new NpgsqlRange(1, true, 10, false), "int4range", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableRanges() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableRanges(); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", NpgsqlDbType.IntegerRange, skipArrayCheck: true); + } + protected override NpgsqlConnection OpenConnection() => throw new NotSupportedException(); @@ -434,5 +438,11 @@ public override object ConvertFrom(ITypeDescriptorContext? context, CultureInfo? #endregion - public RangeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + protected override NpgsqlDataSource DataSource { get; } + + public RangeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) + => DataSource = CreateDataSource(builder => + { + builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + }); } diff --git a/test/Npgsql.Tests/Types/RecordTests.cs b/test/Npgsql.Tests/Types/RecordTests.cs new file mode 100644 index 0000000000..54a56baa4a --- /dev/null +++ b/test/Npgsql.Tests/Types/RecordTests.cs @@ -0,0 +1,109 @@ +using System; +using System.Threading.Tasks; +using Npgsql.Properties; +using NUnit.Framework; +using NUnit.Framework.Constraints; + +namespace Npgsql.Tests.Types; + +public class RecordTests : MultiplexingTestBase +{ + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/724")] + [IssueLink("https://github.com/npgsql/npgsql/issues/1980")] + public async Task Read_Record_as_object_array() + { + var recordLiteral = "(1,'foo'::text)::record"; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + var record = (object[])reader[0]; + Assert.That(record[0], Is.EqualTo(1)); + Assert.That(record[1], Is.EqualTo("foo")); + + var array = (object[][])reader[1]; + Assert.That(array.Length, Is.EqualTo(2)); + Assert.That(array[0][0], Is.EqualTo(1)); + Assert.That(array[1][0], Is.EqualTo(1)); + } + + [Test] + public async Task Read_Record_as_ValueTuple() + { + var recordLiteral = "(1,'foo'::text)::record"; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + var record = reader.GetFieldValue<(int, string)>(0); + Assert.That(record.Item1, Is.EqualTo(1)); + Assert.That(record.Item2, Is.EqualTo("foo")); + + var array = (object[][])reader[1]; + Assert.That(array.Length, Is.EqualTo(2)); + Assert.That(array[0][0], Is.EqualTo(1)); + Assert.That(array[1][0], Is.EqualTo(1)); + } + + [Test] + public async Task Read_Record_as_Tuple() + { + var recordLiteral = "(1,'foo'::text)::record"; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + var record = reader.GetFieldValue>(0); + Assert.That(record.Item1, Is.EqualTo(1)); + Assert.That(record.Item2, Is.EqualTo("foo")); + + var array = (object[][])reader[1]; + Assert.That(array.Length, Is.EqualTo(2)); + Assert.That(array[0][0], Is.EqualTo(1)); + Assert.That(array[1][0], Is.EqualTo(1)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1238")] + public async Task Record_with_non_int_field() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT ('one'::TEXT, 2)", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var record = reader.GetFieldValue(0); + Assert.That(record[0], Is.EqualTo("one")); + Assert.That(record[1], Is.EqualTo(2)); + } + + [Test] + public async Task Records_supported_only_with_EnableRecords([Values] bool withMappings) + { + Func assertExpr = () => withMappings + ? Throws.Nothing + : Throws.Exception + .TypeOf() + .With.Property("InnerException").Property("Message") + .EqualTo(string.Format(NpgsqlStrings.RecordsNotEnabled, "EnableRecords", "NpgsqlSlimDataSourceBuilder")); + + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + if (withMappings) + dataSourceBuilder.EnableRecords(); + await using var dataSource = dataSourceBuilder.Build(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + // RecordHandler doesn't support writing, so we only check for reading + cmd.CommandText = "SELECT ('one'::text, 2)"; + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(() => reader.GetValue(0), assertExpr()); + Assert.That(() => reader.GetFieldValue(0), assertExpr()); + } + + public RecordTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} +} diff --git a/test/Npgsql.Tests/Types/TextTests.cs b/test/Npgsql.Tests/Types/TextTests.cs index aa2e7d69a3..c4583151b4 100644 --- a/test/Npgsql.Tests/Types/TextTests.cs +++ b/test/Npgsql.Tests/Types/TextTests.cs @@ -79,6 +79,7 @@ public async Task Truncate() Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2.Substring(0, 4))); // NpgsqlParameter.Size larger than the value size should mean the value size, as well as 0 and -1 + p.Value = data2; p.Size = data2.Length + 10; Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); p.Size = 0; diff --git a/test/Npgsql.Tests/TypesTests.cs b/test/Npgsql.Tests/TypesTests.cs index 5dbfa844f3..de2b1beed0 100644 --- a/test/Npgsql.Tests/TypesTests.cs +++ b/test/Npgsql.Tests/TypesTests.cs @@ -1,8 +1,5 @@ using System; -using System.Diagnostics; -using System.Globalization; using System.Net; -using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; @@ -208,10 +205,6 @@ public void NpgsqlInet() { var v = new NpgsqlInet(IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 32); Assert.That(v.ToString(), Is.EqualTo("2001:1db8:85a3:1142:1000:8a2e:1370:7334/32")); - -#pragma warning disable CS8625 - Assert.That(v != null); // #776 -#pragma warning disable CS8625 } #pragma warning restore 618 } diff --git a/test/Npgsql.Tests/WriteBufferTests.cs b/test/Npgsql.Tests/WriteBufferTests.cs index 19603b1741..5bd6cdf5a1 100644 --- a/test/Npgsql.Tests/WriteBufferTests.cs +++ b/test/Npgsql.Tests/WriteBufferTests.cs @@ -1,6 +1,5 @@ using System.IO; using Npgsql.Internal; -using Npgsql.Util; using NUnit.Framework; namespace Npgsql.Tests; @@ -8,6 +7,16 @@ namespace Npgsql.Tests; [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] // Parallel access to a single buffer class WriteBufferTests { + [Test] + public void GetWriter_Full_Buffer() + { + WriteBuffer.WritePosition += WriteBuffer.WriteSpaceLeft; + var writer = WriteBuffer.GetWriter(null!, FlushMode.Blocking); + Assert.That(writer.ShouldFlush(sizeof(byte)), Is.True); + writer.Flush(); + Assert.That(writer.ShouldFlush(sizeof(byte)), Is.False); + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1275")] public void Write_zero_characters() { @@ -88,7 +97,7 @@ public void Chunked_char_array_encoding_fits_with_surrogates() public void SetUp() { Underlying = new MemoryStream(); - WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, PGUtil.UTF8Encoding); + WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding); } #pragma warning restore CS8625