8000 Fix handling of multi-message TSIG responses by ibauersachs · Pull Request #300 · dnsjava/dnsjava · GitHub
[go: up one dir, main page]

Skip to content

Fix handling of multi-message TSIG responses #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Verify that the last message is signed as required
  • Loading branch information
ibauersachs committed Nov 2, 2023
commit a9ed360ca759116fcea8940575cc1a5abc049ac0
39 changes: 31 additions & 8 deletions src/main/java/org/xbill/DNS/TSIG.java
Original file line number Diff line number Diff line change
Expand Up @@ -839,36 +839,59 @@ public StreamVerifier(TSIG tsig, TSIGRecord queryTsig) {
* TSIG records must be present on the first and last messages, and at least every 100 records
* in between. After calling this routine, Message.isVerified() may be called on this message.
*
* @param m The message
* @param b The message in unparsed form
* <p>This overload assumes that the verified message is not the last one, which is required to
* have a {@link TSIGRecord}. Use {@link #verify(Message, byte[], boolean)} to explicitly
* specify the last message.
*
* @param message The message
* @param messageBytes The message in unparsed form
* @return The result of the verification (as an Rcode)
* @see Rcode
*/
public int verify(Message message, byte[] messageBytes) {
return verify(message, messageBytes, false);
}

/**
* Verifies a TSIG record on an incoming message that is part of a multiple message response.
* TSIG records must be present on the first and last messages, and at least every 100 records
* in between. After calling this routine, Message.isVerified() may be called on this message.
*
* @param message The message
* @param messageBytes The message in unparsed form
* @param isLastMessage If true, verifies that the {@link Message} has an {@link TSIGRecord}.
* @return The result of the verification (as an Rcode)
* @see Rcode
*/
public int verify(Message m, byte[] b) {
TSIGRecord tsig = m.getTSIG();
public int verify(Message message, byte[] messageBytes, boolean isLastMessage) {
TSIGRecord tsig = message.getTSIG();

nresponses++;
if (nresponses == 1) {
int result = key.verify(m, b, queryTsig, true, sharedHmac);
int result = key.verify(message, messageBytes, queryTsig, true, sharedHmac);
hmacAddSignature(sharedHmac, tsig);
lastsigned = nresponses;
return result;
}

if (tsig != null) {
int result = key.verify(m, b, null, false, sharedHmac);
int result = key.verify(message, messageBytes, null, false, sharedHmac);
lastsigned = nresponses;
hmacAddSignature(sharedHmac, tsig);
return result;
} else {
boolean required = nresponses - lastsigned >= 100;
if (required) {
log.debug("FORMERR: missing required signature on {}th message", nresponses);
m.tsigState = Message.TSIG_FAILED;
message.tsigState = Message.TSIG_FAILED;
return Rcode.FORMERR;
} else if (isLastMessage) {
log.debug("FORMERR: missing required signature on last message");
message.tsigState = Message.TSIG_FAILED;
return Rcode.FORMERR;
} else {
log.trace("Intermediate message {} without signature", nresponses);
addUnsignedMessageToMac(m, b, sharedHmac);
addUnsignedMessageToMac(message, messageBytes, sharedHmac);
return Rcode.NOERROR;
}
}
Expand Down
40 changes: 33 additions & 7 deletions src/test/java/org/xbill/DNS/TSIGTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void testTSIGStreamVerifierMissingMinimumTsig() throws Exception {

byte[] query = client.createQuery();
int numResponses = 200;
List<Message> response = server.handleQuery(query, numResponses, 200);
List<Message> response = server.handleQuery(query, numResponses, 200, false);
Map<Integer, Integer> expectedRcodes = new HashMap<>();
for (int i = 0; i < numResponses; i++) {
expectedRcodes.put(i, i < 100 ? Rcode.NOERROR : Rcode.FORMERR);
Expand All @@ -356,14 +356,35 @@ void testTSIGStreamVerifier(int numResponses, int signEvery) throws Exception {
MockMessageServer server = new MockMessageServer(defaultKey);

byte[] query = client.createQuery();
List<Message> response = server.handleQuery(query, numResponses, signEvery);
List<Message> response = server.handleQuery(query, numResponses, signEvery, false);
Map<Integer, Integer> expectedRcodes = new HashMap<>();
for (int i = 0; i < numResponses; i++) {
expectedRcodes.put(i, Rcode.NOERROR);
}
client.validateResponse(query, response, expectedRcodes);
}

@ParameterizedTest(name = "testTSIGStreamVerifierLastMessage(numResponses: {0}, signEvery: {1})")
@CsvSource({
"53,6",
"105,7",
"1000,100",
})
void testTSIGStreamVerifierLastMessage(int numResponses, int signEvery) throws Exception {
MockMessageClient client = new MockMessageClient(defaultKey);
MockMessageServer server = new MockMessageServer(defaultKey);

byte[] query = client.createQuery();
List<Message> response = server.handleQuery(query, numResponses, signEvery, true);
Map<Integer, Integer> expectedRcodes = new HashMap<>();
for (int i = 0; i < numResponses; i++) {
expectedRcodes.put(i, Rcode.NOERROR);
}

expectedRcodes.put(numResponses - 1, Rcode.FORMERR);
client.validateResponse(query, response, expectedRcodes);
}

private static class MockMessageClient {
private final TSIG key;

Expand All @@ -387,9 +408,13 @@ public void validateResponse(

Map<Integer, Integer> actualRcodes = new HashMap<>();
for (int i = 0; i < responses.size(); i++) {
Message response = responses.get(i);
byte[] renderedMessage = response.toWire(Message.MAXLENGTH);
actualRcodes.put(i, verifier.verify(new Message(renderedMessage), renderedMessage));
boolean isLastMessage = i == responses.size() - 1;
byte[] renderedMessage = responses.get(i).toWire(Message.MAXLENGTH);
Message messageFromWire = new Message(renderedMessage);
actualRcodes.put(i, verifier.verify(messageFromWire, renderedMessage, isLastMessage));
if (isLastMessage) {
assertFalse(messageFromWire.isVerified());
}
}

assertEquals(expectedRcodes, actualRcodes);
Expand All @@ -403,7 +428,8 @@ private static class MockMessageServer {
this.key = key;
}

List<Message> handleQuery(byte[] queryMessageBytes, int responseMessageCount, int signEvery)
List<Message> handleQuery(
byte[] queryMessageBytes, int responseMessageCount, int signEvery, boolean skipLast)
throws Exception {
Message parsedQueryMessage = new Message(queryMessageBytes);
assertNotNull(parsedQueryMessage.getTSIG());
Expand All @@ -423,7 +449,7 @@ List<Message> handleQuery(byte[] queryMessageBytes, int responseMessageCount, in
InetAddress.getByAddress(ByteBuffer.allocate(4).putInt(i).array()));
response.addRecord(answer, Section.ANSWER);

generator.generate(response, i == responseMessageCount - 1);
generator.generate(response, !skipLast && i == responseMessageCount - 1);
responseMessageList.add(response);
}

Expand Down
0