8000 Verify server ID before KILL QUERY to prevent cross-server cancellation by Copilot · Pull Request #1575 · mysql-net/MySqlConnector · GitHub
[go: up one dir, main page]

Skip to content

Verify server ID before KILL QUERY to prevent cross-server cancellation #1575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool)
public int ActiveCommandId { get; private set; }
public int CancellationTimeout { get; private set; }
public int ConnectionId { get; set; }
public string? ServerHostname { get; set; }
public byte[]? AuthPluginData { get; set; }
public long CreatedTimestamp { get; }
public ConnectionPool? Pool { get; }
Expand Down Expand Up @@ -117,6 +118,24 @@ public void DoCancel(ICancellableCommand commandToCancel, MySqlCommand killComma
return;
}

// Verify server identity before executing KILL QUERY to prevent cancelling on the wrong server
var killSession = killCommand.Connection!.Session;
if (!string.IsNullOrEmpty(ServerHostname) && !string.IsNullOrEmpty(killSession.ServerHostname))
{
if (!string.Equals(ServerHostname, killSession.ServerHostname, StringComparison.Ordinal))
{
Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerHostname, killSession.ServerHostname);
return;
}
}
else if (!string.IsNullOrEmpty(ServerHostname) || !string.IsNullOrEmpty(killSession.ServerHostname))
{
// One session has hostname, the other doesn't - this is a potential mismatch
Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerHostname, killSession.ServerHostname);
return;
}
// If both sessions have no hostname, allow the operation for backward compatibility

// NOTE: This command is executed while holding the lock to prevent race conditions during asynchronous cancellation.
// For example, if the lock weren't held, the current command could finish and the other thread could set ActiveCommandId
// to zero, then start executing a new command. By the time this "KILL QUERY" command reached the server, the wrong
Expand Down Expand Up @@ -635,6 +654,9 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
ConnectionId = newConnectionId;
}

// Get server hostname for KILL QUERY verification
await GetServerHostnameAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout;
return redirectionUrl;
}
Expand Down Expand Up @@ -1951,6 +1973,52 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation
}
}

private async Task GetServerHostnameAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
Log.GettingServerHostname(m_logger, Id);
try
{
var payload = SupportsQueryAttributes ? s_selectHostnameWithAttributesPayload : s_selectHostnameNoAttributesPayload;
await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);

// column count: 1
_ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);

// @@hostname column
_ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

if (!SupportsDeprecateEof)
{
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
_ = EofPayload.Create(payload.Span);
}

// first (and only) row
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

var reader = new ByteArrayReader(payload.Span);
var length = reader.ReadLengthEncodedIntegerOrNull();
var hostname = length > 0 ? Encoding.UTF8.GetString(reader.ReadByteString(length)) : null;

ServerHostname = hostname;

Log.RetrievedServerHostname(m_logger, Id, hostname);

// OK/EOF payload
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
if (OkPayload.IsOk(payload.Span, this))
OkPayload.Verify(payload.Span, this);
else
EofPayload.Create(payload.Span);
}
catch (MySqlException ex)
{
Log.FailedToGetServerHostname(m_logger, ex, Id);
// Set fallback value to ensure operation can continue
ServerHostname = null;
}
}

private void ShutdownSocket()
{
Log.ClosingStreamSocket(m_logger, Id);
Expand Down Expand Up @@ -2182,6 +2250,8 @@ protected override void OnStatementBegin(int index)
private static readonly PayloadData s_sleepWithAttributesPayload = QueryPayload.Create(true, "SELECT SLEEP(0) INTO @__MySqlConnector__Sleep;"u8);
private static readonly PayloadData s_selectConnectionIdVersionNoAttributesPayload = QueryPayload.Create(false, "SELECT CONNECTION_ID(), VERSION();"u8);
private static readonly PayloadData s_selectConnectionIdVersionWithAttributesPayload = QueryPayload.Create(true, "SELECT CONNECTION_ID(), VERSION();"u8);
private static readonly PayloadData s_selectHostnameNoAttributesPayload = QueryPayload.Create(false, "SELECT @@hostname;"u8);
private static readonly PayloadData s_selectHostnameWithAttributesPayload = QueryPayload.Create(true, "SELECT @@hostname;"u8);

private readonly ILogger m_logger;
#if NET9_0_OR_GREATER
Expand Down
4 changes: 4 additions & 0 deletions src/MySqlConnector/Logging/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ internal static class EventIds
public const int CertificateErrorUnixSocket = 2158;
public const int CertificateErrorNoPassword = 2159;
public const int CertificateErrorValidThumbprint = 2160;
public const int GettingServerHostname = 2161;
public const int RetrievedServerHostname = 2162;
public const int FailedToGetServerHostname = 2163;

// Command execution events, 2200-2299
public const int CannotExecuteNewCommandInState = 2200;
Expand All @@ -108,6 +111,7 @@ internal static class EventIds
public const int IgnoringCancellationForInactiveCommand = 2306;
public const int CancelingCommand = 2307;
public const int SendingSleepToClearPendingCancellation = 2308;
public const int IgnoringCancellationForDifferentServer = 2309;

// Cached procedure events, 2400-2499
public const int GettingCachedProcedure = 2400;
Expand Down
12 changes: 12 additions & 0 deletions src/MySqlConnector/Logging/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ internal static partial class Log
[LoggerMessage(EventIds.FailedToGetConnectionId, LogLevel.Information, "Session {SessionId} failed to get CONNECTION_ID(), VERSION()")]
public static partial void FailedToGetConnectionId(ILogger logger, Exception exception, string sessionId);

[LoggerMessage(EventIds.GettingServerHostname, LogLevel.Debug, "Session {SessionId} getting server hostname")]
public static partial void GettingServerHostname(ILogger logger, string sessionId);

[LoggerMessage(EventIds.RetrievedServerHostname, LogLevel.Debug, "Session {SessionId} retrieved server hostname: {ServerHostname}")]
public static partial void RetrievedServerHostname(ILogger logger, string sessionId, string? serverHostname);

[LoggerMessage(EventIds.FailedToGetServerHostname, LogLevel.Information, "Session {SessionId} failed to get server hostname")]
public static partial void FailedToGetServerHostname(ILogger logger, Exception exception, string sessionId);

[LoggerMessage(EventIds.IgnoringCancellationForDifferentServer, LogLevel.Warning, "Session {SessionId} ignoring cancellation from session {KillSessionId}: server hostname mismatch (this hostname={ServerHostname}, kill hostname={KillServerHostname})")]
public static partial void IgnoringCancellationForDifferentServer(ILogger logger, string sessionId, string killSessionId, string? serverHostname, string? killServerHostname);

[LoggerMessage(EventIds.ClosingStreamSocket, LogLevel.Debug, "Session {SessionId} closing stream/socket")]
public static partial void ClosingStreamSocket(ILogger logger, string sessionId);

Expand Down
53 changes: 53 additions & 0 deletions tests/IntegrationTests/ServerIdentificationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System.Diagnostics;

namespace IntegrationTests;

public class ServerIdentificationTests : IClassFixture<DatabaseFixture>, IDisposable
{
public ServerIdentificationTests(DatabaseFixture database)
{
m_database = database;
}

public void Dispose()
{
}

[SkippableFact(ServerFeatures.Timeout)]
public void CancelCommand_WithServerVerification()
{
// This test verifies that cancellation still works with server verification
using var connection = new MySqlConnection(AppConfig.ConnectionString);
connection.Open();

Check failure on line 22 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector (Linux Build)

tests/IntegrationTests/ServerIdentificationTests.cs#L22

tests/IntegrationTests/ServerIdentificationTests.cs(22,1): Error SA1028: Code should not contain trailing whitespace (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1028.md)

Check failure on line 22 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector

tests/IntegrationTests/ServerIdentificationTests.cs#L22

tests/IntegrationTests/ServerIdentificationTests.cs(22,1): Error SA1028: Code should not contain trailing whitespace (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1028.md)
using var cmd = new MySqlCommand("SELECT SLEEP(5)", connection);
var task = Task.Run(async () =>
{
await Task.Delay(TimeSpan.FromSeconds(0.5));
cmd.Cancel();
});

var stopwatch = Stopwatch.StartNew();
TestUtilities.AssertExecuteScalarReturnsOneOrIsCanceled(cmd);
Assert.InRange(stopwatch.ElapsedMilliseconds, 250, 2500);

#pragma warning disable xUnit1031 // Do not use blocking task operations in test method
task.Wait(); // shouldn't throw
#pragma warning restore xUnit1031 // Do not use blocking task operations in test method
}

[SkippableFact(ServerFeatures.KnownCertificateAuthority)]

Check failure on line 39 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector (Linux Build)

tests/IntegrationTests/ServerIdentificationTests.cs#L39

tests/IntegrationTests/ServerIdentificationTests.cs(39,59): Error SA1028: Code should not contain trailing whitespace (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1028.md)

Check failure on line 39 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector

tests/IntegrationTests/ServerIdentificationTests.cs#L39

tests/IntegrationTests/ServerIdentificationTests.cs(39,59): Error SA1028: Code should not contain trailing whitespace (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1028.md)
public void ServerHasServerHostname()
{
using var connection = new MySqlConnection(AppConfig.ConnectionString);
connection.Open();

// Test that we can query server hostname
using var cmd = new MySqlCommand("SELECT @@hostname", connection);
var hostname = cmd.ExecuteScalar();

Check failure on line 48 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector (Linux Build)

tests/IntegrationTests/ServerIdentificationTests.cs#L48

tests/IntegrationTests/ServerIdentificationTests.cs(48,1): Error SA1028: Code should not contain trailing whitespace (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1028.md)

Check failure on line 48 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector

tests/IntegrationTests/ServerIdentificationTests.cs#L48

tests/IntegrationTests/ServerIdentificationTests.cs(48,1): Error SA1028: Code should not contain trailing whitespace (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1028.md)
// Hostname might be null on some server configurations, but the query should succeed
}

private readonly DatabaseFixture m_database;
}

Check failure on line 53 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector (Linux Build)

tests/IntegrationTests/ServerIdentificationTests.cs#L53

tests/IntegrationTests/ServerIdentificationTests.cs(53,2): Error SA1518: File is required to end with a single newline character (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1518.md)

Check failure on line 53 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector (Linux Build)

tests/IntegrationTests/ServerIdentificationTests.cs#L53

tests/IntegrationTests/ServerIdentificationTests.cs(53,2): Error SA1518: File is required to end with a single newline character (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1518.md)

Check failure on line 53 in tests/IntegrationTests/ServerIdentificationTests.cs

View check run for this annotation

Azure Pipelines / mysql-net.MySqlConnector

tests/IntegrationTests/ServerIdentificationTests.cs#L53

tests/IntegrationTests/ServerIdentificationTests.cs(53,2): Error SA1518: File is required to end with a single newline character (https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/documentation/SA1518.md)
Loading
0