Skip to content

chore: add network extension manager #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Coder Desktop/Coder Desktop.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
@@ -832,7 +832,7 @@
DEAD_CODE_STRIPPING = YES;
DEVELOPMENT_TEAM = 4399GN35BJ;
GENERATE_INFOPLIST_FILE = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MACOSX_DEPLOYMENT_TARGET = 14.6;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-DesktopTests";
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -851,7 +851,7 @@
DEAD_CODE_STRIPPING = YES;
DEVELOPMENT_TEAM = 4399GN35BJ;
GENERATE_INFOPLIST_FILE = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MACOSX_DEPLOYMENT_TARGET = 14.6;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-DesktopTests";
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -869,7 +869,7 @@
DEAD_CODE_STRIPPING = YES;
DEVELOPMENT_TEAM = 4399GN35BJ;
GENERATE_INFOPLIST_FILE = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MACOSX_DEPLOYMENT_TARGET = 14.6;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-DesktopUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -887,7 +887,7 @@
DEAD_CODE_STRIPPING = YES;
DEVELOPMENT_TEAM = 4399GN35BJ;
GENERATE_INFOPLIST_FILE = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MACOSX_DEPLOYMENT_TARGET = 14.6;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-DesktopUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -1038,7 +1038,7 @@
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 4399GN35BJ;
GENERATE_INFOPLIST_FILE = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MACOSX_DEPLOYMENT_TARGET = 14.6;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop.VPNLibTests";
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -1055,7 +1055,7 @@
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 4399GN35BJ;
GENERATE_INFOPLIST_FILE = YES;
MACOSX_DEPLOYMENT_TARGET = 15.0;
MACOSX_DEPLOYMENT_TARGET = 14.6;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop.VPNLibTests";
PRODUCT_NAME = "$(TARGET_NAME)";
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ struct PreviewClient: Client {
roles: []
)
} catch {
throw ClientError.reqError(AFError.explicitlyCancelled)
throw .reqError(.explicitlyCancelled)
}
}
}
8 changes: 4 additions & 4 deletions Coder Desktop/Coder Desktop/SDK/Client.swift
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ struct CoderClient: Client {
case let .success(data):
return HTTPResponse(resp: out.response!, data: data, req: out.request)
case let .failure(error):
throw ClientError.reqError(error)
throw .reqError(error)
}
}

@@ -58,7 +58,7 @@ struct CoderClient: Client {
case let .success(data):
return HTTPResponse(resp: out.response!, data: data, req: out.request)
case let .failure(error):
throw ClientError.reqError(error)
throw .reqError(error)
}
}

@@ -71,9 +71,9 @@ struct CoderClient: Client {
method: resp.req?.httpMethod,
url: resp.req?.url
)
return ClientError.apiError(out)
return .apiError(out)
} catch {
return ClientError.unexpectedResponse(resp.data[...1024])
return .unexpectedResponse(resp.data[...1024])
}
}

2 changes: 1 addition & 1 deletion Coder Desktop/Coder Desktop/SDK/User.swift
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ extension CoderClient {
do {
return try CoderClient.decoder.decode(User.self, from: res.data)
} catch {
throw ClientError.unexpectedResponse(res.data[...1024])
throw .unexpectedResponse(res.data[...1024])
}
}
}
2 changes: 1 addition & 1 deletion Coder Desktop/Coder Desktop/Views/LoginForm.swift
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@ struct LoginForm<C: Client, S: Session>: View {
loading = true
defer { loading = false }
let client = C(url: url, token: sessionToken)
do throws(ClientError) {
do {
_ = try await client.user("me")
} catch {
loginError = .failedAuth(error)
2 changes: 1 addition & 1 deletion Coder Desktop/Coder DesktopTests/Util.swift
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ struct MockClient: Client {
struct MockErrorClient: Client {
init(url _: URL, token _: String?) {}
func user(_: String) async throws(ClientError) -> Coder_Desktop.User {
throw ClientError.reqError(.explicitlyCancelled)
throw .reqError(.explicitlyCancelled)
}
}

193 changes: 190 additions & 3 deletions Coder Desktop/VPN/Manager.swift
Original file line number Diff line number Diff line change
@@ -4,16 +4,203 @@ import VPNLib

actor Manager {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is worth unit testing, at least right now. We'd need to mock the PTP, the TunnelHandle, the validator, and eventually the XPC speaker, all for a relatively simple abstraction that's mostly error handling.

let ptp: PacketTunnelProvider
let cfg: ManagerConfig

var tunnelHandle: TunnelHandle?
var speaker: Speaker<Vpn_ManagerMessage, Vpn_TunnelMessage>?
let tunnelHandle: TunnelHandle
let speaker: Speaker<Vpn_ManagerMessage, Vpn_TunnelMessage>
var readLoop: Task<Void, any Error>!
// TODO: XPC Speaker

private let dest = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
.first!.appending(path: "coder-vpn.dylib")
private let logger = Logger(subsystem: Bundle.main.bundleIdentifier!, category: "manager")

init(with: PacketTunnelProvider) {
init(with: PacketTunnelProvider, cfg: ManagerConfig) async throws(ManagerError) {
ptp = with
self.cfg = cfg
#if arch(arm64)
let dylibPath = cfg.serverUrl.appending(path: "bin/coder-vpn-arm64.dylib")
#elseif arch(x86_64)
let dylibPath = cfg.serverUrl.appending(path: "bin/coder-vpn-amd64.dylib")
#else
fatalError("unknown architecture")
#endif
do {
try await download(src: dylibPath, dest: dest)
} catch {
throw .download(error)
}
do {
try SignatureValidator.validate(path: dest)
} catch {
throw .validation(error)
}
do {
try tunnelHandle = TunnelHandle(dylibPath: dest)
} catch {
throw .tunnelSetup(error)
}
speaker = await Speaker<Vpn_ManagerMessage, Vpn_TunnelMessage>(
writeFD: tunnelHandle.writeHandle,
readFD: tunnelHandle.readHandle
)
do {
try await speaker.handshake()
} catch {
throw .handshake(error)
}
readLoop = Task { try await run() }
}

func run() async throws {
do {
for try await m in speaker {
switch m {
case let .message(msg):
handleMessage(msg)
case let .RPC(rpc):
handleRPC(rpc)
}
}
} catch {
logger.error("tunnel read loop failed: \(error)")
try await tunnelHandle.close()
// TODO: Notify app over XPC
return
}
logger.info("tunnel read loop exited")
try await tunnelHandle.close()
// TODO: Notify app over XPC
}

func handleMessage(_ msg: Vpn_TunnelMessage) {
guard let msgType = msg.msg else {
logger.critical("received message with no type")
return
}
switch msgType {
case .peerUpdate:
{}() // TODO: Send over XPC
case let .log(logMsg):
writeVpnLog(logMsg)
case .networkSettings, .start, .stop:
logger.critical("received unexpected message: `\(String(describing: msgType))`")
}
}

func handleRPC(_ rpc: RPCRequest<Vpn_ManagerMessage, Vpn_TunnelMessage>) {
guard let msgType = rpc.msg.msg else {
logger.critical("received rpc with no type")
return
}
switch msgType {
case let .networkSettings(ns):
let neSettings = convertNetworkSettingsRequest(ns)
ptp.setTunnelNetworkSettings(neSettings)
case .log, .peerUpdate, .start, .stop:
logger.critical("received unexpected rpc: `\(String(describing: msgType))`")
}
}

// TODO: Call via XPC
func startVPN() async throws(ManagerError) {
logger.info("sending start rpc")
guard let tunFd = ptp.tunnelFileDescriptor else {
throw .noTunnelFileDescriptor
}
let resp: Vpn_TunnelMessage
do {
resp = try await speaker.unaryRPC(.with { msg in
msg.start = .with { req in
req.tunnelFileDescriptor = tunFd
req.apiToken = cfg.apiToken
req.coderURL = cfg.serverUrl.absoluteString
}
})
} catch {
throw .failedRPC(error)
}
guard case let .start(startResp) = resp.msg else {
throw .incorrectResponse(resp)
}
if !startResp.success {
throw .errorResponse(msg: startResp.errorMessage)
}
// TODO: notify app over XPC
}

// TODO: Call via XPC
func stopVPN() async throws(ManagerError) {
logger.info("sending stop rpc")
let resp: Vpn_TunnelMessage
do {
resp = try await speaker.unaryRPC(.with { msg in
msg.stop = .init()
})
} catch {
throw .failedRPC(error)
}
guard case let .stop(stopResp) = resp.msg else {
throw .incorrectResponse(resp)
}
if !stopResp.success {
throw .errorResponse(msg: stopResp.errorMessage)
}
// TODO: notify app over XPC
}

// TODO: Call via XPC
// Retrieves the current state of all peers,
// as required when starting the app whilst the network extension is already running
func getPeerInfo() async throws(ManagerError) {
logger.info("sending peer state request")
let resp: Vpn_TunnelMessage
do {
resp = try await speaker.unaryRPC(.with { msg in
msg.getPeerUpdate = .init()
})
} catch {
throw .failedRPC(error)
}
guard case .peerUpdate = resp.msg else {
throw .incorrectResponse(resp)
}
// TODO: pass to app over XPC
}
}

public struct ManagerConfig {
let apiToken: String
let serverUrl: URL
}

enum ManagerError: Error {
case download(DownloadError)
case tunnelSetup(TunnelHandleError)
case handshake(HandshakeError)
case validation(ValidationError)
case incorrectResponse(Vpn_TunnelMessage)
case failedRPC(any Error)
case errorResponse(msg: String)
case noTunnelFileDescriptor
}

func writeVpnLog(_ log: Vpn_Log) {
let level: OSLogType = switch log.level {
case .info: .info
case .debug: .debug
// warn == error
case .warn: .error
case .error: .error
// critical == fatal == fault
case .critical: .fault
case .fatal: .fault
case .UNRECOGNIZED: .info
}
let logger = Logger(
subsystem: "\(Bundle.main.bundleIdentifier!).dylib",
category: log.loggerNames.joined(separator: ".")
)
let fields = log.fields.map { "\($0.name): \($0.value)" }.joined(separator: ", ")
logger.log(level: level, "\(log.message): \(fields)")
}
12 changes: 9 additions & 3 deletions Coder Desktop/VPN/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@ import os
let CTLIOCGINFO: UInt = 0xC064_4E03

class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
private let logger = Logger(subsystem: Bundle.main.bundleIdentifier!, category: "network-extension")
private let logger = Logger(subsystem: Bundle.main.bundleIdentifier!, category: "packet-tunnel-provider")
private var manager: Manager?

private var tunnelFileDescriptor: Int32? {
public var tunnelFileDescriptor: Int32? {
var ctlInfo = ctl_info()
withUnsafeMutablePointer(to: &ctlInfo.ctl_name) {
$0.withMemoryRebound(to: CChar.self, capacity: MemoryLayout.size(ofValue: $0.pointee)) {
@@ -46,7 +46,13 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
completionHandler(nil)
return
}
manager = Manager(with: self)
Task {
// TODO: Retrieve access URL & Token via Keychain
manager = try await Manager(
with: self,
cfg: .init(apiToken: "fake-token", serverUrl: .init(string: "https://dev.coder.com")!)
)
}
completionHandler(nil)
}

60 changes: 60 additions & 0 deletions Coder Desktop/VPNLib/Convert.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import NetworkExtension
import os

// swiftlint:disable:next function_body_length
public func convertNetworkSettingsRequest(_ req: Vpn_NetworkSettingsRequest) -> NEPacketTunnelNetworkSettings {
let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: req.tunnelRemoteAddress)
networkSettings.tunnelOverheadBytes = NSNumber(value: req.tunnelOverheadBytes)
networkSettings.mtu = NSNumber(value: req.mtu)

if req.hasDnsSettings {
let dnsSettings = NEDNSSettings(servers: req.dnsSettings.servers)
dnsSettings.searchDomains = req.dnsSettings.searchDomains
dnsSettings.domainName = req.dnsSettings.domainName
dnsSettings.matchDomains = req.dnsSettings.matchDomains
dnsSettings.matchDomainsNoSearch = req.dnsSettings.matchDomainsNoSearch
networkSettings.dnsSettings = dnsSettings
}

if req.hasIpv4Settings {
let ipv4Settings = NEIPv4Settings(addresses: req.ipv4Settings.addrs, subnetMasks: req.ipv4Settings.subnetMasks)
ipv4Settings.router = req.ipv4Settings.router
ipv4Settings.includedRoutes = req.ipv4Settings.includedRoutes.map {
let route = NEIPv4Route(destinationAddress: $0.destination, subnetMask: $0.mask)
route.gatewayAddress = $0.router
return route
}
ipv4Settings.excludedRoutes = req.ipv4Settings.excludedRoutes.map {
let route = NEIPv4Route(destinationAddress: $0.destination, subnetMask: $0.mask)
route.gatewayAddress = $0.router
return route
}
networkSettings.ipv4Settings = ipv4Settings
}

if req.hasIpv6Settings {
let ipv6Settings = NEIPv6Settings(
addresses: req.ipv6Settings.addrs,
networkPrefixLengths: req.ipv6Settings.prefixLengths.map { NSNumber(value: $0)
}
)
ipv6Settings.includedRoutes = req.ipv6Settings.includedRoutes.map {
let route = NEIPv6Route(
destinationAddress: $0.destination,
networkPrefixLength: NSNumber(value: $0.prefixLength)
)
route.gatewayAddress = $0.router
return route
}
ipv6Settings.excludedRoutes = req.ipv6Settings.excludedRoutes.map {
let route = NEIPv6Route(
destinationAddress: $0.destination,
networkPrefixLength: NSNumber(value: $0.prefixLength)
)
route.gatewayAddress = $0.router
return route
}
networkSettings.ipv6Settings = ipv6Settings
}
return networkSettings
}
2 changes: 1 addition & 1 deletion Coder Desktop/VPNLib/Receiver.swift
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ actor Receiver<RecvMsg: Message> {
/// On read or decoding error, it logs and closes the stream.
func messages() throws(ReceiveError) -> AsyncStream<RecvMsg> {
if running {
throw ReceiveError.alreadyRunning
throw .alreadyRunning
}
running = true
return AsyncStream(
69 changes: 42 additions & 27 deletions Coder Desktop/VPNLib/Speaker.swift
Original file line number Diff line number Diff line change
@@ -22,27 +22,27 @@ enum ProtoRole: String {
}

/// A version of the VPN protocol that can be negotiated.
struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
public struct ProtoVersion: CustomStringConvertible, Equatable, Codable, Sendable {
let major: Int
let minor: Int

var description: String { "\(major).\(minor)" }
public var description: String { "\(major).\(minor)" }

init(_ major: Int, _ minor: Int) {
self.major = major
self.minor = minor
}

init(parse str: String) throws {
init(parse str: String) throws(HandshakeError) {
let parts = str.split(separator: ".").map { Int($0) }
if parts.count != 2 {
throw HandshakeError.invalidVersion(str)
throw .invalidVersion(str)
}
guard let major = parts[0] else {
throw HandshakeError.invalidVersion(str)
throw .invalidVersion(str)
}
guard let minor = parts[1] else {
throw HandshakeError.invalidVersion(str)
throw .invalidVersion(str)
}
self.major = major
self.minor = minor
@@ -87,14 +87,14 @@ public actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Messag
}

/// Does the VPN Protocol handshake and validates the result
func handshake() async throws {
public func handshake() async throws(HandshakeError) {
let hndsh = Handshaker(writeFD: writeFD, dispatch: dispatch, queue: queue, role: role)
// ignore the version for now because we know it can only be 1.0
try _ = await hndsh.handshake()
}

/// Send a unary RPC message and handle the response
func unaryRPC(_ req: SendMsg) async throws -> RecvMsg {
public func unaryRPC(_ req: SendMsg) async throws -> RecvMsg {
return try await withCheckedThrowingContinuation { continuation in
Task { [sender, secretary, logger] in
let msgID = await secretary.record(continuation: continuation)
@@ -114,15 +114,15 @@ public actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Messag
}
}

func closeWrite() {
public func closeWrite() {
do {
try writeFD.close()
} catch {
logger.error("failed to close write file handle: \(error)")
}
}

func closeRead() {
public func closeRead() {
do {
try readFD.close()
} catch {
@@ -153,8 +153,8 @@ extension Speaker: AsyncSequence, AsyncIteratorProtocol {
}
guard msg.rpc.responseTo == 0 else {
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
do throws(RPCError) {
try await self.secretary.route(reply: msg)
do {
try await secretary.route(reply: msg)
} catch {
logger.error(
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
@@ -188,7 +188,7 @@ actor Handshaker {
}

/// Performs the initial VPN protocol handshake, returning the negotiated `ProtoVersion` that we should use.
func handshake() async throws -> ProtoVersion {
func handshake() async throws(HandshakeError) -> ProtoVersion {
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
// waiting to write with nobody reading.
let readTask = Task {
@@ -201,9 +201,22 @@ actor Handshaker {

let vStr = versions.map { $0.description }.joined(separator: ",")
let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n")
try writeFD.write(contentsOf: ours.data(using: .utf8)!)
do {
try writeFD.write(contentsOf: ours.data(using: .utf8)!)
} catch {
throw HandshakeError.writeError(error)
}

do {
theirData = try await readTask.value
} catch let error as HandshakeError {
throw error
} catch {
// This can't be checked at compile-time, as both Tasks & Continuations can only ever throw
// a type-erased `Error`
Comment on lines +215 to +216
Copy link
Member Author

@ethanndickson ethanndickson Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little frustrating - theTask type itself is generic on two type parameters, Success and Failure, but the only APIs you can await on a task are either value with an untyped throws or result which returns a type-erased any Error error variant. It's very similar for continuations, the type itself is generic but there's no way to return that typed error.

As an aside, I think I'm becoming more of a fan, for binaries at least, of just having a single opaque error type, propagating it with context, and then just displaying it, ala Go or Rust's anyhow - it's so rare that one or more of the error variants are recoverable, and that the caller needs to know which one.

fatalError("handleRead must always throw HandshakeError")
}

let theirData = try await readTask.value
guard let theirsString = String(bytes: theirData, encoding: .utf8) else {
throw HandshakeError.invalidHeader("<unparsable: \(theirData)")
}
@@ -216,6 +229,7 @@ actor Handshaker {
}
}

// resumes must only ever throw HandshakeError
private func handleRead(_: Bool, _ data: DispatchData?, _ error: Int32) {
guard error == 0 else {
let errStrPtr = strerror(error)
@@ -235,7 +249,7 @@ actor Handshaker {
dispatch.read(offset: 0, length: 1, queue: queue, ioHandler: handleRead)
}

private func validateHeader(_ header: String) throws -> ProtoVersion {
private func validateHeader(_ header: String) throws(HandshakeError) -> ProtoVersion {
let parts = header.split(separator: " ")
guard parts.count == 3 else {
throw HandshakeError.invalidHeader("expected 3 parts: \(header)")
@@ -252,12 +266,12 @@ actor Handshaker {
}
let theirVersions = try parts[2]
.split(separator: ",")
.map { try ProtoVersion(parse: String($0)) }
.map { v throws(HandshakeError) in try ProtoVersion(parse: String(v)) }
return try pickVersion(ours: versions, theirs: theirVersions)
}
}

func pickVersion(ours: [ProtoVersion], theirs: [ProtoVersion]) throws -> ProtoVersion {
func pickVersion(ours: [ProtoVersion], theirs: [ProtoVersion]) throws(HandshakeError) -> ProtoVersion {
for our in ours.reversed() {
for their in theirs.reversed() where our.major == their.major {
if our.minor < their.minor {
@@ -266,27 +280,28 @@ func pickVersion(ours: [ProtoVersion], theirs: [ProtoVersion]) throws -> ProtoVe
return their
}
}
throw HandshakeError.unsupportedVersion(theirs)
throw .unsupportedVersion(theirs)
}

enum HandshakeError: Error {
public enum HandshakeError: Error {
case readError(String)
case writeError(any Error)
case invalidHeader(String)
case wrongRole(String)
case invalidVersion(String)
case unsupportedVersion([ProtoVersion])
}

public struct RPCRequest<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Sendable>: Sendable {
let msg: RecvMsg
public let msg: RecvMsg
private let sender: Sender<SendMsg>

public init(req: RecvMsg, sender: Sender<SendMsg>) {
msg = req
self.sender = sender
}

func sendReply(_ reply: SendMsg) async throws {
public func sendReply(_ reply: SendMsg) async throws {
var reply = reply
reply.rpc.responseTo = msg.rpc.msgID
try await sender.send(reply)
@@ -303,10 +318,10 @@ enum RPCError: Error {

/// An actor to record outgoing RPCs and route their replies to the original sender
actor RPCSecretary<RecvMsg: RPCMessage & Sendable> {
private var continuations: [UInt64: CheckedContinuation<RecvMsg, Error>] = [:]
private var continuations: [UInt64: CheckedContinuation<RecvMsg, any Error>] = [:]
private var nextMsgID: UInt64 = 1

func record(continuation: CheckedContinuation<RecvMsg, Error>) -> UInt64 {
func record(continuation: CheckedContinuation<RecvMsg, any Error>) -> UInt64 {
let id = nextMsgID
nextMsgID += 1
continuations[id] = continuation
@@ -326,13 +341,13 @@ actor RPCSecretary<RecvMsg: RPCMessage & Sendable> {

func route(reply: RecvMsg) throws(RPCError) {
guard reply.hasRpc else {
throw RPCError.missingRPC
throw .missingRPC
}
guard reply.rpc.responseTo != 0 else {
throw RPCError.notAResponse
throw .notAResponse
}
guard let cont = continuations[reply.rpc.responseTo] else {
throw RPCError.unknownResponseID(reply.rpc.responseTo)
throw .unknownResponseID(reply.rpc.responseTo)
}
continuations[reply.rpc.responseTo] = nil
cont.resume(returning: reply)
99 changes: 99 additions & 0 deletions Coder Desktop/VPNLibTests/ConvertTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import Testing
@testable import VPNLib

@Suite(.timeLimit(.minutes(1)))
struct ConvertTests {
@Test
// swiftlint:disable:next function_body_length
func convertProtoNetworkSettingsRequest() async throws {
let req: Vpn_NetworkSettingsRequest = .with { req in
req.tunnelRemoteAddress = "10.0.0.1"
req.tunnelOverheadBytes = 20
req.mtu = 1400

req.dnsSettings = .with { dns in
dns.servers = ["8.8.8.8"]
dns.searchDomains = ["example.com"]
dns.domainName = "example.com"
dns.matchDomains = ["example.com"]
dns.matchDomainsNoSearch = false
}

req.ipv4Settings = .with { ipv4 in
ipv4.addrs = ["192.168.1.1"]
ipv4.subnetMasks = ["255.255.255.0"]
ipv4.router = "192.168.1.254"
ipv4.includedRoutes = [
.with { route in
route.destination = "10.0.0.0"
route.mask = "255.0.0.0"
route.router = "192.168.1.254"
},
]
ipv4.excludedRoutes = [
.with { route in
route.destination = "172.16.0.0"
route.mask = "255.240.0.0"
route.router = "192.168.1.254"
},
]
}

req.ipv6Settings = .with { ipv6 in
ipv6.addrs = ["2001:db8::1"]
ipv6.prefixLengths = [64]
ipv6.includedRoutes = [
.with { route in
route.destination = "2001:db8::"
route.router = "2001:db8::1"
route.prefixLength = 64
},
]
ipv6.excludedRoutes = [
.with { route in
route.destination = "2001:0db8:85a3::"
route.router = "2001:db8::1"
route.prefixLength = 128
},
]
}
}

let result = convertNetworkSettingsRequest(req)
#expect(result.tunnelRemoteAddress == req.tunnelRemoteAddress)
#expect(result.dnsSettings!.servers == req.dnsSettings.servers)
#expect(result.dnsSettings!.domainName == req.dnsSettings.domainName)
#expect(result.ipv4Settings!.addresses == req.ipv4Settings.addrs)
#expect(result.ipv4Settings!.subnetMasks == req.ipv4Settings.subnetMasks)
#expect(result.ipv6Settings!.addresses == req.ipv6Settings.addrs)
#expect(result.ipv6Settings!.networkPrefixLengths == [64])

try #require(result.ipv4Settings!.includedRoutes?.count == req.ipv4Settings.includedRoutes.count)
let ipv4IncludedRoute = result.ipv4Settings!.includedRoutes![0]
let expectedIpv4IncludedRoute = req.ipv4Settings.includedRoutes[0]
#expect(ipv4IncludedRoute.destinationAddress == expectedIpv4IncludedRoute.destination)
#expect(ipv4IncludedRoute.destinationSubnetMask == expectedIpv4IncludedRoute.mask)
#expect(ipv4IncludedRoute.gatewayAddress == expectedIpv4IncludedRoute.router)

try #require(result.ipv4Settings!.excludedRoutes?.count == req.ipv4Settings.excludedRoutes.count)
let ipv4ExcludedRoute = result.ipv4Settings!.excludedRoutes![0]
let expectedIpv4ExcludedRoute = req.ipv4Settings.excludedRoutes[0]
#expect(ipv4ExcludedRoute.destinationAddress == expectedIpv4ExcludedRoute.destination)
#expect(ipv4ExcludedRoute.destinationSubnetMask == expectedIpv4ExcludedRoute.mask)
#expect(ipv4ExcludedRoute.gatewayAddress == expectedIpv4ExcludedRoute.router)

try #require(result.ipv6Settings!.includedRoutes?.count == req.ipv6Settings.includedRoutes.count)
let ipv6IncludedRoute = result.ipv6Settings!.includedRoutes![0]
let expectedIpv6IncludedRoute = req.ipv6Settings.includedRoutes[0]
#expect(ipv6IncludedRoute.destinationAddress == expectedIpv6IncludedRoute.destination)
#expect(ipv6IncludedRoute.destinationNetworkPrefixLength == 64)
#expect(ipv6IncludedRoute.gatewayAddress == expectedIpv6IncludedRoute.router)

try #require(result.ipv6Settings!.excludedRoutes?.count == req.ipv6Settings.excludedRoutes.count)
let ipv6ExcludedRoute = result.ipv6Settings!.excludedRoutes![0]
let expectedIpv6ExcludedRoute = req.ipv6Settings.excludedRoutes[0]
#expect(ipv6ExcludedRoute.destinationAddress == expectedIpv6ExcludedRoute.destination)
#expect(ipv6ExcludedRoute.destinationNetworkPrefixLength == 128)
#expect(ipv6ExcludedRoute.gatewayAddress == expectedIpv6ExcludedRoute.router)
}
}