Skip to content

Commit

Permalink
Add support for reading from standard input
Browse files Browse the repository at this point in the history
  • Loading branch information
saagarjha committed Apr 17, 2023
1 parent 8f43e89 commit 3bb59a5
Showing 1 changed file with 168 additions and 94 deletions.
262 changes: 168 additions & 94 deletions unxip.swift
Original file line number Diff line number Diff line change
Expand Up @@ -347,27 +347,76 @@ actor ConcurrentStream<Element> {
}
}

final class Chunk: Sendable {
let buffer: UnsafeBufferPointer<UInt8>
let owned: Bool
struct DataStream<S: AsyncSequence> where S.Element: RandomAccessCollection, S.Element.Element == UInt8 {
var position: Int = 0 {
didSet {
if let cap = cap {
precondition(position <= cap)
}
}
}
var current: (S.Element.Index, S.Element)?
var iterator: S.AsyncIterator

init(buffer: UnsafeBufferPointer<UInt8>, decompressedSize: Int?) {
if let decompressedSize = decompressedSize {
let magic = [0xfd] + "7zX".utf8
precondition(buffer.prefix(magic.count).elementsEqual(magic))
let result = UnsafeMutableBufferPointer<UInt8>.allocate(capacity: Int(decompressedSize))
precondition(compression_decode_buffer(result.baseAddress!, result.count, buffer.baseAddress!, buffer.count, nil, COMPRESSION_LZMA) == decompressedSize)
self.buffer = UnsafeBufferPointer(result)
owned = true
} else {
self.buffer = buffer
owned = false
var cap: Int?

init(data: S) {
self.iterator = data.makeAsyncIterator()
}

mutating func read(upTo n: Int) async throws -> [UInt8]? {
var data = [UInt8]()
var index = 0
while index != n {
let current: (S.Element.Index, S.Element)
if let _current = self.current,
_current.0 != _current.1.endIndex
{
current = _current
} else {
let new = try await iterator.next()
guard let new = new else {
return data
}
current = (new.startIndex, new)
}
let count = min(n - index, current.1.distance(from: current.0, to: current.1.endIndex))
let end = current.1.index(current.0, offsetBy: count)
data.append(contentsOf: current.1[current.0..<end])
self.current = (end, current.1)
index += count
position += count
}
return data
}

mutating func read(_ n: Int) async throws -> [UInt8] {
let data = try await read(upTo: n)!
precondition(data.count == n)
return data
}

deinit {
if owned {
buffer.deallocate()
mutating func read<Integer: BinaryInteger>(_ type: Integer.Type) async throws -> Integer {
try await read(MemoryLayout<Integer>.size).reduce(into: 0) { result, next in
result <<= 8
result |= Integer(next)
}
}
}

struct Chunk: Sendable {
let buffer: [UInt8]

init(data: [UInt8], decompressedSize: Int?) {
if let decompressedSize = decompressedSize {
let magic = [0xfd] + "7zX".utf8
precondition(data.prefix(magic.count).elementsEqual(magic))
buffer = [UInt8](unsafeUninitializedCapacity: decompressedSize) { buffer, count in
precondition(compression_decode_buffer(buffer.baseAddress!, decompressedSize, data, data.count, nil, COMPRESSION_LZMA) == decompressedSize)
count = decompressedSize
}
} else {
buffer = data
}
}
}
Expand All @@ -377,9 +426,7 @@ struct File {
let ino: Int
let mode: Int
let name: String
var data = [UnsafeBufferPointer<UInt8>]()
// For keeping the data alive
var chunks = [Chunk]()
var data = [ArraySlice<UInt8>]()

struct Identifier: Hashable {
let dev: Int
Expand Down Expand Up @@ -527,8 +574,8 @@ struct Options {
]
static let version = "2.0"

var input: URL
var output: URL?
var input: String?
var output: String?
var compress: Bool = true
var dryRun: Bool = false
var verbose: Bool = false
Expand Down Expand Up @@ -567,13 +614,8 @@ struct Options {
Self.printUsage(nominally: false)
}

self.input = URL(fileURLWithPath: input)

guard let output = arguments.dropFirst().first else {
return
}

self.output = URL(fileURLWithPath: output)
self.input = input == "-" ? nil : input
self.output = arguments.dropFirst().first
}

static func printVersion() -> Never {
Expand Down Expand Up @@ -608,46 +650,80 @@ struct Options {
struct Main {
static let options = Options()

static func read<Integer: BinaryInteger, Buffer: RandomAccessCollection>(_ type: Integer.Type, from buffer: inout Buffer) -> Integer where Buffer.Element == UInt8, Buffer.SubSequence == Buffer {
defer {
buffer = buffer[fromOffset: MemoryLayout<Integer>.size]
static func async_precondition(_ condition: @autoclosure () async throws -> Bool) async rethrows {
let result = try await condition()
precondition(result)
}

static func dataStream(descriptor: CInt) -> DataStream<BackpressureStream<[UInt8], CountedBackpressure<[UInt8]>>> {
let stream = BackpressureStream(backpressure: CountedBackpressure(max: 16), of: [UInt8].self)
let io = DispatchIO(type: .stream, fileDescriptor: descriptor, queue: .main) { _ in
}
var result: Integer = 0
var iterator = buffer.makeIterator()
for _ in 0..<MemoryLayout<Integer>.size {
result <<= 8
result |= Integer(iterator.next()!)

Task {
while await withCheckedContinuation({ continuation in
var chunk = DispatchData.empty
io.read(offset: 0, length: Int(PIPE_SIZE * 16), queue: .main) { done, data, error in
guard error == 0 else {
stream.finish(throwing: NSError(domain: NSPOSIXErrorDomain, code: Int(error)))
continuation.resume(returning: false)
return
}

chunk.append(data!)

if done {
if chunk.isEmpty {
stream.finish()
continuation.resume(returning: false)
} else {
let chunk = chunk
Task {
await stream.yield(
[UInt8](unsafeUninitializedCapacity: chunk.count) { buffer, count in
_ = chunk.copyBytes(to: buffer, from: nil)
count = chunk.count
})
continuation.resume(returning: true)
}
}
}
}
}) {
}
}
return result

return DataStream(data: stream)
}

static func chunks(from content: UnsafeBufferPointer<UInt8>) -> BackpressureStream<Chunk, CountedBackpressure<Chunk>> {
static func chunks(from content: DataStream<some AsyncSequence>) -> BackpressureStream<Chunk, CountedBackpressure<Chunk>> {
let decompressionStream = ConcurrentStream<Void>(consumeResults: true)
let chunkStream = BackpressureStream(backpressure: CountedBackpressure(max: 16), of: Chunk.self)

// A consuming reference, but alas we can't express this right now
let _content = content
Task {
var remaining = content[fromOffset: 4]
let chunkSize = read(UInt64.self, from: &remaining)
var content = _content
let magic = "pbzx".utf8
try await async_precondition(try await content.read(magic.count).elementsEqual(magic))
let chunkSize = try await content.read(UInt64.self)
var decompressedSize: UInt64 = 0
var previousYield: Task<Void, Error>?

repeat {
decompressedSize = read(UInt64.self, from: &remaining)
let compressedSize = read(UInt64.self, from: &remaining)
decompressedSize = try await content.read(UInt64.self)
let compressedSize = try await content.read(UInt64.self)

let _remaining = remaining
let block = try await content.read(Int(compressedSize))
let _decompressedSize = decompressedSize
let _previousYield = previousYield
previousYield = await decompressionStream.addTask {
let remaining = _remaining
let decompressedSize = _decompressedSize
let previousYield = _previousYield
let chunk = Chunk(buffer: UnsafeBufferPointer(rebasing: remaining[fromOffset: 0, size: Int(compressedSize)]), decompressedSize: compressedSize == decompressedSize ? nil : Int(decompressedSize))
let chunk = Chunk(data: block, decompressedSize: compressedSize == chunkSize ? nil : Int(decompressedSize))
_ = await previousYield?.result
await chunkStream.yield(chunk)
}

remaining = remaining[fromOffset: Int(compressedSize)]
} while decompressedSize == chunkSize
await decompressionStream.finish()
}
Expand Down Expand Up @@ -702,8 +778,7 @@ struct Main {
position = 0
}
let size = min(filesize, chunk.buffer.endIndex - position)
file.chunks.append(chunk)
file.data.append(UnsafeBufferPointer(rebasing: chunk.buffer[fromOffset: position, size: size]))
file.data.append(chunk.buffer[fromOffset: position, size: size])
filesize -= size
position += size
}
Expand All @@ -719,7 +794,7 @@ struct Main {
return fileStream
}

static func parseContent(_ content: UnsafeBufferPointer<UInt8>) async throws {
static func parseContent(_ content: DataStream<some AsyncSequence>) async throws {
let taskStream = ConcurrentStream<Void>()
let compressionStream = ConcurrentStream<[UInt8]?>(consumeResults: true)

Expand Down Expand Up @@ -834,25 +909,21 @@ struct Main {
return
}

// pwritev requires the vector count to be positive
if file.data.count == 0 {
return
}

var vector = file.data.map {
iovec(iov_base: UnsafeMutableRawPointer(mutating: $0.baseAddress), iov_len: $0.count)
}
let total = file.data.map(\.count).reduce(0, +)
var written = 0

repeat {
var position = 0
outer: for data in file.data {
var written = 0
// TODO: handle partial writes smarter
written = pwritev(fd, &vector, CInt(vector.count), 0)
if written < 0 {
warn(-1, "writing chunk to")
break
}
} while written != total
repeat {
written = data.withUnsafeBytes {
pwrite(fd, $0.baseAddress, data.count, off_t(position))
}
if written < 0 {
warn(-1, "writing chunk to")
break outer
}
} while written != data.count
position += written
}
}
)
default:
Expand All @@ -867,18 +938,25 @@ struct Main {
}
}

static func locateContent(in file: UnsafeBufferPointer<UInt8>) -> UnsafeBufferPointer<UInt8> {
precondition(file.starts(with: "xar!".utf8)) // magic
var header = file[4...]
let headerSize = read(UInt16.self, from: &header)
precondition(read(UInt16.self, from: &header) == 1) // version
let tocCompressedSize = read(UInt64.self, from: &header)
let tocDecompressedSize = read(UInt64.self, from: &header)
_ = read(UInt32.self, from: &header) // checksum
static func locateContent(in file: inout DataStream<some AsyncSequence>) async throws {
let fileStart = file.position

let magic = "xar!".utf8
try await async_precondition(await file.read(magic.count).elementsEqual(magic))
let headerSize = try await file.read(UInt16.self)
try await async_precondition(await file.read(UInt16.self) == 1) // version
let tocCompressedSize = try await file.read(UInt64.self)
let tocDecompressedSize = try await file.read(UInt64.self)
_ = try await file.read(UInt32.self) // checksum

_ = try await file.read(fileStart + Int(headerSize) - file.position)

let zlibSkip = 2 // Apple's decoder doesn't want to see CMF/FLG (see RFC 1950)
_ = try await file.read(2)
var compressedTOC = try await file.read(Int(tocCompressedSize) - zlibSkip)

let toc = [UInt8](unsafeUninitializedCapacity: Int(tocDecompressedSize)) { buffer, count in
let zlibSkip = 2 // Apple's decoder doesn't want to see CMF/FLG (see RFC 1950)
count = compression_decode_buffer(buffer.baseAddress!, Int(tocDecompressedSize), file.baseAddress! + Int(headerSize) + zlibSkip, Int(tocCompressedSize) - zlibSkip, nil, COMPRESSION_ZLIB)
count = compression_decode_buffer(buffer.baseAddress!, Int(tocDecompressedSize), &compressedTOC, compressedTOC.count, nil, COMPRESSION_ZLIB)
precondition(count == Int(tocDecompressedSize))
}

Expand All @@ -888,30 +966,26 @@ struct Main {
}!
let contentOffset = Int(try! content.nodes(forXPath: "data/offset").first!.stringValue!)!
let contentSize = Int(try! content.nodes(forXPath: "data/length").first!.stringValue!)!
let contentBase = Int(headerSize) + Int(tocCompressedSize) + contentOffset

let slice = file[fromOffset: contentBase, size: contentSize]
precondition(slice.starts(with: "pbzx".utf8))
return UnsafeBufferPointer(rebasing: slice)
_ = try await file.read(fileStart + Int(headerSize) + Int(tocCompressedSize) + contentOffset - file.position)
file.cap = file.position + contentSize
}

static func main() async throws {
let handle = try FileHandle(forReadingFrom: options.input)
try handle.seekToEnd()
let length = Int(try handle.offset())
let file = UnsafeBufferPointer(start: mmap(nil, length, PROT_READ, MAP_PRIVATE, handle.fileDescriptor, 0).bindMemory(to: UInt8.self, capacity: length), count: length)
precondition(UnsafeMutableRawPointer(mutating: file.baseAddress) != MAP_FAILED)
defer {
munmap(UnsafeMutableRawPointer(mutating: file.baseAddress), length)
}
let handle =
try options.input.flatMap {
try FileHandle(forReadingFrom: URL(fileURLWithPath: $0))
} ?? FileHandle.standardInput

if let output = options.output {
guard chdir(output.path) == 0 else {
fputs("Failed to access output directory at \(output.path): \(String(cString: strerror(errno)))", stderr)
guard chdir(output) == 0 else {
fputs("Failed to access output directory at \(output): \(String(cString: strerror(errno)))", stderr)
exit(EXIT_FAILURE)
}
}

try await parseContent(locateContent(in: file))
var file = dataStream(descriptor: handle.fileDescriptor)
try await locateContent(in: &file)
try await parseContent(file)
}
}

0 comments on commit 3bb59a5

Please sign in to comment.