|
| 1 | +// A tree representing the type of some captures. |
| 2 | +public enum CaptureStructure: Equatable { |
| 3 | + case atom(name: String? = nil) |
| 4 | + indirect case array(CaptureStructure) |
| 5 | + indirect case optional(CaptureStructure) |
| 6 | + indirect case tuple([CaptureStructure]) |
| 7 | + |
| 8 | + public static func tuple(_ children: CaptureStructure...) -> Self { |
| 9 | + tuple(children) |
| 10 | + } |
| 11 | + |
| 12 | + public static var empty: Self { |
| 13 | + .tuple([]) |
| 14 | + } |
| 15 | +} |
| 16 | + |
| 17 | +extension AST { |
| 18 | + public var captureStructure: CaptureStructure { |
| 19 | + // Note: This implementation could be more optimized. |
| 20 | + switch self { |
| 21 | + case .alternation(let alternation): |
| 22 | + assert(alternation.children.count > 1) |
| 23 | + return alternation.children |
| 24 | + .map(\.captureStructure) |
| 25 | + .reduce(.empty, +) |
| 26 | + .map(CaptureStructure.optional) |
| 27 | + case .concatenation(let concatenation): |
| 28 | + assert(concatenation.children.count > 1) |
| 29 | + return concatenation.children.map(\.captureStructure).reduce(.empty, +) |
| 30 | + case .group(let group): |
| 31 | + let innerCaptures = group.child.captureStructure |
| 32 | + switch group.kind.value { |
| 33 | + case .capture: |
| 34 | + return .atom() + innerCaptures |
| 35 | + case .namedCapture(let name): |
| 36 | + return .atom(name: name.value) + innerCaptures |
| 37 | + default: |
| 38 | + return innerCaptures |
| 39 | + } |
| 40 | + case .quantification(let quantification): |
| 41 | + return quantification.child.captureStructure.map( |
| 42 | + quantification.amount.value == .zeroOrOne |
| 43 | + ? CaptureStructure.optional |
| 44 | + : CaptureStructure.array) |
| 45 | + case .groupTransform: |
| 46 | + fatalError("Unreachable. Case will be removed later.") |
| 47 | + case .quote, .trivia, .atom, .customCharacterClass, .empty: |
| 48 | + return .empty |
| 49 | + } |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +// MARK: - Combination and transformation |
| 54 | + |
| 55 | +extension CaptureStructure { |
| 56 | + /// Returns a capture structure by concatenating any tuples in `self` and |
| 57 | + /// `other`. |
| 58 | + func concatenating(with other: CaptureStructure) -> CaptureStructure { |
| 59 | + switch (self, other) { |
| 60 | + // (T...) + (U...) ==> (T..., U...) |
| 61 | + case let (.tuple(lhs), .tuple(rhs)): |
| 62 | + return .tuple(lhs + rhs) |
| 63 | + // T + () ==> T |
| 64 | + case (_, .tuple(let rhs)) where rhs.isEmpty: |
| 65 | + return self |
| 66 | + // () + T ==> T |
| 67 | + case (.tuple(let lhs), _) where lhs.isEmpty: |
| 68 | + return other |
| 69 | + // (T...) + U ==> (T..., U) |
| 70 | + case let (.tuple(lhs), _): |
| 71 | + return .tuple(lhs + [other]) |
| 72 | + // T + (U...) ==> (T, U...) |
| 73 | + case let (_, .tuple(rhs)): |
| 74 | + return .tuple([self] + rhs) |
| 75 | + // T + U ==> (T, U) |
| 76 | + default: |
| 77 | + return .tuple([self, other]) |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + static func + ( |
| 82 | + lhs: CaptureStructure, rhs: CaptureStructure |
| 83 | + ) -> CaptureStructure { |
| 84 | + lhs.concatenating(with: rhs) |
| 85 | + } |
| 86 | + |
| 87 | + /// Returns a capture structure by transforming any tuple element of `self` |
| 88 | + /// or transforming `self` directly if it is not a tuple. |
| 89 | + func map( |
| 90 | + _ transform: (CaptureStructure) -> CaptureStructure |
| 91 | + ) -> CaptureStructure { |
| 92 | + if case .tuple(let children) = self { |
| 93 | + return .tuple(children.map(transform)) |
| 94 | + } |
| 95 | + return transform(self) |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +// MARK: - Common properties |
| 100 | + |
| 101 | +extension CaptureStructure { |
| 102 | + /// Returns a Boolean indicating whether the structure does not contain any |
| 103 | + /// captures. |
| 104 | + public var isEmpty: Bool { |
| 105 | + if case .tuple(let elements) = self, elements.isEmpty { |
| 106 | + return true |
| 107 | + } |
| 108 | + return false |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +// MARK: - Serialization |
| 113 | + |
| 114 | +extension CaptureStructure { |
| 115 | + /// A byte-sized serialization code. |
| 116 | + private enum Code: UInt8 { |
| 117 | + case end = 0 |
| 118 | + case atom = 1 |
| 119 | + case namedAtom = 2 |
| 120 | + case formArray = 3 |
| 121 | + case formOptional = 4 |
| 122 | + case beginTuple = 5 |
| 123 | + case endTuple = 6 |
| 124 | + } |
| 125 | + |
| 126 | + private typealias SerializationVersion = UInt16 |
| 127 | + private static let currentSerializationVersion: SerializationVersion = 1 |
| 128 | + |
| 129 | + public static func serializationBufferSize( |
| 130 | + forInputUTF8CodeUnitCount inputUTF8CodeUnitCount: Int |
| 131 | + ) -> Int { |
| 132 | + MemoryLayout<SerializationVersion>.stride + inputUTF8CodeUnitCount + 1 |
| 133 | + } |
| 134 | + |
| 135 | + /// Encode the capture structure to the given buffer as a serialized |
| 136 | + /// representation. |
| 137 | + /// |
| 138 | + /// The encoding rules are as follows: |
| 139 | + /// ``` |
| 140 | + /// encode(〚`T`〛) ==> <version>, 〚`T`〛, .end |
| 141 | + /// 〚`T` (atom)〛 ==> .atom |
| 142 | + /// 〚`name: T` (atom)〛 ==> .atom, `name`, '\0' |
| 143 | + /// 〚`[T]`〛 ==> 〚`T`〛, .formArray |
| 144 | + /// 〚`T?`〛 ==> 〚`T`〛, .formOptional |
| 145 | + /// 〚`(T0, T1, ...)`〛 ==> .beginTuple, 〚`T0`〛, 〚`T1`〛, ..., .endTuple |
| 146 | + /// ``` |
| 147 | + /// |
| 148 | + /// - Parameter buffer: A buffer whose byte count is at least the byte count |
| 149 | + /// of the regular expression string that produced this capture structure. |
| 150 | + public func encode(to buffer: UnsafeMutableRawBufferPointer) { |
| 151 | + assert(!buffer.isEmpty, "Buffer must not be empty") |
| 152 | + assert( |
| 153 | + buffer.count >= |
| 154 | + MemoryLayout<SerializationVersion>.stride + MemoryLayout<Code>.stride) |
| 155 | + // Encode version. |
| 156 | + buffer.storeBytes( |
| 157 | + of: Self.currentSerializationVersion, as: SerializationVersion.self) |
| 158 | + // Encode contents. |
| 159 | + var offset = MemoryLayout<SerializationVersion>.stride |
| 160 | + /// Appends a code to the buffer, advancing the offset to the next position. |
| 161 | + func append(_ code: Code) { |
| 162 | + buffer.storeBytes(of: code, toByteOffset: offset, as: Code.self) |
| 163 | + offset += MemoryLayout<Code>.stride |
| 164 | + } |
| 165 | + /// Recursively encode the node to the buffer. |
| 166 | + func encode(_ node: CaptureStructure) { |
| 167 | + switch node { |
| 168 | + // 〚`T` (atom)〛 ==> .atom |
| 169 | + case .atom(name: nil): |
| 170 | + append(.atom) |
| 171 | + // 〚`name: T` (atom)〛 ==> .atom, `name`, '\0' |
| 172 | + case .atom(name: let name?): |
| 173 | + append(.namedAtom) |
| 174 | + let nameCString = name.utf8CString |
| 175 | + let nameSlot = UnsafeMutableRawBufferPointer( |
| 176 | + rebasing: buffer[offset ..< offset+nameCString.count]) |
| 177 | + nameCString.withUnsafeBytes(nameSlot.copyMemory(from:)) |
| 178 | + offset += nameCString.count |
| 179 | + // 〚`[T]`〛 ==> 〚`T`〛, .formArray |
| 180 | + case .array(let child): |
| 181 | + encode(child) |
| 182 | + append(.formArray) |
| 183 | + // 〚`T?`〛 ==> 〚`T`〛, .formOptional |
| 184 | + case .optional(let child): |
| 185 | + encode(child) |
| 186 | + append(.formOptional) |
| 187 | + // 〚`(T0, T1, ...)`〛 ==> .beginTuple, 〚`T0`〛, 〚`T1`〛, ..., .endTuple |
| 188 | + case .tuple(let children): |
| 189 | + append(.beginTuple) |
| 190 | + for child in children { |
| 191 | + encode(child) |
| 192 | + } |
| 193 | + append(.endTuple) |
| 194 | + } |
| 195 | + } |
| 196 | + if !isEmpty { |
| 197 | + encode(self) |
| 198 | + } |
| 199 | + append(.end) |
| 200 | + } |
| 201 | + |
| 202 | + /// Creates a capture structure by decoding a serialized representation from |
| 203 | + /// the given buffer. |
| 204 | + public init?(decoding buffer: UnsafeRawBufferPointer) { |
| 205 | + var scopes: [[CaptureStructure]] = [[]] |
| 206 | + var currentScope: [CaptureStructure] { |
| 207 | + get { scopes[scopes.endIndex - 1] } |
| 208 | + _modify { yield &scopes[scopes.endIndex - 1] } |
| 209 | + } |
| 210 | + // Decode version. |
| 211 | + let version = buffer.load(as: SerializationVersion.self) |
| 212 | + guard version == Self.currentSerializationVersion else { |
| 213 | + return nil |
| 214 | + } |
| 215 | + // Decode contents. |
| 216 | + var offset = MemoryLayout<SerializationVersion>.stride |
| 217 | + /// Returns the next code in the buffer, or nil if the memory does not |
| 218 | + /// contain a valid code. |
| 219 | + func nextCode() -> Code? { |
| 220 | + defer { offset += MemoryLayout<Code>.stride } |
| 221 | + let rawValue = buffer.load(fromByteOffset: offset, as: Code.RawValue.self) |
| 222 | + return Code(rawValue: rawValue) |
| 223 | + } |
| 224 | + repeat { |
| 225 | + guard let code = nextCode() else { |
| 226 | + return nil |
| 227 | + } |
| 228 | + switch code { |
| 229 | + case .end: |
| 230 | + offset = buffer.endIndex |
| 231 | + case .atom: |
| 232 | + currentScope.append(.atom()) |
| 233 | + case .namedAtom: |
| 234 | + let stringAddress = buffer.baseAddress.unsafelyUnwrapped |
| 235 | + .advanced(by: offset) |
| 236 | + .assumingMemoryBound(to: CChar.self) |
| 237 | + let name = String(cString: stringAddress) |
| 238 | + offset += name.utf8CString.count |
| 239 | + currentScope.append(.atom(name: name)) |
| 240 | + case .formArray: |
| 241 | + let lastIndex = currentScope.endIndex - 1 |
| 242 | + currentScope[lastIndex] = .array(currentScope[lastIndex]) |
| 243 | + case .formOptional: |
| 244 | + let lastIndex = currentScope.endIndex - 1 |
| 245 | + currentScope[lastIndex] = .optional(currentScope[lastIndex]) |
| 246 | + case .beginTuple: |
| 247 | + scopes.append([]) |
| 248 | + case .endTuple: |
| 249 | + let lastScope = scopes.removeLast() |
| 250 | + currentScope.append(.tuple(lastScope)) |
| 251 | + } |
| 252 | + } while offset < buffer.endIndex |
| 253 | + guard scopes.count == 1 else { |
| 254 | + return nil // Malformed serialization. |
| 255 | + } |
| 256 | + self = currentScope.count == 1 ? currentScope[0] : .tuple(currentScope) |
| 257 | + } |
| 258 | +} |
0 commit comments