diff --git a/unxip.swift b/unxip.swift index 0841ef1..ded8850 100644 --- a/unxip.swift +++ b/unxip.swift @@ -439,6 +439,41 @@ extension option { // MARK: - Public API +enum UnxipError: Error { + case truncated + case invalid + + // https://github.com/apple/swift/issues/60318 + static func async_throw(_ error: @autoclosure () -> Self, ifNil expression: @autoclosure () async throws -> T?) async throws -> T { + if let value = try await expression() { + return value + } else { + throw error() + } + } + + static func `throw`(_ error: @autoclosure () -> Self, ifNil expression: @autoclosure () throws -> T?) throws -> T { + if let value = try expression() { + return value + } else { + throw error() + } + } + + // https://github.com/apple/swift/issues/60318 + static func async_throw(_ error: @autoclosure () -> Self, if expression: @autoclosure () async throws -> Bool) async throws { + if try await expression() { + throw error() + } + } + + static func `throw`(_ error: @autoclosure () -> Self, if expression: @autoclosure () throws -> Bool) throws { + if try expression() { + throw error() + } + } +} + public struct DataReader where S.Element: RandomAccessCollection, S.Element.Element == UInt8 { public var position: Int = 0 { didSet { @@ -456,7 +491,7 @@ public struct DataReader where S.Element: RandomAccessCollecti self.iterator = data.makeAsyncIterator() } - mutating func read(upTo n: Int) async throws -> [UInt8]? { + mutating func read(upTo n: Int) async throws -> [UInt8] { var data = [UInt8]() var index = 0 while index != n { @@ -483,8 +518,8 @@ public struct DataReader where S.Element: RandomAccessCollecti } mutating func read(_ n: Int) async throws -> [UInt8] { - let data = try await read(upTo: n)! - precondition(data.count == n) + let data = try await read(upTo: n) + try UnxipError.throw(.truncated, if: data.count != n) return data } @@ -569,9 +604,9 @@ public struct Chunk: Sendable { public let buffer: [UInt8] public let decompressed: Bool - init(data: [UInt8], decompressedSize: Int?, lzmaDecompressor: ([UInt8], Int) -> [UInt8]) { + init(data: [UInt8], decompressedSize: Int?, lzmaDecompressor: ([UInt8], Int) throws -> [UInt8]) rethrows { if let decompressedSize = decompressedSize { - buffer = lzmaDecompressor(data, decompressedSize) + buffer = try lzmaDecompressor(data, decompressedSize) decompressed = true } else { buffer = data @@ -789,30 +824,23 @@ public protocol StreamAperture { static func transform(_: Input, options: Options?) -> Next.Input } -extension StreamAperture { - static func async_precondition(_ condition: @autoclosure () async throws -> Bool) async rethrows { - let result = try await condition() - precondition(result) - } -} - protocol Decompressor { - static func decompress(data: [UInt8], decompressedSize: Int) -> [UInt8] + static func decompress(data: [UInt8], decompressedSize: Int) throws -> [UInt8] } public enum DefaultDecompressor { enum Zlib: Decompressor { - static func decompress(data: [UInt8], decompressedSize: Int) -> [UInt8] { - return [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in + static func decompress(data: [UInt8], decompressedSize: Int) throws -> [UInt8] { + return try [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in #if canImport(Compression) let zlibSkip = 2 // Apple's decoder doesn't want to see CMF/FLG (see RFC 1950) - data[data.index(data.startIndex, offsetBy: zlibSkip)...].withUnsafeBufferPointer { - precondition(compression_decode_buffer(buffer.baseAddress!, decompressedSize, $0.baseAddress!, $0.count, nil, COMPRESSION_ZLIB) == decompressedSize) + try data[data.index(data.startIndex, offsetBy: zlibSkip)...].withUnsafeBufferPointer { + try UnxipError.throw(.invalid, if: compression_decode_buffer(buffer.baseAddress!, decompressedSize, $0.baseAddress!, $0.count, nil, COMPRESSION_ZLIB) != decompressedSize) } #else var size = decompressedSize - precondition(uncompress(buffer.baseAddress!, &size, data, UInt(data.count)) == Z_OK) - precondition(size == decompressedSize) + try UnxipError.throw(.invalid, if: uncompress(buffer.baseAddress!, &size, data, UInt(data.count)) != Z_OK) + try UnxipError.throw(.invalid, if: size != decompressedSize) #endif count = decompressedSize } @@ -820,18 +848,18 @@ public enum DefaultDecompressor { } enum LZMA: Decompressor { - static func decompress(data: [UInt8], decompressedSize: Int) -> [UInt8] { + static func decompress(data: [UInt8], decompressedSize: Int) throws -> [UInt8] { let magic = [0xfd] + "7zX".utf8 - precondition(data.prefix(magic.count).elementsEqual(magic)) - return [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in + try UnxipError.throw(.invalid, if: !data.prefix(magic.count).elementsEqual(magic)) + return try [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in #if canImport(Compression) - precondition(compression_decode_buffer(buffer.baseAddress!, decompressedSize, data, data.count, nil, COMPRESSION_LZMA) == decompressedSize) + try UnxipError.throw(.invalid, if: compression_decode_buffer(buffer.baseAddress!, decompressedSize, data, data.count, nil, COMPRESSION_LZMA) != decompressedSize) #else var memlimit = UInt64.max var inIndex = 0 var outIndex = 0 - precondition(lzma_stream_buffer_decode(&memlimit, 0, nil, data, &inIndex, data.count, buffer.baseAddress, &outIndex, decompressedSize) == LZMA_OK) - precondition(inIndex == data.count && outIndex == decompressedSize) + try UnxipError.throw(.invalid, if: lzma_stream_buffer_decode(&memlimit, 0, nil, data, &inIndex, data.count, buffer.baseAddress, &outIndex, decompressedSize) != LZMA_OK) + try UnxipError.throw(.invalid, if: inIndex != data.count || outIndex != decompressedSize) #endif count = decompressedSize } @@ -844,8 +872,8 @@ public enum XIP: StreamAperture where S.Element: RandomAccessC public typealias Next = Chunks public struct Options { - let zlibDecompressor: ([UInt8], Int) -> [UInt8] - let lzmaDecompressor: ([UInt8], Int) -> [UInt8] + let zlibDecompressor: ([UInt8], Int) throws -> [UInt8] + let lzmaDecompressor: ([UInt8], Int) throws -> [UInt8] init(zlibDecompressor: Zlib.Type, lzmaDecompressor: LZMA.Type) { self.zlibDecompressor = Zlib.decompress @@ -861,9 +889,9 @@ public enum XIP: StreamAperture where S.Element: RandomAccessC let fileStart = file.position let magic = "xar!".utf8 - try await async_precondition(await file.read(magic.count).elementsEqual(magic)) + try await UnxipError.async_throw(.invalid, if: await !file.read(magic.count).elementsEqual(magic)) let headerSize = try await file.read(UInt16.self) - try await async_precondition(await file.read(UInt16.self) == 1) // version + try await UnxipError.async_throw(.invalid, if: await file.read(UInt16.self) != 1) // version let tocCompressedSize = try await file.read(UInt64.self) let tocDecompressedSize = try await file.read(UInt64.self) _ = try await file.read(UInt32.self) // checksum @@ -871,7 +899,7 @@ public enum XIP: StreamAperture where S.Element: RandomAccessC _ = try await file.read(fileStart + Int(headerSize) - file.position) let compressedTOC = try await file.read(Int(tocCompressedSize)) - let toc = options.zlibDecompressor(compressedTOC, Int(tocDecompressedSize)) + let toc = try options.zlibDecompressor(compressedTOC, Int(tocDecompressedSize)) #if canImport(UIKit) let document = xmlReadMemory(toc, CInt(toc.count), "", nil, 0) @@ -883,12 +911,12 @@ public enum XIP: StreamAperture where S.Element: RandomAccessC xmlXPathFreeContext(context) } - func evaluateXPath(node: xmlNodePtr!, xpath: String) -> String { - let result = xmlXPathNodeEval(node, xpath, context)! + func evaluateXPath(node: xmlNodePtr!, xpath: String) throws -> String { + let result = try UnxipError.throw(.invalid, ifNil: xmlXPathNodeEval(node, xpath, context)) defer { xmlXPathFreeObject(result) } - precondition(result.pointee.type == XPATH_NODESET && result.pointee.nodesetval.pointee.nodeNr == 1) + try UnxipError.throw(.invalid, if: result.pointee.type != XPATH_NODESET || result.pointee.nodesetval.pointee.nodeNr != 1) let string = xmlNodeListGetString(document, result.pointee.nodesetval.pointee.nodeTab.pointee!.pointee.children, 1)! defer { xmlFree(string) @@ -896,25 +924,29 @@ public enum XIP: StreamAperture where S.Element: RandomAccessC return String(cString: string) } - let result = xmlXPathEvalExpression("/xar/toc/file", context)! + let result = try UnxipError.throw(.invalid, ifNil: xmlXPathEvalExpression("/xar/toc/file", context)) defer { xmlXPathFreeObject(result) } - precondition(result.pointee.type == XPATH_NODESET) - let content = result.pointee.nodesetval.pointee.nodeTab[ - (0..: StreamAperture where S.Element: RandomAccessC } let magic = "pbzx".utf8 - try await async_precondition(try await content.read(magic.count).elementsEqual(magic)) + try await UnxipError.async_throw(.invalid, if: await !content.read(magic.count).elementsEqual(magic)) let chunkSize = try await content.read(UInt64.self) var decompressedSize: UInt64 = 0 var previousYield: Task? @@ -960,7 +992,7 @@ public enum XIP: StreamAperture where S.Element: RandomAccessC let id = OSSignpostID(log: decompressionLog) os_signpost(.begin, log: decompressionLog, name: "Decompress", signpostID: id, compressed ? "Starting %td (compressed size = %td)" : "Starting %td (uncompressed size = %td)", chunkNumber, compressedSize) #endif - let chunk = Chunk(data: block, decompressedSize: compressed ? Int(decompressedSize) : nil, lzmaDecompressor: options.lzmaDecompressor) + let chunk = try Chunk(data: block, decompressedSize: compressed ? Int(decompressedSize) : nil, lzmaDecompressor: options.lzmaDecompressor) #if PROFILING os_signpost(.end, log: decompressionLog, name: "Decompress", signpostID: id, "Ended %td (decompressed size = %td)", chunkNumber, decompressedSize) #endif @@ -989,14 +1021,14 @@ public enum Chunks: StreamAperture { let fileStream = BackpressureStream(backpressure: FileBackpressure(maxSize: 1_000_000_000), of: File.self) Task { var iterator = chunks.makeAsyncIterator() - var chunk = try! await iterator.next()! + var chunk = try await UnxipError.async_throw(.truncated, ifNil: await iterator.next()) var position = chunk.buffer.startIndex - func read(size: Int) async -> [UInt8] { + func read(size: Int) async throws -> [UInt8] { var result = [UInt8]() while result.count < size { if position >= chunk.buffer.endIndex { - chunk = try! await iterator.next()! + chunk = try await UnxipError.async_throw(.truncated, ifNil: await iterator.next()) position = 0 } result.append(chunk.buffer[chunk.buffer.startIndex + position]) @@ -1005,30 +1037,34 @@ public enum Chunks: StreamAperture { return result } - func readOctal(from bytes: [UInt8]) -> Int { - Int(String(data: Data(bytes), encoding: .utf8)!, radix: 8)! + func readOctal(from bytes: [UInt8]) throws -> Int { + try UnxipError.throw(.invalid, ifNil: String(data: Data(bytes), encoding: .utf8).map { + Int($0, radix: 8) + } ?? nil) } while true { - let magic = await read(size: 6) + let magic = try await read(size: 6) // Yes, cpio.h really defines this global macro - precondition(magic.elementsEqual(MAGIC.utf8)) - let dev = readOctal(from: await read(size: 6)) - let ino = readOctal(from: await read(size: 6)) - let mode = readOctal(from: await read(size: 6)) - let _ = await read(size: 6) // uid - let _ = await read(size: 6) // gid - let _ = await read(size: 6) // nlink - let _ = await read(size: 6) // rdev - let _ = await read(size: 11) // mtime - let namesize = readOctal(from: await read(size: 6)) - var filesize = readOctal(from: await read(size: 11)) - let name = String(cString: await read(size: namesize)) + try UnxipError.throw(.invalid, if: !magic.elementsEqual(MAGIC.utf8)) + let dev = try readOctal(from: await read(size: 6)) + let ino = try readOctal(from: await read(size: 6)) + let mode = try readOctal(from: await read(size: 6)) + let _ = try await read(size: 6) // uid + let _ = try await read(size: 6) // gid + let _ = try await read(size: 6) // nlink + let _ = try await read(size: 6) // rdev + let _ = try await read(size: 11) // mtime + let namesize = try readOctal(from: await read(size: 6)) + var filesize = try readOctal(from: await read(size: 11)) + let _name = try await read(size: namesize) + try UnxipError.throw(.invalid, if: _name.last != 0) + let name = String(cString: _name) var file = File(dev: dev, ino: ino, mode: mode, name: name) while filesize > 0 { if position >= chunk.buffer.endIndex { - chunk = try! await iterator.next()! + chunk = try await UnxipError.async_throw(.truncated, ifNil: await iterator.next()) position = chunk.buffer.startIndex } let end = chunk.buffer.index(position, offsetBy: filesize, limitedBy: chunk.buffer.endIndex) ?? chunk.buffer.endIndex