-
Notifications
You must be signed in to change notification settings - Fork 3
chore: add network extension manager #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import NetworkExtension | ||
import os | ||
|
||
// swiftlint:disable:next function_body_length | ||
public func convertNetworkSettingsRequest(_ req: Vpn_NetworkSettingsRequest) -> NEPacketTunnelNetworkSettings { | ||
let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: req.tunnelRemoteAddress) | ||
networkSettings.tunnelOverheadBytes = NSNumber(value: req.tunnelOverheadBytes) | ||
networkSettings.mtu = NSNumber(value: req.mtu) | ||
|
||
if req.hasDnsSettings { | ||
let dnsSettings = NEDNSSettings(servers: req.dnsSettings.servers) | ||
dnsSettings.searchDomains = req.dnsSettings.searchDomains | ||
dnsSettings.domainName = req.dnsSettings.domainName | ||
dnsSettings.matchDomains = req.dnsSettings.matchDomains | ||
dnsSettings.matchDomainsNoSearch = req.dnsSettings.matchDomainsNoSearch | ||
networkSettings.dnsSettings = dnsSettings | ||
} | ||
|
||
if req.hasIpv4Settings { | ||
let ipv4Settings = NEIPv4Settings(addresses: req.ipv4Settings.addrs, subnetMasks: req.ipv4Settings.subnetMasks) | ||
ipv4Settings.router = req.ipv4Settings.router | ||
ipv4Settings.includedRoutes = req.ipv4Settings.includedRoutes.map { | ||
let route = NEIPv4Route(destinationAddress: $0.destination, subnetMask: $0.mask) | ||
route.gatewayAddress = $0.router | ||
return route | ||
} | ||
ipv4Settings.excludedRoutes = req.ipv4Settings.excludedRoutes.map { | ||
let route = NEIPv4Route(destinationAddress: $0.destination, subnetMask: $0.mask) | ||
route.gatewayAddress = $0.router | ||
return route | ||
} | ||
networkSettings.ipv4Settings = ipv4Settings | ||
} | ||
|
||
if req.hasIpv6Settings { | ||
let ipv6Settings = NEIPv6Settings( | ||
addresses: req.ipv6Settings.addrs, | ||
networkPrefixLengths: req.ipv6Settings.prefixLengths.map { NSNumber(value: $0) | ||
} | ||
) | ||
ipv6Settings.includedRoutes = req.ipv6Settings.includedRoutes.map { | ||
let route = NEIPv6Route( | ||
destinationAddress: $0.destination, | ||
networkPrefixLength: NSNumber(value: $0.prefixLength) | ||
) | ||
route.gatewayAddress = $0.router | ||
return route | ||
} | ||
ipv6Settings.excludedRoutes = req.ipv6Settings.excludedRoutes.map { | ||
let route = NEIPv6Route( | ||
destinationAddress: $0.destination, | ||
networkPrefixLength: NSNumber(value: $0.prefixLength) | ||
) | ||
route.gatewayAddress = $0.router | ||
return route | ||
} | ||
networkSettings.ipv6Settings = ipv6Settings | ||
} | ||
return networkSettings | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,27 +22,27 @@ enum ProtoRole: String { | |
} | ||
|
||
/// A version of the VPN protocol that can be negotiated. | ||
struct ProtoVersion: CustomStringConvertible, Equatable, Codable { | ||
public struct ProtoVersion: CustomStringConvertible, Equatable, Codable, Sendable { | ||
let major: Int | ||
let minor: Int | ||
|
||
var description: String { "\(major).\(minor)" } | ||
public var description: String { "\(major).\(minor)" } | ||
|
||
init(_ major: Int, _ minor: Int) { | ||
self.major = major | ||
self.minor = minor | ||
} | ||
|
||
init(parse str: String) throws { | ||
init(parse str: String) throws(HandshakeError) { | ||
let parts = str.split(separator: ".").map { Int($0) } | ||
if parts.count != 2 { | ||
throw HandshakeError.invalidVersion(str) | ||
throw .invalidVersion(str) | ||
} | ||
guard let major = parts[0] else { | ||
throw HandshakeError.invalidVersion(str) | ||
throw .invalidVersion(str) | ||
} | ||
guard let minor = parts[1] else { | ||
throw HandshakeError.invalidVersion(str) | ||
throw .invalidVersion(str) | ||
} | ||
self.major = major | ||
self.minor = minor | ||
|
@@ -87,14 +87,14 @@ public actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Messag | |
} | ||
|
||
/// Does the VPN Protocol handshake and validates the result | ||
func handshake() async throws { | ||
public func handshake() async throws(HandshakeError) { | ||
let hndsh = Handshaker(writeFD: writeFD, dispatch: dispatch, queue: queue, role: role) | ||
// ignore the version for now because we know it can only be 1.0 | ||
try _ = await hndsh.handshake() | ||
} | ||
|
||
/// Send a unary RPC message and handle the response | ||
func unaryRPC(_ req: SendMsg) async throws -> RecvMsg { | ||
public func unaryRPC(_ req: SendMsg) async throws -> RecvMsg { | ||
return try await withCheckedThrowingContinuation { continuation in | ||
Task { [sender, secretary, logger] in | ||
let msgID = await secretary.record(continuation: continuation) | ||
|
@@ -114,15 +114,15 @@ public actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Messag | |
} | ||
} | ||
|
||
func closeWrite() { | ||
public func closeWrite() { | ||
do { | ||
try writeFD.close() | ||
} catch { | ||
logger.error("failed to close write file handle: \(error)") | ||
} | ||
} | ||
|
||
func closeRead() { | ||
public func closeRead() { | ||
do { | ||
try readFD.close() | ||
} catch { | ||
|
@@ -153,8 +153,8 @@ extension Speaker: AsyncSequence, AsyncIteratorProtocol { | |
} | ||
guard msg.rpc.responseTo == 0 else { | ||
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)") | ||
do throws(RPCError) { | ||
try await self.secretary.route(reply: msg) | ||
do { | ||
try await secretary.route(reply: msg) | ||
} catch { | ||
logger.error( | ||
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)") | ||
|
@@ -188,7 +188,7 @@ actor Handshaker { | |
} | ||
|
||
/// Performs the initial VPN protocol handshake, returning the negotiated `ProtoVersion` that we should use. | ||
func handshake() async throws -> ProtoVersion { | ||
func handshake() async throws(HandshakeError) -> ProtoVersion { | ||
// kick off the read async before we try to write, synchronously, so we don't deadlock, both | ||
// waiting to write with nobody reading. | ||
let readTask = Task { | ||
|
@@ -201,9 +201,22 @@ actor Handshaker { | |
|
||
let vStr = versions.map { $0.description }.joined(separator: ",") | ||
let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n") | ||
try writeFD.write(contentsOf: ours.data(using: .utf8)!) | ||
do { | ||
try writeFD.write(contentsOf: ours.data(using: .utf8)!) | ||
} catch { | ||
throw HandshakeError.writeError(error) | ||
} | ||
|
||
do { | ||
theirData = try await readTask.value | ||
} catch let error as HandshakeError { | ||
throw error | ||
} catch { | ||
// This can't be checked at compile-time, as both Tasks & Continuations can only ever throw | ||
// a type-erased `Error` | ||
Comment on lines
+215
to
+216
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a little frustrating - the As an aside, I think I'm becoming more of a fan, for binaries at least, of just having a single opaque error type, propagating it with context, and then just displaying it, ala Go or Rust's |
||
fatalError("handleRead must always throw HandshakeError") | ||
} | ||
|
||
let theirData = try await readTask.value | ||
guard let theirsString = String(bytes: theirData, encoding: .utf8) else { | ||
throw HandshakeError.invalidHeader("<unparsable: \(theirData)") | ||
} | ||
|
@@ -216,6 +229,7 @@ actor Handshaker { | |
} | ||
} | ||
|
||
// resumes must only ever throw HandshakeError | ||
private func handleRead(_: Bool, _ data: DispatchData?, _ error: Int32) { | ||
guard error == 0 else { | ||
let errStrPtr = strerror(error) | ||
|
@@ -235,7 +249,7 @@ actor Handshaker { | |
dispatch.read(offset: 0, length: 1, queue: queue, ioHandler: handleRead) | ||
} | ||
|
||
private func validateHeader(_ header: String) throws -> ProtoVersion { | ||
private func validateHeader(_ header: String) throws(HandshakeError) -> ProtoVersion { | ||
let parts = header.split(separator: " ") | ||
guard parts.count == 3 else { | ||
throw HandshakeError.invalidHeader("expected 3 parts: \(header)") | ||
|
@@ -252,12 +266,12 @@ actor Handshaker { | |
} | ||
let theirVersions = try parts[2] | ||
.split(separator: ",") | ||
.map { try ProtoVersion(parse: String($0)) } | ||
.map { v throws(HandshakeError) in try ProtoVersion(parse: String(v)) } | ||
return try pickVersion(ours: versions, theirs: theirVersions) | ||
} | ||
} | ||
|
||
func pickVersion(ours: [ProtoVersion], theirs: [ProtoVersion]) throws -> ProtoVersion { | ||
func pickVersion(ours: [ProtoVersion], theirs: [ProtoVersion]) throws(HandshakeError) -> ProtoVersion { | ||
for our in ours.reversed() { | ||
for their in theirs.reversed() where our.major == their.major { | ||
if our.minor < their.minor { | ||
|
@@ -266,27 +280,28 @@ func pickVersion(ours: [ProtoVersion], theirs: [ProtoVersion]) throws -> ProtoVe | |
return their | ||
} | ||
} | ||
throw HandshakeError.unsupportedVersion(theirs) | ||
throw .unsupportedVersion(theirs) | ||
} | ||
|
||
enum HandshakeError: Error { | ||
public enum HandshakeError: Error { | ||
case readError(String) | ||
case writeError(any Error) | ||
case invalidHeader(String) | ||
case wrongRole(String) | ||
case invalidVersion(String) | ||
case unsupportedVersion([ProtoVersion]) | ||
} | ||
|
||
public struct RPCRequest<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Sendable>: Sendable { | ||
let msg: RecvMsg | ||
public let msg: RecvMsg | ||
private let sender: Sender<SendMsg> | ||
|
||
public init(req: RecvMsg, sender: Sender<SendMsg>) { | ||
msg = req | ||
self.sender = sender | ||
} | ||
|
||
func sendReply(_ reply: SendMsg) async throws { | ||
public func sendReply(_ reply: SendMsg) async throws { | ||
var reply = reply | ||
reply.rpc.responseTo = msg.rpc.msgID | ||
try await sender.send(reply) | ||
|
@@ -303,10 +318,10 @@ enum RPCError: Error { | |
|
||
/// An actor to record outgoing RPCs and route their replies to the original sender | ||
actor RPCSecretary<RecvMsg: RPCMessage & Sendable> { | ||
private var continuations: [UInt64: CheckedContinuation<RecvMsg, Error>] = [:] | ||
private var continuations: [UInt64: CheckedContinuation<RecvMsg, any Error>] = [:] | ||
private var nextMsgID: UInt64 = 1 | ||
|
||
func record(continuation: CheckedContinuation<RecvMsg, Error>) -> UInt64 { | ||
func record(continuation: CheckedContinuation<RecvMsg, any Error>) -> UInt64 { | ||
let id = nextMsgID | ||
nextMsgID += 1 | ||
continuations[id] = continuation | ||
|
@@ -326,13 +341,13 @@ actor RPCSecretary<RecvMsg: RPCMessage & Sendable> { | |
|
||
func route(reply: RecvMsg) throws(RPCError) { | ||
guard reply.hasRpc else { | ||
throw RPCError.missingRPC | ||
throw .missingRPC | ||
} | ||
guard reply.rpc.responseTo != 0 else { | ||
throw RPCError.notAResponse | ||
throw .notAResponse | ||
} | ||
guard let cont = continuations[reply.rpc.responseTo] else { | ||
throw RPCError.unknownResponseID(reply.rpc.responseTo) | ||
throw .unknownResponseID(reply.rpc.responseTo) | ||
} | ||
continuations[reply.rpc.responseTo] = nil | ||
cont.resume(returning: reply) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import Testing | ||
@testable import VPNLib | ||
|
||
@Suite(.timeLimit(.minutes(1))) | ||
struct ConvertTests { | ||
@Test | ||
// swiftlint:disable:next function_body_length | ||
func convertProtoNetworkSettingsRequest() async throws { | ||
let req: Vpn_NetworkSettingsRequest = .with { req in | ||
req.tunnelRemoteAddress = "10.0.0.1" | ||
req.tunnelOverheadBytes = 20 | ||
req.mtu = 1400 | ||
|
||
req.dnsSettings = .with { dns in | ||
dns.servers = ["8.8.8.8"] | ||
dns.searchDomains = ["example.com"] | ||
dns.domainName = "example.com" | ||
dns.matchDomains = ["example.com"] | ||
dns.matchDomainsNoSearch = false | ||
} | ||
|
||
req.ipv4Settings = .with { ipv4 in | ||
ipv4.addrs = ["192.168.1.1"] | ||
ipv4.subnetMasks = ["255.255.255.0"] | ||
ipv4.router = "192.168.1.254" | ||
ipv4.includedRoutes = [ | ||
.with { route in | ||
route.destination = "10.0.0.0" | ||
route.mask = "255.0.0.0" | ||
route.router = "192.168.1.254" | ||
}, | ||
] | ||
ipv4.excludedRoutes = [ | ||
.with { route in | ||
route.destination = "172.16.0.0" | ||
route.mask = "255.240.0.0" | ||
route.router = "192.168.1.254" | ||
}, | ||
] | ||
} | ||
|
||
req.ipv6Settings = .with { ipv6 in | ||
ipv6.addrs = ["2001:db8::1"] | ||
ipv6.prefixLengths = [64] | ||
ipv6.includedRoutes = [ | ||
.with { route in | ||
route.destination = "2001:db8::" | ||
route.router = "2001:db8::1" | ||
route.prefixLength = 64 | ||
}, | ||
] | ||
ipv6.excludedRoutes = [ | ||
.with { route in | ||
route.destination = "2001:0db8:85a3::" | ||
route.router = "2001:db8::1" | ||
route.prefixLength = 128 | ||
}, | ||
] | ||
} | ||
} | ||
|
||
let result = convertNetworkSettingsRequest(req) | ||
#expect(result.tunnelRemoteAddress == req.tunnelRemoteAddress) | ||
#expect(result.dnsSettings!.servers == req.dnsSettings.servers) | ||
#expect(result.dnsSettings!.domainName == req.dnsSettings.domainName) | ||
#expect(result.ipv4Settings!.addresses == req.ipv4Settings.addrs) | ||
#expect(result.ipv4Settings!.subnetMasks == req.ipv4Settings.subnetMasks) | ||
#expect(result.ipv6Settings!.addresses == req.ipv6Settings.addrs) | ||
#expect(result.ipv6Settings!.networkPrefixLengths == [64]) | ||
|
||
try #require(result.ipv4Settings!.includedRoutes?.count == req.ipv4Settings.includedRoutes.count) | ||
let ipv4IncludedRoute = result.ipv4Settings!.includedRoutes![0] | ||
let expectedIpv4IncludedRoute = req.ipv4Settings.includedRoutes[0] | ||
#expect(ipv4IncludedRoute.destinationAddress == expectedIpv4IncludedRoute.destination) | ||
#expect(ipv4IncludedRoute.destinationSubnetMask == expectedIpv4IncludedRoute.mask) | ||
#expect(ipv4IncludedRoute.gatewayAddress == expectedIpv4IncludedRoute.router) | ||
|
||
try #require(result.ipv4Settings!.excludedRoutes?.count == req.ipv4Settings.excludedRoutes.count) | ||
let ipv4ExcludedRoute = result.ipv4Settings!.excludedRoutes![0] | ||
let expectedIpv4ExcludedRoute = req.ipv4Settings.excludedRoutes[0] | ||
#expect(ipv4ExcludedRoute.destinationAddress == expectedIpv4ExcludedRoute.destination) | ||
#expect(ipv4ExcludedRoute.destinationSubnetMask == expectedIpv4ExcludedRoute.mask) | ||
#expect(ipv4ExcludedRoute.gatewayAddress == expectedIpv4ExcludedRoute.router) | ||
|
||
try #require(result.ipv6Settings!.includedRoutes?.count == req.ipv6Settings.includedRoutes.count) | ||
let ipv6IncludedRoute = result.ipv6Settings!.includedRoutes![0] | ||
let expectedIpv6IncludedRoute = req.ipv6Settings.includedRoutes[0] | ||
#expect(ipv6IncludedRoute.destinationAddress == expectedIpv6IncludedRoute.destination) | ||
#expect(ipv6IncludedRoute.destinationNetworkPrefixLength == 64) | ||
#expect(ipv6IncludedRoute.gatewayAddress == expectedIpv6IncludedRoute.router) | ||
|
||
try #require(result.ipv6Settings!.excludedRoutes?.count == req.ipv6Settings.excludedRoutes.count) | ||
let ipv6ExcludedRoute = result.ipv6Settings!.excludedRoutes![0] | ||
let expectedIpv6ExcludedRoute = req.ipv6Settings.excludedRoutes[0] | ||
#expect(ipv6ExcludedRoute.destinationAddress == expectedIpv6ExcludedRoute.destination) | ||
#expect(ipv6ExcludedRoute.destinationNetworkPrefixLength == 128) | ||
#expect(ipv6ExcludedRoute.gatewayAddress == expectedIpv6ExcludedRoute.router) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is worth unit testing, at least right now. We'd need to mock the PTP, the TunnelHandle, the validator, and eventually the XPC speaker, all for a relatively simple abstraction that's mostly error handling.