Skip to content

Commit

Permalink
Update libunxip failure handling to throw errors
Browse files Browse the repository at this point in the history
  • Loading branch information
saagarjha committed Oct 29, 2023
1 parent 91c6da3 commit 946d97d
Showing 1 changed file with 105 additions and 69 deletions.
174 changes: 105 additions & 69 deletions unxip.swift
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,41 @@ extension option {

// MARK: - Public API

enum UnxipError: Error {
case truncated
case invalid

// https://github.com/apple/swift/issues/60318
static func async_throw<T>(_ error: @autoclosure () -> Self, ifNil expression: @autoclosure () async throws -> T?) async throws -> T {
if let value = try await expression() {
return value
} else {
throw error()
}
}

static func `throw`<T>(_ error: @autoclosure () -> Self, ifNil expression: @autoclosure () throws -> T?) throws -> T {
if let value = try expression() {
return value
} else {
throw error()
}
}

// https://github.com/apple/swift/issues/60318
static func async_throw(_ error: @autoclosure () -> Self, if expression: @autoclosure () async throws -> Bool) async throws {
if try await expression() {
throw error()
}
}

static func `throw`(_ error: @autoclosure () -> Self, if expression: @autoclosure () throws -> Bool) throws {
if try expression() {
throw error()
}
}
}

public struct DataReader<S: AsyncSequence> where S.Element: RandomAccessCollection, S.Element.Element == UInt8 {
public var position: Int = 0 {
didSet {
Expand All @@ -456,7 +491,7 @@ public struct DataReader<S: AsyncSequence> where S.Element: RandomAccessCollecti
self.iterator = data.makeAsyncIterator()
}

mutating func read(upTo n: Int) async throws -> [UInt8]? {
mutating func read(upTo n: Int) async throws -> [UInt8] {
var data = [UInt8]()
var index = 0
while index != n {
Expand All @@ -483,8 +518,8 @@ public struct DataReader<S: AsyncSequence> where S.Element: RandomAccessCollecti
}

mutating func read(_ n: Int) async throws -> [UInt8] {
let data = try await read(upTo: n)!
precondition(data.count == n)
let data = try await read(upTo: n)
try UnxipError.throw(.truncated, if: data.count != n)
return data
}

Expand Down Expand Up @@ -569,9 +604,9 @@ public struct Chunk: Sendable {
public let buffer: [UInt8]
public let decompressed: Bool

init(data: [UInt8], decompressedSize: Int?, lzmaDecompressor: ([UInt8], Int) -> [UInt8]) {
init(data: [UInt8], decompressedSize: Int?, lzmaDecompressor: ([UInt8], Int) throws -> [UInt8]) rethrows {
if let decompressedSize = decompressedSize {
buffer = lzmaDecompressor(data, decompressedSize)
buffer = try lzmaDecompressor(data, decompressedSize)
decompressed = true
} else {
buffer = data
Expand Down Expand Up @@ -789,49 +824,42 @@ public protocol StreamAperture {
static func transform(_: Input, options: Options?) -> Next.Input
}

extension StreamAperture {
static func async_precondition(_ condition: @autoclosure () async throws -> Bool) async rethrows {
let result = try await condition()
precondition(result)
}
}

protocol Decompressor {
static func decompress(data: [UInt8], decompressedSize: Int) -> [UInt8]
static func decompress(data: [UInt8], decompressedSize: Int) throws -> [UInt8]
}

public enum DefaultDecompressor {
enum Zlib: Decompressor {
static func decompress(data: [UInt8], decompressedSize: Int) -> [UInt8] {
return [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in
static func decompress(data: [UInt8], decompressedSize: Int) throws -> [UInt8] {
return try [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in
#if canImport(Compression)
let zlibSkip = 2 // Apple's decoder doesn't want to see CMF/FLG (see RFC 1950)
data[data.index(data.startIndex, offsetBy: zlibSkip)...].withUnsafeBufferPointer {
precondition(compression_decode_buffer(buffer.baseAddress!, decompressedSize, $0.baseAddress!, $0.count, nil, COMPRESSION_ZLIB) == decompressedSize)
try data[data.index(data.startIndex, offsetBy: zlibSkip)...].withUnsafeBufferPointer {
try UnxipError.throw(.invalid, if: compression_decode_buffer(buffer.baseAddress!, decompressedSize, $0.baseAddress!, $0.count, nil, COMPRESSION_ZLIB) != decompressedSize)
}
#else
var size = decompressedSize
precondition(uncompress(buffer.baseAddress!, &size, data, UInt(data.count)) == Z_OK)
precondition(size == decompressedSize)
try UnxipError.throw(.invalid, if: uncompress(buffer.baseAddress!, &size, data, UInt(data.count)) != Z_OK)
try UnxipError.throw(.invalid, if: size != decompressedSize)
#endif
count = decompressedSize
}
}
}

enum LZMA: Decompressor {
static func decompress(data: [UInt8], decompressedSize: Int) -> [UInt8] {
static func decompress(data: [UInt8], decompressedSize: Int) throws -> [UInt8] {
let magic = [0xfd] + "7zX".utf8
precondition(data.prefix(magic.count).elementsEqual(magic))
return [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in
try UnxipError.throw(.invalid, if: !data.prefix(magic.count).elementsEqual(magic))
return try [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in
#if canImport(Compression)
precondition(compression_decode_buffer(buffer.baseAddress!, decompressedSize, data, data.count, nil, COMPRESSION_LZMA) == decompressedSize)
try UnxipError.throw(.invalid, if: compression_decode_buffer(buffer.baseAddress!, decompressedSize, data, data.count, nil, COMPRESSION_LZMA) != decompressedSize)
#else
var memlimit = UInt64.max
var inIndex = 0
var outIndex = 0
precondition(lzma_stream_buffer_decode(&memlimit, 0, nil, data, &inIndex, data.count, buffer.baseAddress, &outIndex, decompressedSize) == LZMA_OK)
precondition(inIndex == data.count && outIndex == decompressedSize)
try UnxipError.throw(.invalid, if: lzma_stream_buffer_decode(&memlimit, 0, nil, data, &inIndex, data.count, buffer.baseAddress, &outIndex, decompressedSize) != LZMA_OK)
try UnxipError.throw(.invalid, if: inIndex != data.count || outIndex != decompressedSize)
#endif
count = decompressedSize
}
Expand All @@ -844,8 +872,8 @@ public enum XIP<S: AsyncSequence>: StreamAperture where S.Element: RandomAccessC
public typealias Next = Chunks

public struct Options {
let zlibDecompressor: ([UInt8], Int) -> [UInt8]
let lzmaDecompressor: ([UInt8], Int) -> [UInt8]
let zlibDecompressor: ([UInt8], Int) throws -> [UInt8]
let lzmaDecompressor: ([UInt8], Int) throws -> [UInt8]

init<Zlib: Decompressor, LZMA: Decompressor>(zlibDecompressor: Zlib.Type, lzmaDecompressor: LZMA.Type) {
self.zlibDecompressor = Zlib.decompress
Expand All @@ -861,17 +889,17 @@ public enum XIP<S: AsyncSequence>: StreamAperture where S.Element: RandomAccessC
let fileStart = file.position

let magic = "xar!".utf8
try await async_precondition(await file.read(magic.count).elementsEqual(magic))
try await UnxipError.async_throw(.invalid, if: await !file.read(magic.count).elementsEqual(magic))
let headerSize = try await file.read(UInt16.self)
try await async_precondition(await file.read(UInt16.self) == 1) // version
try await UnxipError.async_throw(.invalid, if: await file.read(UInt16.self) != 1) // version
let tocCompressedSize = try await file.read(UInt64.self)
let tocDecompressedSize = try await file.read(UInt64.self)
_ = try await file.read(UInt32.self) // checksum

_ = try await file.read(fileStart + Int(headerSize) - file.position)

let compressedTOC = try await file.read(Int(tocCompressedSize))
let toc = options.zlibDecompressor(compressedTOC, Int(tocDecompressedSize))
let toc = try options.zlibDecompressor(compressedTOC, Int(tocDecompressedSize))

#if canImport(UIKit)
let document = xmlReadMemory(toc, CInt(toc.count), "", nil, 0)
Expand All @@ -883,38 +911,42 @@ public enum XIP<S: AsyncSequence>: StreamAperture where S.Element: RandomAccessC
xmlXPathFreeContext(context)
}

func evaluateXPath(node: xmlNodePtr!, xpath: String) -> String {
let result = xmlXPathNodeEval(node, xpath, context)!
func evaluateXPath(node: xmlNodePtr!, xpath: String) throws -> String {
let result = try UnxipError.throw(.invalid, ifNil: xmlXPathNodeEval(node, xpath, context))
defer {
xmlXPathFreeObject(result)
}
precondition(result.pointee.type == XPATH_NODESET && result.pointee.nodesetval.pointee.nodeNr == 1)
try UnxipError.throw(.invalid, if: result.pointee.type != XPATH_NODESET || result.pointee.nodesetval.pointee.nodeNr != 1)
let string = xmlNodeListGetString(document, result.pointee.nodesetval.pointee.nodeTab.pointee!.pointee.children, 1)!
defer {
xmlFree(string)
}
return String(cString: string)
}

let result = xmlXPathEvalExpression("/xar/toc/file", context)!
let result = try UnxipError.throw(.invalid, ifNil: xmlXPathEvalExpression("/xar/toc/file", context))
defer {
xmlXPathFreeObject(result)
}
precondition(result.pointee.type == XPATH_NODESET)
let content = result.pointee.nodesetval.pointee.nodeTab[
(0..<Int(result.pointee.nodesetval.pointee.nodeNr)).first {
evaluateXPath(node: result.pointee.nodesetval.pointee.nodeTab[$0], xpath: "name") == "Content"
}!]

let contentOffset = Int(evaluateXPath(node: content, xpath: "data/offset"))!
let contentSize = Int(evaluateXPath(node: content, xpath: "data/length"))!
try UnxipError.throw(.invalid, if: result.pointee.type != XPATH_NODESET)
let content = try UnxipError.throw(
.invalid,
ifNil: result.pointee.nodesetval.pointee.nodeTab[
(0..<Int(result.pointee.nodesetval.pointee.nodeNr)).first {
try evaluateXPath(node: result.pointee.nodesetval.pointee.nodeTab[$0], xpath: "name") == "Content"
}!])

let contentOffset = try UnxipError.throw(.invalid, ifNil: Int(evaluateXPath(node: content, xpath: "data/offset")))
let contentSize = try UnxipError.throw(.invalid, ifNil: Int(evaluateXPath(node: content, xpath: "data/length")))
#else
let document = try! XMLDocument(data: Data(toc))
let content = try! document.nodes(forXPath: "/xar/toc/file").first {
try! $0.nodes(forXPath: "name").first!.stringValue! == "Content"
}!
let contentOffset = Int(try! content.nodes(forXPath: "data/offset").first!.stringValue!)!
let contentSize = Int(try! content.nodes(forXPath: "data/length").first!.stringValue!)!
let document = try XMLDocument(data: Data(toc))
let content = try UnxipError.throw(
.invalid,
ifNil: document.nodes(forXPath: "/xar/toc/file").first {
try $0.nodes(forXPath: "name").first?.stringValue == "Content"
})
let contentOffset = try UnxipError.throw(.invalid, ifNil: content.nodes(forXPath: "data/offset").first?.stringValue.map(Int.init) ?? nil)
let contentSize = try UnxipError.throw(.invalid, ifNil: content.nodes(forXPath: "data/length").first?.stringValue.map(Int.init) ?? nil)
#endif

_ = try await file.read(fileStart + Int(headerSize) + Int(tocCompressedSize) + contentOffset - file.position)
Expand All @@ -937,7 +969,7 @@ public enum XIP<S: AsyncSequence>: StreamAperture where S.Element: RandomAccessC
}

let magic = "pbzx".utf8
try await async_precondition(try await content.read(magic.count).elementsEqual(magic))
try await UnxipError.async_throw(.invalid, if: await !content.read(magic.count).elementsEqual(magic))
let chunkSize = try await content.read(UInt64.self)
var decompressedSize: UInt64 = 0
var previousYield: Task<Void, Error>?
Expand All @@ -960,7 +992,7 @@ public enum XIP<S: AsyncSequence>: StreamAperture where S.Element: RandomAccessC
let id = OSSignpostID(log: decompressionLog)
os_signpost(.begin, log: decompressionLog, name: "Decompress", signpostID: id, compressed ? "Starting %td (compressed size = %td)" : "Starting %td (uncompressed size = %td)", chunkNumber, compressedSize)
#endif
let chunk = Chunk(data: block, decompressedSize: compressed ? Int(decompressedSize) : nil, lzmaDecompressor: options.lzmaDecompressor)
let chunk = try Chunk(data: block, decompressedSize: compressed ? Int(decompressedSize) : nil, lzmaDecompressor: options.lzmaDecompressor)
#if PROFILING
os_signpost(.end, log: decompressionLog, name: "Decompress", signpostID: id, "Ended %td (decompressed size = %td)", chunkNumber, decompressedSize)
#endif
Expand Down Expand Up @@ -989,14 +1021,14 @@ public enum Chunks: StreamAperture {
let fileStream = BackpressureStream(backpressure: FileBackpressure(maxSize: 1_000_000_000), of: File.self)
Task {
var iterator = chunks.makeAsyncIterator()
var chunk = try! await iterator.next()!
var chunk = try await UnxipError.async_throw(.truncated, ifNil: await iterator.next())
var position = chunk.buffer.startIndex

func read(size: Int) async -> [UInt8] {
func read(size: Int) async throws -> [UInt8] {
var result = [UInt8]()
while result.count < size {
if position >= chunk.buffer.endIndex {
chunk = try! await iterator.next()!
chunk = try await UnxipError.async_throw(.truncated, ifNil: await iterator.next())
position = 0
}
result.append(chunk.buffer[chunk.buffer.startIndex + position])
Expand All @@ -1005,30 +1037,34 @@ public enum Chunks: StreamAperture {
return result
}

func readOctal(from bytes: [UInt8]) -> Int {
Int(String(data: Data(bytes), encoding: .utf8)!, radix: 8)!
func readOctal(from bytes: [UInt8]) throws -> Int {
try UnxipError.throw(.invalid, ifNil: String(data: Data(bytes), encoding: .utf8).map {
Int($0, radix: 8)
} ?? nil)
}

while true {
let magic = await read(size: 6)
let magic = try await read(size: 6)
// Yes, cpio.h really defines this global macro
precondition(magic.elementsEqual(MAGIC.utf8))
let dev = readOctal(from: await read(size: 6))
let ino = readOctal(from: await read(size: 6))
let mode = readOctal(from: await read(size: 6))
let _ = await read(size: 6) // uid
let _ = await read(size: 6) // gid
let _ = await read(size: 6) // nlink
let _ = await read(size: 6) // rdev
let _ = await read(size: 11) // mtime
let namesize = readOctal(from: await read(size: 6))
var filesize = readOctal(from: await read(size: 11))
let name = String(cString: await read(size: namesize))
try UnxipError.throw(.invalid, if: !magic.elementsEqual(MAGIC.utf8))
let dev = try readOctal(from: await read(size: 6))
let ino = try readOctal(from: await read(size: 6))
let mode = try readOctal(from: await read(size: 6))
let _ = try await read(size: 6) // uid
let _ = try await read(size: 6) // gid
let _ = try await read(size: 6) // nlink
let _ = try await read(size: 6) // rdev
let _ = try await read(size: 11) // mtime
let namesize = try readOctal(from: await read(size: 6))
var filesize = try readOctal(from: await read(size: 11))
let _name = try await read(size: namesize)
try UnxipError.throw(.invalid, if: _name.last != 0)
let name = String(cString: _name)
var file = File(dev: dev, ino: ino, mode: mode, name: name)

while filesize > 0 {
if position >= chunk.buffer.endIndex {
chunk = try! await iterator.next()!
chunk = try await UnxipError.async_throw(.truncated, ifNil: await iterator.next())
position = chunk.buffer.startIndex
}
let end = chunk.buffer.index(position, offsetBy: filesize, limitedBy: chunk.buffer.endIndex) ?? chunk.buffer.endIndex
Expand Down

0 comments on commit 946d97d

Please sign in to comment.