Skip to content

Commit

Permalink
More work on zone mutation logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
alanedwardes committed Feb 13, 2024
1 parent 86e45b8 commit 386690b
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 12 deletions.
15 changes: 12 additions & 3 deletions src/Ae.Dns.Client/DnsZoneUpdateClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@ public async Task<DnsMessage> Query(DnsMessage query, CancellationToken token =
var hostnames = query.Nameservers.Select(x => x.Host.ToString()).ToArray();
var addresses = query.Nameservers.Select(x => x.Resource).OfType<DnsIpAddressResource>().Select(x => x.IPAddress).ToArray();

void ChangeRecords(ICollection<DnsResourceRecord> records)
DnsResponseCode ChangeRecords(ICollection<DnsResourceRecord> records)
{
var preRequisitesResponseCode = _dnsZone.TestZoneUpdatePreRequisites(query);
if (preRequisitesResponseCode != DnsResponseCode.NoError)
{
_logger.LogWarning("Pre-requisites check resulted in {ResponseCode} for {Update}", preRequisitesResponseCode, query);
return preRequisitesResponseCode;
}

foreach (var recordToRemove in records.Where(x => hostnames.Contains(x.Host.ToString())).ToArray())
{
records.Remove(recordToRemove);
Expand All @@ -63,12 +70,14 @@ void ChangeRecords(ICollection<DnsResourceRecord> records)
{
records.Add(nameserver);
}

return DnsResponseCode.NoError;
};

if (query.Nameservers.Count > 0 && hostnames.All(x => !Regex.IsMatch(x, @"\s")) && hostnames.All(x => x.ToString().EndsWith(_dnsZone.Origin)))
{
await _dnsZone.Update(ChangeRecords);
return query.CreateAnswerMessage(DnsResponseCode.NoError, ToString());
var responseCode = await _dnsZone.Update(ChangeRecords);
return query.CreateAnswerMessage(responseCode, ToString());
}

return query.CreateAnswerMessage(DnsResponseCode.Refused, ToString());
Expand Down
5 changes: 3 additions & 2 deletions src/Ae.Dns.Protocol/Zone/DnsZone.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ public DnsZone(IEnumerable<DnsResourceRecord> records)
public override string ToString() => Origin;

/// <inheritdoc/>
public async Task Update(Action<IList<DnsResourceRecord>> modification)
public async Task<TResult> Update<TResult>(Func<IList<DnsResourceRecord>, TResult> modification)
{
await _semaphore.WaitAsync();
try
{
modification(_records);
var result = modification(_records);
await ZoneUpdated(this);
return result;
}
finally
{
Expand Down
60 changes: 59 additions & 1 deletion src/Ae.Dns.Protocol/Zone/DnsZoneExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using Ae.Dns.Protocol.Enums;
using Ae.Dns.Protocol.Records;
using System.Collections;
using System.Collections.Generic;
using System.Linq;

namespace Ae.Dns.Protocol.Zone
Expand Down Expand Up @@ -73,7 +76,7 @@ public static DnsResponseCode TestZoneUpdatePreRequisites(this IDnsZone zone, Dn
}
}
}
else if (rr.Class == zone.Records.First().Class)
else if (rr.Class == zone.GetZoneClass())
{
if (!zone.Records.Any(x => x.Host == rr.Host && x.Type == rr.Type && Equals(x.Resource, rr.Resource)))
{
Expand All @@ -88,5 +91,60 @@ public static DnsResponseCode TestZoneUpdatePreRequisites(this IDnsZone zone, Dn

return DnsResponseCode.NoError;
}

/// <summary>
/// Perform record updates for the specified <see cref="IDnsZone"/>.
/// </summary>
/// <param name="zone"></param>
/// <param name="records"></param>
/// <param name="updateMessage"></param>
public static void PerformZoneUpdates(this IDnsZone zone, ICollection<DnsResourceRecord> records, DnsMessage updateMessage)
{
foreach (var rr in updateMessage.Nameservers)
{
if (rr.Class == zone.GetZoneClass())
{
var existingRecords = zone.Records.Where(x => x.Host == rr.Host && x.Type == rr.Type).ToArray();
if (existingRecords.Any())
{
foreach (var existingRecord in existingRecords)
{
existingRecord.Resource = rr.Resource;
existingRecord.TimeToLive = rr.TimeToLive;
}
}
else
{
records.Add(rr);
}
}

if (rr.Class == DnsQueryClass.QCLASS_ANY)
{

}

if (rr.Class == DnsQueryClass.QCLASS_NONE)
{

}
}
}

/// <summary>
/// Get the class of the <see cref="IDnsZone"/>.
/// </summary>
/// <param name="zone"></param>
/// <returns></returns>
public static DnsQueryClass GetZoneClass(this IDnsZone zone)
{
// TODO: force SOA in IDnsZone and use that...
if (zone.Records.Any())
{
return zone.Records.First().Class;
}

return DnsQueryClass.IN;
}
}
}
2 changes: 1 addition & 1 deletion src/Ae.Dns.Protocol/Zone/IDnsZone.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public interface IDnsZone
/// </summary>
/// <param name="modification"></param>
/// <returns></returns>
Task Update(Action<IList<DnsResourceRecord>> modification);
Task<TResult> Update<TResult>(Func<IList<DnsResourceRecord>, TResult> modification);

/// <summary>
/// Serialize the zone to a string.
Expand Down
5 changes: 2 additions & 3 deletions tests/Ae.Dns.Tests/Client/DnsUpdateClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ public sealed class TestDnsZone : IDnsZone
public string FromFormattedHost(string host) => throw new NotImplementedException();
public string SerializeZone() => throw new NotImplementedException();
public string ToFormattedHost(string host) => throw new NotImplementedException();
public Task Update(Action<IList<DnsResourceRecord>> modification)
public Task<TResult> Update<TResult>(Func<IList<DnsResourceRecord>, TResult> modification)
{
modification(_records);
return Task.CompletedTask;
return Task.FromResult(modification(_records));
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/Ae.Dns.Tests/Client/Lookup/DnsZoneLookupTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public string ToFormattedHost(string host)
throw new NotImplementedException();
}

public Task Update(Action<IList<DnsResourceRecord>> modification)
public Task<TResult> Update<TResult>(Func<IList<DnsResourceRecord>, TResult> modification)
{
throw new NotImplementedException();
}
Expand Down Expand Up @@ -125,7 +125,7 @@ public string ToFormattedHost(string host)
throw new NotImplementedException();
}

public Task Update(Action<IList<DnsResourceRecord>> modification)
public Task<TResult> Update<TResult>(Func<IList<DnsResourceRecord>, TResult> modification)
{
throw new NotImplementedException();
}
Expand Down
2 changes: 2 additions & 0 deletions tests/Ae.Dns.Tests/Zone/DnsZoneTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ await originalZone.Update(recordsToUpdate =>
{
recordsToUpdate.Add(record);
}
return true;
});

var serialized = originalZone.SerializeZone();
Expand Down

0 comments on commit 386690b

Please sign in to comment.