diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..b24bb2da --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +bin/* linguist-language=Ruby diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 46959146..9f77688a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -4,3 +4,7 @@ updates: directory: "/" schedule: interval: "daily" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/auto-merge.yml b/.github/workflows/auto-merge.yml new file mode 100644 index 00000000..5468e6d0 --- /dev/null +++ b/.github/workflows/auto-merge.yml @@ -0,0 +1,22 @@ +name: Dependabot auto-merge +on: pull_request + +permissions: + contents: write + pull-requests: write + +jobs: + dependabot: + runs-on: ubuntu-latest + if: ${{ github.actor == 'dependabot[bot]' }} + steps: + - name: Dependabot metadata + id: metadata + uses: dependabot/fetch-metadata@v2.4.0 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + - name: Enable auto-merge for Dependabot PRs + run: gh pr merge --auto --merge "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 8adc1b45..a9fce541 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -1,26 +1,54 @@ -name: Github Pages (rdoc) -on: [push] +name: Deploy rdoc to GitHub Pages + +on: + push: + branches: + - main + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + jobs: - build-and-deploy: + # Build job + build: runs-on: ubuntu-latest steps: - - name: Checkout 🛎️ - uses: actions/checkout@master - - - name: Set up Ruby 💎 + - name: Checkout + uses: actions/checkout@v6 + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Set up Ruby uses: ruby/setup-ruby@v1 with: bundler-cache: true ruby-version: '3.1' - - - name: Install rdoc and generate docs 🔧 + - name: Generate docs run: | gem install rdoc - rdoc --main README.md --op rdocs --exclude={Gemfile,Rakefile,"coverage/*","vendor/*","bin/*","test/*","tmp/*"} - cp -r doc rdocs/doc + rdoc --main README.md --op _site --exclude={Gemfile,Rakefile,"coverage/*","vendor/*","bin/*","test/*","tmp/*"} + cp -r doc _site/doc + - name: Upload artifact + uses: actions/upload-pages-artifact@v4 - - name: Deploy 🚀 - uses: peaceiris/actions-gh-pages@v3 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./rdocs + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7f5ac15c..468591bd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,21 +1,27 @@ name: Main + on: - push -- pull_request_target +- pull_request + jobs: ci: strategy: fail-fast: false matrix: ruby: - - '2.7' + - '2.7.0' - '3.0' - '3.1' - - head + - '3.2' + - '3.3' + - '3.4' + - truffleruby-head name: CI runs-on: ubuntu-latest env: CI: true + # TESTOPTS: --verbose steps: - uses: actions/checkout@master - uses: ruby/setup-ruby@v1 @@ -35,25 +41,8 @@ jobs: - uses: ruby/setup-ruby@v1 with: bundler-cache: true - ruby-version: '3.1' + ruby-version: '3.2' - name: Check run: | - bundle exec rake check + bundle exec rake stree:check bundle exec rubocop - - automerge: - name: AutoMerge - needs: - - ci - - check - runs-on: ubuntu-latest - if: github.event_name == 'pull_request_target' && github.actor == 'dependabot[bot]' - steps: - - uses: actions/github-script@v3 - with: - script: | - github.pulls.merge({ - owner: context.payload.repository.owner.login, - repo: context.payload.repository.name, - pull_number: context.payload.pull_request.number - }) diff --git a/.gitignore b/.gitignore index 2838e82b..3ce1e327 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,10 @@ /coverage/ /pkg/ /rdocs/ +/sorbet/ /spec/reports/ /tmp/ /vendor/ test.rb +query.txt diff --git a/.rubocop.yml b/.rubocop.yml index 8c1bc99e..1b81a535 100644 --- a/.rubocop.yml +++ b/.rubocop.yml @@ -7,12 +7,30 @@ AllCops: SuggestExtensions: false TargetRubyVersion: 2.7 Exclude: - - '{bin,coverage,pkg,test/fixtures,vendor,tmp}/**/*' + - '{.git,.github,.ruby-lsp,bin,coverage,doc,pkg,sorbet,spec,test/fixtures,vendor,tmp}/**/*' - test.rb +Gemspec/DevelopmentDependencies: + Enabled: false + Layout/LineLength: Max: 80 +Lint/AmbiguousBlockAssociation: + Enabled: false + +Lint/AmbiguousOperatorPrecedence: + Enabled: false + +Lint/AmbiguousRange: + Enabled: false + +Lint/BooleanSymbol: + Enabled: false + +Lint/Debugger: + Enabled: false + Lint/DuplicateBranch: Enabled: false @@ -25,6 +43,21 @@ Lint/InterpolationCheck: Lint/MissingSuper: Enabled: false +Lint/NonLocalExitFromIterator: + Enabled: false + +Lint/RedundantRequireStatement: + Enabled: false + +Lint/RescueException: + Enabled: false + +Lint/SuppressedException: + Enabled: false + +Lint/UnderscorePrefixedVariableName: + Enabled: false + Lint/UnusedMethodArgument: AllowUnusedKeywordArguments: true @@ -40,33 +73,81 @@ Naming/MethodParameterName: Naming/RescuedExceptionsVariableName: PreferredName: error +Naming/VariableNumber: + Enabled: false + +Security/Eval: + Enabled: false + +Style/AccessorGrouping: + Enabled: false + +Style/Alias: + Enabled: false + +Style/CaseEquality: + Enabled: false + +Style/CaseLikeIf: + Enabled: false + +Style/ClassVars: + Enabled: false + +Style/CombinableLoops: + Enabled: false + +Style/DocumentDynamicEvalDefinition: + Enabled: false + +Style/Documentation: + Enabled: false + +Style/EndBlock: + Enabled: false + Style/ExplicitBlockArgument: Enabled: false Style/FormatString: - EnforcedStyle: percent + Enabled: false + +Style/FormatStringToken: + Enabled: false Style/GuardClause: Enabled: false +Style/HashLikeCase: + Enabled: false + Style/IdenticalConditionalBranches: Enabled: false Style/IfInsideElse: Enabled: false +Style/IfWithBooleanLiteralBranches: + Enabled: false + Style/KeywordParametersOrder: Enabled: false Style/MissingRespondToMissing: Enabled: false +Style/MultipleComparison: + Enabled: false + Style/MutableConstant: Enabled: false Style/NegatedIfElseCondition: Enabled: false +Style/Next: + Enabled: false + Style/NumericPredicate: Enabled: false @@ -76,5 +157,20 @@ Style/ParallelAssignment: Style/PerlBackrefs: Enabled: false +Style/RedundantArrayConstructor: + Enabled: false + +Style/RedundantParentheses: + Enabled: false + +Style/SafeNavigation: + Enabled: false + Style/SpecialGlobalVars: Enabled: false + +Style/StructInheritance: + Enabled: false + +Style/YodaExpression: + Enabled: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 26e538c3..4ad42fc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,424 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) a ## [Unreleased] +## [6.3.0] - 2025-07-16 + +### Added + +- The `--extension` command line option has been added to the CLI to specify what type of content is coming from stdin. +- The `--config` command line option has been added to the CLI to specify the path to the configuration file. + +### Changed + +- Fix formatting of character literals when single quotes is enabled. +- Pass ignore files option to the language server. +- Hash keys should remain unchanged when there are any omitted values in the hash. +- We now properly handle compilation errors in the parser. + +## [6.2.0] - 2023-09-20 + +### Added + +- Fix `WithScope` for destructured post arguments. + +### Changed + +- Always use `do`/`end` for multi-line lambdas. + +## [6.1.1] - 2023-03-21 + +### Changed + +- Fixed a bug where the call chain formatter was incorrectly looking at call messages. + +## [6.1.0] - 2023-03-20 + +### Added + +- The `stree ctags` command for generating ctags like `universal-ctags` or `ripper-tags` would. +- The `definedivar` YARV instruction has been added to match CRuby's implementation. +- We now generate better Sorbet RBI files for the nodes in the tree and the visitors. +- `SyntaxTree::Reflection.nodes` now includes the visitor method. + +### Changed + +- We now explicitly require `pp` in environments that need it. + +## [6.0.2] - 2023-03-03 + +### Added + +- The `WithScope` visitor mixin will now additionally report local variables defined through regular expression named captures. +- The `WithScope` visitor mixin now properly handles destructured splat arguments in required positions. + +### Changed + +- Fixed the AST output by adding blocks to `Command` and `CommandCall` nodes in the `FieldVisitor`. +- Fixed the location of lambda local variables (e.g., `->(; a) {}`). + +## [6.0.1] - 2023-02-26 + +### Added + +- The class declarations returned as the result of the indexing operation now have their superclass as a field. It is returned as an array of constants. If the superclass is anything other than a constant lookup, then it raises an error. + +### Changed + +- The `nesting` field on the results of the indexing operation is no longer a single flat array. Instead it is an array of arrays, where each array is a single nesting level. This more accurately reflects the nesting of the nodes in the tree. For example, `class Foo::Bar::Baz; end` would result in `[Foo, Bar, Baz]`, but that incorrectly implies that you can see constants at each of those levels. Now this would result in `[[Foo, Bar, Baz]]` to indicate that it can see either the top level or constants within the scope of `Foo::Bar::Baz` only. +- When formatting hashes that have omitted values and mixed hash rockets with labels, the formatting now maintains whichever delimiter was used in the source. This is because forcing the use of hash rockets with omitted values results in a syntax error. +- Handle the case where a bare hash is used after the `break`, `next`, or `return` keywords. Previously this would result in hash labels which is not valid syntax. Now it maintains the delimiters used in the source. +- The `<<` operator will now break on chained `<<` expressions. Previously it would always stay flat. + +## [6.0.0] - 2023-02-10 + +### Added + +- `SyntaxTree::BasicVisitor::visit_methods` has been added to allow you to check multiple visit methods inside of a block. There _was_ a method called `visit_methods` previously, but it was undocumented because it was meant as a private API. That method has been renamed to `valid_visit_methods`. +- `rake sorbet:rbi` has been added as a task within the repository to generate an RBI file corresponding to the nodes in the tree. This can be used to help aid consumers of Syntax Tree that are using Sorbet. +- `SyntaxTree::Reflection` has been added to allow you to get information about the nodes in the tree. It is not required by default, since it takes a small amount of time to parse `node.rb` and get all of the information. +- `SyntaxTree::Node#to_mermaid` has been added to allow you to generate a Mermaid diagram of the node and its children. This is useful for debugging and understanding the structure of the tree. +- `SyntaxTree::Translation` has been added as an experimental API to transform the Syntax Tree syntax tree into the syntax trees represented by the whitequark/parser and rubocop/rubocop-ast gems. + - `SyntaxTree::Translation.to_parser(node, buffer)` will return a `Parser::AST::Node` object. + - `SyntaxTree::Translation.to_rubocop_ast(node, buffer)` will return a `RuboCop::AST::Node` object. +- `SyntaxTree::index` and `SyntaxTree::index_file` have been added to allow you to get a list of all of the classes, modules, and methods defined in a given source string or file. +- Various convenience methods have been added: + - `SyntaxTree::format_file` - which calls format with the result of reading the file + - `SyntaxTree::format_node` - which formats the node directly + - `SyntaxTree::parse_file` - which calls parse with the result of reading the file + - `SyntaxTree::search_file` - which calls search with the result of reading the file + - `SyntaxTree::Node#start_char` - which is the same as calling `node.location.start_char` + - `SyntaxTree::Node#end_char` - which is the same as calling `node.location.end_char` +- `SyntaxTree::Assoc` nodes can now be formatted on their own without a parent hash node. +- `SyntaxTree::BlockVar#arg0?` has been added to check if a single required block parameter is present and would potentially be expanded. +- More experimental APIs have been added to the `SyntaxTree::YARV` module, including: + - `SyntaxTree::YARV::ControlFlowGraph` + - `SyntaxTree::YARV::DataFlowGraph` + - `SyntaxTree::YARV::SeaOfNodes` + +### Changed + +#### Major changes + +- *BREAKING* Updates to `WithEnvironment`: + - The `WithEnvironment` module has been renamed to `WithScope`. + - The `current_environment` method has been renamed to `current_scope`. + - The `with_current_environment` method has been removed. + - Previously scopes were always able to look up the tree, as in: `a = 1; def foo; a = 2; end` would see only a single `a` variable. That has been corrected. + - Previously accessing variables from inside of blocks that were not shadowed would mark them as being local to the block only. This has been correct. +- *BREAKING* Lots of constants moved out of `SyntaxTree::Visitor` to just `SyntaxTree`: + * `SyntaxTree::Visitor::FieldVisitor` is now `SyntaxTree::FieldVisitor` + * `SyntaxTree::Visitor::JSONVisitor` is now `SyntaxTree::JSONVisitor` + * `SyntaxTree::Visitor::MatchVisitor` is now `SyntaxTree::MatchVisitor` + * `SyntaxTree::Visitor::MutationVisitor` is now `SyntaxTree::MutationVisitor` + * `SyntaxTree::Visitor::PrettyPrintVisitor` is now `SyntaxTree::PrettyPrintVisitor` +- *BREAKING* Lots of constants are now autoloaded instead of required by default. This is only particularly relevant if you are in a forking environment and want to preload constants before forking for better memory usage with copy-on-write. +- *BREAKING* The `SyntaxTree::Statements#initialize` method no longer accepts a parser as the first argument. It now mirrors the other nodes in that it accepts its children and location. As a result, Syntax Tree nodes are now marshalable (and therefore can be sent over DRb). Previously the `Statements` node was not able to be marshaled because it held a reference to the parser. + +#### Minor changes + +- Many places where embedded documents (`=begin` to `=end`) were being treated as real comments have been fixed for formatting. +- Dynamic symbols in keyword pattern matching now have better formatting. +- Endless method definitions used to have a `SyntaxTree::BodyStmt` node that had any kind of node as its `statements` field. That has been corrected to be more consistent such that now going from `def_node.bodystmt.statements` always returns a `SyntaxTree::Statements` node, which is more consistent. +- We no longer assume that `fiddle` is able to be required, and only require it when it is actually needed. + +#### Tiny changes + +- Empty parameter nodes within blocks now have more accurate location information. +- Pinned variables have more correct location information now. (Previously the location was just around the variable itself, but it now includes the pin.) +- Array patterns in pattern matching now have more accurate location information when they are using parentheses with a constant present. +- Find patterns in pattern matching now have more correct location information for their `left` and `right` fields. +- Lots of nodes have more correct types in the comments on their attributes. +- The expressions `break foo.bar :baz do |qux| qux end` and `next fun foo do end` now correctly parses as a control-flow statement with a method call that has a block attached, as opposed to a control-flow statement with a block attached. +- The expression `self::a, b = 1, 2` would previously yield a `SyntaxTree::ConstPathField` node for the first element of the left-hand-side of the multiple assignment. Semantically this is incorrect, and we have fixed this to now be a `SyntaxTree::Field` node instead. + +## [5.3.0] - 2023-01-26 + +### Added + +- `#arity` has been added to `DefNode`, `BlockNode`, and `Params`. The method returns a range where the lower bound is the minimum and the upper bound is the maximum number of arguments that can be used to invoke that block/method definition. +- `#arity` has been added to `CallNode`, `Command`, `CommandCall`, and `VCall` nodes. The method returns the number of arguments included in the invocation. For splats, double splats, or argument forwards, this method returns `Float::INFINITY`. +- `SyntaxTree::index` and `SyntaxTree::index_file` APIs have been added to collect a list of classes, modules, and methods defined in a given source string or file, respectively. These APIs are experimental and subject to change. +- A `plugin/disable_auto_ternary` plugin has been added the disables the formatted that automatically changes permissable `if/else` clauses into ternaries. + +### Changed + +- Files are now only written from the CLI if the content of them changes, which should match watching files less chaotic. +- In the case that `rb_iseq_load` cannot be found, `Fiddle::DLError` is now rescued. +- Previously if there were invalid UTF-8 byte sequences after the `__END__` keyword the parser could potentially have crashed when parsing comments. This has been fixed. +- Previously there was special formatting for array literals that contained only variable references (either locals, method calls, or constants). For consistency, this has been removed and all array literals are now formatted the same way. + +## [5.2.0] - 2023-01-04 + +### Added + +- An experiment in evaluating compiled instruction sequences has been added to Syntax Tree. This is subject to change, so it will not be well documented or testing at the moment. It does not impact other functionality. + +### Changed + +- Empty parentheses on method calls will now be left in place. Previously they were left in place if the method being called looked like a constant. Now they are left in place for all method calls since the method name can mirror the name of a local variable, in which case the parentheses are required. + +## [5.1.0] - 2022-12-28 + +### Added + +- An experiment in working with instruction sequences has been added to Syntax Tree. This is subject to change, so it is not well documented or tested at the moment. It does not impact other functionality. +- You can now format at a different base layer of indentation. This is an optional third argument to `SyntaxTree::format`. + +### Changed + +- Support forwarding anonymous keyword arguments with `**`. +- The `BodyStmt` node now has a more correct location information. +- Ignore the `textDocument/documentColor` request coming into the language server to support clients that require that request be received. +- Do not attempt to convert `if..else` into ternaries if the predicate has a `Binary` node. +- Properly handle nested pattern matching when a rightward assignment is inside a `when` clause. + +## [5.0.1] - 2022-11-10 + +### Changed + +- Fix the plugin parsing on the CLI so that they are respected. + +## [5.0.0] - 2022-11-09 + +### Added + +- Every node now implements the `#copy(**)` method, which provides a copy of the node with the given attributes replaced. +- Every node now implements the `#===(other)` method, which checks if the given node matches the current node for all attributes except for comments and location. +- There is a new `SyntaxTree::Visitor::MutationVisitor` and its convenience method `SyntaxTree.mutation` which can be used to mutate a syntax tree. For details on how to use this visitor, check the README. + +### Changed + +- Nodes no longer have a `comments:` keyword on their initializers. By default, they initialize to an empty array. If you were previously passing comments into the initializer, you should now create the node first, then call `node.comments.concat` to add your comments. +- A lot of nodes have been folded into other nodes to make it easier to interact with the AST. This means that a lot of visit methods have been removed from the visitor and a lot of class definitions are no longer present. This also means that the nodes that received more function now have additional methods or fields to be able to differentiate them. Note that none of these changes have resulted in different formatting. The changes are listed below: + - `IfMod`, `UnlessMod`, `WhileMod`, `UntilMod` have been folded into `IfNode`, `UnlessNode`, `WhileNode`, and `UntilNode`. Each of the nodes now have a `modifier?` method to tell if it was originally in the modifier form. Consequently, the `visit_if_mod`, `visit_unless_mod`, `visit_while_mod`, and `visit_until_mod` methods have been removed from the visitor. + - `VarAlias` is no longer a node, and the `Alias` node has been renamed. They have been folded into the `AliasNode` node. The `AliasNode` node now has a `var_alias?` method to tell you if it is aliasing a global variable. Consequently, the `visit_var_alias` method has been removed from the visitor interface. If you were previously using this method, you should now use `visit_alias` instead. + - `Yield0` is no longer a node, and the `Yield` node has been renamed. They has been folded into the `YieldNode` node. The `YieldNode` node can now have its `arguments` field be `nil`. Consequently, the `visit_yield0` method has been removed from the visitor interface. If you were previously using this method, you should now use `visit_yield` instead. + - `FCall` is no longer a node, and the `Call` node has been renamed. They have been folded into the `CallNode` node. The `CallNode` node can now have its `receiver` and `operator` fields be `nil`. Consequently, the `visit_fcall` method has been removed from the visitor interface. If you were previously using this method, you should now use `visit_call` instead. + - `Dot2` and `Dot3` are no longer nodes. Instead they have become a single new `RangeNode` node. This node looks the same as `Dot2` and `Dot3`, except that it additionally has an `operator` field that contains the operator that created the node. Consequently, the `visit_dot2` and `visit_dot3` methods have been removed from the visitor interface. If you were previously using these methods, you should now use `visit_range` instead. + - `Def`, `DefEndless`, and `Defs` have been folded into the `DefNode` node. The `DefNode` node now has the `target` and `operator` fields which originally came from `Defs` which can both be `nil`. It also now has an `endless?` method on it to tell if the original node was found in the endless form. Finally the `bodystmt` field can now either be a `BodyStmt` as it was or any other kind of node since that was the body of the `DefEndless` node. The `visit_defs` and `visit_def_endless` methods on the visitor have therefore been removed. + - `DoBlock` and `BraceBlock` have now been folded into a `BlockNode` node. The `BlockNode` node now has a `keywords?` method on it that returns true if the block was constructed with the `do`..`end` keywords. The `visit_do_block` and `visit_brace_block` methods on the visitor have therefore been removed and replaced with the `visit_block` method. + - `Return0` is no longer a node, and the `Return` node has been renamed. They have been folded into the `ReturnNode` node. The `ReturnNode` node can now have its `arguments` field be `nil`. Consequently, the `visit_return0` method has been removed from the visitor interface. If you were previously using this method, you should now use `visit_return` instead. +- The `ArgsForward`, `Redo`, `Retry`, and `ZSuper` nodes no longer have `value` fields associated with them (which were always string literals corresponding to the keyword being used). +- The `Command` and `CommandCall` nodes now has `block` attributes on them. These attributes are used in the place where you would previously have had a `MethodAddBlock` structure. Where before the `MethodAddBlock` would have the command and block as its two children, you now just have one command node with the `block` attribute set to the `Block` node. +- Previously the formatting options were defined on an unfrozen hash called `SyntaxTree::Formatter::OPTIONS`. It was globally mutable, which made it impossible to reference from within a Ractor. As such, it has now been replaced with `SyntaxTree::Formatter::Options.new` which creates a new options object instance that can be modified without impacting global state. As a part of this change, formatting can now be performed from within a non-main Ractor. In order to check if the `plugin/single_quotes` plugin has been loaded, check if `SyntaxTree::Formatter::SINGLE_QUOTES` is defined. In order to check if the `plugin/trailing_comma` plugin has been loaded, check if `SyntaxTree::Formatter::TRAILING_COMMA` is defined. + +## [4.3.0] - 2022-10-28 + +### Added + +- [#183](https://github.com/ruby-syntax-tree/syntax_tree/pull/183) - Support TruffleRuby by eliminating internal pattern matching in some places and stopping some tests from running in other places. +- [#184](https://github.com/ruby-syntax-tree/syntax_tree/pull/184) - Remove internal pattern matching entirely. + +### Changed + +- [#183](https://github.com/ruby-syntax-tree/syntax_tree/pull/183) - Pattern matching works against dynamic symbols now. +- [#184](https://github.com/ruby-syntax-tree/syntax_tree/pull/184) - Exit with the correct exit status within the rake tasks. + +## [4.2.0] - 2022-10-25 + +### Added + +- [#182](https://github.com/ruby-syntax-tree/syntax_tree/pull/182) - The new `stree expr` CLI command will function similarly to the `stree match` CLI command except that it only outputs the first expression of the program. +- [#182](https://github.com/ruby-syntax-tree/syntax_tree/pull/182) - Added the `SyntaxTree::Pattern` class for compiling `in` expressions into procs. + +### Changed + +- [#182](https://github.com/ruby-syntax-tree/syntax_tree/pull/182) - Much more syntax is now supported by the search command. + +## [4.1.0] - 2022-10-24 + +### Added + +- [#180](https://github.com/ruby-syntax-tree/syntax_tree/pull/180) - The new `stree search` CLI command and the corresponding `SyntaxTree::Search` class for searching for a pattern against a given syntax tree. + +## [4.0.2] - 2022-10-19 + +### Changed + +- [#177](https://github.com/ruby-syntax-tree/syntax_tree/pull/177) - Fix up various other issues with the environment visitor addition. + +## [4.0.1] - 2022-10-18 + +### Changed + +- [#172](https://github.com/ruby-syntax-tree/syntax_tree/pull/172) - Use a refinement for `Symbol#name` addition so that other runtimes or tools don't get confused by its availability. +- [#173](https://github.com/ruby-syntax-tree/syntax_tree/pull/173) - Fix the `current_environment` usage to use the method instead of the instance variable. +- [#175](https://github.com/ruby-syntax-tree/syntax_tree/pull/175) - Update `prettier_print` requirement since v1.0.0 had a bug with `#breakable_return`. + +## [4.0.0] - 2022-10-17 + +### Added + +- [#169](https://github.com/ruby-syntax-tree/syntax_tree/pull/169) - You can now pass `--ignore-files` multiple times. +- [#157](https://github.com/ruby-syntax-tree/syntax_tree/pull/157) - We now support tracking local variable definitions throughout the visitor. This allows you to access scope information while visiting the tree. +- [#170](https://github.com/ruby-syntax-tree/syntax_tree/pull/170) - There is now an undocumented `STREE_FAST_FORMAT` environment variable checked when formatting. It has the effect of turning _off_ formatting call chains and ternaries in special ways. This improves performance quite a bit. I'm leaving it undocumented because ideally we just improve the performance as a whole. This is meant as a stopgap until we get there. + +### Changed + +- [#170](https://github.com/ruby-syntax-tree/syntax_tree/pull/170) - We now require at least version `1.0.0` of `prettier_print`. This is to take advantage of the first-class string support in the doc tree. +- [#170](https://github.com/ruby-syntax-tree/syntax_tree/pull/170) - Pattern matching has been removed from usage internal to this library (excluding the language server). This should hopefully enable runtimes that don't have pattern matching fully implemented yet (e.g., TruffleRuby) to run this gem. + +## [3.6.3] - 2022-10-11 + +### Changed + +- [#167](https://github.com/ruby-syntax-tree/syntax_tree/pull/167) - Change the error encountered when an `else` node does not have an associated `end` token to be a parse error. + +## [3.6.2] - 2022-10-04 + +### Changed + +- [#165](https://github.com/ruby-syntax-tree/syntax_tree/pull/165) - Conditionals (`if`/`unless`), loops (`for`/`while`/`until`) and lambdas all had issues when comments immediately succeeded the declaration of the node where the comment could potentially be dropped. That has now been fixed. +- [#166](https://github.com/ruby-syntax-tree/syntax_tree/pull/166) - Blocks can now be formatted even if they are the top node of the tree. Previously they were looking to their parent for some additional metadata, so we now handle the case where the parent is nil. + +## [3.6.1] - 2022-09-28 + +### Changed + +- [#161](https://github.com/ruby-syntax-tree/syntax_tree/pull/161) - Previously, we were checking if STDIN was a TTY to determine if there was content to be read. Instead, we now check if no filenames were passed, and in that case we attempt to read from STDIN. This should fix errors users were experiencing in non-TTY environments like CI. +- [#162](https://github.com/ruby-syntax-tree/syntax_tree/pull/162) - Parse errors shouldn't crash the language server anymore. + +## [3.6.0] - 2022-09-19 + +### Added + +- [#158](https://github.com/ruby-syntax-tree/syntax_tree/pull/158) - Support the ability to pass `--ignore-files` to the CLI and the Rake tasks to ignore a certain pattern of files. + +## [3.5.0] - 2022-08-26 + +### Added + +- [#148](https://github.com/ruby-syntax-tree/syntax_tree/pull/148) - Support Ruby 2.7.0 (previously we only supported back to 2.7.3). +- [#152](https://github.com/ruby-syntax-tree/syntax_tree/pull/152) - Support the `-e` inline script option for the `stree` CLI. + +### Changed + +- [#141](https://github.com/ruby-syntax-tree/syntax_tree/pull/141) - Use `q.format` for `SyntaxTree.format` so that the main node gets pushed onto the stack for checking parent nodes. +- [#147](https://github.com/ruby-syntax-tree/syntax_tree/pull/147) - Fix rightward assignment token management such that `in` and `=>` stay the same regardless of their context. + +## [3.4.0] - 2022-08-19 + +### Added + +- [#127](https://github.com/ruby-syntax-tree/syntax_tree/pull/127) - Allow the language server to handle other file extensions if it is activated for those extensions. +- [#133](https://github.com/ruby-syntax-tree/syntax_tree/pull/133) - Add documentation on supporting vim and neovim. + +### Changed + +- [#132](https://github.com/ruby-syntax-tree/syntax_tree/pull/132) - Provide better error messages when end quotes and end keywords are missing from tokens. +- [#134](https://github.com/ruby-syntax-tree/syntax_tree/pull/134) - Ensure the correct `end` keyword is getting removed by `begin..rescue` clauses. +- [#137](https://github.com/ruby-syntax-tree/syntax_tree/pull/137) - Better support regular expressions with no ending token. + +## [3.3.0] - 2022-08-02 + +### Added + +- [#123](https://github.com/ruby-syntax-tree/syntax_tree/pull/123) - Allow the rake tasks to configure print width. +- [#125](https://github.com/ruby-syntax-tree/syntax_tree/pull/125) - Add support for an `.streerc` file in the current working directory to configure the CLI. + +## [3.2.1] - 2022-07-22 + +### Changed + +- [#119](https://github.com/ruby-syntax-tree/syntax_tree/pull/119) - If there are conditionals in the assignment we cannot convert it to the modifier form. There was a bug where it would stop checking for assignment nodes if there were any optional child nodes. + +## [3.2.0] - 2022-07-19 + +### Added + +- [#116](https://github.com/ruby-syntax-tree/syntax_tree/pull/116) - Pass the `--print-width` option in the CLI to the language server. + +## [3.1.0] - 2022-07-19 + +### Added + +- [#115](https://github.com/ruby-syntax-tree/syntax_tree/pull/115) - Support the `--print-width` option in the CLI for the actions that support it. + +## [3.0.1] - 2022-07-15 + +### Changed + +- [#112](https://github.com/ruby-syntax-tree/syntax_tree/pull/112) - Fix parallel CLI execution by not short-circuiting with the `||` operator. + +## [3.0.0] - 2022-07-04 + +### Changed + +- [#102](https://github.com/ruby-syntax-tree/syntax_tree/issues/102) - Handle requests to the language server for files that do not yet exist on disk. + +### Removed + +- [#108](https://github.com/ruby-syntax-tree/syntax_tree/pull/108) - Remove old inlay hints code. + +## [2.9.0] - 2022-07-04 + +### Added + +- [#106](https://github.com/ruby-syntax-tree/syntax_tree/pull/106) - Add inlay hint support to match the LSP specification. + +## [2.8.0] - 2022-06-21 + +### Added + +- [#95](https://github.com/ruby-syntax-tree/syntax_tree/pull/95) - The `HeredocEnd` node has been added which effectively results in the ability to determine the location of the ending of a heredoc from source. +- [#99](https://github.com/ruby-syntax-tree/syntax_tree/pull/99) - The LSP now allows you to pass the same configuration options as the other CLI commands which allows formatting to be modified in the VSCode extension. +- [#100](https://github.com/ruby-syntax-tree/syntax_tree/pull/100) - The LSP now explicitly responds to the shutdown request so that VSCode never deadlocks. + +### Changed + +- [#96](https://github.com/ruby-syntax-tree/syntax_tree/pull/96) - The CLI now runs in parallel by default. There is a worker created for each processor on the running machine (as determined by `Etc.nprocessors`). +- [#97](https://github.com/ruby-syntax-tree/syntax_tree/pull/97) - Syntax Tree now handles the case where `DidYouMean` is not available for whatever reason, as well as handles the newer `detailed_message` API for errors. + +## [2.7.1] - 2022-05-25 + +### Added + +- [#92](https://github.com/ruby-syntax-tree/syntax_tree/pull/92) - (Internal) Drastically increase test coverage, including many more tests for the language server and the CLI. + +### Changed + +- [#87](https://github.com/ruby-syntax-tree/syntax_tree/pull/87) - Don't convert quotes on strings if it would result in more escapes. +- [#91](https://github.com/ruby-syntax-tree/syntax_tree/pull/91) - Always use `[]` with array patterns. There are just too many edge cases where you have to use them anyway. This simplifies the look and makes it more consistent. +- [#92](https://github.com/ruby-syntax-tree/syntax_tree/pull/92) - Remodel the currently shipped plugins such that they're modifying an options hash instead of overriding methods. This should make it easier for other plugins to reference the already loaded plugins, e.g., the RBS plugin referencing the quotes. +- [#92](https://github.com/ruby-syntax-tree/syntax_tree/pull/92) - Fix up the language server inlay hints to continue walking the tree once a pattern is found. This should increase useability. + +## [2.7.0] - 2022-05-19 + +### Added + +- [#88](https://github.com/ruby-syntax-tree/syntax_tree/pull/88) - Provide a `SyntaxTree::BasicVisitor` that has no visit methods implemented. + +### Changed + +- [#90](https://github.com/ruby-syntax-tree/syntax_tree/pull/90) - Provide better formatting for `SyntaxTree::AryPtn` when its nested inside a `SyntaxTree::RAssign`. + +## [2.6.0] - 2022-05-16 + +### Added + +- [#74](https://github.com/ruby-syntax-tree/syntax_tree/pull/74) - Add Rake test to run check and format commands. +- [#83](https://github.com/ruby-syntax-tree/syntax_tree/pull/83) - Add a trailing commas plugin. +- [#84](https://github.com/ruby-syntax-tree/syntax_tree/pull/84) - Handle lambda block-local variables. + +### Changed + +- [#85](https://github.com/ruby-syntax-tree/syntax_tree/pull/85) - Better handle trailing operators on command calls. + +## [2.5.0] - 2022-05-13 + +### Added + +- [#79](https://github.com/ruby-syntax-tree/syntax_tree/pull/79) - Support an optional `maxwidth` second argument to `SyntaxTree.format`. + +### Changed + +- [#77](https://github.com/ruby-syntax-tree/syntax_tree/pull/77) - Correct the pattern for checking if a dynamic symbol can be converted into a label as a hash key. +- [#72](https://github.com/ruby-syntax-tree/syntax_tree/pull/72) - Disallow conditionals with `not` without parentheses in the predicate from turning into a ternary. + +## [2.4.1] - 2022-05-10 + +- [#73](https://github.com/ruby-syntax-tree/syntax_tree/pull/73) - Fix nested hash patterns from accidentally adding a `then` to their output. + ## [2.4.0] - 2022-05-07 ### Added @@ -209,7 +627,43 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) a - 🎉 Initial release! 🎉 -[unreleased]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.4.0...HEAD +[unreleased]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v6.2.0...HEAD +[6.2.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v6.1.1...v6.2.0 +[6.1.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v6.1.0...v6.1.1 +[6.1.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v6.0.2...v6.1.0 +[6.0.2]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v6.0.1...v6.0.2 +[6.0.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v6.0.0...v6.0.1 +[6.0.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v5.3.0...v6.0.0 +[5.3.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v5.2.0...v5.3.0 +[5.2.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v5.1.0...v5.2.0 +[5.1.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v5.0.1...v5.1.0 +[5.0.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v5.0.0...v5.0.1 +[5.0.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.3.0...v5.0.0 +[4.3.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.2.0...v4.3.0 +[4.2.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.1.0...v4.2.0 +[4.1.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.0.2...v4.1.0 +[4.0.2]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.0.1...v4.0.2 +[4.0.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.0.0...v4.0.1 +[4.0.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.3...v4.0.0 +[3.6.3]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.2...v3.6.3 +[3.6.2]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.1...v3.6.2 +[3.6.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.0...v3.6.1 +[3.6.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.5.0...v3.6.0 +[3.5.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.4.0...v3.5.0 +[3.4.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.3.0...v3.4.0 +[3.3.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.2.1...v3.3.0 +[3.2.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.2.0...v3.2.1 +[3.2.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.1.0...v3.2.0 +[3.1.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.0.1...v3.1.0 +[3.0.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.0.0...v3.0.1 +[3.0.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.9.0...v3.0.0 +[2.9.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.8.0...v2.9.0 +[2.8.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.7.1...v2.8.0 +[2.7.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.7.0...v2.7.1 +[2.7.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.6.0...v2.7.0 +[2.6.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.5.0...v2.6.0 +[2.5.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.4.1...v2.5.0 +[2.4.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.4.0...v2.4.1 [2.4.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.3.1...v2.4.0 [2.3.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.3.0...v2.3.1 [2.3.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v2.2.0...v2.3.0 diff --git a/Gemfile b/Gemfile index 73418542..b4252fb5 100644 --- a/Gemfile +++ b/Gemfile @@ -4,4 +4,4 @@ source "https://rubygems.org" gemspec -gem "rubocop" +gem "fiddle" diff --git a/Gemfile.lock b/Gemfile.lock index 8357fd92..a5d1d974 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,40 +1,53 @@ PATH remote: . specs: - syntax_tree (2.4.0) + syntax_tree (6.3.0) + prettier_print (>= 1.2.0) GEM remote: https://rubygems.org/ specs: - ast (2.4.2) - docile (1.4.0) - minitest (5.15.0) - parallel (1.22.1) - parser (3.1.2.0) + ast (2.4.3) + docile (1.4.1) + fiddle (1.1.8) + json (2.15.2) + language_server-protocol (3.17.0.5) + lint_roller (1.1.0) + minitest (5.26.1) + parallel (1.27.0) + parser (3.3.10.0) ast (~> 2.4.1) + racc + prettier_print (1.2.1) + prism (1.6.0) + racc (1.8.1) rainbow (3.1.1) - rake (13.0.6) - regexp_parser (2.3.1) - rexml (3.2.5) - rubocop (1.29.0) + rake (13.3.1) + regexp_parser (2.11.3) + rubocop (1.81.7) + json (~> 2.3) + language_server-protocol (~> 3.17.0.2) + lint_roller (~> 1.1.0) parallel (~> 1.10) - parser (>= 3.1.0.0) + parser (>= 3.3.0.2) rainbow (>= 2.2.2, < 4.0) - regexp_parser (>= 1.8, < 3.0) - rexml (>= 3.2.5, < 4.0) - rubocop-ast (>= 1.17.0, < 2.0) + regexp_parser (>= 2.9.3, < 3.0) + rubocop-ast (>= 1.47.1, < 2.0) ruby-progressbar (~> 1.7) - unicode-display_width (>= 1.4.0, < 3.0) - rubocop-ast (1.17.0) - parser (>= 3.1.1.0) - ruby-progressbar (1.11.0) - simplecov (0.21.2) + unicode-display_width (>= 2.4.0, < 4.0) + rubocop-ast (1.47.1) + parser (>= 3.3.7.2) + prism (~> 1.4) + ruby-progressbar (1.13.0) + simplecov (0.22.0) docile (~> 1.1) simplecov-html (~> 0.11) simplecov_json_formatter (~> 0.1) - simplecov-html (0.12.3) + simplecov-html (0.13.1) simplecov_json_formatter (0.1.4) - unicode-display_width (2.1.0) + unicode-display_width (3.2.0) + unicode-emoji (~> 4.1) + unicode-emoji (4.1.0) PLATFORMS arm64-darwin-21 @@ -45,6 +58,7 @@ PLATFORMS DEPENDENCIES bundler + fiddle minitest rake rubocop diff --git a/README.md b/README.md index 8955a310..c238620e 100644 --- a/README.md +++ b/README.md @@ -15,31 +15,49 @@ It is built with only standard library dependencies. It additionally ships with - [CLI](#cli) - [ast](#ast) - [check](#check) + - [ctags](#ctags) + - [expr](#expr) - [format](#format) - [json](#json) - [match](#match) + - [search](#search) - [write](#write) + - [Configuration](#configuration) + - [Globbing](#globbing) - [Library](#library) - [SyntaxTree.read(filepath)](#syntaxtreereadfilepath) - [SyntaxTree.parse(source)](#syntaxtreeparsesource) - [SyntaxTree.format(source)](#syntaxtreeformatsource) + - [SyntaxTree.mutation(&block)](#syntaxtreemutationblock) + - [SyntaxTree.search(source, query, &block)](#syntaxtreesearchsource-query-block) + - [SyntaxTree.index(source)](#syntaxtreeindexsource) - [Nodes](#nodes) - [child_nodes](#child_nodes) + - [copy(**attrs)](#copyattrs) - [Pattern matching](#pattern-matching) - [pretty_print(q)](#pretty_printq) - [to_json(*opts)](#to_jsonopts) - [format(q)](#formatq) + - [===(other)](#other) - [construct_keys](#construct_keys) - [Visitor](#visitor) - [visit_method](#visit_method) + - [visit_methods](#visit_methods) + - [BasicVisitor](#basicvisitor) + - [MutationVisitor](#mutationvisitor) + - [WithScope](#withscope) - [Language server](#language-server) - [textDocument/formatting](#textdocumentformatting) - - [textDocument/inlayHints](#textdocumentinlayhints) + - [textDocument/inlayHint](#textdocumentinlayhint) - [syntaxTree/visualizing](#syntaxtreevisualizing) -- [Plugins](#plugins) +- [Customization](#customization) + - [Ignoring code](#ignoring-code) + - [Plugins](#plugins) + - [Languages](#languages) - [Integration](#integration) + - [Rake](#rake) - [RuboCop](#rubocop) - - [VSCode](#vscode) + - [Editors](#editors) - [Contributing](#contributing) - [License](#license) @@ -77,7 +95,9 @@ bundle exec stree version ## CLI -Syntax Tree ships with the `stree` CLI, which can be used to inspect and manipulate Ruby code. Below are listed all of the commands built into the CLI that you can use. Note that for all commands that operate on files, you can also pass in content through STDIN. +Syntax Tree ships with the `stree` CLI, which can be used to inspect and manipulate Ruby code. Below are listed all of the commands built into the CLI that you can use. + +For many commands, file paths are accepted after the configuration options. For all of these commands, you can alternatively pass in content through STDIN or through the `-e` option to specify an inline script. ### ast @@ -114,9 +134,60 @@ If there are files with unformatted code, you will receive: The listed files did not match the expected format. ``` +To change the print width that you are checking against, specify the `--print-width` option, as in: + +```sh +stree check --print-width=100 path/to/file.rb +``` + +### ctags + +This command will output to stdout a set of tags suitable for usage with [ctags](https://github.com/universal-ctags/ctags). + +```sh +stree ctags path/to/file.rb +``` + +For a file containing the following Ruby code: + +```ruby +class Foo +end + +class Bar < Foo +end +``` + +you will receive: + +``` +!_TAG_FILE_FORMAT 2 /extended format; --format=1 will not append ;" to lines/ +!_TAG_FILE_SORTED 1 /0=unsorted, 1=sorted, 2=foldcase/ +Bar test.rb /^class Bar < Foo$/;" c inherits:Foo +Foo test.rb /^class Foo$/;" c +``` + +### expr + +This command will output a Ruby case-match expression that would match correctly against the first expression of the input. + +```sh +stree expr path/to/file.rb +``` + +For a file that contains `1 + 1`, you will receive: + +```ruby +SyntaxTree::Binary[ + left: SyntaxTree::Int[value: "1"], + operator: :+, + right: SyntaxTree::Int[value: "1"] +] +``` + ### format -This command will output the formatted version of each of the listed files. Importantly, it will not write that content back to the source files. It is meant to display the formatted version only. +This command will output the formatted version of each of the listed files to stdout. Importantly, it will not write that content back to the source files – for that, you want [`write`](#write). ```sh stree format path/to/file.rb @@ -128,6 +199,12 @@ For a file that contains `1 + 1`, you will receive: 1 + 1 ``` +To change the print width that you are formatting with, specify the `--print-width` option, as in: + +```sh +stree format --print-width=100 path/to/file.rb +``` + ### json This command will output a JSON representation of the syntax tree that is functionally equivalent to the input. This is mostly used in contexts where you need to access the tree from JavaScript or serialize it over a network. @@ -195,9 +272,32 @@ SyntaxTree::Program[ ] ``` +### search + +This command will search the given filepaths against the specified pattern to find nodes that match. The pattern is a Ruby pattern-matching expression that is matched against each node in the tree. It can optionally be loaded from a file if you specify a filepath as the pattern argument. + +```sh +stree search VarRef path/to/file.rb +``` + +For a file that contains `Foo + Bar` you will receive: + +``` +path/to/file.rb:1:0: Foo + Bar +path/to/file.rb:1:6: Foo + Bar +``` + +If you put `VarRef` into a file instead (for example, `query.txt`), you would instead run: + +```sh +stree search query.txt path/to/file.rb +``` + +Note that the output of the `match` CLI command creates a valid pattern that can be used as the input for this command. + ### write -This command will format the listed files and write that formatted version back to the source files. Note that this overwrites the original content, to be sure to be using a version control system. +This command will format the listed files and write that formatted version back to the source files. Note that this overwrites the original content, so be sure to be using a version control system. ```sh stree write path/to/file.rb @@ -209,13 +309,56 @@ This will list every file that is being formatted. It will output light gray if path/to/file.rb 0ms ``` +To change the print width that you are writing with, specify the `--print-width` option, as in: + +```sh +stree write --print-width=100 path/to/file.rb +``` + +To ignore certain files from a glob (in order to make it easier to specify the filepaths), you can pass the `--ignore-files` option as an additional glob, as in: + +```sh +stree write --ignore-files='db/**/*.rb' '**/*.rb' +``` + +### Configuration + +Any of the above CLI commands can also read configuration options from a `.streerc` file in the directory where the commands are executed. + +This should be a text file with each argument on a separate line. + +```txt +--print-width=100 +--plugins=plugin/trailing_comma +``` + +If this file is present, it will _always_ be used for CLI commands. You can also pass options from the command line as in the examples above. The options in the `.streerc` file are passed to the CLI first, then the arguments from the command line. In the case of exclusive options (e.g. `--print-width`), this means that the command line options override what's in the config file. In the case of options that can take multiple inputs (e.g. `--plugins`), the effect is additive. That is, the plugins passed from the command line will be loaded _in addition to_ the plugins in the config file. + +### Globbing + +When running commands with `stree`, it's common to pass in lists of files. For example: + +```sh +stree write 'lib/*.rb' 'test/*.rb' +``` + +The commands in the CLI accept any number of arguments. This means you _could_ pass `**/*.rb` (note the lack of quotes). This would make your shell expand out the file paths listed according to its own rules. (For example, [here](https://www.gnu.org/software/bash/manual/html_node/Filename-Expansion.html) are the rules for GNU bash.) + +However, it's recommended to instead use quotes, which means that Ruby is responsible for performing the file path expansion instead. This ensures a consistent experience across different environments and shells. The globs must follow the Ruby-specific globbing syntax as specified in the documentation for [Dir](https://ruby-doc.org/core-3.1.1/Dir.html#method-c-glob). + +Baked into this syntax is the ability to provide exceptions to file name patterns as well. For example, if you are in a Rails app and want to exclude files named `schema.rb` but write all other Ruby files, you can use the following syntax: + +```shell +stree write "**/{[!schema]*,*}.rb" +``` + ## Library -Syntax Tree can be used as a library to access the syntax tree underlying Ruby source code. +Syntax Tree can be used as a library to access the syntax tree underlying Ruby source code. The API is described below. For the full library documentation, see the [RDoc documentation](https://ruby-syntax-tree.github.io/syntax_tree/). ### SyntaxTree.read(filepath) -This function takes a filepath and returns a string associated with the content of that file. It is similar in functionality to `File.read`, except htat it takes into account Ruby-level file encoding (through magic comments at the top of the file). +This function takes a filepath and returns a string associated with the content of that file. It is similar in functionality to `File.read`, except that it takes into account Ruby-level file encoding (through magic comments at the top of the file). ### SyntaxTree.parse(source) @@ -223,7 +366,19 @@ This function takes an input string containing Ruby code and returns the syntax ### SyntaxTree.format(source) -This function takes an input string containing Ruby code, parses it into its underlying syntax tree, and formats it back out to a string. +This function takes an input string containing Ruby code, parses it into its underlying syntax tree, and formats it back out to a string. You can optionally pass a second argument to this method as well that is the maximum width to print. It defaults to `80`. + +### SyntaxTree.mutation(&block) + +This function yields a new mutation visitor to the block, and then returns the initialized visitor. It's effectively a shortcut for creating a `SyntaxTree::MutationVisitor` without having to remember the class name. For more information on that visitor, see the definition below. + +### SyntaxTree.search(source, query, &block) + +This function takes an input string containing Ruby code, an input string containing a valid Ruby `in` clause expression that can be used to match against nodes in the tree (can be generated using `stree expr`, `stree match`, or `Node#construct_keys`), and a block. Each node that matches the given query will be yielded to the block. The block will receive the node as its only argument. + +### SyntaxTree.index(source) + +This function takes an input string containing Ruby code and returns a list of all of the class declarations, module declarations, and method definitions within a file. Each of the entries also has access to its associated comments. This is useful for generating documentation or index information for a file to support something like go-to-definition. ## Nodes @@ -239,6 +394,20 @@ program.child_nodes.first.child_nodes.first # => (binary (int "1") :+ (int "1")) ``` +### copy(**attrs) + +This method returns a copy of the node, with the given attributes replaced. + +```ruby +program = SyntaxTree.parse("1 + 1") + +binary = program.statements.body.first +# => (binary (int "1") + (int "1")) + +binary.copy(operator: :-) +# => (binary (int "1") - (int "1")) +``` + ### Pattern matching Pattern matching is another way to descend the tree which is more specific than using `child_nodes`. Using Ruby's built-in pattern matching, you can extract the same information but be as specific about your constraints as you like. For example, with minimal constraints: @@ -296,6 +465,18 @@ formatter.output.join # => "1 + 1" ``` +### ===(other) + +Every node responds to `===`, which is used to check if the given other node matches all of the attributes of the current node except for location and comments. For example: + +```ruby +program1 = SyntaxTree.parse("1 + 1") +program2 = SyntaxTree.parse("1 + 1") + +program1 === program2 +# => true +``` + ### construct_keys Every node responds to `construct_keys`, which will return a string that contains a Ruby pattern-matching expression that could be used to match against the current node. It's meant to be used in tooling and through the CLI mostly. @@ -305,16 +486,16 @@ program = SyntaxTree.parse("1 + 1") puts program.construct_keys # SyntaxTree::Program[ -# statements: SyntaxTree::Statements[ -# body: [ -# SyntaxTree::Binary[ -# left: SyntaxTree::Int[value: "1"], -# operator: :+, -# right: SyntaxTree::Int[value: "1"] -# ] -# ] -# ] -# ] +# statements: SyntaxTree::Statements[ +# body: [ +# SyntaxTree::Binary[ +# left: SyntaxTree::Int[value: "1"], +# operator: :+, +# right: SyntaxTree::Int[value: "1"] +# ] +# ] +# ] +# ] ``` ## Visitor @@ -344,7 +525,7 @@ With visitors, you only define handlers for the nodes that you need. You can fin * call `visit(child)` with each child that you want to visit * call nothing if you're sure you don't want to descend further -There are a couple of visitors that ship with Syntax Tree that can be used as examples. They live in the [lib/syntax_tree/visitor](lib/syntax_tree/visitor) directory. +There are a couple of visitors that ship with Syntax Tree that can be used as examples. They live in the [lib/syntax_tree](lib/syntax_tree) directory. ### visit_method @@ -370,6 +551,96 @@ Did you mean? visit_binary from bin/console:8:in `
' ``` +### visit_methods + +Similar to `visit_method`, `visit_methods` also checks that methods defined are valid visit methods. This variation however accepts a block and checks that all methods defined within that block are valid visit methods. It's meant to be used like: + +```ruby +class ArithmeticVisitor < SyntaxTree::Visitor + visit_methods do + def visit_binary(node) + # ... + end + + def visit_int(node) + # ... + end + end +end +``` + +This is only checked when the methods are defined and does not impose any kind of runtime overhead after that. It is very useful for upgrading versions of Syntax Tree in case these methods names change. + +### BasicVisitor + +When you're defining your own visitor, by default it will walk down the tree even if you don't define `visit_*` methods. This is to ensure you can define a subset of the necessary methods in order to only interact with the nodes you're interested in. If you'd like to change this default to instead raise an error if you visit a node you haven't explicitly handled, you can instead inherit from `BasicVisitor`. + +```ruby +class MyVisitor < SyntaxTree::BasicVisitor + def visit_int(node) + # ... + end +end +``` + +The visitor defined above will error out unless it's only visiting a `SyntaxTree::Int` node. This is useful in a couple of ways, e.g., if you're trying to define a visitor to handle the whole tree but it's currently a work-in-progress. + +### MutationVisitor + +The `MutationVisitor` is a visitor that can be used to mutate the tree. It works by defining a default `visit_*` method that returns a copy of the given node with all of its attributes visited. This new node will replace the old node in the tree. Typically, you use the `#mutate` method on it to define mutations using patterns. For example: + +```ruby +# Create a new visitor +visitor = SyntaxTree::MutationVisitor.new + +# Specify that it should mutate If nodes with assignments in their predicates +visitor.mutate("IfNode[predicate: Assign | OpAssign]") do |node| + # Get the existing If's predicate node + predicate = node.predicate + + # Create a new predicate node that wraps the existing predicate node + # in parentheses + predicate = + SyntaxTree::Paren.new( + lparen: SyntaxTree::LParen.default, + contents: predicate, + location: predicate.location + ) + + # Return a copy of this node with the new predicate + node.copy(predicate: predicate) +end + +source = "if a = 1; end" +program = SyntaxTree.parse(source) + +SyntaxTree::Formatter.format(source, program) +# => "if a = 1\nend\n" + +SyntaxTree::Formatter.format(source, program.accept(visitor)) +# => "if (a = 1)\nend\n" +``` + +### WithScope + +The `WithScope` module can be included in visitors to automatically keep track of local variables and arguments defined inside each scope. A `current_scope` accessor is made available to the request, allowing it to find all usages and definitions of a local. + +```ruby +class MyVisitor < Visitor + prepend WithScope + + def visit_ident(node) + # find_local will return a Local for any local variables or arguments + # present in the current environment or nil if the identifier is not a local + local = current_scope.find_local(node) + + puts local.type # the type of the local (:variable or :argument) + puts local.definitions # the array of locations where this local is defined + puts local.usages # the array of locations where this local occurs + end +end +``` + ## Language server Syntax Tree additionally ships with a language server conforming to the [language server protocol](https://microsoft.github.io/language-server-protocol/). It can be invoked through the CLI by running: @@ -384,7 +655,7 @@ By default, the language server is relatively minimal, mostly meant to provide a As mentioned above, the language server responds to formatting requests with the formatted document. It typically responds on the order of tens of milliseconds, so it should be fast enough for any IDE. -### textDocument/inlayHints +### textDocument/inlayHint The language server also responds to the relatively new inlay hints request. This request allows the language server to define additional information that should exist in the source code as helpful hints to the developer. In our case we use it to display things like implicit parentheses. For example, if you had the following code: @@ -392,7 +663,7 @@ The language server also responds to the relatively new inlay hints request. Thi 1 + 2 * 3 ``` -Implicity, the `2 * 3` is going to be executed first because the `*` operator has higher precedence than the `+` operator. However, to ease mental overhead, our language server includes small parentheses to make this explicit, as in: +Implicitly, the `2 * 3` is going to be executed first because the `*` operator has higher precedence than the `+` operator. To ease mental overhead, our language server includes small parentheses to make this explicit, as in: ```ruby 1 + ₍2 * 3₎ @@ -402,15 +673,46 @@ Implicity, the `2 * 3` is going to be executed first because the `*` operator ha The language server additionally includes this custom request to return a textual representation of the syntax tree underlying the source code of a file. Language server clients can use this to (for example) open an additional tab with this information displayed. -## Plugins +## Customization -You can register additional configuration and additional languages that can flow through the same CLI with Syntax Tree's plugin system. When invoking the CLI, you pass through the list of plugins with the `--plugins` options to the commands that accept them. They should be a comma-delimited list. When the CLI first starts, it will require the files corresponding to those names. +There are multiple ways to customize Syntax Tree's behavior when parsing and formatting code. You can ignore certain sections of the source code, you can register plugins to provide custom formatting behavior, and you can register additional languages to be parsed and formatted. -### Configuration +### Ignoring code + +To ignore a section of source code, you can use a special `# stree-ignore` comment. This comment should be placed immediately above the code that you want to ignore. For example: + +```ruby +numbers = [ + 10000, + 20000, + 30000 +] +``` + +Normally the snippet above would be formatted as `numbers = [10_000, 20_000, 30_000]`. However, sometimes you want to keep the original formatting to improve readability or maintainability. In that case, you can put the ignore comment before it, as in: + +```ruby +# stree-ignore +numbers = [ + 10000, + 20000, + 30000 +] +``` + +Now when Syntax Tree goes to format that code, it will copy the source code exactly as it is, including the newlines and indentation. + +### Plugins -To register additional configuration, define a file somewhere in your load path named `syntax_tree/my_plugin` directory. Then when invoking the CLI, you will pass `--plugins=my_plugin`. That will get required. In this way, you can modify Syntax Tree however you would like. Some plugins ship with Syntax Tree itself. They are: +You can register additional customization that can flow through the same CLI with Syntax Tree's plugin system. When invoking the CLI, you pass through the list of plugins with the `--plugins` options to the commands that accept them. They should be a comma-delimited list. When the CLI first starts, it will require the files corresponding to those names. + +To register plugins, define a file somewhere in your load path named `syntax_tree/my_plugin`. Then when invoking the CLI, you will pass `--plugins=my_plugin`. To require multiple, separate them by a comma. In this way, you can modify Syntax Tree however you would like. Some plugins ship with Syntax Tree itself. They are: * `plugin/single_quotes` - This will change all of your string literals to use single quotes instead of the default double quotes. +* `plugin/trailing_comma` - This will put trailing commas into multiline array literals, hash literals, and method calls that can support trailing commas. +* `plugin/disable_auto_ternary` - This will prevent the automatic conversion of `if ... else` to ternary expressions. + +If you're using Syntax Tree as a library, you can require those files directly or manually pass those options to the formatter initializer through the `SyntaxTree::Formatter::Options` class. ### Languages @@ -428,13 +730,77 @@ In this case, whenever the CLI encounters a filepath that ends with the given ex Below are listed all of the "official" language plugins hosted under the same GitHub organization, which can be used as references for how to implement other plugins. +* [bf](https://github.com/ruby-syntax-tree/syntax_tree-bf) for the [brainf*** language](https://esolangs.org/wiki/Brainfuck). +* [css](https://github.com/ruby-syntax-tree/syntax_tree-css) for the [CSS stylesheet language](https://www.w3.org/Style/CSS/). * [haml](https://github.com/ruby-syntax-tree/syntax_tree-haml) for the [Haml template language](https://haml.info/). -* [json](https://github.com/ruby-syntax-tree/syntax_tree-json) for JSON. +* [json](https://github.com/ruby-syntax-tree/syntax_tree-json) for the [JSON notation language](https://www.json.org/). * [rbs](https://github.com/ruby-syntax-tree/syntax_tree-rbs) for the [RBS type language](https://github.com/ruby/rbs). +* [xml](https://github.com/ruby-syntax-tree/syntax_tree-xml) for the [XML markup language](https://www.w3.org/XML/). ## Integration -Syntax Tree's goal is to seemlessly integrate into your workflow. To this end, it provides a couple of additional tools beyond the CLI and the Ruby library. +Syntax Tree's goal is to seamlessly integrate into your workflow. To this end, it provides a couple of additional tools beyond the CLI and the Ruby library. + +### Rake + +Syntax Tree ships with the ability to define [rake](https://github.com/ruby/rake) tasks that will trigger runs of the CLI. To define them in your application, add the following configuration to your `Rakefile`: + +```ruby +require "syntax_tree/rake_tasks" +SyntaxTree::Rake::CheckTask.new +SyntaxTree::Rake::WriteTask.new +``` + +These calls will define `rake stree:check` and `rake stree:write` (equivalent to calling `stree check` and `stree write` with the CLI respectively). You can configure them by either passing arguments to the `new` method or by using a block. + +#### `name` + +If you'd like to change the default name of the rake task, you can pass that as the first argument, as in: + +```ruby +SyntaxTree::Rake::WriteTask.new(:format) +``` + +#### `source_files` + +If you wanted to configure Syntax Tree to check or write different files than the default (`lib/**/*.rb`), you can set the `source_files` field, as in: + +```ruby +SyntaxTree::Rake::WriteTask.new do |t| + t.source_files = FileList[%w[Gemfile Rakefile lib/**/*.rb test/**/*.rb]] +end +``` + +#### `ignore_files` + +If you want to ignore certain file patterns when running the command, you can pass the `ignore_files` option. This will be checked with `File.fnmatch?` against each filepath that the command would be run against. For example: + +```ruby +SyntaxTree::Rake::WriteTask.new do |t| + t.source_files = "**/*.rb" + t.ignore_files = "db/**/*.rb" +end +``` + +#### `print_width` + +If you want to use a different print width from the default (80), you can pass that to the `print_width` field, as in: + +```ruby +SyntaxTree::Rake::WriteTask.new do |t| + t.print_width = 100 +end +``` + +#### `plugins` + +If you're running Syntax Tree with plugins (either your own or the pre-built ones), you can pass that to the `plugins` field, as in: + +```ruby +SyntaxTree::Rake::WriteTask.new do |t| + t.plugins = ["plugin/single_quotes"] +end +``` ### RuboCop @@ -445,9 +811,12 @@ inherit_gem: syntax_tree: config/rubocop.yml ``` -### VSCode +### Editors -To integrate Syntax Tree into VSCode, you should use the official VSCode extension [ruby-syntax-tree/vscode-syntax-tree](https://github.com/ruby-syntax-tree/vscode-syntax-tree). +* [Neovim](https://neovim.io/) - [neovim/nvim-lspconfig](https://github.com/neovim/nvim-lspconfig). +* [Vim](https://www.vim.org/) - [dense-analysis/ale](https://github.com/dense-analysis/ale). +* [VSCode](https://code.visualstudio.com/) - [ruby-syntax-tree/vscode-syntax-tree](https://github.com/ruby-syntax-tree/vscode-syntax-tree). +* [Emacs](https://www.gnu.org/software/emacs/) - [emacs-format-all-the-code](https://github.com/lassik/emacs-format-all-the-code). ## Contributing diff --git a/Rakefile b/Rakefile index 4b3de39a..fb4f8847 100644 --- a/Rakefile +++ b/Rakefile @@ -2,6 +2,9 @@ require "bundler/gem_tasks" require "rake/testtask" +require "syntax_tree/rake_tasks" + +Rake.add_rakelib "tasks" Rake::TestTask.new(:test) do |t| t.libs << "test" @@ -11,24 +14,26 @@ end task default: :test -FILEPATHS = %w[ - Gemfile - Rakefile - syntax_tree.gemspec - lib/**/*.rb - test/*.rb -].freeze - -task :syntax_tree do - $:.unshift File.expand_path("lib", __dir__) - require "syntax_tree" - require "syntax_tree/cli" -end +configure = ->(task) do + task.source_files = + FileList[ + %w[ + Gemfile + Rakefile + syntax_tree.gemspec + lib/**/*.rb + tasks/*.rake + test/*.rb + ] + ] -task check: :syntax_tree do - exit SyntaxTree::CLI.run(["check"] + FILEPATHS) + # Since Syntax Tree supports back to Ruby 2.7.0, we need to make sure that we + # format our code such that it's compatible with that version. This actually + # has very little effect on the output, the only change at the moment is that + # Ruby < 2.7.3 didn't allow a newline before the closing brace of a hash + # pattern. + task.target_ruby_version = Gem::Version.new("2.7.0") end -task format: :syntax_tree do - exit SyntaxTree::CLI.run(["write"] + FILEPATHS) -end +SyntaxTree::Rake::CheckTask.new(&configure) +SyntaxTree::Rake::WriteTask.new(&configure) diff --git a/bin/console b/bin/console index 1c18bd62..6f35f1ec 100755 --- a/bin/console +++ b/bin/console @@ -3,6 +3,7 @@ require "bundler/setup" require "syntax_tree" +require "syntax_tree/reflection" require "irb" IRB.start(__FILE__) diff --git a/bin/profile b/bin/profile index 0a1b6ade..15bd28ae 100755 --- a/bin/profile +++ b/bin/profile @@ -6,22 +6,21 @@ require "bundler/inline" gemfile do source "https://rubygems.org" gem "stackprof" + gem "prettier_print" end $:.unshift(File.expand_path("../lib", __dir__)) require "syntax_tree" -GC.disable - StackProf.run(mode: :cpu, out: "tmp/profile.dump", raw: true) do - filepath = File.expand_path("../lib/syntax_tree/node.rb", __dir__) - SyntaxTree.format(File.read(filepath)) + Dir[File.join(RbConfig::CONFIG["libdir"], "**/*.rb")].each do |filepath| + SyntaxTree.format(SyntaxTree.read(filepath)) + end end -GC.enable - File.open("tmp/flamegraph.html", "w") do |file| report = Marshal.load(IO.binread("tmp/profile.dump")) + StackProf::Report.new(report).print_text StackProf::Report.new(report).print_d3_flamegraph(file) end diff --git a/lib/syntax_tree.rb b/lib/syntax_tree.rb index c5e2d913..90fb7fe7 100644 --- a/lib/syntax_tree.rb +++ b/lib/syntax_tree.rb @@ -1,59 +1,119 @@ # frozen_string_literal: true -require "json" +require "prettier_print" require "pp" -require "prettyprint" require "ripper" -require "stringio" -require_relative "syntax_tree/formatter" require_relative "syntax_tree/node" -require_relative "syntax_tree/parser" -require_relative "syntax_tree/version" +require_relative "syntax_tree/basic_visitor" require_relative "syntax_tree/visitor" -require_relative "syntax_tree/visitor/field_visitor" -require_relative "syntax_tree/visitor/json_visitor" -require_relative "syntax_tree/visitor/match_visitor" -require_relative "syntax_tree/visitor/pretty_print_visitor" - -# If PrettyPrint::Align isn't defined, then we haven't gotten the updated -# version of prettyprint. In that case we'll define our own. This is going to -# overwrite a bunch of methods, so silencing them as well. -unless PrettyPrint.const_defined?(:Align) - verbose = $VERBOSE - $VERBOSE = nil - - begin - require_relative "syntax_tree/prettyprint" - ensure - $VERBOSE = verbose - end -end -# When PP is running, it expects that everything that interacts with it is going -# to flow through PP.pp, since that's the main entry into the module from the -# perspective of its uses in core Ruby. In doing so, it calls guard_inspect_key -# at the top of the PP.pp method, which establishes some thread-local hashes to -# check for cycles in the pretty printed tree. This means that if you want to -# manually call pp on some object _before_ you have established these hashes, -# you're going to break everything. So this call ensures that those hashes have -# been set up before anything uses pp manually. -PP.new(+"", 0).guard_inspect_key {} +require_relative "syntax_tree/formatter" +require_relative "syntax_tree/parser" +require_relative "syntax_tree/version" # Syntax Tree is a suite of tools built on top of the internal CRuby parser. It # provides the ability to generate a syntax tree from source, as well as the # tools necessary to inspect and manipulate that syntax tree. It can be used to # build formatters, linters, language servers, and more. module SyntaxTree + # Syntax Tree the library has many features that aren't always used by the + # CLI. Requiring those features takes time, so we autoload as many constants + # as possible in order to keep the CLI as fast as possible. + + autoload :Database, "syntax_tree/database" + autoload :DSL, "syntax_tree/dsl" + autoload :FieldVisitor, "syntax_tree/field_visitor" + autoload :Index, "syntax_tree/index" + autoload :JSONVisitor, "syntax_tree/json_visitor" + autoload :LanguageServer, "syntax_tree/language_server" + autoload :MatchVisitor, "syntax_tree/match_visitor" + autoload :Mermaid, "syntax_tree/mermaid" + autoload :MermaidVisitor, "syntax_tree/mermaid_visitor" + autoload :MutationVisitor, "syntax_tree/mutation_visitor" + autoload :Pattern, "syntax_tree/pattern" + autoload :PrettyPrintVisitor, "syntax_tree/pretty_print_visitor" + autoload :Search, "syntax_tree/search" + autoload :WithScope, "syntax_tree/with_scope" + # This holds references to objects that respond to both #parse and #format # so that we can use them in the CLI. HANDLERS = {} HANDLERS.default = SyntaxTree - # This is a hook provided so that plugins can register themselves as the - # handler for a particular file type. - def self.register_handler(extension, handler) - HANDLERS[extension] = handler + # This is the default print width when formatting. It can be overridden in the + # CLI by passing the --print-width option or here in the API by passing the + # optional second argument to ::format. + DEFAULT_PRINT_WIDTH = 80 + + # This is the default ruby version that we're going to target for formatting. + # It shouldn't really be changed except in very niche circumstances. + DEFAULT_RUBY_VERSION = Formatter::SemanticVersion.new(RUBY_VERSION).freeze + + # The default indentation level for formatting. We allow changing this so + # that Syntax Tree can format arbitrary parts of a document. + DEFAULT_INDENTATION = 0 + + # Parses the given source and returns the formatted source. + def self.format( + source, + maxwidth = DEFAULT_PRINT_WIDTH, + base_indentation = DEFAULT_INDENTATION, + options: Formatter::Options.new + ) + format_node( + source, + parse(source), + maxwidth, + base_indentation, + options: options + ) + end + + # Parses the given file and returns the formatted source. + def self.format_file( + filepath, + maxwidth = DEFAULT_PRINT_WIDTH, + base_indentation = DEFAULT_INDENTATION, + options: Formatter::Options.new + ) + format(read(filepath), maxwidth, base_indentation, options: options) + end + + # Accepts a node in the tree and returns the formatted source. + def self.format_node( + source, + node, + maxwidth = DEFAULT_PRINT_WIDTH, + base_indentation = DEFAULT_INDENTATION, + options: Formatter::Options.new + ) + formatter = Formatter.new(source, [], maxwidth, options: options) + node.format(formatter) + + formatter.flush(base_indentation) + formatter.output.join + end + + # Indexes the given source code to return a list of all class, module, and + # method definitions. Used to quickly provide indexing capability for IDEs or + # documentation generation. + def self.index(source) + Index.index(source) + end + + # Indexes the given file to return a list of all class, module, and method + # definitions. Used to quickly provide indexing capability for IDEs or + # documentation generation. + def self.index_file(filepath) + Index.index_file(filepath) + end + + # A convenience method for creating a new mutation visitor. + def self.mutation + visitor = MutationVisitor.new + yield visitor + visitor end # Parses the given source and returns the syntax tree. @@ -63,13 +123,9 @@ def self.parse(source) response unless parser.error? end - # Parses the given source and returns the formatted source. - def self.format(source) - formatter = Formatter.new(source, []) - parse(source).format(formatter) - - formatter.flush - formatter.output.join + # Parses the given file and returns the syntax tree. + def self.parse_file(filepath) + parse(read(filepath)) end # Returns the source from the given filepath taking into account any potential @@ -86,4 +142,25 @@ def self.read(filepath) File.read(filepath, encoding: encoding) end + + # This is a hook provided so that plugins can register themselves as the + # handler for a particular file type. + def self.register_handler(extension, handler) + HANDLERS[extension] = handler + end + + # Searches through the given source using the given pattern and yields each + # node in the tree that matches the pattern to the given block. + def self.search(source, query, &block) + pattern = Pattern.new(query).compile + program = parse(source) + + Search.new(pattern).scan(program, &block) + end + + # Searches through the given file using the given pattern and yields each + # node in the tree that matches the pattern to the given block. + def self.search_file(filepath, query, &block) + search(read(filepath), query, &block) + end end diff --git a/lib/syntax_tree/basic_visitor.rb b/lib/syntax_tree/basic_visitor.rb new file mode 100644 index 00000000..bd8ea5f2 --- /dev/null +++ b/lib/syntax_tree/basic_visitor.rb @@ -0,0 +1,117 @@ +# frozen_string_literal: true + +module SyntaxTree + # BasicVisitor is the parent class of the Visitor class that provides the + # ability to walk down the tree. It does not define any handlers, so you + # should extend this class if you want your visitor to raise an error if you + # attempt to visit a node that you don't handle. + class BasicVisitor + # This is raised when you use the Visitor.visit_method method and it fails. + # It is correctable to through DidYouMean. + class VisitMethodError < StandardError + attr_reader :visit_method + + def initialize(visit_method) + @visit_method = visit_method + super("Invalid visit method: #{visit_method}") + end + end + + # This class is used by DidYouMean to offer corrections to invalid visit + # method names. + class VisitMethodChecker + attr_reader :visit_method + + def initialize(error) + @visit_method = error.visit_method + end + + def corrections + @corrections ||= + DidYouMean::SpellChecker.new( + dictionary: BasicVisitor.valid_visit_methods + ).correct(visit_method) + end + + # In some setups with Ruby you can turn off DidYouMean, so we're going to + # respect that setting here. + if defined?(DidYouMean.correct_error) + DidYouMean.correct_error(VisitMethodError, self) + end + end + + # This module is responsible for checking all of the methods defined within + # a given block to ensure that they are valid visit methods. + class VisitMethodsChecker < Module + Status = Struct.new(:checking) + + # This is the status of the checker. It's used to determine whether or not + # we should be checking the methods that are defined. It is kept as an + # instance variable so that it can be disabled later. + attr_reader :status + + def initialize + # We need the status to be an instance variable so that it can be + # accessed by the disable! method, but also a local variable so that it + # can be captured by the define_method block. + status = @status = Status.new(true) + + define_method(:method_added) do |name| + BasicVisitor.visit_method(name) if status.checking + super(name) + end + end + + def disable! + status.checking = false + end + end + + class << self + # This is the list of all of the valid visit methods. + def valid_visit_methods + @valid_visit_methods ||= + Visitor.instance_methods.grep(/^visit_(?!child_nodes)/) + end + + # This method is here to help folks write visitors. + # + # It's not always easy to ensure you're writing the correct method name in + # the visitor since it's perfectly valid to define methods that don't + # override these parent methods. + # + # If you use this method, you can ensure you're writing the correct method + # name. It will raise an error if the visit method you're defining isn't + # actually a method on the parent visitor. + def visit_method(method_name) + return if valid_visit_methods.include?(method_name) + + raise VisitMethodError, method_name + end + + # This method is here to help folks write visitors. + # + # Within the given block, every method that is defined will be checked to + # ensure it's a valid visit method using the BasicVisitor::visit_method + # method defined above. + def visit_methods + checker = VisitMethodsChecker.new + extend(checker) + yield + checker.disable! + end + end + + def visit(node) + node&.accept(self) + end + + def visit_all(nodes) + nodes.map { |node| visit(node) } + end + + def visit_child_nodes(node) + visit_all(node.child_nodes) + end + end +end diff --git a/lib/syntax_tree/cli.rb b/lib/syntax_tree/cli.rb index 64848ca6..e3bac8f1 100644 --- a/lib/syntax_tree/cli.rb +++ b/lib/syntax_tree/cli.rb @@ -1,5 +1,8 @@ # frozen_string_literal: true +require "etc" +require "optparse" + module SyntaxTree # Syntax Tree ships with the `stree` CLI, which can be used to inspect and # manipulate Ruby code. This module is responsible for powering that CLI. @@ -34,9 +37,82 @@ def self.yellow(value) end end + # An item of work that corresponds to a file to be processed. + class FileItem + attr_reader :filepath + + def initialize(filepath) + @filepath = filepath + end + + def handler + HANDLERS[File.extname(filepath)] + end + + def source + handler.read(filepath) + end + + def writable? + File.writable?(filepath) + end + end + + # An item of work that corresponds to a script content passed via the + # command line. + class ScriptItem + attr_reader :source + + def initialize(source, extension) + @source = source + @extension = extension + end + + def handler + HANDLERS[@extension] + end + + def filepath + :script + end + + def writable? + false + end + end + + # An item of work that correspond to the content passed in via stdin. + class STDINItem + def initialize(extension) + @extension = extension + end + + def handler + HANDLERS[@extension] + end + + def filepath + :stdin + end + + def source + $stdin.read + end + + def writable? + false + end + end + # The parent action class for the CLI that implements the basics. class Action - def run(handler, filepath, source) + attr_reader :options + + def initialize(options) + @options = options + end + + def run(item) end def success @@ -48,8 +124,8 @@ def failure # An action of the CLI that prints out the AST for the given source. class AST < Action - def run(handler, _filepath, source) - pp handler.parse(source) + def run(item) + pp item.handler.parse(item.source) end end @@ -59,10 +135,18 @@ class Check < Action class UnformattedError < StandardError end - def run(handler, filepath, source) - raise UnformattedError if source != handler.format(source) + def run(item) + source = item.source + formatted = + item.handler.format( + source, + options.print_width, + options: options.formatter_options + ) + + raise UnformattedError if source != formatted rescue StandardError - warn("[#{Color.yellow("warn")}] #{filepath}") + warn("[#{Color.yellow("warn")}] #{item.filepath}") raise end @@ -75,17 +159,117 @@ def failure end end + # An action of the CLI that generates ctags for the given source. + class CTags < Action + attr_reader :entries + + def initialize(options) + super + @entries = [] + end + + def run(item) + lines = item.source.lines(chomp: true) + + SyntaxTree + .index(item.source) + .each do |entry| + line = lines[entry.location.line - 1] + pattern = "/^#{line.gsub("\\", "\\\\\\\\").gsub("/", "\\/")}$/;\"" + + entries << case entry + when SyntaxTree::Index::ModuleDefinition + parts = [entry.name, item.filepath, pattern, "m"] + + if entry.nesting != [[entry.name]] + parts << "class:#{entry.nesting.flatten.tap(&:pop).join(".")}" + end + + parts.join("\t") + when SyntaxTree::Index::ClassDefinition + parts = [entry.name, item.filepath, pattern, "c"] + + if entry.nesting != [[entry.name]] + parts << "class:#{entry.nesting.flatten.tap(&:pop).join(".")}" + end + + unless entry.superclass.empty? + inherits = entry.superclass.join(".").delete_prefix(".") + parts << "inherits:#{inherits}" + end + + parts.join("\t") + when SyntaxTree::Index::MethodDefinition + parts = [entry.name, item.filepath, pattern, "f"] + + unless entry.nesting.empty? + parts << "class:#{entry.nesting.flatten.join(".")}" + end + + parts.join("\t") + when SyntaxTree::Index::SingletonMethodDefinition + parts = [entry.name, item.filepath, pattern, "F"] + + unless entry.nesting.empty? + parts << "class:#{entry.nesting.flatten.join(".")}" + end + + parts.join("\t") + when SyntaxTree::Index::AliasMethodDefinition + parts = [entry.name, item.filepath, pattern, "a"] + + unless entry.nesting.empty? + parts << "class:#{entry.nesting.flatten.join(".")}" + end + + parts.join("\t") + when SyntaxTree::Index::ConstantDefinition + parts = [entry.name, item.filepath, pattern, "C"] + + unless entry.nesting.empty? + parts << "class:#{entry.nesting.flatten.join(".")}" + end + + parts.join("\t") + end + end + end + + def success + puts(<<~HEADER) + !_TAG_FILE_FORMAT 2 /extended format; --format=1 will not append ;" to lines/ + !_TAG_FILE_SORTED 1 /0=unsorted, 1=sorted, 2=foldcase/ + HEADER + + entries.sort.each { |entry| puts(entry) } + end + end + # An action of the CLI that formats the source twice to check if the first # format is not idempotent. class Debug < Action class NonIdempotentFormatError < StandardError end - def run(handler, filepath, source) - warning = "[#{Color.yellow("warn")}] #{filepath}" - formatted = handler.format(source) - - raise NonIdempotentFormatError if formatted != handler.format(formatted) + def run(item) + handler = item.handler + warning = "[#{Color.yellow("warn")}] #{item.filepath}" + + formatted = + handler.format( + item.source, + options.print_width, + options: options.formatter_options + ) + + double_formatted = + handler.format( + formatted, + options.print_width, + options: options.formatter_options + ) + + raise NonIdempotentFormatError if formatted != double_formatted rescue StandardError warn(warning) raise @@ -102,25 +286,51 @@ def failure # An action of the CLI that prints out the doc tree IR for the given source. class Doc < Action - def run(handler, _filepath, source) - formatter = Formatter.new(source, []) - handler.parse(source).format(formatter) + def run(item) + source = item.source + + formatter_options = options.formatter_options + formatter = Formatter.new(source, [], options: formatter_options) + + item.handler.parse(source).format(formatter) pp formatter.groups.first end end + # An action of the CLI that outputs a pattern-matching Ruby expression that + # would match the first expression of the input given. + class Expr < Action + def run(item) + program = item.handler.parse(item.source) + + if (expressions = program.statements.body) && expressions.size == 1 + puts expressions.first.construct_keys + else + warn("The input to `stree expr` must be a single expression.") + exit(1) + end + end + end + # An action of the CLI that formats the input source and prints it out. class Format < Action - def run(handler, _filepath, source) - puts handler.format(source) + def run(item) + formatted = + item.handler.format( + item.source, + options.print_width, + options: options.formatter_options + ) + + puts formatted end end # An action of the CLI that converts the source into its equivalent JSON # representation. class Json < Action - def run(handler, _filepath, source) - object = Visitor::JSONVisitor.new.visit(handler.parse(source)) + def run(item) + object = item.handler.parse(item.source).accept(JSONVisitor.new) puts JSON.pretty_generate(object) end end @@ -128,27 +338,73 @@ def run(handler, _filepath, source) # An action of the CLI that outputs a pattern-matching Ruby expression that # would match the input given. class Match < Action - def run(handler, _filepath, source) - puts handler.parse(source).construct_keys + def run(item) + puts item.handler.parse(item.source).construct_keys + end + end + + # An action of the CLI that searches for the given pattern matching pattern + # in the given files. + class Search < Action + attr_reader :search + + def initialize(query) + query = File.read(query) if File.readable?(query) + pattern = + begin + Pattern.new(query).compile + rescue Pattern::CompilationError => error + warn(error.message) + exit(1) + end + + @search = SyntaxTree::Search.new(pattern) + end + + def run(item) + search.scan(item.handler.parse(item.source)) do |node| + location = node.location + line = location.start_line + + bold_range = + if line == location.end_line + location.start_column...location.end_column + else + location.start_column.. + end + + source = item.source.lines[line - 1].chomp + source[bold_range] = Color.bold(source[bold_range]).to_s + + puts("#{item.filepath}:#{line}:#{location.start_column}: #{source}") + end end end # An action of the CLI that formats the input source and writes the # formatted output back to the file. class Write < Action - def run(handler, filepath, source) - print filepath + def run(item) + filepath = item.filepath start = Time.now - formatted = handler.format(source) - File.write(filepath, formatted) if filepath != :stdin + source = item.source + formatted = + item.handler.format( + source, + options.print_width, + options: options.formatter_options + ) + changed = source != formatted + + File.write(filepath, formatted) if item.writable? && changed - color = source == formatted ? Color.gray(filepath) : filepath + color = changed ? filepath : Color.gray(filepath) delta = ((Time.now - start) * 1000).round - puts "\r#{color} #{delta}ms" + puts "#{color} #{delta}ms" rescue StandardError - puts "\r#{filepath}" + puts filepath raise end end @@ -156,127 +412,287 @@ def run(handler, filepath, source) # The help message displayed if the input arguments are not correctly # ordered or formatted. HELP = <<~HELP - #{Color.bold("stree ast [OPTIONS] [FILE]")} + #{Color.bold("stree ast [--plugins=...] [--print-width=NUMBER] [-e SCRIPT] FILE")} Print out the AST corresponding to the given files - #{Color.bold("stree check [OPTIONS] [FILE]")} + #{Color.bold("stree check [--plugins=...] [--print-width=NUMBER] [-e SCRIPT] FILE")} Check that the given files are formatted as syntax tree would format them - #{Color.bold("stree debug [OPTIONS] [FILE]")} + #{Color.bold("stree ctags [-e SCRIPT] FILE")} + Print out a ctags-compatible index of the given files + + #{Color.bold("stree debug [--plugins=...] [--print-width=NUMBER] [-e SCRIPT] FILE")} Check that the given files can be formatted idempotently - #{Color.bold("stree doc [OPTIONS] [FILE]")} + #{Color.bold("stree doc [--plugins=...] [-e SCRIPT] FILE")} Print out the doc tree that would be used to format the given files - #{Color.bold("stree format [OPTIONS] [FILE]")} + #{Color.bold("stree expr [-e SCRIPT] FILE")} + Print out a pattern-matching Ruby expression that would match the first + expression of the given files + + #{Color.bold("stree format [--plugins=...] [--print-width=NUMBER] [-e SCRIPT] FILE")} Print out the formatted version of the given files - #{Color.bold("stree json [OPTIONS] [FILE]")} + #{Color.bold("stree json [--plugins=...] [-e SCRIPT] FILE")} Print out the JSON representation of the given files - #{Color.bold("stree match [OPTIONS] [FILE]")} + #{Color.bold("stree match [--plugins=...] [-e SCRIPT] FILE")} Print out a pattern-matching Ruby expression that would match the given files #{Color.bold("stree help")} Display this help message - #{Color.bold("stree lsp")} + #{Color.bold("stree lsp [--plugins=...] [--print-width=NUMBER]")} Run syntax tree in language server mode + #{Color.bold("stree search PATTERN [-e SCRIPT] FILE")} + Search for the given pattern in the given files + #{Color.bold("stree version")} Output the current version of syntax tree - #{Color.bold("stree write [OPTIONS] [FILE]")} + #{Color.bold("stree write [--plugins=...] [--print-width=NUMBER] [-e SCRIPT] FILE")} Read, format, and write back the source of the given files - [OPTIONS] + --ignore-files=... + A glob pattern to ignore files when processing. This can be specified + multiple times to ignore multiple patterns. --plugins=... A comma-separated list of plugins to load. + + --print-width=... + The maximum line width to use when formatting. + + -e ... + Parse an inline string. + + --extension=... + A file extension matching the content passed in via STDIN or -e. + Defaults to '.rb'. + + --config=... + Path to a configuration file. Defaults to .streerc in the current + working directory. HELP + # This represents all of the options that can be passed to the CLI. It is + # responsible for parsing the list and then returning the file paths at the + # end. + class Options + attr_reader :ignore_files, + :plugins, + :print_width, + :scripts, + :extension, + :target_ruby_version + + def initialize + @ignore_files = [] + @plugins = [] + @print_width = DEFAULT_PRINT_WIDTH + @scripts = [] + @extension = ".rb" + @target_ruby_version = DEFAULT_RUBY_VERSION + end + + def formatter_options + @formatter_options ||= + Formatter::Options.new(target_ruby_version: target_ruby_version) + end + + def parse(arguments) + parser.parse!(arguments) + end + + private + + def parser + OptionParser.new do |opts| + # If there is a glob specified to ignore, then we'll track that here. + # Any of the CLI commands that operate on filenames will then ignore + # this set of files. + opts.on("--ignore-files=GLOB") do |glob| + @ignore_files << (glob.match(/\A'(.*)'\z/) ? $1 : glob) + end + + # If there are any plugins specified on the command line, then load + # them by requiring them here. We do this by transforming something + # like + # + # stree format --plugins=haml template.haml + # + # into + # + # require "syntax_tree/haml" + # + opts.on("--plugins=PLUGINS") do |plugins| + @plugins = plugins.split(",") + @plugins.each { |plugin| require "syntax_tree/#{plugin}" } + end + + # If there is a print width specified on the command line, then + # parse that out here and use it when formatting. + opts.on("--print-width=NUMBER", Integer) do |print_width| + @print_width = print_width + end + + # If there is a script specified on the command line, then parse + # it and add it to the list of scripts to run. + opts.on("-e SCRIPT") { |script| @scripts << script } + + # If there is a extension specified, then parse it and use it for + # STDIN and scripts. + opts.on("--extension=EXTENSION") do |extension| + # Both ".rb" and "rb" are going to work + @extension = ".#{extension.delete_prefix(".")}" + end + + # If there is a target ruby version specified on the command line, + # parse that out and use it when formatting. + opts.on("--target-ruby-version=VERSION") do |version| + @target_ruby_version = Formatter::SemanticVersion.new(version) + end + end + end + end + + # We allow a minimal configuration file to act as additional command line + # arguments to the CLI. Each line of the config file should be a new + # argument, as in: + # + # --plugins=plugin/single_quote + # --print-width=100 + # + # When invoking the CLI, we will read this config file and then parse it if + # it exists in the current working directory. + class ConfigFile + FILENAME = ".streerc" + + attr_reader :filepath + + def initialize(filepath = nil) + if filepath + if File.readable?(filepath) + @filepath = filepath + else + raise ArgumentError, "Invalid configuration file: #{filepath}" + end + else + @filepath = File.join(Dir.pwd, FILENAME) + end + end + + def exists? + File.readable?(filepath) + end + + def arguments + exists? ? File.readlines(filepath, chomp: true) : [] + end + end + class << self # Run the CLI over the given array of strings that make up the arguments # passed to the invocation. def run(argv) name, *arguments = argv - case name - when "help" - puts HELP - return 0 - when "lsp" - require "syntax_tree/language_server" - LanguageServer.new.run - return 0 - when "version" - puts SyntaxTree::VERSION - return 0 + # First, we need to check if there's a --config option specified + # so we can use the custom config file path. + config_filepath = nil + arguments.each_with_index do |arg, index| + if arg.start_with?("--config=") + config_filepath = arg.split("=", 2)[1] + arguments.delete_at(index) + break + elsif arg == "--config" && arguments[index + 1] + config_filepath = arguments[index + 1] + arguments.delete_at(index + 1) + arguments.delete_at(index) + break + end end + config_file = ConfigFile.new(config_filepath) + arguments = config_file.arguments.concat(arguments) + + options = Options.new + options.parse(arguments) + action = case name when "a", "ast" - AST.new + AST.new(options) when "c", "check" - Check.new + Check.new(options) + when "ctags" + CTags.new(options) when "debug" - Debug.new + Debug.new(options) when "doc" - Doc.new + Doc.new(options) + when "e", "expr" + Expr.new(options) + when "f", "format" + Format.new(options) + when "help" + puts HELP + return 0 when "j", "json" - Json.new + Json.new(options) + when "lsp" + LanguageServer.new( + print_width: options.print_width, + ignore_files: options.ignore_files + ).run + return 0 when "m", "match" - Match.new - when "f", "format" - Format.new + Match.new(options) + when "s", "search" + Search.new(arguments.shift) + when "version" + puts SyntaxTree::VERSION + return 0 when "w", "write" - Write.new + Write.new(options) else warn(HELP) return 1 end - # If we're not reading from stdin and the user didn't supply and - # filepaths to be read, then we exit with the usage message. - if $stdin.tty? && arguments.empty? - warn(HELP) - return 1 - end + # We're going to build up a queue of items to process. + queue = Queue.new - # If there are any plugins specified on the command line, then load them - # by requiring them here. We do this by transforming something like - # - # stree format --plugins=haml template.haml - # - # into - # - # require "syntax_tree/haml" - # - if arguments.first&.start_with?("--plugins=") - plugins = arguments.shift[/^--plugins=(.*)$/, 1] - plugins.split(",").each { |plugin| require "syntax_tree/#{plugin}" } - end + # If there are any arguments or scripts, then we'll add those to the + # queue. Otherwise we'll read the content off STDIN. + if arguments.any? || options.scripts.any? + arguments.each do |pattern| + Dir + .glob(pattern) + .each do |filepath| + # Skip past invalid filepaths by default. + next unless File.readable?(filepath) + + # Skip past any ignored filepaths. + next if options.ignore_files.any? { File.fnmatch(_1, filepath) } - # Track whether or not there are any errors from any of the files that - # we take action on so that we can properly clean up and exit. - errored = false - - each_file(arguments) do |handler, filepath, source| - action.run(handler, filepath, source) - rescue Parser::ParseError => error - warn("Error: #{error.message}") - highlight_error(error, source) - errored = true - rescue Check::UnformattedError, Debug::NonIdempotentFormatError - errored = true - rescue StandardError => error - warn(error.message) - warn(error.backtrace) - errored = true + # Otherwise, a new file item for the given filepath to the list. + queue << FileItem.new(filepath) + end + end + + options.scripts.each do |script| + queue << ScriptItem.new(script, options.extension) + end + else + queue << STDINItem.new(options.extension) end - if errored + # At the end, we're going to return whether or not this worker ever + # encountered an error. + if process_queue(queue, action) action.failure 1 else @@ -287,22 +703,49 @@ def run(argv) private - def each_file(arguments) - if $stdin.tty? || arguments.any? - arguments.each do |pattern| - Dir - .glob(pattern) - .each do |filepath| - next unless File.file?(filepath) - - handler = HANDLERS[File.extname(filepath)] - source = handler.read(filepath) - yield handler, filepath, source + # Processes each item in the queue with the given action. Returns whether + # or not any errors were encountered. + def process_queue(queue, action) + workers = + [Etc.nprocessors, queue.size].min.times.map do + Thread.new do + # Propagate errors in the worker threads up to the parent thread. + Thread.current.abort_on_exception = true + + # Track whether or not there are any errors from any of the files + # that we take action on so that we can properly clean up and + # exit. + errored = false + + # While there is still work left to do, shift off the queue and + # process the item. + until queue.empty? + item = queue.shift + errored |= + begin + action.run(item) + false + rescue Parser::ParseError => error + warn("Error: #{error.message}") + highlight_error(error, item.source) + true + rescue Check::UnformattedError, + Debug::NonIdempotentFormatError + true + rescue StandardError => error + warn(error.message) + warn(error.backtrace) + true + end end + + # At the end, we're going to return whether or not this worker + # ever encountered an error. + errored + end end - else - yield HANDLERS[".rb"], :stdin, $stdin.read - end + + workers.map(&:value).inject(:|) end # Highlights a snippet from a source and parse error. diff --git a/lib/syntax_tree/database.rb b/lib/syntax_tree/database.rb new file mode 100644 index 00000000..c9981f35 --- /dev/null +++ b/lib/syntax_tree/database.rb @@ -0,0 +1,331 @@ +# frozen_string_literal: true + +module SyntaxTree + # Provides the ability to index source files into a database, then query for + # the nodes. + module Database + class IndexingVisitor < SyntaxTree::FieldVisitor + attr_reader :database, :filepath, :node_id + + def initialize(database, filepath) + @database = database + @filepath = filepath + @node_id = nil + end + + private + + def comments(node) + end + + def field(name, value) + return unless value.is_a?(SyntaxTree::Node) + + binds = [node_id, visit(value), name] + database.execute(<<~SQL, binds) + INSERT INTO edges (from_id, to_id, name) + VALUES (?, ?, ?) + SQL + end + + def list(name, values) + values.each_with_index do |value, index| + binds = [node_id, visit(value), name, index] + database.execute(<<~SQL, binds) + INSERT INTO edges (from_id, to_id, name, list_index) + VALUES (?, ?, ?, ?) + SQL + end + end + + def node(node, _name) + previous = node_id + binds = [ + node.class.name.delete_prefix("SyntaxTree::"), + filepath, + node.location.start_line, + node.location.start_column + ] + + database.execute(<<~SQL, binds) + INSERT INTO nodes (type, path, line, column) + VALUES (?, ?, ?, ?) + SQL + + begin + @node_id = database.last_insert_row_id + yield + @node_id + ensure + @node_id = previous + end + end + + def text(name, value) + end + + def pairs(name, values) + values.each_with_index do |(key, value), index| + binds = [node_id, visit(key), "#{name}[0]", index] + database.execute(<<~SQL, binds) + INSERT INTO edges (from_id, to_id, name, list_index) + VALUES (?, ?, ?, ?) + SQL + + binds = [node_id, visit(value), "#{name}[1]", index] + database.execute(<<~SQL, binds) + INSERT INTO edges (from_id, to_id, name, list_index) + VALUES (?, ?, ?, ?) + SQL + end + end + end + + # Query for a specific type of node. + class TypeQuery + attr_reader :type + + def initialize(type) + @type = type + end + + def each(database, &block) + sql = "SELECT * FROM nodes WHERE type = ?" + database.execute(sql, type).each(&block) + end + end + + # Query for the attributes of a node, optionally also filtering by type. + class AttrQuery + attr_reader :type, :attrs + + def initialize(type, attrs) + @type = type + @attrs = attrs + end + + def each(database, &block) + joins = [] + binds = [] + + attrs.each do |name, query| + ids = query.each(database).map { |row| row[0] } + joins << <<~SQL + JOIN edges AS #{name} + ON #{name}.from_id = nodes.id + AND #{name}.name = ? + AND #{name}.to_id IN (#{(["?"] * ids.size).join(", ")}) + SQL + + binds.push(name).concat(ids) + end + + sql = +"SELECT nodes.* FROM nodes, edges #{joins.join(" ")}" + + if type + sql << " WHERE nodes.type = ?" + binds << type + end + + sql << " GROUP BY nodes.id" + database.execute(sql, binds).each(&block) + end + end + + # Query for the results of either query. + class OrQuery + attr_reader :left, :right + + def initialize(left, right) + @left = left + @right = right + end + + def each(database, &block) + left.each(database, &block) + right.each(database, &block) + end + end + + # A lazy query result. + class QueryResult + attr_reader :database, :query + + def initialize(database, query) + @database = database + @query = query + end + + def each(&block) + return enum_for(__method__) unless block_given? + query.each(database, &block) + end + end + + # A pattern matching expression that will be compiled into a query. + class Pattern + class CompilationError < StandardError + end + + attr_reader :query + + def initialize(query) + @query = query + end + + def compile + program = + begin + SyntaxTree.parse("case nil\nin #{query}\nend") + rescue Parser::ParseError + raise CompilationError, query + end + + compile_node(program.statements.body.first.consequent.pattern) + end + + private + + def compile_error(node) + raise CompilationError, PP.pp(node, +"").chomp + end + + # Shortcut for combining two queries into one that returns the results of + # if either query matches. + def combine_or(left, right) + OrQuery.new(left, right) + end + + # in foo | bar + def compile_binary(node) + compile_error(node) if node.operator != :| + + combine_or(compile_node(node.left), compile_node(node.right)) + end + + # in Ident + def compile_const(node) + value = node.value + + if SyntaxTree.const_defined?(value, false) + clazz = SyntaxTree.const_get(value) + TypeQuery.new(clazz.name.delete_prefix("SyntaxTree::")) + else + compile_error(node) + end + end + + # in SyntaxTree::Ident + def compile_const_path_ref(node) + parent = node.parent + if !parent.is_a?(SyntaxTree::VarRef) || + !parent.value.is_a?(SyntaxTree::Const) + compile_error(node) + end + + if parent.value.value == "SyntaxTree" + compile_node(node.constant) + else + compile_error(node) + end + end + + # in Ident[value: String] + def compile_hshptn(node) + compile_error(node) unless node.keyword_rest.nil? + + attrs = {} + node.keywords.each do |keyword, value| + compile_error(node) unless keyword.is_a?(SyntaxTree::Label) + attrs[keyword.value.chomp(":")] = compile_node(value) + end + + type = node.constant ? compile_node(node.constant).type : nil + AttrQuery.new(type, attrs) + end + + # in Foo + def compile_var_ref(node) + value = node.value + + if value.is_a?(SyntaxTree::Const) + compile_node(value) + else + compile_error(node) + end + end + + def compile_node(node) + case node + when SyntaxTree::Binary + compile_binary(node) + when SyntaxTree::Const + compile_const(node) + when SyntaxTree::ConstPathRef + compile_const_path_ref(node) + when SyntaxTree::HshPtn + compile_hshptn(node) + when SyntaxTree::VarRef + compile_var_ref(node) + else + compile_error(node) + end + end + end + + class Connection + attr_reader :raw_connection + + def initialize(raw_connection) + @raw_connection = raw_connection + end + + def execute(query, binds = []) + raw_connection.execute(query, binds) + end + + def index_file(filepath) + program = SyntaxTree.parse(SyntaxTree.read(filepath)) + program.accept(IndexingVisitor.new(self, filepath)) + end + + def last_insert_row_id + raw_connection.last_insert_row_id + end + + def prepare + raw_connection.execute(<<~SQL) + CREATE TABLE nodes ( + id integer primary key, + type varchar(20), + path varchar(200), + line integer, + column integer + ); + SQL + + raw_connection.execute(<<~SQL) + CREATE INDEX nodes_type ON nodes (type); + SQL + + raw_connection.execute(<<~SQL) + CREATE TABLE edges ( + id integer primary key, + from_id integer, + to_id integer, + name varchar(20), + list_index integer + ); + SQL + + raw_connection.execute(<<~SQL) + CREATE INDEX edges_name ON edges (name); + SQL + end + + def search(query) + QueryResult.new(self, Pattern.new(query).compile) + end + end + end +end diff --git a/lib/syntax_tree/dsl.rb b/lib/syntax_tree/dsl.rb new file mode 100644 index 00000000..4506aa04 --- /dev/null +++ b/lib/syntax_tree/dsl.rb @@ -0,0 +1,1016 @@ +# frozen_string_literal: true + +module SyntaxTree + # This module provides shortcuts for creating AST nodes. + module DSL + # Create a new BEGINBlock node. + def BEGINBlock(lbrace, statements) + BEGINBlock.new( + lbrace: lbrace, + statements: statements, + location: Location.default + ) + end + + # Create a new CHAR node. + def CHAR(value) + CHAR.new(value: value, location: Location.default) + end + + # Create a new ENDBlock node. + def ENDBlock(lbrace, statements) + ENDBlock.new( + lbrace: lbrace, + statements: statements, + location: Location.default + ) + end + + # Create a new EndContent node. + def EndContent(value) + EndContent.new(value: value, location: Location.default) + end + + # Create a new AliasNode node. + def AliasNode(left, right) + AliasNode.new(left: left, right: right, location: Location.default) + end + + # Create a new ARef node. + def ARef(collection, index) + ARef.new(collection: collection, index: index, location: Location.default) + end + + # Create a new ARefField node. + def ARefField(collection, index) + ARefField.new( + collection: collection, + index: index, + location: Location.default + ) + end + + # Create a new ArgParen node. + def ArgParen(arguments) + ArgParen.new(arguments: arguments, location: Location.default) + end + + # Create a new Args node. + def Args(parts) + Args.new(parts: parts, location: Location.default) + end + + # Create a new ArgBlock node. + def ArgBlock(value) + ArgBlock.new(value: value, location: Location.default) + end + + # Create a new ArgStar node. + def ArgStar(value) + ArgStar.new(value: value, location: Location.default) + end + + # Create a new ArgsForward node. + def ArgsForward + ArgsForward.new(location: Location.default) + end + + # Create a new ArrayLiteral node. + def ArrayLiteral(lbracket, contents) + ArrayLiteral.new( + lbracket: lbracket, + contents: contents, + location: Location.default + ) + end + + # Create a new AryPtn node. + def AryPtn(constant, requireds, rest, posts) + AryPtn.new( + constant: constant, + requireds: requireds, + rest: rest, + posts: posts, + location: Location.default + ) + end + + # Create a new Assign node. + def Assign(target, value) + Assign.new(target: target, value: value, location: Location.default) + end + + # Create a new Assoc node. + def Assoc(key, value) + Assoc.new(key: key, value: value, location: Location.default) + end + + # Create a new AssocSplat node. + def AssocSplat(value) + AssocSplat.new(value: value, location: Location.default) + end + + # Create a new Backref node. + def Backref(value) + Backref.new(value: value, location: Location.default) + end + + # Create a new Backtick node. + def Backtick(value) + Backtick.new(value: value, location: Location.default) + end + + # Create a new BareAssocHash node. + def BareAssocHash(assocs) + BareAssocHash.new(assocs: assocs, location: Location.default) + end + + # Create a new Begin node. + def Begin(bodystmt) + Begin.new(bodystmt: bodystmt, location: Location.default) + end + + # Create a new PinnedBegin node. + def PinnedBegin(statement) + PinnedBegin.new(statement: statement, location: Location.default) + end + + # Create a new Binary node. + def Binary(left, operator, right) + Binary.new( + left: left, + operator: operator, + right: right, + location: Location.default + ) + end + + # Create a new BlockVar node. + def BlockVar(params, locals) + BlockVar.new(params: params, locals: locals, location: Location.default) + end + + # Create a new BlockArg node. + def BlockArg(name) + BlockArg.new(name: name, location: Location.default) + end + + # Create a new BodyStmt node. + def BodyStmt( + statements, + rescue_clause, + else_keyword, + else_clause, + ensure_clause + ) + BodyStmt.new( + statements: statements, + rescue_clause: rescue_clause, + else_keyword: else_keyword, + else_clause: else_clause, + ensure_clause: ensure_clause, + location: Location.default + ) + end + + # Create a new Break node. + def Break(arguments) + Break.new(arguments: arguments, location: Location.default) + end + + # Create a new CallNode node. + def CallNode(receiver, operator, message, arguments) + CallNode.new( + receiver: receiver, + operator: operator, + message: message, + arguments: arguments, + location: Location.default + ) + end + + # Create a new Case node. + def Case(keyword, value, consequent) + Case.new( + keyword: keyword, + value: value, + consequent: consequent, + location: Location.default + ) + end + + # Create a new RAssign node. + def RAssign(value, operator, pattern) + RAssign.new( + value: value, + operator: operator, + pattern: pattern, + location: Location.default + ) + end + + # Create a new ClassDeclaration node. + def ClassDeclaration( + constant, + superclass, + bodystmt, + location = Location.default + ) + ClassDeclaration.new( + constant: constant, + superclass: superclass, + bodystmt: bodystmt, + location: location + ) + end + + # Create a new Comma node. + def Comma(value) + Comma.new(value: value, location: Location.default) + end + + # Create a new Command node. + def Command(message, arguments, block, location = Location.default) + Command.new( + message: message, + arguments: arguments, + block: block, + location: location + ) + end + + # Create a new CommandCall node. + def CommandCall(receiver, operator, message, arguments, block) + CommandCall.new( + receiver: receiver, + operator: operator, + message: message, + arguments: arguments, + block: block, + location: Location.default + ) + end + + # Create a new Comment node. + def Comment(value, inline, location = Location.default) + Comment.new(value: value, inline: inline, location: location) + end + + # Create a new Const node. + def Const(value) + Const.new(value: value, location: Location.default) + end + + # Create a new ConstPathField node. + def ConstPathField(parent, constant) + ConstPathField.new( + parent: parent, + constant: constant, + location: Location.default + ) + end + + # Create a new ConstPathRef node. + def ConstPathRef(parent, constant) + ConstPathRef.new( + parent: parent, + constant: constant, + location: Location.default + ) + end + + # Create a new ConstRef node. + def ConstRef(constant) + ConstRef.new(constant: constant, location: Location.default) + end + + # Create a new CVar node. + def CVar(value) + CVar.new(value: value, location: Location.default) + end + + # Create a new DefNode node. + def DefNode( + target, + operator, + name, + params, + bodystmt, + location = Location.default + ) + DefNode.new( + target: target, + operator: operator, + name: name, + params: params, + bodystmt: bodystmt, + location: location + ) + end + + # Create a new Defined node. + def Defined(value) + Defined.new(value: value, location: Location.default) + end + + # Create a new BlockNode node. + def BlockNode(opening, block_var, bodystmt) + BlockNode.new( + opening: opening, + block_var: block_var, + bodystmt: bodystmt, + location: Location.default + ) + end + + # Create a new RangeNode node. + def RangeNode(left, operator, right) + RangeNode.new( + left: left, + operator: operator, + right: right, + location: Location.default + ) + end + + # Create a new DynaSymbol node. + def DynaSymbol(parts, quote) + DynaSymbol.new(parts: parts, quote: quote, location: Location.default) + end + + # Create a new Else node. + def Else(keyword, statements) + Else.new( + keyword: keyword, + statements: statements, + location: Location.default + ) + end + + # Create a new Elsif node. + def Elsif(predicate, statements, consequent) + Elsif.new( + predicate: predicate, + statements: statements, + consequent: consequent, + location: Location.default + ) + end + + # Create a new EmbDoc node. + def EmbDoc(value) + EmbDoc.new(value: value, location: Location.default) + end + + # Create a new EmbExprBeg node. + def EmbExprBeg(value) + EmbExprBeg.new(value: value, location: Location.default) + end + + # Create a new EmbExprEnd node. + def EmbExprEnd(value) + EmbExprEnd.new(value: value, location: Location.default) + end + + # Create a new EmbVar node. + def EmbVar(value) + EmbVar.new(value: value, location: Location.default) + end + + # Create a new Ensure node. + def Ensure(keyword, statements) + Ensure.new( + keyword: keyword, + statements: statements, + location: Location.default + ) + end + + # Create a new ExcessedComma node. + def ExcessedComma(value) + ExcessedComma.new(value: value, location: Location.default) + end + + # Create a new Field node. + def Field(parent, operator, name) + Field.new( + parent: parent, + operator: operator, + name: name, + location: Location.default + ) + end + + # Create a new FloatLiteral node. + def FloatLiteral(value) + FloatLiteral.new(value: value, location: Location.default) + end + + # Create a new FndPtn node. + def FndPtn(constant, left, values, right) + FndPtn.new( + constant: constant, + left: left, + values: values, + right: right, + location: Location.default + ) + end + + # Create a new For node. + def For(index, collection, statements) + For.new( + index: index, + collection: collection, + statements: statements, + location: Location.default + ) + end + + # Create a new GVar node. + def GVar(value) + GVar.new(value: value, location: Location.default) + end + + # Create a new HashLiteral node. + def HashLiteral(lbrace, assocs) + HashLiteral.new( + lbrace: lbrace, + assocs: assocs, + location: Location.default + ) + end + + # Create a new Heredoc node. + def Heredoc(beginning, ending, dedent, parts) + Heredoc.new( + beginning: beginning, + ending: ending, + dedent: dedent, + parts: parts, + location: Location.default + ) + end + + # Create a new HeredocBeg node. + def HeredocBeg(value) + HeredocBeg.new(value: value, location: Location.default) + end + + # Create a new HeredocEnd node. + def HeredocEnd(value) + HeredocEnd.new(value: value, location: Location.default) + end + + # Create a new HshPtn node. + def HshPtn(constant, keywords, keyword_rest) + HshPtn.new( + constant: constant, + keywords: keywords, + keyword_rest: keyword_rest, + location: Location.default + ) + end + + # Create a new Ident node. + def Ident(value) + Ident.new(value: value, location: Location.default) + end + + # Create a new IfNode node. + def IfNode(predicate, statements, consequent) + IfNode.new( + predicate: predicate, + statements: statements, + consequent: consequent, + location: Location.default + ) + end + + # Create a new IfOp node. + def IfOp(predicate, truthy, falsy) + IfOp.new( + predicate: predicate, + truthy: truthy, + falsy: falsy, + location: Location.default + ) + end + + # Create a new Imaginary node. + def Imaginary(value) + Imaginary.new(value: value, location: Location.default) + end + + # Create a new In node. + def In(pattern, statements, consequent) + In.new( + pattern: pattern, + statements: statements, + consequent: consequent, + location: Location.default + ) + end + + # Create a new Int node. + def Int(value) + Int.new(value: value, location: Location.default) + end + + # Create a new IVar node. + def IVar(value) + IVar.new(value: value, location: Location.default) + end + + # Create a new Kw node. + def Kw(value) + Kw.new(value: value, location: Location.default) + end + + # Create a new KwRestParam node. + def KwRestParam(name) + KwRestParam.new(name: name, location: Location.default) + end + + # Create a new Label node. + def Label(value) + Label.new(value: value, location: Location.default) + end + + # Create a new LabelEnd node. + def LabelEnd(value) + LabelEnd.new(value: value, location: Location.default) + end + + # Create a new Lambda node. + def Lambda(params, statements) + Lambda.new( + params: params, + statements: statements, + location: Location.default + ) + end + + # Create a new LambdaVar node. + def LambdaVar(params, locals) + LambdaVar.new(params: params, locals: locals, location: Location.default) + end + + # Create a new LBrace node. + def LBrace(value) + LBrace.new(value: value, location: Location.default) + end + + # Create a new LBracket node. + def LBracket(value) + LBracket.new(value: value, location: Location.default) + end + + # Create a new LParen node. + def LParen(value) + LParen.new(value: value, location: Location.default) + end + + # Create a new MAssign node. + def MAssign(target, value) + MAssign.new(target: target, value: value, location: Location.default) + end + + # Create a new MethodAddBlock node. + def MethodAddBlock(call, block, location = Location.default) + MethodAddBlock.new(call: call, block: block, location: location) + end + + # Create a new MLHS node. + def MLHS(parts, comma) + MLHS.new(parts: parts, comma: comma, location: Location.default) + end + + # Create a new MLHSParen node. + def MLHSParen(contents, comma) + MLHSParen.new( + contents: contents, + comma: comma, + location: Location.default + ) + end + + # Create a new ModuleDeclaration node. + def ModuleDeclaration(constant, bodystmt) + ModuleDeclaration.new( + constant: constant, + bodystmt: bodystmt, + location: Location.default + ) + end + + # Create a new MRHS node. + def MRHS(parts) + MRHS.new(parts: parts, location: Location.default) + end + + # Create a new Next node. + def Next(arguments) + Next.new(arguments: arguments, location: Location.default) + end + + # Create a new Op node. + def Op(value) + Op.new(value: value, location: Location.default) + end + + # Create a new OpAssign node. + def OpAssign(target, operator, value) + OpAssign.new( + target: target, + operator: operator, + value: value, + location: Location.default + ) + end + + # Create a new Params node. + def Params(requireds, optionals, rest, posts, keywords, keyword_rest, block) + Params.new( + requireds: requireds, + optionals: optionals, + rest: rest, + posts: posts, + keywords: keywords, + keyword_rest: keyword_rest, + block: block, + location: Location.default + ) + end + + # Create a new Paren node. + def Paren(lparen, contents) + Paren.new(lparen: lparen, contents: contents, location: Location.default) + end + + # Create a new Period node. + def Period(value) + Period.new(value: value, location: Location.default) + end + + # Create a new Program node. + def Program(statements) + Program.new(statements: statements, location: Location.default) + end + + # Create a new QSymbols node. + def QSymbols(beginning, elements) + QSymbols.new( + beginning: beginning, + elements: elements, + location: Location.default + ) + end + + # Create a new QSymbolsBeg node. + def QSymbolsBeg(value) + QSymbolsBeg.new(value: value, location: Location.default) + end + + # Create a new QWords node. + def QWords(beginning, elements) + QWords.new( + beginning: beginning, + elements: elements, + location: Location.default + ) + end + + # Create a new QWordsBeg node. + def QWordsBeg(value) + QWordsBeg.new(value: value, location: Location.default) + end + + # Create a new RationalLiteral node. + def RationalLiteral(value) + RationalLiteral.new(value: value, location: Location.default) + end + + # Create a new RBrace node. + def RBrace(value) + RBrace.new(value: value, location: Location.default) + end + + # Create a new RBracket node. + def RBracket(value) + RBracket.new(value: value, location: Location.default) + end + + # Create a new Redo node. + def Redo + Redo.new(location: Location.default) + end + + # Create a new RegexpContent node. + def RegexpContent(beginning, parts) + RegexpContent.new( + beginning: beginning, + parts: parts, + location: Location.default + ) + end + + # Create a new RegexpBeg node. + def RegexpBeg(value) + RegexpBeg.new(value: value, location: Location.default) + end + + # Create a new RegexpEnd node. + def RegexpEnd(value) + RegexpEnd.new(value: value, location: Location.default) + end + + # Create a new RegexpLiteral node. + def RegexpLiteral(beginning, ending, parts) + RegexpLiteral.new( + beginning: beginning, + ending: ending, + parts: parts, + location: Location.default + ) + end + + # Create a new RescueEx node. + def RescueEx(exceptions, variable) + RescueEx.new( + exceptions: exceptions, + variable: variable, + location: Location.default + ) + end + + # Create a new Rescue node. + def Rescue(keyword, exception, statements, consequent) + Rescue.new( + keyword: keyword, + exception: exception, + statements: statements, + consequent: consequent, + location: Location.default + ) + end + + # Create a new RescueMod node. + def RescueMod(statement, value) + RescueMod.new( + statement: statement, + value: value, + location: Location.default + ) + end + + # Create a new RestParam node. + def RestParam(name) + RestParam.new(name: name, location: Location.default) + end + + # Create a new Retry node. + def Retry + Retry.new(location: Location.default) + end + + # Create a new ReturnNode node. + def ReturnNode(arguments) + ReturnNode.new(arguments: arguments, location: Location.default) + end + + # Create a new RParen node. + def RParen(value) + RParen.new(value: value, location: Location.default) + end + + # Create a new SClass node. + def SClass(target, bodystmt) + SClass.new(target: target, bodystmt: bodystmt, location: Location.default) + end + + # Create a new Statements node. + def Statements(body) + Statements.new(body: body, location: Location.default) + end + + # Create a new StringContent node. + def StringContent(parts) + StringContent.new(parts: parts, location: Location.default) + end + + # Create a new StringConcat node. + def StringConcat(left, right) + StringConcat.new(left: left, right: right, location: Location.default) + end + + # Create a new StringDVar node. + def StringDVar(variable) + StringDVar.new(variable: variable, location: Location.default) + end + + # Create a new StringEmbExpr node. + def StringEmbExpr(statements) + StringEmbExpr.new(statements: statements, location: Location.default) + end + + # Create a new StringLiteral node. + def StringLiteral(parts, quote) + StringLiteral.new(parts: parts, quote: quote, location: Location.default) + end + + # Create a new Super node. + def Super(arguments) + Super.new(arguments: arguments, location: Location.default) + end + + # Create a new SymBeg node. + def SymBeg(value) + SymBeg.new(value: value, location: Location.default) + end + + # Create a new SymbolContent node. + def SymbolContent(value) + SymbolContent.new(value: value, location: Location.default) + end + + # Create a new SymbolLiteral node. + def SymbolLiteral(value) + SymbolLiteral.new(value: value, location: Location.default) + end + + # Create a new Symbols node. + def Symbols(beginning, elements) + Symbols.new( + beginning: beginning, + elements: elements, + location: Location.default + ) + end + + # Create a new SymbolsBeg node. + def SymbolsBeg(value) + SymbolsBeg.new(value: value, location: Location.default) + end + + # Create a new TLambda node. + def TLambda(value) + TLambda.new(value: value, location: Location.default) + end + + # Create a new TLamBeg node. + def TLamBeg(value) + TLamBeg.new(value: value, location: Location.default) + end + + # Create a new TopConstField node. + def TopConstField(constant) + TopConstField.new(constant: constant, location: Location.default) + end + + # Create a new TopConstRef node. + def TopConstRef(constant) + TopConstRef.new(constant: constant, location: Location.default) + end + + # Create a new TStringBeg node. + def TStringBeg(value) + TStringBeg.new(value: value, location: Location.default) + end + + # Create a new TStringContent node. + def TStringContent(value) + TStringContent.new(value: value, location: Location.default) + end + + # Create a new TStringEnd node. + def TStringEnd(value) + TStringEnd.new(value: value, location: Location.default) + end + + # Create a new Not node. + def Not(statement, parentheses) + Not.new( + statement: statement, + parentheses: parentheses, + location: Location.default + ) + end + + # Create a new Unary node. + def Unary(operator, statement) + Unary.new( + operator: operator, + statement: statement, + location: Location.default + ) + end + + # Create a new Undef node. + def Undef(symbols) + Undef.new(symbols: symbols, location: Location.default) + end + + # Create a new UnlessNode node. + def UnlessNode(predicate, statements, consequent) + UnlessNode.new( + predicate: predicate, + statements: statements, + consequent: consequent, + location: Location.default + ) + end + + # Create a new UntilNode node. + def UntilNode(predicate, statements) + UntilNode.new( + predicate: predicate, + statements: statements, + location: Location.default + ) + end + + # Create a new VarField node. + def VarField(value) + VarField.new(value: value, location: Location.default) + end + + # Create a new VarRef node. + def VarRef(value) + VarRef.new(value: value, location: Location.default) + end + + # Create a new PinnedVarRef node. + def PinnedVarRef(value) + PinnedVarRef.new(value: value, location: Location.default) + end + + # Create a new VCall node. + def VCall(value) + VCall.new(value: value, location: Location.default) + end + + # Create a new VoidStmt node. + def VoidStmt + VoidStmt.new(location: Location.default) + end + + # Create a new When node. + def When(arguments, statements, consequent) + When.new( + arguments: arguments, + statements: statements, + consequent: consequent, + location: Location.default + ) + end + + # Create a new WhileNode node. + def WhileNode(predicate, statements) + WhileNode.new( + predicate: predicate, + statements: statements, + location: Location.default + ) + end + + # Create a new Word node. + def Word(parts) + Word.new(parts: parts, location: Location.default) + end + + # Create a new Words node. + def Words(beginning, elements) + Words.new( + beginning: beginning, + elements: elements, + location: Location.default + ) + end + + # Create a new WordsBeg node. + def WordsBeg(value) + WordsBeg.new(value: value, location: Location.default) + end + + # Create a new XString node. + def XString(parts) + XString.new(parts: parts, location: Location.default) + end + + # Create a new XStringLiteral node. + def XStringLiteral(parts) + XStringLiteral.new(parts: parts, location: Location.default) + end + + # Create a new YieldNode node. + def YieldNode(arguments) + YieldNode.new(arguments: arguments, location: Location.default) + end + + # Create a new ZSuper node. + def ZSuper + ZSuper.new(location: Location.default) + end + end +end diff --git a/lib/syntax_tree/visitor/field_visitor.rb b/lib/syntax_tree/field_visitor.rb similarity index 82% rename from lib/syntax_tree/visitor/field_visitor.rb rename to lib/syntax_tree/field_visitor.rb index 631084e8..f5607c67 100644 --- a/lib/syntax_tree/visitor/field_visitor.rb +++ b/lib/syntax_tree/field_visitor.rb @@ -1,57 +1,54 @@ # frozen_string_literal: true module SyntaxTree - class Visitor - # This is the parent class of a lot of built-in visitors for Syntax Tree. It - # reflects visiting each of the fields on every node in turn. It itself does - # not do anything with these fields, it leaves that behavior up to the - # subclass to implement. - # - # In order to properly use this class, you will need to subclass it and - # implement #comments, #field, #list, #node, #pairs, and #text. Those are - # documented here. - # - # == comments(node) - # - # This accepts the node that is being visited and does something depending - # on the comments attached to the node. - # - # == field(name, value) - # - # This accepts the name of the field being visited as a string (like - # "value") and the actual value of that field. The value can be a subclass - # of Node or any other type that can be held within the tree. - # - # == list(name, values) - # - # This accepts the name of the field being visited as well as a list of - # values. This is used, for example, when visiting something like the body - # of a Statements node. - # - # == node(name, node) - # - # This is the parent serialization method for each node. It is called with - # the node itself, as well as the type of the node as a string. The type - # is an internally used value that usually resembles the name of the - # ripper event that generated the node. The method should yield to the - # given block which then calls through to visit each of the fields on the - # node. - # - # == text(name, value) - # - # This accepts the name of the field being visited as well as a string - # value representing the value of the field. - # - # == pairs(name, values) - # - # This accepts the name of the field being visited as well as a list of - # pairs that represent the value of the field. It is used only in a couple - # of circumstances, like when visiting the list of optional parameters - # defined on a method. - # - class FieldVisitor < Visitor - attr_reader :q - + # This is the parent class of a lot of built-in visitors for Syntax Tree. It + # reflects visiting each of the fields on every node in turn. It itself does + # not do anything with these fields, it leaves that behavior up to the + # subclass to implement. + # + # In order to properly use this class, you will need to subclass it and + # implement #comments, #field, #list, #node, #pairs, and #text. Those are + # documented here. + # + # == comments(node) + # + # This accepts the node that is being visited and does something depending on + # the comments attached to the node. + # + # == field(name, value) + # + # This accepts the name of the field being visited as a string (like "value") + # and the actual value of that field. The value can be a subclass of Node or + # any other type that can be held within the tree. + # + # == list(name, values) + # + # This accepts the name of the field being visited as well as a list of + # values. This is used, for example, when visiting something like the body of + # a Statements node. + # + # == node(name, node) + # + # This is the parent serialization method for each node. It is called with the + # node itself, as well as the type of the node as a string. The type is an + # internally used value that usually resembles the name of the ripper event + # that generated the node. The method should yield to the given block which + # then calls through to visit each of the fields on the node. + # + # == text(name, value) + # + # This accepts the name of the field being visited as well as a string value + # representing the value of the field. + # + # == pairs(name, values) + # + # This accepts the name of the field being visited as well as a list of pairs + # that represent the value of the field. It is used only in a couple of + # circumstances, like when visiting the list of optional parameters defined on + # a method. + # + class FieldVisitor < BasicVisitor + visit_methods do def visit_aref(node) node(node, "aref") do field("collection", node.collection) @@ -105,7 +102,7 @@ def visit_args(node) end def visit_args_forward(node) - visit_token(node, "args_forward") + node(node, "args_forward") { comments(node) } end def visit_array(node) @@ -186,6 +183,14 @@ def visit_binary(node) end end + def visit_block(node) + node(node, "block") do + field("block_var", node.block_var) if node.block_var + field("bodystmt", node.bodystmt) + comments(node) + end + end + def visit_blockarg(node) node(node, "blockarg") do field("name", node.name) if node.name @@ -211,14 +216,6 @@ def visit_bodystmt(node) end end - def visit_brace_block(node) - node(node, "brace_block") do - field("block_var", node.block_var) if node.block_var - field("statements", node.statements) - comments(node) - end - end - def visit_break(node) node(node, "break") do field("arguments", node.arguments) @@ -266,6 +263,7 @@ def visit_command(node) node(node, "command") do field("message", node.message) field("arguments", node.arguments) + field("block", node.block) if node.block comments(node) end end @@ -276,6 +274,7 @@ def visit_command_call(node) field("operator", node.operator) field("message", node.message) field("arguments", node.arguments) if node.arguments + field("block", node.block) if node.block comments(node) end end @@ -317,6 +316,8 @@ def visit_cvar(node) def visit_def(node) node(node, "def") do + field("target", node.target) + field("operator", node.operator) field("name", node.name) field("params", node.params) field("bodystmt", node.bodystmt) @@ -324,20 +325,6 @@ def visit_def(node) end end - def visit_def_endless(node) - node(node, "def_endless") do - if node.target - field("target", node.target) - field("operator", node.operator) - end - - field("name", node.name) - field("paren", node.paren) if node.paren - field("statement", node.statement) - comments(node) - end - end - def visit_defined(node) node(node, "defined") do field("value", node.value) @@ -345,41 +332,6 @@ def visit_defined(node) end end - def visit_defs(node) - node(node, "defs") do - field("target", node.target) - field("operator", node.operator) - field("name", node.name) - field("params", node.params) - field("bodystmt", node.bodystmt) - comments(node) - end - end - - def visit_do_block(node) - node(node, "do_block") do - field("block_var", node.block_var) if node.block_var - field("bodystmt", node.bodystmt) - comments(node) - end - end - - def visit_dot2(node) - node(node, "dot2") do - field("left", node.left) if node.left - field("right", node.right) if node.right - comments(node) - end - end - - def visit_dot3(node) - node(node, "dot3") do - field("left", node.left) if node.left - field("right", node.right) if node.right - comments(node) - end - end - def visit_dyna_symbol(node) node(node, "dyna_symbol") do list("parts", node.parts) @@ -437,14 +389,6 @@ def visit_excessed_comma(node) visit_token(node, "excessed_comma") end - def visit_fcall(node) - node(node, "fcall") do - field("value", node.value) - field("arguments", node.arguments) if node.arguments - comments(node) - end - end - def visit_field(node) node(node, "field") do field("parent", node.parent) @@ -499,6 +443,10 @@ def visit_heredoc_beg(node) visit_token(node, "heredoc_beg") end + def visit_heredoc_end(node) + visit_token(node, "heredoc_end") + end + def visit_hshptn(node) node(node, "hshptn") do field("constant", node.constant) if node.constant @@ -521,14 +469,6 @@ def visit_if(node) end end - def visit_if_mod(node) - node(node, "if_mod") do - field("statement", node.statement) - field("predicate", node.predicate) - comments(node) - end - end - def visit_if_op(node) node(node, "if_op") do field("predicate", node.predicate) @@ -586,6 +526,14 @@ def visit_lambda(node) end end + def visit_lambda_var(node) + node(node, "lambda_var") do + field("params", node.params) + list("locals", node.locals) if node.locals.any? + comments(node) + end + end + def visit_lbrace(node) visit_token(node, "lbrace") end @@ -737,6 +685,15 @@ def visit_qwords_beg(node) node(node, "qwords_beg") { field("value", node.value) } end + def visit_range(node) + node(node, "range") do + field("left", node.left) if node.left + field("operator", node.operator) + field("right", node.right) if node.right + comments(node) + end + end + def visit_rassign(node) node(node, "rassign") do field("value", node.value) @@ -759,7 +716,7 @@ def visit_rbracket(node) end def visit_redo(node) - visit_token(node, "redo") + node(node, "redo") { comments(node) } end def visit_regexp_beg(node) @@ -815,7 +772,7 @@ def visit_rest_param(node) end def visit_retry(node) - visit_token(node, "retry") + node(node, "retry") { comments(node) } end def visit_return(node) @@ -825,10 +782,6 @@ def visit_return(node) end end - def visit_return0(node) - visit_token(node, "return0") - end - def visit_rparen(node) node(node, "rparen") { field("value", node.value) } end @@ -972,14 +925,6 @@ def visit_unless(node) end end - def visit_unless_mod(node) - node(node, "unless_mod") do - field("statement", node.statement) - field("predicate", node.predicate) - comments(node) - end - end - def visit_until(node) node(node, "until") do field("predicate", node.predicate) @@ -988,22 +933,6 @@ def visit_until(node) end end - def visit_until_mod(node) - node(node, "until_mod") do - field("statement", node.statement) - field("predicate", node.predicate) - comments(node) - end - end - - def visit_var_alias(node) - node(node, "var_alias") do - field("left", node.left) - field("right", node.right) - comments(node) - end - end - def visit_var_field(node) node(node, "var_field") do field("value", node.value) @@ -1046,14 +975,6 @@ def visit_while(node) end end - def visit_while_mod(node) - node(node, "while_mod") do - field("statement", node.statement) - field("predicate", node.predicate) - comments(node) - end - end - def visit_word(node) node(node, "word") do list("parts", node.parts) @@ -1090,25 +1011,21 @@ def visit_yield(node) end end - def visit_yield0(node) - visit_token(node, "yield0") - end - def visit_zsuper(node) - visit_token(node, "zsuper") + node(node, "zsuper") { comments(node) } end def visit___end__(node) visit_token(node, "__end__") end + end - private + private - def visit_token(node, type) - node(node, type) do - field("value", node.value) - comments(node) - end + def visit_token(node, type) + node(node, type) do + field("value", node.value) + comments(node) end end end diff --git a/lib/syntax_tree/formatter.rb b/lib/syntax_tree/formatter.rb index 9959421a..2b229885 100644 --- a/lib/syntax_tree/formatter.rb +++ b/lib/syntax_tree/formatter.rb @@ -3,25 +3,113 @@ module SyntaxTree # A slightly enhanced PP that knows how to format recursively including # comments. - class Formatter < PP + class Formatter < PrettierPrint + # Unfortunately, Gem::Version.new is not ractor-safe because it performs + # global caching using a class variable. This works around that by just + # setting the instance variables directly. + class SemanticVersion < ::Gem::Version + def initialize(version) + @version = version + @segments = nil + end + end + + # We want to minimize as much as possible the number of options that are + # available in syntax tree. For the most part, if users want non-default + # formatting, they should override the format methods on the specific nodes + # themselves. However, because of some history with prettier and the fact + # that folks have become entrenched in their ways, we decided to provide a + # small amount of configurability. + class Options + attr_reader :quote, + :trailing_comma, + :disable_auto_ternary, + :target_ruby_version + + def initialize( + quote: :default, + trailing_comma: :default, + disable_auto_ternary: :default, + target_ruby_version: :default + ) + @quote = + if quote == :default + # We ship with a single quotes plugin that will define this + # constant. That constant is responsible for determining the default + # quote style. If it's defined, we default to single quotes, + # otherwise we default to double quotes. + defined?(SINGLE_QUOTES) ? "'" : "\"" + else + quote + end + + @trailing_comma = + if trailing_comma == :default + # We ship with a trailing comma plugin that will define this + # constant. That constant is responsible for determining the default + # trailing comma value. If it's defined, then we default to true. + # Otherwise we default to false. + defined?(TRAILING_COMMA) + else + trailing_comma + end + + @disable_auto_ternary = + if disable_auto_ternary == :default + # We ship with a disable ternary plugin that will define this + # constant. That constant is responsible for determining the default + # disable ternary value. If it's defined, then we default to true. + # Otherwise we default to false. + defined?(DISABLE_AUTO_TERNARY) + else + disable_auto_ternary + end + + @target_ruby_version = + if target_ruby_version == :default + # The default target Ruby version is the current version of Ruby. + # This is really only used for very niche cases, and it shouldn't be + # used by most users. + SemanticVersion.new(RUBY_VERSION) + else + target_ruby_version + end + end + end + COMMENT_PRIORITY = 1 HEREDOC_PRIORITY = 2 - attr_reader :source, :stack, :quote + attr_reader :source, :stack + + # These options are overridden in plugins to we need to make sure they are + # available here. + attr_reader :quote, + :trailing_comma, + :disable_auto_ternary, + :target_ruby_version + + alias trailing_comma? trailing_comma + alias disable_auto_ternary? disable_auto_ternary - def initialize(source, ...) - super(...) + def initialize(source, *args, options: Options.new) + super(*args) @source = source @stack = [] - @quote = "\"" + + # Memoizing these values to make access faster. + @quote = options.quote + @trailing_comma = options.trailing_comma + @disable_auto_ternary = options.disable_auto_ternary + @target_ruby_version = options.target_ruby_version end - def self.format(source, node) - formatter = new(source, []) - node.format(formatter) - formatter.flush - formatter.output.join + def self.format(source, node, base_indentation = 0) + q = new(source, []) + q.format(node) + q.flush(base_indentation) + q.output.join end def format(node, stackable: true) @@ -31,19 +119,39 @@ def format(node, stackable: true) # If there are comments, then we're going to format them around the node # so that they get printed properly. if node.comments.any? - leading, trailing = node.comments.partition(&:leading?) + trailing = [] + last_leading = nil - # Print all comments that were found before the node. - leading.each do |comment| - comment.format(self) - breakable(force: true) + # First, we're going to print all of the comments that were found before + # the node. We'll also gather up any trailing comments that we find. + node.comments.each do |comment| + if comment.leading? + comment.format(self) + breakable(force: true) + last_leading = comment + else + trailing << comment + end end # If the node has a stree-ignore comment right before it, then we're # going to just print out the node as it was seen in the source. doc = - if leading.last&.ignore? - text(source[node.location.start_char...node.location.end_char]) + if last_leading&.ignore? + range = source[node.start_char...node.end_char] + first = true + + range.each_line(chomp: true) do |line| + if first + first = false + else + breakable_return + end + + text(line) + end + + breakable_return if range.end_with?("\n") else node.format(self) end @@ -68,6 +176,10 @@ def format_each(nodes) nodes.each { |node| format(node) } end + def grandparent + stack[-3] + end + def parent stack[-2] end @@ -75,5 +187,42 @@ def parent def parents stack[0...-1].reverse_each end + + # This is a simplified version of prettyprint's group. It doesn't provide + # any of the more advanced options because we don't need them and they take + # up expensive computation time. + def group + contents = [] + doc = Group.new(0, contents: contents) + + groups << doc + target << doc + + with_target(contents) { yield } + groups.pop + doc + end + + # A similar version to the super, except that it calls back into the + # separator proc with the instance of `self`. + def seplist(list, sep = nil, iter_method = :each) + first = true + list.__send__(iter_method) do |*v| + if first + first = false + elsif sep + sep.call(self) + else + comma_breakable + end + yield(*v) + end + end + + # This is a much simplified version of prettyprint's text. It avoids + # calculating width by pushing the string directly onto the target. + def text(string) + target << string + end end end diff --git a/lib/syntax_tree/formatter/single_quotes.rb b/lib/syntax_tree/formatter/single_quotes.rb deleted file mode 100644 index 4d1f41b3..00000000 --- a/lib/syntax_tree/formatter/single_quotes.rb +++ /dev/null @@ -1,13 +0,0 @@ -# frozen_string_literal: true - -module SyntaxTree - class Formatter - # This module overrides the quote method on the formatter to use single - # quotes for everything instead of double quotes. - module SingleQuotes - def quote - "'" - end - end - end -end diff --git a/lib/syntax_tree/index.rb b/lib/syntax_tree/index.rb new file mode 100644 index 00000000..0280749f --- /dev/null +++ b/lib/syntax_tree/index.rb @@ -0,0 +1,683 @@ +# frozen_string_literal: true + +module SyntaxTree + # This class can be used to build an index of the structure of Ruby files. We + # define an index as the list of constants and methods defined within a file. + # + # This index strives to be as fast as possible to better support tools like + # IDEs. Because of that, it has different backends depending on what + # functionality is available. + module Index + # This is a location for an index entry. + class Location + attr_reader :line, :column + + def initialize(line, column) + @line = line + @column = column + end + end + + # This entry represents a class definition using the class keyword. + class ClassDefinition + attr_reader :nesting, :name, :superclass, :location, :comments + + def initialize(nesting, name, superclass, location, comments) + @nesting = nesting + @name = name + @superclass = superclass + @location = location + @comments = comments + end + end + + # This entry represents a constant assignment. + class ConstantDefinition + attr_reader :nesting, :name, :location, :comments + + def initialize(nesting, name, location, comments) + @nesting = nesting + @name = name + @location = location + @comments = comments + end + end + + # This entry represents a module definition using the module keyword. + class ModuleDefinition + attr_reader :nesting, :name, :location, :comments + + def initialize(nesting, name, location, comments) + @nesting = nesting + @name = name + @location = location + @comments = comments + end + end + + # This entry represents a method definition using the def keyword. + class MethodDefinition + attr_reader :nesting, :name, :location, :comments + + def initialize(nesting, name, location, comments) + @nesting = nesting + @name = name + @location = location + @comments = comments + end + end + + # This entry represents a singleton method definition using the def keyword + # with a specified target. + class SingletonMethodDefinition + attr_reader :nesting, :name, :location, :comments + + def initialize(nesting, name, location, comments) + @nesting = nesting + @name = name + @location = location + @comments = comments + end + end + + # This entry represents a method definition that was created using the alias + # keyword. + class AliasMethodDefinition + attr_reader :nesting, :name, :location, :comments + + def initialize(nesting, name, location, comments) + @nesting = nesting + @name = name + @location = location + @comments = comments + end + end + + # When you're using the instruction sequence backend, this class is used to + # lazily parse comments out of the source code. + class FileComments + # We use the ripper library to pull out source comments. + class Parser < Ripper + attr_reader :comments + + def initialize(*) + super + @comments = {} + end + + def on_comment(value) + comments[lineno] = value.chomp + end + end + + # This represents the Ruby source in the form of a file. When it needs to + # be read we'll read the file. + class FileSource + attr_reader :filepath + + def initialize(filepath) + @filepath = filepath + end + + def source + File.read(filepath) + end + end + + # This represents the Ruby source in the form of a string. When it needs + # to be read the string is returned. + class StringSource + attr_reader :source + + def initialize(source) + @source = source + end + end + + attr_reader :source + + def initialize(source) + @source = source + end + + def comments + @comments ||= Parser.new(source.source).tap(&:parse).comments + end + end + + # This class handles parsing comments from Ruby source code in the case that + # we use the instruction sequence backend. Because the instruction sequence + # backend doesn't provide comments (since they are dropped) we provide this + # interface to lazily parse them out. + class EntryComments + include Enumerable + attr_reader :file_comments, :location + + def initialize(file_comments, location) + @file_comments = file_comments + @location = location + end + + def each(&block) + line = location.line - 1 + result = [] + + while line >= 0 && (comment = file_comments.comments[line]) + result.unshift(comment) + line -= 1 + end + + result.each(&block) + end + end + + # This backend creates the index using RubyVM::InstructionSequence, which is + # faster than using the Syntax Tree parser, but is not available on all + # runtimes. + class ISeqBackend + VM_DEFINECLASS_TYPE_CLASS = 0x00 + VM_DEFINECLASS_TYPE_SINGLETON_CLASS = 0x01 + VM_DEFINECLASS_TYPE_MODULE = 0x02 + VM_DEFINECLASS_FLAG_SCOPED = 0x08 + VM_DEFINECLASS_FLAG_HAS_SUPERCLASS = 0x10 + + def index(source) + index_iseq( + RubyVM::InstructionSequence.compile(source).to_a, + FileComments.new(FileComments::StringSource.new(source)) + ) + end + + def index_file(filepath) + index_iseq( + RubyVM::InstructionSequence.compile_file(filepath).to_a, + FileComments.new(FileComments::FileSource.new(filepath)) + ) + end + + private + + def location_for(iseq) + code_location = iseq[4][:code_location] + Location.new(code_location[0], code_location[1]) + end + + def find_constant_path(insns, index) + index -= 1 while index >= 0 && + ( + insns[index].is_a?(Integer) || + ( + insns[index].is_a?(Array) && + %i[swap topn].include?(insns[index][0]) + ) + ) + insn = insns[index] + + if insn.is_a?(Array) && insn[0] == :opt_getconstant_path + # In this case we're on Ruby 3.2+ and we have an opt_getconstant_path + # instruction, so we already know all of the symbols in the nesting. + [index - 1, insn[1]] + elsif insn.is_a?(Symbol) && insn.match?(/\Alabel_\d+/) + # Otherwise, if we have a label then this is very likely the + # destination of an opt_getinlinecache instruction, in which case + # we'll walk backwards to grab up all of the constants. + names = [] + + index -= 1 + until insns[index].is_a?(Array) && + insns[index][0] == :opt_getinlinecache + if insns[index].is_a?(Array) && insns[index][0] == :getconstant + names.unshift(insns[index][1]) + end + + index -= 1 + end + + [index - 1, names] + else + [index, []] + end + end + + def find_attr_arguments(insns, index) + orig_argc = insns[index][1][:orig_argc] + names = [] + + current = index - 1 + while current >= 0 && names.length < orig_argc + if insns[current].is_a?(Array) && insns[current][0] == :putobject + names.unshift(insns[current][1]) + end + + current -= 1 + end + + names if insns[current] == [:putself] && names.length == orig_argc + end + + def method_definition(nesting, name, location, file_comments) + comments = EntryComments.new(file_comments, location) + + if nesting.last == [:singletonclass] + SingletonMethodDefinition.new( + nesting[0...-1], + name, + location, + comments + ) + else + MethodDefinition.new(nesting, name, location, comments) + end + end + + def index_iseq(iseq, file_comments) + results = [] + queue = [[iseq, []]] + + while (current_iseq, current_nesting = queue.shift) + file = current_iseq[5] + line = current_iseq[8] + insns = current_iseq[13] + + insns.each_with_index do |insn, index| + case insn + when Integer + line = insn + next + when Array + # continue on + else + # skip everything else + next + end + + case insn[0] + when :defineclass + _, name, class_iseq, flags = insn + next_nesting = current_nesting.dup + + # This is the index we're going to search for the nested constant + # path within the declaration name. + constant_index = index - 2 + + # This is the superclass of the class being defined. + superclass = [] + + # If there is a superclass, then we're going to find it here and + # then update the constant_index as necessary. + if flags & VM_DEFINECLASS_FLAG_HAS_SUPERCLASS > 0 + constant_index, superclass = + find_constant_path(insns, index - 1) + + if superclass.empty? + warn("#{file}:#{line}: superclass with non constant path") + next + end + end + + if (_, nesting = find_constant_path(insns, constant_index)) + # If there is a constant path in the class name, then we need to + # handle that by updating the nesting. + next_nesting << (nesting << name) + else + # Otherwise we'll add the class name to the nesting. + next_nesting << [name] + end + + if flags == VM_DEFINECLASS_TYPE_SINGLETON_CLASS + # At the moment, we don't support singletons that aren't + # defined on self. We could, but it would require more + # emulation. + if insns[index - 2] != [:putself] + warn( + "#{file}:#{line}: singleton class with non-self receiver" + ) + next + end + elsif flags & VM_DEFINECLASS_TYPE_MODULE > 0 + location = location_for(class_iseq) + results << ModuleDefinition.new( + next_nesting, + name, + location, + EntryComments.new(file_comments, location) + ) + else + location = location_for(class_iseq) + results << ClassDefinition.new( + next_nesting, + name, + superclass, + location, + EntryComments.new(file_comments, location) + ) + end + + queue << [class_iseq, next_nesting] + when :definemethod + location = location_for(insn[2]) + results << method_definition( + current_nesting, + insn[1], + location, + file_comments + ) + when :definesmethod + if insns[index - 1] != [:putself] + warn("#{file}:#{line}: singleton method with non-self receiver") + next + end + + location = location_for(insn[2]) + results << SingletonMethodDefinition.new( + current_nesting, + insn[1], + location, + EntryComments.new(file_comments, location) + ) + when :setconstant + next_nesting = current_nesting.dup + name = insn[1] + + _, nesting = find_constant_path(insns, index - 1) + next_nesting << nesting if nesting.any? + + location = Location.new(line, :unknown) + results << ConstantDefinition.new( + next_nesting, + name, + location, + EntryComments.new(file_comments, location) + ) + when :opt_send_without_block, :send + case insn[1][:mid] + when :attr_reader, :attr_writer, :attr_accessor + attr_names = find_attr_arguments(insns, index) + next unless attr_names + + location = Location.new(line, :unknown) + attr_names.each do |attr_name| + if insn[1][:mid] != :attr_writer + results << method_definition( + current_nesting, + attr_name, + location, + file_comments + ) + end + + if insn[1][:mid] != :attr_reader + results << method_definition( + current_nesting, + :"#{attr_name}=", + location, + file_comments + ) + end + end + when :"core#set_method_alias" + # Now we have to validate that the alias is happening with a + # non-interpolated value. To do this we'll match the specific + # pattern we're expecting. + values = + insns[(index - 4)...index].map do |previous| + previous.is_a?(Array) ? previous[0] : previous + end + if values != + %i[putspecialobject putspecialobject putobject putobject] + next + end + + # Now that we know it's in the structure we want it, we can use + # the values of the putobject to determine the alias. + location = Location.new(line, :unknown) + results << AliasMethodDefinition.new( + current_nesting, + insns[index - 2][1], + location, + EntryComments.new(file_comments, location) + ) + end + end + end + end + + results + end + end + + # This backend creates the index using the Syntax Tree parser and a visitor. + # It is not as fast as using the instruction sequences directly, but is + # supported on all runtimes. + class ParserBackend + class ConstantNameVisitor < Visitor + def visit_const_ref(node) + [node.constant.value.to_sym] + end + + def visit_const_path_ref(node) + visit(node.parent) << node.constant.value.to_sym + end + + def visit_var_ref(node) + [node.value.value.to_sym] + end + end + + class IndexVisitor < Visitor + attr_reader :results, :nesting, :statements + + def initialize + @results = [] + @nesting = [] + @statements = nil + end + + visit_methods do + def visit_alias(node) + if node.left.is_a?(SymbolLiteral) && node.right.is_a?(SymbolLiteral) + location = + Location.new( + node.location.start_line, + node.location.start_column + ) + + results << AliasMethodDefinition.new( + nesting.dup, + node.left.value.value.to_sym, + location, + comments_for(node) + ) + end + + super + end + + def visit_assign(node) + if node.target.is_a?(VarField) && node.target.value.is_a?(Const) + location = + Location.new( + node.location.start_line, + node.location.start_column + ) + + results << ConstantDefinition.new( + nesting.dup, + node.target.value.value.to_sym, + location, + comments_for(node) + ) + end + + super + end + + def visit_class(node) + names = node.constant.accept(ConstantNameVisitor.new) + nesting << names + + location = + Location.new(node.location.start_line, node.location.start_column) + + superclass = + if node.superclass + visited = node.superclass.accept(ConstantNameVisitor.new) + + if visited == [[]] + raise NotImplementedError, "superclass with non constant path" + end + + visited + else + [] + end + + results << ClassDefinition.new( + nesting.dup, + names.last, + superclass, + location, + comments_for(node) + ) + + super + nesting.pop + end + + def visit_command(node) + case node.message.value + when "attr_reader", "attr_writer", "attr_accessor" + comments = comments_for(node) + location = + Location.new( + node.location.start_line, + node.location.start_column + ) + + node.arguments.parts.each do |argument| + next unless argument.is_a?(SymbolLiteral) + name = argument.value.value.to_sym + + if node.message.value != "attr_writer" + results << MethodDefinition.new( + nesting.dup, + name, + location, + comments + ) + end + + if node.message.value != "attr_reader" + results << MethodDefinition.new( + nesting.dup, + :"#{name}=", + location, + comments + ) + end + end + end + + super + end + + def visit_def(node) + name = node.name.value.to_sym + location = + Location.new(node.location.start_line, node.location.start_column) + + results << if node.target.nil? + MethodDefinition.new( + nesting.dup, + name, + location, + comments_for(node) + ) + else + SingletonMethodDefinition.new( + nesting.dup, + name, + location, + comments_for(node) + ) + end + + super + end + + def visit_module(node) + names = node.constant.accept(ConstantNameVisitor.new) + nesting << names + + location = + Location.new(node.location.start_line, node.location.start_column) + + results << ModuleDefinition.new( + nesting.dup, + names.last, + location, + comments_for(node) + ) + + super + nesting.pop + end + + def visit_program(node) + super + results + end + + def visit_statements(node) + @statements = node + super + end + end + + private + + def comments_for(node) + comments = [] + + body = statements.body + line = node.location.start_line - 1 + index = body.index(node) + return comments if index.nil? + + index -= 1 + while index >= 0 && body[index].is_a?(Comment) && + (line - body[index].location.start_line < 2) + comments.unshift(body[index].value) + line = body[index].location.start_line + index -= 1 + end + + comments + end + end + + def index(source) + SyntaxTree.parse(source).accept(IndexVisitor.new) + end + + def index_file(filepath) + index(SyntaxTree.read(filepath)) + end + end + + # The class defined here is used to perform the indexing, depending on what + # functionality is available from the runtime. + INDEX_BACKEND = + defined?(RubyVM::InstructionSequence) ? ISeqBackend : ParserBackend + + # This method accepts source code and then indexes it. + def self.index(source, backend: INDEX_BACKEND.new) + backend.index(source) + end + + # This method accepts a filepath and then indexes it. + def self.index_file(filepath, backend: INDEX_BACKEND.new) + backend.index_file(filepath) + end + end +end diff --git a/lib/syntax_tree/json_visitor.rb b/lib/syntax_tree/json_visitor.rb new file mode 100644 index 00000000..7ad3fba0 --- /dev/null +++ b/lib/syntax_tree/json_visitor.rb @@ -0,0 +1,55 @@ +# frozen_string_literal: true + +require "json" + +module SyntaxTree + # This visitor transforms the AST into a hash that contains only primitives + # that can be easily serialized into JSON. + class JSONVisitor < FieldVisitor + attr_reader :target + + def initialize + @target = nil + end + + private + + def comments(node) + target[:comments] = visit_all(node.comments) + end + + def field(name, value) + target[name] = value.is_a?(Node) ? visit(value) : value + end + + def list(name, values) + target[name] = visit_all(values) + end + + def node(node, type) + previous = @target + @target = { type: type, location: visit_location(node.location) } + yield + @target + ensure + @target = previous + end + + def pairs(name, values) + target[name] = values.map { |(key, value)| [visit(key), visit(value)] } + end + + def text(name, value) + target[name] = value + end + + def visit_location(location) + [ + location.start_line, + location.start_char, + location.end_line, + location.end_char + ] + end + end +end diff --git a/lib/syntax_tree/language_server.rb b/lib/syntax_tree/language_server.rb index 1e305cca..aaa64e9a 100644 --- a/lib/syntax_tree/language_server.rb +++ b/lib/syntax_tree/language_server.rb @@ -2,10 +2,9 @@ require "cgi" require "json" +require "pp" require "uri" -require_relative "language_server/inlay_hints" - module SyntaxTree # Syntax Tree additionally ships with a language server conforming to the # language server protocol. It can be invoked through the CLI by running: @@ -13,79 +12,283 @@ module SyntaxTree # stree lsp # class LanguageServer - attr_reader :input, :output + # This class provides inlay hints for the language server. For more + # information, see the spec here: + # https://github.com/microsoft/language-server-protocol/issues/956. + class InlayHints < Visitor + # This represents a hint that is going to be displayed in the editor. + class Hint + attr_reader :line, :character, :label + + def initialize(line:, character:, label:) + @line = line + @character = character + @label = label + end + + # This is the shape that the LSP expects. + def to_json(*opts) + { + position: { + line: line, + character: character + }, + label: label + }.to_json(*opts) + end + end + + attr_reader :stack, :hints + + def initialize + @stack = [] + @hints = [] + end + + def visit(node) + stack << node + result = super + stack.pop + result + end + + visit_methods do + # Adds parentheses around assignments contained within the default + # values of parameters. For example, + # + # def foo(a = b = c) + # end + # + # becomes + # + # def foo(a = ₍b = c₎) + # end + # + def visit_assign(node) + parentheses(node.location) if stack[-2].is_a?(Params) + super + end + + # Adds parentheses around binary expressions to make it clear which + # subexpression will be evaluated first. For example, + # + # a + b * c + # + # becomes + # + # a + ₍b * c₎ + # + def visit_binary(node) + case stack[-2] + when Assign, OpAssign + parentheses(node.location) + when Binary + parentheses(node.location) if stack[-2].operator != node.operator + end + + super + end + + # Adds parentheses around ternary operators contained within certain + # expressions where it could be confusing which subexpression will get + # evaluated first. For example, + # + # a ? b : c ? d : e + # + # becomes + # + # a ? b : ₍c ? d : e₎ + # + def visit_if_op(node) + case stack[-2] + when Assign, Binary, IfOp, OpAssign + parentheses(node.location) + end + + super + end + + # Adds the implicitly rescued StandardError into a bare rescue clause. + # For example, + # + # begin + # rescue + # end + # + # becomes + # + # begin + # rescue StandardError + # end + # + def visit_rescue(node) + if node.exception.nil? + hints << Hint.new( + line: node.location.start_line - 1, + character: node.location.start_column + "rescue".length, + label: " StandardError" + ) + end + + super + end + + # Adds parentheses around unary statements using the - operator that are + # contained within Binary nodes. For example, + # + # -a + b + # + # becomes + # + # ₍-a₎ + b + # + def visit_unary(node) + if stack[-2].is_a?(Binary) && (node.operator == "-") + parentheses(node.location) + end + + super + end + end + + private + + def parentheses(location) + hints << Hint.new( + line: location.start_line - 1, + character: location.start_column, + label: "₍" + ) + + hints << Hint.new( + line: location.end_line - 1, + character: location.end_column, + label: "₎" + ) + end + end + + # This is a small module that effectively mirrors pattern matching. We're + # using it so that we can support truffleruby without having to ignore the + # language server. + module Request + # Represents a hash pattern. + class Shape + attr_reader :values + + def initialize(values) + @values = values + end - def initialize(input: $stdin, output: $stdout) + def ===(other) + values.all? do |key, value| + value == :any ? other.key?(key) : value === other[key] + end + end + end + + # Represents an array pattern. + class Tuple + attr_reader :values + + def initialize(values) + @values = values + end + + def ===(other) + values.each_with_index.all? { |value, index| value === other[index] } + end + end + + def self.[](value) + case value + when Array + Tuple.new(value.map { |child| self[child] }) + when Hash + Shape.new(value.transform_values { |child| self[child] }) + else + value + end + end + end + + attr_reader :input, :output, :print_width + + def initialize( + input: $stdin, + output: $stdout, + print_width: DEFAULT_PRINT_WIDTH, + ignore_files: [] + ) @input = input.binmode @output = output.binmode + @print_width = print_width + @ignore_files = ignore_files end + # rubocop:disable Layout/LineLength def run store = Hash.new do |hash, uri| - hash[uri] = File.binread(CGI.unescape(URI.parse(uri).path)) + filepath = CGI.unescape(URI.parse(uri).path) + File.exist?(filepath) ? (hash[uri] = File.read(filepath)) : nil end while (headers = input.gets("\r\n\r\n")) source = input.read(headers[/Content-Length: (\d+)/i, 1].to_i) request = JSON.parse(source, symbolize_names: true) + # stree-ignore case request - in { method: "initialize", id: } + when Request[method: "initialize", id: :any] store.clear - write(id: id, result: { capabilities: capabilities }) - in method: "initialized" + write(id: request[:id], result: { capabilities: capabilities }) + when Request[method: "initialized"] # ignored - in method: "shutdown" + when Request[method: "shutdown"] # tolerate missing ID to be a good citizen store.clear + write(id: request[:id], result: {}) return - in { - method: "textDocument/didChange", - params: { textDocument: { uri: }, contentChanges: [{ text: }, *] } - } - store[uri] = text - in { - method: "textDocument/didOpen", - params: { textDocument: { uri:, text: } } - } - store[uri] = text - in { - method: "textDocument/didClose", params: { textDocument: { uri: } } - } - store.delete(uri) - in { - method: "textDocument/formatting", - id:, - params: { textDocument: { uri: } } - } - write(id: id, result: [format(store[uri])]) - in { - method: "textDocument/inlayHints", - id:, - params: { textDocument: { uri: } } - } - write(id: id, result: inlay_hints(store[uri])) - in { - method: "syntaxTree/visualizing", - id:, - params: { textDocument: { uri: } } - } - output = [] - PP.pp(SyntaxTree.parse(store[uri]), output) - write(id: id, result: output.join) - in method: %r{\$/.+} + when Request[method: "textDocument/didChange", params: { textDocument: { uri: :any }, contentChanges: [{ text: :any }] }] + store[request.dig(:params, :textDocument, :uri)] = request.dig(:params, :contentChanges, 0, :text) + when Request[method: "textDocument/didOpen", params: { textDocument: { uri: :any, text: :any } }] + store[request.dig(:params, :textDocument, :uri)] = request.dig(:params, :textDocument, :text) + when Request[method: "textDocument/didClose", params: { textDocument: { uri: :any } }] + store.delete(request.dig(:params, :textDocument, :uri)) + when Request[method: "textDocument/formatting", id: :any, params: { textDocument: { uri: :any } }] + uri = request.dig(:params, :textDocument, :uri) + filepath = uri.split("///").last + ignore = @ignore_files.any? do |glob| + File.fnmatch(glob, filepath) + end + contents = store[uri] + write(id: request[:id], result: contents && !ignore ? format(contents, uri.split(".").last) : nil) + when Request[method: "textDocument/inlayHint", id: :any, params: { textDocument: { uri: :any } }] + uri = request.dig(:params, :textDocument, :uri) + contents = store[uri] + write(id: request[:id], result: contents ? inlay_hints(contents) : nil) + when Request[method: "syntaxTree/visualizing", id: :any, params: { textDocument: { uri: :any } }] + uri = request.dig(:params, :textDocument, :uri) + write(id: request[:id], result: PP.pp(SyntaxTree.parse(store[uri]), +"")) + when Request[method: %r{\$/.+}] + # ignored + when Request[method: "textDocument/documentColor", params: { textDocument: { uri: :any } }] # ignored else - raise "Unhandled: #{request}" + raise ArgumentError, "Unhandled: #{request}" end end end + # rubocop:enable Layout/LineLength private def capabilities { documentFormattingProvider: true, + inlayHintProvider: { + resolveProvider: false + }, textDocumentSync: { change: 1, openClose: true @@ -93,37 +296,38 @@ def capabilities } end - def format(source) - { - range: { - start: { - line: 0, - character: 0 - }, - end: { - line: source.lines.size + 1, - character: 0 - } - }, - newText: SyntaxTree.format(source) - } - end + def format(source, extension) + text = SyntaxTree::HANDLERS[".#{extension}"].format(source, print_width) - def log(message) - write(method: "window/logMessage", params: { type: 4, message: message }) + [ + { + range: { + start: { + line: 0, + character: 0 + }, + end: { + line: source.lines.size + 1, + character: 0 + } + }, + newText: text + } + ] + rescue Parser::ParseError + # If there is a parse error, then we're not going to return any formatting + # changes for this source. + nil end def inlay_hints(source) - inlay_hints = InlayHints.find(SyntaxTree.parse(source)) - serialize = ->(position, text) { { position: position, text: text } } - - { - before: inlay_hints.before.map(&serialize), - after: inlay_hints.after.map(&serialize) - } + visitor = InlayHints.new + SyntaxTree.parse(source).accept(visitor) + visitor.hints rescue Parser::ParseError # If there is a parse error, then we're not going to return any inlay # hints for this source. + [] end def write(value) @@ -131,5 +335,9 @@ def write(value) output.print("Content-Length: #{response.bytesize}\r\n\r\n#{response}") output.flush end + + def log(message) + write(method: "window/logMessage", params: { type: 4, message: message }) + end end end diff --git a/lib/syntax_tree/language_server/inlay_hints.rb b/lib/syntax_tree/language_server/inlay_hints.rb deleted file mode 100644 index 69fc5ce4..00000000 --- a/lib/syntax_tree/language_server/inlay_hints.rb +++ /dev/null @@ -1,126 +0,0 @@ -# frozen_string_literal: true - -module SyntaxTree - class LanguageServer - # This class provides inlay hints for the language server. It is loosely - # designed around the LSP spec, but existed before the spec was finalized so - # is a little different for now. - # - # For more information, see the spec here: - # https://github.com/microsoft/language-server-protocol/issues/956. - # - class InlayHints < Visitor - attr_reader :stack, :before, :after - - def initialize - @stack = [] - @before = Hash.new { |hash, key| hash[key] = +"" } - @after = Hash.new { |hash, key| hash[key] = +"" } - end - - def visit(node) - stack << node - result = super - stack.pop - result - end - - # Adds parentheses around assignments contained within the default values - # of parameters. For example, - # - # def foo(a = b = c) - # end - # - # becomes - # - # def foo(a = ₍b = c₎) - # end - # - def visit_assign(node) - parentheses(node.location) if stack[-2].is_a?(Params) - end - - # Adds parentheses around binary expressions to make it clear which - # subexpression will be evaluated first. For example, - # - # a + b * c - # - # becomes - # - # a + ₍b * c₎ - # - def visit_binary(node) - case stack[-2] - in Assign | OpAssign - parentheses(node.location) - in Binary[operator: operator] if operator != node.operator - parentheses(node.location) - else - end - end - - # Adds parentheses around ternary operators contained within certain - # expressions where it could be confusing which subexpression will get - # evaluated first. For example, - # - # a ? b : c ? d : e - # - # becomes - # - # a ? b : ₍c ? d : e₎ - # - def visit_if_op(node) - if stack[-2] in Assign | Binary | IfOp | OpAssign - parentheses(node.location) - end - end - - # Adds the implicitly rescued StandardError into a bare rescue clause. For - # example, - # - # begin - # rescue - # end - # - # becomes - # - # begin - # rescue StandardError - # end - # - def visit_rescue(node) - if node.exception.nil? - after[node.location.start_char + "rescue".length] << " StandardError" - end - end - - # Adds parentheses around unary statements using the - operator that are - # contained within Binary nodes. For example, - # - # -a + b - # - # becomes - # - # ₍-a₎ + b - # - def visit_unary(node) - if stack[-2].is_a?(Binary) && (node.operator == "-") - parentheses(node.location) - end - end - - def self.find(program) - visitor = new - visitor.visit(program) - visitor - end - - private - - def parentheses(location) - before[location.start_char] << "₍" - after[location.end_char] << "₎" - end - end - end -end diff --git a/lib/syntax_tree/match_visitor.rb b/lib/syntax_tree/match_visitor.rb new file mode 100644 index 00000000..ca5bf234 --- /dev/null +++ b/lib/syntax_tree/match_visitor.rb @@ -0,0 +1,120 @@ +# frozen_string_literal: true + +module SyntaxTree + # This visitor transforms the AST into a Ruby pattern matching expression that + # would match correctly against the AST. + class MatchVisitor < FieldVisitor + attr_reader :q + + def initialize(q) + @q = q + end + + def visit(node) + case node + when Node + super + when String + # pp will split up a string on newlines and concat them together using a + # "+" operator. This breaks the pattern matching expression. So instead + # we're going to check here for strings and manually put the entire + # value into the output buffer. + q.text(node.inspect) + else + node.pretty_print(q) + end + end + + private + + def comments(node) + return if node.comments.empty? + + q.nest(0) do + q.text("comments: [") + q.indent do + q.breakable("") + q.seplist(node.comments) { |comment| visit(comment) } + end + q.breakable("") + q.text("]") + end + end + + def field(name, value) + q.nest(0) do + q.text(name) + q.text(": ") + visit(value) + end + end + + def list(name, values) + q.group do + q.text(name) + q.text(": [") + q.indent do + q.breakable("") + q.seplist(values) { |value| visit(value) } + end + q.breakable("") + q.text("]") + end + end + + def node(node, _type) + items = [] + q.with_target(items) { yield } + + if items.empty? + q.text(node.class.name) + return + end + + q.group do + q.text(node.class.name) + q.text("[") + q.indent do + q.breakable("") + q.seplist(items) { |item| q.target << item } + end + q.breakable("") + q.text("]") + end + end + + def pairs(name, values) + q.group do + q.text(name) + q.text(": [") + q.indent do + q.breakable("") + q.seplist(values) do |(key, value)| + q.group do + q.text("[") + q.indent do + q.breakable("") + visit(key) + q.text(",") + q.breakable + visit(value || nil) + end + q.breakable("") + q.text("]") + end + end + end + q.breakable("") + q.text("]") + end + end + + def text(name, value) + q.nest(0) do + q.text(name) + q.text(": ") + value.pretty_print(q) + end + end + end +end diff --git a/lib/syntax_tree/mermaid.rb b/lib/syntax_tree/mermaid.rb new file mode 100644 index 00000000..68ea4734 --- /dev/null +++ b/lib/syntax_tree/mermaid.rb @@ -0,0 +1,177 @@ +# frozen_string_literal: true + +require "cgi" +require "stringio" + +module SyntaxTree + # This module is responsible for rendering mermaid (https://mermaid.js.org/) + # flow charts. + module Mermaid + # This is the main class that handles rendering a flowchart. It keeps track + # of its nodes and links and renders them according to the mermaid syntax. + class FlowChart + attr_reader :output, :prefix, :nodes, :links + + def initialize + @output = StringIO.new + @output.puts("flowchart TD") + @prefix = " " + + @nodes = {} + @links = [] + end + + # Retrieve a node that has already been added to the flowchart by its id. + def fetch(id) + nodes.fetch(id) + end + + # Add a link to the flowchart between two nodes with an optional label. + def link(from, to, label = nil, type: :directed, color: nil) + link = Link.new(from, to, label, type, color) + links << link + + output.puts("#{prefix}#{link.render}") + link + end + + # Add a node to the flowchart with an optional label. + def node(id, label = " ", shape: :rectangle) + node = Node.new(id, label, shape) + nodes[id] = node + + output.puts("#{prefix}#{nodes[id].render}") + node + end + + # Add a subgraph to the flowchart. Within the given block, all of the + # nodes will be rendered within the subgraph. + def subgraph(label) + output.puts("#{prefix}subgraph #{Mermaid.escape(label)}") + + previous = prefix + @prefix = "#{prefix} " + + begin + yield + ensure + @prefix = previous + output.puts("#{prefix}end") + end + end + + # Return the rendered flowchart. + def render + links.each_with_index do |link, index| + if link.color + output.puts("#{prefix}linkStyle #{index} stroke:#{link.color}") + end + end + + output.string + end + end + + # This class represents a link between two nodes in a flowchart. It is not + # meant to be interacted with directly, but rather used as a data structure + # by the FlowChart class. + class Link + TYPES = %i[directed dotted].freeze + COLORS = %i[green red].freeze + + attr_reader :from, :to, :label, :type, :color + + def initialize(from, to, label, type, color) + raise unless TYPES.include?(type) + raise if color && !COLORS.include?(color) + + @from = from + @to = to + @label = label + @type = type + @color = color + end + + def render + left_side, right_side, full_side = sides + + if label + escaped = Mermaid.escape(label) + "#{from.id} #{left_side} #{escaped} #{right_side} #{to.id}" + else + "#{from.id} #{full_side} #{to.id}" + end + end + + private + + def sides + case type + when :directed + %w[-- --> -->] + when :dotted + %w[-. .-> -.->] + end + end + end + + # This class represents a node in a flowchart. Unlike the Link class, it can + # be used directly. It is the return value of the #node method, and is meant + # to be passed around to #link methods to create links between nodes. + class Node + SHAPES = %i[circle rectangle rounded stadium].freeze + + attr_reader :id, :label, :shape + + def initialize(id, label, shape) + raise unless SHAPES.include?(shape) + + @id = id + @label = label + @shape = shape + end + + def render + left_bound, right_bound = bounds + "#{id}#{left_bound}#{Mermaid.escape(label)}#{right_bound}" + end + + private + + def bounds + case shape + when :circle + %w[(( ))] + when :rectangle + ["[", "]"] + when :rounded + %w[( )] + when :stadium + ["([", "])"] + end + end + end + + class << self + # Escape a label to be used in the mermaid syntax. This is used to escape + # HTML entities such that they render properly within the quotes. + def escape(label) + "\"#{CGI.escapeHTML(label)}\"" + end + + # Create a new flowchart. If a block is given, it will be yielded to and + # the flowchart will be rendered. Otherwise, the flowchart will be + # returned. + def flowchart + flowchart = FlowChart.new + + if block_given? + yield flowchart + flowchart.render + else + flowchart + end + end + end + end +end diff --git a/lib/syntax_tree/mermaid_visitor.rb b/lib/syntax_tree/mermaid_visitor.rb new file mode 100644 index 00000000..fc9f6706 --- /dev/null +++ b/lib/syntax_tree/mermaid_visitor.rb @@ -0,0 +1,69 @@ +# frozen_string_literal: true + +module SyntaxTree + # This visitor transforms the AST into a mermaid flow chart. + class MermaidVisitor < FieldVisitor + attr_reader :flowchart, :target + + def initialize + @flowchart = Mermaid.flowchart + @target = nil + end + + def visit_program(node) + super + flowchart.render + end + + private + + def comments(node) + # Ignore + end + + def field(name, value) + case value + when nil + # skip + when Node + flowchart.link(target, visit(value), name) + else + to = + flowchart.node("#{target.id}_#{name}", value.inspect, shape: :stadium) + flowchart.link(target, to, name) + end + end + + def list(name, values) + values.each_with_index do |value, index| + field("#{name}[#{index}]", value) + end + end + + def node(node, type) + previous_target = target + + begin + @target = flowchart.node("node_#{node.object_id}", type) + yield + @target + ensure + @target = previous_target + end + end + + def pairs(name, values) + values.each_with_index do |(key, value), index| + to = flowchart.node("#{target.id}_#{name}_#{index}", shape: :circle) + + flowchart.link(target, to, "#{name}[#{index}]") + flowchart.link(to, visit(key), "[0]") + flowchart.link(to, visit(value), "[1]") if value + end + end + + def text(name, value) + field(name, value) + end + end +end diff --git a/lib/syntax_tree/mutation_visitor.rb b/lib/syntax_tree/mutation_visitor.rb new file mode 100644 index 00000000..0b4b9357 --- /dev/null +++ b/lib/syntax_tree/mutation_visitor.rb @@ -0,0 +1,924 @@ +# frozen_string_literal: true + +module SyntaxTree + # This visitor walks through the tree and copies each node as it is being + # visited. This is useful for mutating the tree before it is formatted. + class MutationVisitor < BasicVisitor + attr_reader :mutations + + def initialize + @mutations = [] + end + + # Create a new mutation based on the given query that will mutate the node + # using the given block. The block should return a new node that will take + # the place of the given node in the tree. These blocks frequently make use + # of the `copy` method on nodes to create a new node with the same + # properties as the original node. + def mutate(query, &block) + mutations << [Pattern.new(query).compile, block] + end + + # This is the base visit method for each node in the tree. It first creates + # a copy of the node using the visit_* methods defined below. Then it checks + # each mutation in sequence and calls it if it finds a match. + def visit(node) + return unless node + result = node.accept(self) + + mutations.each do |(pattern, mutation)| + result = mutation.call(result) if pattern.call(result) + end + + result + end + + visit_methods do + # Visit a BEGINBlock node. + def visit_BEGIN(node) + node.copy( + lbrace: visit(node.lbrace), + statements: visit(node.statements) + ) + end + + # Visit a CHAR node. + def visit_CHAR(node) + node.copy + end + + # Visit a ENDBlock node. + def visit_END(node) + node.copy( + lbrace: visit(node.lbrace), + statements: visit(node.statements) + ) + end + + # Visit a EndContent node. + def visit___end__(node) + node.copy + end + + # Visit a AliasNode node. + def visit_alias(node) + node.copy(left: visit(node.left), right: visit(node.right)) + end + + # Visit a ARef node. + def visit_aref(node) + node.copy(index: visit(node.index)) + end + + # Visit a ARefField node. + def visit_aref_field(node) + node.copy(index: visit(node.index)) + end + + # Visit a ArgParen node. + def visit_arg_paren(node) + node.copy(arguments: visit(node.arguments)) + end + + # Visit a Args node. + def visit_args(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a ArgBlock node. + def visit_arg_block(node) + node.copy(value: visit(node.value)) + end + + # Visit a ArgStar node. + def visit_arg_star(node) + node.copy(value: visit(node.value)) + end + + # Visit a ArgsForward node. + def visit_args_forward(node) + node.copy + end + + # Visit a ArrayLiteral node. + def visit_array(node) + node.copy( + lbracket: visit(node.lbracket), + contents: visit(node.contents) + ) + end + + # Visit a AryPtn node. + def visit_aryptn(node) + node.copy( + constant: visit(node.constant), + requireds: visit_all(node.requireds), + rest: visit(node.rest), + posts: visit_all(node.posts) + ) + end + + # Visit a Assign node. + def visit_assign(node) + node.copy(target: visit(node.target)) + end + + # Visit a Assoc node. + def visit_assoc(node) + node.copy + end + + # Visit a AssocSplat node. + def visit_assoc_splat(node) + node.copy + end + + # Visit a Backref node. + def visit_backref(node) + node.copy + end + + # Visit a Backtick node. + def visit_backtick(node) + node.copy + end + + # Visit a BareAssocHash node. + def visit_bare_assoc_hash(node) + node.copy(assocs: visit_all(node.assocs)) + end + + # Visit a Begin node. + def visit_begin(node) + node.copy(bodystmt: visit(node.bodystmt)) + end + + # Visit a PinnedBegin node. + def visit_pinned_begin(node) + node.copy + end + + # Visit a Binary node. + def visit_binary(node) + node.copy + end + + # Visit a BlockVar node. + def visit_block_var(node) + node.copy(params: visit(node.params), locals: visit_all(node.locals)) + end + + # Visit a BlockArg node. + def visit_blockarg(node) + node.copy(name: visit(node.name)) + end + + # Visit a BodyStmt node. + def visit_bodystmt(node) + node.copy( + statements: visit(node.statements), + rescue_clause: visit(node.rescue_clause), + else_clause: visit(node.else_clause), + ensure_clause: visit(node.ensure_clause) + ) + end + + # Visit a Break node. + def visit_break(node) + node.copy(arguments: visit(node.arguments)) + end + + # Visit a Call node. + def visit_call(node) + node.copy( + receiver: visit(node.receiver), + operator: node.operator == :"::" ? :"::" : visit(node.operator), + message: node.message == :call ? :call : visit(node.message), + arguments: visit(node.arguments) + ) + end + + # Visit a Case node. + def visit_case(node) + node.copy( + keyword: visit(node.keyword), + value: visit(node.value), + consequent: visit(node.consequent) + ) + end + + # Visit a RAssign node. + def visit_rassign(node) + node.copy(operator: visit(node.operator)) + end + + # Visit a ClassDeclaration node. + def visit_class(node) + node.copy( + constant: visit(node.constant), + superclass: visit(node.superclass), + bodystmt: visit(node.bodystmt) + ) + end + + # Visit a Comma node. + def visit_comma(node) + node.copy + end + + # Visit a Command node. + def visit_command(node) + node.copy( + message: visit(node.message), + arguments: visit(node.arguments), + block: visit(node.block) + ) + end + + # Visit a CommandCall node. + def visit_command_call(node) + node.copy( + operator: node.operator == :"::" ? :"::" : visit(node.operator), + message: visit(node.message), + arguments: visit(node.arguments), + block: visit(node.block) + ) + end + + # Visit a Comment node. + def visit_comment(node) + node.copy + end + + # Visit a Const node. + def visit_const(node) + node.copy + end + + # Visit a ConstPathField node. + def visit_const_path_field(node) + node.copy(constant: visit(node.constant)) + end + + # Visit a ConstPathRef node. + def visit_const_path_ref(node) + node.copy(constant: visit(node.constant)) + end + + # Visit a ConstRef node. + def visit_const_ref(node) + node.copy(constant: visit(node.constant)) + end + + # Visit a CVar node. + def visit_cvar(node) + node.copy + end + + # Visit a Def node. + def visit_def(node) + node.copy( + target: visit(node.target), + operator: visit(node.operator), + name: visit(node.name), + params: visit(node.params), + bodystmt: visit(node.bodystmt) + ) + end + + # Visit a Defined node. + def visit_defined(node) + node.copy + end + + # Visit a Block node. + def visit_block(node) + node.copy( + opening: visit(node.opening), + block_var: visit(node.block_var), + bodystmt: visit(node.bodystmt) + ) + end + + # Visit a RangeNode node. + def visit_range(node) + node.copy( + left: visit(node.left), + operator: visit(node.operator), + right: visit(node.right) + ) + end + + # Visit a DynaSymbol node. + def visit_dyna_symbol(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a Else node. + def visit_else(node) + node.copy( + keyword: visit(node.keyword), + statements: visit(node.statements) + ) + end + + # Visit a Elsif node. + def visit_elsif(node) + node.copy( + statements: visit(node.statements), + consequent: visit(node.consequent) + ) + end + + # Visit a EmbDoc node. + def visit_embdoc(node) + node.copy + end + + # Visit a EmbExprBeg node. + def visit_embexpr_beg(node) + node.copy + end + + # Visit a EmbExprEnd node. + def visit_embexpr_end(node) + node.copy + end + + # Visit a EmbVar node. + def visit_embvar(node) + node.copy + end + + # Visit a Ensure node. + def visit_ensure(node) + node.copy( + keyword: visit(node.keyword), + statements: visit(node.statements) + ) + end + + # Visit a ExcessedComma node. + def visit_excessed_comma(node) + node.copy + end + + # Visit a Field node. + def visit_field(node) + node.copy( + operator: node.operator == :"::" ? :"::" : visit(node.operator), + name: visit(node.name) + ) + end + + # Visit a FloatLiteral node. + def visit_float(node) + node.copy + end + + # Visit a FndPtn node. + def visit_fndptn(node) + node.copy( + constant: visit(node.constant), + left: visit(node.left), + values: visit_all(node.values), + right: visit(node.right) + ) + end + + # Visit a For node. + def visit_for(node) + node.copy(index: visit(node.index), statements: visit(node.statements)) + end + + # Visit a GVar node. + def visit_gvar(node) + node.copy + end + + # Visit a HashLiteral node. + def visit_hash(node) + node.copy(lbrace: visit(node.lbrace), assocs: visit_all(node.assocs)) + end + + # Visit a Heredoc node. + def visit_heredoc(node) + node.copy( + beginning: visit(node.beginning), + ending: visit(node.ending), + parts: visit_all(node.parts) + ) + end + + # Visit a HeredocBeg node. + def visit_heredoc_beg(node) + node.copy + end + + # Visit a HeredocEnd node. + def visit_heredoc_end(node) + node.copy + end + + # Visit a HshPtn node. + def visit_hshptn(node) + node.copy( + constant: visit(node.constant), + keywords: + node.keywords.map { |label, value| [visit(label), visit(value)] }, + keyword_rest: visit(node.keyword_rest) + ) + end + + # Visit a Ident node. + def visit_ident(node) + node.copy + end + + # Visit a IfNode node. + def visit_if(node) + node.copy( + predicate: visit(node.predicate), + statements: visit(node.statements), + consequent: visit(node.consequent) + ) + end + + # Visit a IfOp node. + def visit_if_op(node) + node.copy + end + + # Visit a Imaginary node. + def visit_imaginary(node) + node.copy + end + + # Visit a In node. + def visit_in(node) + node.copy( + statements: visit(node.statements), + consequent: visit(node.consequent) + ) + end + + # Visit a Int node. + def visit_int(node) + node.copy + end + + # Visit a IVar node. + def visit_ivar(node) + node.copy + end + + # Visit a Kw node. + def visit_kw(node) + node.copy + end + + # Visit a KwRestParam node. + def visit_kwrest_param(node) + node.copy(name: visit(node.name)) + end + + # Visit a Label node. + def visit_label(node) + node.copy + end + + # Visit a LabelEnd node. + def visit_label_end(node) + node.copy + end + + # Visit a Lambda node. + def visit_lambda(node) + node.copy( + params: visit(node.params), + statements: visit(node.statements) + ) + end + + # Visit a LambdaVar node. + def visit_lambda_var(node) + node.copy(params: visit(node.params), locals: visit_all(node.locals)) + end + + # Visit a LBrace node. + def visit_lbrace(node) + node.copy + end + + # Visit a LBracket node. + def visit_lbracket(node) + node.copy + end + + # Visit a LParen node. + def visit_lparen(node) + node.copy + end + + # Visit a MAssign node. + def visit_massign(node) + node.copy(target: visit(node.target)) + end + + # Visit a MethodAddBlock node. + def visit_method_add_block(node) + node.copy(call: visit(node.call), block: visit(node.block)) + end + + # Visit a MLHS node. + def visit_mlhs(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a MLHSParen node. + def visit_mlhs_paren(node) + node.copy(contents: visit(node.contents)) + end + + # Visit a ModuleDeclaration node. + def visit_module(node) + node.copy( + constant: visit(node.constant), + bodystmt: visit(node.bodystmt) + ) + end + + # Visit a MRHS node. + def visit_mrhs(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a Next node. + def visit_next(node) + node.copy(arguments: visit(node.arguments)) + end + + # Visit a Op node. + def visit_op(node) + node.copy + end + + # Visit a OpAssign node. + def visit_opassign(node) + node.copy(target: visit(node.target), operator: visit(node.operator)) + end + + # Visit a Params node. + def visit_params(node) + node.copy( + requireds: visit_all(node.requireds), + optionals: + node.optionals.map { |ident, value| [visit(ident), visit(value)] }, + rest: visit(node.rest), + posts: visit_all(node.posts), + keywords: + node.keywords.map { |ident, value| [visit(ident), visit(value)] }, + keyword_rest: + node.keyword_rest == :nil ? :nil : visit(node.keyword_rest), + block: visit(node.block) + ) + end + + # Visit a Paren node. + def visit_paren(node) + node.copy(lparen: visit(node.lparen), contents: visit(node.contents)) + end + + # Visit a Period node. + def visit_period(node) + node.copy + end + + # Visit a Program node. + def visit_program(node) + node.copy(statements: visit(node.statements)) + end + + # Visit a QSymbols node. + def visit_qsymbols(node) + node.copy( + beginning: visit(node.beginning), + elements: visit_all(node.elements) + ) + end + + # Visit a QSymbolsBeg node. + def visit_qsymbols_beg(node) + node.copy + end + + # Visit a QWords node. + def visit_qwords(node) + node.copy( + beginning: visit(node.beginning), + elements: visit_all(node.elements) + ) + end + + # Visit a QWordsBeg node. + def visit_qwords_beg(node) + node.copy + end + + # Visit a RationalLiteral node. + def visit_rational(node) + node.copy + end + + # Visit a RBrace node. + def visit_rbrace(node) + node.copy + end + + # Visit a RBracket node. + def visit_rbracket(node) + node.copy + end + + # Visit a Redo node. + def visit_redo(node) + node.copy + end + + # Visit a RegexpContent node. + def visit_regexp_content(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a RegexpBeg node. + def visit_regexp_beg(node) + node.copy + end + + # Visit a RegexpEnd node. + def visit_regexp_end(node) + node.copy + end + + # Visit a RegexpLiteral node. + def visit_regexp_literal(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a RescueEx node. + def visit_rescue_ex(node) + node.copy(variable: visit(node.variable)) + end + + # Visit a Rescue node. + def visit_rescue(node) + node.copy( + keyword: visit(node.keyword), + exception: visit(node.exception), + statements: visit(node.statements), + consequent: visit(node.consequent) + ) + end + + # Visit a RescueMod node. + def visit_rescue_mod(node) + node.copy + end + + # Visit a RestParam node. + def visit_rest_param(node) + node.copy(name: visit(node.name)) + end + + # Visit a Retry node. + def visit_retry(node) + node.copy + end + + # Visit a Return node. + def visit_return(node) + node.copy(arguments: visit(node.arguments)) + end + + # Visit a RParen node. + def visit_rparen(node) + node.copy + end + + # Visit a SClass node. + def visit_sclass(node) + node.copy(bodystmt: visit(node.bodystmt)) + end + + # Visit a Statements node. + def visit_statements(node) + node.copy(body: visit_all(node.body)) + end + + # Visit a StringContent node. + def visit_string_content(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a StringConcat node. + def visit_string_concat(node) + node.copy(left: visit(node.left), right: visit(node.right)) + end + + # Visit a StringDVar node. + def visit_string_dvar(node) + node.copy(variable: visit(node.variable)) + end + + # Visit a StringEmbExpr node. + def visit_string_embexpr(node) + node.copy(statements: visit(node.statements)) + end + + # Visit a StringLiteral node. + def visit_string_literal(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a Super node. + def visit_super(node) + node.copy(arguments: visit(node.arguments)) + end + + # Visit a SymBeg node. + def visit_symbeg(node) + node.copy + end + + # Visit a SymbolContent node. + def visit_symbol_content(node) + node.copy(value: visit(node.value)) + end + + # Visit a SymbolLiteral node. + def visit_symbol_literal(node) + node.copy(value: visit(node.value)) + end + + # Visit a Symbols node. + def visit_symbols(node) + node.copy( + beginning: visit(node.beginning), + elements: visit_all(node.elements) + ) + end + + # Visit a SymbolsBeg node. + def visit_symbols_beg(node) + node.copy + end + + # Visit a TLambda node. + def visit_tlambda(node) + node.copy + end + + # Visit a TLamBeg node. + def visit_tlambeg(node) + node.copy + end + + # Visit a TopConstField node. + def visit_top_const_field(node) + node.copy(constant: visit(node.constant)) + end + + # Visit a TopConstRef node. + def visit_top_const_ref(node) + node.copy(constant: visit(node.constant)) + end + + # Visit a TStringBeg node. + def visit_tstring_beg(node) + node.copy + end + + # Visit a TStringContent node. + def visit_tstring_content(node) + node.copy + end + + # Visit a TStringEnd node. + def visit_tstring_end(node) + node.copy + end + + # Visit a Not node. + def visit_not(node) + node.copy(statement: visit(node.statement)) + end + + # Visit a Unary node. + def visit_unary(node) + node.copy + end + + # Visit a Undef node. + def visit_undef(node) + node.copy(symbols: visit_all(node.symbols)) + end + + # Visit a UnlessNode node. + def visit_unless(node) + node.copy( + predicate: visit(node.predicate), + statements: visit(node.statements), + consequent: visit(node.consequent) + ) + end + + # Visit a UntilNode node. + def visit_until(node) + node.copy( + predicate: visit(node.predicate), + statements: visit(node.statements) + ) + end + + # Visit a VarField node. + def visit_var_field(node) + node.copy(value: visit(node.value)) + end + + # Visit a VarRef node. + def visit_var_ref(node) + node.copy(value: visit(node.value)) + end + + # Visit a PinnedVarRef node. + def visit_pinned_var_ref(node) + node.copy(value: visit(node.value)) + end + + # Visit a VCall node. + def visit_vcall(node) + node.copy(value: visit(node.value)) + end + + # Visit a VoidStmt node. + def visit_void_stmt(node) + node.copy + end + + # Visit a When node. + def visit_when(node) + node.copy( + arguments: visit(node.arguments), + statements: visit(node.statements), + consequent: visit(node.consequent) + ) + end + + # Visit a WhileNode node. + def visit_while(node) + node.copy( + predicate: visit(node.predicate), + statements: visit(node.statements) + ) + end + + # Visit a Word node. + def visit_word(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a Words node. + def visit_words(node) + node.copy( + beginning: visit(node.beginning), + elements: visit_all(node.elements) + ) + end + + # Visit a WordsBeg node. + def visit_words_beg(node) + node.copy + end + + # Visit a XString node. + def visit_xstring(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a XStringLiteral node. + def visit_xstring_literal(node) + node.copy(parts: visit_all(node.parts)) + end + + # Visit a YieldNode node. + def visit_yield(node) + node.copy(arguments: visit(node.arguments)) + end + + # Visit a ZSuper node. + def visit_zsuper(node) + node.copy + end + end + end +end diff --git a/lib/syntax_tree/node.rb b/lib/syntax_tree/node.rb index d153ef78..96241bb1 100644 --- a/lib/syntax_tree/node.rb +++ b/lib/syntax_tree/node.rb @@ -83,6 +83,20 @@ def self.fixed(line:, char:, column:) end_column: column ) end + + # A convenience method that is typically used when you don't care about the + # location of a node, but need to create a Location instance to pass to a + # constructor. + def self.default + new( + start_line: 1, + start_char: 0, + start_column: 0, + end_line: 1, + end_char: 0, + end_column: 0 + ) + end end # This is the parent node of all of the syntax tree nodes. It's pretty much @@ -112,18 +126,40 @@ def format(q) raise NotImplementedError end + def start_char + location.start_char + end + + def end_char + location.end_char + end + def pretty_print(q) - visitor = Visitor::PrettyPrintVisitor.new(q) - visitor.visit(self) + accept(PrettyPrintVisitor.new(q)) end def to_json(*opts) - visitor = Visitor::JSONVisitor.new - visitor.visit(self).to_json(*opts) + accept(JSONVisitor.new).to_json(*opts) + end + + def to_mermaid + accept(MermaidVisitor.new) end def construct_keys - PP.format(+"") { |q| Visitor::MatchVisitor.new(q).visit(self) } + PrettierPrint.format(+"") { |q| accept(MatchVisitor.new(q)) } + end + end + + # When we're implementing the === operator for a node, we oftentimes need to + # compare two arrays. We want to skip over the === definition of array and use + # our own here, so we do that using this module. + module ArrayMatch + def self.call(left, right) + left.length === right.length && + left + .zip(right) + .all? { |left_value, right_value| left_value === right_value } end end @@ -146,11 +182,11 @@ class BEGINBlock < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(lbrace:, statements:, location:, comments: []) + def initialize(lbrace:, statements:, location:) @lbrace = lbrace @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -161,6 +197,18 @@ def child_nodes [lbrace, statements] end + def copy(lbrace: nil, statements: nil, location: nil) + node = + BEGINBlock.new( + lbrace: lbrace || self.lbrace, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -177,13 +225,18 @@ def format(q) q.text("BEGIN ") q.format(lbrace) q.indent do - q.breakable + q.breakable_space q.format(statements) end - q.breakable + q.breakable_space q.text("}") end end + + def ===(other) + other.is_a?(BEGINBlock) && lbrace === other.lbrace && + statements === other.statements + end end # CHAR irepresents a single codepoint in the script encoding. @@ -199,10 +252,10 @@ class CHAR < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -213,6 +266,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + CHAR.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -224,10 +288,14 @@ def format(q) q.text(value) else q.text(q.quote) - q.text(value[1] == "\"" ? "\\\"" : value[1]) + q.text(value[1] == q.quote ? "\\#{q.quote}" : value[1]) q.text(q.quote) end end + + def ===(other) + other.is_a?(CHAR) && value === other.value + end end # ENDBlock represents the use of the +END+ keyword, which hooks into the @@ -249,11 +317,11 @@ class ENDBlock < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(lbrace:, statements:, location:, comments: []) + def initialize(lbrace:, statements:, location:) @lbrace = lbrace @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -264,6 +332,18 @@ def child_nodes [lbrace, statements] end + def copy(lbrace: nil, statements: nil, location: nil) + node = + ENDBlock.new( + lbrace: lbrace || self.lbrace, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -280,13 +360,18 @@ def format(q) q.text("END ") q.format(lbrace) q.indent do - q.breakable + q.breakable_space q.format(statements) end - q.breakable + q.breakable_space q.text("}") end end + + def ===(other) + other.is_a?(ENDBlock) && lbrace === other.lbrace && + statements === other.statements + end end # EndContent represents the use of __END__ syntax, which allows individual @@ -305,10 +390,10 @@ class EndContent < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -319,6 +404,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + EndContent.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -327,10 +423,24 @@ def deconstruct_keys(_keys) def format(q) q.text("__END__") - q.breakable(force: true) + q.breakable_force + + first = true + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end - separator = -> { q.breakable(indent: false, force: true) } - q.seplist(value.split(/\r?\n/, -1), separator) { |line| q.text(line) } + q.text(line) + end + + q.breakable_return if value.end_with?("\n") + end + + def ===(other) + other.is_a?(EndContent) && value === other.value end end @@ -345,11 +455,12 @@ def format(q) # can either provide bare words (like the example above) or you can provide # symbols (note that this includes dynamic symbols like # :"left-#{middle}-right"). - class Alias < Node + class AliasNode < Node # Formats an argument to the alias keyword. For symbol literals it uses the # value of the symbol directly to look like bare words. class AliasArgumentFormatter - # [DynaSymbol | SymbolLiteral] the argument being passed to alias + # [Backref | DynaSymbol | GVar | SymbolLiteral] the argument being passed + # to alias attr_reader :argument def initialize(argument) @@ -373,20 +484,20 @@ def format(q) end end - # [DynaSymbol | SymbolLiteral] the new name of the method + # [DynaSymbol | GVar | SymbolLiteral] the new name of the method attr_reader :left - # [DynaSymbol | SymbolLiteral] the old name of the method + # [Backref | DynaSymbol | GVar | SymbolLiteral] the old name of the method attr_reader :right # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(left:, right:, location:, comments: []) + def initialize(left:, right:, location:) @left = left @right = right @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -397,6 +508,18 @@ def child_nodes [left, right] end + def copy(left: nil, right: nil, location: nil) + node = + AliasNode.new( + left: left || self.left, + right: right || self.right, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -412,12 +535,20 @@ def format(q) q.format(left_argument, stackable: false) q.group do q.nest(keyword.length) do - q.breakable(force: left_argument.comments.any?) + left_argument.comments.any? ? q.breakable_force : q.breakable_space q.format(AliasArgumentFormatter.new(right), stackable: false) end end end end + + def ===(other) + other.is_a?(AliasNode) && left === other.left && right === other.right + end + + def var_alias? + left.is_a?(GVar) + end end # ARef represents when you're pulling a value out of a collection at a @@ -434,7 +565,7 @@ def format(q) # collection[] # class ARef < Node - # [untyped] the value being indexed + # [Node] the value being indexed attr_reader :collection # [nil | Args] the value being passed within the brackets @@ -443,11 +574,11 @@ class ARef < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(collection:, index:, location:, comments: []) + def initialize(collection:, index:, location:) @collection = collection @index = index @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -458,6 +589,18 @@ def child_nodes [collection, index] end + def copy(collection: nil, index: nil, location: nil) + node = + ARef.new( + collection: collection || self.collection, + index: index || self.index, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -476,15 +619,20 @@ def format(q) if index q.indent do - q.breakable("") + q.breakable_empty q.format(index) end - q.breakable("") + q.breakable_empty end q.text("]") end end + + def ===(other) + other.is_a?(ARef) && collection === other.collection && + index === other.index + end end # ARefField represents assigning values into collections at specific indices. @@ -495,7 +643,7 @@ def format(q) # collection[index] = value # class ARefField < Node - # [untyped] the value being indexed + # [Node] the value being indexed attr_reader :collection # [nil | Args] the value being passed within the brackets @@ -504,11 +652,11 @@ class ARefField < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(collection:, index:, location:, comments: []) + def initialize(collection:, index:, location:) @collection = collection @index = index @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -519,6 +667,18 @@ def child_nodes [collection, index] end + def copy(collection: nil, index: nil, location: nil) + node = + ARefField.new( + collection: collection || self.collection, + index: index || self.index, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -537,15 +697,20 @@ def format(q) if index q.indent do - q.breakable("") + q.breakable_empty q.format(index) end - q.breakable("") + q.breakable_empty end q.text("]") end end + + def ===(other) + other.is_a?(ARefField) && collection === other.collection && + index === other.index + end end # ArgParen represents wrapping arguments to a method inside a set of @@ -567,10 +732,10 @@ class ArgParen < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(arguments:, location:, comments: []) + def initialize(arguments:, location:) @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -581,6 +746,17 @@ def child_nodes [arguments] end + def copy(arguments: nil, location: nil) + node = + ArgParen.new( + arguments: arguments || self.arguments, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -593,12 +769,46 @@ def format(q) return end - q.group(0, "(", ")") do + q.text("(") + q.group do q.indent do - q.breakable("") + q.breakable_empty q.format(arguments) + q.if_break { q.text(",") } if q.trailing_comma? && trailing_comma? end - q.breakable("") + q.breakable_empty + end + q.text(")") + end + + def ===(other) + other.is_a?(ArgParen) && arguments === other.arguments + end + + def arity + arguments&.arity || 0 + end + + private + + def trailing_comma? + arguments = self.arguments + return false unless arguments.is_a?(Args) + + parts = arguments.parts + if parts.last.is_a?(ArgBlock) + # If the last argument is a block, then we can't put a trailing comma + # after it without resulting in a syntax error. + false + elsif (parts.length == 1) && (part = parts.first) && + (part.is_a?(Command) || part.is_a?(CommandCall)) + # If the only argument is a command or command call, then a trailing + # comma would be parsed as part of that expression instead of on this + # one, so we don't want to add a trailing comma. + false + else + # Otherwise, we should be okay to add a trailing comma. + true end end end @@ -609,16 +819,16 @@ def format(q) # method(first, second, third) # class Args < Node - # [Array[ untyped ]] the arguments that this node wraps + # [Array[ Node ]] the arguments that this node wraps attr_reader :parts # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, location:, comments: []) + def initialize(parts:, location:) @parts = parts @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -629,6 +839,17 @@ def child_nodes parts end + def copy(parts: nil, location: nil) + node = + Args.new( + parts: parts || self.parts, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -638,6 +859,25 @@ def deconstruct_keys(_keys) def format(q) q.seplist(parts) { |part| q.format(part) } end + + def ===(other) + other.is_a?(Args) && ArrayMatch.call(parts, other.parts) + end + + def arity + parts.sum do |part| + case part + when ArgStar, ArgsForward + Float::INFINITY + when BareAssocHash + part.assocs.sum do |assoc| + assoc.is_a?(AssocSplat) ? Float::INFINITY : 1 + end + else + 1 + end + end + end end # ArgBlock represents using a block operator on an expression. @@ -645,16 +885,16 @@ def format(q) # method(&expression) # class ArgBlock < Node - # [nil | untyped] the expression being turned into a block + # [nil | Node] the expression being turned into a block attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -665,6 +905,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + ArgBlock.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -675,6 +926,10 @@ def format(q) q.text("&") q.format(value) if value end + + def ===(other) + other.is_a?(ArgBlock) && value === other.value + end end # Star represents using a splat operator on an expression. @@ -682,16 +937,16 @@ def format(q) # method(*arguments) # class ArgStar < Node - # [nil | untyped] the expression being splatted + # [nil | Node] the expression being splatted attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -702,6 +957,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + ArgStar.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -712,6 +978,10 @@ def format(q) q.text("*") q.format(value) if value end + + def ===(other) + other.is_a?(ArgStar) && value === other.value + end end # ArgsForward represents forwarding all kinds of arguments onto another method @@ -732,16 +1002,12 @@ def format(q) # The ArgsForward node appears in both the caller (the request method calls) # and the callee (the get and post definitions). class ArgsForward < Node - # [String] the value of the operator - attr_reader :value - # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) - @value = value + def initialize(location:) @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -752,14 +1018,29 @@ def child_nodes [] end + def copy(location: nil) + node = ArgsForward.new(location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } + { location: location, comments: comments } end def format(q) - q.text(value) + q.text("...") + end + + def ===(other) + other.is_a?(ArgsForward) + end + + def arity + Float::INFINITY end end @@ -770,6 +1051,17 @@ def format(q) # [one, two, three] # class ArrayLiteral < Node + # It's very common to use seplist with ->(q) { q.breakable_space }. We wrap + # that pattern into an object to cut down on having to create a bunch of + # lambdas all over the place. + class BreakableSpaceSeparator + def call(q) + q.breakable_space + end + end + + BREAKABLE_SPACE_SEPARATOR = BreakableSpaceSeparator.new.freeze + # Formats an array of multiple simple string literals into the %w syntax. class QWordsFormatter # [Args] the contents of the array @@ -780,10 +1072,11 @@ def initialize(contents) end def format(q) - q.group(0, "%w[", "]") do + q.text("%w[") + q.group do q.indent do - q.breakable("") - q.seplist(contents.parts, -> { q.breakable }) do |part| + q.breakable_empty + q.seplist(contents.parts, BREAKABLE_SPACE_SEPARATOR) do |part| if part.is_a?(StringLiteral) q.format(part.parts.first) else @@ -791,8 +1084,9 @@ def format(q) end end end - q.breakable("") + q.breakable_empty end + q.text("]") end end @@ -806,62 +1100,17 @@ def initialize(contents) end def format(q) - q.group(0, "%i[", "]") do + q.text("%i[") + q.group do q.indent do - q.breakable("") - q.seplist(contents.parts, -> { q.breakable }) do |part| + q.breakable_empty + q.seplist(contents.parts, BREAKABLE_SPACE_SEPARATOR) do |part| q.format(part.value) end end - q.breakable("") - end - end - end - - # Formats an array that contains only a list of variable references. To make - # things simpler, if there are a bunch, we format them all using the "fill" - # algorithm as opposed to breaking them into a ton of lines. For example, - # - # [foo, bar, baz] - # - # instead of becoming: - # - # [ - # foo, - # bar, - # baz - # ] - # - # would instead become: - # - # [ - # foo, bar, - # baz - # ] - # - # provided the line length was hit between `bar` and `baz`. - class VarRefsFormatter - # [Args] the contents of the array - attr_reader :contents - - def initialize(contents) - @contents = contents - end - - def format(q) - q.group(0, "[", "]") do - q.indent do - q.breakable("") - - separator = -> do - q.text(",") - q.fill_breakable - end - - q.seplist(contents.parts, separator) { |part| q.format(part) } - end - q.breakable("") + q.breakable_empty end + q.text("]") end end @@ -881,17 +1130,18 @@ def format(q) q.text("[") q.indent do lbracket.comments.each do |comment| - q.breakable(force: true) + q.breakable_force comment.format(q) end end - q.breakable(force: true) + q.breakable_force q.text("]") end end end - # [LBracket] the bracket that opens this array + # [nil | LBracket | QSymbolsBeg | QWordsBeg | SymbolsBeg | WordsBeg] the + # bracket that opens this array attr_reader :lbracket # [nil | Args] the contents of the array @@ -900,11 +1150,11 @@ def format(q) # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(lbracket:, contents:, location:, comments: []) + def initialize(lbracket:, contents:, location:) @lbracket = lbracket @contents = contents @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -915,6 +1165,18 @@ def child_nodes [lbracket, contents] end + def copy(lbracket: nil, contents: nil, location: nil) + node = + ArrayLiteral.new( + lbracket: lbracket || self.lbracket, + contents: contents || self.contents, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -927,19 +1189,20 @@ def deconstruct_keys(_keys) end def format(q) - if qwords? - QWordsFormatter.new(contents).format(q) - return - end + lbracket = self.lbracket + contents = self.contents - if qsymbols? - QSymbolsFormatter.new(contents).format(q) - return - end + if lbracket.is_a?(LBracket) && lbracket.comments.empty? && contents && + contents.comments.empty? && contents.parts.length > 1 + if qwords? + QWordsFormatter.new(contents).format(q) + return + end - if var_refs?(q) - VarRefsFormatter.new(contents).format(q) - return + if qsymbols? + QSymbolsFormatter.new(contents).format(q) + return + end end if empty_with_comments? @@ -952,52 +1215,43 @@ def format(q) if contents q.indent do - q.breakable("") + q.breakable_empty q.format(contents) + q.if_break { q.text(",") } if q.trailing_comma? end end - q.breakable("") + q.breakable_empty q.text("]") end end + def ===(other) + other.is_a?(ArrayLiteral) && lbracket === other.lbracket && + contents === other.contents + end + private def qwords? - lbracket.comments.empty? && contents && contents.comments.empty? && - contents.parts.length > 1 && - contents.parts.all? do |part| - case part - when StringLiteral - part.comments.empty? && part.parts.length == 1 && - part.parts.first.is_a?(TStringContent) && - !part.parts.first.value.match?(/[\s\[\]\\]/) - when CHAR - !part.value.match?(/[\[\]\\]/) - else - false - end + contents.parts.all? do |part| + case part + when StringLiteral + part.comments.empty? && part.parts.length == 1 && + part.parts.first.is_a?(TStringContent) && + !part.parts.first.value.match?(/[\s\[\]\\]/) + when CHAR + !part.value.match?(/[\[\]\\]/) + else + false end + end end def qsymbols? - lbracket.comments.empty? && contents && contents.comments.empty? && - contents.parts.length > 1 && - contents.parts.all? do |part| - part.is_a?(SymbolLiteral) && part.comments.empty? - end - end - - def var_refs?(q) - lbracket.comments.empty? && contents && contents.comments.empty? && - contents.parts.all? do |part| - part.is_a?(VarRef) && part.comments.empty? - end && - ( - contents.parts.sum { |part| part.value.value.length + 2 } > - q.maxwidth * 2 - ) + contents.parts.all? do |part| + part.is_a?(SymbolLiteral) && part.comments.empty? + end end # If we have an empty array that contains only comments, then we're going @@ -1045,10 +1299,10 @@ def format(q) end end - # [nil | VarRef] the optional constant wrapper + # [nil | VarRef | ConstPathRef] the optional constant wrapper attr_reader :constant - # [Array[ untyped ]] the regular positional arguments that this array + # [Array[ Node ]] the regular positional arguments that this array # pattern is matching against attr_reader :requireds @@ -1056,27 +1310,20 @@ def format(q) # positional arguments attr_reader :rest - # [Array[ untyped ]] the list of positional arguments occurring after the + # [Array[ Node ]] the list of positional arguments occurring after the # optional star if there is one attr_reader :posts # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - constant:, - requireds:, - rest:, - posts:, - location:, - comments: [] - ) + def initialize(constant:, requireds:, rest:, posts:, location:) @constant = constant @requireds = requireds @rest = rest @posts = posts @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1087,6 +1334,26 @@ def child_nodes [constant, *requireds, rest, *posts] end + def copy( + constant: nil, + requireds: nil, + rest: nil, + posts: nil, + location: nil + ) + node = + AryPtn.new( + constant: constant || self.constant, + requireds: requireds || self.requireds, + rest: rest || self.rest, + posts: posts || self.posts, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1101,45 +1368,41 @@ def deconstruct_keys(_keys) end def format(q) - parts = [*requireds] - parts << RestFormatter.new(rest) if rest - parts += posts + q.group do + q.format(constant) if constant + q.text("[") + q.indent do + q.breakable_empty + + parts = [*requireds] + parts << RestFormatter.new(rest) if rest + parts += posts - if constant - q.group do - q.format(constant) - q.text("[") q.seplist(parts) { |part| q.format(part) } - q.text("]") end - - return - end - - parent = q.parent - if parts.length == 1 || PATTERNS.include?(parent.class) - q.text("[") - q.seplist(parts) { |part| q.format(part) } + q.breakable_empty q.text("]") - elsif parts.empty? - q.text("[]") - else - q.group { q.seplist(parts) { |part| q.format(part) } } end end + + def ===(other) + other.is_a?(AryPtn) && constant === other.constant && + ArrayMatch.call(requireds, other.requireds) && rest === other.rest && + ArrayMatch.call(posts, other.posts) + end end # Determins if the following value should be indented or not. module AssignFormatting def self.skip_indent?(value) case value - in ArrayLiteral | HashLiteral | Heredoc | Lambda | QSymbols | QWords | - Symbols | Words + when ArrayLiteral, HashLiteral, Heredoc, Lambda, QSymbols, QWords, + Symbols, Words true - in Call[receiver:] - skip_indent?(receiver) - in DynaSymbol[quote:] - quote.start_with?("%s") + when CallNode + skip_indent?(value.receiver) + when DynaSymbol + value.quote.start_with?("%s") else false end @@ -1157,17 +1420,17 @@ class Assign < Node # to assign the result of the expression to attr_reader :target - # [untyped] the expression to be assigned + # [Node] the expression to be assigned attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(target:, value:, location:, comments: []) + def initialize(target:, value:, location:) @target = target @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1178,6 +1441,18 @@ def child_nodes [target, value] end + def copy(target: nil, value: nil, location: nil) + node = + Assign.new( + target: target || self.target, + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1194,13 +1469,17 @@ def format(q) q.format(value) else q.indent do - q.breakable + q.breakable_space q.format(value) end end end end + def ===(other) + other.is_a?(Assign) && target === other.target && value === other.value + end + private def skip_indent? @@ -1214,22 +1493,22 @@ def skip_indent? # # { key1: value1, key2: value2 } # - # In the above example, the would be two AssocNew nodes. + # In the above example, the would be two Assoc nodes. class Assoc < Node - # [untyped] the key of this pair + # [Node] the key of this pair attr_reader :key - # [untyped] the value of this pair + # [nil | Node] the value of this pair attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(key:, value:, location:, comments: []) + def initialize(key:, value:, location:) @key = key @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1240,6 +1519,18 @@ def child_nodes [key, value] end + def copy(key: nil, value: nil, location: nil) + node = + Assoc.new( + key: key || self.key, + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1254,10 +1545,14 @@ def format(q) end end + def ===(other) + other.is_a?(Assoc) && key === other.key && value === other.value + end + private def format_contents(q) - q.parent.format_key(q, key) + (q.parent || HashKeyFormatter::Identity.new).format_key(q, key) return unless value if key.comments.empty? && AssignFormatting.skip_indent?(value) @@ -1265,7 +1560,7 @@ def format_contents(q) q.format(value) else q.indent do - q.breakable + q.breakable_space q.format(value) end end @@ -1278,16 +1573,16 @@ def format_contents(q) # { **pairs } # class AssocSplat < Node - # [untyped] the expression that is being splatted + # [nil | Node] the expression that is being splatted attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1298,6 +1593,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + AssocSplat.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1306,7 +1612,11 @@ def deconstruct_keys(_keys) def format(q) q.text("**") - q.format(value) + q.format(value) if value + end + + def ===(other) + other.is_a?(AssocSplat) && value === other.value end end @@ -1322,10 +1632,10 @@ class Backref < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1336,6 +1646,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Backref.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1345,6 +1666,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Backref) && value === other.value + end end # Backtick represents the use of the ` operator. It's usually found being used @@ -1357,10 +1682,10 @@ class Backtick < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1371,6 +1696,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Backtick.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1380,6 +1716,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Backtick) && value === other.value + end end # This module is responsible for formatting the assocs contained within a @@ -1388,21 +1728,26 @@ def format(q) module HashKeyFormatter # Formats the keys of a hash literal using labels. class Labels - LABEL = /^[@$_A-Za-z]([_A-Za-z0-9]*)?([!_=?A-Za-z0-9])?$/ + LABEL = /\A[A-Za-z_](\w*[\w!?])?\z/.freeze def format_key(q, key) case key - in Label + when Label q.format(key) - in SymbolLiteral + when SymbolLiteral q.format(key.value) q.text(":") - in DynaSymbol[parts: [TStringContent[value: LABEL] => part]] - q.format(part) - q.text(":") - in DynaSymbol - q.format(key) - q.text(":") + when DynaSymbol + parts = key.parts + + if parts.length == 1 && (part = parts.first) && + part.is_a?(TStringContent) && part.value.match?(LABEL) + q.format(part) + q.text(":") + else + q.format(key) + q.text(":") + end end end end @@ -1412,8 +1757,7 @@ class Rockets def format_key(q, key) case key when Label - q.text(":") - q.text(key.value.chomp(":")) + q.text(":#{key.value.chomp(":")}") when DynaSymbol q.text(":") q.format(key) @@ -1425,30 +1769,74 @@ def format_key(q, key) end end - def self.for(container) - labels = - container.assocs.all? do |assoc| - next true if assoc.is_a?(AssocSplat) - - case assoc.key - when Label - true - when SymbolLiteral - # When attempting to convert a hash rocket into a hash label, - # you need to take care because only certain patterns are - # allowed. Ruby source says that they have to match keyword - # arguments to methods, but don't specify what that is. After - # some experimentation, it looks like it's: - value = assoc.key.value.value - value.match?(/^[_A-Za-z]/) && !value.end_with?("=") - when DynaSymbol - true - else - false - end + # When formatting a single assoc node without the context of the parent + # hash, this formatter is used. It uses whatever is present in the node, + # because there is nothing to be consistent with. + class Identity + def format_key(q, key) + if key.is_a?(Label) + q.format(key) + else + q.format(key) + q.text(" =>") + end + end + end + + class << self + def for(container) + (assocs = container.assocs).each_with_index do |assoc, index| + if assoc.is_a?(AssocSplat) + # Splat nodes do not impact the formatting choice. + elsif assoc.value.nil? + # If the value is nil, then it has been omitted. In this case we + # have to match the existing formatting because standardizing would + # potentially break the code. For example: + # + # { first:, "second" => "value" } + # + return Identity.new + else + # Otherwise, we need to check the type of the key. If it's a label + # or dynamic symbol, we can use labels. If it's a symbol literal + # then it needs to match a certain pattern to be used as a label. If + # it's anything else, then we need to use hash rockets. + case assoc.key + when Label, DynaSymbol + # Here labels can be used. + when SymbolLiteral + # When attempting to convert a hash rocket into a hash label, + # you need to take care because only certain patterns are + # allowed. Ruby source says that they have to match keyword + # arguments to methods, but don't specify what that is. After + # some experimentation, it looks like it's: + value = assoc.key.value.value + + if !value.match?(/^[_A-Za-z]/) || value.end_with?("=") + if omitted_value?(assocs[(index + 1)..]) + return Identity.new + else + return Rockets.new + end + end + else + if omitted_value?(assocs[(index + 1)..]) + return Identity.new + else + return Rockets.new + end + end + end end - (labels ? Labels : Rockets).new + Labels.new + end + + private + + def omitted_value?(assocs) + assocs.any? { |assoc| !assoc.is_a?(AssocSplat) && assoc.value.nil? } + end end end @@ -1465,10 +1853,10 @@ class BareAssocHash < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(assocs:, location:, comments: []) + def initialize(assocs:, location:) @assocs = assocs @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1479,6 +1867,17 @@ def child_nodes assocs end + def copy(assocs: nil, location: nil) + node = + BareAssocHash.new( + assocs: assocs || self.assocs, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1489,8 +1888,20 @@ def format(q) q.seplist(assocs) { |assoc| q.format(assoc) } end + def ===(other) + other.is_a?(BareAssocHash) && ArrayMatch.call(assocs, other.assocs) + end + def format_key(q, key) - (@key_formatter ||= HashKeyFormatter.for(self)).format_key(q, key) + @key_formatter ||= + case q.parents.take(3).last + when Break, Next, ReturnNode + HashKeyFormatter::Identity.new + else + HashKeyFormatter.for(self) + end + + @key_formatter.format_key(q, key) end end @@ -1507,10 +1918,10 @@ class Begin < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(bodystmt:, location:, comments: []) + def initialize(bodystmt:, location:) @bodystmt = bodystmt @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1521,6 +1932,17 @@ def child_nodes [bodystmt] end + def copy(bodystmt: nil, location: nil) + node = + Begin.new( + bodystmt: bodystmt || self.bodystmt, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1532,14 +1954,18 @@ def format(q) unless bodystmt.empty? q.indent do - q.breakable(force: true) unless bodystmt.statements.empty? + q.breakable_force unless bodystmt.statements.empty? q.format(bodystmt) end end - q.breakable(force: true) + q.breakable_force q.text("end") end + + def ===(other) + other.is_a?(Begin) && bodystmt === other.bodystmt + end end # PinnedBegin represents a pinning a nested statement within pattern matching. @@ -1549,16 +1975,16 @@ def format(q) # end # class PinnedBegin < Node - # [untyped] the expression being pinned + # [Node] the expression being pinned attr_reader :statement # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(statement:, location:, comments: []) + def initialize(statement:, location:) @statement = statement @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1569,6 +1995,17 @@ def child_nodes [statement] end + def copy(statement: nil, location: nil) + node = + PinnedBegin.new( + statement: statement || self.statement, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1580,14 +2017,18 @@ def format(q) q.text("^(") q.nest(1) do q.indent do - q.breakable("") + q.breakable_empty q.format(statement) end - q.breakable("") + q.breakable_empty q.text(")") end end end + + def ===(other) + other.is_a?(PinnedBegin) && statement === other.statement + end end # Binary represents any expression that involves two sub-expressions with an @@ -1601,24 +2042,38 @@ def format(q) # array << value # class Binary < Node - # [untyped] the left-hand side of the expression + # Since Binary's operator is a symbol, it's better to use the `name` method + # than to allocate a new string every time. This is a tiny performance + # optimization, but enough that it shows up in the profiler. Adding this in + # for older Ruby versions. + unless :+.respond_to?(:name) + using Module.new { + refine Symbol do + def name + to_s.freeze + end + end + } + end + + # [Node] the left-hand side of the expression attr_reader :left # [Symbol] the operator used between the two expressions attr_reader :operator - # [untyped] the right-hand side of the expression + # [Node] the right-hand side of the expression attr_reader :right # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(left:, operator:, right:, location:, comments: []) + def initialize(left:, operator:, right:, location:) @left = left @operator = operator @right = right @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1629,6 +2084,19 @@ def child_nodes [left, right] end + def copy(left: nil, operator: nil, right: nil, location: nil) + node = + Binary.new( + left: left || self.left, + operator: operator || self.operator, + right: right || self.right, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1642,74 +2110,40 @@ def deconstruct_keys(_keys) end def format(q) + left = self.left power = operator == :** q.group do q.group { q.format(left) } q.text(" ") unless power - if operator == :<< - q.text(operator.to_s) - q.text(" ") - q.format(right) - else + if operator != :<< q.group do - q.text(operator.to_s) - + q.text(operator.name) q.indent do - q.breakable(power ? "" : " ") + power ? q.breakable_empty : q.breakable_space q.format(right) end end - end - end - end - end - - # This module will remove any breakables from the list of contents so that no - # newlines are present in the output. - module RemoveBreaks - class << self - def call(doc) - marker = Object.new - stack = [doc] - - while stack.any? - doc = stack.pop - - if doc == marker - stack.pop - next - end - - stack += [doc, marker] - - case doc - when PrettyPrint::Align, PrettyPrint::Indent, PrettyPrint::Group - doc.contents.map! { |child| remove_breaks(child) } - stack += doc.contents.reverse - when PrettyPrint::IfBreak - doc.flat_contents.map! { |child| remove_breaks(child) } - stack += doc.flat_contents.reverse + elsif left.is_a?(Binary) && left.operator == :<< + q.group do + q.text(operator.name) + q.indent do + power ? q.breakable_empty : q.breakable_space + q.format(right) + end end - end - end - - private - - def remove_breaks(doc) - case doc - when PrettyPrint::Breakable - text = PrettyPrint::Text.new - text.add(object: doc.force? ? "; " : doc.separator, width: doc.width) - text - when PrettyPrint::IfBreak - PrettyPrint::Align.new(indent: 0, contents: doc.flat_contents) else - doc + q.text("<< ") + q.format(right) end end end + + def ===(other) + other.is_a?(Binary) && left === other.left && + operator === other.operator && right === other.right + end end # BlockVar represents the parameters being declared for a block. Effectively @@ -1729,11 +2163,11 @@ class BlockVar < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(params:, locals:, location:, comments: []) + def initialize(params:, locals:, location:) @params = params @locals = locals @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1744,22 +2178,60 @@ def child_nodes [params, *locals] end + def copy(params: nil, locals: nil, location: nil) + node = + BlockVar.new( + params: params || self.params, + locals: locals || self.locals, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { params: params, locals: locals, location: location, comments: comments } end + # Within the pipes of the block declaration, we don't want any spaces. So + # we'll separate the parameters with a comma and space but no breakables. + class Separator + def call(q) + q.text(", ") + end + end + + # We'll keep a single instance of this separator around for all block vars + # to cut down on allocations. + SEPARATOR = Separator.new.freeze + def format(q) - q.group(0, "|", "|") do - doc = q.format(params) - RemoveBreaks.call(doc) + q.text("|") + q.group do + q.remove_breaks(q.format(params)) if locals.any? q.text("; ") - q.seplist(locals, -> { q.text(", ") }) { |local| q.format(local) } + q.seplist(locals, SEPARATOR) { |local| q.format(local) } end end + q.text("|") + end + + def ===(other) + other.is_a?(BlockVar) && params === other.params && + ArrayMatch.call(locals, other.locals) + end + + # When a single required parameter is declared for a block, it gets + # automatically expanded if the values being yielded into it are an array. + def arg0? + params.requireds.length == 1 && params.optionals.empty? && + params.rest.nil? && params.posts.empty? && params.keywords.empty? && + params.keyword_rest.nil? && params.block.nil? end end @@ -1774,10 +2246,10 @@ class BlockArg < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(name:, location:, comments: []) + def initialize(name:, location:) @name = name @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -1788,6 +2260,17 @@ def child_nodes [name] end + def copy(name: nil, location: nil) + node = + BlockArg.new( + name: name || self.name, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -1798,6 +2281,10 @@ def format(q) q.text("&") q.format(name) if name end + + def ===(other) + other.is_a?(BlockArg) && name === other.name + end end # bodystmt can't actually determine its bounds appropriately because it @@ -1828,8 +2315,7 @@ def initialize( else_keyword:, else_clause:, ensure_clause:, - location:, - comments: [] + location: ) @statements = statements @rescue_clause = rescue_clause @@ -1837,10 +2323,12 @@ def initialize( @else_clause = else_clause @ensure_clause = ensure_clause @location = location - @comments = comments + @comments = [] end - def bind(start_char, start_column, end_char, end_column) + def bind(parser, start_char, start_column, end_char, end_column) + rescue_clause = self.rescue_clause + @location = Location.new( start_line: location.start_line, @@ -1851,11 +2339,10 @@ def bind(start_char, start_column, end_char, end_column) end_column: end_column ) - parts = [rescue_clause, else_clause, ensure_clause] - # Here we're going to determine the bounds for the statements - consequent = parts.compact.first + consequent = rescue_clause || else_clause || ensure_clause statements.bind( + parser, start_char, start_column, consequent ? consequent.location.start_char : end_char, @@ -1864,7 +2351,8 @@ def bind(start_char, start_column, end_char, end_column) # Next we're going to determine the rescue clause if there is one if rescue_clause - consequent = parts.drop(1).compact.first + consequent = else_clause || ensure_clause + rescue_clause.bind_end( consequent ? consequent.location.start_char : end_char, consequent ? consequent.location.start_column : end_column @@ -1884,12 +2372,35 @@ def child_nodes [statements, rescue_clause, else_keyword, else_clause, ensure_clause] end + def copy( + statements: nil, + rescue_clause: nil, + else_keyword: nil, + else_clause: nil, + ensure_clause: nil, + location: nil + ) + node = + BodyStmt.new( + statements: statements || self.statements, + rescue_clause: rescue_clause || self.rescue_clause, + else_keyword: else_keyword || self.else_keyword, + else_clause: else_clause || self.else_clause, + ensure_clause: ensure_clause || self.ensure_clause, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { statements: statements, rescue_clause: rescue_clause, + else_keyword: else_keyword, else_clause: else_clause, ensure_clause: ensure_clause, location: location, @@ -1903,423 +2414,268 @@ def format(q) if rescue_clause q.nest(-2) do - q.breakable(force: true) + q.breakable_force q.format(rescue_clause) end end if else_clause q.nest(-2) do - q.breakable(force: true) + q.breakable_force q.format(else_keyword) end unless else_clause.empty? - q.breakable(force: true) + q.breakable_force q.format(else_clause) end end if ensure_clause q.nest(-2) do - q.breakable(force: true) + q.breakable_force q.format(ensure_clause) end end end end - end - - # Responsible for formatting either a BraceBlock or a DoBlock. - class BlockFormatter - # Formats the opening brace or keyword of a block. - class BlockOpenFormatter - # [String] the actual output that should be printed - attr_reader :text - - # [LBrace | Keyword] the node that is being represented - attr_reader :node - - def initialize(text, node) - @text = text - @node = node - end - - def comments - node.comments - end - def format(q) - q.text(text) - end + def ===(other) + other.is_a?(BodyStmt) && statements === other.statements && + rescue_clause === other.rescue_clause && + else_keyword === other.else_keyword && + else_clause === other.else_clause && + ensure_clause === other.ensure_clause end + end - # [BraceBlock | DoBlock] the block node to be formatted - attr_reader :node - - # [LBrace | Keyword] the node that opens the block - attr_reader :block_open - - # [String] the string that closes the block - attr_reader :block_close + # Formats either a Break, Next, or Return node. + class FlowControlFormatter + # [String] the keyword to print + attr_reader :keyword - # [BodyStmt | Statements] the statements inside the block - attr_reader :statements + # [Break | Next | Return] the node being formatted + attr_reader :node - def initialize(node, block_open, block_close, statements) + def initialize(keyword, node) + @keyword = keyword @node = node - @block_open = block_open - @block_close = block_close - @statements = statements end def format(q) - # If this is nested anywhere inside of a Command or CommandCall node, then - # we can't change which operators we're using for the bounds of the block. - break_opening, break_closing, flat_opening, flat_closing = - if unchangeable_bounds?(q) - [block_open.value, block_close, block_open.value, block_close] - elsif forced_do_end_bounds?(q) - %w[do end do end] - elsif forced_brace_bounds?(q) - %w[{ } { }] - else - %w[do end { }] - end - - # If the receiver of this block a Command or CommandCall node, then there - # are no parentheses around the arguments to that command, so we need to - # break the block. - receiver = q.parent.call - if receiver.is_a?(Command) || receiver.is_a?(CommandCall) - q.break_parent - format_break(q, break_opening, break_closing) + # If there are no arguments associated with this flow control, then we can + # safely just print the keyword and return. + if node.arguments.nil? + q.text(keyword) return end q.group do - q - .if_break { format_break(q, break_opening, break_closing) } - .if_flat { format_flat(q, flat_opening, flat_closing) } - end - end + q.text(keyword) - private + parts = node.arguments.parts + length = parts.length - # If this is nested anywhere inside certain nodes, then we can't change - # which operators/keywords we're using for the bounds of the block. - def unchangeable_bounds?(q) - q.parents.any? do |parent| - # If we hit a statements, then we're safe to use whatever since we - # know for certain we're going to get split over multiple lines - # anyway. - break false if parent.is_a?(Statements) + if length == 0 + # Here there are no arguments at all, so we're not going to print + # anything. This would be like if we had: + # + # break + # + elsif length >= 2 + # If there are multiple arguments, format them all. If the line is + # going to break into multiple, then use brackets to start and end the + # expression. + format_arguments(q, " [", "]") + else + # If we get here, then we're formatting a single argument to the flow + # control keyword. + part = parts.first - [Command, CommandCall].include?(parent.class) + case part + when Paren + statements = part.contents.body + + if statements.length == 1 + statement = statements.first + + if statement.is_a?(ArrayLiteral) + contents = statement.contents + + if contents && contents.parts.length >= 2 + # Here we have a single argument that is a set of parentheses + # wrapping an array literal that has at least 2 elements. + # We're going to print the contents of the array directly. + # This would be like if we had: + # + # break([1, 2, 3]) + # + # which we will print as: + # + # break 1, 2, 3 + # + q.text(" ") + format_array_contents(q, statement) + else + # Here we have a single argument that is a set of parentheses + # wrapping an array literal that has 0 or 1 elements. We're + # going to skip the parentheses but print the array itself. + # This would be like if we had: + # + # break([1]) + # + # which we will print as: + # + # break [1] + # + q.text(" ") + q.format(statement) + end + elsif skip_parens?(statement) + # Here we have a single argument that is a set of parentheses + # that themselves contain a single statement. That statement is + # a simple value that we can skip the parentheses for. This + # would be like if we had: + # + # break(1) + # + # which we will print as: + # + # break 1 + # + q.text(" ") + q.format(statement) + else + # Here we have a single argument that is a set of parentheses. + # We're going to print the parentheses themselves as if they + # were the set of arguments. This would be like if we had: + # + # break(foo.bar) + # + q.format(part) + end + else + q.format(part) + end + when ArrayLiteral + contents = part.contents + + if contents && contents.parts.length >= 2 + # Here there is a single argument that is an array literal with at + # least two elements. We skip directly into the array literal's + # elements in order to print the contents. This would be like if + # we had: + # + # break [1, 2, 3] + # + # which we will print as: + # + # break 1, 2, 3 + # + q.text(" ") + format_array_contents(q, part) + else + # Here there is a single argument that is an array literal with 0 + # or 1 elements. In this case we're going to print the array as it + # is because skipping the brackets would change the remaining. + # This would be like if we had: + # + # break [] + # break [1] + # + q.text(" ") + q.format(part) + end + else + # Here there is a single argument that hasn't matched one of our + # previous cases. We're going to print the argument as it is. This + # would be like if we had: + # + # break foo + # + format_arguments(q, "(", ")") + end + end end end - # If we're a sibling of a control-flow keyword, then we're going to have to - # use the do..end bounds. - def forced_do_end_bounds?(q) - [Break, Next, Return, Super].include?(q.parent.call.class) - end - - # If we're the predicate of a loop or conditional, then we're going to have - # to go with the {..} bounds. - def forced_brace_bounds?(q) - parents = q.parents.to_a - parents.each_with_index.any? do |parent, index| - # If we hit certain breakpoints then we know we're safe. - break false if [Paren, Statements].include?(parent.class) - - [ - If, - IfMod, - IfOp, - Unless, - UnlessMod, - While, - WhileMod, - Until, - UntilMod - ].include?(parent.class) && parent.predicate == parents[index - 1] - end - end - - def format_break(q, opening, closing) - q.text(" ") - q.format(BlockOpenFormatter.new(opening, block_open), stackable: false) - - if node.block_var - q.text(" ") - q.format(node.block_var) - end + private - unless statements.empty? - q.indent do - q.breakable - q.format(statements) - end + def format_array_contents(q, array) + q.if_break { q.text("[") } + q.indent do + q.breakable_empty + q.format(array.contents) end - - q.breakable - q.text(closing) + q.breakable_empty + q.if_break { q.text("]") } end - def format_flat(q, opening, closing) - q.text(" ") - q.format(BlockOpenFormatter.new(opening, block_open), stackable: false) - - if node.block_var - q.breakable - q.format(node.block_var) - q.breakable + def format_arguments(q, opening, closing) + q.if_break { q.text(opening) } + q.indent do + q.breakable_space + q.format(node.arguments) end + q.breakable_empty + q.if_break { q.text(closing) } + end - if statements.empty? - q.text(" ") if opening == "do" + def skip_parens?(node) + case node + when FloatLiteral, Imaginary, Int, RationalLiteral + true + when VarRef + case node.value + when Const, CVar, GVar, IVar, Kw + true + else + false + end else - q.breakable unless node.block_var - q.format(statements) - q.breakable + false end - - q.text(closing) end end - # BraceBlock represents passing a block to a method call using the { } - # operators. + # Break represents using the +break+ keyword. # - # method { |variable| variable + 1 } + # break # - class BraceBlock < Node - # [LBrace] the left brace that opens this block - attr_reader :lbrace - - # [nil | BlockVar] the optional set of parameters to the block - attr_reader :block_var - - # [Statements] the list of expressions to evaluate within the block - attr_reader :statements + # It can also optionally accept arguments, as in: + # + # break 1 + # + class Break < Node + # [Args] the arguments being sent to the keyword + attr_reader :arguments # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(lbrace:, block_var:, statements:, location:, comments: []) - @lbrace = lbrace - @block_var = block_var - @statements = statements + def initialize(arguments:, location:) + @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) - visitor.visit_brace_block(self) + visitor.visit_break(self) end def child_nodes - [lbrace, block_var, statements] + [arguments] end - alias deconstruct child_nodes + def copy(arguments: nil, location: nil) + node = + Break.new( + arguments: arguments || self.arguments, + location: location || self.location + ) - def deconstruct_keys(_keys) - { - lbrace: lbrace, - block_var: block_var, - statements: statements, - location: location, - comments: comments - } - end - - def format(q) - BlockFormatter.new(self, lbrace, "}", statements).format(q) - end - end - - # Formats either a Break, Next, or Return node. - class FlowControlFormatter - # [String] the keyword to print - attr_reader :keyword - - # [Break | Next | Return] the node being formatted - attr_reader :node - - def initialize(keyword, node) - @keyword = keyword - @node = node - end - - def format(q) - q.group do - q.text(keyword) - - case node.arguments.parts - in [] - # Here there are no arguments at all, so we're not going to print - # anything. This would be like if we had: - # - # break - # - in [Paren[ - contents: { - body: [ArrayLiteral[contents: { parts: [_, _, *] }] => array] - } - ]] - # Here we have a single argument that is a set of parentheses wrapping - # an array literal that has at least 2 elements. We're going to print - # the contents of the array directly. This would be like if we had: - # - # break([1, 2, 3]) - # - # which we will print as: - # - # break 1, 2, 3 - # - q.text(" ") - format_array_contents(q, array) - in [Paren[contents: { body: [ArrayLiteral => statement] }]] - # Here we have a single argument that is a set of parentheses wrapping - # an array literal that has 0 or 1 elements. We're going to skip the - # parentheses but print the array itself. This would be like if we - # had: - # - # break([1]) - # - # which we will print as: - # - # break [1] - # - q.text(" ") - q.format(statement) - in [Paren[contents: { body: [statement] }]] if skip_parens?(statement) - # Here we have a single argument that is a set of parentheses that - # themselves contain a single statement. That statement is a simple - # value that we can skip the parentheses for. This would be like if we - # had: - # - # break(1) - # - # which we will print as: - # - # break 1 - # - q.text(" ") - q.format(statement) - in [Paren => part] - # Here we have a single argument that is a set of parentheses. We're - # going to print the parentheses themselves as if they were the set of - # arguments. This would be like if we had: - # - # break(foo.bar) - # - q.format(part) - in [ArrayLiteral[contents: { parts: [_, _, *] }] => array] - # Here there is a single argument that is an array literal with at - # least two elements. We skip directly into the array literal's - # elements in order to print the contents. This would be like if we - # had: - # - # break [1, 2, 3] - # - # which we will print as: - # - # break 1, 2, 3 - # - q.text(" ") - format_array_contents(q, array) - in [ArrayLiteral => part] - # Here there is a single argument that is an array literal with 0 or 1 - # elements. In this case we're going to print the array as it is - # because skipping the brackets would change the remaining. This would - # be like if we had: - # - # break [] - # break [1] - # - q.text(" ") - q.format(part) - in [_] - # Here there is a single argument that hasn't matched one of our - # previous cases. We're going to print the argument as it is. This - # would be like if we had: - # - # break foo - # - format_arguments(q, "(", ")") - else - # If there are multiple arguments, format them all. If the line is - # going to break into multiple, then use brackets to start and end the - # expression. - format_arguments(q, " [", "]") - end - end - end - - private - - def format_array_contents(q, array) - q.if_break { q.text("[") } - q.indent do - q.breakable("") - q.format(array.contents) - end - q.breakable("") - q.if_break { q.text("]") } - end - - def format_arguments(q, opening, closing) - q.if_break { q.text(opening) } - q.indent do - q.breakable(" ") - q.format(node.arguments) - end - q.breakable("") - q.if_break { q.text(closing) } - end - - def skip_parens?(node) - case node - in FloatLiteral | Imaginary | Int | RationalLiteral - true - in VarRef[value: Const | CVar | GVar | IVar | Kw] - true - else - false - end - end - end - - # Break represents using the +break+ keyword. - # - # break - # - # It can also optionally accept arguments, as in: - # - # break 1 - # - class Break < Node - # [Args] the arguments being sent to the keyword - attr_reader :arguments - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(arguments:, location:, comments: []) - @arguments = arguments - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_break(self) - end - - def child_nodes - [arguments] + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -2331,6 +2687,10 @@ def deconstruct_keys(_keys) def format(q) FlowControlFormatter.new("break", self).format(q) end + + def ===(other) + other.is_a?(Break) && arguments === other.arguments + end end # Wraps a call operator (which can be a string literal :: or an Op node or a @@ -2348,8 +2708,11 @@ def comments end def format(q) - if operator == :"::" || (operator.is_a?(Op) && operator.value == "::") + case operator + when :"::" q.text(".") + when Op + operator.value == "::" ? q.text(".") : operator.format(q) else operator.format(q) end @@ -2371,7 +2734,7 @@ def format(q) # Of course there are a lot of caveats to that, including trailing operators # when necessary, where comments are places, how blocks are aligned, etc. class CallChainFormatter - # [Call | MethodAddBlock] the top of the call chain + # [CallNode | MethodAddBlock] the top of the call chain attr_reader :node def initialize(node) @@ -2385,13 +2748,30 @@ def format(q) # First, walk down the chain until we get to the point where we're not # longer at a chainable node. loop do - case children.last - in Call[receiver: Call] - children << children.last.receiver - in Call[receiver: MethodAddBlock[call: Call]] - children << children.last.receiver - in MethodAddBlock[call: Call] - children << children.last.call + case (child = children.last) + when CallNode + case (receiver = child.receiver) + when CallNode + if receiver.receiver.nil? + break + else + children << receiver + end + when MethodAddBlock + if (call = receiver.call).is_a?(CallNode) && !call.receiver.nil? + children << receiver + else + break + end + else + break + end + when MethodAddBlock + if (call = child.call).is_a?(CallNode) && !call.receiver.nil? + children << call + else + break + end else break end @@ -2405,15 +2785,14 @@ def format(q) # https://github.com/prettier/plugin-ruby/issues/863. parents = q.parents.take(4) if (parent = parents[2]) - # If we're at a do_block, then we want to go one more level up. This is - # because do blocks have BodyStmt nodes instead of just Statements - # nodes. - parent = parents[3] if parent.is_a?(DoBlock) + # If we're at a block with the `do` keywords, then we want to go one + # more level up. This is because do blocks have BodyStmt nodes instead + # of just Statements nodes. + parent = parents[3] if parent.is_a?(BlockNode) && parent.keywords? - case parent - in MethodAddBlock[call: FCall[value: { value: "sig" }]] + if parent.is_a?(MethodAddBlock) && + (call = parent.call).is_a?(CallNode) && call.message.value == "sig" threshold = 2 - else end end @@ -2435,13 +2814,13 @@ def format_chain(q, children) empty_except_last = children .drop(1) - .all? { |child| child.is_a?(Call) && child.arguments.nil? } + .all? { |child| child.is_a?(CallNode) && child.arguments.nil? } # Here, we're going to add all of the children onto the stack of the # formatter so it's as if we had descending normally into them. This is # necessary so they can check their parents as normal. q.stack.concat(children) - q.format(children.last.receiver) + q.format(children.last.receiver) if children.last.receiver q.group do if attach_directly?(children.last) @@ -2456,20 +2835,21 @@ def format_chain(q, children) skip_operator = false while (child = children.pop) - case child - in Call[ - receiver: Call[message: { value: "where" }], - message: { value: "not" } - ] - # This is very specialized behavior wherein we group - # .where.not calls together because it looks better. For more - # information, see - # https://github.com/prettier/plugin-ruby/issues/862. - in Call - # If we're at a Call node and not a MethodAddBlock node in the - # chain then we're going to add a newline so it indents properly. - q.breakable("") - else + if child.is_a?(CallNode) + if (receiver = child.receiver).is_a?(CallNode) && + (receiver.message != :call) && + (receiver.message.value == "where") && + (child.message != :call && child.message.value == "not") + # This is very specialized behavior wherein we group + # .where.not calls together because it looks better. For more + # information, see + # https://github.com/prettier/plugin-ruby/issues/862. + else + # If we're at a Call node and not a MethodAddBlock node in the + # chain then we're going to add a newline so it indents + # properly. + q.breakable_empty + end end format_child( @@ -2482,9 +2862,13 @@ def format_chain(q, children) # If the parent call node has a comment on the message then we need # to print the operator trailing in order to keep it working. - case children.last - in Call[message: { comments: [_, *] }, operator:] - q.format(CallOperatorFormatter.new(operator)) + last_child = children.last + if last_child.is_a?(CallNode) && last_child.message != :call && + ( + (last_child.message.comments.any? && last_child.operator) || + (last_child.operator && last_child.operator.comments.any?) + ) + q.format(CallOperatorFormatter.new(last_child.operator)) skip_operator = true else skip_operator = false @@ -2499,18 +2883,23 @@ def format_chain(q, children) if empty_except_last case node - in Call + when CallNode node.format_arguments(q) - in MethodAddBlock[block:] - q.format(block) + when MethodAddBlock + q.format(node.block) end end end def self.chained?(node) + return false if ENV["STREE_FAST_FORMAT"] + case node - in Call | MethodAddBlock[call: Call] - true + when CallNode + !node.receiver.nil? + when MethodAddBlock + call = node.call + call.is_a?(CallNode) && !call.receiver.nil? else false end @@ -2522,9 +2911,13 @@ def self.chained?(node) # want to indent the first call. So we'll pop off the first children and # format it separately here. def attach_directly?(node) - [ArrayLiteral, HashLiteral, Heredoc, If, Unless, XStringLiteral].include?( - node.receiver.class - ) + case node.receiver + when ArrayLiteral, HashLiteral, Heredoc, IfNode, UnlessNode, + XStringLiteral + true + else + false + end end def format_child( @@ -2536,15 +2929,15 @@ def format_child( ) # First, format the actual contents of the child. case child - in Call + when CallNode q.group do - unless skip_operator + if !skip_operator && child.operator q.format(CallOperatorFormatter.new(child.operator)) end q.format(child.message) if child.message != :call child.format_arguments(q) unless skip_attached end - in MethodAddBlock + when MethodAddBlock q.format(child.block) unless skip_attached end @@ -2552,7 +2945,7 @@ def format_child( # them out here since we're bypassing the normal comment printing. if child.comments.any? && !skip_comments child.comments.each do |comment| - comment.inline? ? q.text(" ") : q.breakable + comment.inline? ? q.text(" ") : q.breakable_space comment.format(q) end @@ -2561,15 +2954,15 @@ def format_child( end end - # Call represents a method call. + # CallNode represents a method call. # # receiver.message # - class Call < Node - # [untyped] the receiver of the method call + class CallNode < Node + # [nil | Node] the receiver of the method call attr_reader :receiver - # [:"::" | Op | Period] the operator being used to send the message + # [nil | :"::" | Op | Period] the operator being used to send the message attr_reader :operator # [:call | Backtick | Const | Ident | Op] the message being sent @@ -2581,20 +2974,13 @@ class Call < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - receiver:, - operator:, - message:, - arguments:, - location:, - comments: [] - ) + def initialize(receiver:, operator:, message:, arguments:, location:) @receiver = receiver @operator = operator @message = message @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -2610,6 +2996,26 @@ def child_nodes ] end + def copy( + receiver: nil, + operator: nil, + message: nil, + arguments: nil, + location: nil + ) + node = + CallNode.new( + receiver: receiver || self.receiver, + operator: operator || self.operator, + message: message || self.message, + arguments: arguments || self.arguments, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -2624,30 +3030,60 @@ def deconstruct_keys(_keys) end def format(q) - # If we're at the top of a call chain, then we're going to do some - # specialized printing in case we can print it nicely. We _only_ do this - # at the top of the chain to avoid weird recursion issues. - if !CallChainFormatter.chained?(q.parent) && - CallChainFormatter.chained?(receiver) - q.group do - q - .if_break { CallChainFormatter.new(self).format(q) } - .if_flat { format_contents(q) } + if receiver + # If we're at the top of a call chain, then we're going to do some + # specialized printing in case we can print it nicely. We _only_ do this + # at the top of the chain to avoid weird recursion issues. + if CallChainFormatter.chained?(receiver) && + !CallChainFormatter.chained?(q.parent) + q.group do + q + .if_break { CallChainFormatter.new(self).format(q) } + .if_flat { format_contents(q) } + end + else + format_contents(q) end else - format_contents(q) + q.format(message) + + # Note that this explicitly leaves parentheses in place even if they are + # empty. There are two reasons we would need to do this. The first is if + # we're calling something that looks like a constant, as in: + # + # Foo() + # + # In this case if we remove the parentheses then this becomes a constant + # reference and not a method call. The second is if we're calling a + # method that is the same name as a local variable that is in scope, as + # in: + # + # foo = foo() + # + # In this case we have to keep the parentheses or else it treats this + # like assigning nil to the local variable. Note that we could attempt + # to be smarter about this by tracking the local variables that are in + # scope, but for now it's simpler and more efficient to just leave the + # parentheses in place. + q.format(arguments) if arguments end end + def ===(other) + other.is_a?(CallNode) && receiver === other.receiver && + operator === other.operator && message === other.message && + arguments === other.arguments + end + + # Print out the arguments to this call. If there are no arguments, then do + # nothing. def format_arguments(q) case arguments - in ArgParen + when ArgParen q.format(arguments) - in Args + when Args q.text(" ") q.format(arguments) - else - # Do nothing if there are no arguments. end end @@ -2664,7 +3100,7 @@ def format_contents(q) q.group do q.indent do if receiver.comments.any? || call_operator.comments.any? - q.breakable(force: true) + q.breakable_force end if call_operator.comments.empty? @@ -2678,6 +3114,10 @@ def format_contents(q) end end end + + def arity + arguments&.arity || 0 + end end # Case represents the beginning of a case chain. @@ -2695,7 +3135,7 @@ class Case < Node # [Kw] the keyword that opens this expression attr_reader :keyword - # [nil | untyped] optional value being switched on + # [nil | Node] optional value being switched on attr_reader :value # [In | When] the next clause in the chain @@ -2704,12 +3144,12 @@ class Case < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(keyword:, value:, consequent:, location:, comments: []) + def initialize(keyword:, value:, consequent:, location:) @keyword = keyword @value = value @consequent = consequent @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -2720,6 +3160,19 @@ def child_nodes [keyword, value, consequent] end + def copy(keyword: nil, value: nil, consequent: nil, location: nil) + node = + Case.new( + keyword: keyword || self.keyword, + value: value || self.value, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -2741,13 +3194,18 @@ def format(q) q.format(value) end - q.breakable(force: true) + q.breakable_force q.format(consequent) - q.breakable(force: true) + q.breakable_force q.text("end") end end + + def ===(other) + other.is_a?(Case) && keyword === other.keyword && value === other.value && + consequent === other.consequent + end end # RAssign represents a single-line pattern match. @@ -2756,25 +3214,25 @@ def format(q) # value => pattern # class RAssign < Node - # [untyped] the left-hand expression + # [Node] the left-hand expression attr_reader :value # [Kw | Op] the operator being used to match against the pattern, which is # either => or in attr_reader :operator - # [untyped] the pattern on the right-hand side of the expression + # [Node] the pattern on the right-hand side of the expression attr_reader :pattern # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, operator:, pattern:, location:, comments: []) + def initialize(value:, operator:, pattern:, location:) @value = value @operator = operator @pattern = pattern @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -2785,6 +3243,19 @@ def child_nodes [value, operator, pattern] end + def copy(value: nil, operator: nil, pattern: nil, location: nil) + node = + RAssign.new( + value: value || self.value, + operator: operator || self.operator, + pattern: pattern || self.pattern, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -2802,14 +3273,26 @@ def format(q) q.format(value) q.text(" ") q.format(operator) - q.group do - q.indent do - q.breakable - q.format(pattern) + + case pattern + when AryPtn, FndPtn, HshPtn + q.text(" ") + q.format(pattern) + else + q.group do + q.indent do + q.breakable_space + q.format(pattern) + end end end end end + + def ===(other) + other.is_a?(RAssign) && value === other.value && + operator === other.operator && pattern === other.pattern + end end # Class represents defining a class using the +class+ keyword. @@ -2849,7 +3332,7 @@ class ClassDeclaration < Node # defined attr_reader :constant - # [nil | untyped] the optional superclass declaration + # [nil | Node] the optional superclass declaration attr_reader :superclass # [BodyStmt] the expressions to execute within the context of the class @@ -2858,12 +3341,12 @@ class ClassDeclaration < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, superclass:, bodystmt:, location:, comments: []) + def initialize(constant:, superclass:, bodystmt:, location:) @constant = constant @superclass = superclass @bodystmt = bodystmt @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -2874,6 +3357,19 @@ def child_nodes [constant, superclass, bodystmt] end + def copy(constant: nil, superclass: nil, bodystmt: nil, location: nil) + node = + ClassDeclaration.new( + constant: constant || self.constant, + superclass: superclass || self.superclass, + bodystmt: bodystmt || self.bodystmt, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -2887,38 +3383,45 @@ def deconstruct_keys(_keys) end def format(q) - declaration = -> do - q.group do - q.text("class ") - q.format(constant) - - if superclass - q.text(" < ") - q.format(superclass) - end - end - end - if bodystmt.empty? q.group do - declaration.call - q.breakable(force: true) + format_declaration(q) + q.breakable_force q.text("end") end else q.group do - declaration.call + format_declaration(q) q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end - q.breakable(force: true) + q.breakable_force q.text("end") end end end + + def ===(other) + other.is_a?(ClassDeclaration) && constant === other.constant && + superclass === other.superclass && bodystmt === other.bodystmt + end + + private + + def format_declaration(q) + q.group do + q.text("class ") + q.format(constant) + + if superclass + q.text(" < ") + q.format(superclass) + end + end + end end # Comma represents the use of the , operator. @@ -2939,11 +3442,19 @@ def child_nodes [] end + def copy(value: nil, location: nil) + Comma.new(value: value || self.value, location: location || self.location) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(Comma) && value === other.value + end end # Command represents a method call with arguments and no parentheses. Note @@ -2959,14 +3470,18 @@ class Command < Node # [Args] the arguments being sent with the message attr_reader :arguments + # [nil | BlockNode] the optional block being passed to the method + attr_reader :block + # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(message:, arguments:, location:, comments: []) + def initialize(message:, arguments:, block:, location:) @message = message @arguments = arguments + @block = block @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -2974,7 +3489,20 @@ def accept(visitor) end def child_nodes - [message, arguments] + [message, arguments, block] + end + + def copy(message: nil, arguments: nil, block: nil, location: nil) + node = + Command.new( + message: message || self.message, + arguments: arguments || self.arguments, + block: block || self.block, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -2983,6 +3511,7 @@ def deconstruct_keys(_keys) { message: message, arguments: arguments, + block: block, location: location, comments: comments } @@ -2993,20 +3522,47 @@ def format(q) q.format(message) align(q, self) { q.format(arguments) } end + + q.format(block) if block + end + + def ===(other) + other.is_a?(Command) && message === other.message && + arguments === other.arguments && block === other.block + end + + def arity + arguments.arity end private def align(q, node, &block) - case node.arguments - in Args[parts: [Def | Defs | DefEndless]] - q.text(" ") - yield - in Args[parts: [IfOp]] - q.if_flat { q.text(" ") } - yield - in Args[parts: [Command => command]] - align(q, command, &block) + arguments = node.arguments + + if arguments.is_a?(Args) + parts = arguments.parts + + if parts.size == 1 + part = parts.first + + case part + when DefNode + q.text(" ") + yield + when IfOp + q.if_flat { q.text(" ") } + yield + when Command + align(q, part, &block) + else + q.text(" ") + q.nest(message.value.length + 1) { yield } + end + else + q.text(" ") + q.nest(message.value.length + 1) { yield } + end else q.text(" ") q.nest(message.value.length + 1) { yield } @@ -3020,18 +3576,21 @@ def align(q, node, &block) # object.method argument # class CommandCall < Node - # [untyped] the receiver of the message + # [nil | Node] the receiver of the message attr_reader :receiver - # [:"::" | Op | Period] the operator used to send the message + # [nil | :"::" | Op | Period] the operator used to send the message attr_reader :operator - # [Const | Ident | Op] the message being send + # [:call | Const | Ident | Op] the message being send attr_reader :message - # [nil | Args] the arguments going along with the message + # [nil | Args | ArgParen] the arguments going along with the message attr_reader :arguments + # [nil | BlockNode] the block associated with this method call + attr_reader :block + # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments @@ -3040,15 +3599,16 @@ def initialize( operator:, message:, arguments:, - location:, - comments: [] + block:, + location: ) @receiver = receiver @operator = operator @message = message @arguments = arguments + @block = block @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3056,7 +3616,29 @@ def accept(visitor) end def child_nodes - [receiver, message, arguments] + [receiver, message, arguments, block] + end + + def copy( + receiver: nil, + operator: nil, + message: nil, + arguments: nil, + block: nil, + location: nil + ) + node = + CommandCall.new( + receiver: receiver || self.receiver, + operator: operator || self.operator, + message: message || self.message, + arguments: arguments || self.arguments, + block: block || self.block, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -3067,60 +3649,67 @@ def deconstruct_keys(_keys) operator: operator, message: message, arguments: arguments, + block: block, location: location, comments: comments } end def format(q) + message = self.message + arguments = self.arguments + block = self.block + q.group do doc = q.nest(0) do q.format(receiver) - q.format(CallOperatorFormatter.new(operator), stackable: false) - q.format(message) + + # If there are leading comments on the message then we know we have + # a newline in the source that is forcing these things apart. In + # this case we will have to use a trailing operator. + if message != :call && message.comments.any?(&:leading?) + q.format(CallOperatorFormatter.new(operator), stackable: false) + q.indent do + q.breakable_empty + q.format(message) + end + else + q.format(CallOperatorFormatter.new(operator), stackable: false) + q.format(message) + end end - case arguments - in Args[parts: [IfOp]] - q.if_flat { q.text(" ") } - q.format(arguments) - in Args - q.text(" ") - q.nest(argument_alignment(q, doc)) { q.format(arguments) } - else - # If there are no arguments, print nothing. + # Format the arguments for this command call here. If there are no + # arguments, then print nothing. + if arguments + parts = arguments.parts + + if parts.length == 1 && parts.first.is_a?(IfOp) + q.if_flat { q.text(" ") } + q.format(arguments) + else + q.text(" ") + q.nest(argument_alignment(q, doc)) { q.format(arguments) } + end end end - end - - private - # This is a somewhat naive method that is attempting to sum up the width of - # the doc nodes that make up the given doc node. This is used to align - # content. - def doc_width(parent) - queue = [parent] - width = 0 - - until queue.empty? - doc = queue.shift + q.format(block) if block + end - case doc - when PrettyPrint::Text - width += doc.width - when PrettyPrint::Indent, PrettyPrint::Align, PrettyPrint::Group - queue = doc.contents + queue - when PrettyPrint::IfBreak - queue = doc.break_contents + queue - when PrettyPrint::Breakable - width = 0 - end - end + def ===(other) + other.is_a?(CommandCall) && receiver === other.receiver && + operator === other.operator && message === other.message && + arguments === other.arguments && block === other.block + end - width + def arity + arguments&.arity || 0 end + private + def argument_alignment(q, doc) # Very special handling case for rspec matchers. In general with rspec # matchers you expect to see something like: @@ -3138,7 +3727,7 @@ def argument_alignment(q, doc) if %w[to not_to to_not].include?(message.value) 0 else - width = doc_width(doc) + 1 + width = q.last_position(doc) + 1 width > (q.maxwidth / 2) ? 0 : width end end @@ -3183,7 +3772,7 @@ def trailing? end def ignore? - value[1..].strip == "stree-ignore" + value.match?(/\A#\s*stree-ignore\s*\z/) end def comments @@ -3198,6 +3787,14 @@ def child_nodes [] end + def copy(value: nil, inline: nil, location: nil) + Comment.new( + value: value || self.value, + inline: inline || self.inline, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3207,6 +3804,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Comment) && value === other.value && inline === other.inline + end end # Const represents a literal value that _looks_ like a constant. This could @@ -3230,10 +3831,10 @@ class Const < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3244,6 +3845,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Const.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3253,6 +3865,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Const) && value === other.value + end end # ConstPathField represents the child node of some kind of assignment. It @@ -3262,7 +3878,7 @@ def format(q) # object::Const = value # class ConstPathField < Node - # [untyped] the source of the constant + # [Node] the source of the constant attr_reader :parent # [Const] the constant itself @@ -3271,11 +3887,11 @@ class ConstPathField < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parent:, constant:, location:, comments: []) + def initialize(parent:, constant:, location:) @parent = parent @constant = constant @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3286,6 +3902,18 @@ def child_nodes [parent, constant] end + def copy(parent: nil, constant: nil, location: nil) + node = + ConstPathField.new( + parent: parent || self.parent, + constant: constant || self.constant, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3302,6 +3930,11 @@ def format(q) q.text("::") q.format(constant) end + + def ===(other) + other.is_a?(ConstPathField) && parent === other.parent && + constant === other.constant + end end # ConstPathRef represents referencing a constant by a path. @@ -3309,7 +3942,7 @@ def format(q) # object::Const # class ConstPathRef < Node - # [untyped] the source of the constant + # [Node] the source of the constant attr_reader :parent # [Const] the constant itself @@ -3318,11 +3951,11 @@ class ConstPathRef < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parent:, constant:, location:, comments: []) + def initialize(parent:, constant:, location:) @parent = parent @constant = constant @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3333,6 +3966,18 @@ def child_nodes [parent, constant] end + def copy(parent: nil, constant: nil, location: nil) + node = + ConstPathRef.new( + parent: parent || self.parent, + constant: constant || self.constant, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3349,6 +3994,11 @@ def format(q) q.text("::") q.format(constant) end + + def ===(other) + other.is_a?(ConstPathRef) && parent === other.parent && + constant === other.constant + end end # ConstRef represents the name of the constant being used in a class or module @@ -3364,10 +4014,10 @@ class ConstRef < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, location:, comments: []) + def initialize(constant:, location:) @constant = constant @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3378,6 +4028,17 @@ def child_nodes [constant] end + def copy(constant: nil, location: nil) + node = + ConstRef.new( + constant: constant || self.constant, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3387,6 +4048,10 @@ def deconstruct_keys(_keys) def format(q) q.format(constant) end + + def ===(other) + other.is_a?(ConstRef) && constant === other.constant + end end # CVar represents the use of a class variable. @@ -3400,10 +4065,10 @@ class CVar < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3414,6 +4079,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + CVar.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3423,31 +4099,44 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(CVar) && value === other.value + end end # Def represents defining a regular method on the current self object. # # def method(param) result end + # def object.method(param) result end # - class Def < Node + class DefNode < Node + # [nil | Node] the target where the method is being defined + attr_reader :target + + # [nil | Op | Period] the operator being used to declare the method + attr_reader :operator + # [Backtick | Const | Ident | Kw | Op] the name of the method attr_reader :name - # [Params | Paren] the parameter declaration for the method + # [nil | Params | Paren] the parameter declaration for the method attr_reader :params - # [BodyStmt] the expressions to be executed by the method + # [BodyStmt | Node] the expressions to be executed by the method attr_reader :bodystmt # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(name:, params:, bodystmt:, location:, comments: []) + def initialize(target:, operator:, name:, params:, bodystmt:, location:) + @target = target + @operator = operator @name = name @params = params @bodystmt = bodystmt @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3455,13 +4144,37 @@ def accept(visitor) end def child_nodes - [name, params, bodystmt] + [target, operator, name, params, bodystmt] + end + + def copy( + target: nil, + operator: nil, + name: nil, + params: nil, + bodystmt: nil, + location: nil + ) + node = + DefNode.new( + target: target || self.target, + operator: operator || self.operator, + name: name || self.name, + params: params || self.params, + bodystmt: bodystmt || self.bodystmt, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes def deconstruct_keys(_keys) { + target: target, + operator: operator, name: name, params: params, bodystmt: bodystmt, @@ -3471,116 +4184,74 @@ def deconstruct_keys(_keys) end def format(q) + params = self.params + bodystmt = self.bodystmt + q.group do q.group do - q.text("def ") + q.text("def") + q.text(" ") if target || name.comments.empty? + + if target + q.format(target) + q.format(CallOperatorFormatter.new(operator), stackable: false) + end + q.format(name) - if !params.is_a?(Params) || !params.empty? || params.comments.any? + case params + when Paren q.format(params) + when Params + q.format(params) if !params.empty? || params.comments.any? end end - unless bodystmt.empty? - q.indent do - q.breakable(force: true) - q.format(bodystmt) + if endless? + q.text(" =") + q.group do + q.indent do + q.breakable_space + q.format(bodystmt) + end + end + else + unless bodystmt.empty? + q.indent do + q.breakable_force + q.format(bodystmt) + end end - end - q.breakable(force: true) - q.text("end") + q.breakable_force + q.text("end") + end end end - end - - # DefEndless represents defining a single-line method since Ruby 3.0+. - # - # def method = result - # - class DefEndless < Node - # [untyped] the target where the method is being defined - attr_reader :target - - # [Op | Period] the operator being used to declare the method - attr_reader :operator - - # [Backtick | Const | Ident | Kw | Op] the name of the method - attr_reader :name - - # [nil | Params | Paren] the parameter declaration for the method - attr_reader :paren - - # [untyped] the expression to be executed by the method - attr_reader :statement - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize( - target:, - operator:, - name:, - paren:, - statement:, - location:, - comments: [] - ) - @target = target - @operator = operator - @name = name - @paren = paren - @statement = statement - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_def_endless(self) - end - def child_nodes - [target, operator, name, paren, statement] + def ===(other) + other.is_a?(DefNode) && target === other.target && + operator === other.operator && name === other.name && + params === other.params && bodystmt === other.bodystmt end - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { - target: target, - operator: operator, - name: name, - paren: paren, - statement: statement, - location: location, - comments: comments - } + # Returns true if the method was found in the source in the "endless" form, + # i.e. where the method body is defined using the `=` operator after the + # method name and parameters. + def endless? + !bodystmt.is_a?(BodyStmt) end - def format(q) - q.group do - q.text("def ") - - if target - q.format(target) - q.format(CallOperatorFormatter.new(operator), stackable: false) - end - - q.format(name) - - if paren - params = paren - params = params.contents if params.is_a?(Paren) - q.format(paren) unless params.empty? - end + def arity + params = self.params - q.text(" =") - q.group do - q.indent do - q.breakable - q.format(statement) - end - end + case params + when Params + params.arity + when Paren + params.contents.arity + else + 0..0 end end end @@ -3591,16 +4262,16 @@ def format(q) # defined?(variable) # class Defined < Node - # [untyped] the value being sent to the keyword + # [Node] the value being sent to the keyword attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3611,6 +4282,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + Defined.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3618,73 +4300,100 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "defined?(", ")") do + q.text("defined?(") + q.group do q.indent do - q.breakable("") + q.breakable_empty q.format(value) end - q.breakable("") + q.breakable_empty end + q.text(")") + end + + def ===(other) + other.is_a?(Defined) && value === other.value end end - # Defs represents defining a singleton method on an object. + # Block represents passing a block to a method call using the +do+ and +end+ + # keywords or the +{+ and +}+ operators. # - # def object.method(param) result end + # method do |value| + # end # - class Defs < Node - # [untyped] the target where the method is being defined - attr_reader :target + # method { |value| } + # + class BlockNode < Node + # Formats the opening brace or keyword of a block. + class BlockOpenFormatter + # [String] the actual output that should be printed + attr_reader :text - # [Op | Period] the operator being used to declare the method - attr_reader :operator + # [LBrace | Keyword] the node that is being represented + attr_reader :node - # [Backtick | Const | Ident | Kw | Op] the name of the method - attr_reader :name + def initialize(text, node) + @text = text + @node = node + end - # [Params | Paren] the parameter declaration for the method - attr_reader :params + def comments + node.comments + end + + def format(q) + q.text(text) + end + end + + # [LBrace | Kw] the left brace or the do keyword that opens this block + attr_reader :opening + + # [nil | BlockVar] the optional variable declaration within this block + attr_reader :block_var - # [BodyStmt] the expressions to be executed by the method + # [BodyStmt | Statements] the expressions to be executed within this block attr_reader :bodystmt # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - target:, - operator:, - name:, - params:, - bodystmt:, - location:, - comments: [] - ) - @target = target - @operator = operator - @name = name - @params = params + def initialize(opening:, block_var:, bodystmt:, location:) + @opening = opening + @block_var = block_var @bodystmt = bodystmt @location = location - @comments = comments + @comments = [] end def accept(visitor) - visitor.visit_defs(self) + visitor.visit_block(self) end def child_nodes - [target, operator, name, params, bodystmt] + [opening, block_var, bodystmt] + end + + def copy(opening: nil, block_var: nil, bodystmt: nil, location: nil) + node = + BlockNode.new( + opening: opening || self.opening, + block_var: block_var || self.block_var, + bodystmt: bodystmt || self.bodystmt, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes def deconstruct_keys(_keys) { - target: target, - operator: operator, - name: name, - params: params, + opening: opening, + block_var: block_var, bodystmt: bodystmt, location: location, comments: comments @@ -3692,112 +4401,148 @@ def deconstruct_keys(_keys) end def format(q) - q.group do - q.group do - q.text("def ") - q.format(target) - q.format(CallOperatorFormatter.new(operator), stackable: false) - q.format(name) - - if !params.is_a?(Params) || !params.empty? || params.comments.any? - q.format(params) - end + # If this is nested anywhere inside of a Command or CommandCall node, then + # we can't change which operators we're using for the bounds of the block. + break_opening, break_closing, flat_opening, flat_closing = + if unchangeable_bounds?(q) + block_close = keywords? ? "end" : "}" + [opening.value, block_close, opening.value, block_close] + elsif forced_do_end_bounds?(q) + %w[do end do end] + elsif forced_brace_bounds?(q) + %w[{ } { }] + else + %w[do end { }] end - unless bodystmt.empty? - q.indent do - q.breakable(force: true) - q.format(bodystmt) - end - end + # If the receiver of this block a Command or CommandCall node, then there + # are no parentheses around the arguments to that command, so we need to + # break the block. + case q.parent + when nil, Command, CommandCall + q.break_parent + format_break(q, break_opening, break_closing) + return + end - q.breakable(force: true) - q.text("end") + q.group do + q + .if_break { format_break(q, break_opening, break_closing) } + .if_flat { format_flat(q, flat_opening, flat_closing) } end end - end - - # DoBlock represents passing a block to a method call using the +do+ and +end+ - # keywords. - # - # method do |value| - # end - # - class DoBlock < Node - # [Kw] the do keyword that opens this block - attr_reader :keyword - - # [nil | BlockVar] the optional variable declaration within this block - attr_reader :block_var - - # [BodyStmt] the expressions to be executed within this block - attr_reader :bodystmt - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - def initialize(keyword:, block_var:, bodystmt:, location:, comments: []) - @keyword = keyword - @block_var = block_var - @bodystmt = bodystmt - @location = location - @comments = comments + def ===(other) + other.is_a?(BlockNode) && opening === other.opening && + block_var === other.block_var && bodystmt === other.bodystmt end - def accept(visitor) - visitor.visit_do_block(self) + def keywords? + opening.is_a?(Kw) end - def child_nodes - [keyword, block_var, bodystmt] + def arity + case block_var + when BlockVar + block_var.params.arity + else + 0..0 + end end - alias deconstruct child_nodes + private - def deconstruct_keys(_keys) - { - keyword: keyword, - block_var: block_var, - bodystmt: bodystmt, - location: location, - comments: comments - } + # If this is nested anywhere inside certain nodes, then we can't change + # which operators/keywords we're using for the bounds of the block. + def unchangeable_bounds?(q) + q.parents.any? do |parent| + # If we hit a statements, then we're safe to use whatever since we + # know for certain we're going to get split over multiple lines + # anyway. + case parent + when Statements, ArgParen + break false + when Command, CommandCall + true + else + false + end + end + end + + # If we're a sibling of a control-flow keyword, then we're going to have to + # use the do..end bounds. + def forced_do_end_bounds?(q) + case q.parent&.call + when Break, Next, ReturnNode, Super + true + else + false + end end - def format(q) - BlockFormatter.new(self, keyword, "end", bodystmt).format(q) + # If we're the predicate of a loop or conditional, then we're going to have + # to go with the {..} bounds. + def forced_brace_bounds?(q) + previous = nil + q.parents.any? do |parent| + case parent + when Paren, Statements + # If we hit certain breakpoints then we know we're safe. + return false + when IfNode, IfOp, UnlessNode, WhileNode, UntilNode + return true if parent.predicate == previous + end + + previous = parent + false + end end - end - # Responsible for formatting Dot2 and Dot3 nodes. - class DotFormatter - # [String] the operator to display - attr_reader :operator + def format_break(q, break_opening, break_closing) + q.text(" ") + q.format(BlockOpenFormatter.new(break_opening, opening), stackable: false) - # [Dot2 | Dot3] the node that is being formatter - attr_reader :node + if block_var + q.text(" ") + q.format(block_var) + end - def initialize(operator, node) - @operator = operator - @node = node + unless bodystmt.empty? + q.indent do + q.breakable_space + q.format(bodystmt) + end + end + + q.breakable_space + q.text(break_closing) end - def format(q) - space = [If, IfMod, Unless, UnlessMod].include?(q.parent.class) + def format_flat(q, flat_opening, flat_closing) + q.text(" ") + q.format(BlockOpenFormatter.new(flat_opening, opening), stackable: false) - left = node.left - right = node.right + if block_var + q.breakable_space + q.format(block_var) + q.breakable_space + end - q.format(left) if left - q.text(" ") if space - q.text(operator) - q.text(" ") if space - q.format(right) if right + if bodystmt.empty? + q.text(" ") if flat_opening == "do" + else + q.breakable_space unless block_var + q.format(bodystmt) + q.breakable_space + end + + q.text(flat_closing) end end - # Dot2 represents using the .. operator between two expressions. Usually this - # is to create a range object. + # RangeNode represents using the .. or the ... operator between two + # expressions. Usually this is to create a range object. # # 1..2 # @@ -3807,87 +4552,76 @@ def format(q) # end # # One of the sides of the expression may be nil, but not both. - class Dot2 < Node - # [nil | untyped] the left side of the expression + class RangeNode < Node + # [nil | Node] the left side of the expression attr_reader :left - # [nil | untyped] the right side of the expression + # [Op] the operator used for this range + attr_reader :operator + + # [nil | Node] the right side of the expression attr_reader :right # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(left:, right:, location:, comments: []) + def initialize(left:, operator:, right:, location:) @left = left + @operator = operator @right = right @location = location - @comments = comments + @comments = [] end def accept(visitor) - visitor.visit_dot2(self) + visitor.visit_range(self) end def child_nodes [left, right] end - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { left: left, right: right, location: location, comments: comments } - end + def copy(left: nil, operator: nil, right: nil, location: nil) + node = + RangeNode.new( + left: left || self.left, + operator: operator || self.operator, + right: right || self.right, + location: location || self.location + ) - def format(q) - DotFormatter.new("..", self).format(q) + node.comments.concat(comments.map(&:copy)) + node end - end - - # Dot3 represents using the ... operator between two expressions. Usually this - # is to create a range object. It's effectively the same event as the Dot2 - # node but with this operator you're asking Ruby to omit the final value. - # - # 1...2 - # - # Like Dot2 it can also be used to create a flip-flop. - # - # if value == 5 ... value == 10 - # end - # - # One of the sides of the expression may be nil, but not both. - class Dot3 < Node - # [nil | untyped] the left side of the expression - attr_reader :left - - # [nil | untyped] the right side of the expression - attr_reader :right - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(left:, right:, location:, comments: []) - @left = left - @right = right - @location = location - @comments = comments - end + alias deconstruct child_nodes - def accept(visitor) - visitor.visit_dot3(self) + def deconstruct_keys(_keys) + { + left: left, + operator: operator, + right: right, + location: location, + comments: comments + } end - def child_nodes - [left, right] - end + def format(q) + q.format(left) if left - alias deconstruct child_nodes + case q.parent + when IfNode, UnlessNode + q.text(" #{operator.value} ") + else + q.text(operator.value) + end - def deconstruct_keys(_keys) - { left: left, right: right, location: location, comments: comments } + q.format(right) if right end - def format(q) - DotFormatter.new("...", self).format(q) + def ===(other) + other.is_a?(RangeNode) && left === other.left && + operator === other.operator && right === other.right end end @@ -3902,9 +4636,9 @@ module Quotes # whichever quote the user chose. (If they chose single quotes, then double # quoting would activate the escape sequence, and if they chose double # quotes, then single quotes would deactivate it.) - def self.locked?(node) + def self.locked?(node, quote) node.parts.any? do |part| - !part.is_a?(TStringContent) || part.value.match?(/\\|#[@${]/) + !part.is_a?(TStringContent) || part.value.match?(/\\|#[@${]|#{quote}/) end end @@ -3946,17 +4680,17 @@ class DynaSymbol < Node # dynamic symbol attr_reader :parts - # [String] the quote used to delimit the dynamic symbol + # [nil | String] the quote used to delimit the dynamic symbol attr_reader :quote # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, quote:, location:, comments: []) + def initialize(parts:, quote:, location:) @parts = parts @quote = quote @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -3967,6 +4701,18 @@ def child_nodes parts end + def copy(parts: nil, quote: nil, location: nil) + node = + DynaSymbol.new( + parts: parts || self.parts, + quote: quote || self.quote, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -3976,19 +4722,35 @@ def deconstruct_keys(_keys) def format(q) opening_quote, closing_quote = quotes(q) - q.group(0, opening_quote, closing_quote) do + q.text(opening_quote) + q.group do parts.each do |part| if part.is_a?(TStringContent) value = Quotes.normalize(part.value, closing_quote) - separator = -> { q.breakable(force: true, indent: false) } - q.seplist(value.split(/\r?\n/, -1), separator) do |text| - q.text(text) + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end + + q.text(line) end + + q.breakable_return if value.end_with?("\n") else q.format(part) end end end + q.text(closing_quote) + end + + def ===(other) + other.is_a?(DynaSymbol) && ArrayMatch.call(parts, other.parts) && + quote === other.quote end private @@ -4019,12 +4781,12 @@ def quotes(q) if matched [quote, matching] - elsif Quotes.locked?(self) + elsif Quotes.locked?(self, q.quote) ["#{":" unless hash_key}'", "'"] else ["#{":" unless hash_key}#{q.quote}", q.quote] end - elsif Quotes.locked?(self) + elsif Quotes.locked?(self, q.quote) if quote.start_with?(":") [hash_key ? quote[1..] : quote, quote[1..]] else @@ -4052,11 +4814,11 @@ class Else < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(keyword:, statements:, location:, comments: []) + def initialize(keyword:, statements:, location:) @keyword = keyword @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4067,6 +4829,18 @@ def child_nodes [keyword, statements] end + def copy(keyword: nil, statements: nil, location: nil) + node = + Else.new( + keyword: keyword || self.keyword, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4084,12 +4858,17 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end end end + + def ===(other) + other.is_a?(Else) && keyword === other.keyword && + statements === other.statements + end end # Elsif represents another clause in an +if+ or +unless+ chain. @@ -4099,7 +4878,7 @@ def format(q) # end # class Elsif < Node - # [untyped] the expression to be checked + # [Node] the expression to be checked attr_reader :predicate # [Statements] the expressions to be executed @@ -4111,18 +4890,12 @@ class Elsif < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - predicate:, - statements:, - consequent:, - location:, - comments: [] - ) + def initialize(predicate:, statements:, consequent:, location:) @predicate = predicate @statements = statements @consequent = consequent @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4133,6 +4906,19 @@ def child_nodes [predicate, statements, consequent] end + def copy(predicate: nil, statements: nil, consequent: nil, location: nil) + node = + Elsif.new( + predicate: predicate || self.predicate, + statements: statements || self.statements, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4154,19 +4940,24 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent q.group do - q.breakable(force: true) + q.breakable_force q.format(consequent) end end end end + + def ===(other) + other.is_a?(Elsif) && predicate === other.predicate && + statements === other.statements && consequent === other.consequent + end end # EmbDoc represents a multi-line comment. @@ -4183,6 +4974,25 @@ class EmbDoc < Node def initialize(value:, location:) @value = value @location = location + + @leading = false + @trailing = false + end + + def leading! + @leading = true + end + + def leading? + @leading + end + + def trailing! + @trailing = true + end + + def trailing? + @trailing end def inline? @@ -4205,6 +5015,13 @@ def child_nodes [] end + def copy(value: nil, location: nil) + EmbDoc.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4212,9 +5029,19 @@ def deconstruct_keys(_keys) end def format(q) - q.trim + if (q.parent.is_a?(DefNode) && q.parent.endless?) || + q.parent.is_a?(Statements) + q.trim + else + q.breakable_return + end + q.text(value) end + + def ===(other) + other.is_a?(EmbDoc) && value === other.value + end end # EmbExprBeg represents the beginning token for using interpolation inside of @@ -4240,11 +5067,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + EmbExprBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(EmbExprBeg) && value === other.value + end end # EmbExprEnd represents the ending token for using interpolation inside of a @@ -4270,11 +5108,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + EmbExprEnd.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(EmbExprEnd) && value === other.value + end end # EmbVar represents the use of shorthand interpolation for an instance, class, @@ -4302,11 +5151,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + EmbVar.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(EmbVar) && value === other.value + end end # Ensure represents the use of the +ensure+ keyword and its subsequent @@ -4326,11 +5186,11 @@ class Ensure < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(keyword:, statements:, location:, comments: []) + def initialize(keyword:, statements:, location:) @keyword = keyword @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4341,6 +5201,18 @@ def child_nodes [keyword, statements] end + def copy(keyword: nil, statements: nil, location: nil) + node = + Ensure.new( + keyword: keyword || self.keyword, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4357,11 +5229,16 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end end + + def ===(other) + other.is_a?(Ensure) && keyword === other.keyword && + statements === other.statements + end end # ExcessedComma represents a trailing comma in a list of block parameters. It @@ -4381,10 +5258,10 @@ class ExcessedComma < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4395,72 +5272,29 @@ def child_nodes [] end - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } - end - - def format(q) - q.text(value) - end - end - - # FCall represents the piece of a method call that comes before any arguments - # (i.e., just the name of the method). It is used in places where the parser - # is sure that it is a method call and not potentially a local variable. - # - # method(argument) - # - # In the above example, it's referring to the +method+ segment. - class FCall < Node - # [Const | Ident] the name of the method - attr_reader :value - - # [nil | ArgParen | Args] the arguments to the method call - attr_reader :arguments - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(value:, arguments:, location:, comments: []) - @value = value - @arguments = arguments - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_fcall(self) - end + def copy(value: nil, location: nil) + node = + ExcessedComma.new( + value: value || self.value, + location: location || self.location + ) - def child_nodes - [value, arguments] + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes def deconstruct_keys(_keys) - { - value: value, - arguments: arguments, - location: location, - comments: comments - } + { value: value, location: location, comments: comments } end def format(q) - q.format(value) + q.text(value) + end - if arguments.is_a?(ArgParen) && arguments.arguments.nil? && - !value.is_a?(Const) - # If you're using an explicit set of parentheses on something that looks - # like a constant, then we need to match that in order to maintain valid - # Ruby. For example, you could do something like Foo(), on which we - # would need to keep the parentheses to make it look like a method call. - else - q.format(arguments) - end + def ===(other) + other.is_a?(ExcessedComma) && value === other.value end end @@ -4470,7 +5304,7 @@ def format(q) # object.variable = value # class Field < Node - # [untyped] the parent object that owns the field being assigned + # [Node] the parent object that owns the field being assigned attr_reader :parent # [:"::" | Op | Period] the operator being used for the assignment @@ -4482,12 +5316,12 @@ class Field < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parent:, operator:, name:, location:, comments: []) + def initialize(parent:, operator:, name:, location:) @parent = parent @operator = operator @name = name @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4495,9 +5329,23 @@ def accept(visitor) end def child_nodes + operator = self.operator [parent, (operator if operator != :"::"), name] end + def copy(parent: nil, operator: nil, name: nil, location: nil) + node = + Field.new( + parent: parent || self.parent, + operator: operator || self.operator, + name: name || self.name, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4517,6 +5365,11 @@ def format(q) q.format(name) end end + + def ===(other) + other.is_a?(Field) && parent === other.parent && + operator === other.operator && name === other.name + end end # FloatLiteral represents a floating point number literal. @@ -4530,10 +5383,10 @@ class FloatLiteral < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4544,6 +5397,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + FloatLiteral.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4553,6 +5417,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(FloatLiteral) && value === other.value + end end # FndPtn represents matching against a pattern where you find a pattern in an @@ -4563,13 +5431,13 @@ def format(q) # end # class FndPtn < Node - # [nil | untyped] the optional constant wrapper + # [nil | VarRef | ConstPathRef] the optional constant wrapper attr_reader :constant # [VarField] the splat on the left-hand side attr_reader :left - # [Array[ untyped ]] the list of positional expressions in the pattern that + # [Array[ Node ]] the list of positional expressions in the pattern that # are being matched attr_reader :values @@ -4579,13 +5447,13 @@ class FndPtn < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, left:, values:, right:, location:, comments: []) + def initialize(constant:, left:, values:, right:, location:) @constant = constant @left = left @values = values @right = right @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4596,6 +5464,20 @@ def child_nodes [constant, left, *values, right] end + def copy(constant: nil, left: nil, values: nil, right: nil, location: nil) + node = + FndPtn.new( + constant: constant || self.constant, + left: left || self.left, + values: values || self.values, + right: right || self.right, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4611,18 +5493,34 @@ def deconstruct_keys(_keys) def format(q) q.format(constant) if constant - q.group(0, "[", "]") do - q.text("*") - q.format(left) - q.comma_breakable - q.seplist(values) { |value| q.format(value) } - q.comma_breakable + q.group do + q.text("[") - q.text("*") - q.format(right) + q.indent do + q.breakable_empty + + q.text("*") + q.format(left) + q.comma_breakable + + q.seplist(values) { |value| q.format(value) } + q.comma_breakable + + q.text("*") + q.format(right) + end + + q.breakable_empty + q.text("]") end end + + def ===(other) + other.is_a?(FndPtn) && constant === other.constant && + left === other.left && ArrayMatch.call(values, other.values) && + right === other.right + end end # For represents using a +for+ loop. @@ -4635,7 +5533,7 @@ class For < Node # pull values out of the object being enumerated attr_reader :index - # [untyped] the object being enumerated in the loop + # [Node] the object being enumerated in the loop attr_reader :collection # [Statements] the statements to be executed @@ -4644,12 +5542,12 @@ class For < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(index:, collection:, statements:, location:, comments: []) + def initialize(index:, collection:, statements:, location:) @index = index @collection = collection @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4660,6 +5558,19 @@ def child_nodes [index, collection, statements] end + def copy(index: nil, collection: nil, statements: nil, location: nil) + node = + For.new( + index: index || self.index, + collection: collection || self.collection, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4681,15 +5592,20 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end - q.breakable(force: true) + q.breakable_force q.text("end") end end + + def ===(other) + other.is_a?(For) && index === other.index && + collection === other.collection && statements === other.statements + end end # GVar represents a global variable literal. @@ -4703,10 +5619,10 @@ class GVar < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4717,6 +5633,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + GVar.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4726,6 +5653,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(GVar) && value === other.value + end end # HashLiteral represents a hash literal. @@ -4749,11 +5680,11 @@ def format(q) q.text("{") q.indent do lbrace.comments.each do |comment| - q.breakable(force: true) + q.breakable_force comment.format(q) end end - q.breakable(force: true) + q.breakable_force q.text("}") end end @@ -4762,25 +5693,37 @@ def format(q) # [LBrace] the left brace that opens this hash attr_reader :lbrace - # [Array[ AssocNew | AssocSplat ]] the optional contents of the hash + # [Array[ Assoc | AssocSplat ]] the optional contents of the hash attr_reader :assocs # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(lbrace:, assocs:, location:, comments: []) + def initialize(lbrace:, assocs:, location:) @lbrace = lbrace @assocs = assocs @location = location - @comments = comments + @comments = [] end def accept(visitor) visitor.visit_hash(self) end - def child_nodes - [lbrace] + assocs + def child_nodes + [lbrace].concat(assocs) + end + + def copy(lbrace: nil, assocs: nil, location: nil) + node = + HashLiteral.new( + lbrace: lbrace || self.lbrace, + assocs: assocs || self.assocs, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -4797,6 +5740,11 @@ def format(q) end end + def ===(other) + other.is_a?(HashLiteral) && lbrace === other.lbrace && + ArrayMatch.call(assocs, other.assocs) + end + def format_key(q, key) (@key_formatter ||= HashKeyFormatter.for(self)).format_key(q, key) end @@ -4818,13 +5766,14 @@ def format_contents(q) q.format(lbrace) if assocs.empty? - q.breakable("") + q.breakable_empty else q.indent do - q.breakable + q.breakable_space q.seplist(assocs) { |assoc| q.format(assoc) } + q.if_break { q.text(",") } if q.trailing_comma? end - q.breakable + q.breakable_space end q.text("}") @@ -4841,7 +5790,7 @@ class Heredoc < Node # [HeredocBeg] the opening of the heredoc attr_reader :beginning - # [String] the ending of the heredoc + # [HeredocEnd] the ending of the heredoc attr_reader :ending # [Integer] how far to dedent the heredoc @@ -4854,20 +5803,13 @@ class Heredoc < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - beginning:, - ending: nil, - dedent: 0, - parts: [], - location:, - comments: [] - ) + def initialize(beginning:, location:, ending: nil, dedent: 0, parts: []) @beginning = beginning @ending = ending @dedent = dedent @parts = parts @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4875,7 +5817,20 @@ def accept(visitor) end def child_nodes - [beginning, *parts] + [beginning, *parts, ending] + end + + def copy(beginning: nil, location: nil, ending: nil, parts: nil) + node = + Heredoc.new( + beginning: beginning || self.beginning, + location: location || self.location, + ending: ending || self.ending, + parts: parts || self.parts + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -4890,40 +5845,50 @@ def deconstruct_keys(_keys) } end - def format(q) - # This is a very specific behavior that should probably be included in the - # prettyprint module. It's when you want to force a newline, but don't - # want to force the break parent. - breakable = -> do - q.target << PrettyPrint::Breakable.new( - " ", - 1, - indent: false, - force: true - ) - end + # This is a very specific behavior where you want to force a newline, but + # don't want to force the break parent. + SEPARATOR = + PrettierPrint::Breakable.new(" ", 1, indent: false, force: true).freeze + def format(q) q.group do q.format(beginning) q.line_suffix(priority: Formatter::HEREDOC_PRIORITY) do q.group do - breakable.call + q.target << SEPARATOR parts.each do |part| if part.is_a?(TStringContent) - texts = part.value.split(/\r?\n/, -1) - q.seplist(texts, breakable) { |text| q.text(text) } + value = part.value + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.target << SEPARATOR + end + + q.text(line) + end + + q.target << SEPARATOR if value.end_with?("\n") else q.format(part) end end - q.text(ending) + q.format(ending) end end end end + + def ===(other) + other.is_a?(Heredoc) && beginning === other.beginning && + ending === other.ending && ArrayMatch.call(parts, other.parts) + end end # HeredocBeg represents the beginning declaration of a heredoc. @@ -4940,10 +5905,10 @@ class HeredocBeg < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -4954,6 +5919,71 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + HeredocBeg.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + + alias deconstruct child_nodes + + def deconstruct_keys(_keys) + { value: value, location: location, comments: comments } + end + + def format(q) + q.text(value) + end + + def ===(other) + other.is_a?(HeredocBeg) && value === other.value + end + end + + # HeredocEnd represents the closing declaration of a heredoc. + # + # <<~DOC + # contents + # DOC + # + # In the example above the HeredocEnd node represents the closing DOC. + class HeredocEnd < Node + # [String] the closing declaration of the heredoc + attr_reader :value + + # [Array[ Comment | EmbDoc ]] the comments attached to this node + attr_reader :comments + + def initialize(value:, location:) + @value = value + @location = location + @comments = [] + end + + def accept(visitor) + visitor.visit_heredoc_end(self) + end + + def child_nodes + [] + end + + def copy(value: nil, location: nil) + node = + HeredocEnd.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -4963,6 +5993,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(HeredocEnd) && value === other.value + end end # HshPtn represents matching against a hash pattern using the Ruby 2.7+ @@ -4978,7 +6012,7 @@ class KeywordFormatter # [Label] the keyword being used attr_reader :key - # [untyped] the optional value for the keyword + # [Node] the optional value for the keyword attr_reader :value def initialize(key, value) @@ -4991,7 +6025,7 @@ def comments end def format(q) - q.format(key) + HashKeyFormatter::Labels.new.format_key(q, key) if value q.text(" ") @@ -5019,11 +6053,11 @@ def format(q) end end - # [nil | untyped] the optional constant wrapper + # [nil | VarRef | ConstPathRef] the optional constant wrapper attr_reader :constant - # [Array[ [Label, untyped] ]] the set of tuples representing the keywords - # that should be matched against in the pattern + # [Array[ [DynaSymbol | Label, nil | Node] ]] the set of tuples + # representing the keywords that should be matched against in the pattern attr_reader :keywords # [nil | VarField] an optional parameter to gather up all remaining keywords @@ -5032,12 +6066,12 @@ def format(q) # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, keywords:, keyword_rest:, location:, comments: []) + def initialize(constant:, keywords:, keyword_rest:, location:) @constant = constant @keywords = keywords @keyword_rest = keyword_rest @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5048,6 +6082,19 @@ def child_nodes [constant, *keywords.flatten(1), keyword_rest] end + def copy(constant: nil, keywords: nil, keyword_rest: nil, location: nil) + node = + HshPtn.new( + constant: constant || self.constant, + keywords: keywords || self.keywords, + keyword_rest: keyword_rest || self.keyword_rest, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5063,15 +6110,7 @@ def deconstruct_keys(_keys) def format(q) parts = keywords.map { |(key, value)| KeywordFormatter.new(key, value) } parts << KeywordRestFormatter.new(keyword_rest) if keyword_rest - - contents = -> do - q.group { q.seplist(parts) { |part| q.format(part, stackable: false) } } - - # If there isn't a constant, and there's a blank keyword_rest, then we - # have an plain ** that needs to have a `then` after it in order to - # parse correctly on the next parse. - q.text(" then") if !constant && keyword_rest && keyword_rest.value.nil? - end + nested = PATTERNS.include?(q.parent.class) # If there is a constant, we're going to format to have the constant name # first and then use brackets. @@ -5080,10 +6119,10 @@ def format(q) q.format(constant) q.text("[") q.indent do - q.breakable("") - contents.call + q.breakable_empty + format_contents(q, parts, nested) end - q.breakable("") + q.breakable_empty q.text("]") end return @@ -5097,8 +6136,8 @@ def format(q) # If there's only one pair, then we'll just print the contents provided # we're not inside another pattern. - if !PATTERNS.include?(q.parent.class) && parts.size == 1 - contents.call + if !nested && parts.size == 1 + format_contents(q, parts, nested) return end @@ -5107,11 +6146,40 @@ def format(q) q.group do q.text("{") q.indent do - q.breakable - contents.call + q.breakable_space + format_contents(q, parts, nested) end - q.breakable - q.text("}") + + if q.target_ruby_version < Formatter::SemanticVersion.new("2.7.3") + q.text(" }") + else + q.breakable_space + q.text("}") + end + end + end + + def ===(other) + other.is_a?(HshPtn) && constant === other.constant && + keywords.length == other.keywords.length && + keywords + .zip(other.keywords) + .all? { |left, right| ArrayMatch.call(left, right) } && + keyword_rest === other.keyword_rest + end + + private + + def format_contents(q, parts, nested) + keyword_rest = self.keyword_rest + + q.group { q.seplist(parts) { |part| q.format(part, stackable: false) } } + + # If there isn't a constant, and there's a blank keyword_rest, then we + # have an plain ** that needs to have a `then` after it in order to + # parse correctly on the next parse. + if !constant && keyword_rest && keyword_rest.value.nil? && !nested + q.text(" then") end end end @@ -5132,10 +6200,10 @@ class Ident < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5146,6 +6214,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Ident.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5155,6 +6234,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Ident) && value === other.value + end end # If the predicate of a conditional or loop contains an assignment (in which @@ -5166,8 +6249,12 @@ def self.call(parent) queue = [parent] while (node = queue.shift) - return true if [Assign, MAssign, OpAssign].include?(node.class) - queue += node.child_nodes + case node + when Assign, MAssign, OpAssign + return true + else + node.child_nodes.each { |child| queue << child if child } + end end false @@ -5182,27 +6269,36 @@ def self.call(parent) module Ternaryable class << self def call(q, node) - case q.parents.take(2)[1] - in Paren[contents: Statements[body: [node]]] - # If this is a conditional inside of a parentheses as the only - # content, then we don't want to transform it into a ternary. - # Presumably the user wanted it to be an explicit conditional because - # there are parentheses around it. So we'll just leave it in place. - false - else - # Otherwise, we're going to check the conditional for certain cases. - case node - in predicate: Assign | Command | CommandCall | MAssign | OpAssign - false - in { - statements: { body: [truthy] }, - consequent: Else[statements: { body: [falsy] }] - } - ternaryable?(truthy) && ternaryable?(falsy) - else - false - end + return false if ENV["STREE_FAST_FORMAT"] || q.disable_auto_ternary? + + # If this is a conditional inside of a parentheses as the only content, + # then we don't want to transform it into a ternary. Presumably the user + # wanted it to be an explicit conditional because there are parentheses + # around it. So we'll just leave it in place. + grandparent = q.grandparent + if grandparent.is_a?(Paren) && (body = grandparent.contents.body) && + body.length == 1 && body.first == node + return false end + + # Otherwise, we'll check the type of predicate. For certain nodes we + # want to force it to not be a ternary, like if the predicate is an + # assignment because it's hard to read. + case node.predicate + when Assign, Binary, Command, CommandCall, MAssign, OpAssign + return false + when Not + return false unless node.predicate.parentheses? + end + + # If there's no Else, then this can't be represented as a ternary. + return false unless node.consequent.is_a?(Else) + + truthy_body = node.statements.body + falsy_body = node.consequent.statements.body + + (truthy_body.length == 1) && ternaryable?(truthy_body.first) && + (falsy_body.length == 1) && ternaryable?(falsy_body.first) end private @@ -5211,24 +6307,23 @@ def call(q, node) # parentheses around them. In this case we say they cannot be ternaried # and default instead to breaking them into multiple lines. def ternaryable?(statement) - # This is a list of nodes that should not be allowed to be a part of a - # ternary clause. - no_ternary = [ - Alias, Assign, Break, Command, CommandCall, Heredoc, If, IfMod, IfOp, - Lambda, MAssign, Next, OpAssign, RescueMod, Return, Return0, Super, - Undef, Unless, UnlessMod, Until, UntilMod, VarAlias, VoidStmt, While, - WhileMod, Yield, Yield0, ZSuper - ] - - # Here we're going to check that the only statement inside the - # statements node is no a part of our denied list of nodes that can be - # ternaries. - # - # If the user is using one of the lower precedence "and" or "or" - # operators, then we can't use a ternary expression as it would break - # the flow control. - !no_ternary.include?(statement.class) && - !(statement.is_a?(Binary) && %i[and or].include?(statement.operator)) + case statement + when AliasNode, Assign, Break, Command, CommandCall, Defined, Heredoc, + IfNode, IfOp, Lambda, MAssign, Next, OpAssign, RescueMod, + ReturnNode, Super, Undef, UnlessNode, UntilNode, VoidStmt, + WhileNode, YieldNode, ZSuper + # This is a list of nodes that should not be allowed to be a part of a + # ternary clause. + false + when Binary + # If the user is using one of the lower precedence "and" or "or" + # operators, then we can't use a ternary expression as it would break + # the flow control. + operator = statement.operator + operator != :and && operator != :or + else + true + end end end end @@ -5247,58 +6342,81 @@ def initialize(keyword, node) end def format(q) - # If we can transform this node into a ternary, then we're going to print - # a special version that uses the ternary operator if it fits on one line. - if Ternaryable.call(q, node) - format_ternary(q) - return - end - - # If the predicate of the conditional contains an assignment (in which - # case we can't know for certain that that assignment doesn't impact the - # statements inside the conditional) then we can't use the modifier form - # and we must use the block form. - if ContainsAssignment.call(node.predicate) - format_break(q, force: true) - return - end + if node.modifier? + statement = node.statements.body[0] - if node.consequent || node.statements.empty? || contains_conditional? - q.group { format_break(q, force: true) } + if ContainsAssignment.call(statement) || q.parent.is_a?(In) + q.group { format_flat(q) } + else + q.group do + q + .if_break { format_break(q, force: false) } + .if_flat { format_flat(q) } + end + end else - q.group do - q - .if_break { format_break(q, force: false) } - .if_flat do - Parentheses.flat(q) do - q.format(node.statements) - q.text(" #{keyword} ") - q.format(node.predicate) + # If we can transform this node into a ternary, then we're going to + # print a special version that uses the ternary operator if it fits on + # one line. + if Ternaryable.call(q, node) + format_ternary(q) + return + end + + # If the predicate of the conditional contains an assignment (in which + # case we can't know for certain that that assignment doesn't impact the + # statements inside the conditional) then we can't use the modifier form + # and we must use the block form. + if ContainsAssignment.call(node.predicate) + format_break(q, force: true) + return + end + + if node.consequent || node.statements.empty? || contains_conditional? + q.group { format_break(q, force: true) } + else + q.group do + q + .if_break { format_break(q, force: false) } + .if_flat do + Parentheses.flat(q) do + q.format(node.statements) + q.text(" #{keyword} ") + q.format(node.predicate) + end end - end + end end end end private + def format_flat(q) + Parentheses.flat(q) do + q.format(node.statements.body[0]) + q.text(" #{keyword} ") + q.format(node.predicate) + end + end + def format_break(q, force:) q.text("#{keyword} ") q.nest(keyword.length + 1) { q.format(node.predicate) } unless node.statements.empty? q.indent do - q.breakable(force: force) + force ? q.breakable_force : q.breakable_space q.format(node.statements) end end if node.consequent - q.breakable(force: force) + force ? q.breakable_force : q.breakable_space q.format(node.consequent) end - q.breakable(force: force) + force ? q.breakable_force : q.breakable_space q.text("end") end @@ -5310,11 +6428,11 @@ def format_ternary(q) q.nest(keyword.length + 1) { q.format(node.predicate) } q.indent do - q.breakable + q.breakable_space q.format(node.statements) end - q.breakable + q.breakable_space q.group do q.format(node.consequent.keyword) q.indent do @@ -5322,14 +6440,13 @@ def format_ternary(q) # force it into the output but we _don't_ want to explicitly # break the parent. If a break-parent shows up in the tree, then # it's going to force it all the way up to the tree, which is - # going to negate the ternary. Maybe this should be an option in - # prettyprint? As in force: :no_break_parent or something. - q.target << PrettyPrint::Breakable.new(" ", 1, force: true) + # going to negate the ternary. + q.breakable(force: :skip_break_parent) q.format(node.consequent.statements) end end - q.breakable + q.breakable_space q.text("end") end .if_flat do @@ -5349,8 +6466,11 @@ def format_ternary(q) end def contains_conditional? - case node - in statements: { body: [If | IfMod | IfOp | Unless | UnlessMod] } + statements = node.statements.body + return false if statements.length != 1 + + case statements.first + when IfNode, IfOp, UnlessNode true else false @@ -5363,31 +6483,25 @@ def contains_conditional? # if predicate # end # - class If < Node - # [untyped] the expression to be checked + class IfNode < Node + # [Node] the expression to be checked attr_reader :predicate # [Statements] the expressions to be executed attr_reader :statements - # [nil, Elsif, Else] the next clause in the chain + # [nil | Elsif | Else] the next clause in the chain attr_reader :consequent # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - predicate:, - statements:, - consequent:, - location:, - comments: [] - ) + def initialize(predicate:, statements:, consequent:, location:) @predicate = predicate @statements = statements @consequent = consequent @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5398,6 +6512,19 @@ def child_nodes [predicate, statements, consequent] end + def copy(predicate: nil, statements: nil, consequent: nil, location: nil) + node = + IfNode.new( + predicate: predicate || self.predicate, + statements: statements || self.statements, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5413,6 +6540,16 @@ def deconstruct_keys(_keys) def format(q) ConditionalFormatter.new("if", self).format(q) end + + def ===(other) + other.is_a?(IfNode) && predicate === other.predicate && + statements === other.statements && consequent === other.consequent + end + + # Checks if the node was originally found in the modifier form. + def modifier? + predicate.location.start_char > statements.location.start_char + end end # IfOp represents a ternary clause. @@ -5420,24 +6557,24 @@ def format(q) # predicate ? truthy : falsy # class IfOp < Node - # [untyped] the expression to be checked + # [Node] the expression to be checked attr_reader :predicate - # [untyped] the expression to be executed if the predicate is truthy + # [Node] the expression to be executed if the predicate is truthy attr_reader :truthy - # [untyped] the expression to be executed if the predicate is falsy + # [Node] the expression to be executed if the predicate is falsy attr_reader :falsy # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(predicate:, truthy:, falsy:, location:, comments: []) + def initialize(predicate:, truthy:, falsy:, location:) @predicate = predicate @truthy = truthy @falsy = falsy @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5448,6 +6585,19 @@ def child_nodes [predicate, truthy, falsy] end + def copy(predicate: nil, truthy: nil, falsy: nil, location: nil) + node = + IfOp.new( + predicate: predicate || self.predicate, + truthy: truthy || self.truthy, + falsy: falsy || self.falsy, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5462,10 +6612,26 @@ def deconstruct_keys(_keys) def format(q) force_flat = [ - Alias, Assign, Break, Command, CommandCall, Heredoc, If, IfMod, IfOp, - Lambda, MAssign, Next, OpAssign, RescueMod, Return, Return0, Super, - Undef, Unless, UnlessMod, UntilMod, VarAlias, VoidStmt, WhileMod, Yield, - Yield0, ZSuper + AliasNode, + Assign, + Break, + Command, + CommandCall, + Heredoc, + IfNode, + IfOp, + Lambda, + MAssign, + Next, + OpAssign, + RescueMod, + ReturnNode, + Super, + Undef, + UnlessNode, + VoidStmt, + YieldNode, + ZSuper ] if q.parent.is_a?(Paren) || force_flat.include?(truthy.class) || @@ -5477,6 +6643,11 @@ def format(q) q.group { q.if_break { format_break(q) }.if_flat { format_flat(q) } } end + def ===(other) + other.is_a?(IfOp) && predicate === other.predicate && + truthy === other.truthy && falsy === other.falsy + end + private def format_break(q) @@ -5485,19 +6656,19 @@ def format_break(q) q.nest("if ".length) { q.format(predicate) } q.indent do - q.breakable + q.breakable_space q.format(truthy) end - q.breakable + q.breakable_space q.text("else") q.indent do - q.breakable + q.breakable_space q.format(falsy) end - q.breakable + q.breakable_space q.text("end") end end @@ -5506,103 +6677,17 @@ def format_flat(q) q.format(predicate) q.text(" ?") - q.breakable - q.format(truthy) - q.text(" :") - - q.breakable - q.format(falsy) - end - end - - # Formats an IfMod or UnlessMod node. - class ConditionalModFormatter - # [String] the keyword associated with this conditional - attr_reader :keyword - - # [IfMod | UnlessMod] the node that is being formatted - attr_reader :node - - def initialize(keyword, node) - @keyword = keyword - @node = node - end - - def format(q) - if ContainsAssignment.call(node.statement) || q.parent.is_a?(In) - q.group { format_flat(q) } - else - q.group { q.if_break { format_break(q) }.if_flat { format_flat(q) } } - end - end - - private - - def format_break(q) - q.text("#{keyword} ") - q.nest(keyword.length + 1) { q.format(node.predicate) } q.indent do - q.breakable - q.format(node.statement) - end - q.breakable - q.text("end") - end + q.breakable_space + q.format(truthy) + q.text(" :") - def format_flat(q) - Parentheses.flat(q) do - q.format(node.statement) - q.text(" #{keyword} ") - q.format(node.predicate) + q.breakable_space + q.format(falsy) end end end - # IfMod represents the modifier form of an +if+ statement. - # - # expression if predicate - # - class IfMod < Node - # [untyped] the expression to be executed - attr_reader :statement - - # [untyped] the expression to be checked - attr_reader :predicate - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(statement:, predicate:, location:, comments: []) - @statement = statement - @predicate = predicate - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_if_mod(self) - end - - def child_nodes - [statement, predicate] - end - - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { - statement: statement, - predicate: predicate, - location: location, - comments: comments - } - end - - def format(q) - ConditionalModFormatter.new("if", self).format(q) - end - end - # Imaginary represents an imaginary number literal. # # 1i @@ -5614,10 +6699,10 @@ class Imaginary < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5628,6 +6713,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Imaginary.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5637,6 +6733,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Imaginary) && value === other.value + end end # In represents using the +in+ keyword within the Ruby 2.7+ pattern matching @@ -5647,7 +6747,7 @@ def format(q) # end # class In < Node - # [untyped] the pattern to check against + # [Node] the pattern to check against attr_reader :pattern # [Statements] the expressions to execute if the pattern matched @@ -5659,12 +6759,12 @@ class In < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(pattern:, statements:, consequent:, location:, comments: []) + def initialize(pattern:, statements:, consequent:, location:) @pattern = pattern @statements = statements @consequent = consequent @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5675,6 +6775,19 @@ def child_nodes [pattern, statements, consequent] end + def copy(pattern: nil, statements: nil, consequent: nil, location: nil) + node = + In.new( + pattern: pattern || self.pattern, + statements: statements || self.statements, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5689,24 +6802,32 @@ def deconstruct_keys(_keys) def format(q) keyword = "in " + pattern = self.pattern + consequent = self.consequent q.group do q.text(keyword) q.nest(keyword.length) { q.format(pattern) } + q.text(" then") if pattern.is_a?(RangeNode) && pattern.right.nil? unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent - q.breakable(force: true) + q.breakable_force q.format(consequent) end end end + + def ===(other) + other.is_a?(In) && pattern === other.pattern && + statements === other.statements && consequent === other.consequent + end end # Int represents an integer number literal. @@ -5720,10 +6841,10 @@ class Int < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5734,6 +6855,14 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Int.new(value: value || self.value, location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5751,6 +6880,10 @@ def format(q) q.text(value) end end + + def ===(other) + other.is_a?(Int) && value === other.value + end end # IVar represents an instance variable literal. @@ -5764,10 +6897,10 @@ class IVar < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5778,6 +6911,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + IVar.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5787,6 +6931,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(IVar) && value === other.value + end end # Kw represents the use of a keyword. It can be almost anywhere in the syntax @@ -5806,13 +6954,17 @@ class Kw < Node # [String] the value of the keyword attr_reader :value + # [Symbol] the symbol version of the value + attr_reader :name + # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value + @name = value.to_sym @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5823,6 +6975,14 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Kw.new(value: value || self.value, location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5832,6 +6992,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Kw) && value === other.value + end end # KwRestParam represents defining a parameter in a method definition that @@ -5846,10 +7010,10 @@ class KwRestParam < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(name:, location:, comments: []) + def initialize(name:, location:) @name = name @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5860,6 +7024,17 @@ def child_nodes [name] end + def copy(name: nil, location: nil) + node = + KwRestParam.new( + name: name || self.name, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5870,6 +7045,10 @@ def format(q) q.text("**") q.format(name) if name end + + def ===(other) + other.is_a?(KwRestParam) && name === other.name + end end # Label represents the use of an identifier to associate with an object. You @@ -5892,10 +7071,10 @@ class Label < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5906,6 +7085,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Label.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5915,6 +7105,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Label) && value === other.value + end end # LabelEnd represents the end of a dynamic symbol. @@ -5941,11 +7135,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + LabelEnd.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(LabelEnd) && value === other.value + end end # Lambda represents using a lambda literal (not the lambda method call). @@ -5953,7 +7158,7 @@ def deconstruct_keys(_keys) # ->(value) { value * 2 } # class Lambda < Node - # [Params | Paren] the parameter declaration for this lambda + # [LambdaVar | Paren] the parameter declaration for this lambda attr_reader :params # [BodyStmt | Statements] the expressions to be executed in this lambda @@ -5962,11 +7167,11 @@ class Lambda < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(params:, statements:, location:, comments: []) + def initialize(params:, statements:, location:) @params = params @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -5977,6 +7182,18 @@ def child_nodes [params, statements] end + def copy(params: nil, statements: nil, location: nil) + node = + Lambda.new( + params: params || self.params, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -5989,9 +7206,14 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "->") do + params = self.params + + q.text("->") + q.group do if params.is_a?(Paren) q.format(params) unless params.contents.empty? + elsif params.empty? && params.comments.any? + q.format(params) elsif !params.empty? q.group do q.text("(") @@ -6003,27 +7225,106 @@ def format(q) q.text(" ") q .if_break do - force_parens = - q.parents.any? do |node| - node.is_a?(Command) || node.is_a?(CommandCall) - end + q.text("do") - q.text(force_parens ? "{" : "do") - q.indent do - q.breakable - q.format(statements) + unless statements.empty? + q.indent do + q.breakable_space + q.format(statements) + end end - q.breakable - q.text(force_parens ? "}" : "end") + q.breakable_space + q.text("end") end .if_flat do - q.text("{ ") - q.format(statements) - q.text(" }") + q.text("{") + + unless statements.empty? + q.text(" ") + q.format(statements) + q.text(" ") + end + + q.text("}") end end end + + def ===(other) + other.is_a?(Lambda) && params === other.params && + statements === other.statements + end + end + + # LambdaVar represents the parameters being declared for a lambda. Effectively + # this node is everything contained within the parentheses. This includes all + # of the various parameter types, as well as block-local variable + # declarations. + # + # -> (positional, optional = value, keyword:, █ local) do + # end + # + class LambdaVar < Node + # [Params] the parameters being declared with the block + attr_reader :params + + # [Array[ Ident ]] the list of block-local variable declarations + attr_reader :locals + + # [Array[ Comment | EmbDoc ]] the comments attached to this node + attr_reader :comments + + def initialize(params:, locals:, location:) + @params = params + @locals = locals + @location = location + @comments = [] + end + + def accept(visitor) + visitor.visit_lambda_var(self) + end + + def child_nodes + [params, *locals] + end + + def copy(params: nil, locals: nil, location: nil) + node = + LambdaVar.new( + params: params || self.params, + locals: locals || self.locals, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + + alias deconstruct child_nodes + + def deconstruct_keys(_keys) + { params: params, locals: locals, location: location, comments: comments } + end + + def empty? + params.empty? && locals.empty? + end + + def format(q) + q.format(params) + + if locals.any? + q.text("; ") + q.seplist(locals, BlockVar::SEPARATOR) { |local| q.format(local) } + end + end + + def ===(other) + other.is_a?(LambdaVar) && params === other.params && + ArrayMatch.call(locals, other.locals) + end end # LBrace represents the use of a left brace, i.e., {. @@ -6034,10 +7335,10 @@ class LBrace < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6048,6 +7349,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + LBrace.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6057,6 +7369,19 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(LBrace) && value === other.value + end + + # Because some nodes keep around a { token so that comments can be attached + # to it if they occur in the source, oftentimes an LBrace is a child of + # another node. This means it's required at initialization time. To make it + # easier to create LBrace nodes without any specific value, this method + # provides a default node. + def self.default + new(value: "{", location: Location.default) + end end # LBracket represents the use of a left bracket, i.e., [. @@ -6067,10 +7392,10 @@ class LBracket < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6081,6 +7406,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + LBracket.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6090,6 +7426,19 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(LBracket) && value === other.value + end + + # Because some nodes keep around a [ token so that comments can be attached + # to it if they occur in the source, oftentimes an LBracket is a child of + # another node. This means it's required at initialization time. To make it + # easier to create LBracket nodes without any specific value, this method + # provides a default node. + def self.default + new(value: "[", location: Location.default) + end end # LParen represents the use of a left parenthesis, i.e., (. @@ -6100,10 +7449,10 @@ class LParen < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6114,6 +7463,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + LParen.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6123,6 +7483,19 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(LParen) && value === other.value + end + + # Because some nodes keep around a ( token so that comments can be attached + # to it if they occur in the source, oftentimes an LParen is a child of + # another node. This means it's required at initialization time. To make it + # easier to create LParen nodes without any specific value, this method + # provides a default node. + def self.default + new(value: "(", location: Location.default) + end end # MAssign is a parent node of any kind of multiple assignment. This includes @@ -6143,17 +7516,17 @@ class MAssign < Node # [MLHS | MLHSParen] the target of the multiple assignment attr_reader :target - # [untyped] the value being assigned + # [Node] the value being assigned attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(target:, value:, location:, comments: []) + def initialize(target:, value:, location:) @target = target @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6164,6 +7537,18 @@ def child_nodes [target, value] end + def copy(target: nil, value: nil, location: nil) + node = + MAssign.new( + target: target || self.target, + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6175,11 +7560,15 @@ def format(q) q.group { q.format(target) } q.text(" =") q.indent do - q.breakable + q.breakable_space q.format(value) end end end + + def ===(other) + other.is_a?(MAssign) && target === other.target && value === other.value + end end # MethodAddBlock represents a method call with a block argument. @@ -6187,20 +7576,20 @@ def format(q) # method {} # class MethodAddBlock < Node - # [Call | Command | CommandCall | FCall] the method call + # [ARef | CallNode | Command | CommandCall | Super | ZSuper] the method call attr_reader :call - # [BraceBlock | DoBlock] the block being sent with the method call + # [BlockNode] the block being sent with the method call attr_reader :block # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(call:, block:, location:, comments: []) + def initialize(call:, block:, location:) @call = call @block = block @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6211,6 +7600,18 @@ def child_nodes [call, block] end + def copy(call: nil, block: nil, location: nil) + node = + MethodAddBlock.new( + call: call || self.call, + block: block || self.block, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6221,8 +7622,8 @@ def format(q) # If we're at the top of a call chain, then we're going to do some # specialized printing in case we can print it nicely. We _only_ do this # at the top of the chain to avoid weird recursion issues. - if !CallChainFormatter.chained?(q.parent) && - CallChainFormatter.chained?(call) + if CallChainFormatter.chained?(call) && + !CallChainFormatter.chained?(q.parent) q.group do q .if_break { CallChainFormatter.new(self).format(q) } @@ -6233,6 +7634,11 @@ def format(q) end end + def ===(other) + other.is_a?(MethodAddBlock) && call === other.call && + block === other.block + end + def format_contents(q) q.format(call) q.format(block) @@ -6245,8 +7651,12 @@ def format_contents(q) # first, second, third = value # class MLHS < Node - # Array[ARefField | ArgStar | Field | Ident | MLHSParen | VarField] the - # parts of the left-hand side of a multiple assignment + # [ + # Array[ + # ARefField | ArgStar | ConstPathField | Field | Ident | MLHSParen | + # TopConstField | VarField + # ] + # ] the parts of the left-hand side of a multiple assignment attr_reader :parts # [boolean] whether or not there is a trailing comma at the end of this @@ -6257,11 +7667,11 @@ class MLHS < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, comma: false, location:, comments: []) + def initialize(parts:, location:, comma: false) @parts = parts @comma = comma @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6272,6 +7682,18 @@ def child_nodes parts end + def copy(parts: nil, location: nil, comma: nil) + node = + MLHS.new( + parts: parts || self.parts, + location: location || self.location, + comma: comma || self.comma + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6282,6 +7704,11 @@ def format(q) q.seplist(parts) { |part| q.format(part) } q.text(",") if comma end + + def ===(other) + other.is_a?(MLHS) && ArrayMatch.call(parts, other.parts) && + comma === other.comma + end end # MLHSParen represents parentheses being used to destruct values in a multiple @@ -6301,11 +7728,11 @@ class MLHSParen < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(contents:, comma: false, location:, comments: []) + def initialize(contents:, location:, comma: false) @contents = contents @comma = comma @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6316,6 +7743,17 @@ def child_nodes [contents] end + def copy(contents: nil, location: nil) + node = + MLHSParen.new( + contents: contents || self.contents, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6329,17 +7767,23 @@ def format(q) q.format(contents) q.text(",") if comma else - q.group(0, "(", ")") do + q.text("(") + q.group do q.indent do - q.breakable("") + q.breakable_empty q.format(contents) end q.text(",") if comma - q.breakable("") + q.breakable_empty end + q.text(")") end end + + def ===(other) + other.is_a?(MLHSParen) && contents === other.contents + end end # ModuleDeclaration represents defining a module using the +module+ keyword. @@ -6357,11 +7801,11 @@ class ModuleDeclaration < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, bodystmt:, location:, comments: []) + def initialize(constant:, bodystmt:, location:) @constant = constant @bodystmt = bodystmt @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6372,6 +7816,18 @@ def child_nodes [constant, bodystmt] end + def copy(constant: nil, bodystmt: nil, location: nil) + node = + ModuleDeclaration.new( + constant: constant || self.constant, + bodystmt: bodystmt || self.bodystmt, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6384,33 +7840,40 @@ def deconstruct_keys(_keys) end def format(q) - declaration = -> do - q.group do - q.text("module ") - q.format(constant) - end - end - if bodystmt.empty? q.group do - declaration.call - q.breakable(force: true) + format_declaration(q) + q.breakable_force q.text("end") end else q.group do - declaration.call + format_declaration(q) q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end - q.breakable(force: true) + q.breakable_force q.text("end") end end end + + def ===(other) + other.is_a?(ModuleDeclaration) && constant === other.constant && + bodystmt === other.bodystmt + end + + private + + def format_declaration(q) + q.group do + q.text("module ") + q.format(constant) + end + end end # MRHS represents the values that are being assigned on the right-hand side of @@ -6419,16 +7882,16 @@ def format(q) # values = first, second, third # class MRHS < Node - # Array[untyped] the parts that are being assigned + # [Array[Node]] the parts that are being assigned attr_reader :parts # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, location:, comments: []) + def initialize(parts:, location:) @parts = parts @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6439,6 +7902,17 @@ def child_nodes parts end + def copy(parts: nil, location: nil) + node = + MRHS.new( + parts: parts || self.parts, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6448,6 +7922,10 @@ def deconstruct_keys(_keys) def format(q) q.seplist(parts) { |part| q.format(part) } end + + def ===(other) + other.is_a?(MRHS) && ArrayMatch.call(parts, other.parts) + end end # Next represents using the +next+ keyword. @@ -6474,10 +7952,10 @@ class Next < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(arguments:, location:, comments: []) + def initialize(arguments:, location:) @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6488,6 +7966,17 @@ def child_nodes [arguments] end + def copy(arguments: nil, location: nil) + node = + Next.new( + arguments: arguments || self.arguments, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6497,6 +7986,10 @@ def deconstruct_keys(_keys) def format(q) FlowControlFormatter.new("next", self).format(q) end + + def ===(other) + other.is_a?(Next) && arguments === other.arguments + end end # Op represents an operator literal in the source. @@ -6508,13 +8001,17 @@ class Op < Node # [String] the operator attr_reader :value + # [Symbol] the symbol version of the value + attr_reader :name + # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value + @name = value.to_sym @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6525,6 +8022,14 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Op.new(value: value || self.value, location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6534,6 +8039,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Op) && value === other.value + end end # OpAssign represents assigning a value to a variable or constant using an @@ -6549,18 +8058,18 @@ class OpAssign < Node # [Op] the operator being used for the assignment attr_reader :operator - # [untyped] the expression to be assigned + # [Node] the expression to be assigned attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(target:, operator:, value:, location:, comments: []) + def initialize(target:, operator:, value:, location:) @target = target @operator = operator @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6571,6 +8080,19 @@ def child_nodes [target, operator, value] end + def copy(target: nil, operator: nil, value: nil, location: nil) + node = + OpAssign.new( + target: target || self.target, + operator: operator || self.operator, + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6594,13 +8116,18 @@ def format(q) q.format(value) else q.indent do - q.breakable + q.breakable_space q.format(value) end end end end + def ===(other) + other.is_a?(OpAssign) && target === other.target && + operator === other.operator && value === other.value + end + private def skip_indent? @@ -6646,7 +8173,7 @@ module Parentheses Assign, Assoc, Binary, - Call, + CallNode, Defined, MAssign, OpAssign @@ -6665,10 +8192,10 @@ def self.break(q) q.text("(") q.indent do - q.breakable("") + q.breakable_empty yield end - q.breakable("") + q.breakable_empty q.text(")") end end @@ -6688,7 +8215,7 @@ class OptionalFormatter # [Ident] the name of the parameter attr_reader :name - # [untyped] the value of the parameter + # [Node] the value of the parameter attr_reader :value def initialize(name, value) @@ -6713,7 +8240,7 @@ class KeywordFormatter # [Ident] the name of the parameter attr_reader :name - # [nil | untyped] the value of the parameter + # [nil | Node] the value of the parameter attr_reader :value def initialize(name, value) @@ -6754,10 +8281,10 @@ def format(q) end end - # [Array[ Ident ]] any required parameters + # [Array[ Ident | MLHSParen ]] any required parameters attr_reader :requireds - # [Array[ [ Ident, untyped ] ]] any optional parameters and their default + # [Array[ [ Ident, Node ] ]] any optional parameters and their default # values attr_reader :optionals @@ -6765,15 +8292,16 @@ def format(q) # parameter attr_reader :rest - # [Array[ Ident ]] any positional parameters that exist after a rest - # parameter + # [Array[ Ident | MLHSParen ]] any positional parameters that exist after a + # rest parameter attr_reader :posts - # [Array[ [ Ident, nil | untyped ] ]] any keyword parameters and their + # [Array[ [ Label, nil | Node ] ]] any keyword parameters and their # optional default values attr_reader :keywords - # [nil | :nil | KwRestParam] the optional keyword rest parameter + # [nil | :nil | ArgsForward | KwRestParam] the optional keyword rest + # parameter attr_reader :keyword_rest # [nil | BlockArg] the optional block parameter @@ -6783,15 +8311,14 @@ def format(q) attr_reader :comments def initialize( + location:, requireds: [], optionals: [], rest: nil, posts: [], keywords: [], keyword_rest: nil, - block: nil, - location:, - comments: [] + block: nil ) @requireds = requireds @optionals = optionals @@ -6801,7 +8328,7 @@ def initialize( @keyword_rest = keyword_rest @block = block @location = location - @comments = comments + @comments = [] end # Params nodes are the most complicated in the tree. Occasionally you want @@ -6818,6 +8345,8 @@ def accept(visitor) end def child_nodes + keyword_rest = self.keyword_rest + [ *requireds, *optionals.flatten(1), @@ -6829,6 +8358,32 @@ def child_nodes ] end + def copy( + location: nil, + requireds: nil, + optionals: nil, + rest: nil, + posts: nil, + keywords: nil, + keyword_rest: nil, + block: nil + ) + node = + Params.new( + location: location || self.location, + requireds: requireds || self.requireds, + optionals: optionals || self.optionals, + rest: rest || self.rest, + posts: posts || self.posts, + keywords: keywords || self.keywords, + keyword_rest: keyword_rest || self.keyword_rest, + block: block || self.block + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6846,37 +8401,88 @@ def deconstruct_keys(_keys) end def format(q) + rest = self.rest + keyword_rest = self.keyword_rest + parts = [ *requireds, *optionals.map { |(name, value)| OptionalFormatter.new(name, value) } ] parts << rest if rest && !rest.is_a?(ExcessedComma) - parts += [ - *posts, - *keywords.map { |(name, value)| KeywordFormatter.new(name, value) } - ] + parts.concat(posts) + parts.concat( + keywords.map { |(name, value)| KeywordFormatter.new(name, value) } + ) parts << KeywordRestFormatter.new(keyword_rest) if keyword_rest parts << block if block - contents = -> do - q.seplist(parts) { |part| q.format(part) } - q.format(rest) if rest.is_a?(ExcessedComma) + if parts.empty? + q.nest(0) { format_contents(q, parts) } + return end - if ![Def, Defs, DefEndless].include?(q.parent.class) || parts.empty? - q.nest(0, &contents) - else - q.group(0, "(", ")") do - q.indent do - q.breakable("") - contents.call + if q.parent.is_a?(DefNode) + q.nest(0) do + q.text("(") + q.group do + q.indent do + q.breakable_empty + format_contents(q, parts) + end + q.breakable_empty end - q.breakable("") + q.text(")") end + else + q.nest(0) { format_contents(q, parts) } end end + + def ===(other) + other.is_a?(Params) && ArrayMatch.call(requireds, other.requireds) && + optionals.length == other.optionals.length && + optionals + .zip(other.optionals) + .all? { |left, right| ArrayMatch.call(left, right) } && + rest === other.rest && ArrayMatch.call(posts, other.posts) && + keywords.length == other.keywords.length && + keywords + .zip(other.keywords) + .all? { |left, right| ArrayMatch.call(left, right) } && + keyword_rest === other.keyword_rest && block === other.block + end + + # Returns a range representing the possible number of arguments accepted + # by this params node not including the block. For example: + # + # def foo(a, b = 1, c:, d: 2, &block) + # ... + # end + # + # has arity 2..4. + # + def arity + optional_keywords = keywords.count { |_label, value| value } + + lower_bound = + requireds.length + posts.length + keywords.length - optional_keywords + + upper_bound = + if keyword_rest.nil? && rest.nil? + lower_bound + optionals.length + optional_keywords + end + + lower_bound..upper_bound + end + + private + + def format_contents(q, parts) + q.seplist(parts) { |part| q.format(part) } + q.format(rest) if rest.is_a?(ExcessedComma) + end end # Paren represents using balanced parentheses in a couple places in a Ruby @@ -6889,17 +8495,17 @@ class Paren < Node # [LParen] the left parenthesis that opened this statement attr_reader :lparen - # [nil | untyped] the expression inside the parentheses + # [nil | Node] the expression inside the parentheses attr_reader :contents # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(lparen:, contents:, location:, comments: []) + def initialize(lparen:, contents:, location:) @lparen = lparen @contents = contents @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6910,6 +8516,18 @@ def child_nodes [lparen, contents] end + def copy(lparen: nil, contents: nil, location: nil) + node = + Paren.new( + lparen: lparen || self.lparen, + contents: contents || self.contents, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6922,20 +8540,27 @@ def deconstruct_keys(_keys) end def format(q) + contents = self.contents + q.group do q.format(lparen) if contents && (!contents.is_a?(Params) || !contents.empty?) q.indent do - q.breakable("") + q.breakable_empty q.format(contents) end end - q.breakable("") + q.breakable_empty q.text(")") end end + + def ===(other) + other.is_a?(Paren) && lparen === other.lparen && + contents === other.contents + end end # Period represents the use of the +.+ operator. It is usually found in method @@ -6947,10 +8572,10 @@ class Period < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6961,6 +8586,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + Period.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -6970,6 +8606,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(Period) && value === other.value + end end # Program represents the overall syntax tree. @@ -6980,10 +8620,10 @@ class Program < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(statements:, location:, comments: []) + def initialize(statements:, location:) @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -6994,6 +8634,17 @@ def child_nodes [statements] end + def copy(statements: nil, location: nil) + node = + Program.new( + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7006,7 +8657,11 @@ def format(q) # We're going to put a newline on the end so that it always has one unless # it ends with the special __END__ syntax. In that case we want to # replicate the text exactly so we will just let it be. - q.breakable(force: true) unless statements.body.last.is_a?(EndContent) + q.breakable_force unless statements.body.last.is_a?(EndContent) + end + + def ===(other) + other.is_a?(Program) && statements === other.statements end end @@ -7024,19 +8679,31 @@ class QSymbols < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(beginning:, elements:, location:, comments: []) + def initialize(beginning:, elements:, location:) @beginning = beginning @elements = elements @location = location - @comments = comments + @comments = [] end def accept(visitor) visitor.visit_qsymbols(self) end - def child_nodes - [] + def child_nodes + [] + end + + def copy(beginning: nil, elements: nil, location: nil) + node = + QSymbols.new( + beginning: beginning || self.beginning, + elements: elements || self.elements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -7058,15 +8725,23 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) + end + + def ===(other) + other.is_a?(QSymbols) && beginning === other.beginning && + ArrayMatch.call(elements, other.elements) end end @@ -7094,11 +8769,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + QSymbolsBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(QSymbolsBeg) && value === other.value + end end # QWords represents a string literal array without interpolation. @@ -7115,11 +8801,11 @@ class QWords < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(beginning:, elements:, location:, comments: []) + def initialize(beginning:, elements:, location:) @beginning = beginning @elements = elements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7130,6 +8816,14 @@ def child_nodes [] end + def copy(beginning: nil, elements: nil, location: nil) + QWords.new( + beginning: beginning || self.beginning, + elements: elements || self.elements, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7149,15 +8843,23 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) + end + + def ===(other) + other.is_a?(QWords) && beginning === other.beginning && + ArrayMatch.call(elements, other.elements) end end @@ -7185,11 +8887,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + QWordsBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(QWordsBeg) && value === other.value + end end # RationalLiteral represents the use of a rational number literal. @@ -7203,10 +8916,10 @@ class RationalLiteral < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7217,6 +8930,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + RationalLiteral.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7226,6 +8950,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(RationalLiteral) && value === other.value + end end # RBrace represents the use of a right brace, i.e., +++. @@ -7246,11 +8974,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + RBrace.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(RBrace) && value === other.value + end end # RBracket represents the use of a right bracket, i.e., +]+. @@ -7271,11 +9010,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + RBracket.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(RBracket) && value === other.value + end end # Redo represents the use of the +redo+ keyword. @@ -7283,16 +9033,12 @@ def deconstruct_keys(_keys) # redo # class Redo < Node - # [String] the value of the keyword - attr_reader :value - # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) - @value = value + def initialize(location:) @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7303,14 +9049,25 @@ def child_nodes [] end + def copy(location: nil) + node = Redo.new(location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } + { location: location, comments: comments } end def format(q) - q.text(value) + q.text("redo") + end + + def ===(other) + other.is_a?(Redo) end end @@ -7342,11 +9099,24 @@ def child_nodes parts end + def copy(beginning: nil, parts: nil, location: nil) + RegexpContent.new( + beginning: beginning || self.beginning, + parts: parts || self.parts, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { beginning: beginning, parts: parts, location: location } end + + def ===(other) + other.is_a?(RegexpContent) && beginning === other.beginning && + parts === other.parts + end end # RegexpBeg represents the start of a regular expression literal. @@ -7375,11 +9145,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + RegexpBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(RegexpBeg) && value === other.value + end end # RegexpEnd represents the end of a regular expression literal. @@ -7409,11 +9190,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + RegexpEnd.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(RegexpEnd) && value === other.value + end end # RegexpLiteral represents a regular expression literal. @@ -7434,12 +9226,12 @@ class RegexpLiteral < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(beginning:, ending:, parts:, location:, comments: []) + def initialize(beginning:, ending:, parts:, location:) @beginning = beginning @ending = ending @parts = parts @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7450,6 +9242,19 @@ def child_nodes parts end + def copy(beginning: nil, ending: nil, parts: nil, location: nil) + node = + RegexpLiteral.new( + beginning: beginning || self.beginning, + ending: ending || self.ending, + parts: parts || self.parts, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7503,6 +9308,12 @@ def format(q) end end + def ===(other) + other.is_a?(RegexpLiteral) && beginning === other.beginning && + ending === other.ending && options === other.options && + ArrayMatch.call(parts, other.parts) + end + def options ending[1..] end @@ -7535,7 +9346,7 @@ def ambiguous?(q) # end # class RescueEx < Node - # [untyped] the list of exceptions being rescued + # [nil | Node] the list of exceptions being rescued attr_reader :exceptions # [nil | Field | VarField] the expression being used to capture the raised @@ -7545,11 +9356,11 @@ class RescueEx < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(exceptions:, variable:, location:, comments: []) + def initialize(exceptions:, variable:, location:) @exceptions = exceptions @variable = variable @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7560,6 +9371,18 @@ def child_nodes [*exceptions, variable] end + def copy(exceptions: nil, variable: nil, location: nil) + node = + RescueEx.new( + exceptions: exceptions || self.exceptions, + variable: variable || self.variable, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7584,6 +9407,11 @@ def format(q) end end end + + def ===(other) + other.is_a?(RescueEx) && exceptions === other.exceptions && + variable === other.variable + end end # Rescue represents the use of the rescue keyword inside of a BodyStmt node. @@ -7596,7 +9424,7 @@ class Rescue < Node # [Kw] the rescue keyword attr_reader :keyword - # [RescueEx] the exceptions being rescued + # [nil | RescueEx] the exceptions being rescued attr_reader :exception # [Statements] the expressions to evaluate when an error is rescued @@ -7608,20 +9436,13 @@ class Rescue < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - keyword:, - exception:, - statements:, - consequent:, - location:, - comments: [] - ) + def initialize(keyword:, exception:, statements:, consequent:, location:) @keyword = keyword @exception = exception @statements = statements @consequent = consequent @location = location - @comments = comments + @comments = [] end def bind_end(end_char, end_column) @@ -7635,11 +9456,11 @@ def bind_end(end_char, end_column) end_column: end_column ) - if consequent - consequent.bind_end(end_char, end_column) + if (next_node = consequent) + next_node.bind_end(end_char, end_column) statements.bind_end( - consequent.location.start_char, - consequent.location.start_column + next_node.location.start_char, + next_node.location.start_column ) else statements.bind_end(end_char, end_column) @@ -7654,6 +9475,26 @@ def child_nodes [keyword, exception, statements, consequent] end + def copy( + keyword: nil, + exception: nil, + statements: nil, + consequent: nil, + location: nil + ) + node = + Rescue.new( + keyword: keyword || self.keyword, + exception: exception || self.exception, + statements: statements || self.statements, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7679,17 +9520,23 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent - q.breakable(force: true) + q.breakable_force q.format(consequent) end end end + + def ===(other) + other.is_a?(Rescue) && keyword === other.keyword && + exception === other.exception && statements === other.statements && + consequent === other.consequent + end end # RescueMod represents the use of the modifier form of a +rescue+ clause. @@ -7697,20 +9544,20 @@ def format(q) # expression rescue value # class RescueMod < Node - # [untyped] the expression to execute + # [Node] the expression to execute attr_reader :statement - # [untyped] the value to use if the executed expression raises an error + # [Node] the value to use if the executed expression raises an error attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(statement:, value:, location:, comments: []) + def initialize(statement:, value:, location:) @statement = statement @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7721,6 +9568,18 @@ def child_nodes [statement, value] end + def copy(statement: nil, value: nil, location: nil) + node = + RescueMod.new( + statement: statement || self.statement, + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7733,19 +9592,26 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "begin", "end") do + q.text("begin") + q.group do q.indent do - q.breakable(force: true) + q.breakable_force q.format(statement) end - q.breakable(force: true) + q.breakable_force q.text("rescue StandardError") q.indent do - q.breakable(force: true) + q.breakable_force q.format(value) end - q.breakable(force: true) + q.breakable_force end + q.text("end") + end + + def ===(other) + other.is_a?(RescueMod) && statement === other.statement && + value === other.value end end @@ -7761,10 +9627,10 @@ class RestParam < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(name:, location:, comments: []) + def initialize(name:, location:) @name = name @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7775,6 +9641,17 @@ def child_nodes [name] end + def copy(name: nil, location: nil) + node = + RestParam.new( + name: name || self.name, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7785,6 +9662,10 @@ def format(q) q.text("*") q.format(name) if name end + + def ===(other) + other.is_a?(RestParam) && name === other.name + end end # Retry represents the use of the +retry+ keyword. @@ -7792,16 +9673,12 @@ def format(q) # retry # class Retry < Node - # [String] the value of the keyword - attr_reader :value - # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) - @value = value + def initialize(location:) @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7812,14 +9689,25 @@ def child_nodes [] end + def copy(location: nil) + node = Retry.new(location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } + { location: location, comments: comments } end def format(q) - q.text(value) + q.text("retry") + end + + def ===(other) + other.is_a?(Retry) end end @@ -7827,17 +9715,17 @@ def format(q) # # return value # - class Return < Node - # [Args] the arguments being passed to the keyword + class ReturnNode < Node + # [nil | Args] the arguments being passed to the keyword attr_reader :arguments # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(arguments:, location:, comments: []) + def initialize(arguments:, location:) @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7848,6 +9736,17 @@ def child_nodes [arguments] end + def copy(arguments: nil, location: nil) + node = + ReturnNode.new( + arguments: arguments || self.arguments, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7857,41 +9756,9 @@ def deconstruct_keys(_keys) def format(q) FlowControlFormatter.new("return", self).format(q) end - end - - # Return0 represents the bare +return+ keyword with no arguments. - # - # return - # - class Return0 < Node - # [String] the value of the keyword - attr_reader :value - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(value:, location:, comments: []) - @value = value - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_return0(self) - end - - def child_nodes - [] - end - - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } - end - def format(q) - q.text(value) + def ===(other) + other.is_a?(ReturnNode) && arguments === other.arguments end end @@ -7913,11 +9780,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + RParen.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(RParen) && value === other.value + end end # SClass represents a block of statements that should be evaluated within the @@ -7928,7 +9806,7 @@ def deconstruct_keys(_keys) # end # class SClass < Node - # [untyped] the target of the singleton class to enter + # [Node] the target of the singleton class to enter attr_reader :target # [BodyStmt] the expressions to be executed @@ -7937,11 +9815,11 @@ class SClass < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(target:, bodystmt:, location:, comments: []) + def initialize(target:, bodystmt:, location:) @target = target @bodystmt = bodystmt @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -7952,6 +9830,18 @@ def child_nodes [target, bodystmt] end + def copy(target: nil, bodystmt: nil, location: nil) + node = + SClass.new( + target: target || self.target, + bodystmt: bodystmt || self.bodystmt, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -7964,14 +9854,21 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "class << ", "end") do + q.text("class << ") + q.group do q.format(target) q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end - q.breakable(force: true) + q.breakable_force end + q.text("end") + end + + def ===(other) + other.is_a?(SClass) && target === other.target && + bodystmt === other.bodystmt end end @@ -7983,23 +9880,19 @@ def format(q) # propagate that onto void_stmt nodes inside the stmts in order to make sure # all comments get printed appropriately. class Statements < Node - # [SyntaxTree] the parser that is generating this node - attr_reader :parser - - # [Array[ untyped ]] the list of expressions contained within this node + # [Array[ Node ]] the list of expressions contained within this node attr_reader :body # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parser, body:, location:, comments: []) - @parser = parser + def initialize(body:, location:) @body = body @location = location - @comments = comments + @comments = [] end - def bind(start_char, start_column, end_char, end_column) + def bind(parser, start_char, start_column, end_char, end_column) @location = Location.new( start_line: location.start_line, @@ -8010,8 +9903,8 @@ def bind(start_char, start_column, end_char, end_column) end_column: end_column ) - if body[0].is_a?(VoidStmt) - location = body[0].location + if (void_stmt = body[0]).is_a?(VoidStmt) + location = void_stmt.location location = Location.new( start_line: location.start_line, @@ -8025,7 +9918,7 @@ def bind(start_char, start_column, end_char, end_column) body[0] = VoidStmt.new(location: location) end - attach_comments(start_char, end_char) + attach_comments(parser, start_char, end_char) end def bind_end(end_char, end_column) @@ -8054,10 +9947,21 @@ def child_nodes body end + def copy(body: nil, location: nil) + node = + Statements.new( + body: body || self.body, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) - { parser: parser, body: body, location: location, comments: comments } + { body: body, location: location, comments: comments } end def format(q) @@ -8077,30 +9981,26 @@ def format(q) end end - access_controls = - Hash.new do |hash, node| - hash[node] = node.is_a?(VCall) && - %w[private protected public].include?(node.value.value) - end - - body.each_with_index do |statement, index| + previous = nil + body.each do |statement| next if statement.is_a?(VoidStmt) if line.nil? q.format(statement) elsif (statement.location.start_line - line) > 1 - q.breakable(force: true) - q.breakable(force: true) + q.breakable_force + q.breakable_force q.format(statement) - elsif access_controls[statement] || access_controls[body[index - 1]] - q.breakable(force: true) - q.breakable(force: true) + elsif (statement.is_a?(VCall) && statement.access_control?) || + (previous.is_a?(VCall) && previous.access_control?) + q.breakable_force + q.breakable_force q.format(statement) elsif statement.location.start_line != line - q.breakable(force: true) + q.breakable_force q.format(statement) elsif !q.parent.is_a?(StringEmbExpr) - q.breakable(force: true) + q.breakable_force q.format(statement) else q.text("; ") @@ -8108,15 +10008,20 @@ def format(q) end line = statement.location.end_line + previous = statement end end + def ===(other) + other.is_a?(Statements) && ArrayMatch.call(body, other.body) + end + private # As efficiently as possible, gather up all of the comments that have been # found while this statements list was being parsed and add them into the # body. - def attach_comments(start_char, end_char) + def attach_comments(parser, start_char, end_char) parser_comments = parser.comments comment_index = 0 @@ -8163,9 +10068,13 @@ class StringContent < Node # string attr_reader :parts + # [Array[ Comment | EmbDoc ]] the comments attached to this node + attr_reader :comments + def initialize(parts:, location:) @parts = parts @location = location + @comments = [] end def accept(visitor) @@ -8176,11 +10085,49 @@ def child_nodes parts end + def copy(parts: nil, location: nil) + StringContent.new( + parts: parts || self.parts, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { parts: parts, location: location } end + + def ===(other) + other.is_a?(StringContent) && ArrayMatch.call(parts, other.parts) + end + + def format(q) + q.text(q.quote) + q.group do + parts.each do |part| + if part.is_a?(TStringContent) + value = Quotes.normalize(part.value, q.quote) + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end + + q.text(line) + end + + q.breakable_return if value.end_with?("\n") + else + q.format(part) + end + end + end + q.text(q.quote) + end end # StringConcat represents concatenating two strings together using a backward @@ -8190,7 +10137,8 @@ def deconstruct_keys(_keys) # "second" # class StringConcat < Node - # [StringConcat | StringLiteral] the left side of the concatenation + # [Heredoc | StringConcat | StringLiteral] the left side of the + # concatenation attr_reader :left # [StringLiteral] the right side of the concatenation @@ -8199,11 +10147,11 @@ class StringConcat < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(left:, right:, location:, comments: []) + def initialize(left:, right:, location:) @left = left @right = right @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8214,6 +10162,18 @@ def child_nodes [left, right] end + def copy(left: nil, right: nil, location: nil) + node = + StringConcat.new( + left: left || self.left, + right: right || self.right, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8225,11 +10185,15 @@ def format(q) q.format(left) q.text(" \\") q.indent do - q.breakable(force: true) + q.breakable_force q.format(right) end end end + + def ===(other) + other.is_a?(StringConcat) && left === other.left && right === other.right + end end # StringDVar represents shorthand interpolation of a variable into a string. @@ -8245,10 +10209,10 @@ class StringDVar < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(variable:, location:, comments: []) + def initialize(variable:, location:) @variable = variable @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8259,6 +10223,17 @@ def child_nodes [variable] end + def copy(variable: nil, location: nil) + node = + StringDVar.new( + variable: variable || self.variable, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8270,6 +10245,10 @@ def format(q) q.format(variable) q.text("}") end + + def ===(other) + other.is_a?(StringDVar) && variable === other.variable + end end # StringEmbExpr represents interpolated content. It can be contained within a @@ -8285,10 +10264,10 @@ class StringEmbExpr < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(statements:, location:, comments: []) + def initialize(statements:, location:) @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8299,6 +10278,17 @@ def child_nodes [statements] end + def copy(statements: nil, location: nil) + node = + StringEmbExpr.new( + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8311,20 +10301,29 @@ def format(q) # same line in the source, then we're going to leave them in place and # assume that's the way the developer wanted this expression # represented. - doc = q.group(0, '#{', "}") { q.format(statements) } - RemoveBreaks.call(doc) + q.remove_breaks( + q.group do + q.text('#{') + q.format(statements) + q.text("}") + end + ) else q.group do q.text('#{') q.indent do - q.breakable("") + q.breakable_empty q.format(statements) end - q.breakable("") + q.breakable_empty q.text("}") end end end + + def ===(other) + other.is_a?(StringEmbExpr) && statements === other.statements + end end # StringLiteral represents a string literal. @@ -8336,17 +10335,17 @@ class StringLiteral < Node # string literal attr_reader :parts - # [String] which quote was used by the string literal + # [nil | String] which quote was used by the string literal attr_reader :quote # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, quote:, location:, comments: []) + def initialize(parts:, quote:, location:) @parts = parts @quote = quote @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8357,6 +10356,18 @@ def child_nodes parts end + def copy(parts: nil, quote: nil, location: nil) + node = + StringLiteral.new( + parts: parts || self.parts, + quote: quote || self.quote, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8370,27 +10381,43 @@ def format(q) end opening_quote, closing_quote = - if !Quotes.locked?(self) + if !Quotes.locked?(self, q.quote) [q.quote, q.quote] - elsif quote.start_with?("%") + elsif quote&.start_with?("%") [quote, Quotes.matching(quote[/%[qQ]?(.)/, 1])] else [quote, quote] end - q.group(0, opening_quote, closing_quote) do + q.text(opening_quote) + q.group do parts.each do |part| if part.is_a?(TStringContent) value = Quotes.normalize(part.value, closing_quote) - separator = -> { q.breakable(force: true, indent: false) } - q.seplist(value.split(/\r?\n/, -1), separator) do |text| - q.text(text) + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end + + q.text(line) end + + q.breakable_return if value.end_with?("\n") else q.format(part) end end end + q.text(closing_quote) + end + + def ===(other) + other.is_a?(StringLiteral) && ArrayMatch.call(parts, other.parts) && + quote === other.quote end end @@ -8406,10 +10433,10 @@ class Super < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(arguments:, location:, comments: []) + def initialize(arguments:, location:) @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8420,6 +10447,17 @@ def child_nodes [arguments] end + def copy(arguments: nil, location: nil) + node = + Super.new( + arguments: arguments || self.arguments, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8438,6 +10476,10 @@ def format(q) end end end + + def ===(other) + other.is_a?(Super) && arguments === other.arguments + end end # SymBeg represents the beginning of a symbol literal. @@ -8473,11 +10515,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + SymBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(SymBeg) && value === other.value + end end # SymbolContent represents symbol contents and is always the child of a @@ -8503,11 +10556,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + SymbolContent.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(SymbolContent) && value === other.value + end end # SymbolLiteral represents a symbol in the system with no interpolation @@ -8516,25 +10580,36 @@ def deconstruct_keys(_keys) # :symbol # class SymbolLiteral < Node - # [Backtick | Const | CVar | GVar | Ident | IVar | Kw | Op] the value of the - # symbol + # [Backtick | Const | CVar | GVar | Ident | IVar | Kw | Op | TStringContent] + # the value of the symbol attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) visitor.visit_symbol_literal(self) end - def child_nodes - [value] + def child_nodes + [value] + end + + def copy(value: nil, location: nil) + node = + SymbolLiteral.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -8545,8 +10620,13 @@ def deconstruct_keys(_keys) def format(q) q.text(":") + q.text("\\") if value.comments.any? q.format(value) end + + def ===(other) + other.is_a?(SymbolLiteral) && value === other.value + end end # Symbols represents a symbol array literal with interpolation. @@ -8563,11 +10643,11 @@ class Symbols < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(beginning:, elements:, location:, comments: []) + def initialize(beginning:, elements:, location:) @beginning = beginning @elements = elements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8578,6 +10658,14 @@ def child_nodes [] end + def copy(beginning: nil, elements: nil, location: nil) + Symbols.new( + beginning: beginning || self.beginning, + elements: elements || self.elements, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8597,15 +10685,23 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) + end + + def ===(other) + other.is_a?(Symbols) && beginning === other.beginning && + ArrayMatch.call(elements, other.elements) end end @@ -8634,11 +10730,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + SymbolsBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(SymbolsBeg) && value === other.value + end end # TLambda represents the beginning of a lambda literal. @@ -8663,11 +10770,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + TLambda.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(TLambda) && value === other.value + end end # TLamBeg represents the beginning of the body of a lambda literal using @@ -8693,11 +10811,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + TLamBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(TLamBeg) && value === other.value + end end # TopConstField is always the child node of some kind of assignment. It @@ -8713,10 +10842,10 @@ class TopConstField < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, location:, comments: []) + def initialize(constant:, location:) @constant = constant @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8727,6 +10856,17 @@ def child_nodes [constant] end + def copy(constant: nil, location: nil) + node = + TopConstField.new( + constant: constant || self.constant, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8737,6 +10877,10 @@ def format(q) q.text("::") q.format(constant) end + + def ===(other) + other.is_a?(TopConstField) && constant === other.constant + end end # TopConstRef is very similar to TopConstField except that it is not involved @@ -8751,10 +10895,10 @@ class TopConstRef < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(constant:, location:, comments: []) + def initialize(constant:, location:) @constant = constant @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8765,6 +10909,17 @@ def child_nodes [constant] end + def copy(constant: nil, location: nil) + node = + TopConstRef.new( + constant: constant || self.constant, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8775,6 +10930,10 @@ def format(q) q.text("::") q.format(constant) end + + def ===(other) + other.is_a?(TopConstRef) && constant === other.constant + end end # TStringBeg represents the beginning of a string literal. @@ -8804,11 +10963,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + TStringBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(TStringBeg) && value === other.value + end end # TStringContent represents plain characters inside of an entity that accepts @@ -8826,10 +10996,10 @@ class TStringContent < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def match?(pattern) @@ -8844,6 +11014,17 @@ def child_nodes [] end + def copy(value: nil, location: nil) + node = + TStringContent.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8853,6 +11034,10 @@ def deconstruct_keys(_keys) def format(q) q.text(value) end + + def ===(other) + other.is_a?(TStringContent) && value === other.value + end end # TStringEnd represents the end of a string literal. @@ -8882,11 +11067,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + TStringEnd.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(TStringEnd) && value === other.value + end end # Not represents the unary +not+ method being called on an expression. @@ -8894,20 +11090,21 @@ def deconstruct_keys(_keys) # not value # class Not < Node - # [nil | untyped] the statement on which to operate + # [nil | Node] the statement on which to operate attr_reader :statement # [boolean] whether or not parentheses were used attr_reader :parentheses + alias parentheses? parentheses # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(statement:, parentheses:, location:, comments: []) + def initialize(statement:, parentheses:, location:) @statement = statement @parentheses = parentheses @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8918,6 +11115,18 @@ def child_nodes [statement] end + def copy(statement: nil, parentheses: nil, location: nil) + node = + Not.new( + statement: statement || self.statement, + parentheses: parentheses || self.parentheses, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -8930,28 +11139,32 @@ def deconstruct_keys(_keys) end def format(q) - parent = q.parents.take(2)[1] - ternary = - (parent.is_a?(If) || parent.is_a?(Unless)) && - Ternaryable.call(q, parent) - q.text("not") if parentheses q.text("(") - elsif ternary - q.if_break { q.text(" ") }.if_flat { q.text("(") } + q.format(statement) if statement + q.text(")") else - q.text(" ") + grandparent = q.grandparent + ternary = + (grandparent.is_a?(IfNode) || grandparent.is_a?(UnlessNode)) && + Ternaryable.call(q, grandparent) + + if ternary + q.if_break { q.text(" ") }.if_flat { q.text("(") } + q.format(statement) if statement + q.if_flat { q.text(")") } if ternary + else + q.text(" ") + q.format(statement) if statement + end end + end - q.format(statement) if statement - - if parentheses - q.text(")") - elsif ternary - q.if_flat { q.text(")") } - end + def ===(other) + other.is_a?(Not) && statement === other.statement && + parentheses === other.parentheses end end @@ -8964,17 +11177,17 @@ class Unary < Node # [String] the operator being used attr_reader :operator - # [untyped] the statement on which to operate + # [Node] the statement on which to operate attr_reader :statement # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(operator:, statement:, location:, comments: []) + def initialize(operator:, statement:, location:) @operator = operator @statement = statement @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -8985,6 +11198,18 @@ def child_nodes [statement] end + def copy(operator: nil, statement: nil, location: nil) + node = + Unary.new( + operator: operator || self.operator, + statement: statement || self.statement, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9000,6 +11225,11 @@ def format(q) q.text(operator) q.format(statement) end + + def ===(other) + other.is_a?(Unary) && operator === other.operator && + statement === other.statement + end end # Undef represents the use of the +undef+ keyword. @@ -9037,10 +11267,10 @@ def format(q) # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(symbols:, location:, comments: []) + def initialize(symbols:, location:) @symbols = symbols @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9051,6 +11281,17 @@ def child_nodes symbols end + def copy(symbols: nil, location: nil) + node = + Undef.new( + symbols: symbols || self.symbols, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9068,6 +11309,10 @@ def format(q) end end end + + def ===(other) + other.is_a?(Undef) && ArrayMatch.call(symbols, other.symbols) + end end # Unless represents the first clause in an +unless+ chain. @@ -9075,31 +11320,25 @@ def format(q) # unless predicate # end # - class Unless < Node - # [untyped] the expression to be checked + class UnlessNode < Node + # [Node] the expression to be checked attr_reader :predicate # [Statements] the expressions to be executed attr_reader :statements - # [nil, Elsif, Else] the next clause in the chain + # [nil | Elsif | Else] the next clause in the chain attr_reader :consequent # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - predicate:, - statements:, - consequent:, - location:, - comments: [] - ) + def initialize(predicate:, statements:, consequent:, location:) @predicate = predicate @statements = statements @consequent = consequent @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9110,6 +11349,19 @@ def child_nodes [predicate, statements, consequent] end + def copy(predicate: nil, statements: nil, consequent: nil, location: nil) + node = + UnlessNode.new( + predicate: predicate || self.predicate, + statements: statements || self.statements, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9125,269 +11377,154 @@ def deconstruct_keys(_keys) def format(q) ConditionalFormatter.new("unless", self).format(q) end - end - - # UnlessMod represents the modifier form of an +unless+ statement. - # - # expression unless predicate - # - class UnlessMod < Node - # [untyped] the expression to be executed - attr_reader :statement - - # [untyped] the expression to be checked - attr_reader :predicate - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(statement:, predicate:, location:, comments: []) - @statement = statement - @predicate = predicate - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_unless_mod(self) - end - - def child_nodes - [statement, predicate] - end - - alias deconstruct child_nodes - def deconstruct_keys(_keys) - { - statement: statement, - predicate: predicate, - location: location, - comments: comments - } + def ===(other) + other.is_a?(UnlessNode) && predicate === other.predicate && + statements === other.statements && consequent === other.consequent end - def format(q) - ConditionalModFormatter.new("unless", self).format(q) + # Checks if the node was originally found in the modifier form. + def modifier? + predicate.location.start_char > statements.location.start_char end end - # Formats an Until, UntilMod, While, or WhileMod node. + # Formats an Until or While node. class LoopFormatter # [String] the name of the keyword used for this loop attr_reader :keyword - # [Until | UntilMod | While | WhileMod] the node that is being formatted + # [Until | While] the node that is being formatted attr_reader :node - # [untyped] the statements associated with the node - attr_reader :statements - - def initialize(keyword, node, statements) + def initialize(keyword, node) @keyword = keyword @node = node - @statements = statements end def format(q) - if ContainsAssignment.call(node.predicate) - format_break(q) - q.break_parent - return - end - - q.group do - q - .if_break { format_break(q) } - .if_flat do - Parentheses.flat(q) do - q.format(statements) - q.text(" #{keyword} ") - q.format(node.predicate) - end - end - end - end - - private - - def format_break(q) - q.text("#{keyword} ") - q.nest(keyword.length + 1) { q.format(node.predicate) } - q.indent do - q.breakable("") - q.format(statements) - end - q.breakable("") - q.text("end") - end - end - - # Until represents an +until+ loop. - # - # until predicate - # end - # - class Until < Node - # [untyped] the expression to be checked - attr_reader :predicate - - # [Statements] the expressions to be executed - attr_reader :statements - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(predicate:, statements:, location:, comments: []) - @predicate = predicate - @statements = statements - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_until(self) - end - - def child_nodes - [predicate, statements] - end - - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { - predicate: predicate, - statements: statements, - location: location, - comments: comments - } - end - - def format(q) - if statements.empty? - keyword = "until " - + # If we're in the modifier form and we're modifying a `begin`, then this + # is a special case where we need to explicitly use the modifier form + # because otherwise the semantic meaning changes. This looks like: + # + # begin + # foo + # end while bar + # + # Also, if the statement of the modifier includes an assignment, then we + # can't know for certain that it won't impact the predicate, so we need to + # force it to stay as it is. This looks like: + # + # foo = bar while foo + # + if node.modifier? && (statement = node.statements.body.first) && + (statement.is_a?(Begin) || ContainsAssignment.call(statement)) + q.format(statement) + q.text(" #{keyword} ") + q.format(node.predicate) + elsif node.statements.empty? q.group do - q.text(keyword) - q.nest(keyword.length) { q.format(predicate) } - q.breakable(force: true) + q.text("#{keyword} ") + q.nest(keyword.length + 1) { q.format(node.predicate) } + q.breakable_force q.text("end") end + elsif ContainsAssignment.call(node.predicate) + format_break(q) + q.break_parent else - LoopFormatter.new("until", self, statements).format(q) + q.group do + q + .if_break { format_break(q) } + .if_flat do + Parentheses.flat(q) do + q.format(node.statements) + q.text(" #{keyword} ") + q.format(node.predicate) + end + end + end end end - end - - # UntilMod represents the modifier form of a +until+ loop. - # - # expression until predicate - # - class UntilMod < Node - # [untyped] the expression to be executed - attr_reader :statement - - # [untyped] the expression to be checked - attr_reader :predicate - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(statement:, predicate:, location:, comments: []) - @statement = statement - @predicate = predicate - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_until_mod(self) - end - - def child_nodes - [statement, predicate] - end - - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { - statement: statement, - predicate: predicate, - location: location, - comments: comments - } - end - def format(q) - # If we're in the modifier form and we're modifying a `begin`, then this - # is a special case where we need to explicitly use the modifier form - # because otherwise the semantic meaning changes. This looks like: - # - # begin - # foo - # end until bar - # - # Also, if the statement of the modifier includes an assignment, then we - # can't know for certain that it won't impact the predicate, so we need to - # force it to stay as it is. This looks like: - # - # foo = bar until foo - # - if statement.is_a?(Begin) || ContainsAssignment.call(statement) - q.format(statement) - q.text(" until ") - q.format(predicate) - else - LoopFormatter.new("until", self, statement).format(q) + private + + def format_break(q) + q.text("#{keyword} ") + q.nest(keyword.length + 1) { q.format(node.predicate) } + q.indent do + q.breakable_empty + q.format(node.statements) end + q.breakable_empty + q.text("end") end end - # VarAlias represents when you're using the +alias+ keyword with global - # variable arguments. + # Until represents an +until+ loop. # - # alias $new $old + # until predicate + # end # - class VarAlias < Node - # [GVar] the new alias of the variable - attr_reader :left + class UntilNode < Node + # [Node] the expression to be checked + attr_reader :predicate - # [Backref | GVar] the current name of the variable to be aliased - attr_reader :right + # [Statements] the expressions to be executed + attr_reader :statements # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(left:, right:, location:, comments: []) - @left = left - @right = right + def initialize(predicate:, statements:, location:) + @predicate = predicate + @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) - visitor.visit_var_alias(self) + visitor.visit_until(self) end def child_nodes - [left, right] + [predicate, statements] + end + + def copy(predicate: nil, statements: nil, location: nil) + node = + UntilNode.new( + predicate: predicate || self.predicate, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes def deconstruct_keys(_keys) - { left: left, right: right, location: location, comments: comments } + { + predicate: predicate, + statements: statements, + location: location, + comments: comments + } end def format(q) - keyword = "alias " + LoopFormatter.new("until", self).format(q) + end - q.text(keyword) - q.format(left) - q.text(" ") - q.format(right) + def ===(other) + other.is_a?(UntilNode) && predicate === other.predicate && + statements === other.statements + end + + def modifier? + predicate.location.start_char > statements.location.start_char end end @@ -9398,16 +11535,16 @@ def format(q) # # In the example above, the VarField node represents the +variable+ token. class VarField < Node - # [nil | Const | CVar | GVar | Ident | IVar] the target of this node + # [nil | :nil | Const | CVar | GVar | Ident | IVar] the target of this node attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9415,7 +11552,18 @@ def accept(visitor) end def child_nodes - [value] + value == :nil ? [] : [value] + end + + def copy(value: nil, location: nil) + node = + VarField.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node end alias deconstruct child_nodes @@ -9431,6 +11579,10 @@ def format(q) q.format(value) end end + + def ===(other) + other.is_a?(VarField) && value === other.value + end end # VarRef represents a variable reference. @@ -9448,10 +11600,10 @@ class VarRef < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9462,6 +11614,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + VarRef.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9471,6 +11634,38 @@ def deconstruct_keys(_keys) def format(q) q.format(value) end + + def ===(other) + other.is_a?(VarRef) && value === other.value + end + + # Oh man I hate this so much. Basically, ripper doesn't provide enough + # functionality to actually know where pins are within an expression. So we + # have to walk the tree ourselves and insert more information. In doing so, + # we have to replace this node by a pinned node when necessary. + # + # To be clear, this method should just not exist. It's not good. It's a + # place of shame. But it's necessary for now, so I'm keeping it. + def pin(parent, pin) + replace = + PinnedVarRef.new(value: value, location: pin.location.to(location)) + + parent + .deconstruct_keys([]) + .each do |key, value| + if value == self + parent.instance_variable_set(:"@#{key}", replace) + break + elsif value.is_a?(Array) && (index = value.index(self)) + parent.public_send(key)[index] = replace + break + elsif value.is_a?(Array) && + (index = value.index { |(_k, v)| v == self }) + parent.public_send(key)[index][1] = replace + break + end + end + end end # PinnedVarRef represents a pinned variable reference within a pattern @@ -9483,16 +11678,16 @@ def format(q) # This can be a plain local variable like the example above. It can also be a # a class variable, a global variable, or an instance variable. class PinnedVarRef < Node - # [VarRef] the value of this node + # [Const | CVar | GVar | Ident | IVar] the value of this node attr_reader :value # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9503,6 +11698,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + PinnedVarRef.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9515,6 +11721,10 @@ def format(q) q.format(value) end end + + def ===(other) + other.is_a?(PinnedVarRef) && value === other.value + end end # VCall represent any plain named object with Ruby that could be either a @@ -9529,10 +11739,10 @@ class VCall < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) + def initialize(value:, location:) @value = value @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9543,6 +11753,17 @@ def child_nodes [value] end + def copy(value: nil, location: nil) + node = + VCall.new( + value: value || self.value, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9552,6 +11773,18 @@ def deconstruct_keys(_keys) def format(q) q.format(value) end + + def ===(other) + other.is_a?(VCall) && value === other.value + end + + def access_control? + @access_control ||= %w[private protected public].include?(value.value) + end + + def arity + 0 + end end # VoidStmt represents an empty lexical block of code. @@ -9559,15 +11792,12 @@ def format(q) # ;; # class VoidStmt < Node - # [Location] the location of this node - attr_reader :location - # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(location:, comments: []) + def initialize(location:) @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9578,6 +11808,13 @@ def child_nodes [] end + def copy(location: nil) + node = VoidStmt.new(location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9586,6 +11823,10 @@ def deconstruct_keys(_keys) def format(q) end + + def ===(other) + other.is_a?(VoidStmt) + end end # When represents a +when+ clause in a +case+ chain. @@ -9607,18 +11848,12 @@ class When < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize( - arguments:, - statements:, - consequent:, - location:, - comments: [] - ) + def initialize(arguments:, statements:, consequent:, location:) @arguments = arguments @statements = statements @consequent = consequent @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9629,6 +11864,19 @@ def child_nodes [arguments, statements, consequent] end + def copy(arguments: nil, statements: nil, consequent: nil, location: nil) + node = + When.new( + arguments: arguments || self.arguments, + statements: statements || self.statements, + consequent: consequent || self.consequent, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9641,6 +11889,22 @@ def deconstruct_keys(_keys) } end + # We have a special separator here for when clauses which causes them to + # fill as much of the line as possible as opposed to everything breaking + # into its own line as soon as you hit the print limit. + class Separator + def call(q) + q.group do + q.text(",") + q.breakable_space + end + end + end + + # We're going to keep a single instance of this separator around so we don't + # have to allocate a new one every time we format a when clause. + SEPARATOR = Separator.new.freeze + def format(q) keyword = "when " @@ -9651,33 +11915,35 @@ def format(q) if arguments.comments.any? q.format(arguments) else - separator = -> { q.group { q.comma_breakable } } - q.seplist(arguments.parts, separator) { |part| q.format(part) } + q.seplist(arguments.parts, SEPARATOR) { |part| q.format(part) } end # Very special case here. If you're inside of a when clause and the # last argument to the predicate is and endless range, then you are # forced to use the "then" keyword to make it parse properly. last = arguments.parts.last - if (last.is_a?(Dot2) || last.is_a?(Dot3)) && !last.right - q.text(" then") - end + q.text(" then") if last.is_a?(RangeNode) && !last.right end end unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent - q.breakable(force: true) + q.breakable_force q.format(consequent) end end end + + def ===(other) + other.is_a?(When) && arguments === other.arguments && + statements === other.statements && consequent === other.consequent + end end # While represents a +while+ loop. @@ -9685,8 +11951,8 @@ def format(q) # while predicate # end # - class While < Node - # [untyped] the expression to be checked + class WhileNode < Node + # [Node] the expression to be checked attr_reader :predicate # [Statements] the expressions to be executed @@ -9695,11 +11961,11 @@ class While < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(predicate:, statements:, location:, comments: []) + def initialize(predicate:, statements:, location:) @predicate = predicate @statements = statements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9710,6 +11976,18 @@ def child_nodes [predicate, statements] end + def copy(predicate: nil, statements: nil, location: nil) + node = + WhileNode.new( + predicate: predicate || self.predicate, + statements: statements || self.statements, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9722,83 +12000,16 @@ def deconstruct_keys(_keys) end def format(q) - if statements.empty? - keyword = "while " - - q.group do - q.text(keyword) - q.nest(keyword.length) { q.format(predicate) } - q.breakable(force: true) - q.text("end") - end - else - LoopFormatter.new("while", self, statements).format(q) - end + LoopFormatter.new("while", self).format(q) end - end - - # WhileMod represents the modifier form of a +while+ loop. - # - # expression while predicate - # - class WhileMod < Node - # [untyped] the expression to be executed - attr_reader :statement - - # [untyped] the expression to be checked - attr_reader :predicate - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(statement:, predicate:, location:, comments: []) - @statement = statement - @predicate = predicate - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_while_mod(self) - end - - def child_nodes - [statement, predicate] - end - - alias deconstruct child_nodes - def deconstruct_keys(_keys) - { - statement: statement, - predicate: predicate, - location: location, - comments: comments - } + def ===(other) + other.is_a?(WhileNode) && predicate === other.predicate && + statements === other.statements end - def format(q) - # If we're in the modifier form and we're modifying a `begin`, then this - # is a special case where we need to explicitly use the modifier form - # because otherwise the semantic meaning changes. This looks like: - # - # begin - # foo - # end while bar - # - # Also, if the statement of the modifier includes an assignment, then we - # can't know for certain that it won't impact the predicate, so we need to - # force it to stay as it is. This looks like: - # - # foo = bar while foo - # - if statement.is_a?(Begin) || ContainsAssignment.call(statement) - q.format(statement) - q.text(" while ") - q.format(predicate) - else - LoopFormatter.new("while", self, statement).format(q) - end + def modifier? + predicate.location.start_char > statements.location.start_char end end @@ -9817,10 +12028,10 @@ class Word < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, location:, comments: []) + def initialize(parts:, location:) @parts = parts @location = location - @comments = comments + @comments = [] end def match?(pattern) @@ -9835,6 +12046,17 @@ def child_nodes parts end + def copy(parts: nil, location: nil) + node = + Word.new( + parts: parts || self.parts, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9844,6 +12066,10 @@ def deconstruct_keys(_keys) def format(q) q.format_each(parts) end + + def ===(other) + other.is_a?(Word) && ArrayMatch.call(parts, other.parts) + end end # Words represents a string literal array with interpolation. @@ -9860,11 +12086,11 @@ class Words < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(beginning:, elements:, location:, comments: []) + def initialize(beginning:, elements:, location:) @beginning = beginning @elements = elements @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9875,6 +12101,14 @@ def child_nodes [] end + def copy(beginning: nil, elements: nil, location: nil) + Words.new( + beginning: beginning || self.beginning, + elements: elements || self.elements, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -9894,15 +12128,23 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) + end + + def ===(other) + other.is_a?(Words) && beginning === other.beginning && + ArrayMatch.call(elements, other.elements) end end @@ -9931,11 +12173,22 @@ def child_nodes [] end + def copy(value: nil, location: nil) + WordsBeg.new( + value: value || self.value, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { value: value, location: location } end + + def ===(other) + other.is_a?(WordsBeg) && value === other.value + end end # XString represents the contents of an XStringLiteral. @@ -9960,11 +12213,22 @@ def child_nodes parts end + def copy(parts: nil, location: nil) + XString.new( + parts: parts || self.parts, + location: location || self.location + ) + end + alias deconstruct child_nodes def deconstruct_keys(_keys) { parts: parts, location: location } end + + def ===(other) + other.is_a?(XString) && ArrayMatch.call(parts, other.parts) + end end # XStringLiteral represents a string that gets executed. @@ -9979,10 +12243,10 @@ class XStringLiteral < Node # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(parts:, location:, comments: []) + def initialize(parts:, location:) @parts = parts @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -9993,6 +12257,17 @@ def child_nodes parts end + def copy(parts: nil, location: nil) + node = + XStringLiteral.new( + parts: parts || self.parts, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -10004,23 +12279,27 @@ def format(q) q.format_each(parts) q.text("`") end + + def ===(other) + other.is_a?(XStringLiteral) && ArrayMatch.call(parts, other.parts) + end end # Yield represents using the +yield+ keyword with arguments. # # yield value # - class Yield < Node - # [Args | Paren] the arguments passed to the yield + class YieldNode < Node + # [nil | Args | Paren] the arguments passed to the yield attr_reader :arguments # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(arguments:, location:, comments: []) + def initialize(arguments:, location:) @arguments = arguments @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -10031,6 +12310,17 @@ def child_nodes [arguments] end + def copy(arguments: nil, location: nil) + node = + YieldNode.new( + arguments: arguments || self.arguments, + location: location || self.location + ) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) @@ -10038,6 +12328,11 @@ def deconstruct_keys(_keys) end def format(q) + if arguments.nil? + q.text("yield") + return + end + q.group do q.text("yield") @@ -10046,49 +12341,17 @@ def format(q) else q.if_break { q.text("(") }.if_flat { q.text(" ") } q.indent do - q.breakable("") + q.breakable_empty q.format(arguments) end - q.breakable("") + q.breakable_empty q.if_break { q.text(")") } end end end - end - - # Yield0 represents the bare +yield+ keyword with no arguments. - # - # yield - # - class Yield0 < Node - # [String] the value of the keyword - attr_reader :value - - # [Array[ Comment | EmbDoc ]] the comments attached to this node - attr_reader :comments - - def initialize(value:, location:, comments: []) - @value = value - @location = location - @comments = comments - end - - def accept(visitor) - visitor.visit_yield0(self) - end - - def child_nodes - [] - end - - alias deconstruct child_nodes - - def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } - end - def format(q) - q.text(value) + def ===(other) + other.is_a?(YieldNode) && arguments === other.arguments end end @@ -10097,16 +12360,12 @@ def format(q) # super # class ZSuper < Node - # [String] the value of the keyword - attr_reader :value - # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments - def initialize(value:, location:, comments: []) - @value = value + def initialize(location:) @location = location - @comments = comments + @comments = [] end def accept(visitor) @@ -10117,14 +12376,25 @@ def child_nodes [] end + def copy(location: nil) + node = ZSuper.new(location: location || self.location) + + node.comments.concat(comments.map(&:copy)) + node + end + alias deconstruct child_nodes def deconstruct_keys(_keys) - { value: value, location: location, comments: comments } + { location: location, comments: comments } end def format(q) - q.text(value) + q.text("super") + end + + def ===(other) + other.is_a?(ZSuper) end end end diff --git a/lib/syntax_tree/parser.rb b/lib/syntax_tree/parser.rb index 75d3c322..ace077ee 100644 --- a/lib/syntax_tree/parser.rb +++ b/lib/syntax_tree/parser.rb @@ -53,16 +53,53 @@ def initialize(start, line) # there's a BOM at the beginning of the file, which is the reason we need # to compare it to 0 here. def [](byteindex) - indices[byteindex < 0 ? 0 : byteindex] + indices[[byteindex, 0].max] + end + end + + # This represents all of the tokens coming back from the lexer. It is + # replacing a simple array because it keeps track of the last deleted token + # from the list for better error messages. + class TokenList + attr_reader :tokens, :last_deleted + + def initialize + @tokens = [] + @last_deleted = nil + end + + def <<(token) + tokens << token + end + + def [](index) + tokens[index] + end + + def any?(&block) + tokens.any?(&block) + end + + def reverse_each(&block) + tokens.reverse_each(&block) + end + + def rindex(&block) + tokens.rindex(&block) + end + + def delete(value) + @last_deleted = tokens.delete(value) || @last_deleted + end + + def delete_at(index) + @last_deleted = tokens.delete_at(index) end end # [String] the source being parsed attr_reader :source - # [Array[ String ]] the list of lines in the source - attr_reader :lines - # [Array[ SingleByteString | MultiByteString ]] the list of objects that # represent the start of each line in character offsets attr_reader :line_counts @@ -85,12 +122,6 @@ def initialize(source, *) # example. @source = source - # Similarly, we keep the lines of the source string around to be able to - # check if certain lines contain certain characters. For example, we'll - # use this to generate the content that goes after the __END__ keyword. - # Or we'll use this to check if a comment has other content on its line. - @lines = source.split(/\r?\n/) - # This is the full set of comments that have been found by the parser. # It's a running list. At the end of every block of statements, they will # go in and attempt to grab any comments that are on their own line and @@ -124,7 +155,7 @@ def initialize(source, *) # Most of the time, when a parser event consumes one of these events, it # will be deleted from the list. So ideally, this list stays pretty short # over the course of parsing a source string. - @tokens = [] + @tokens = TokenList.new # Here we're going to build up a list of SingleByteString or # MultiByteString objects. They're each going to represent a string in the @@ -133,7 +164,7 @@ def initialize(source, *) @line_counts = [] last_index = 0 - @source.lines.each do |line| + @source.each_line do |line| @line_counts << if line.size == line.bytesize SingleByteString.new(last_index) else @@ -174,6 +205,33 @@ def current_column line[column].to_i - line.start end + # Returns the current location that is being looked at for the parser for + # the purpose of locating the error. + def find_token_error(location) + if location + # If we explicitly passed a location into this find_token_error method, + # that means that's the source of the error, so we'll use that + # information for our error object. + lineno = location.start_line + [lineno, location.start_char - line_counts[lineno - 1].start] + elsif lineno && column + # If there is a line number associated with the current ripper state, + # then we'll use that information to generate the error. + [lineno, column] + elsif (location = tokens.last_deleted&.location) + # If we've already deleted a token from the list of tokens that we are + # consuming, then we'll fall back to that token's location. + lineno = location.start_line + [lineno, location.start_char - line_counts[lineno - 1].start] + else + # Finally, it's possible that when we hit this error the parsing thread + # for ripper has died. In that case, lineno and column both return nil. + # So we're just going to set it to line 1, column 0 in the hopes that + # that makes any sense. + [1, 0] + end + end + # As we build up a list of tokens, we'll periodically need to go backwards # and find the ones that we've already hit in order to determine the # location information for nodes that use them. For example, if you have a @@ -186,35 +244,81 @@ def current_column # "module" (which would happen to be the innermost keyword). Then the outer # one would only be able to grab the first one. In this way all of the # tokens act as their own stack. - def find_token(type, value = :any, consume: true, location: nil) + # + # If we're expecting to be able to find a token and consume it, but can't + # actually find it, then we need to raise an error. This is _usually_ caused + # by a syntax error in the source that we're printing. It could also be + # caused by accidentally attempting to consume a token twice by two + # different parser event handlers. + + def find_token(type) + index = tokens.rindex { |token| token.is_a?(type) } + tokens[index] if index + end + + def find_token_between(type, left, right) + bounds = left.location.end_char...right.location.start_char index = tokens.rindex do |token| - token.is_a?(type) && (value == :any || (token.value == value)) + char = token.location.start_char + break if char < bounds.begin + + token.is_a?(type) && bounds.cover?(char) end - if consume - # If we're expecting to be able to find a token and consume it, but - # can't actually find it, then we need to raise an error. This is - # _usually_ caused by a syntax error in the source that we're printing. - # It could also be caused by accidentally attempting to consume a token - # twice by two different parser event handlers. - unless index - token = value == :any ? type.name.split("::", 2).last : value - message = "Cannot find expected #{token}" - - if location - lineno = location.start_line - column = location.start_char - line_counts[lineno - 1].start - raise ParseError.new(message, lineno, column) - else - raise ParseError.new(message, lineno, column) - end + tokens[index] if index + end + + def find_keyword(name) + index = tokens.rindex { |token| token.is_a?(Kw) && (token.name == name) } + tokens[index] if index + end + + def find_keyword_between(name, left, right) + bounds = left.end_char...right.start_char + index = + tokens.rindex do |token| + char = token.location.start_char + break if char < bounds.begin + + token.is_a?(Kw) && (token.name == name) && bounds.cover?(char) end - tokens.delete_at(index) - elsif index - tokens[index] - end + tokens[index] if index + end + + def find_operator(name) + index = tokens.rindex { |token| token.is_a?(Op) && (token.name == name) } + tokens[index] if index + end + + def consume_error(name, location) + message = "Cannot find expected #{name}" + raise ParseError.new(message, *find_token_error(location)) + end + + def consume_token(type) + index = tokens.rindex { |token| token.is_a?(type) } + consume_error(type.name.split("::", 2).last, nil) unless index + tokens.delete_at(index) + end + + def consume_tstring_end(location) + index = tokens.rindex { |token| token.is_a?(TStringEnd) } + consume_error("string ending", location) unless index + tokens.delete_at(index) + end + + def consume_keyword(name) + index = tokens.rindex { |token| token.is_a?(Kw) && (token.name == name) } + consume_error(name, nil) unless index + tokens.delete_at(index) + end + + def consume_operator(name) + index = tokens.rindex { |token| token.is_a?(Op) && (token.name == name) } + consume_error(name, nil) unless index + tokens.delete_at(index) end # A helper function to find a :: operator. We do special handling instead of @@ -243,13 +347,18 @@ def find_colon2_before(const) # By finding the next non-space character, we can make sure that the bounds # of the statement list are correct. def find_next_statement_start(position) - remaining = source[position..] - - if remaining.sub(/\A +/, "")[0] == "#" - return position + remaining.index("\n") + maximum = source.length + + position.upto(maximum) do |pound_index| + case source[pound_index] + when "#" + return source.index("\n", pound_index + 1) || maximum + when " " + # continue + else + return position + end end - - position end # -------------------------------------------------------------------------- @@ -260,18 +369,19 @@ def find_next_statement_start(position) # :call-seq: # on_BEGIN: (Statements statements) -> BEGINBlock def on_BEGIN(statements) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) - start_char = find_next_statement_start(lbrace.location.end_char) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) + start_char = find_next_statement_start(lbrace.location.end_char) statements.bind( + self, start_char, start_char - line_counts[lbrace.location.start_line - 1].start, rbrace.location.start_char, rbrace.location.start_column ) - keyword = find_token(Kw, "BEGIN") + keyword = consume_keyword(:BEGIN) BEGINBlock.new( lbrace: lbrace, @@ -298,18 +408,19 @@ def on_CHAR(value) # :call-seq: # on_END: (Statements statements) -> ENDBlock def on_END(statements) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) - start_char = find_next_statement_start(lbrace.location.end_char) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) + start_char = find_next_statement_start(lbrace.location.end_char) statements.bind( + self, start_char, start_char - line_counts[lbrace.location.start_line - 1].start, rbrace.location.start_char, rbrace.location.start_column ) - keyword = find_token(Kw, "END") + keyword = consume_keyword(:END) ENDBlock.new( lbrace: lbrace, @@ -338,11 +449,11 @@ def on___end__(value) # on_alias: ( # (DynaSymbol | SymbolLiteral) left, # (DynaSymbol | SymbolLiteral) right - # ) -> Alias + # ) -> AliasNode def on_alias(left, right) - keyword = find_token(Kw, "alias") + keyword = consume_keyword(:alias) - Alias.new( + AliasNode.new( left: left, right: right, location: keyword.location.to(right.location) @@ -352,8 +463,8 @@ def on_alias(left, right) # :call-seq: # on_aref: (untyped collection, (nil | Args) index) -> ARef def on_aref(collection, index) - find_token(LBracket) - rbracket = find_token(RBracket) + consume_token(LBracket) + rbracket = consume_token(RBracket) ARef.new( collection: collection, @@ -368,8 +479,8 @@ def on_aref(collection, index) # (nil | Args) index # ) -> ARefField def on_aref_field(collection, index) - find_token(LBracket) - rbracket = find_token(RBracket) + consume_token(LBracket) + rbracket = consume_token(RBracket) ARefField.new( collection: collection, @@ -387,8 +498,8 @@ def on_aref_field(collection, index) # (nil | Args | ArgsForward) arguments # ) -> ArgParen def on_arg_paren(arguments) - lparen = find_token(LParen) - rparen = find_token(RParen) + lparen = consume_token(LParen) + rparen = consume_token(RParen) # If the arguments exceed the ending of the parentheses, then we know we # have a heredoc in the arguments, and we need to use the bounds of the @@ -430,23 +541,26 @@ def on_args_add(arguments, argument) # (false | untyped) block # ) -> Args def on_args_add_block(arguments, block) + end_char = arguments.parts.any? && arguments.location.end_char + # First, see if there is an & operator that could potentially be # associated with the block part of this args_add_block. If there is not, # then just return the arguments. - operator = find_token(Op, "&", consume: false) - return arguments unless operator - - # If there are any arguments and the operator we found from the list is - # not after them, then we're going to return the arguments as-is because - # we're looking at an & that occurs before the arguments are done. - if arguments.parts.any? && - operator.location.start_char < arguments.location.end_char - return arguments - end + index = + tokens.rindex do |token| + # If there are any arguments and the operator we found from the list + # is not after them, then we're going to return the arguments as-is + # because we're looking at an & that occurs before the arguments are + # done. + return arguments if end_char && token.location.start_char < end_char + token.is_a?(Op) && (token.name == :&) + end + + return arguments unless index # Now we know we have an & operator, so we're going to delete it from the # list of tokens to make sure it doesn't get confused with anything else. - tokens.delete(operator) + operator = tokens.delete_at(index) # Construct the location that represents the block argument. location = operator.location @@ -465,7 +579,7 @@ def on_args_add_block(arguments, block) # :call-seq: # on_args_add_star: (Args arguments, untyped star) -> Args def on_args_add_star(arguments, argument) - beginning = find_token(Op, "*") + beginning = consume_operator(:*) ending = argument || beginning location = @@ -487,9 +601,9 @@ def on_args_add_star(arguments, argument) # :call-seq: # on_args_forward: () -> ArgsForward def on_args_forward - op = find_token(Op, "...") + op = consume_operator(:"...") - ArgsForward.new(value: op.value, location: op.location) + ArgsForward.new(location: op.location) end # :call-seq: @@ -507,8 +621,8 @@ def on_args_new # ArrayLiteral | QSymbols | QWords | Symbols | Words def on_array(contents) if !contents || contents.is_a?(Args) - lbracket = find_token(LBracket) - rbracket = find_token(RBracket) + lbracket = consume_token(LBracket) + rbracket = consume_token(RBracket) ArrayLiteral.new( lbracket: lbracket, @@ -516,8 +630,7 @@ def on_array(contents) location: lbracket.location.to(rbracket.location) ) else - tstring_end = - find_token(TStringEnd, location: contents.beginning.location) + tstring_end = consume_tstring_end(contents.beginning.location) contents.class.new( beginning: contents.beginning, @@ -527,6 +640,61 @@ def on_array(contents) end end + # Ugh... I really do not like this class. Basically, ripper doesn't provide + # enough information about where pins are located in the tree. It only gives + # events for ^ ops and var_ref nodes. You have to piece it together + # yourself. + # + # Note that there are edge cases here that we straight up do not address, + # because I honestly think it's going to be faster to write a new parser + # than to address them. For example, this will not work properly: + # + # foo in ^((bar = 0; bar; baz)) + # + # If someone actually does something like that, we'll have to find another + # way to make this work. + class PinVisitor < Visitor + attr_reader :pins, :stack + + def initialize(pins) + @pins = pins + @stack = [] + end + + def visit(node) + return if pins.empty? + stack << node + super + stack.pop + end + + visit_methods do + def visit_var_ref(node) + if node.start_char > pins.first.start_char + node.pin(stack[-2], pins.shift) + else + super + end + end + end + + def self.visit(node, tokens) + start_char = node.start_char + allocated = [] + + tokens.reverse_each do |token| + char = token.location.start_char + break if char <= start_char + + if token.is_a?(Op) && token.value == "^" + allocated.unshift(tokens.delete(token)) + end + end + + new(allocated).visit(node) if allocated.any? + end + end + # :call-seq: # on_aryptn: ( # (nil | VarRef) constant, @@ -535,25 +703,22 @@ def on_array(contents) # (nil | Array[untyped]) posts # ) -> AryPtn def on_aryptn(constant, requireds, rest, posts) - parts = [constant, *requireds, rest, *posts].compact + lbracket = find_token(LBracket) + lbracket ||= find_token(LParen) if constant - # If there aren't any parts (no constant, no positional arguments), then - # we're matching an empty array. In this case, we're going to look for the - # left and right brackets explicitly. Otherwise, we'll just use the bounds - # of the various parts. - location = - if parts.empty? - find_token(LBracket).location.to(find_token(RBracket).location) - else - parts[0].location.to(parts[-1].location) - end + rbracket = find_token(RBracket) + rbracket ||= find_token(RParen) if constant - # If there's the optional then keyword, then we'll delete that and use it - # as the end bounds of the location. - if (token = find_token(Kw, "then", consume: false)) - tokens.delete(token) - location = location.to(token.location) - end + parts = [constant, lbracket, *requireds, rest, *posts, rbracket].compact + + # The location is going to be determined by the first part to the last + # part. This includes potential brackets. + location = parts[0].location.to(parts[-1].location) + + # Now that we have the location calculated, we can remove the brackets + # from the list of tokens. + tokens.delete(lbracket) if lbracket + tokens.delete(rbracket) if rbracket # If there is a plain *, then we're going to fix up the location of it # here because it currently doesn't have anything to use for its precise @@ -561,12 +726,13 @@ def on_aryptn(constant, requireds, rest, posts) if rest.is_a?(VarField) && rest.value.nil? tokens.rindex do |rtoken| case rtoken - in Op[value: "*"] - rest = VarField.new(value: nil, location: rtoken.location) - break - in Comma + when Comma break - else + when Op + if rtoken.value == "*" + rest = VarField.new(value: nil, location: rtoken.location) + break + end end end end @@ -611,11 +777,11 @@ def on_assoc_new(key, value) # :call-seq: # on_assoc_splat: (untyped value) -> AssocSplat def on_assoc_splat(value) - operator = find_token(Op, "**") + operator = consume_operator(:**) AssocSplat.new( value: value, - location: operator.location.to(value.location) + location: operator.location.to((value || operator).location) ) end @@ -671,34 +837,34 @@ def on_bare_assoc_hash(assocs) # :call-seq: # on_begin: (untyped bodystmt) -> Begin | PinnedBegin def on_begin(bodystmt) - pin = find_token(Op, "^", consume: false) + pin = find_operator(:^) if pin && pin.location.start_char < bodystmt.location.start_char tokens.delete(pin) - find_token(LParen) + consume_token(LParen) - rparen = find_token(RParen) + rparen = consume_token(RParen) location = pin.location.to(rparen.location) PinnedBegin.new(statement: bodystmt, location: location) else - keyword = find_token(Kw, "begin") + keyword = consume_keyword(:begin) end_location = - if bodystmt.rescue_clause || bodystmt.ensure_clause || - bodystmt.else_clause + if bodystmt.else_clause bodystmt.location else - find_token(Kw, "end").location + consume_keyword(:end).location end bodystmt.bind( - keyword.location.end_char, + self, + find_next_statement_start(keyword.location.end_char), keyword.location.end_column, end_location.end_char, end_location.end_column ) - location = keyword.location.to(bodystmt.location) + location = keyword.location.to(end_location) Begin.new(bodystmt: bodystmt, location: location) end end @@ -714,13 +880,11 @@ def on_binary(left, operator, right) # Here, we're going to search backward for the token that's between the # two operands that matches the operator so we can delete it from the # list. + range = (left.location.end_char + 1)...right.location.start_char index = tokens.rindex do |token| - location = token.location - - token.is_a?(Op) && token.value == operator.to_s && - location.start_char > left.location.end_char && - location.end_char < right.location.start_char + token.is_a?(Op) && token.name == operator && + range.cover?(token.location.start_char) end tokens.delete_at(index) if index @@ -745,13 +909,34 @@ def on_binary(left, operator, right) # on_block_var: (Params params, (nil | Array[Ident]) locals) -> BlockVar def on_block_var(params, locals) index = - tokens.rindex do |node| - node.is_a?(Op) && %w[| ||].include?(node.value) && - node.location.start_char < params.location.start_char - end + tokens.rindex { |node| node.is_a?(Op) && %w[| ||].include?(node.value) } + + ending = tokens.delete_at(index) + beginning = ending.value == "||" ? ending : consume_operator(:|) + + # If there are no parameters, then we didn't have anything to base the + # location information of off. Now that we have an opening of the + # block, we can correct this. + if params.empty? + start_line = params.location.start_line + start_char = + ( + if beginning.value == "||" + beginning.location.start_char + else + find_next_statement_start(beginning.location.end_char) + end + ) - beginning = tokens[index] - ending = tokens[-1] + location = + Location.fixed( + line: start_line, + char: start_char, + column: start_char - line_counts[start_line - 1].start + ) + + params = params.copy(location: location) + end BlockVar.new( params: params, @@ -763,7 +948,7 @@ def on_block_var(params, locals) # :call-seq: # on_blockarg: (Ident name) -> BlockArg def on_blockarg(name) - operator = find_token(Op, "&") + operator = consume_operator(:&) location = operator.location location = location.to(name.location) if name @@ -779,14 +964,23 @@ def on_blockarg(name) # (nil | Ensure) ensure_clause # ) -> BodyStmt def on_bodystmt(statements, rescue_clause, else_clause, ensure_clause) + # In certain versions of Ruby, the `statements` argument can be any node + # in the case that we're inside of an endless method definition. In this + # case we'll wrap it in a Statements node to be consistent. + unless statements.is_a?(Statements) + statements = + Statements.new(body: [statements], location: statements.location) + end + + parts = [statements, rescue_clause, else_clause, ensure_clause].compact + BodyStmt.new( statements: statements, rescue_clause: rescue_clause, - else_keyword: else_clause && find_token(Kw, "else"), + else_keyword: else_clause && consume_keyword(:else), else_clause: else_clause, ensure_clause: ensure_clause, - location: - Location.fixed(line: lineno, char: char_pos, column: current_column) + location: parts.first.location.to(parts.last.location) ) end @@ -794,14 +988,15 @@ def on_bodystmt(statements, rescue_clause, else_clause, ensure_clause) # on_brace_block: ( # (nil | BlockVar) block_var, # Statements statements - # ) -> BraceBlock + # ) -> BlockNode def on_brace_block(block_var, statements) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) location = (block_var || lbrace).location - start_char = find_next_statement_start(location.end_char) + start_char = find_next_statement_start(location.end_char) statements.bind( + self, start_char, start_char - line_counts[location.start_line - 1].start, rbrace.location.start_char, @@ -821,10 +1016,10 @@ def on_brace_block(block_var, statements) end_column: rbrace.location.end_column ) - BraceBlock.new( - lbrace: lbrace, + BlockNode.new( + opening: lbrace, block_var: block_var, - statements: statements, + bodystmt: statements, location: location ) end @@ -832,7 +1027,7 @@ def on_brace_block(block_var, statements) # :call-seq: # on_break: (Args arguments) -> Break def on_break(arguments) - keyword = find_token(Kw, "break") + keyword = consume_keyword(:break) location = keyword.location location = location.to(arguments.location) if arguments.parts.any? @@ -845,7 +1040,7 @@ def on_break(arguments) # untyped receiver, # (:"::" | Op | Period) operator, # (:call | Backtick | Const | Ident | Op) message - # ) -> Call + # ) -> CallNode def on_call(receiver, operator, message) ending = if message != :call @@ -856,7 +1051,7 @@ def on_call(receiver, operator, message) receiver end - Call.new( + CallNode.new( receiver: receiver, operator: operator, message: message, @@ -868,8 +1063,24 @@ def on_call(receiver, operator, message) # :call-seq: # on_case: (untyped value, untyped consequent) -> Case | RAssign def on_case(value, consequent) - if (keyword = find_token(Kw, "case", consume: false)) - tokens.delete(keyword) + if value && (operator = find_keyword(:in) || find_operator(:"=>")) && + (value.location.end_char...consequent.location.start_char).cover?( + operator.location.start_char + ) + tokens.delete(operator) + + node = + RAssign.new( + value: value, + operator: operator, + pattern: consequent, + location: value.location.to(consequent.location) + ) + + PinVisitor.visit(node, tokens) + node + else + keyword = consume_keyword(:case) Case.new( keyword: keyword, @@ -877,15 +1088,6 @@ def on_case(value, consequent) consequent: consequent, location: keyword.location.to(consequent.location) ) - else - operator = find_token(Kw, "in", consume: false) || find_token(Op, "=>") - - RAssign.new( - value: value, - operator: operator, - pattern: consequent, - location: value.location.to(consequent.location) - ) end end @@ -896,12 +1098,13 @@ def on_case(value, consequent) # BodyStmt bodystmt # ) -> ClassDeclaration def on_class(constant, superclass, bodystmt) - beginning = find_token(Kw, "class") - ending = find_token(Kw, "end") + beginning = consume_keyword(:class) + ending = consume_keyword(:end) location = (superclass || constant).location start_char = find_next_statement_start(location.end_char) bodystmt.bind( + self, start_char, start_char - line_counts[location.start_line - 1].start, ending.location.start_char, @@ -941,6 +1144,7 @@ def on_command(message, arguments) Command.new( message: message, arguments: arguments, + block: nil, location: message.location.to(arguments.location) ) end @@ -960,6 +1164,7 @@ def on_command_call(receiver, operator, message, arguments) operator: operator, message: message, arguments: arguments, + block: nil, location: receiver.location.to(ending.location) ) end @@ -967,20 +1172,37 @@ def on_command_call(receiver, operator, message, arguments) # :call-seq: # on_comment: (String value) -> Comment def on_comment(value) - line = lineno - comment = - Comment.new( - value: value.chomp, - inline: value.strip != lines[line - 1].strip, - location: - Location.token( - line: line, - char: char_pos, - column: current_column, - size: value.size - 1 - ) + # char is the index of the # character in the source. + char = char_pos + location = + Location.token( + line: lineno, + char: char, + column: current_column, + size: value.size - 1 ) + # Loop backward in the source string, starting from the beginning of the + # comment, and find the first character that is not a space or a tab. If + # index is -1, this indicates that we've checked all of the characters + # back to the start of the source, so this comment must be at the + # beginning of the file. + # + # We are purposefully not using rindex or regular expressions here because + # they check if there are invalid characters, which is actually possible + # with the use of __END__ syntax. + index = char - 1 + while index > -1 && (source[index] == "\t" || source[index] == " ") + index -= 1 + end + + # If we found a character that was not a space or a tab before the comment + # and it's a newline, then this comment is inline. Otherwise, it stands on + # its own and can be attached as its own node in the tree. + inline = index != -1 && source[index] != "\n" + comment = + Comment.new(value: value.chomp, inline: inline, location: location) + @comments << comment comment end @@ -1001,13 +1223,23 @@ def on_const(value) end # :call-seq: - # on_const_path_field: (untyped parent, Const constant) -> ConstPathField + # on_const_path_field: (untyped parent, Const constant) -> + # ConstPathField | Field def on_const_path_field(parent, constant) - ConstPathField.new( - parent: parent, - constant: constant, - location: parent.location.to(constant.location) - ) + if constant.is_a?(Const) + ConstPathField.new( + parent: parent, + constant: constant, + location: parent.location.to(constant.location) + ) + else + Field.new( + parent: parent, + operator: consume_operator(:"::"), + name: constant, + location: parent.location.to(constant.location) + ) + end end # :call-seq: @@ -1046,7 +1278,7 @@ def on_cvar(value) # (Backtick | Const | Ident | Kw | Op) name, # (nil | Params | Paren) params, # untyped bodystmt - # ) -> Def | DefEndless + # ) -> DefNode def on_def(name, params, bodystmt) # Make sure to delete this token in case you're defining something like # def class which would lead to this being a kw and causing all kinds of @@ -1055,7 +1287,7 @@ def on_def(name, params, bodystmt) # Find the beginning of the method definition, which works for single-line # and normal method definitions. - beginning = find_token(Kw, "def") + beginning = consume_keyword(:def) # If there aren't any params then we need to correct the params node # location information @@ -1075,20 +1307,23 @@ def on_def(name, params, bodystmt) params = Params.new(location: location) end - ending = find_token(Kw, "end", consume: false) + ending = find_keyword(:end) if ending tokens.delete(ending) start_char = find_next_statement_start(params.location.end_char) bodystmt.bind( + self, start_char, start_char - line_counts[params.location.start_line - 1].start, ending.location.start_char, ending.location.start_column ) - Def.new( + DefNode.new( + target: nil, + operator: nil, name: name, params: params, bodystmt: bodystmt, @@ -1099,12 +1334,12 @@ def on_def(name, params, bodystmt) # the statements list. Before, it was just the individual statement. statement = bodystmt.is_a?(BodyStmt) ? bodystmt.statements : bodystmt - DefEndless.new( + DefNode.new( target: nil, operator: nil, name: name, - paren: params, - statement: statement, + params: params, + bodystmt: statement, location: beginning.location.to(bodystmt.location) ) end @@ -1113,13 +1348,13 @@ def on_def(name, params, bodystmt) # :call-seq: # on_defined: (untyped value) -> Defined def on_defined(value) - beginning = find_token(Kw, "defined?") + beginning = consume_keyword(:defined?) ending = value range = beginning.location.end_char...value.location.start_char if source[range].include?("(") - find_token(LParen) - ending = find_token(RParen) + consume_token(LParen) + ending = consume_token(RParen) end Defined.new( @@ -1135,7 +1370,7 @@ def on_defined(value) # (Backtick | Const | Ident | Kw | Op) name, # (Params | Paren) params, # BodyStmt bodystmt - # ) -> Defs + # ) -> DefNode def on_defs(target, operator, name, params, bodystmt) # Make sure to delete this token in case you're defining something # like def class which would lead to this being a kw and causing all kinds @@ -1160,21 +1395,22 @@ def on_defs(target, operator, name, params, bodystmt) params = Params.new(location: location) end - beginning = find_token(Kw, "def") - ending = find_token(Kw, "end", consume: false) + beginning = consume_keyword(:def) + ending = find_keyword(:end) if ending tokens.delete(ending) start_char = find_next_statement_start(params.location.end_char) bodystmt.bind( + self, start_char, start_char - line_counts[params.location.start_line - 1].start, ending.location.start_char, ending.location.start_column ) - Defs.new( + DefNode.new( target: target, operator: operator, name: name, @@ -1187,34 +1423,35 @@ def on_defs(target, operator, name, params, bodystmt) # the statements list. Before, it was just the individual statement. statement = bodystmt.is_a?(BodyStmt) ? bodystmt.statements : bodystmt - DefEndless.new( + DefNode.new( target: target, operator: operator, name: name, - paren: params, - statement: statement, + params: params, + bodystmt: statement, location: beginning.location.to(bodystmt.location) ) end end # :call-seq: - # on_do_block: (BlockVar block_var, BodyStmt bodystmt) -> DoBlock + # on_do_block: (BlockVar block_var, BodyStmt bodystmt) -> BlockNode def on_do_block(block_var, bodystmt) - beginning = find_token(Kw, "do") - ending = find_token(Kw, "end") + beginning = consume_keyword(:do) + ending = consume_keyword(:end) location = (block_var || beginning).location start_char = find_next_statement_start(location.end_char) bodystmt.bind( + self, start_char, start_char - line_counts[location.start_line - 1].start, ending.location.start_char, ending.location.start_column ) - DoBlock.new( - keyword: beginning, + BlockNode.new( + opening: beginning, block_var: block_var, bodystmt: bodystmt, location: beginning.location.to(ending.location) @@ -1222,30 +1459,32 @@ def on_do_block(block_var, bodystmt) end # :call-seq: - # on_dot2: ((nil | untyped) left, (nil | untyped) right) -> Dot2 + # on_dot2: ((nil | untyped) left, (nil | untyped) right) -> RangeNode def on_dot2(left, right) - operator = find_token(Op, "..") + operator = consume_operator(:"..") beginning = left || operator ending = right || operator - Dot2.new( + RangeNode.new( left: left, + operator: operator, right: right, location: beginning.location.to(ending.location) ) end # :call-seq: - # on_dot3: ((nil | untyped) left, (nil | untyped) right) -> Dot3 + # on_dot3: ((nil | untyped) left, (nil | untyped) right) -> RangeNode def on_dot3(left, right) - operator = find_token(Op, "...") + operator = consume_operator(:"...") beginning = left || operator ending = right || operator - Dot3.new( + RangeNode.new( left: left, + operator: operator, right: right, location: beginning.location.to(ending.location) ) @@ -1254,10 +1493,10 @@ def on_dot3(left, right) # :call-seq: # on_dyna_symbol: (StringContent string_content) -> DynaSymbol def on_dyna_symbol(string_content) - if find_token(SymBeg, consume: false) + if (symbeg = find_token(SymBeg)) # A normal dynamic symbol - symbeg = find_token(SymBeg) - tstring_end = find_token(TStringEnd, location: symbeg.location) + tokens.delete(symbeg) + tstring_end = consume_tstring_end(symbeg.location) DynaSymbol.new( quote: symbeg.value, @@ -1266,8 +1505,8 @@ def on_dyna_symbol(string_content) ) else # A dynamic symbol as a hash key - tstring_beg = find_token(TStringBeg) - label_end = find_token(LabelEnd) + tstring_beg = consume_token(TStringBeg) + label_end = consume_token(LabelEnd) DynaSymbol.new( parts: string_content.parts, @@ -1280,7 +1519,7 @@ def on_dyna_symbol(string_content) # :call-seq: # on_else: (Statements statements) -> Else def on_else(statements) - keyword = find_token(Kw, "else") + keyword = consume_keyword(:else) # else can either end with an end keyword (in which case we'll want to # consume that event) or it can end with an ensure keyword (in which case @@ -1290,11 +1529,17 @@ def on_else(statements) token.is_a?(Kw) && %w[end ensure].include?(token.value) end + if index.nil? + message = "Cannot find expected else ending" + raise ParseError.new(message, *find_token_error(keyword.location)) + end + node = tokens[index] ending = node.value == "end" ? tokens.delete_at(index) : node - start_char = find_next_statement_start(keyword.location.end_char) + start_char = find_next_statement_start(keyword.location.end_char) statements.bind( + self, start_char, start_char - line_counts[keyword.location.start_line - 1].start, ending.location.start_char, @@ -1315,12 +1560,21 @@ def on_else(statements) # (nil | Elsif | Else) consequent # ) -> Elsif def on_elsif(predicate, statements, consequent) - beginning = find_token(Kw, "elsif") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:elsif) + ending = consequent || consume_keyword(:end) + + delimiter = + find_keyword_between(:then, predicate, statements) || + find_token_between(Semicolon, predicate, statements) + + tokens.delete(delimiter) if delimiter + start_char = + find_next_statement_start((delimiter || predicate).location.end_char) statements.bind( - predicate.location.end_char, - predicate.location.end_column, + self, + start_char, + start_char - line_counts[predicate.location.start_line - 1].start, ending.location.start_char, ending.location.start_column ) @@ -1435,13 +1689,14 @@ def on_embvar(value) # :call-seq: # on_ensure: (Statements statements) -> Ensure def on_ensure(statements) - keyword = find_token(Kw, "ensure") + keyword = consume_keyword(:ensure) # We don't want to consume the :@kw event, because that would break # def..ensure..end chains. - ending = find_token(Kw, "end", consume: false) + ending = find_keyword(:end) start_char = find_next_statement_start(keyword.location.end_char) statements.bind( + self, start_char, start_char - line_counts[keyword.location.start_line - 1].start, ending.location.start_char, @@ -1461,27 +1716,33 @@ def on_ensure(statements) # :call-seq: # on_excessed_comma: () -> ExcessedComma def on_excessed_comma(*) - comma = find_token(Comma) + comma = consume_token(Comma) ExcessedComma.new(value: comma.value, location: comma.location) end # :call-seq: - # on_fcall: ((Const | Ident) value) -> FCall + # on_fcall: ((Const | Ident) value) -> CallNode def on_fcall(value) - FCall.new(value: value, arguments: nil, location: value.location) + CallNode.new( + receiver: nil, + operator: nil, + message: value, + arguments: nil, + location: value.location + ) end # :call-seq: # on_field: ( # untyped parent, - # (:"::" | Op | Period) operator + # (:"::" | Op | Period | 73) operator # (Const | Ident) name # ) -> Field def on_field(parent, operator, name) Field.new( parent: parent, - operator: operator, + operator: operator == 73 ? :"::" : operator, name: name, location: parent.location.to(name.location) ) @@ -1510,24 +1771,38 @@ def on_float(value) # VarField right # ) -> FndPtn def on_fndptn(constant, left, values, right) + # The left and right of a find pattern are always going to be splats, so + # we're going to consume the * operators and use their location + # information to extend the location of the splats. + right, left = + [right, left].map do |node| + operator = consume_operator(:*) + location = + if node.value + operator.location.to(node.location) + else + operator.location + end + + node.copy(location: location) + end + # The opening of this find pattern is either going to be a left bracket, a # right left parenthesis, or the left splat. We're going to use this to # determine how to find the closing of the pattern, as well as determining # the location of the node. - opening = - find_token(LBracket, consume: false) || - find_token(LParen, consume: false) || left + opening = find_token(LBracket) || find_token(LParen) || left # The closing is based on the opening, which is either the matched # punctuation or the right splat. closing = case opening - in LBracket + when LBracket tokens.delete(opening) - find_token(RBracket) - in LParen + consume_token(RBracket) + when LParen tokens.delete(opening) - find_token(RParen) + consume_token(RParen) else right end @@ -1548,22 +1823,24 @@ def on_fndptn(constant, left, values, right) # Statements statements # ) -> For def on_for(index, collection, statements) - beginning = find_token(Kw, "for") - in_keyword = find_token(Kw, "in") - ending = find_token(Kw, "end") - - # Consume the do keyword if it exists so that it doesn't get confused for - # some other block - keyword = find_token(Kw, "do", consume: false) - if keyword && - keyword.location.start_char > collection.location.end_char && - keyword.location.end_char < ending.location.start_char - tokens.delete(keyword) - end + beginning = consume_keyword(:for) + in_keyword = consume_keyword(:in) + ending = consume_keyword(:end) + + delimiter = + find_keyword_between(:do, collection, ending) || + find_token_between(Semicolon, collection, ending) + + tokens.delete(delimiter) if delimiter + + start_char = + find_next_statement_start((delimiter || collection).location.end_char) statements.bind( - (keyword || collection).location.end_char, - (keyword || collection).location.end_column, + self, + start_char, + start_char - + line_counts[(delimiter || collection).location.end_line - 1].start, ending.location.start_char, ending.location.start_column ) @@ -1599,8 +1876,8 @@ def on_gvar(value) # :call-seq: # on_hash: ((nil | Array[AssocNew | AssocSplat]) assocs) -> HashLiteral def on_hash(assocs) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) HashLiteral.new( lbrace: lbrace, @@ -1617,7 +1894,7 @@ def on_heredoc_beg(value) line: lineno, char: char_pos, column: current_column, - size: value.size + 1 + size: value.size ) # Here we're going to artificially create an extra node type so that if @@ -1647,9 +1924,19 @@ def on_heredoc_dedent(string, width) def on_heredoc_end(value) heredoc = @heredocs[-1] + location = + Location.token( + line: lineno, + char: char_pos, + column: current_column, + size: value.size + ) + + heredoc_end = HeredocEnd.new(value: value.chomp, location: location) + @heredocs[-1] = Heredoc.new( beginning: heredoc.beginning, - ending: value.chomp, + ending: heredoc_end, dedent: heredoc.dedent, parts: heredoc.parts, location: @@ -1657,9 +1944,9 @@ def on_heredoc_end(value) start_line: heredoc.location.start_line, start_char: heredoc.location.start_char, start_column: heredoc.location.start_column, - end_line: lineno, - end_char: char_pos, - end_column: current_column + end_line: location.end_line, + end_char: location.end_char, + end_column: location.end_column ) ) end @@ -1667,23 +1954,61 @@ def on_heredoc_end(value) # :call-seq: # on_hshptn: ( # (nil | untyped) constant, - # Array[[Label, untyped]] keywords, + # Array[[Label | StringContent, untyped]] keywords, # (nil | VarField) keyword_rest # ) -> HshPtn def on_hshptn(constant, keywords, keyword_rest) - # Create an artificial VarField if we find an extra ** on the end - if !keyword_rest && (token = find_token(Op, "**", consume: false)) + keywords = + (keywords || []).map do |(label, value)| + if label.is_a?(Label) + [label, value] + else + tstring_beg_index = + tokens.rindex do |token| + token.is_a?(TStringBeg) && + token.location.start_char < label.location.start_char + end + + tstring_beg = tokens.delete_at(tstring_beg_index) + + label_end_index = + tokens.rindex do |token| + token.is_a?(LabelEnd) && + token.location.start_char == label.location.end_char + end + + label_end = tokens.delete_at(label_end_index) + + [ + DynaSymbol.new( + parts: label.parts, + quote: label_end.value[0], + location: tstring_beg.location.to(label_end.location) + ), + value + ] + end + end + + if keyword_rest + # We're doing this to delete the token from the list so that it doesn't + # confuse future patterns by thinking they have an extra ** on the end. + consume_operator(:**) + elsif (token = find_operator(:**)) tokens.delete(token) + + # Create an artificial VarField if we find an extra ** on the end. This + # means the formatting will be a little more consistent. keyword_rest = VarField.new(value: nil, location: token.location) end - parts = [constant, *keywords&.flatten(1), keyword_rest].compact + parts = [constant, *keywords.flatten(1), keyword_rest].compact # If there's no constant, there may be braces, so we're going to look for # those to get our bounds. unless constant - lbrace = find_token(LBrace, consume: false) - rbrace = find_token(RBrace, consume: false) + lbrace = find_token(LBrace) + rbrace = find_token(RBrace) if lbrace && rbrace parts = [lbrace, *parts, rbrace] @@ -1692,15 +2017,9 @@ def on_hshptn(constant, keywords, keyword_rest) end end - # Delete the optional then keyword - if (token = find_token(Kw, "then", consume: false)) - parts << token - tokens.delete(token) - end - HshPtn.new( constant: constant, - keywords: keywords || [], + keywords: keywords, keyword_rest: keyword_rest, location: parts[0].location.to(parts[-1].location) ) @@ -1726,19 +2045,27 @@ def on_ident(value) # untyped predicate, # Statements statements, # (nil | Elsif | Else) consequent - # ) -> If + # ) -> IfNode def on_if(predicate, statements, consequent) - beginning = find_token(Kw, "if") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:if) + ending = consequent || consume_keyword(:end) + + if (keyword = find_keyword_between(:then, predicate, ending)) + tokens.delete(keyword) + end + + start_char = + find_next_statement_start((keyword || predicate).location.end_char) statements.bind( - predicate.location.end_char, - predicate.location.end_column, + self, + start_char, + start_char - line_counts[predicate.location.end_line - 1].start, ending.location.start_char, ending.location.start_column ) - If.new( + IfNode.new( predicate: predicate, statements: statements, consequent: consequent, @@ -1758,13 +2085,15 @@ def on_ifop(predicate, truthy, falsy) end # :call-seq: - # on_if_mod: (untyped predicate, untyped statement) -> IfMod + # on_if_mod: (untyped predicate, untyped statement) -> IfNode def on_if_mod(predicate, statement) - find_token(Kw, "if") + consume_keyword(:if) - IfMod.new( - statement: statement, + IfNode.new( predicate: predicate, + statements: + Statements.new(body: [statement], location: statement.location), + consequent: nil, location: statement.location.to(predicate.location) ) end @@ -1803,17 +2132,26 @@ def on_in(pattern, statements, consequent) # Here we have a rightward assignment return pattern unless statements - beginning = find_token(Kw, "in") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:in) + ending = consequent || consume_keyword(:end) statements_start = pattern - if (token = find_token(Kw, "then", consume: false)) + if (token = find_keyword_between(:then, pattern, statements)) tokens.delete(token) statements_start = token end - start_char = find_next_statement_start(statements_start.location.end_char) + start_char = + find_next_statement_start((token || statements_start).location.end_char) + + # Ripper ignores parentheses on patterns, so we need to do the same in + # order to attach comments correctly to the pattern. + if source[start_char] == ")" + start_char = find_next_statement_start(start_char + 1) + end + statements.bind( + self, start_char, start_char - line_counts[statements_start.location.start_line - 1].start, @@ -1821,12 +2159,16 @@ def on_in(pattern, statements, consequent) ending.location.start_column ) - In.new( - pattern: pattern, - statements: statements, - consequent: consequent, - location: beginning.location.to(ending.location) - ) + node = + In.new( + pattern: pattern, + statements: statements, + consequent: consequent, + location: beginning.location.to(ending.location) + ) + + PinVisitor.visit(node, tokens) + node end # :call-seq: @@ -1881,7 +2223,7 @@ def on_kw(value) # :call-seq: # on_kwrest_param: ((nil | Ident) name) -> KwRestParam def on_kwrest_param(name) - location = find_token(Op, "**").location + location = consume_operator(:**).location location = location.to(name.location) if name KwRestParam.new(name: name, location: location) @@ -1927,7 +2269,7 @@ def on_label_end(value) # (BodyStmt | Statements) statements # ) -> Lambda def on_lambda(params, statements) - beginning = find_token(TLambda) + beginning = consume_token(TLambda) braces = tokens.any? do |token| token.is_a?(TLamBeg) && @@ -1935,16 +2277,66 @@ def on_lambda(params, statements) end if braces - opening = find_token(TLamBeg) - closing = find_token(RBrace) + opening = consume_token(TLamBeg) + closing = consume_token(RBrace) else - opening = find_token(Kw, "do") - closing = find_token(Kw, "end") + opening = consume_keyword(:do) + closing = consume_keyword(:end) end + # We need to do some special mapping here. Since ripper doesn't support + # capturing lambda vars, we need to normalize all of that here. + params = + if params.is_a?(Paren) + # In this case we've gotten to the parentheses wrapping a set of + # parameters case. Here we need to manually scan for lambda locals. + range = (params.location.start_char + 1)...params.location.end_char + locals = lambda_locals(source[range]) + + location = params.contents.location + location = location.to(locals.last.location) if locals.any? + + node = + Paren.new( + lparen: params.lparen, + contents: + LambdaVar.new( + params: params.contents, + locals: locals, + location: location + ), + location: params.location + ) + + node.comments.concat(params.comments) + node + else + # If there are no parameters, then we didn't have anything to base the + # location information of off. Now that we have an opening of the + # block, we can correct this. + if params.empty? + opening_location = opening.location + location = + Location.fixed( + line: opening_location.start_line, + char: opening_location.start_char, + column: opening_location.start_column + ) + + params = params.copy(location: location) + end + + # In this case we've gotten to the plain set of parameters. In this + # case there cannot be lambda locals, so we will wrap the parameters + # into a lambda var that has no locals. + LambdaVar.new(params: params, locals: [], location: params.location) + end + + start_char = find_next_statement_start(opening.location.end_char) statements.bind( - opening.location.end_char, - opening.location.end_column, + self, + start_char, + start_char - line_counts[opening.location.end_line - 1].start, closing.location.start_char, closing.location.start_column ) @@ -1956,6 +2348,89 @@ def on_lambda(params, statements) ) end + # :call-seq: + # on_lambda_var: (Params params, Array[ Ident ] locals) -> LambdaVar + def on_lambda_var(params, locals) + location = params.location + location = location.to(locals.last.location) if locals.any? + + LambdaVar.new(params: params, locals: locals || [], location: location) + end + + # Ripper doesn't support capturing lambda local variables until 3.2. To + # mitigate this, we have to parse that code for ourselves. We use the range + # from the parentheses to find where we _should_ be looking. Then we check + # if the resulting tokens match a pattern that we determine means that the + # declaration has block-local variables. Once it does, we parse those out + # and convert them into Ident nodes. + def lambda_locals(source) + tokens = Ripper.lex(source) + + # First, check that we have a semi-colon. If we do, then we can start to + # parse the tokens _after_ the semicolon. + index = tokens.rindex { |token| token[1] == :on_semicolon } + return [] unless index + + # Next, map over the tokens and convert them into Ident nodes. Bail out + # midway through if we encounter a token we didn't expect. Basically we're + # making our own mini-parser here. To do that we'll walk through a small + # state machine: + # + # ┌────────┐ ┌────────┐ ┌─────────┐ + # │ │ │ │ │┌───────┐│ + # ──> │ item │ ─── ident ──> │ next │ ─── rparen ──> ││ final ││ + # │ │ <── comma ─── │ │ │└───────┘│ + # └────────┘ └────────┘ └─────────┘ + # │ ^ │ ^ + # └──┘ └──┘ + # ignored_nl, sp nl, sp + # + state = :item + transitions = { + item: { + on_ignored_nl: :item, + on_sp: :item, + on_ident: :next + }, + next: { + on_nl: :next, + on_sp: :next, + on_comma: :item, + on_rparen: :final + }, + final: { + } + } + + parent_line = lineno - 1 + parent_column = + consume_token(Semicolon).location.start_column - tokens[index][0][1] + + tokens[(index + 1)..].each_with_object([]) do |token, locals| + (lineno, column), type, value, = token + column += parent_column if lineno == 1 + lineno += parent_line + + # Make the state transition for the parser. If there isn't a transition + # from the current state to a new state for this type, then we're in a + # pattern that isn't actually locals. In that case we can return []. + state = transitions[state].fetch(type) { return [] } + + # If we hit an identifier, then add it to our list. + next if type != :on_ident + + location = + Location.token( + line: lineno, + char: line_counts[lineno - 1][column], + column: column, + size: value.size + ) + + locals << Ident.new(value: value, location: location) + end + end + # :call-seq: # on_lbrace: (String value) -> LBrace def on_lbrace(value) @@ -2032,37 +2507,49 @@ def on_massign(target, value) # :call-seq: # on_method_add_arg: ( - # (Call | FCall) call, + # CallNode call, # (ArgParen | Args) arguments - # ) -> Call | FCall + # ) -> CallNode def on_method_add_arg(call, arguments) location = call.location location = location.to(arguments.location) if arguments.is_a?(ArgParen) - if call.is_a?(FCall) - FCall.new(value: call.value, arguments: arguments, location: location) - else - Call.new( - receiver: call.receiver, - operator: call.operator, - message: call.message, - arguments: arguments, - location: location - ) - end + CallNode.new( + receiver: call.receiver, + operator: call.operator, + message: call.message, + arguments: arguments, + location: location + ) end # :call-seq: # on_method_add_block: ( - # (Call | Command | CommandCall | FCall) call, - # (BraceBlock | DoBlock) block - # ) -> MethodAddBlock + # (Break | Call | Command | CommandCall, Next) call, + # Block block + # ) -> Break | MethodAddBlock def on_method_add_block(call, block) - MethodAddBlock.new( - call: call, - block: block, - location: call.location.to(block.location) - ) + location = call.location.to(block.location) + + case call + when Break, Next, ReturnNode + parts = call.arguments.parts + + node = parts.pop + copied = + node.copy(block: block, location: node.location.to(block.location)) + + copied.comments.concat(call.comments) + parts << copied + + call.copy(location: location) + when Command, CommandCall + node = call.copy(block: block, location: location) + node.comments.concat(call.comments) + node + else + MethodAddBlock.new(call: call, block: block, location: location) + end end # :call-seq: @@ -2092,7 +2579,7 @@ def on_mlhs_add_post(left, right) # (nil | ARefField | Field | Ident | VarField) part # ) -> MLHS def on_mlhs_add_star(mlhs, part) - beginning = find_token(Op, "*") + beginning = consume_operator(:*) ending = part || beginning location = beginning.location.to(ending.location) @@ -2115,8 +2602,8 @@ def on_mlhs_new # :call-seq: # on_mlhs_paren: ((MLHS | MLHSParen) contents) -> MLHSParen def on_mlhs_paren(contents) - lparen = find_token(LParen) - rparen = find_token(RParen) + lparen = consume_token(LParen) + rparen = consume_token(RParen) comma_range = lparen.location.end_char...rparen.location.start_char contents.comma = true if source[comma_range].strip.end_with?(",") @@ -2133,11 +2620,12 @@ def on_mlhs_paren(contents) # BodyStmt bodystmt # ) -> ModuleDeclaration def on_module(constant, bodystmt) - beginning = find_token(Kw, "module") - ending = find_token(Kw, "end") + beginning = consume_keyword(:module) + ending = consume_keyword(:end) start_char = find_next_statement_start(constant.location.end_char) bodystmt.bind( + self, start_char, start_char - line_counts[constant.location.start_line - 1].start, ending.location.start_char, @@ -2173,7 +2661,7 @@ def on_mrhs_add(mrhs, part) # :call-seq: # on_mrhs_add_star: (MRHS mrhs, untyped value) -> MRHS def on_mrhs_add_star(mrhs, value) - beginning = find_token(Op, "*") + beginning = consume_operator(:*) ending = value || beginning arg_star = @@ -2201,7 +2689,7 @@ def on_mrhs_new_from_args(arguments) # :call-seq: # on_next: (Args arguments) -> Next def on_next(arguments) - keyword = find_token(Kw, "next") + keyword = consume_keyword(:next) location = keyword.location location = location.to(arguments.location) if arguments.parts.any? @@ -2284,19 +2772,40 @@ def on_params( # have a `nil` for the value instead of a `false`. keywords&.map! { |(key, value)| [key, value || nil] } - parts = [ - *requireds, - *optionals&.flatten(1), - rest, - *posts, - *keywords&.flatten(1), - (keyword_rest if keyword_rest != :nil), - (block if block != :&) - ].compact + # Here we're going to build up a list of all of the params so that we can + # determine our location information. + parts = [] + + requireds&.each { |required| parts << required.location } + optionals&.each do |(key, value)| + parts << key.location + parts << value.location if value + end + + parts << rest.location if rest + posts&.each { |post| parts << post.location } + + keywords&.each do |(key, value)| + parts << key.location + parts << value.location if value + end + + if keyword_rest == :nil + # When we get a :nil here, it means that we have **nil syntax, which + # means this set of parameters accepts no more keyword arguments. In + # this case we need to go and find the location of these two tokens. + operator = consume_operator(:**) + parts << operator.location.to(consume_keyword(:nil).location) + elsif keyword_rest + parts << keyword_rest.location + end + + parts << block.location if block && block != :& + parts = parts.compact location = if parts.any? - parts[0].location.to(parts[-1].location) + parts[0].to(parts[-1]) else Location.fixed(line: lineno, char: char_pos, column: current_column) end @@ -2316,8 +2825,8 @@ def on_params( # :call-seq: # on_paren: (untyped contents) -> Paren def on_paren(contents) - lparen = find_token(LParen) - rparen = find_token(RParen) + lparen = consume_token(LParen) + rparen = consume_token(RParen) if contents.is_a?(Params) location = contents.location @@ -2362,6 +2871,7 @@ def on_parse_error(error, *) alias on_assign_error on_parse_error alias on_class_name_error on_parse_error alias on_param_error on_parse_error + alias compile_error on_parse_error # :call-seq: # on_period: (String value) -> Period @@ -2381,19 +2891,19 @@ def on_period(value) # :call-seq: # on_program: (Statements statements) -> Program def on_program(statements) - last_column = source.length - line_counts[lines.length - 1].start + last_column = source.length - line_counts.last.start location = Location.new( start_line: 1, start_char: 0, start_column: 0, - end_line: lines.length, + end_line: line_counts.length - 1, end_char: source.length, end_column: last_column ) statements.body << @__end__ if @__end__ - statements.bind(0, 0, source.length, last_column) + statements.bind(self, 0, 0, source.length, last_column) program = Program.new(statements: statements, location: location) attach_comments(program, @comments) @@ -2522,7 +3032,7 @@ def on_qsymbols_beg(value) # :call-seq: # on_qsymbols_new: () -> QSymbols def on_qsymbols_new - beginning = find_token(QSymbolsBeg) + beginning = consume_token(QSymbolsBeg) QSymbols.new( beginning: beginning, @@ -2563,7 +3073,7 @@ def on_qwords_beg(value) # :call-seq: # on_qwords_new: () -> QWords def on_qwords_new - beginning = find_token(QWordsBeg) + beginning = consume_token(QWordsBeg) QWords.new( beginning: beginning, @@ -2628,9 +3138,9 @@ def on_rbracket(value) # :call-seq: # on_redo: () -> Redo def on_redo - keyword = find_token(Kw, "redo") + keyword = consume_keyword(:redo) - Redo.new(value: keyword.value, location: keyword.location) + Redo.new(location: keyword.location) end # :call-seq: @@ -2683,21 +3193,28 @@ def on_regexp_end(value) # :call-seq: # on_regexp_literal: ( # RegexpContent regexp_content, - # RegexpEnd ending + # (nil | RegexpEnd) ending # ) -> RegexpLiteral def on_regexp_literal(regexp_content, ending) + location = regexp_content.location + + if ending.nil? + message = "Cannot find expected regular expression ending" + raise ParseError.new(message, *find_token_error(location)) + end + RegexpLiteral.new( beginning: regexp_content.beginning, ending: ending.value, parts: regexp_content.parts, - location: regexp_content.location.to(ending.location) + location: location.to(ending.location) ) end # :call-seq: # on_regexp_new: () -> RegexpContent def on_regexp_new - regexp_beg = find_token(RegexpBeg) + regexp_beg = consume_token(RegexpBeg) RegexpContent.new( beginning: regexp_beg.value, @@ -2714,12 +3231,13 @@ def on_regexp_new # (nil | Rescue) consequent # ) -> Rescue def on_rescue(exceptions, variable, statements, consequent) - keyword = find_token(Kw, "rescue") + keyword = consume_keyword(:rescue) exceptions = exceptions[0] if exceptions.is_a?(Array) last_node = variable || exceptions || keyword - start_char = find_next_statement_start(last_node.location.end_char) + start_char = find_next_statement_start(last_node.end_char) statements.bind( + self, start_char, start_char - line_counts[last_node.location.start_line - 1].start, char_pos, @@ -2740,7 +3258,7 @@ def on_rescue(exceptions, variable, statements, consequent) start_char: keyword.location.end_char + 1, start_column: keyword.location.end_column + 1, end_line: last_node.location.end_line, - end_char: last_node.location.end_char, + end_char: last_node.end_char, end_column: last_node.location.end_column ) ) @@ -2766,7 +3284,7 @@ def on_rescue(exceptions, variable, statements, consequent) # :call-seq: # on_rescue_mod: (untyped statement, untyped value) -> RescueMod def on_rescue_mod(statement, value) - find_token(Kw, "rescue") + consume_keyword(:rescue) RescueMod.new( statement: statement, @@ -2778,7 +3296,7 @@ def on_rescue_mod(statement, value) # :call-seq: # on_rest_param: ((nil | Ident) name) -> RestParam def on_rest_param(name) - location = find_token(Op, "*").location + location = consume_operator(:*).location location = location.to(name.location) if name RestParam.new(name: name, location: location) @@ -2787,28 +3305,28 @@ def on_rest_param(name) # :call-seq: # on_retry: () -> Retry def on_retry - keyword = find_token(Kw, "retry") + keyword = consume_keyword(:retry) - Retry.new(value: keyword.value, location: keyword.location) + Retry.new(location: keyword.location) end # :call-seq: - # on_return: (Args arguments) -> Return + # on_return: (Args arguments) -> ReturnNode def on_return(arguments) - keyword = find_token(Kw, "return") + keyword = consume_keyword(:return) - Return.new( + ReturnNode.new( arguments: arguments, location: keyword.location.to(arguments.location) ) end # :call-seq: - # on_return0: () -> Return0 + # on_return0: () -> ReturnNode def on_return0 - keyword = find_token(Kw, "return") + keyword = consume_keyword(:return) - Return0.new(value: keyword.value, location: keyword.location) + ReturnNode.new(arguments: nil, location: keyword.location) end # :call-seq: @@ -2833,11 +3351,12 @@ def on_rparen(value) # :call-seq: # on_sclass: (untyped target, BodyStmt bodystmt) -> SClass def on_sclass(target, bodystmt) - beginning = find_token(Kw, "class") - ending = find_token(Kw, "end") + beginning = consume_keyword(:class) + ending = consume_keyword(:end) start_char = find_next_statement_start(target.location.end_char) bodystmt.bind( + self, start_char, start_char - line_counts[target.location.start_line - 1].start, ending.location.start_char, @@ -2851,9 +3370,29 @@ def on_sclass(target, bodystmt) ) end - # def on_semicolon(value) - # value - # end + # Semicolons are tokens that get added to the token list but never get + # attached to the AST. Because of this they only need to track their + # associated location so they can be used for computing bounds. + class Semicolon + attr_reader :location + + def initialize(location) + @location = location + end + end + + # :call-seq: + # on_semicolon: (String value) -> Semicolon + def on_semicolon(value) + tokens << Semicolon.new( + Location.token( + line: lineno, + char: char_pos, + column: current_column, + size: value.size + ) + ) + end # def on_sp(value) # value @@ -2871,18 +3410,13 @@ def on_stmts_add(statements, statement) statements.location.to(statement.location) end - Statements.new( - self, - body: statements.body << statement, - location: location - ) + Statements.new(body: statements.body << statement, location: location) end # :call-seq: # on_stmts_new: () -> Statements def on_stmts_new Statements.new( - self, body: [], location: Location.fixed(line: lineno, char: char_pos, column: current_column) @@ -2895,6 +3429,11 @@ def on_stmts_new # (StringEmbExpr | StringDVar | TStringContent) part # ) -> StringContent def on_string_add(string, part) + # Due to some eccentricities in how ripper works, you need this here in + # case you have a syntax error with an embedded expression that doesn't + # finish, as in: "#{" + return string if part.is_a?(String) + location = string.parts.any? ? string.location.to(part.location) : part.location @@ -2927,7 +3466,7 @@ def on_string_content # :call-seq: # on_string_dvar: ((Backref | VarRef) variable) -> StringDVar def on_string_dvar(variable) - embvar = find_token(EmbVar) + embvar = consume_token(EmbVar) StringDVar.new( variable: variable, @@ -2938,10 +3477,11 @@ def on_string_dvar(variable) # :call-seq: # on_string_embexpr: (Statements statements) -> StringEmbExpr def on_string_embexpr(statements) - embexpr_beg = find_token(EmbExprBeg) - embexpr_end = find_token(EmbExprEnd) + embexpr_beg = consume_token(EmbExprBeg) + embexpr_end = consume_token(EmbExprEnd) statements.bind( + self, embexpr_beg.location.end_char, embexpr_beg.location.end_column, embexpr_end.location.start_char, @@ -2980,8 +3520,8 @@ def on_string_literal(string) location: heredoc.location ) else - tstring_beg = find_token(TStringBeg) - tstring_end = find_token(TStringEnd, location: tstring_beg.location) + tstring_beg = consume_token(TStringBeg) + tstring_end = consume_tstring_end(tstring_beg.location) location = Location.new( @@ -3007,7 +3547,7 @@ def on_string_literal(string) # :call-seq: # on_super: ((ArgParen | Args) arguments) -> Super def on_super(arguments) - keyword = find_token(Kw, "super") + keyword = consume_keyword(:super) Super.new( arguments: arguments, @@ -3054,7 +3594,7 @@ def on_symbol(value) # ) -> SymbolLiteral def on_symbol_literal(value) if value.is_a?(SymbolContent) - symbeg = find_token(SymBeg) + symbeg = consume_token(SymBeg) SymbolLiteral.new( value: value.value, @@ -3098,7 +3638,7 @@ def on_symbols_beg(value) # :call-seq: # on_symbols_new: () -> Symbols def on_symbols_new - beginning = find_token(SymbolsBeg) + beginning = consume_token(SymbolsBeg) Symbols.new( beginning: beginning, @@ -3228,13 +3768,13 @@ def on_unary(operator, statement) # We have somewhat special handling of the not operator since if it has # parentheses they don't get reported as a paren node for some reason. - beginning = find_token(Kw, "not") + beginning = consume_keyword(:not) ending = statement || beginning parentheses = source[beginning.location.end_char] == "(" if parentheses - find_token(LParen) - ending = find_token(RParen) + consume_token(LParen) + ending = consume_token(RParen) end Not.new( @@ -3267,7 +3807,7 @@ def on_unary(operator, statement) # :call-seq: # on_undef: (Array[DynaSymbol | SymbolLiteral] symbols) -> Undef def on_undef(symbols) - keyword = find_token(Kw, "undef") + keyword = consume_keyword(:undef) Undef.new( symbols: symbols, @@ -3280,19 +3820,27 @@ def on_undef(symbols) # untyped predicate, # Statements statements, # ((nil | Elsif | Else) consequent) - # ) -> Unless + # ) -> UnlessNode def on_unless(predicate, statements, consequent) - beginning = find_token(Kw, "unless") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:unless) + ending = consequent || consume_keyword(:end) + + if (keyword = find_keyword_between(:then, predicate, ending)) + tokens.delete(keyword) + end + + start_char = + find_next_statement_start((keyword || predicate).location.end_char) statements.bind( - predicate.location.end_char, - predicate.location.end_column, + self, + start_char, + start_char - line_counts[predicate.location.end_line - 1].start, ending.location.start_char, ending.location.start_column ) - Unless.new( + UnlessNode.new( predicate: predicate, statements: statements, consequent: consequent, @@ -3301,40 +3849,44 @@ def on_unless(predicate, statements, consequent) end # :call-seq: - # on_unless_mod: (untyped predicate, untyped statement) -> UnlessMod + # on_unless_mod: (untyped predicate, untyped statement) -> UnlessNode def on_unless_mod(predicate, statement) - find_token(Kw, "unless") + consume_keyword(:unless) - UnlessMod.new( - statement: statement, + UnlessNode.new( predicate: predicate, + statements: + Statements.new(body: [statement], location: statement.location), + consequent: nil, location: statement.location.to(predicate.location) ) end # :call-seq: - # on_until: (untyped predicate, Statements statements) -> Until + # on_until: (untyped predicate, Statements statements) -> UntilNode def on_until(predicate, statements) - beginning = find_token(Kw, "until") - ending = find_token(Kw, "end") - - # Consume the do keyword if it exists so that it doesn't get confused for - # some other block - keyword = find_token(Kw, "do", consume: false) - if keyword && keyword.location.start_char > predicate.location.end_char && - keyword.location.end_char < ending.location.start_char - tokens.delete(keyword) - end + beginning = consume_keyword(:until) + ending = consume_keyword(:end) + + delimiter = + find_keyword_between(:do, predicate, statements) || + find_token_between(Semicolon, predicate, statements) + + tokens.delete(delimiter) if delimiter # Update the Statements location information + start_char = + find_next_statement_start((delimiter || predicate).location.end_char) + statements.bind( - predicate.location.end_char, - predicate.location.end_column, + self, + start_char, + start_char - line_counts[predicate.location.end_line - 1].start, ending.location.start_char, ending.location.start_column ) - Until.new( + UntilNode.new( predicate: predicate, statements: statements, location: beginning.location.to(ending.location) @@ -3342,23 +3894,24 @@ def on_until(predicate, statements) end # :call-seq: - # on_until_mod: (untyped predicate, untyped statement) -> UntilMod + # on_until_mod: (untyped predicate, untyped statement) -> UntilNode def on_until_mod(predicate, statement) - find_token(Kw, "until") + consume_keyword(:until) - UntilMod.new( - statement: statement, + UntilNode.new( predicate: predicate, + statements: + Statements.new(body: [statement], location: statement.location), location: statement.location.to(predicate.location) ) end # :call-seq: - # on_var_alias: (GVar left, (Backref | GVar) right) -> VarAlias + # on_var_alias: (GVar left, (Backref | GVar) right) -> AliasNode def on_var_alias(left, right) - keyword = find_token(Kw, "alias") + keyword = consume_keyword(:alias) - VarAlias.new( + AliasNode.new( left: left, right: right, location: keyword.location.to(right.location) @@ -3385,17 +3938,7 @@ def on_var_field(value) # :call-seq: # on_var_ref: ((Const | CVar | GVar | Ident | IVar | Kw) value) -> VarRef def on_var_ref(value) - pin = find_token(Op, "^", consume: false) - - if pin && pin.location.start_char == value.location.start_char - 1 - tokens.delete(pin) - PinnedVarRef.new( - value: value, - location: pin.location.to(value.location) - ) - else - VarRef.new(value: value, location: value.location) - end + VarRef.new(value: value, location: value.location) end # :call-seq: @@ -3420,18 +3963,20 @@ def on_void_stmt # (nil | Else | When) consequent # ) -> When def on_when(arguments, statements, consequent) - beginning = find_token(Kw, "when") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:when) + ending = consequent || consume_keyword(:end) statements_start = arguments - if (token = find_token(Kw, "then", consume: false)) + if (token = find_keyword(:then)) tokens.delete(token) statements_start = token end - start_char = find_next_statement_start(statements_start.location.end_char) + start_char = + find_next_statement_start((token || statements_start).location.end_char) statements.bind( + self, start_char, start_char - line_counts[statements_start.location.start_line - 1].start, @@ -3448,28 +3993,30 @@ def on_when(arguments, statements, consequent) end # :call-seq: - # on_while: (untyped predicate, Statements statements) -> While + # on_while: (untyped predicate, Statements statements) -> WhileNode def on_while(predicate, statements) - beginning = find_token(Kw, "while") - ending = find_token(Kw, "end") - - # Consume the do keyword if it exists so that it doesn't get confused for - # some other block - keyword = find_token(Kw, "do", consume: false) - if keyword && keyword.location.start_char > predicate.location.end_char && - keyword.location.end_char < ending.location.start_char - tokens.delete(keyword) - end + beginning = consume_keyword(:while) + ending = consume_keyword(:end) + + delimiter = + find_keyword_between(:do, predicate, statements) || + find_token_between(Semicolon, predicate, statements) + + tokens.delete(delimiter) if delimiter # Update the Statements location information + start_char = + find_next_statement_start((delimiter || predicate).location.end_char) + statements.bind( - predicate.location.end_char, - predicate.location.end_column, + self, + start_char, + start_char - line_counts[predicate.location.end_line - 1].start, ending.location.start_char, ending.location.start_column ) - While.new( + WhileNode.new( predicate: predicate, statements: statements, location: beginning.location.to(ending.location) @@ -3477,13 +4024,14 @@ def on_while(predicate, statements) end # :call-seq: - # on_while_mod: (untyped predicate, untyped statement) -> WhileMod + # on_while_mod: (untyped predicate, untyped statement) -> WhileNode def on_while_mod(predicate, statement) - find_token(Kw, "while") + consume_keyword(:while) - WhileMod.new( - statement: statement, + WhileNode.new( predicate: predicate, + statements: + Statements.new(body: [statement], location: statement.location), location: statement.location.to(predicate.location) ) end @@ -3542,7 +4090,7 @@ def on_words_beg(value) # :call-seq: # on_words_new: () -> Words def on_words_new - beginning = find_token(WordsBeg) + beginning = consume_token(WordsBeg) Words.new( beginning: beginning, @@ -3576,7 +4124,7 @@ def on_xstring_new if heredoc && heredoc.beginning.value.include?("`") heredoc.location else - find_token(Backtick).location + consume_token(Backtick).location end XString.new(parts: [], location: location) @@ -3596,7 +4144,7 @@ def on_xstring_literal(xstring) location: heredoc.location ) else - ending = find_token(TStringEnd, location: xstring.location) + ending = consume_tstring_end(xstring.location) XStringLiteral.new( parts: xstring.parts, @@ -3606,30 +4154,30 @@ def on_xstring_literal(xstring) end # :call-seq: - # on_yield: ((Args | Paren) arguments) -> Yield + # on_yield: ((Args | Paren) arguments) -> YieldNode def on_yield(arguments) - keyword = find_token(Kw, "yield") + keyword = consume_keyword(:yield) - Yield.new( + YieldNode.new( arguments: arguments, location: keyword.location.to(arguments.location) ) end # :call-seq: - # on_yield0: () -> Yield0 + # on_yield0: () -> YieldNode def on_yield0 - keyword = find_token(Kw, "yield") + keyword = consume_keyword(:yield) - Yield0.new(value: keyword.value, location: keyword.location) + YieldNode.new(arguments: nil, location: keyword.location) end # :call-seq: # on_zsuper: () -> ZSuper def on_zsuper - keyword = find_token(Kw, "super") + keyword = consume_keyword(:super) - ZSuper.new(value: keyword.value, location: keyword.location) + ZSuper.new(location: keyword.location) end end end diff --git a/lib/syntax_tree/pattern.rb b/lib/syntax_tree/pattern.rb new file mode 100644 index 00000000..a5e88bfa --- /dev/null +++ b/lib/syntax_tree/pattern.rb @@ -0,0 +1,288 @@ +# frozen_string_literal: true + +module SyntaxTree + # A pattern is an object that wraps a Ruby pattern matching expression. The + # expression would normally be passed to an `in` clause within a `case` + # expression or a rightward assignment expression. For example, in the + # following snippet: + # + # case node + # in Const[value: "SyntaxTree"] + # end + # + # the pattern is the `Const[value: "SyntaxTree"]` expression. Within Syntax + # Tree, every node generates these kinds of expressions using the + # #construct_keys method. + # + # The pattern gets compiled into an object that responds to call by running + # the #compile method. This method itself will run back through Syntax Tree to + # parse the expression into a tree, then walk the tree to generate the + # necessary callable objects. For example, if you wanted to compile the + # expression above into a callable, you would: + # + # callable = SyntaxTree::Pattern.new("Const[value: 'SyntaxTree']").compile + # callable.call(node) + # + # The callable object returned by #compile is guaranteed to respond to #call + # with a single argument, which is the node to match against. It also is + # guaranteed to respond to #===, which means it itself can be used in a `case` + # expression, as in: + # + # case node + # when callable + # end + # + # If the query given to the initializer cannot be compiled into a valid + # matcher (either because of a syntax error or because it is using syntax we + # do not yet support) then a SyntaxTree::Pattern::CompilationError will be + # raised. + class Pattern + # Raised when the query given to a pattern is either invalid Ruby syntax or + # is using syntax that we don't yet support. + class CompilationError < StandardError + def initialize(repr) + super(<<~ERROR) + Syntax Tree was unable to compile the pattern you provided to search + into a usable expression. It failed on to understand the node + represented by: + + #{repr} + + Note that not all syntax supported by Ruby's pattern matching syntax + is also supported by Syntax Tree's code search. If you're using some + syntax that you believe should be supported, please open an issue on + GitHub at https://github.com/ruby-syntax-tree/syntax_tree/issues/new. + ERROR + end + end + + attr_reader :query + + def initialize(query) + @query = query + end + + def compile + program = + begin + SyntaxTree.parse("case nil\nin #{query}\nend") + rescue Parser::ParseError + raise CompilationError, query + end + + raise CompilationError, query if program.nil? + compile_node(program.statements.body.first.consequent.pattern) + end + + private + + # Shortcut for combining two procs into one that returns true if both return + # true. + def combine_and(left, right) + ->(other) { left.call(other) && right.call(other) } + end + + # Shortcut for combining two procs into one that returns true if either + # returns true. + def combine_or(left, right) + ->(other) { left.call(other) || right.call(other) } + end + + # Raise an error because the given node is not supported. + def compile_error(node) + raise CompilationError, PP.pp(node, +"").chomp + end + + # There are a couple of nodes (string literals, dynamic symbols, and regexp) + # that contain list of parts. This can include plain string content, + # interpolated expressions, and interpolated variables. We only support + # plain string content, so this method will extract out the plain string + # content if it is the only element in the list. + def extract_string(node) + parts = node.parts + + if parts.length == 1 && (part = parts.first) && part.is_a?(TStringContent) + part.value + end + end + + # in [foo, bar, baz] + def compile_aryptn(node) + compile_error(node) if !node.rest.nil? || node.posts.any? + + constant = node.constant + compiled_constant = compile_node(constant) if constant + + preprocessed = node.requireds.map { |required| compile_node(required) } + + compiled_requireds = ->(other) do + deconstructed = other.deconstruct + + deconstructed.length == preprocessed.length && + preprocessed + .zip(deconstructed) + .all? { |(matcher, value)| matcher.call(value) } + end + + if compiled_constant + combine_and(compiled_constant, compiled_requireds) + else + compiled_requireds + end + end + + # in foo | bar + def compile_binary(node) + compile_error(node) if node.operator != :| + + combine_or(compile_node(node.left), compile_node(node.right)) + end + + # in Ident + # in String + def compile_const(node) + value = node.value + + if SyntaxTree.const_defined?(value, false) + clazz = SyntaxTree.const_get(value) + + ->(other) { clazz === other } + elsif Object.const_defined?(value, false) + clazz = Object.const_get(value) + + ->(other) { clazz === other } + else + compile_error(node) + end + end + + # in SyntaxTree::Ident + def compile_const_path_ref(node) + parent = node.parent + compile_error(node) if !parent.is_a?(VarRef) || !parent.value.is_a?(Const) + + if parent.value.value == "SyntaxTree" + compile_node(node.constant) + else + compile_error(node) + end + end + + # in :"" + # in :"foo" + def compile_dyna_symbol(node) + if node.parts.empty? + symbol = :"" + + ->(other) { symbol === other } + elsif (value = extract_string(node)) + symbol = value.to_sym + + ->(other) { symbol === other } + else + compile_error(node) + end + end + + # in Ident[value: String] + # in { value: String } + def compile_hshptn(node) + compile_error(node) unless node.keyword_rest.nil? + compiled_constant = compile_node(node.constant) if node.constant + + preprocessed = + node.keywords.to_h do |keyword, value| + compile_error(node) unless keyword.is_a?(Label) + [keyword.value.chomp(":").to_sym, compile_node(value)] + end + + compiled_keywords = ->(other) do + deconstructed = other.deconstruct_keys(preprocessed.keys) + + preprocessed.all? do |keyword, matcher| + matcher.call(deconstructed[keyword]) + end + end + + if compiled_constant + combine_and(compiled_constant, compiled_keywords) + else + compiled_keywords + end + end + + # in /foo/ + def compile_regexp_literal(node) + if (value = extract_string(node)) + regexp = /#{value}/ + + ->(attribute) { regexp === attribute } + else + compile_error(node) + end + end + + # in "" + # in "foo" + def compile_string_literal(node) + if node.parts.empty? + ->(attribute) { "" === attribute } + elsif (value = extract_string(node)) + ->(attribute) { value === attribute } + else + compile_error(node) + end + end + + # in :+ + # in :foo + def compile_symbol_literal(node) + symbol = node.value.value.to_sym + + ->(attribute) { symbol === attribute } + end + + # in Foo + # in nil + def compile_var_ref(node) + value = node.value + + if value.is_a?(Const) + compile_node(value) + elsif value.is_a?(Kw) && value.value.nil? + ->(attribute) { nil === attribute } + else + compile_error(node) + end + end + + # Compile any kind of node. Dispatch out to the individual compilation + # methods based on the type of node. + def compile_node(node) + case node + when AryPtn + compile_aryptn(node) + when Binary + compile_binary(node) + when Const + compile_const(node) + when ConstPathRef + compile_const_path_ref(node) + when DynaSymbol + compile_dyna_symbol(node) + when HshPtn + compile_hshptn(node) + when RegexpLiteral + compile_regexp_literal(node) + when StringLiteral + compile_string_literal(node) + when SymbolLiteral + compile_symbol_literal(node) + when VarRef + compile_var_ref(node) + else + compile_error(node) + end + end + end +end diff --git a/lib/syntax_tree/plugin/disable_auto_ternary.rb b/lib/syntax_tree/plugin/disable_auto_ternary.rb new file mode 100644 index 00000000..dd38c783 --- /dev/null +++ b/lib/syntax_tree/plugin/disable_auto_ternary.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +module SyntaxTree + class Formatter + DISABLE_AUTO_TERNARY = true + end +end diff --git a/lib/syntax_tree/plugin/single_quotes.rb b/lib/syntax_tree/plugin/single_quotes.rb index d8034084..c7405e2c 100644 --- a/lib/syntax_tree/plugin/single_quotes.rb +++ b/lib/syntax_tree/plugin/single_quotes.rb @@ -1,4 +1,7 @@ # frozen_string_literal: true -require "syntax_tree/formatter/single_quotes" -SyntaxTree::Formatter.prepend(SyntaxTree::Formatter::SingleQuotes) +module SyntaxTree + class Formatter + SINGLE_QUOTES = true + end +end diff --git a/lib/syntax_tree/plugin/trailing_comma.rb b/lib/syntax_tree/plugin/trailing_comma.rb new file mode 100644 index 00000000..1ae2b96d --- /dev/null +++ b/lib/syntax_tree/plugin/trailing_comma.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +module SyntaxTree + class Formatter + TRAILING_COMMA = true + end +end diff --git a/lib/syntax_tree/pretty_print_visitor.rb b/lib/syntax_tree/pretty_print_visitor.rb new file mode 100644 index 00000000..894e0cf4 --- /dev/null +++ b/lib/syntax_tree/pretty_print_visitor.rb @@ -0,0 +1,83 @@ +# frozen_string_literal: true + +module SyntaxTree + # This visitor pretty-prints the AST into an equivalent s-expression. + class PrettyPrintVisitor < FieldVisitor + attr_reader :q + + def initialize(q) + @q = q + end + + # This is here because we need to make sure the operator is cast to a string + # before we print it out. + def visit_binary(node) + node(node, "binary") do + field("left", node.left) + text("operator", node.operator.to_s) + field("right", node.right) + comments(node) + end + end + + # This is here to make it a little nicer to look at labels since they + # typically have their : at the end of the value. + def visit_label(node) + node(node, "label") do + q.breakable + q.text(":") + q.text(node.value[0...-1]) + comments(node) + end + end + + private + + def comments(node) + return if node.comments.empty? + + q.breakable + q.group(2, "(", ")") do + q.seplist(node.comments) { |comment| q.pp(comment) } + end + end + + def field(_name, value) + q.breakable + q.pp(value) + end + + def list(_name, values) + q.breakable + q.group(2, "(", ")") { q.seplist(values) { |value| q.pp(value) } } + end + + def node(_node, type) + q.group(2, "(", ")") do + q.text(type) + yield + end + end + + def pairs(_name, values) + q.group(2, "(", ")") do + q.seplist(values) do |(key, value)| + q.pp(key) + + if value + q.text("=") + q.group(2) do + q.breakable("") + q.pp(value) + end + end + end + end + end + + def text(_name, value) + q.breakable + q.text(value) + end + end +end diff --git a/lib/syntax_tree/prettyprint.rb b/lib/syntax_tree/prettyprint.rb deleted file mode 100644 index 7fe64a56..00000000 --- a/lib/syntax_tree/prettyprint.rb +++ /dev/null @@ -1,1159 +0,0 @@ -# frozen_string_literal: true -# -# This class implements a pretty printing algorithm. It finds line breaks and -# nice indentations for grouped structure. -# -# By default, the class assumes that primitive elements are strings and each -# byte in the strings is a single column in width. But it can be used for other -# situations by giving suitable arguments for some methods: -# -# * newline object and space generation block for PrettyPrint.new -# * optional width argument for PrettyPrint#text -# * PrettyPrint#breakable -# -# There are several candidate uses: -# * text formatting using proportional fonts -# * multibyte characters which has columns different to number of bytes -# * non-string formatting -# -# == Usage -# -# To use this module, you will need to generate a tree of print nodes that -# represent indentation and newline behavior before it gets sent to the printer. -# Each node has different semantics, depending on the desired output. -# -# The most basic node is a Text node. This represents plain text content that -# cannot be broken up even if it doesn't fit on one line. You would create one -# of those with the text method, as in: -# -# PrettyPrint.format { |q| q.text('my content') } -# -# No matter what the desired output width is, the output for the snippet above -# will always be the same. -# -# If you want to allow the printer to break up the content on the space -# character when there isn't enough width for the full string on the same line, -# you can use the Breakable and Group nodes. For example: -# -# PrettyPrint.format do |q| -# q.group do -# q.text('my') -# q.breakable -# q.text('content') -# end -# end -# -# Now, if everything fits on one line (depending on the maximum width specified) -# then it will be the same output as the first example. If, however, there is -# not enough room on the line, then you will get two lines of output, one for -# the first string and one for the second. -# -# There are other nodes for the print tree as well, described in the -# documentation below. They control alignment, indentation, conditional -# formatting, and more. -# -# == Bugs -# * Box based formatting? -# -# Report any bugs at http://bugs.ruby-lang.org -# -# == References -# Christian Lindig, Strictly Pretty, March 2000, -# https://lindig.github.io/papers/strictly-pretty-2000.pdf -# -# Philip Wadler, A prettier printer, March 1998, -# https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf -# -# == Author -# Tanaka Akira -# -class PrettyPrint - # A node in the print tree that represents aligning nested nodes to a certain - # prefix width or string. - class Align - attr_reader :indent, :contents - - def initialize(indent:, contents: []) - @indent = indent - @contents = contents - end - - def pretty_print(q) - q.group(2, "align#{indent}([", "])") do - q.seplist(contents) { |content| q.pp(content) } - end - end - end - - # A node in the print tree that represents a place in the buffer that the - # content can be broken onto multiple lines. - class Breakable - attr_reader :separator, :width - - def initialize( - separator = " ", - width = separator.length, - force: false, - indent: true - ) - @separator = separator - @width = width - @force = force - @indent = indent - end - - def force? - @force - end - - def indent? - @indent - end - - def pretty_print(q) - q.text("breakable") - - attributes = [ - ("force=true" if force?), - ("indent=false" unless indent?) - ].compact - - if attributes.any? - q.text("(") - q.seplist(attributes, -> { q.text(", ") }) do |attribute| - q.text(attribute) - end - q.text(")") - end - end - end - - # A node in the print tree that forces the surrounding group to print out in - # the "break" mode as opposed to the "flat" mode. Useful for when you need to - # force a newline into a group. - class BreakParent - def pretty_print(q) - q.text("break-parent") - end - end - - # A node in the print tree that represents a group of items which the printer - # should try to fit onto one line. This is the basic command to tell the - # printer when to break. Groups are usually nested, and the printer will try - # to fit everything on one line, but if it doesn't fit it will break the - # outermost group first and try again. It will continue breaking groups until - # everything fits (or there are no more groups to break). - class Group - attr_reader :depth, :contents - - def initialize(depth, contents: []) - @depth = depth - @contents = contents - @break = false - end - - def break - @break = true - end - - def break? - @break - end - - def pretty_print(q) - q.group(2, break? ? "breakGroup([" : "group([", "])") do - q.seplist(contents) { |content| q.pp(content) } - end - end - end - - # A node in the print tree that represents printing one thing if the - # surrounding group node is broken and another thing if the surrounding group - # node is flat. - class IfBreak - attr_reader :break_contents, :flat_contents - - def initialize(break_contents: [], flat_contents: []) - @break_contents = break_contents - @flat_contents = flat_contents - end - - def pretty_print(q) - q.group(2, "if-break(", ")") do - q.breakable("") - q.group(2, "[", "],") do - q.seplist(break_contents) { |content| q.pp(content) } - end - q.breakable - q.group(2, "[", "]") do - q.seplist(flat_contents) { |content| q.pp(content) } - end - end - end - end - - # A node in the print tree that is a variant of the Align node that indents - # its contents by one level. - class Indent - attr_reader :contents - - def initialize(contents: []) - @contents = contents - end - - def pretty_print(q) - q.group(2, "indent([", "])") do - q.seplist(contents) { |content| q.pp(content) } - end - end - end - - # A node in the print tree that has its own special buffer for implementing - # content that should flush before any newline. - # - # Useful for implementating trailing content, as it's not always practical to - # constantly check where the line ends to avoid accidentally printing some - # content after a line suffix node. - class LineSuffix - DEFAULT_PRIORITY = 1 - - attr_reader :priority, :contents - - def initialize(priority: DEFAULT_PRIORITY, contents: []) - @priority = priority - @contents = contents - end - - def pretty_print(q) - q.group(2, "line-suffix([", "])") do - q.seplist(contents) { |content| q.pp(content) } - end - end - end - - # A node in the print tree that represents plain content that cannot be broken - # up (by default this assumes strings, but it can really be anything). - class Text - attr_reader :objects, :width - - def initialize - @objects = [] - @width = 0 - end - - def add(object: "", width: object.length) - @objects << object - @width += width - end - - def pretty_print(q) - q.group(2, "text([", "])") do - q.seplist(objects) { |object| q.pp(object) } - end - end - end - - # A node in the print tree that represents trimming all of the indentation of - # the current line, in the rare case that you need to ignore the indentation - # that you've already created. This node should be placed after a Breakable. - class Trim - def pretty_print(q) - q.text("trim") - end - end - - # When building up the contents in the output buffer, it's convenient to be - # able to trim trailing whitespace before newlines. If the output object is a - # string or array or strings, then we can do this with some gsub calls. If - # not, then this effectively just wraps the output object and forwards on - # calls to <<. - module Buffer - # This is the default output buffer that provides a base implementation of - # trim! that does nothing. It's effectively a wrapper around whatever output - # object was given to the format command. - class DefaultBuffer - attr_reader :output - - def initialize(output = []) - @output = output - end - - def <<(object) - @output << object - end - - def trim! - 0 - end - end - - # This is an output buffer that wraps a string output object. It provides a - # trim! method that trims off trailing whitespace from the string using - # gsub!. - class StringBuffer < DefaultBuffer - def initialize(output = "".dup) - super(output) - end - - def trim! - length = output.length - output.gsub!(/[\t ]*\z/, "") - length - output.length - end - end - - # This is an output buffer that wraps an array output object. It provides a - # trim! method that trims off trailing whitespace from the last element in - # the array if it's an unfrozen string using the same method as the - # StringBuffer. - class ArrayBuffer < DefaultBuffer - def initialize(output = []) - super(output) - end - - def trim! - return 0 if output.empty? - - trimmed = 0 - - while output.any? && output.last.is_a?(String) && - output.last.match?(/\A[\t ]*\z/) - trimmed += output.pop.length - end - - if output.any? && output.last.is_a?(String) && !output.last.frozen? - length = output.last.length - output.last.gsub!(/[\t ]*\z/, "") - trimmed += length - output.last.length - end - - trimmed - end - end - - # This is a switch for building the correct output buffer wrapper class for - # the given output object. - def self.for(output) - case output - when String - StringBuffer.new(output) - when Array - ArrayBuffer.new(output) - else - DefaultBuffer.new(output) - end - end - end - - # PrettyPrint::SingleLine is used by PrettyPrint.singleline_format - # - # It is passed to be similar to a PrettyPrint object itself, by responding to - # all of the same print tree node builder methods, as well as the #flush - # method. - # - # The significant difference here is that there are no line breaks in the - # output. If an IfBreak node is used, only the flat contents are printed. - # LineSuffix nodes are printed at the end of the buffer when #flush is called. - class SingleLine - # The output object. It stores rendered text and should respond to <<. - attr_reader :output - - # The current array of contents that the print tree builder methods should - # append to. - attr_reader :target - - # A buffer output that wraps any calls to line_suffix that will be flushed - # at the end of printing. - attr_reader :line_suffixes - - # Create a PrettyPrint::SingleLine object - # - # Arguments: - # * +output+ - String (or similar) to store rendered text. Needs to respond - # to '<<'. - # * +maxwidth+ - Argument position expected to be here for compatibility. - # This argument is a noop. - # * +newline+ - Argument position expected to be here for compatibility. - # This argument is a noop. - def initialize(output, _maxwidth = nil, _newline = nil) - @output = Buffer.for(output) - @target = @output - @line_suffixes = Buffer::ArrayBuffer.new - end - - # Flushes the line suffixes onto the output buffer. - def flush - line_suffixes.output.each { |doc| output << doc } - end - - # -------------------------------------------------------------------------- - # Markers node builders - # -------------------------------------------------------------------------- - - # Appends +separator+ to the text to be output. By default +separator+ is - # ' ' - # - # The +width+, +indent+, and +force+ arguments are here for compatibility. - # They are all noop arguments. - def breakable( - separator = " ", - _width = separator.length, - indent: nil, - force: nil - ) - target << separator - end - - # Here for compatibility, does nothing. - def break_parent - end - - # Appends +separator+ to the output buffer. +width+ is a noop here for - # compatibility. - def fill_breakable(separator = " ", _width = separator.length) - target << separator - end - - # Immediately trims the output buffer. - def trim - target.trim! - end - - # -------------------------------------------------------------------------- - # Container node builders - # -------------------------------------------------------------------------- - - # Opens a block for grouping objects to be pretty printed. - # - # Arguments: - # * +indent+ - noop argument. Present for compatibility. - # * +open_obj+ - text appended before the &block. Default is '' - # * +close_obj+ - text appended after the &block. Default is '' - # * +open_width+ - noop argument. Present for compatibility. - # * +close_width+ - noop argument. Present for compatibility. - def group( - _indent = nil, - open_object = "", - close_object = "", - _open_width = nil, - _close_width = nil - ) - target << open_object - yield - target << close_object - end - - # A class that wraps the ability to call #if_flat. The contents of the - # #if_flat block are executed immediately, so effectively this class and the - # #if_break method that triggers it are unnecessary, but they're here to - # maintain compatibility. - class IfBreakBuilder - def if_flat - yield - end - end - - # Effectively unnecessary, but here for compatibility. - def if_break - IfBreakBuilder.new - end - - # Also effectively unnecessary, but here for compatibility. - def if_flat - end - - # A noop that immediately yields. - def indent - yield - end - - # Changes the target output buffer to the line suffix output buffer which - # will get flushed at the end of printing. - def line_suffix - previous_target, @target = @target, line_suffixes - yield - @target = previous_target - end - - # Takes +indent+ arg, but does nothing with it. - # - # Yields to a block. - def nest(_indent) - yield - end - - # Add +object+ to the text to be output. - # - # +width+ argument is here for compatibility. It is a noop argument. - def text(object = "", _width = nil) - target << object - end - end - - # This object represents the current level of indentation within the printer. - # It has the ability to generate new levels of indentation through the #align - # and #indent methods. - class IndentLevel - IndentPart = Object.new - DedentPart = Object.new - - StringAlignPart = Struct.new(:n) - NumberAlignPart = Struct.new(:n) - - attr_reader :genspace, :value, :length, :queue, :root - - def initialize( - genspace:, - value: genspace.call(0), - length: 0, - queue: [], - root: nil - ) - @genspace = genspace - @value = value - @length = length - @queue = queue - @root = root - end - - # This can accept a whole lot of different kinds of objects, due to the - # nature of the flexibility of the Align node. - def align(n) - case n - when NilClass - self - when String - indent(StringAlignPart.new(n)) - else - indent(n < 0 ? DedentPart : NumberAlignPart.new(n)) - end - end - - def indent(part = IndentPart) - next_value = genspace.call(0) - next_length = 0 - next_queue = (part == DedentPart ? queue[0...-1] : [*queue, part]) - - last_spaces = 0 - - add_spaces = ->(count) do - next_value << genspace.call(count) - next_length += count - end - - flush_spaces = -> do - add_spaces[last_spaces] if last_spaces > 0 - last_spaces = 0 - end - - next_queue.each do |next_part| - case next_part - when IndentPart - flush_spaces.call - add_spaces.call(2) - when StringAlignPart - flush_spaces.call - next_value += next_part.n - next_length += next_part.n.length - when NumberAlignPart - last_spaces += next_part.n - end - end - - flush_spaces.call - - IndentLevel.new( - genspace: genspace, - value: next_value, - length: next_length, - queue: next_queue, - root: root - ) - end - end - - # When printing, you can optionally specify the value that should be used - # whenever a group needs to be broken onto multiple lines. In this case the - # default is \n. - DEFAULT_NEWLINE = "\n" - - # When generating spaces after a newline for indentation, by default we - # generate one space per character needed for indentation. You can change this - # behavior (for instance to use tabs) by passing a different genspace - # procedure. - DEFAULT_GENSPACE = ->(n) { " " * n } - - # There are two modes in printing, break and flat. When we're in break mode, - # any lines will use their newline, any if-breaks will use their break - # contents, etc. - MODE_BREAK = 1 - - # This is another print mode much like MODE_BREAK. When we're in flat mode, we - # attempt to print everything on one line until we either hit a broken group, - # a forced line, or the maximum width. - MODE_FLAT = 2 - - # This is a convenience method which is same as follows: - # - # begin - # q = PrettyPrint.new(output, maxwidth, newline, &genspace) - # ... - # q.flush - # output - # end - # - def self.format( - output = "".dup, - maxwidth = 80, - newline = DEFAULT_NEWLINE, - genspace = DEFAULT_GENSPACE - ) - q = new(output, maxwidth, newline, &genspace) - yield q - q.flush - output - end - - # This is similar to PrettyPrint::format but the result has no breaks. - # - # +maxwidth+, +newline+ and +genspace+ are ignored. - # - # The invocation of +breakable+ in the block doesn't break a line and is - # treated as just an invocation of +text+. - # - def self.singleline_format( - output = "".dup, - _maxwidth = nil, - _newline = nil, - _genspace = nil - ) - q = SingleLine.new(output) - yield q - output - end - - # The output object. It represents the final destination of the contents of - # the print tree. It should respond to <<. - # - # This defaults to "".dup - attr_reader :output - - # This is an output buffer that wraps the output object and provides - # additional functionality depending on its type. - # - # This defaults to Buffer::StringBuffer.new("".dup) - attr_reader :buffer - - # The maximum width of a line, before it is separated in to a newline - # - # This defaults to 80, and should be an Integer - attr_reader :maxwidth - - # The value that is appended to +output+ to add a new line. - # - # This defaults to "\n", and should be String - attr_reader :newline - - # An object that responds to call that takes one argument, of an Integer, and - # returns the corresponding number of spaces. - # - # By default this is: ->(n) { ' ' * n } - attr_reader :genspace - - # The stack of groups that are being printed. - attr_reader :groups - - # The current array of contents that calls to methods that generate print tree - # nodes will append to. - attr_reader :target - - # Creates a buffer for pretty printing. - # - # +output+ is an output target. If it is not specified, '' is assumed. It - # should have a << method which accepts the first argument +obj+ of - # PrettyPrint#text, the first argument +separator+ of PrettyPrint#breakable, - # the first argument +newline+ of PrettyPrint.new, and the result of a given - # block for PrettyPrint.new. - # - # +maxwidth+ specifies maximum line length. If it is not specified, 80 is - # assumed. However actual outputs may overflow +maxwidth+ if long - # non-breakable texts are provided. - # - # +newline+ is used for line breaks. "\n" is used if it is not specified. - # - # The block is used to generate spaces. ->(n) { ' ' * n } is used if it is not - # given. - def initialize( - output = "".dup, - maxwidth = 80, - newline = DEFAULT_NEWLINE, - &genspace - ) - @output = output - @buffer = Buffer.for(output) - @maxwidth = maxwidth - @newline = newline - @genspace = genspace || DEFAULT_GENSPACE - reset - end - - # Returns the group most recently added to the stack. - # - # Contrived example: - # out = "" - # => "" - # q = PrettyPrint.new(out) - # => # - # q.group { - # q.text q.current_group.inspect - # q.text q.newline - # q.group(q.current_group.depth + 1) { - # q.text q.current_group.inspect - # q.text q.newline - # q.group(q.current_group.depth + 1) { - # q.text q.current_group.inspect - # q.text q.newline - # q.group(q.current_group.depth + 1) { - # q.text q.current_group.inspect - # q.text q.newline - # } - # } - # } - # } - # => 284 - # puts out - # # - # # - # # - # # - def current_group - groups.last - end - - # Flushes all of the generated print tree onto the output buffer, then clears - # the generated tree from memory. - def flush - # First, get the root group, since we placed one at the top to begin with. - doc = groups.first - - # This represents how far along the current line we are. It gets reset - # back to 0 when we encounter a newline. - position = 0 - - # This is our command stack. A command consists of a triplet of an - # indentation level, the mode (break or flat), and a doc node. - commands = [[IndentLevel.new(genspace: genspace), MODE_BREAK, doc]] - - # This is a small optimization boolean. It keeps track of whether or not - # when we hit a group node we should check if it fits on the same line. - should_remeasure = false - - # This is a separate command stack that includes the same kind of triplets - # as the commands variable. It is used to keep track of things that should - # go at the end of printed lines once the other doc nodes are accounted for. - # Typically this is used to implement comments. - line_suffixes = [] - - # This is a special sort used to order the line suffixes by both the - # priority set on the line suffix and the index it was in the original - # array. - line_suffix_sort = ->(line_suffix) do - [-line_suffix.last, -line_suffixes.index(line_suffix)] - end - - # This is a linear stack instead of a mutually recursive call defined on - # the individual doc nodes for efficiency. - while (indent, mode, doc = commands.pop) - case doc - when Text - doc.objects.each { |object| buffer << object } - position += doc.width - when Array - doc.reverse_each { |part| commands << [indent, mode, part] } - when Indent - commands << [indent.indent, mode, doc.contents] - when Align - commands << [indent.align(doc.indent), mode, doc.contents] - when Trim - position -= buffer.trim! - when Group - if mode == MODE_FLAT && !should_remeasure - commands << [ - indent, - doc.break? ? MODE_BREAK : MODE_FLAT, - doc.contents - ] - else - should_remeasure = false - next_cmd = [indent, MODE_FLAT, doc.contents] - commands << if !doc.break? && - fits?(next_cmd, commands, maxwidth - position) - next_cmd - else - [indent, MODE_BREAK, doc.contents] - end - end - when IfBreak - if mode == MODE_BREAK && doc.break_contents.any? - commands << [indent, mode, doc.break_contents] - elsif mode == MODE_FLAT && doc.flat_contents.any? - commands << [indent, mode, doc.flat_contents] - end - when LineSuffix - line_suffixes << [indent, mode, doc.contents, doc.priority] - when Breakable - if mode == MODE_FLAT - if doc.force? - # This line was forced into the output even if we were in flat mode, - # so we need to tell the next group that no matter what, it needs to - # remeasure because the previous measurement didn't accurately - # capture the entire expression (this is necessary for nested - # groups). - should_remeasure = true - else - buffer << doc.separator - position += doc.width - next - end - end - - # If there are any commands in the line suffix buffer, then we're going - # to flush them now, as we are about to add a newline. - if line_suffixes.any? - commands << [indent, mode, doc] - commands += line_suffixes.sort_by(&line_suffix_sort) - line_suffixes = [] - next - end - - if !doc.indent? - buffer << newline - - if indent.root - buffer << indent.root.value - position = indent.root.length - else - position = 0 - end - else - position -= buffer.trim! - buffer << newline - buffer << indent.value - position = indent.length - end - when BreakParent - # do nothing - else - # Special case where the user has defined some way to get an extra doc - # node that we don't explicitly support into the list. In this case - # we're going to assume it's 0-width and just append it to the output - # buffer. - # - # This is useful behavior for putting marker nodes into the list so that - # you can know how things are getting mapped before they get printed. - buffer << doc - end - - if commands.empty? && line_suffixes.any? - commands += line_suffixes.sort_by(&line_suffix_sort) - line_suffixes = [] - end - end - - # Reset the group stack and target array so that this pretty printer object - # can continue to be used before calling flush again if desired. - reset - end - - # ---------------------------------------------------------------------------- - # Markers node builders - # ---------------------------------------------------------------------------- - - # This says "you can break a line here if necessary", and a +width+\-column - # text +separator+ is inserted if a line is not broken at the point. - # - # If +separator+ is not specified, ' ' is used. - # - # If +width+ is not specified, +separator.length+ is used. You will have to - # specify this when +separator+ is a multibyte character, for example. - # - # By default, if the surrounding group is broken and a newline is inserted, - # the printer will indent the subsequent line up to the current level of - # indentation. You can disable this behavior with the +indent+ argument if - # that's not desired (rare). - # - # By default, when you insert a Breakable into the print tree, it only breaks - # the surrounding group when the group's contents cannot fit onto the - # remaining space of the current line. You can force it to break the - # surrounding group instead if you always want the newline with the +force+ - # argument. - def breakable( - separator = " ", - width = separator.length, - indent: true, - force: false - ) - doc = Breakable.new(separator, width, indent: indent, force: force) - - target << doc - break_parent if force - - doc - end - - # This inserts a BreakParent node into the print tree which forces the - # surrounding and all parent group nodes to break. - def break_parent - doc = BreakParent.new - target << doc - - groups.reverse_each do |group| - break if group.break? - group.break - end - - doc - end - - # This is similar to #breakable except the decision to break or not is - # determined individually. - # - # Two #fill_breakable under a group may cause 4 results: - # (break,break), (break,non-break), (non-break,break), (non-break,non-break). - # This is different to #breakable because two #breakable under a group - # may cause 2 results: (break,break), (non-break,non-break). - # - # The text +separator+ is inserted if a line is not broken at this point. - # - # If +separator+ is not specified, ' ' is used. - # - # If +width+ is not specified, +separator.length+ is used. You will have to - # specify this when +separator+ is a multibyte character, for example. - def fill_breakable(separator = " ", width = separator.length) - group { breakable(separator, width) } - end - - # This inserts a Trim node into the print tree which, when printed, will clear - # all whitespace at the end of the output buffer. This is useful for the rare - # case where you need to delete printed indentation and force the next node - # to start at the beginning of the line. - def trim - doc = Trim.new - target << doc - - doc - end - - # ---------------------------------------------------------------------------- - # Container node builders - # ---------------------------------------------------------------------------- - - # Groups line break hints added in the block. The line break hints are all to - # be used or not. - # - # If +indent+ is specified, the method call is regarded as nested by - # nest(indent) { ... }. - # - # If +open_object+ is specified, text(open_object, open_width) is - # called before grouping. If +close_object+ is specified, - # text(close_object, close_width) is called after grouping. - def group( - indent = 0, - open_object = "", - close_object = "", - open_width = open_object.length, - close_width = close_object.length - ) - text(open_object, open_width) if open_object != "" - - doc = Group.new(groups.last.depth + 1) - groups << doc - target << doc - - with_target(doc.contents) do - if indent != 0 - nest(indent) { yield } - else - yield - end - end - - groups.pop - text(close_object, close_width) if close_object != "" - - doc - end - - # A small DSL-like object used for specifying the alternative contents to be - # printed if the surrounding group doesn't break for an IfBreak node. - class IfBreakBuilder - attr_reader :builder, :if_break - - def initialize(builder, if_break) - @builder = builder - @if_break = if_break - end - - def if_flat(&block) - builder.with_target(if_break.flat_contents, &block) - end - end - - # Inserts an IfBreak node with the contents of the block being added to its - # list of nodes that should be printed if the surrounding node breaks. If it - # doesn't, then you can specify the contents to be printed with the #if_flat - # method used on the return object from this method. For example, - # - # q.if_break { q.text('do') }.if_flat { q.text('{') } - # - # In the example above, if the surrounding group is broken it will print 'do' - # and if it is not it will print '{'. - def if_break - doc = IfBreak.new - target << doc - - with_target(doc.break_contents) { yield } - IfBreakBuilder.new(self, doc) - end - - # This is similar to if_break in that it also inserts an IfBreak node into the - # print tree, however it's starting from the flat contents, and cannot be used - # to build the break contents. - def if_flat - doc = IfBreak.new - target << doc - - with_target(doc.flat_contents) { yield } - end - - # Very similar to the #nest method, this indents the nested content by one - # level by inserting an Indent node into the print tree. The contents of the - # node are determined by the block. - def indent - doc = Indent.new - target << doc - - with_target(doc.contents) { yield } - doc - end - - # Inserts a LineSuffix node into the print tree. The contents of the node are - # determined by the block. - def line_suffix(priority: LineSuffix::DEFAULT_PRIORITY) - doc = LineSuffix.new(priority: priority) - target << doc - - with_target(doc.contents) { yield } - doc - end - - # Increases left margin after newline with +indent+ for line breaks added in - # the block. - def nest(indent) - doc = Align.new(indent: indent) - target << doc - - with_target(doc.contents) { yield } - doc - end - - # This adds +object+ as a text of +width+ columns in width. - # - # If +width+ is not specified, object.length is used. - def text(object = "", width = object.length) - doc = target.last - - unless doc.is_a?(Text) - doc = Text.new - target << doc - end - - doc.add(object: object, width: width) - doc - end - - # ---------------------------------------------------------------------------- - # Internal APIs - # ---------------------------------------------------------------------------- - - # A convenience method used by a lot of the print tree node builders that - # temporarily changes the target that the builders will append to. - def with_target(target) - previous_target, @target = @target, target - yield - @target = previous_target - end - - private - - # This method returns a boolean as to whether or not the remaining commands - # fit onto the remaining space on the current line. If we finish printing - # all of the commands or if we hit a newline, then we return true. Otherwise - # if we continue printing past the remaining space, we return false. - def fits?(next_command, rest_commands, remaining) - # This is the index in the remaining commands that we've handled so far. - # We reverse through the commands and add them to the stack if we've run - # out of nodes to handle. - rest_index = rest_commands.length - - # This is our stack of commands, very similar to the commands list in the - # print method. - commands = [next_command] - - # This is our output buffer, really only necessary to keep track of - # because we could encounter a Trim doc node that would actually add - # remaining space. - fit_buffer = buffer.class.new - - while remaining >= 0 - if commands.empty? - return true if rest_index == 0 - - rest_index -= 1 - commands << rest_commands[rest_index] - next - end - - indent, mode, doc = commands.pop - - case doc - when Text - doc.objects.each { |object| fit_buffer << object } - remaining -= doc.width - when Array - doc.reverse_each { |part| commands << [indent, mode, part] } - when Indent - commands << [indent.indent, mode, doc.contents] - when Align - commands << [indent.align(doc.indent), mode, doc.contents] - when Trim - remaining += fit_buffer.trim! - when Group - commands << [indent, doc.break? ? MODE_BREAK : mode, doc.contents] - when IfBreak - if mode == MODE_BREAK && doc.break_contents.any? - commands << [indent, mode, doc.break_contents] - elsif mode == MODE_FLAT && doc.flat_contents.any? - commands << [indent, mode, doc.flat_contents] - end - when Breakable - if mode == MODE_FLAT && !doc.force? - fit_buffer << doc.separator - remaining -= doc.width - next - end - - return true - end - end - - false - end - - # Resets the group stack and target array so that this pretty printer object - # can continue to be used before calling flush again if desired. - def reset - @groups = [Group.new(0)] - @target = @groups.last.contents - end -end diff --git a/lib/syntax_tree/rake/check_task.rb b/lib/syntax_tree/rake/check_task.rb new file mode 100644 index 00000000..5b441a5b --- /dev/null +++ b/lib/syntax_tree/rake/check_task.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +require_relative "task" + +module SyntaxTree + module Rake + # A Rake task that runs check on a set of source files. + # + # Example: + # + # require "syntax_tree/rake/check_task" + # + # SyntaxTree::Rake::CheckTask.new do |t| + # t.source_files = "{app,config,lib}/**/*.rb" + # end + # + # This will create task that can be run with: + # + # rake stree:check + # + class CheckTask < Task + private + + def command + "check" + end + end + end +end diff --git a/lib/syntax_tree/rake/task.rb b/lib/syntax_tree/rake/task.rb new file mode 100644 index 00000000..e9a20433 --- /dev/null +++ b/lib/syntax_tree/rake/task.rb @@ -0,0 +1,85 @@ +# frozen_string_literal: true + +require "rake" +require "rake/tasklib" + +require "syntax_tree" +require "syntax_tree/cli" + +module SyntaxTree + module Rake + # A parent Rake task that runs a command on a set of source files. + class Task < ::Rake::TaskLib + # Name of the task. + attr_accessor :name + + # Glob pattern to match source files. + # Defaults to 'lib/**/*.rb'. + attr_accessor :source_files + + # The set of plugins to require. + # Defaults to []. + attr_accessor :plugins + + # Max line length. + # Defaults to 80. + attr_accessor :print_width + + # The target Ruby version to use for formatting. + # Defaults to Gem::Version.new(RUBY_VERSION). + attr_accessor :target_ruby_version + + # Glob pattern to ignore source files. + # Defaults to ''. + attr_accessor :ignore_files + + def initialize( + name = :"stree:#{command}", + source_files = ::Rake::FileList["lib/**/*.rb"], + plugins = [], + print_width = DEFAULT_PRINT_WIDTH, + target_ruby_version = Gem::Version.new(RUBY_VERSION), + ignore_files = "" + ) + @name = name + @source_files = source_files + @plugins = plugins + @print_width = print_width + @target_ruby_version = target_ruby_version + @ignore_files = ignore_files + + yield self if block_given? + define_task + end + + private + + # This method needs to be overridden in the child tasks. + def command + raise NotImplementedError + end + + def define_task + desc "Runs `stree #{command}` over source files" + task(name) { run_task } + end + + def run_task + arguments = [command] + arguments << "--plugins=#{plugins.join(",")}" if plugins.any? + + if print_width != DEFAULT_PRINT_WIDTH + arguments << "--print-width=#{print_width}" + end + + if target_ruby_version != Gem::Version.new(RUBY_VERSION) + arguments << "--target-ruby-version=#{target_ruby_version}" + end + + arguments << "--ignore-files=#{ignore_files}" if ignore_files != "" + + abort if SyntaxTree::CLI.run(arguments + Array(source_files)) != 0 + end + end + end +end diff --git a/lib/syntax_tree/rake/write_task.rb b/lib/syntax_tree/rake/write_task.rb new file mode 100644 index 00000000..8037792e --- /dev/null +++ b/lib/syntax_tree/rake/write_task.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +require_relative "task" + +module SyntaxTree + module Rake + # A Rake task that runs write on a set of source files. + # + # Example: + # + # require "syntax_tree/rake/write_task" + # + # SyntaxTree::Rake::WriteTask.new do |t| + # t.source_files = "{app,config,lib}/**/*.rb" + # end + # + # This will create task that can be run with: + # + # rake stree:write + # + class WriteTask < Task + private + + def command + "write" + end + end + end +end diff --git a/lib/syntax_tree/rake_tasks.rb b/lib/syntax_tree/rake_tasks.rb new file mode 100644 index 00000000..b53743e5 --- /dev/null +++ b/lib/syntax_tree/rake_tasks.rb @@ -0,0 +1,4 @@ +# frozen_string_literal: true + +require_relative "rake/check_task" +require_relative "rake/write_task" diff --git a/lib/syntax_tree/reflection.rb b/lib/syntax_tree/reflection.rb new file mode 100644 index 00000000..6955aa21 --- /dev/null +++ b/lib/syntax_tree/reflection.rb @@ -0,0 +1,257 @@ +# frozen_string_literal: true + +module SyntaxTree + # This module is used to provide some reflection on the various types of nodes + # and their attributes. As soon as it is required it collects all of its + # information. + module Reflection + # This module represents the type of the values being passed to attributes + # of nodes. It is used as part of the documentation of the attributes. + module Type + CONSTANTS = SyntaxTree.constants.to_h { [_1, SyntaxTree.const_get(_1)] } + + # Represents an array type that holds another type. + class ArrayType + attr_reader :type + + def initialize(type) + @type = type + end + + def ===(value) + value.is_a?(Array) && value.all? { type === _1 } + end + + def inspect + "Array<#{type.inspect}>" + end + end + + # Represents a tuple type that holds a number of types in order. + class TupleType + attr_reader :types + + def initialize(types) + @types = types + end + + def ===(value) + value.is_a?(Array) && value.length == types.length && + value.zip(types).all? { |item, type| type === item } + end + + def inspect + "[#{types.map(&:inspect).join(", ")}]" + end + end + + # Represents a union type that can be one of a number of types. + class UnionType + attr_reader :types + + def initialize(types) + @types = types + end + + def ===(value) + types.any? { _1 === value } + end + + def inspect + types.map(&:inspect).join(" | ") + end + end + + class << self + def parse(comment) + comment = comment.gsub("\n", " ") + + unless comment.start_with?("[") + raise "Comment does not start with a bracket: #{comment.inspect}" + end + + count = 1 + found = + comment.chars[1..] + .find + .with_index(1) do |char, index| + count += { "[" => 1, "]" => -1 }.fetch(char, 0) + break index if count == 0 + end + + # If we weren't able to find the end of the balanced brackets, then + # the comment is malformed. + if found.nil? + raise "Comment does not have balanced brackets: #{comment.inspect}" + end + + parse_type(comment[1...found].strip) + end + + private + + def parse_type(value) + case value + when "Integer" + Integer + when "String" + String + when "Symbol" + Symbol + when "boolean" + UnionType.new([TrueClass, FalseClass]) + when "nil" + NilClass + when ":\"::\"" + :"::" + when ":call" + :call + when ":nil" + :nil + when /\AArray\[(.+)\]\z/ + ArrayType.new(parse_type($1.strip)) + when /\A\[(.+)\]\z/ + TupleType.new($1.strip.split(/\s*,\s*/).map { parse_type(_1) }) + else + if value.include?("|") + UnionType.new(value.split(/\s*\|\s*/).map { parse_type(_1) }) + else + CONSTANTS.fetch(value.to_sym) + end + end + end + end + end + + # This class represents one of the attributes on a node in the tree. + class Attribute + attr_reader :name, :comment, :type + + def initialize(name, comment) + @name = name + @comment = comment + @type = Type.parse(comment) + end + end + + # This class represents one of our nodes in the tree. We're going to use it + # as a placeholder for collecting all of the various places that nodes are + # used. + class Node + attr_reader :name, :comment, :attributes, :visitor_method + + def initialize(name, comment, attributes, visitor_method) + @name = name + @comment = comment + @attributes = attributes + @visitor_method = visitor_method + end + end + + class << self + # This is going to hold a hash of all of the nodes in the tree. The keys + # are the names of the nodes as symbols. + attr_reader :nodes + + # This expects a node name as a symbol and returns the node object for + # that node. + def node(name) + nodes.fetch(name) + end + + private + + def parse_comments(statements, index) + statements[0...index] + .reverse_each + .take_while { _1.is_a?(SyntaxTree::Comment) } + .reverse_each + .map { _1.value[2..] } + end + end + + @nodes = {} + + # For each node, we're going to parse out its attributes and other metadata. + # We'll use this as the basis for our report. + program = + SyntaxTree.parse(SyntaxTree.read(File.expand_path("node.rb", __dir__))) + + program_statements = program.statements + main_statements = program_statements.body.last.bodystmt.statements.body + main_statements.each_with_index do |main_statement, main_statement_index| + # Ensure we are only looking at class declarations. + next unless main_statement.is_a?(SyntaxTree::ClassDeclaration) + + # Ensure we're looking at class declarations with superclasses. + superclass = main_statement.superclass + next unless superclass.is_a?(SyntaxTree::VarRef) + + # Ensure we're looking at class declarations that inherit from Node. + next unless superclass.value.value == "Node" + + # All child nodes inherit the location attr_reader from Node, so we'll add + # that to the list of attributes first. + attributes = { + location: + Attribute.new(:location, "[Location] the location of this node") + } + + # This is the name of the method tha gets called on the given visitor when + # the accept method is called on this node. + visitor_method = nil + + statements = main_statement.bodystmt.statements.body + statements.each_with_index do |statement, statement_index| + case statement + when SyntaxTree::Command + # We only use commands in node classes to define attributes. So, we + # can safely assume that we're looking at an attribute definition. + unless %w[attr_reader attr_accessor].include?(statement.message.value) + raise "Unexpected command: #{statement.message.value.inspect}" + end + + # The arguments to the command are the attributes that we're defining. + # We want to ensure that we're only defining one at a time. + if statement.arguments.parts.length != 1 + raise "Declaring more than one attribute at a time is not permitted" + end + + attribute = + Attribute.new( + statement.arguments.parts.first.value.value.to_sym, + "#{parse_comments(statements, statement_index).join("\n")}\n" + ) + + # Ensure that we don't already have an attribute named the same as + # this one, and then add it to the list of attributes. + if attributes.key?(attribute.name) + raise "Duplicate attribute: #{attribute.name}" + end + + attributes[attribute.name] = attribute + when SyntaxTree::DefNode + if statement.name.value == "accept" + call_node = statement.bodystmt.statements.body.first + visitor_method = call_node.message.value.to_sym + end + end + end + + # If we never found a visitor method, then we have an error. + raise if visitor_method.nil? + + # Finally, set it up in the hash of nodes so that we can use it later. + comments = parse_comments(main_statements, main_statement_index) + node = + Node.new( + main_statement.constant.constant.value.to_sym, + "#{comments.join("\n")}\n", + attributes, + visitor_method + ) + + @nodes[node.name] = node + end + end +end diff --git a/lib/syntax_tree/search.rb b/lib/syntax_tree/search.rb new file mode 100644 index 00000000..9fd52ba1 --- /dev/null +++ b/lib/syntax_tree/search.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +module SyntaxTree + # Provides an interface for searching for a pattern of nodes against a + # subtree of an AST. + class Search + attr_reader :pattern + + def initialize(pattern) + @pattern = pattern + end + + def scan(root) + return to_enum(__method__, root) unless block_given? + queue = [root] + + until queue.empty? + node = queue.shift + next unless node + + yield node if pattern.call(node) + queue += node.child_nodes + end + end + end +end diff --git a/lib/syntax_tree/version.rb b/lib/syntax_tree/version.rb index 894ff1b7..9e80fa7b 100644 --- a/lib/syntax_tree/version.rb +++ b/lib/syntax_tree/version.rb @@ -1,5 +1,5 @@ # frozen_string_literal: true module SyntaxTree - VERSION = "2.4.0" + VERSION = "6.3.0" end diff --git a/lib/syntax_tree/visitor.rb b/lib/syntax_tree/visitor.rb index 57794ddb..eb57acd2 100644 --- a/lib/syntax_tree/visitor.rb +++ b/lib/syntax_tree/visitor.rb @@ -4,79 +4,14 @@ module SyntaxTree # Visitor is a parent class that provides the ability to walk down the tree # and handle a subset of nodes. By defining your own subclass, you can # explicitly handle a node type by defining a visit_* method. - class Visitor - # This is raised when you use the Visitor.visit_method method and it fails. - # It is correctable to through DidYouMean. - class VisitMethodError < StandardError - attr_reader :visit_method - - def initialize(visit_method) - @visit_method = visit_method - super("Invalid visit method: #{visit_method}") - end - end - - # This class is used by DidYouMean to offer corrections to invalid visit - # method names. - class VisitMethodChecker - attr_reader :visit_method - - def initialize(error) - @visit_method = error.visit_method - end - - def corrections - @corrections ||= - DidYouMean::SpellChecker.new( - dictionary: Visitor.visit_methods - ).correct(visit_method) - end - - DidYouMean.correct_error(VisitMethodError, self) - end - - class << self - # This method is here to help folks write visitors. - # - # It's not always easy to ensure you're writing the correct method name in - # the visitor since it's perfectly valid to define methods that don't - # override these parent methods. - # - # If you use this method, you can ensure you're writing the correct method - # name. It will raise an error if the visit method you're defining isn't - # actually a method on the parent visitor. - def visit_method(method_name) - return if visit_methods.include?(method_name) - - raise VisitMethodError, method_name - end - - # This is the list of all of the valid visit methods. - def visit_methods - @visit_methods ||= - Visitor.instance_methods.grep(/^visit_(?!child_nodes)/) - end - end - - def visit(node) - node&.accept(self) - end - - def visit_all(nodes) - nodes.map { |node| visit(node) } - end - - def visit_child_nodes(node) - visit_all(node.child_nodes) - end - + class Visitor < BasicVisitor # Visit an ARef node. alias visit_aref visit_child_nodes # Visit an ARefField node. alias visit_aref_field visit_child_nodes - # Visit an Alias node. + # Visit an AliasNode node. alias visit_alias visit_child_nodes # Visit an ArgBlock node. @@ -127,6 +62,9 @@ def visit_child_nodes(node) # Visit a Binary node. alias visit_binary visit_child_nodes + # Visit a Block node. + alias visit_block visit_child_nodes + # Visit a BlockArg node. alias visit_blockarg visit_child_nodes @@ -136,9 +74,6 @@ def visit_child_nodes(node) # Visit a BodyStmt node. alias visit_bodystmt visit_child_nodes - # Visit a BraceBlock node. - alias visit_brace_block visit_child_nodes - # Visit a Break node. alias visit_break visit_child_nodes @@ -184,24 +119,9 @@ def visit_child_nodes(node) # Visit a Def node. alias visit_def visit_child_nodes - # Visit a DefEndless node. - alias visit_def_endless visit_child_nodes - # Visit a Defined node. alias visit_defined visit_child_nodes - # Visit a Defs node. - alias visit_defs visit_child_nodes - - # Visit a DoBlock node. - alias visit_do_block visit_child_nodes - - # Visit a Dot2 node. - alias visit_dot2 visit_child_nodes - - # Visit a Dot3 node. - alias visit_dot3 visit_child_nodes - # Visit a DynaSymbol node. alias visit_dyna_symbol visit_child_nodes @@ -232,9 +152,6 @@ def visit_child_nodes(node) # Visit an ExcessedComma node. alias visit_excessed_comma visit_child_nodes - # Visit a FCall node. - alias visit_fcall visit_child_nodes - # Visit a Field node. alias visit_field visit_child_nodes @@ -259,18 +176,18 @@ def visit_child_nodes(node) # Visit a HeredocBeg node. alias visit_heredoc_beg visit_child_nodes + # Visit a HeredocEnd node. + alias visit_heredoc_end visit_child_nodes + # Visit a HshPtn node. alias visit_hshptn visit_child_nodes # Visit an Ident node. alias visit_ident visit_child_nodes - # Visit an If node. + # Visit an IfNode node. alias visit_if visit_child_nodes - # Visit an IfMod node. - alias visit_if_mod visit_child_nodes - # Visit an IfOp node. alias visit_if_op visit_child_nodes @@ -301,6 +218,9 @@ def visit_child_nodes(node) # Visit a Lambda node. alias visit_lambda visit_child_nodes + # Visit a LambdaVar node. + alias visit_lambda_var visit_child_nodes + # Visit a LBrace node. alias visit_lbrace visit_child_nodes @@ -370,6 +290,9 @@ def visit_child_nodes(node) # Visit a QWordsBeg node. alias visit_qwords_beg visit_child_nodes + # Visit a RangeNode node + alias visit_range visit_child_nodes + # Visit a RAssign node. alias visit_rassign visit_child_nodes @@ -415,9 +338,6 @@ def visit_child_nodes(node) # Visit a Return node. alias visit_return visit_child_nodes - # Visit a Return0 node. - alias visit_return0 visit_child_nodes - # Visit a RParen node. alias visit_rparen visit_child_nodes @@ -487,21 +407,12 @@ def visit_child_nodes(node) # Visit an Undef node. alias visit_undef visit_child_nodes - # Visit an Unless node. + # Visit an UnlessNode node. alias visit_unless visit_child_nodes - # Visit an UnlessMod node. - alias visit_unless_mod visit_child_nodes - - # Visit an Until node. + # Visit an UntilNode node. alias visit_until visit_child_nodes - # Visit an UntilMod node. - alias visit_until_mod visit_child_nodes - - # Visit a VarAlias node. - alias visit_var_alias visit_child_nodes - # Visit a VarField node. alias visit_var_field visit_child_nodes @@ -517,12 +428,9 @@ def visit_child_nodes(node) # Visit a When node. alias visit_when visit_child_nodes - # Visit a While node. + # Visit a WhileNode node. alias visit_while visit_child_nodes - # Visit a WhileMod node. - alias visit_while_mod visit_child_nodes - # Visit a Word node. alias visit_word visit_child_nodes @@ -538,12 +446,9 @@ def visit_child_nodes(node) # Visit a XStringLiteral node. alias visit_xstring_literal visit_child_nodes - # Visit a Yield node. + # Visit a YieldNode node. alias visit_yield visit_child_nodes - # Visit a Yield0 node. - alias visit_yield0 visit_child_nodes - # Visit a ZSuper node. alias visit_zsuper visit_child_nodes diff --git a/lib/syntax_tree/visitor/json_visitor.rb b/lib/syntax_tree/visitor/json_visitor.rb deleted file mode 100644 index b516980c..00000000 --- a/lib/syntax_tree/visitor/json_visitor.rb +++ /dev/null @@ -1,55 +0,0 @@ -# frozen_string_literal: true - -module SyntaxTree - class Visitor - # This visitor transforms the AST into a hash that contains only primitives - # that can be easily serialized into JSON. - class JSONVisitor < FieldVisitor - attr_reader :target - - def initialize - @target = nil - end - - private - - def comments(node) - target[:comments] = visit_all(node.comments) - end - - def field(name, value) - target[name] = value.is_a?(Node) ? visit(value) : value - end - - def list(name, values) - target[name] = visit_all(values) - end - - def node(node, type) - previous = @target - @target = { type: type, location: visit_location(node.location) } - yield - @target - ensure - @target = previous - end - - def pairs(name, values) - target[name] = values.map { |(key, value)| [visit(key), visit(value)] } - end - - def text(name, value) - target[name] = value - end - - def visit_location(location) - [ - location.start_line, - location.start_char, - location.end_line, - location.end_char - ] - end - end - end -end diff --git a/lib/syntax_tree/visitor/match_visitor.rb b/lib/syntax_tree/visitor/match_visitor.rb deleted file mode 100644 index 205f2b90..00000000 --- a/lib/syntax_tree/visitor/match_visitor.rb +++ /dev/null @@ -1,122 +0,0 @@ -# frozen_string_literal: true - -module SyntaxTree - class Visitor - # This visitor transforms the AST into a Ruby pattern matching expression - # that would match correctly against the AST. - class MatchVisitor < FieldVisitor - attr_reader :q - - def initialize(q) - @q = q - end - - def visit(node) - case node - when Node - super - when String - # pp will split up a string on newlines and concat them together using - # a "+" operator. This breaks the pattern matching expression. So - # instead we're going to check here for strings and manually put the - # entire value into the output buffer. - q.text(node.inspect) - else - q.pp(node) - end - end - - private - - def comments(node) - return if node.comments.empty? - - q.nest(0) do - q.text("comments: [") - q.indent do - q.breakable("") - q.seplist(node.comments) { |comment| visit(comment) } - end - q.breakable("") - q.text("]") - end - end - - def field(name, value) - q.nest(0) do - q.text(name) - q.text(": ") - visit(value) - end - end - - def list(name, values) - q.group do - q.text(name) - q.text(": [") - q.indent do - q.breakable("") - q.seplist(values) { |value| visit(value) } - end - q.breakable("") - q.text("]") - end - end - - def node(node, _type) - items = [] - q.with_target(items) { yield } - - if items.empty? - q.text(node.class.name) - return - end - - q.group do - q.text(node.class.name) - q.text("[") - q.indent do - q.breakable("") - q.seplist(items) { |item| q.target << item } - end - q.breakable("") - q.text("]") - end - end - - def pairs(name, values) - q.group do - q.text(name) - q.text(": [") - q.indent do - q.breakable("") - q.seplist(values) do |(key, value)| - q.group do - q.text("[") - q.indent do - q.breakable("") - visit(key) - q.text(",") - q.breakable - visit(value || nil) - end - q.breakable("") - q.text("]") - end - end - end - q.breakable("") - q.text("]") - end - end - - def text(name, value) - q.nest(0) do - q.text(name) - q.text(": ") - q.pp(value) - end - end - end - end -end diff --git a/lib/syntax_tree/visitor/pretty_print_visitor.rb b/lib/syntax_tree/visitor/pretty_print_visitor.rb deleted file mode 100644 index 674e3aac..00000000 --- a/lib/syntax_tree/visitor/pretty_print_visitor.rb +++ /dev/null @@ -1,85 +0,0 @@ -# frozen_string_literal: true - -module SyntaxTree - class Visitor - # This visitor pretty-prints the AST into an equivalent s-expression. - class PrettyPrintVisitor < FieldVisitor - attr_reader :q - - def initialize(q) - @q = q - end - - # This is here because we need to make sure the operator is cast to a - # string before we print it out. - def visit_binary(node) - node(node, "binary") do - field("left", node.left) - text("operator", node.operator.to_s) - field("right", node.right) - comments(node) - end - end - - # This is here to make it a little nicer to look at labels since they - # typically have their : at the end of the value. - def visit_label(node) - node(node, "label") do - q.breakable - q.text(":") - q.text(node.value[0...-1]) - comments(node) - end - end - - private - - def comments(node) - return if node.comments.empty? - - q.breakable - q.group(2, "(", ")") do - q.seplist(node.comments) { |comment| q.pp(comment) } - end - end - - def field(_name, value) - q.breakable - q.pp(value) - end - - def list(_name, values) - q.breakable - q.group(2, "(", ")") { q.seplist(values) { |value| q.pp(value) } } - end - - def node(_node, type) - q.group(2, "(", ")") do - q.text(type) - yield - end - end - - def pairs(_name, values) - q.group(2, "(", ")") do - q.seplist(values) do |(key, value)| - q.pp(key) - - if value - q.text("=") - q.group(2) do - q.breakable("") - q.pp(value) - end - end - end - end - end - - def text(_name, value) - q.breakable - q.text(value) - end - end - end -end diff --git a/lib/syntax_tree/with_scope.rb b/lib/syntax_tree/with_scope.rb new file mode 100644 index 00000000..8c4908f3 --- /dev/null +++ b/lib/syntax_tree/with_scope.rb @@ -0,0 +1,311 @@ +# frozen_string_literal: true + +module SyntaxTree + # WithScope is a module intended to be included in classes inheriting from + # Visitor. The module overrides a few visit methods to automatically keep + # track of local variables and arguments defined in the current scope. + # Example usage: + # + # class MyVisitor < Visitor + # include WithScope + # + # def visit_ident(node) + # # Check if we're visiting an identifier for an argument, a local + # # variable or something else + # local = current_scope.find_local(node) + # + # if local.type == :argument + # # handle identifiers for arguments + # elsif local.type == :variable + # # handle identifiers for variables + # else + # # handle other identifiers, such as method names + # end + # end + # end + # + module WithScope + # The scope class is used to keep track of local variables and arguments + # inside a particular scope. + class Scope + # This class tracks the occurrences of a local variable or argument. + class Local + # [Symbol] The type of the local (e.g. :argument, :variable) + attr_reader :type + + # [Array[Location]] The locations of all definitions and assignments of + # this local + attr_reader :definitions + + # [Array[Location]] The locations of all usages of this local + attr_reader :usages + + def initialize(type) + @type = type + @definitions = [] + @usages = [] + end + + def add_definition(location) + @definitions << location + end + + def add_usage(location) + @usages << location + end + end + + # [Integer] a unique identifier for this scope + attr_reader :id + + # [scope | nil] The parent scope + attr_reader :parent + + # [Hash[String, Local]] The local variables and arguments defined in this + # scope + attr_reader :locals + + def initialize(id, parent = nil) + @id = id + @parent = parent + @locals = {} + end + + # Adding a local definition will either insert a new entry in the locals + # hash or append a new definition location to an existing local. Notice + # that it's not possible to change the type of a local after it has been + # registered. + def add_local_definition(identifier, type) + name = identifier.value.delete_suffix(":") + + local = + if type == :argument + locals[name] ||= Local.new(type) + else + resolve_local(name, type) + end + + local.add_definition(identifier.location) + end + + # Adding a local usage will either insert a new entry in the locals + # hash or append a new usage location to an existing local. Notice that + # it's not possible to change the type of a local after it has been + # registered. + def add_local_usage(identifier, type) + name = identifier.value.delete_suffix(":") + resolve_local(name, type).add_usage(identifier.location) + end + + # Try to find the local given its name in this scope or any of its + # parents. + def find_local(name) + locals[name] || parent&.find_local(name) + end + + private + + def resolve_local(name, type) + local = find_local(name) + + unless local + local = Local.new(type) + locals[name] = local + end + + local + end + end + + attr_reader :current_scope + + def initialize(*args, **kwargs, &block) + super + + @current_scope = Scope.new(0) + @next_scope_id = 0 + end + + # Visits for nodes that create new scopes, such as classes, modules + # and method definitions. + def visit_class(node) + with_scope { super } + end + + def visit_module(node) + with_scope { super } + end + + # When we find a method invocation with a block, only the code that happens + # inside of the block needs a fresh scope. The method invocation + # itself happens in the same scope. + def visit_method_add_block(node) + visit(node.call) + with_scope(current_scope) { visit(node.block) } + end + + def visit_def(node) + with_scope { super } + end + + # Visit for keeping track of local arguments, such as method and block + # arguments. + def visit_params(node) + add_argument_definitions(node.requireds) + add_argument_definitions(node.posts) + + node.keywords.each do |param| + current_scope.add_local_definition(param.first, :argument) + end + + node.optionals.each do |param| + current_scope.add_local_definition(param.first, :argument) + end + + super + end + + def visit_rest_param(node) + name = node.name + current_scope.add_local_definition(name, :argument) if name + + super + end + + def visit_kwrest_param(node) + name = node.name + current_scope.add_local_definition(name, :argument) if name + + super + end + + def visit_blockarg(node) + name = node.name + current_scope.add_local_definition(name, :argument) if name + + super + end + + def visit_block_var(node) + node.locals.each do |local| + current_scope.add_local_definition(local, :variable) + end + + super + end + alias visit_lambda_var visit_block_var + + # Visit for keeping track of local variable definitions + def visit_var_field(node) + value = node.value + current_scope.add_local_definition(value, :variable) if value.is_a?(Ident) + + super + end + + # Visit for keeping track of local variable definitions + def visit_pinned_var_ref(node) + value = node.value + current_scope.add_local_usage(value, :variable) if value.is_a?(Ident) + + super + end + + # Visits for keeping track of variable and argument usages + def visit_var_ref(node) + value = node.value + + if value.is_a?(Ident) + definition = current_scope.find_local(value.value) + current_scope.add_local_usage(value, definition.type) if definition + end + + super + end + + # When using regex named capture groups, vcalls might actually be a variable + def visit_vcall(node) + value = node.value + definition = current_scope.find_local(value.value) + current_scope.add_local_usage(value, definition.type) if definition + + super + end + + # Visit for capturing local variables defined in regex named capture groups + def visit_binary(node) + if node.operator == :=~ + left = node.left + + if left.is_a?(RegexpLiteral) && left.parts.length == 1 && + left.parts.first.is_a?(TStringContent) + content = left.parts.first + + value = content.value + location = content.location + start_line = location.start_line + + Regexp + .new(value, Regexp::FIXEDENCODING) + .names + .each do |name| + offset = value.index(/\(\?<#{Regexp.escape(name)}>/) + line = start_line + value[0...offset].count("\n") + + # We need to add 3 to account for these three characters + # prefixing a named capture (?< + column = location.start_column + offset + 3 + if value[0...offset].include?("\n") + column = + value[0...offset].length - value[0...offset].rindex("\n") + + 3 - 1 + end + + ident_location = + Location.new( + start_line: line, + start_char: location.start_char + offset, + start_column: column, + end_line: line, + end_char: location.start_char + offset + name.length, + end_column: column + name.length + ) + + identifier = Ident.new(value: name, location: ident_location) + current_scope.add_local_definition(identifier, :variable) + end + end + end + + super + end + + private + + def add_argument_definitions(list) + list.each do |param| + case param + when ArgStar + value = param.value + current_scope.add_local_definition(value, :argument) if value + when MLHSParen + add_argument_definitions(param.contents.parts) + else + current_scope.add_local_definition(param, :argument) + end + end + end + + def next_scope_id + @next_scope_id += 1 + end + + def with_scope(parent_scope = nil) + previous_scope = @current_scope + @current_scope = Scope.new(next_scope_id, parent_scope) + yield + ensure + @current_scope = previous_scope + end + end +end diff --git a/syntax_tree.gemspec b/syntax_tree.gemspec index 06a7ed78..f6c4a734 100644 --- a/syntax_tree.gemspec +++ b/syntax_tree.gemspec @@ -19,14 +19,17 @@ Gem::Specification.new do |spec| .reject { |f| f.match(%r{^(test|spec|features)/}) } end - spec.required_ruby_version = ">= 2.7.3" + spec.required_ruby_version = ">= 2.7.0" spec.bindir = "exe" spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) } spec.require_paths = %w[lib] + spec.add_dependency "prettier_print", ">= 1.2.0" + spec.add_development_dependency "bundler" spec.add_development_dependency "minitest" spec.add_development_dependency "rake" + spec.add_development_dependency "rubocop" spec.add_development_dependency "simplecov" end diff --git a/tasks/sorbet.rake b/tasks/sorbet.rake new file mode 100644 index 00000000..05f48874 --- /dev/null +++ b/tasks/sorbet.rake @@ -0,0 +1,373 @@ +# frozen_string_literal: true + +module SyntaxTree + class RBI + include DSL + + attr_reader :body, :line + + def initialize + @body = [] + @line = 1 + end + + def generate + require "syntax_tree/reflection" + + body << Comment("# typed: strict", false, location) + @line += 2 + + generate_parent + Reflection.nodes.sort.each { |(_, node)| generate_node(node) } + + body << ClassDeclaration( + ConstPathRef(VarRef(Const("SyntaxTree")), Const("BasicVisitor")), + nil, + BodyStmt( + Statements(generate_visitor("overridable")), + nil, + nil, + nil, + nil + ), + location + ) + + body << ClassDeclaration( + ConstPathRef(VarRef(Const("SyntaxTree")), Const("Visitor")), + ConstPathRef(VarRef(Const("SyntaxTree")), Const("BasicVisitor")), + BodyStmt(Statements(generate_visitor("override")), nil, nil, nil, nil), + location + ) + + Formatter.format(nil, Program(Statements(body))) + end + + private + + def generate_comments(comment) + comment + .lines(chomp: true) + .map { |line| Comment("# #{line}", false, location).tap { @line += 1 } } + end + + def generate_parent + attribute = Reflection.nodes[:Program].attributes[:location] + class_location = location + + node_body = generate_comments(attribute.comment) + node_body << sig_block { sig_returns { sig_type_for(attribute.type) } } + @line += 1 + + node_body << Command( + Ident("attr_reader"), + Args([SymbolLiteral(Ident("location"))]), + nil, + location + ) + @line += 1 + + body << ClassDeclaration( + ConstPathRef(VarRef(Const("SyntaxTree")), Const("Node")), + nil, + BodyStmt(Statements(node_body), nil, nil, nil, nil), + class_location + ) + @line += 2 + end + + def generate_node(node) + body.concat(generate_comments(node.comment)) + class_location = location + @line += 2 + + body << ClassDeclaration( + ConstPathRef(VarRef(Const("SyntaxTree")), Const(node.name.to_s)), + ConstPathRef(VarRef(Const("SyntaxTree")), Const("Node")), + BodyStmt(Statements(generate_node_body(node)), nil, nil, nil, nil), + class_location + ) + + @line += 2 + end + + def generate_node_body(node) + node_body = [] + node.attributes.sort.each do |(name, attribute)| + next if name == :location + + node_body.concat(generate_comments(attribute.comment)) + node_body << sig_block { sig_returns { sig_type_for(attribute.type) } } + @line += 1 + + node_body << Command( + Ident("attr_reader"), + Args([SymbolLiteral(Ident(attribute.name.to_s))]), + nil, + location + ) + @line += 2 + end + + node_body.concat(generate_initialize(node)) + + node_body << sig_block do + CallNode( + sig_params do + BareAssocHash( + [Assoc(Label("visitor:"), sig_type_for(BasicVisitor))] + ) + end, + Period("."), + Ident("returns"), + ArgParen( + Args( + [CallNode(VarRef(Const("T")), Period("."), Ident("untyped"), nil)] + ) + ) + ) + end + @line += 1 + + node_body << generate_def_node( + "accept", + Paren( + LParen("("), + Params.new(requireds: [Ident("visitor")], location: location) + ) + ) + @line += 2 + + node_body << generate_child_nodes + @line += 1 + + node_body << generate_def_node("child_nodes", nil) + @line += 2 + + node_body << sig_block do + CallNode( + sig_params do + BareAssocHash( + [ + Assoc( + Label("other:"), + CallNode( + VarRef(Const("T")), + Period("."), + Ident("untyped"), + nil + ) + ) + ] + ) + end, + Period("."), + sig_returns { ConstPathRef(VarRef(Const("T")), Const("Boolean")) }, + nil + ) + end + @line += 1 + + node_body << generate_def_node( + "==", + Paren( + LParen("("), + Params.new(location: location, requireds: [Ident("other")]) + ) + ) + @line += 2 + + node_body + end + + def generate_initialize(node) + parameters = + SyntaxTree.const_get(node.name).instance_method(:initialize).parameters + + assocs = + parameters.map do |(_, name)| + Assoc(Label("#{name}:"), sig_type_for(node.attributes[name].type)) + end + + node_body = [] + node_body << sig_block do + CallNode( + sig_params { BareAssocHash(assocs) }, + Period("."), + Ident("void"), + nil + ) + end + @line += 1 + + params = Params.new(location: location) + parameters.each do |(type, name)| + case type + when :req + params.requireds << Ident(name.to_s) + when :keyreq + params.keywords << [Label("#{name}:"), nil] + when :key + params.keywords << [ + Label("#{name}:"), + CallNode( + VarRef(Const("T")), + Period("."), + Ident("unsafe"), + ArgParen(Args([VarRef(Kw("nil"))])) + ) + ] + else + raise + end + end + + node_body << generate_def_node("initialize", Paren(LParen("("), params)) + @line += 2 + + node_body + end + + def generate_child_nodes + type = + Reflection::Type::ArrayType.new( + Reflection::Type::UnionType.new([NilClass, Node]) + ) + + sig_block { sig_returns { sig_type_for(type) } } + end + + def generate_def_node(name, params) + DefNode( + nil, + nil, + Ident(name), + params, + BodyStmt(Statements([VoidStmt()]), nil, nil, nil, nil), + location + ) + end + + def generate_visitor(override) + body = [] + + Reflection.nodes.each do |name, node| + body << sig_block do + CallNode( + CallNode( + Ident(override), + Period("."), + sig_params do + BareAssocHash( + [ + Assoc( + Label("node:"), + sig_type_for(SyntaxTree.const_get(name)) + ) + ] + ) + end, + nil + ), + Period("."), + sig_returns do + CallNode(VarRef(Const("T")), Period("."), Ident("untyped"), nil) + end, + nil + ) + end + + body << generate_def_node( + node.visitor_method, + Paren( + LParen("("), + Params.new(requireds: [Ident("node")], location: location) + ) + ) + + @line += 2 + end + + body + end + + def sig_block + MethodAddBlock( + CallNode(nil, nil, Ident("sig"), nil), + BlockNode( + LBrace("{"), + nil, + BodyStmt(Statements([yield]), nil, nil, nil, nil) + ), + location + ) + end + + def sig_params + CallNode(nil, nil, Ident("params"), ArgParen(Args([yield]))) + end + + def sig_returns + CallNode(nil, nil, Ident("returns"), ArgParen(Args([yield]))) + end + + def sig_type_for(type) + case type + when Reflection::Type::ArrayType + ARef( + ConstPathRef(VarRef(Const("T")), Const("Array")), + sig_type_for(type.type) + ) + when Reflection::Type::TupleType + ArrayLiteral(LBracket("["), Args(type.types.map { sig_type_for(_1) })) + when Reflection::Type::UnionType + if type.types.include?(NilClass) + selected = type.types.reject { _1 == NilClass } + subtype = + if selected.size == 1 + selected.first + else + Reflection::Type::UnionType.new(selected) + end + + CallNode( + VarRef(Const("T")), + Period("."), + Ident("nilable"), + ArgParen(Args([sig_type_for(subtype)])) + ) + else + CallNode( + VarRef(Const("T")), + Period("."), + Ident("any"), + ArgParen(Args(type.types.map { sig_type_for(_1) })) + ) + end + when Symbol + ConstRef(Const("Symbol")) + else + *parents, constant = type.name.split("::").map { Const(_1) } + + if parents.empty? + ConstRef(constant) + else + [*parents[1..], constant].inject( + VarRef(parents.first) + ) { |accum, const| ConstPathRef(accum, const) } + end + end + end + + def location + Location.fixed(line: line, char: 0, column: 0) + end + end +end + +namespace :sorbet do + desc "Generate RBI files for Sorbet" + task :rbi do + puts SyntaxTree::RBI.new.generate + end +end diff --git a/test/cli_test.rb b/test/cli_test.rb index ade1485c..a0d6001d 100644 --- a/test/cli_test.rb +++ b/test/cli_test.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true require_relative "test_helper" +require "securerandom" module SyntaxTree class CLITest < Minitest::Test @@ -9,6 +10,10 @@ def parse(source) source * 2 end + def format(source, _print_width, **) + "Formatted #{source}" + end + def read(filepath) File.read(filepath) end @@ -20,7 +25,7 @@ def test_handler file = Tempfile.new(%w[test- .test]) file.puts("test") - result = run_cli("ast", file: file) + result = run_cli("ast", contents: file) assert_equal("\"test\\n\" + \"test\\n\"\n", result.stdio) ensure SyntaxTree::HANDLERS.delete(".test") @@ -31,12 +36,16 @@ def test_ast assert_includes(result.stdio, "ident \"test\"") end - def test_ast_syntax_error - file = Tempfile.new(%w[test- .rb]) - file.puts("foo\n<>\nbar\n") + def test_ast_ignore + result = run_cli("ast", "--ignore-files='*/test*'") + assert_equal(0, result.status) + assert_empty(result.stdio) + end - result = run_cli("ast", file: file) + def test_ast_syntax_error + result = run_cli("ast", contents: "foo\n<>\nbar\n") assert_includes(result.stderr, "syntax error") + refute_equal(0, result.status) end def test_check @@ -45,11 +54,20 @@ def test_check end def test_check_unformatted - file = Tempfile.new(%w[test- .rb]) - file.write("foo") - - result = run_cli("check", file: file) + result = run_cli("check", contents: "foo") assert_includes(result.stderr, "expected") + refute_equal(0, result.status) + end + + def test_check_print_width + contents = "#{"a" * 40} + #{"b" * 40}\n" + result = run_cli("check", "--print-width=100", contents: contents) + assert_includes(result.stdio, "match") + end + + def test_check_target_ruby_version + result = run_cli("check", "--target-ruby-version=2.6.0") + assert_includes(result.stdio, "match") end def test_debug @@ -64,6 +82,7 @@ def test_debug_non_idempotent_format SyntaxTree.stub(:format, formatting) do result = run_cli("debug") assert_includes(result.stderr, "idempotently") + refute_equal(0, result.status) end end @@ -72,6 +91,17 @@ def test_doc assert_includes(result.stdio, "test") end + def test_expr + result = run_cli("expr") + assert_includes(result.stdio, "SyntaxTree::Ident") + end + + def test_expr_more_than_one + result = run_cli("expr", contents: "1; 2") + assert_includes(result.stderr, "single expression") + refute_equal(0, result.status) + end + def test_format result = run_cli("format") assert_equal("test\n", result.stdio) @@ -87,6 +117,22 @@ def test_match assert_includes(result.stdio, "SyntaxTree::Program") end + def test_search + result = run_cli("search", "VarRef", contents: "Foo + Bar") + assert_equal(2, result.stdio.lines.length) + end + + def test_search_multi_line + result = run_cli("search", "Binary", contents: "1 +\n2") + assert_equal(1, result.stdio.lines.length) + end + + def test_search_invalid + result = run_cli("search", "FooBar") + assert_includes(result.stderr, "unable") + refute_equal(0, result.status) + end + def test_version result = run_cli("version") assert_includes(result.stdio, SyntaxTree::VERSION.to_s) @@ -96,16 +142,36 @@ def test_write file = Tempfile.new(%w[test- .test]) filepath = file.path - result = run_cli("write", file: file) + result = run_cli("write", contents: file) assert_includes(result.stdio, filepath) end def test_write_syntax_tree - file = Tempfile.new(%w[test- .rb]) - file.write("<>") - - result = run_cli("write", file: file) + result = run_cli("write", contents: "<>") assert_includes(result.stderr, "syntax error") + refute_equal(0, result.status) + end + + def test_write_script + args = ["write", "-e", "1 + 2"] + stdout, stderr = capture_io { SyntaxTree::CLI.run(args) } + + assert_includes stdout, "script" + assert_empty stderr + end + + def test_write_stdin + previous = $stdin + $stdin = StringIO.new("1 + 2") + + begin + stdout, stderr = capture_io { SyntaxTree::CLI.run(["write"]) } + + assert_includes stdout, "stdin" + assert_empty stderr + ensure + $stdin = previous + end end def test_help @@ -114,18 +180,13 @@ def test_help end def test_help_default - *, stderr = capture_io { SyntaxTree::CLI.run(["foobar"]) } + status = 0 + *, stderr = capture_io { status = SyntaxTree::CLI.run(["foobar"]) } assert_includes(stderr, "stree help") + refute_equal(0, status) end def test_no_arguments - $stdin.stub(:tty?, true) do - *, stderr = capture_io { SyntaxTree::CLI.run(["check"]) } - assert_includes(stderr, "stree help") - end - end - - def test_no_arguments_no_tty stdin = $stdin $stdin = StringIO.new("1+1") @@ -135,33 +196,226 @@ def test_no_arguments_no_tty $stdin = stdin end + def test_inline_script + stdio, = capture_io { SyntaxTree::CLI.run(%w[format -e 1+1]) } + assert_equal("1 + 1\n", stdio) + end + + def test_multiple_inline_scripts + stdio, = capture_io { SyntaxTree::CLI.run(%w[format -e 1+1 -e 2+2]) } + assert_equal(["1 + 1", "2 + 2"], stdio.split("\n").sort) + end + + def test_format_script_with_custom_handler + SyntaxTree.register_handler(".test", TestHandler.new) + stdio, = + capture_io do + SyntaxTree::CLI.run(%w[format --extension=test -e ]) + end + assert_equal("Formatted \n", stdio) + ensure + SyntaxTree::HANDLERS.delete(".test") + end + + def test_format_stdin_with_custom_handler + SyntaxTree.register_handler(".test", TestHandler.new) + stdin = $stdin + $stdin = StringIO.new("") + stdio, = capture_io { SyntaxTree::CLI.run(%w[format --extension=test]) } + assert_equal("Formatted \n", stdio) + ensure + $stdin = stdin + SyntaxTree::HANDLERS.delete(".test") + end + def test_generic_error SyntaxTree.stub(:format, ->(*) { raise }) do result = run_cli("format") + refute_equal(0, result.status) end end - private + def test_plugins + with_plugin_directory do |directory| + plugin = directory.plugin("puts 'Hello, world!'") + result = run_cli("format", "--plugins=#{plugin}") - Result = Struct.new(:status, :stdio, :stderr, keyword_init: true) + assert_equal("Hello, world!\ntest\n", result.stdio) + end + end + + def test_language_server + prev_stdin = $stdin + prev_stdout = $stdout + + request = { method: "shutdown" }.merge(jsonrpc: "2.0").to_json + $stdin = + StringIO.new("Content-Length: #{request.bytesize}\r\n\r\n#{request}") + $stdout = StringIO.new + + assert_equal(0, SyntaxTree::CLI.run(["lsp"])) + ensure + $stdin = prev_stdin + $stdout = prev_stdout + end + + def test_config_file + with_plugin_directory do |directory| + plugin = directory.plugin("puts 'Hello, world!'") + config = <<~TXT + --print-width=100 + --plugins=#{plugin} + TXT + + with_config_file(config) do + contents = "#{"a" * 40} + #{"b" * 40}\n" + result = run_cli("format", contents: contents) + + assert_equal("Hello, world!\n#{contents}", result.stdio) + end + end + end + + def test_print_width_args_with_config_file + with_config_file("--print-width=100") do + result = run_cli("check", contents: "#{"a" * 40} + #{"b" * 40}\n") + + assert_includes(result.stdio, "match") + end + end + + def test_print_width_args_with_config_file_override + with_config_file("--print-width=100") do + contents = "#{"a" * 40} + #{"b" * 40}\n" + result = run_cli("check", "--print-width=82", contents: contents) + + assert_includes(result.stderr, "expected") + refute_equal(0, result.status) + end + end + + def test_plugin_args_with_config_file + with_plugin_directory do |directory| + plugin1 = directory.plugin("puts 'Hello, world!'") + + with_config_file("--plugins=#{plugin1}") do + plugin2 = directory.plugin("puts 'Bye, world!'") + result = run_cli("format", "--plugins=#{plugin2}") + + assert_equal("Hello, world!\nBye, world!\ntest\n", result.stdio) + end + end + end - def run_cli(command, file: nil) - if file.nil? - file = Tempfile.new(%w[test- .rb]) - file.puts("test") + def test_config_file_custom_path + with_plugin_directory do |directory| + plugin = directory.plugin("puts 'Custom config!'") + config = <<~TXT + --print-width=80 + --plugins=#{plugin} + TXT + + filepath = File.join(Dir.tmpdir, "#{SecureRandom.hex}.streerc") + with_config_file(config, filepath) do + contents = "#{"a" * 30} + #{"b" * 30}\n" + result = run_cli("format", "--config=#{filepath}", contents: contents) + + assert_equal("Custom config!\n#{contents}", result.stdio) + end end + end + + def test_config_file_custom_path_space_separated + with_plugin_directory do |directory| + plugin = directory.plugin("puts 'Custom config space!'") + config = <<~TXT + --print-width=80 + --plugins=#{plugin} + TXT + + filepath = File.join(Dir.tmpdir, "#{SecureRandom.hex}.streerc") + with_config_file(config, filepath) do + contents = "#{"a" * 30} + #{"b" * 30}\n" + result = run_cli("format", "--config", filepath, contents: contents) + + assert_equal("Custom config space!\n#{contents}", result.stdio) + end + end + end + + def test_config_file_nonexistent_path + assert_raises(ArgumentError) do + run_cli("format", "--config=/nonexistent/path.streerc") + end + end - file.rewind + Result = Struct.new(:status, :stdio, :stderr, keyword_init: true) + + private + + def run_cli(command, *args, contents: :default) + tempfile = + case contents + when :default + Tempfile.new(%w[test- .rb]).tap { |file| file.puts("test") } + when String + Tempfile.new(%w[test- .rb]).tap { |file| file.write(contents) } + else + contents + end + + tempfile.rewind status = nil stdio, stderr = - capture_io { status = SyntaxTree::CLI.run([command, file.path]) } + capture_io do + status = + begin + SyntaxTree::CLI.run([command, *args, tempfile.path]) + rescue SystemExit => error + error.status + end + end Result.new(status: status, stdio: stdio, stderr: stderr) ensure - file.close - file.unlink + tempfile.close + tempfile.unlink + end + + def with_config_file(contents, filepath = nil) + filepath ||= File.join(Dir.pwd, SyntaxTree::CLI::ConfigFile::FILENAME) + File.write(filepath, contents) + + yield + ensure + FileUtils.rm(filepath) + end + + class PluginDirectory + attr_reader :directory + + def initialize(directory) + @directory = directory + end + + def plugin(contents) + name = SecureRandom.hex + File.write(File.join(directory, "#{name}.rb"), contents) + name + end + end + + def with_plugin_directory + Dir.mktmpdir do |directory| + $:.unshift(directory) + + plugin_directory = File.join(directory, "syntax_tree") + Dir.mkdir(plugin_directory) + + yield PluginDirectory.new(plugin_directory) + end end end end diff --git a/test/fixtures/arg_paren.rb b/test/fixtures/arg_paren.rb index 0e01e208..0816af6a 100644 --- a/test/fixtures/arg_paren.rb +++ b/test/fixtures/arg_paren.rb @@ -2,8 +2,6 @@ foo(bar) % foo() -- -foo % foo(barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr) - diff --git a/test/fixtures/args_forward.rb b/test/fixtures/args_forward.rb index 5ba618a8..cc538f44 100644 --- a/test/fixtures/args_forward.rb +++ b/test/fixtures/args_forward.rb @@ -1,4 +1,4 @@ -% +% # >= 2.7.3 def foo(...) bar(:baz, ...) end diff --git a/test/fixtures/array_literal.rb b/test/fixtures/array_literal.rb index df807728..391d2eae 100644 --- a/test/fixtures/array_literal.rb +++ b/test/fixtures/array_literal.rb @@ -24,9 +24,16 @@ - fooooooooooooooooo = 1 [ - fooooooooooooooooo, fooooooooooooooooo, fooooooooooooooooo, - fooooooooooooooooo, fooooooooooooooooo, fooooooooooooooooo, - fooooooooooooooooo, fooooooooooooooooo, fooooooooooooooooo, fooooooooooooooooo + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo, + fooooooooooooooooo ] % [ diff --git a/test/fixtures/aryptn.rb b/test/fixtures/aryptn.rb index c5562305..64d5d9d0 100644 --- a/test/fixtures/aryptn.rb +++ b/test/fixtures/aryptn.rb @@ -4,53 +4,110 @@ end % case foo +in [] then +end +- +case foo +in [] +end +% +case foo +in * then +end +- +case foo +in [*] +end +% +case foo in _, _ end +- +case foo +in [_, _] +end % case foo in bar, baz end +- +case foo +in [bar, baz] +end % case foo in [bar] end % case foo -in [bar, baz] +in [bar] +in [baz] end -- +% case foo -in bar, baz +in [bar, baz] end % case foo in bar, *baz end +- +case foo +in [bar, *baz] +end % case foo in *bar, baz end +- +case foo +in [*bar, baz] +end % case foo in bar, *, baz end +- +case foo +in [bar, *, baz] +end % case foo in *, bar, baz end +- +case foo +in [*, bar, baz] +end % case foo in Constant[bar] end % case foo +in Constant(bar) +end +- +case foo +in Constant[bar] +end +% +case foo in Constant[bar, baz] end % case foo in bar, [baz, _] => qux end +- +case foo +in [bar, [baz, _] => qux] +end % case foo in bar, baz if bar == baz end +- +case foo +in [bar, baz] if bar == baz +end diff --git a/test/fixtures/assoc.rb b/test/fixtures/assoc.rb index cd3e5ed1..83a4887a 100644 --- a/test/fixtures/assoc.rb +++ b/test/fixtures/assoc.rb @@ -46,3 +46,9 @@ { foo: "bar" } % { "foo #{bar}": "baz" } +% +{ "foo=": "baz" } +% # >= 3.1.0 +{ bar => 1, baz: } +% # >= 3.1.0 +{ baz:, bar => 1 } diff --git a/test/fixtures/assoc_splat.rb b/test/fixtures/assoc_splat.rb index 2182c2ed..8b595ce9 100644 --- a/test/fixtures/assoc_splat.rb +++ b/test/fixtures/assoc_splat.rb @@ -12,3 +12,7 @@ } - { **foo } +% # >= 3.2.0 +def foo(**) + bar(**) +end diff --git a/test/fixtures/binary.rb b/test/fixtures/binary.rb index f8833cdc..4cb56cbf 100644 --- a/test/fixtures/binary.rb +++ b/test/fixtures/binary.rb @@ -3,6 +3,11 @@ % foo << bar % +foo << barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr << barrrrrrrrrrrrr << barrrrrrrrrrrrrrrrrr +- +foo << barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr << barrrrrrrrrrrrr << + barrrrrrrrrrrrrrrrrr +% foo**bar % foo * barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr diff --git a/test/fixtures/bodystmt.rb b/test/fixtures/bodystmt.rb index 4cbb8f5e..5999fdba 100644 --- a/test/fixtures/bodystmt.rb +++ b/test/fixtures/bodystmt.rb @@ -36,6 +36,7 @@ end % begin +rescue StandardError else # else end % diff --git a/test/fixtures/break.rb b/test/fixtures/break.rb index a77c6b35..23277f6b 100644 --- a/test/fixtures/break.rb +++ b/test/fixtures/break.rb @@ -1,29 +1,45 @@ % -break +tap { break } % -break foo +tap { break foo } % -break foo, bar +tap { break foo, bar } % -break(foo) +tap { break(foo) } % -break fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo +tap { break fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo } - -break( - fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo -) +tap do + break( + fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo + ) +end % -break(fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo) +tap { break(fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo) } - -break( - fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo -) -% -break (foo), bar -% -break( - foo - bar -) -% -break foo.bar :baz do |qux| qux end +tap do + break( + fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo + ) +end +% +tap { break (foo), bar } +% +tap do + break( + foo + bar + ) +end +% +tap { break foo.bar :baz do |qux| qux end } +- +tap do + break( + foo.bar :baz do |qux| + qux + end + ) +end +% +tap { break :foo => "bar" } diff --git a/test/fixtures/call.rb b/test/fixtures/call.rb index f3333276..eec717f0 100644 --- a/test/fixtures/call.rb +++ b/test/fixtures/call.rb @@ -1,6 +1,8 @@ % foo.bar % +foo.bar(baz) +% foo.() % foo::() @@ -21,3 +23,52 @@ .barrrrrrrrrrrrrrrrrrr {} .bazzzzzzzzzzzzzzzzzzzzzzzzzz .quxxxxxxxxx +% +foo. # comment + bar +% +foo + .bar + .baz # comment + .qux + .quux +% +foo + .bar + .baz. + # comment + qux + .quux +% +{ a: 1, b: 2 }.fooooooooooooooooo.barrrrrrrrrrrrrrrrrrr.bazzzzzzzzzzzz.quxxxxxxxxxxxx +- +{ a: 1, b: 2 }.fooooooooooooooooo + .barrrrrrrrrrrrrrrrrrr + .bazzzzzzzzzzzz + .quxxxxxxxxxxxx +% +fooooooooooooooooo.barrrrrrrrrrrrrrrrrrr.bazzzzzzzzzzzz.quxxxxxxxxxxxx.each { block } +- +fooooooooooooooooo.barrrrrrrrrrrrrrrrrrr.bazzzzzzzzzzzz.quxxxxxxxxxxxx.each do + block +end +% +foo.bar.baz.each do + block1 + block2 +end +% +a b do +end.c d +% +self. +=begin +=end + to_s +% +fooooooooooooooooooooooooooooooooooo.barrrrrrrrrrrrrrrrrrrrrrrrrrrrrr.where.not(:id).order(:id) +- +fooooooooooooooooooooooooooooooooooo + .barrrrrrrrrrrrrrrrrrrrrrrrrrrrrr + .where.not(:id) + .order(:id) diff --git a/test/fixtures/command_call.rb b/test/fixtures/command_call.rb index 5060ffa4..7c055e8d 100644 --- a/test/fixtures/command_call.rb +++ b/test/fixtures/command_call.rb @@ -28,3 +28,49 @@ % foo.bar baz do end +% +foo. + # comment + bar baz +% +foo.bar baz ? qux : qaz +% +expect foo, bar.map { |i| { quux: bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz } } +- +expect foo, + bar.map { |i| + { + quux: + bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz + } + } +% +expect(foo, bar.map { |i| {quux: bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz} }) +- +expect( + foo, + bar.map do |i| + { + quux: + bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz + } + end +) +% +expect(foo.map { |i| { bar: i.bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz } } ).to match(baz.map { |i| { bar: i.bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz } }) +- +expect( + foo.map do |i| + { + bar: + i.bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz + } + end +).to match( + baz.map do |i| + { + bar: + i.bazzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz + } + end +) diff --git a/test/fixtures/def.rb b/test/fixtures/def.rb index a827adfe..0cc49e0a 100644 --- a/test/fixtures/def.rb +++ b/test/fixtures/def.rb @@ -23,3 +23,9 @@ def foo() # comment def foo( # comment ) end +% +def +=begin +=end +a +end diff --git a/test/fixtures/def_endless.rb b/test/fixtures/def_endless.rb index dbac88bb..8d1f9d33 100644 --- a/test/fixtures/def_endless.rb +++ b/test/fixtures/def_endless.rb @@ -4,8 +4,6 @@ def foo = bar def foo(bar) = baz % def foo() = bar -- -def foo = bar % # >= 3.1.0 def foo = bar baz % # >= 3.1.0 @@ -14,7 +12,23 @@ def self.foo = bar def self.foo(bar) = baz % # >= 3.1.0 def self.foo() = bar -- -def self.foo = bar % # >= 3.1.0 def self.foo = bar baz +% +begin + true +rescue StandardError + false +end + +def foo? = true +% +def a() +=begin +=end +=1 +- +def a() = +=begin +=end + 1 diff --git a/test/fixtures/do_block.rb b/test/fixtures/do_block.rb index 016f27b2..8ea4f75f 100644 --- a/test/fixtures/do_block.rb +++ b/test/fixtures/do_block.rb @@ -14,3 +14,15 @@ foo :bar do baz end +% +sig do + override.params(contacts: Contact::ActiveRecord_Relation).returns( + Customer::ActiveRecord_Relation + ) +end +- +sig do + override + .params(contacts: Contact::ActiveRecord_Relation) + .returns(Customer::ActiveRecord_Relation) +end diff --git a/test/fixtures/elsif.rb b/test/fixtures/elsif.rb index 2e4cd831..e0dd2bd6 100644 --- a/test/fixtures/elsif.rb +++ b/test/fixtures/elsif.rb @@ -17,3 +17,8 @@ else qyz end +% +if true +elsif false # comment1 + # comment2 +end diff --git a/test/fixtures/for.rb b/test/fixtures/for.rb index 62b207ee..1346a367 100644 --- a/test/fixtures/for.rb +++ b/test/fixtures/for.rb @@ -38,3 +38,7 @@ for foo, in [[foo, bar]] foo end +% +for foo in bar # comment1 + # comment2 +end diff --git a/test/fixtures/hash.rb b/test/fixtures/hash.rb index 9c43a4fe..70e89f69 100644 --- a/test/fixtures/hash.rb +++ b/test/fixtures/hash.rb @@ -29,3 +29,5 @@ { # comment } +% # >= 3.1.0 +{ foo:, "bar" => "baz" } diff --git a/test/fixtures/hshptn.rb b/test/fixtures/hshptn.rb index 2935f9c1..02d1cf75 100644 --- a/test/fixtures/hshptn.rb +++ b/test/fixtures/hshptn.rb @@ -30,7 +30,7 @@ case foo in **bar end -% +% # >= 2.7.3 case foo in { foo:, # comment1 @@ -64,5 +64,23 @@ end % case foo +in {} then +end +- +case foo +in {} +end +% +case foo in **nil end +% +case foo +in bar, { baz:, **nil } +in qux: +end +- +case foo +in [bar, { baz:, **nil }] +in qux: +end diff --git a/test/fixtures/if.rb b/test/fixtures/if.rb index cabea4c3..b25386b9 100644 --- a/test/fixtures/if.rb +++ b/test/fixtures/if.rb @@ -35,3 +35,42 @@ % if foo {} end +% +if not a + b +else + c +end +% +if not(a) + b +else + c +end +- +not(a) ? b : c +% +(if foo then bar else baz end) +- +( + if foo + bar + else + baz + end +) +% +if (x = x + 1).to_i + x +end +% +if true # comment1 + # comment2 +end +% +result = + if false && val = 1 + "A" + else + "B" + end diff --git a/test/fixtures/ifop.rb b/test/fixtures/ifop.rb index 541e667e..f7504658 100644 --- a/test/fixtures/ifop.rb +++ b/test/fixtures/ifop.rb @@ -10,3 +10,11 @@ end % foo bar ? 1 : 2 +% +tap { foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo ? break : baz } +- +tap do + foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo ? + break : + baz +end diff --git a/test/fixtures/in.rb b/test/fixtures/in.rb index 1e1b2282..59102505 100644 --- a/test/fixtures/in.rb +++ b/test/fixtures/in.rb @@ -14,8 +14,10 @@ end - case foo -in fooooooooooooooooooooooooooooooooooooo, - barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr +in [ + fooooooooooooooooooooooooooooooooooooo, + barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr + ] baz end % diff --git a/test/fixtures/lambda.rb b/test/fixtures/lambda.rb index 043ceb5a..8b922ef0 100644 --- a/test/fixtures/lambda.rb +++ b/test/fixtures/lambda.rb @@ -1,4 +1,6 @@ % +-> {} +% -> { foo } % ->(foo, bar) { baz } @@ -40,3 +42,69 @@ -> { -> foo do bar end.baz }.qux - -> { ->(foo) { bar }.baz }.qux +% +->(;a) {} +- +->(; a) {} +% +->(; a) {} +% +->(; a,b) {} +- +->(; a, b) {} +% +->(; a, b) {} +% +->(; +a +) {} +- +->(; a) {} +% +->(; a , +b +) {} +- +->(; a, b) {} +% +->(a = (b; c)) {} +- +->( + a = ( + b + c + ) +) do +end +% +-> do # comment1 + # comment2 +end +% # multiline lambda in a command +command "arg" do + -> { + multi + line + } +end +- +command "arg" do + -> do + multi + line + end +end +% # multiline lambda in a command call +command.call "arg" do + -> { + multi + line + } +end +- +command.call "arg" do + -> do + multi + line + end +end diff --git a/test/fixtures/next.rb b/test/fixtures/next.rb index be667951..dc159488 100644 --- a/test/fixtures/next.rb +++ b/test/fixtures/next.rb @@ -1,67 +1,82 @@ % -next +tap { next } % -next foo +tap { next foo } % -next foo, bar +tap { next foo, bar } % -next(foo) +tap { next(foo) } % -next fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo +tap { next fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo } - -next( - fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo -) +tap do + next( + fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo + ) +end % -next(fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo) +tap { next(fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo) } - -next( - fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo -) -% -next (foo), bar -% -next( - foo - bar -) +tap do + next( + fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo + ) +end +% +tap { next (foo), bar } +% +tap do + next( + foo + bar + ) +end +% +tap { next(1) } +- +tap { next 1 } % -next(1) +tap { next(1.0) } - -next 1 +tap { next 1.0 } % -next(1.0) +tap { next($a) } - -next 1.0 +tap { next $a } % -next($a) +tap { next(@@a) } - -next $a +tap { next @@a } % -next(@@a) +tap { next(self) } - -next @@a +tap { next self } % -next(self) +tap { next(@a) } - -next self +tap { next @a } % -next(@a) +tap { next(A) } - -next @a +tap { next A } % -next(A) +tap { next([]) } - -next A +tap { next [] } % -next([]) +tap { next([1]) } - -next [] +tap { next [1] } % -next([1]) +tap { next([1, 2]) } - -next [1] +tap { next 1, 2 } % -next([1, 2]) +tap { next fun foo do end } - -next 1, 2 +tap do + next( + fun foo do + end + ) +end diff --git a/test/fixtures/params.rb b/test/fixtures/params.rb index 67b6ec90..551aa9a5 100644 --- a/test/fixtures/params.rb +++ b/test/fixtures/params.rb @@ -16,7 +16,7 @@ def foo(*) % def foo(*rest) end -% +% # >= 2.7.3 def foo(...) end % diff --git a/test/fixtures/rassign.rb b/test/fixtures/rassign.rb index 882ce890..3d357351 100644 --- a/test/fixtures/rassign.rb +++ b/test/fixtures/rassign.rb @@ -12,3 +12,20 @@ - foooooooooooooooooooooooooooooooooooooo => barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr +% +foo => [ + ConstantConstantConstant, + ConstantConstantConstant, + ConstantConstantConstant, + ConstantConstantConstant, + ConstantConstantConstant +] +% +a in Integer +b => [Integer => c] +% +case [0] +when 0 + { a: 0 } => { a: } + puts a +end diff --git a/test/fixtures/redo.rb b/test/fixtures/redo.rb index 8ab087a2..962af3d0 100644 --- a/test/fixtures/redo.rb +++ b/test/fixtures/redo.rb @@ -1,4 +1,6 @@ % -redo +tap { redo } % -redo # comment +tap do + redo # comment +end diff --git a/test/fixtures/retry.rb b/test/fixtures/retry.rb index 2b14d21a..47b6be51 100644 --- a/test/fixtures/retry.rb +++ b/test/fixtures/retry.rb @@ -1,4 +1,10 @@ % -retry +begin +rescue StandardError + retry +end % -retry # comment +begin +rescue StandardError + retry # comment +end diff --git a/test/fixtures/return.rb b/test/fixtures/return.rb index 8f7d0aa3..7092464f 100644 --- a/test/fixtures/return.rb +++ b/test/fixtures/return.rb @@ -37,3 +37,5 @@ return [] % return [1] +% +return :foo => "bar" diff --git a/test/fixtures/string_literal.rb b/test/fixtures/string_literal.rb index ebe56a40..d8ee0cdb 100644 --- a/test/fixtures/string_literal.rb +++ b/test/fixtures/string_literal.rb @@ -41,4 +41,8 @@ % '"foo"' - -"\"foo\"" +'"foo"' +% +"'foo'" +- +"'foo'" diff --git a/test/fixtures/symbols.rb b/test/fixtures/symbols.rb index 5e2673f3..12f0a22f 100644 --- a/test/fixtures/symbols.rb +++ b/test/fixtures/symbols.rb @@ -19,3 +19,8 @@ %I[foo] # comment % %I{foo[]} +% +:\ +=begin +=end +symbol diff --git a/test/fixtures/unless.rb b/test/fixtures/unless.rb index c66b16bf..2d5038c1 100644 --- a/test/fixtures/unless.rb +++ b/test/fixtures/unless.rb @@ -32,3 +32,7 @@ unless foo a ? b : c end +% +unless true # comment1 + # comment2 +end diff --git a/test/fixtures/until.rb b/test/fixtures/until.rb index 778e3fb0..f3ef5202 100644 --- a/test/fixtures/until.rb +++ b/test/fixtures/until.rb @@ -23,3 +23,7 @@ until (foo += 1) foo end +% +until true # comment1 + # comment2 +end diff --git a/test/fixtures/var_field_rassign.rb b/test/fixtures/var_field_rassign.rb index 3e019c5c..aa5ec379 100644 --- a/test/fixtures/var_field_rassign.rb +++ b/test/fixtures/var_field_rassign.rb @@ -1,6 +1,7 @@ % foo in bar % +bar = 1 foo in ^bar % foo in ^@bar diff --git a/test/fixtures/while.rb b/test/fixtures/while.rb index 1404f07d..9415135a 100644 --- a/test/fixtures/while.rb +++ b/test/fixtures/while.rb @@ -23,3 +23,7 @@ while (foo += 1) foo end +% +while true # comment1 + # comment2 +end diff --git a/test/fixtures/yield.rb b/test/fixtures/yield.rb index f3f023f8..3cf1e5f1 100644 --- a/test/fixtures/yield.rb +++ b/test/fixtures/yield.rb @@ -1,16 +1,30 @@ % -yield foo +def foo + yield foo +end % -yield(foo) +def foo + yield(foo) +end % -yield foo, bar +def foo + yield foo, bar +end % -yield(foo, bar) +def foo + yield(foo, bar) +end % -yield foo # comment +def foo + yield foo # comment +end % -yield(foo) # comment +def foo + yield(foo) # comment +end % -yield( # comment - foo -) +def foo + yield( # comment + foo + ) +end diff --git a/test/fixtures/yield0.rb b/test/fixtures/yield0.rb index a168c4aa..c1833bb5 100644 --- a/test/fixtures/yield0.rb +++ b/test/fixtures/yield0.rb @@ -1,4 +1,8 @@ % -yield +def foo + yield +end % -yield # comment +def foo + yield # comment +end diff --git a/test/formatting_test.rb b/test/formatting_test.rb index eff7ef71..5e5f9e9f 100644 --- a/test/formatting_test.rb +++ b/test/formatting_test.rb @@ -7,6 +7,7 @@ class FormattingTest < Minitest::Test Fixtures.each_fixture do |fixture| define_method(:"test_formatted_#{fixture.name}") do assert_equal(fixture.formatted, SyntaxTree.format(fixture.source)) + assert_syntax_tree(SyntaxTree.parse(fixture.source)) end end @@ -27,5 +28,37 @@ def test_stree_ignore assert_equal(source, SyntaxTree.format(source)) end + + def test_formatting_with_different_indentation_level + source = <<~SOURCE + def foo + puts "a" + end + SOURCE + + # Default indentation + assert_equal(source, SyntaxTree.format(source)) + + # Level 2 + assert_equal(<<-EXPECTED.chomp, SyntaxTree.format(source, 80, 2).rstrip) + def foo + puts "a" + end + EXPECTED + + # Level 4 + assert_equal(<<-EXPECTED.chomp, SyntaxTree.format(source, 80, 4).rstrip) + def foo + puts "a" + end + EXPECTED + + # Level 6 + assert_equal(<<-EXPECTED.chomp, SyntaxTree.format(source, 80, 6).rstrip) + def foo + puts "a" + end + EXPECTED + end end end diff --git a/test/idempotency_test.rb b/test/idempotency_test.rb index 1f560db2..32d9d196 100644 --- a/test/idempotency_test.rb +++ b/test/idempotency_test.rb @@ -1,6 +1,6 @@ # frozen_string_literal: true -return unless ENV["CI"] +return if !ENV["CI"] || RUBY_ENGINE == "truffleruby" require_relative "test_helper" module SyntaxTree diff --git a/test/index_test.rb b/test/index_test.rb new file mode 100644 index 00000000..1e2a7fc7 --- /dev/null +++ b/test/index_test.rb @@ -0,0 +1,183 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class IndexTest < Minitest::Test + def test_module + index_each("module Foo; end") do |entry| + assert_equal :Foo, entry.name + assert_equal [[:Foo]], entry.nesting + end + end + + def test_module_nested + index_each("module Foo; module Bar; end; end") do |entry| + assert_equal :Bar, entry.name + assert_equal [[:Foo], [:Bar]], entry.nesting + end + end + + def test_module_comments + index_each("# comment1\n# comment2\nmodule Foo; end") do |entry| + assert_equal :Foo, entry.name + assert_equal ["# comment1", "# comment2"], entry.comments.to_a + end + end + + def test_class + index_each("class Foo; end") do |entry| + assert_equal :Foo, entry.name + assert_equal [[:Foo]], entry.nesting + end + end + + def test_class_paths_2 + index_each("class Foo::Bar; end") do |entry| + assert_equal :Bar, entry.name + assert_equal [%i[Foo Bar]], entry.nesting + end + end + + def test_class_paths_3 + index_each("class Foo::Bar::Baz; end") do |entry| + assert_equal :Baz, entry.name + assert_equal [%i[Foo Bar Baz]], entry.nesting + end + end + + def test_class_nested + index_each("class Foo; class Bar; end; end") do |entry| + assert_equal :Bar, entry.name + assert_equal [[:Foo], [:Bar]], entry.nesting + end + end + + def test_class_paths_nested + index_each("class Foo; class Bar::Baz::Qux; end; end") do |entry| + assert_equal :Qux, entry.name + assert_equal [[:Foo], %i[Bar Baz Qux]], entry.nesting + end + end + + def test_class_superclass + index_each("class Foo < Bar; end") do |entry| + assert_equal :Foo, entry.name + assert_equal [[:Foo]], entry.nesting + assert_equal [:Bar], entry.superclass + end + end + + def test_class_path_superclass + index_each("class Foo::Bar < Baz::Qux; end") do |entry| + assert_equal :Bar, entry.name + assert_equal [%i[Foo Bar]], entry.nesting + assert_equal %i[Baz Qux], entry.superclass + end + end + + def test_class_comments + index_each("# comment1\n# comment2\nclass Foo; end") do |entry| + assert_equal :Foo, entry.name + assert_equal ["# comment1", "# comment2"], entry.comments.to_a + end + end + + def test_method + index_each("def foo; end") do |entry| + assert_equal :foo, entry.name + assert_empty entry.nesting + end + end + + def test_method_nested + index_each("class Foo; def foo; end; end") do |entry| + assert_equal :foo, entry.name + assert_equal [[:Foo]], entry.nesting + end + end + + def test_method_comments + index_each("# comment1\n# comment2\ndef foo; end") do |entry| + assert_equal :foo, entry.name + assert_equal ["# comment1", "# comment2"], entry.comments.to_a + end + end + + def test_singleton_method + index_each("def self.foo; end") do |entry| + assert_equal :foo, entry.name + assert_empty entry.nesting + end + end + + def test_singleton_method_nested + index_each("class Foo; def self.foo; end; end") do |entry| + assert_equal :foo, entry.name + assert_equal [[:Foo]], entry.nesting + end + end + + def test_singleton_method_comments + index_each("# comment1\n# comment2\ndef self.foo; end") do |entry| + assert_equal :foo, entry.name + assert_equal ["# comment1", "# comment2"], entry.comments.to_a + end + end + + def test_alias_method + index_each("alias foo bar") do |entry| + assert_equal :foo, entry.name + assert_empty entry.nesting + end + end + + def test_attr_reader + index_each("attr_reader :foo") do |entry| + assert_equal :foo, entry.name + assert_empty entry.nesting + end + end + + def test_attr_writer + index_each("attr_writer :foo") do |entry| + assert_equal :foo=, entry.name + assert_empty entry.nesting + end + end + + def test_attr_accessor + index_each("attr_accessor :foo") do |entry| + assert_equal :foo=, entry.name + assert_empty entry.nesting + end + end + + def test_constant + index_each("FOO = 1") do |entry| + assert_equal :FOO, entry.name + assert_empty entry.nesting + end + end + + def test_this_file + entries = Index.index_file(__FILE__, backend: Index::ParserBackend.new) + + if defined?(RubyVM::InstructionSequence) + entries += Index.index_file(__FILE__, backend: Index::ISeqBackend.new) + end + + entries.map { |entry| entry.comments.to_a } + end + + private + + def index_each(source) + yield Index.index(source, backend: Index::ParserBackend.new).last + + if defined?(RubyVM::InstructionSequence) + yield Index.index(source, backend: Index::ISeqBackend.new).last + end + end + end +end diff --git a/test/interface_test.rb b/test/interface_test.rb deleted file mode 100644 index 49a74e92..00000000 --- a/test/interface_test.rb +++ /dev/null @@ -1,68 +0,0 @@ -# frozen_string_literal: true - -require_relative "test_helper" - -module SyntaxTree - class InterfaceTest < Minitest::Test - ObjectSpace.each_object(Node.singleton_class) do |klass| - next if klass == Node - - define_method(:"test_instantiate_#{klass.name}") do - assert_syntax_tree(instantiate(klass)) - end - end - - Fixtures.each_fixture do |fixture| - define_method(:"test_#{fixture.name}") do - assert_syntax_tree(SyntaxTree.parse(fixture.source)) - end - end - - private - - # This method is supposed to instantiate a new instance of the given class. - # The class is always a descendant from SyntaxTree::Node, so we can make - # certain assumptions about the way the initialize method is set up. If it - # needs to be special-cased, it's done so at the end of this method. - def instantiate(klass) - params = {} - - # Set up all of the keyword parameters for the class. - klass - .instance_method(:initialize) - .parameters - .each { |(type, name)| params[name] = nil if type.start_with?("key") } - - # Set up any default values that have to be arrays. - %i[ - assocs - comments - elements - keywords - locals - optionals - parts - posts - requireds - symbols - values - ].each { |key| params[key] = [] if params.key?(key) } - - # Set up a default location for the node. - params[:location] = Location.fixed(line: 0, char: 0, column: 0) - - case klass.name - when "SyntaxTree::Binary" - klass.new(**params, operator: :+) - when "SyntaxTree::Label" - klass.new(**params, value: "label:") - when "SyntaxTree::RegexpLiteral" - klass.new(**params, ending: "/") - when "SyntaxTree::Statements" - klass.new(nil, **params, body: []) - else - klass.new(**params) - end - end - end -end diff --git a/test/language_server/inlay_hints_test.rb b/test/language_server/inlay_hints_test.rb new file mode 100644 index 00000000..d3741894 --- /dev/null +++ b/test/language_server/inlay_hints_test.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require "syntax_tree/language_server" + +module SyntaxTree + class LanguageServer + class InlayHintsTest < Minitest::Test + def test_assignments_in_parameters + assert_hints(2, "def foo(a = b = c); end") + end + + def test_operators_in_binaries + assert_hints(2, "1 + 2 * 3") + end + + def test_binaries_in_assignments + assert_hints(2, "a = 1 + 2") + end + + def test_nested_ternaries + assert_hints(2, "a ? b : c ? d : e") + end + + def test_bare_rescue + assert_hints(1, "begin; rescue; end") + end + + def test_unary_in_binary + assert_hints(2, "-a + b") + end + + private + + def assert_hints(expected, source) + visitor = InlayHints.new + SyntaxTree.parse(source).accept(visitor) + + assert_equal(expected, visitor.hints.length) + end + end + end +end diff --git a/test/language_server_test.rb b/test/language_server_test.rb new file mode 100644 index 00000000..54455c95 --- /dev/null +++ b/test/language_server_test.rb @@ -0,0 +1,357 @@ +# frozen_string_literal: true + +require_relative "test_helper" +require "syntax_tree/language_server" + +module SyntaxTree + # stree-ignore + class LanguageServerTest < Minitest::Test + class Initialize + attr_reader :id + + def initialize(id) + @id = id + end + + def to_hash + { method: "initialize", id: id } + end + end + + class Shutdown + attr_reader :id + + def initialize(id) + @id = id + end + + def to_hash + { method: "shutdown", id: id } + end + end + + class TextDocumentDidOpen + attr_reader :uri, :text + + def initialize(uri, text) + @uri = uri + @text = text + end + + def to_hash + { + method: "textDocument/didOpen", + params: { textDocument: { uri: uri, text: text } } + } + end + end + + class TextDocumentDidChange + attr_reader :uri, :text + + def initialize(uri, text) + @uri = uri + @text = text + end + + def to_hash + { + method: "textDocument/didChange", + params: { + textDocument: { uri: uri }, + contentChanges: [{ text: text }] + } + } + end + end + + class TextDocumentDidClose + attr_reader :uri + + def initialize(uri) + @uri = uri + end + + def to_hash + { + method: "textDocument/didClose", + params: { textDocument: { uri: uri } } + } + end + end + + class TextDocumentFormatting + attr_reader :id, :uri + + def initialize(id, uri) + @id = id + @uri = uri + end + + def to_hash + { + method: "textDocument/formatting", + id: id, + params: { textDocument: { uri: uri } } + } + end + end + + class TextDocumentInlayHint + attr_reader :id, :uri + + def initialize(id, uri) + @id = id + @uri = uri + end + + def to_hash + { + method: "textDocument/inlayHint", + id: id, + params: { textDocument: { uri: uri } } + } + end + end + + class SyntaxTreeVisualizing + attr_reader :id, :uri + + def initialize(id, uri) + @id = id + @uri = uri + end + + def to_hash + { + method: "syntaxTree/visualizing", + id: id, + params: { textDocument: { uri: uri } } + } + end + end + + def test_formatting + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", "class Foo; end"), + TextDocumentDidChange.new("file:///path/to/file.rb", "class Bar; end"), + TextDocumentFormatting.new(2, "file:///path/to/file.rb"), + TextDocumentDidClose.new("file:///path/to/file.rb"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: [{ newText: :any }] }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_equal("class Bar\nend\n", responses.dig(1, :result, 0, :newText)) + end + + def test_formatting_ignore + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", "class Foo; end"), + TextDocumentFormatting.new(2, "file:///path/to/file.rb"), + Shutdown.new(3) + ], ignore_files: ["path/**/*.rb"]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: :any }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_nil(responses.dig(1, :result)) + end + + def test_formatting_failure + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", "<>"), + TextDocumentFormatting.new(2, "file:///path/to/file.rb"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: :any }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_nil(responses.dig(1, :result)) + end + + def test_formatting_print_width + contents = "#{"a" * 40} + #{"b" * 40}\n" + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", contents), + TextDocumentFormatting.new(2, "file:///path/to/file.rb"), + TextDocumentDidClose.new("file:///path/to/file.rb"), + Shutdown.new(3) + ], print_width: 100) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: [{ newText: :any }] }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_equal(contents, responses.dig(1, :result, 0, :newText)) + end + + def test_inlay_hint + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", <<~RUBY), + begin + 1 + 2 * 3 + rescue + end + RUBY + TextDocumentInlayHint.new(2, "file:///path/to/file.rb"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: :any }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_equal(3, responses.dig(1, :result).size) + end + + def test_inlay_hint_invalid + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", "<>"), + TextDocumentInlayHint.new(2, "file:///path/to/file.rb"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: :any }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_equal(0, responses.dig(1, :result).size) + end + + def test_visualizing + responses = run_server([ + Initialize.new(1), + TextDocumentDidOpen.new("file:///path/to/file.rb", "1 + 2"), + SyntaxTreeVisualizing.new(2, "file:///path/to/file.rb"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: :any }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_equal( + "(program (statements ((binary (int \"1\") + (int \"2\")))))\n", + responses.dig(1, :result) + ) + end + + def test_reading_file + Tempfile.open(%w[test- .rb]) do |file| + file.write("class Foo; end") + file.rewind + + responses = run_server([ + Initialize.new(1), + TextDocumentFormatting.new(2, "file://#{file.path}"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: [{ newText: :any }] }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + assert_equal("class Foo\nend\n", responses.dig(1, :result, 0, :newText)) + end + end + + def test_bogus_request + assert_raises(ArgumentError) do + run_server([{ method: "textDocument/bogus" }]) + end + end + + def test_clean_shutdown + responses = run_server([Initialize.new(1), Shutdown.new(2)]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: {} } + ]] + + assert_operator(shape, :===, responses) + end + + def test_file_that_does_not_exist + responses = run_server([ + Initialize.new(1), + TextDocumentFormatting.new(2, "file:///path/to/file.rb"), + Shutdown.new(3) + ]) + + shape = LanguageServer::Request[[ + { id: 1, result: { capabilities: Hash } }, + { id: 2, result: :any }, + { id: 3, result: {} } + ]] + + assert_operator(shape, :===, responses) + end + + private + + def write(content) + request = content.to_hash.merge(jsonrpc: "2.0").to_json + "Content-Length: #{request.bytesize}\r\n\r\n#{request}" + end + + def read(content) + [].tap do |messages| + while (headers = content.gets("\r\n\r\n")) + source = content.read(headers[/Content-Length: (\d+)/i, 1].to_i) + messages << JSON.parse(source, symbolize_names: true) + end + end + end + + def run_server(messages, print_width: DEFAULT_PRINT_WIDTH, ignore_files: []) + input = StringIO.new(messages.map { |message| write(message) }.join) + output = StringIO.new + + LanguageServer.new( + input: input, + output: output, + print_width: print_width, + ignore_files: ignore_files + ).run + + read(output.tap(&:rewind)) + end + end +end diff --git a/test/location_test.rb b/test/location_test.rb new file mode 100644 index 00000000..26831fb1 --- /dev/null +++ b/test/location_test.rb @@ -0,0 +1,28 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class LocationTest < Minitest::Test + def test_lines + location = Location.fixed(line: 1, char: 0, column: 0) + location = location.to(Location.fixed(line: 3, char: 3, column: 3)) + + assert_equal(1..3, location.lines) + end + + def test_deconstruct + location = Location.fixed(line: 1, char: 0, column: 0) + + assert_equal(1, location.start_line) + assert_equal(0, location.start_char) + assert_equal(0, location.start_column) + end + + def test_deconstruct_keys + location = Location.fixed(line: 1, char: 0, column: 0) + + assert_equal(1, location.start_line) + end + end +end diff --git a/test/mutation_test.rb b/test/mutation_test.rb new file mode 100644 index 00000000..ab9dd019 --- /dev/null +++ b/test/mutation_test.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class MutationTest < Minitest::Test + def test_mutates_based_on_patterns + source = <<~RUBY + if a = b + c + end + RUBY + + expected = <<~RUBY + if (a = b) + c + end + RUBY + + program = SyntaxTree.parse(source).accept(build_mutation) + assert_equal(expected, SyntaxTree::Formatter.format(source, program)) + end + + private + + def build_mutation + SyntaxTree.mutation do |mutation| + mutation.mutate("IfNode[predicate: Assign | OpAssign]") do |node| + # Get the existing If's predicate node + predicate = node.predicate + + # Create a new predicate node that wraps the existing predicate node + # in parentheses + predicate = + SyntaxTree::Paren.new( + lparen: SyntaxTree::LParen.default, + contents: predicate, + location: predicate.location + ) + + # Return a copy of this node with the new predicate + node.copy(predicate: predicate) + end + end + end + end +end diff --git a/test/node_test.rb b/test/node_test.rb index 6bde39bc..f2706b2c 100644 --- a/test/node_test.rb +++ b/test/node_test.rb @@ -32,7 +32,7 @@ def test___end__ end def test_alias - assert_node(Alias, "alias left right") + assert_node(AliasNode, "alias left right") end def test_aref @@ -60,7 +60,7 @@ def test_arg_paren_heredoc ARGUMENT SOURCE - at = location(lines: 1..3, chars: 6..28) + at = location(lines: 1..3, chars: 6..37) assert_node(ArgParen, source, at: at, &:arguments) end @@ -104,16 +104,18 @@ def test_arg_star end end - def test_args_forward - source = <<~SOURCE - def get(...) - request(:GET, ...) - end - SOURCE + guard_version("2.7.3") do + def test_args_forward + source = <<~SOURCE + def get(...) + request(:GET, ...) + end + SOURCE - at = location(lines: 2..2, chars: 29..32) - assert_node(ArgsForward, source, at: at) do |node| - node.bodystmt.statements.body.first.arguments.arguments.parts.last + at = location(lines: 2..2, chars: 29..32) + assert_node(ArgsForward, source, at: at) do |node| + node.bodystmt.statements.body.first.arguments.arguments.parts.last + end end end @@ -129,7 +131,7 @@ def test_aryptn end SOURCE - at = location(lines: 2..2, chars: 18..47) + at = location(lines: 2..2, chars: 18..48) assert_node(AryPtn, source, at: at) { |node| node.consequent.pattern } end @@ -266,7 +268,7 @@ def test_bodystmt end SOURCE - at = location(lines: 9..9, chars: 5..64) + at = location(lines: 2..9, chars: 5..64) assert_node(BodyStmt, source, at: at, &:bodystmt) end @@ -274,15 +276,18 @@ def test_brace_block source = "method { |variable| variable + 1 }" at = location(chars: 7..34) - assert_node(BraceBlock, source, at: at, &:block) + assert_node(BlockNode, source, at: at, &:block) end def test_break - assert_node(Break, "break value") + at = location(chars: 6..17) + assert_node(Break, "tap { break value }", at: at) do |node| + node.block.bodystmt.body.first + end end def test_call - assert_node(Call, "receiver.message") + assert_node(CallNode, "receiver.message") end def test_case @@ -363,7 +368,7 @@ def test_cvar end def test_def - assert_node(Def, "def method(param) result end") + assert_node(DefNode, "def method(param) result end") end def test_def_paramless @@ -372,18 +377,18 @@ def method end SOURCE - assert_node(Def, source) + assert_node(DefNode, source) end guard_version("3.0.0") do def test_def_endless - assert_node(DefEndless, "def method = result") + assert_node(DefNode, "def method = result") end end guard_version("3.1.0") do def test_def_endless_command - assert_node(DefEndless, "def method = result argument") + assert_node(DefNode, "def method = result argument") end end @@ -392,7 +397,7 @@ def test_defined end def test_defs - assert_node(Defs, "def object.method(param) result end") + assert_node(DefNode, "def object.method(param) result end") end def test_defs_paramless @@ -401,22 +406,22 @@ def object.method end SOURCE - assert_node(Defs, source) + assert_node(DefNode, source) end def test_do_block source = "method do |variable| variable + 1 end" at = location(chars: 7..37) - assert_node(DoBlock, source, at: at, &:block) + assert_node(BlockNode, source, at: at, &:block) end def test_dot2 - assert_node(Dot2, "1..3") + assert_node(RangeNode, "1..3") end def test_dot3 - assert_node(Dot3, "1...3") + assert_node(RangeNode, "1...3") end def test_dyna_symbol @@ -485,7 +490,7 @@ def test_excessed_comma end def test_fcall - assert_node(FCall, "method(argument)") + assert_node(CallNode, "method(argument)") end def test_field @@ -531,7 +536,7 @@ def test_heredoc HEREDOC SOURCE - at = location(lines: 1..3, chars: 0..22) + at = location(lines: 1..3, chars: 0..30) assert_node(Heredoc, source, at: at) end @@ -542,10 +547,21 @@ def test_heredoc_beg HEREDOC SOURCE - at = location(chars: 0..11) + at = location(chars: 0..10) assert_node(HeredocBeg, source, at: at, &:beginning) end + def test_heredoc_end + source = <<~SOURCE + <<~HEREDOC + contents + HEREDOC + SOURCE + + at = location(lines: 3..3, chars: 22..30, columns: 0..8) + assert_node(HeredocEnd, source, at: at, &:ending) + end + def test_hshptn source = <<~SOURCE case value @@ -562,7 +578,7 @@ def test_ident end def test_if - assert_node(If, "if value then else end") + assert_node(IfNode, "if value then else end") end def test_if_op @@ -570,7 +586,7 @@ def test_if_op end def test_if_mod - assert_node(IfMod, "expression if predicate") + assert_node(IfNode, "expression if predicate") end def test_imaginary @@ -634,7 +650,7 @@ def test_lbrace source = "method {}" at = location(chars: 7..8) - assert_node(LBrace, source, at: at) { |node| node.block.lbrace } + assert_node(LBrace, source, at: at) { |node| node.block.opening } end def test_lparen @@ -697,7 +713,10 @@ def test_mrhs_add_star end def test_next - assert_node(Next, "next(value)") + at = location(chars: 6..17) + assert_node(Next, "tap { next(value) }", at: at) do |node| + node.block.bodystmt.body.first + end end def test_op @@ -746,10 +765,9 @@ def test_program program = parser.parse refute(parser.error?) - case program - in statements: { body: [statement] } - assert_kind_of(VCall, statement) - end + statements = program.statements.body + assert_equal 1, statements.size + assert_kind_of(VCall, statements.first) json = JSON.parse(program.to_json) io = StringIO.new @@ -774,7 +792,9 @@ def test_rational end def test_redo - assert_node(Redo, "redo") + assert_node(Redo, "tap { redo }", at: location(chars: 6..10)) do |node| + node.block.bodystmt.body.first + end end def test_regexp_literal @@ -821,15 +841,18 @@ def test_rest_param end def test_retry - assert_node(Retry, "retry") + at = location(chars: 15..20) + assert_node(Retry, "begin; rescue; retry; end", at: at) do |node| + node.bodystmt.rescue_clause.statements.body.first + end end def test_return - assert_node(Return, "return value") + assert_node(ReturnNode, "return value") end def test_return0 - assert_node(Return0, "return") + assert_node(ReturnNode, "return") end def test_sclass @@ -911,23 +934,23 @@ def test_undef end def test_unless - assert_node(Unless, "unless value then else end") + assert_node(UnlessNode, "unless value then else end") end def test_unless_mod - assert_node(UnlessMod, "expression unless predicate") + assert_node(UnlessNode, "expression unless predicate") end def test_until - assert_node(Until, "until value do end") + assert_node(UntilNode, "until value do end") end def test_until_mod - assert_node(UntilMod, "expression until predicate") + assert_node(UntilNode, "expression until predicate") end def test_var_alias - assert_node(VarAlias, "alias $new $old") + assert_node(AliasNode, "alias $new $old") end def test_var_field @@ -937,8 +960,8 @@ def test_var_field guard_version("3.1.0") do def test_pinned_var_ref - source = "foo in ^bar" - at = location(chars: 7..11) + source = "bar = 1; foo in ^bar" + at = location(chars: 16..20) assert_node(PinnedVarRef, source, at: at, &:pattern) end @@ -969,11 +992,11 @@ def test_when end def test_while - assert_node(While, "while value do end") + assert_node(WhileNode, "while value do end") end def test_while_mod - assert_node(WhileMod, "expression while predicate") + assert_node(WhileNode, "expression while predicate") end def test_word @@ -996,16 +1019,22 @@ def test_xstring_heredoc HEREDOC SOURCE - at = location(lines: 1..3, chars: 0..18) + at = location(lines: 1..3, chars: 0..26) assert_node(Heredoc, source, at: at) end def test_yield - assert_node(Yield, "yield value") + at = location(lines: 2..2, chars: 10..21) + assert_node(YieldNode, "def foo\n yield value\nend\n", at: at) do |node| + node.bodystmt.statements.body.first + end end def test_yield0 - assert_node(Yield0, "yield") + at = location(lines: 2..2, chars: 10..15) + assert_node(YieldNode, "def foo\n yield\nend\n", at: at) do |node| + node.bodystmt.statements.body.first + end end def test_zsuper @@ -1032,6 +1061,356 @@ def test_multibyte_column_positions assert_node(Command, source, at: at) end + def test_root_class_raises_not_implemented_errors + { + accept: [nil], + child_nodes: [], + deconstruct: [], + deconstruct_keys: [[]], + format: [nil] + }.each do |method, arguments| + assert_raises(NotImplementedError) do + Node.new.public_send(method, *arguments) + end + end + end + + def test_arity_no_args + source = <<~SOURCE + def foo + end + SOURCE + + at = location(chars: 0..11, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(0..0, node.arity) + node + end + end + + def test_arity_positionals + source = <<~SOURCE + def foo(a, b = 1) + end + SOURCE + + at = location(chars: 0..21, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(1..2, node.arity) + node + end + end + + def test_arity_rest + source = <<~SOURCE + def foo(a, *b) + end + SOURCE + + at = location(chars: 0..18, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(1.., node.arity) + node + end + end + + def test_arity_keyword_rest + source = <<~SOURCE + def foo(a, **b) + end + SOURCE + + at = location(chars: 0..19, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(1.., node.arity) + node + end + end + + def test_arity_keywords + source = <<~SOURCE + def foo(a:, b: 1) + end + SOURCE + + at = location(chars: 0..21, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(1..2, node.arity) + node + end + end + + def test_arity_mixed + source = <<~SOURCE + def foo(a, b = 1, c:, d: 2) + end + SOURCE + + at = location(chars: 0..31, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(2..4, node.arity) + node + end + end + + guard_version("2.7.3") do + def test_arity_arg_forward + source = <<~SOURCE + def foo(...) + end + SOURCE + + at = location(chars: 0..16, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(0.., node.arity) + node + end + end + end + + guard_version("3.0.0") do + def test_arity_positional_and_arg_forward + source = <<~SOURCE + def foo(a, ...) + end + SOURCE + + at = location(chars: 0..19, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(1.., node.arity) + node + end + end + end + + def test_arity_no_parenthesis + source = <<~SOURCE + def foo a, b = 1 + end + SOURCE + + at = location(chars: 0..20, columns: 0..3, lines: 1..2) + assert_node(DefNode, source, at: at) do |node| + assert_equal(1..2, node.arity) + node + end + end + + def test_block_arity_positionals + source = <<~SOURCE + [].each do |a, b, c| + end + SOURCE + + at = location(chars: 8..24, columns: 8..3, lines: 1..2) + assert_node(BlockNode, source, at: at) do |node| + block = node.block + assert_equal(3..3, block.arity) + block + end + end + + def test_block_arity_with_optional + source = <<~SOURCE + [].each do |a, b = 1| + end + SOURCE + + at = location(chars: 8..25, columns: 8..3, lines: 1..2) + assert_node(BlockNode, source, at: at) do |node| + block = node.block + assert_equal(1..2, block.arity) + block + end + end + + def test_block_arity_with_optional_keyword + source = <<~SOURCE + [].each do |a, b: 2| + end + SOURCE + + at = location(chars: 8..24, columns: 8..3, lines: 1..2) + assert_node(BlockNode, source, at: at) do |node| + block = node.block + assert_equal(1..2, block.arity) + block + end + end + + def test_call_node_arity_positional_arguments + source = <<~SOURCE + foo(1, 2, 3) + SOURCE + + at = location(chars: 0..12, columns: 0..3, lines: 1..1) + assert_node(CallNode, source, at: at) do |node| + assert_equal(3, node.arity) + node + end + end + + def test_call_node_arity_keyword_arguments + source = <<~SOURCE + foo(bar, something: 123) + SOURCE + + at = location(chars: 0..24, columns: 0..24, lines: 1..1) + assert_node(CallNode, source, at: at) do |node| + assert_equal(2, node.arity) + node + end + end + + def test_call_node_arity_splat_arguments + source = <<~SOURCE + foo(*bar) + SOURCE + + at = location(chars: 0..9, columns: 0..9, lines: 1..1) + assert_node(CallNode, source, at: at) do |node| + assert_equal(Float::INFINITY, node.arity) + node + end + end + + def test_call_node_arity_keyword_rest_arguments + source = <<~SOURCE + foo(**bar) + SOURCE + + at = location(chars: 0..10, columns: 0..10, lines: 1..1) + assert_node(CallNode, source, at: at) do |node| + assert_equal(Float::INFINITY, node.arity) + node + end + end + + guard_version("2.7.3") do + def test_call_node_arity_arg_forward_arguments + source = <<~SOURCE + def foo(...) + bar(...) + end + SOURCE + + at = location(chars: 15..23, columns: 2..10, lines: 2..2) + assert_node(CallNode, source, at: at) do |node| + call = node.bodystmt.statements.body.first + assert_equal(Float::INFINITY, call.arity) + call + end + end + end + + def test_command_arity_positional_arguments + source = <<~SOURCE + foo 1, 2, 3 + SOURCE + + at = location(chars: 0..11, columns: 0..3, lines: 1..1) + assert_node(Command, source, at: at) do |node| + assert_equal(3, node.arity) + node + end + end + + def test_command_arity_keyword_arguments + source = <<~SOURCE + foo bar, something: 123 + SOURCE + + at = location(chars: 0..23, columns: 0..23, lines: 1..1) + assert_node(Command, source, at: at) do |node| + assert_equal(2, node.arity) + node + end + end + + def test_command_arity_splat_arguments + source = <<~SOURCE + foo *bar + SOURCE + + at = location(chars: 0..8, columns: 0..8, lines: 1..1) + assert_node(Command, source, at: at) do |node| + assert_equal(Float::INFINITY, node.arity) + node + end + end + + def test_command_arity_keyword_rest_arguments + source = <<~SOURCE + foo **bar + SOURCE + + at = location(chars: 0..9, columns: 0..9, lines: 1..1) + assert_node(Command, source, at: at) do |node| + assert_equal(Float::INFINITY, node.arity) + node + end + end + + def test_command_call_arity_positional_arguments + source = <<~SOURCE + object.foo 1, 2, 3 + SOURCE + + at = location(chars: 0..18, columns: 0..3, lines: 1..1) + assert_node(CommandCall, source, at: at) do |node| + assert_equal(3, node.arity) + node + end + end + + def test_command_call_arity_keyword_arguments + source = <<~SOURCE + object.foo bar, something: 123 + SOURCE + + at = location(chars: 0..30, columns: 0..30, lines: 1..1) + assert_node(CommandCall, source, at: at) do |node| + assert_equal(2, node.arity) + node + end + end + + def test_command_call_arity_splat_arguments + source = <<~SOURCE + object.foo *bar + SOURCE + + at = location(chars: 0..15, columns: 0..15, lines: 1..1) + assert_node(CommandCall, source, at: at) do |node| + assert_equal(Float::INFINITY, node.arity) + node + end + end + + def test_command_call_arity_keyword_rest_arguments + source = <<~SOURCE + object.foo **bar + SOURCE + + at = location(chars: 0..16, columns: 0..16, lines: 1..1) + assert_node(CommandCall, source, at: at) do |node| + assert_equal(Float::INFINITY, node.arity) + node + end + end + + def test_vcall_arity + source = <<~SOURCE + foo + SOURCE + + at = location(chars: 0..3, columns: 0..3, lines: 1..1) + assert_node(VCall, source, at: at) do |node| + assert_equal(0, node.arity) + node + end + end + private def location(lines: 1..1, chars: 0..0, columns: 0..0) diff --git a/test/parser_test.rb b/test/parser_test.rb index 8aadbfc2..169d5b46 100644 --- a/test/parser_test.rb +++ b/test/parser_test.rb @@ -30,5 +30,97 @@ def test_parses_ripper_methods # Finally, assert that we have no remaining events. assert_empty(events) end + + def test_errors_on_missing_token_with_location + error = assert_raises(Parser::ParseError) { SyntaxTree.parse("f+\"foo") } + assert_equal(3, error.column) + end + + def test_errors_on_missing_end_with_location + error = assert_raises(Parser::ParseError) { SyntaxTree.parse("foo do 1") } + assert_equal(4, error.column) + end + + def test_errors_on_missing_regexp_ending + error = + assert_raises(Parser::ParseError) { SyntaxTree.parse("a =~ /foo") } + + assert_equal(6, error.column) + end + + def test_errors_on_missing_token_without_location + assert_raises(Parser::ParseError) { SyntaxTree.parse(":\"foo") } + end + + def test_handles_strings_with_non_terminated_embedded_expressions + assert_raises(Parser::ParseError) { SyntaxTree.parse('"#{"') } + end + + def test_errors_on_else_missing_two_ends + assert_raises(Parser::ParseError) { SyntaxTree.parse(<<~RUBY) } + def foo + if something + else + call do + end + RUBY + end + + def test_does_not_choke_on_invalid_characters_in_source_string + SyntaxTree.parse(<<~RUBY) + # comment + # comment + __END__ + \xC5 + RUBY + end + + def test_lambda_vars_with_parameters_location + tree = SyntaxTree.parse(<<~RUBY) + # comment + # comment + ->(_i; a) { a } + RUBY + + local_location = + tree.statements.body.last.params.contents.locals.first.location + + assert_equal(3, local_location.start_line) + assert_equal(3, local_location.end_line) + assert_equal(7, local_location.start_column) + assert_equal(8, local_location.end_column) + end + + def test_lambda_vars_location + tree = SyntaxTree.parse(<<~RUBY) + # comment + # comment + ->(; a) { a } + RUBY + + local_location = + tree.statements.body.last.params.contents.locals.first.location + + assert_equal(3, local_location.start_line) + assert_equal(3, local_location.end_line) + assert_equal(5, local_location.start_column) + assert_equal(6, local_location.end_column) + end + + def test_multiple_lambda_vars_location + tree = SyntaxTree.parse(<<~RUBY) + # comment + # comment + ->(; a, b, c) { a } + RUBY + + local_location = + tree.statements.body.last.params.contents.locals.last.location + + assert_equal(3, local_location.start_line) + assert_equal(3, local_location.end_line) + assert_equal(11, local_location.start_column) + assert_equal(12, local_location.end_column) + end end end diff --git a/test/plugin/disable_auto_ternary_test.rb b/test/plugin/disable_auto_ternary_test.rb new file mode 100644 index 00000000..b2af9d35 --- /dev/null +++ b/test/plugin/disable_auto_ternary_test.rb @@ -0,0 +1,32 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +module SyntaxTree + class DisableTernaryTest < Minitest::Test + def test_short_if_else_unchanged + assert_format(<<~RUBY) + if true + 1 + else + 2 + end + RUBY + end + + def test_short_ternary_unchanged + assert_format("true ? 1 : 2\n") + end + + private + + def assert_format(expected, source = expected) + options = Formatter::Options.new(disable_auto_ternary: true) + formatter = Formatter.new(source, [], options: options) + SyntaxTree.parse(source).format(formatter) + + formatter.flush + assert_equal(expected, formatter.output.join) + end + end +end diff --git a/test/formatter/single_quotes_test.rb b/test/plugin/single_quotes_test.rb similarity index 65% rename from test/formatter/single_quotes_test.rb rename to test/plugin/single_quotes_test.rb index 8bf82cb8..b1359ac7 100644 --- a/test/formatter/single_quotes_test.rb +++ b/test/plugin/single_quotes_test.rb @@ -1,18 +1,21 @@ # frozen_string_literal: true require_relative "../test_helper" -require "syntax_tree/formatter/single_quotes" module SyntaxTree - class Formatter - class TestFormatter < Formatter - prepend Formatter::SingleQuotes - end - + class SingleQuotesTest < Minitest::Test def test_empty_string_literal assert_format("''\n", "\"\"") end + def test_character_literal_with_double_quote + assert_format("'\"'\n", "?\"") + end + + def test_character_literal_with_singlee_quote + assert_format("'\\''\n", "?'") + end + def test_string_literal assert_format("'string'\n", "\"string\"") end @@ -25,6 +28,10 @@ def test_dyna_symbol assert_format(":'symbol'\n", ":\"symbol\"") end + def test_single_quote_in_string + assert_format("\"str'ing\"\n") + end + def test_label assert_format( "{ foo => foo, :'bar' => bar }\n", @@ -35,7 +42,8 @@ def test_label private def assert_format(expected, source = expected) - formatter = TestFormatter.new(source, []) + options = Formatter::Options.new(quote: "'") + formatter = Formatter.new(source, [], options: options) SyntaxTree.parse(source).format(formatter) formatter.flush diff --git a/test/plugin/trailing_comma_test.rb b/test/plugin/trailing_comma_test.rb new file mode 100644 index 00000000..7f6e49a8 --- /dev/null +++ b/test/plugin/trailing_comma_test.rb @@ -0,0 +1,91 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +module SyntaxTree + class TrailingCommaTest < Minitest::Test + def test_arg_paren_flat + assert_format("foo(a)\n") + end + + def test_arg_paren_break + assert_format(<<~EXPECTED, <<~SOURCE) + foo( + #{"a" * 80}, + ) + EXPECTED + foo(#{"a" * 80}) + SOURCE + end + + def test_arg_paren_block + assert_format(<<~EXPECTED, <<~SOURCE) + foo( + &#{"a" * 80} + ) + EXPECTED + foo(&#{"a" * 80}) + SOURCE + end + + def test_arg_paren_command + assert_format(<<~EXPECTED, <<~SOURCE) + foo( + bar #{"a" * 80} + ) + EXPECTED + foo(bar #{"a" * 80}) + SOURCE + end + + def test_arg_paren_command_call + assert_format(<<~EXPECTED, <<~SOURCE) + foo( + bar.baz #{"a" * 80} + ) + EXPECTED + foo(bar.baz #{"a" * 80}) + SOURCE + end + + def test_array_literal_flat + assert_format("[a]\n") + end + + def test_array_literal_break + assert_format(<<~EXPECTED, <<~SOURCE) + [ + #{"a" * 80}, + ] + EXPECTED + [#{"a" * 80}] + SOURCE + end + + def test_hash_literal_flat + assert_format("{ a: a }\n") + end + + def test_hash_literal_break + assert_format(<<~EXPECTED, <<~SOURCE) + { + a: + #{"a" * 80}, + } + EXPECTED + { a: #{"a" * 80} } + SOURCE + end + + private + + def assert_format(expected, source = expected) + options = Formatter::Options.new(trailing_comma: true) + formatter = Formatter.new(source, [], options: options) + SyntaxTree.parse(source).format(formatter) + + formatter.flush + assert_equal(expected, formatter.output.join) + end + end +end diff --git a/test/quotes_test.rb b/test/quotes_test.rb new file mode 100644 index 00000000..2e2e0243 --- /dev/null +++ b/test/quotes_test.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class QuotesTest < Minitest::Test + def test_normalize + content = "'aaa' \"bbb\" \\'ccc\\' \\\"ddd\\\"" + enclosing = "\"" + + result = Quotes.normalize(content, enclosing) + assert_equal "'aaa' \\\"bbb\\\" \\'ccc\\' \\\"ddd\\\"", result + end + end +end diff --git a/test/ractor_test.rb b/test/ractor_test.rb new file mode 100644 index 00000000..7e0201ca --- /dev/null +++ b/test/ractor_test.rb @@ -0,0 +1,52 @@ +# frozen_string_literal: true + +# Don't run this test if we're in a version of Ruby that doesn't have Ractors. +return unless defined?(Ractor) + +# Don't run this version on Ruby 3.0.0. For some reason it just hangs within the +# main Ractor waiting for this children. Not going to investigate it since it's +# already been fixed in 3.1.0. +return if Gem::Version.new(RUBY_VERSION) < Gem::Version.new("3.1.0") + +require_relative "test_helper" + +module SyntaxTree + class RactorTest < Minitest::Test + def test_formatting + ractors = + filepaths.map do |filepath| + # At the moment we have to parse in the main Ractor because Ripper is + # not marked as a Ractor-safe extension. + source = SyntaxTree.read(filepath) + program = SyntaxTree.parse(source) + + with_silenced_warnings do + Ractor.new(source, program, name: filepath) do |source, program| + SyntaxTree::Formatter.format(source, program) + end + end + end + + ractors.each { |ractor| assert_kind_of String, ractor.take } + end + + private + + def filepaths + Dir.glob(File.expand_path("../lib/syntax_tree/plugin/*.rb", __dir__)) + end + + # Ractors still warn about usage, so I'm disabling that warning here just to + # have clean test output. + def with_silenced_warnings + previous = $VERBOSE + + begin + $VERBOSE = nil + yield + ensure + $VERBOSE = previous + end + end + end +end diff --git a/test/rake_test.rb b/test/rake_test.rb new file mode 100644 index 00000000..90662519 --- /dev/null +++ b/test/rake_test.rb @@ -0,0 +1,57 @@ +# frozen_string_literal: true + +require_relative "test_helper" +require "syntax_tree/rake_tasks" + +module SyntaxTree + module Rake + class CheckTaskTest < Minitest::Test + Invocation = Struct.new(:args) + + def test_task_command + assert_raises(NotImplementedError) { Task.new.command } + end + + def test_check_task + source_files = "{app,config,lib}/**/*.rb" + + CheckTask.new do |t| + t.source_files = source_files + t.print_width = 100 + t.target_ruby_version = Gem::Version.new("2.6.0") + end + + expected = [ + "check", + "--print-width=100", + "--target-ruby-version=2.6.0", + source_files + ] + + invocation = invoke("stree:check") + assert_equal(expected, invocation.args) + end + + def test_write_task + source_files = "{app,config,lib}/**/*.rb" + WriteTask.new { |t| t.source_files = source_files } + + invocation = invoke("stree:write") + assert_equal(["write", source_files], invocation.args) + end + + private + + def invoke(task_name) + invocation = nil + stub = ->(args) { invocation = Invocation.new(args) } + + assert_raises SystemExit do + SyntaxTree::CLI.stub(:run, stub) { ::Rake::Task[task_name].invoke } + end + + invocation + end + end + end +end diff --git a/test/search_test.rb b/test/search_test.rb new file mode 100644 index 00000000..9f7d89b8 --- /dev/null +++ b/test/search_test.rb @@ -0,0 +1,127 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class SearchTest < Minitest::Test + def test_search_invalid_syntax + assert_raises(Pattern::CompilationError) { search("", "<>") } + end + + def test_search_invalid_constant + assert_raises(Pattern::CompilationError) { search("", "Foo") } + end + + def test_search_invalid_nested_constant + assert_raises(Pattern::CompilationError) { search("", "Foo::Bar") } + end + + def test_search_regexp_with_interpolation + assert_raises(Pattern::CompilationError) { search("", "/\#{foo}/") } + end + + def test_search_string_with_interpolation + assert_raises(Pattern::CompilationError) { search("", '"#{foo}"') } + end + + def test_search_symbol_with_interpolation + assert_raises(Pattern::CompilationError) { search("", ":\"\#{foo}\"") } + end + + def test_search_invalid_node + assert_raises(Pattern::CompilationError) { search("", "Int[^foo]") } + end + + def test_search_self + assert_raises(Pattern::CompilationError) { search("", "self") } + end + + def test_search_array_pattern_no_constant + results = search("1 + 2", "[Int, Int]") + + assert_equal 1, results.length + end + + def test_search_array_pattern + results = search("1 + 2", "Binary[Int, Int]") + + assert_equal 1, results.length + end + + def test_search_binary_or + results = search("Foo + Bar + 1", "VarRef | Int") + + assert_equal 3, results.length + assert_equal "1", results.min_by { |node| node.class.name }.value + end + + def test_search_const + results = search("Foo + Bar + Baz", "VarRef") + + assert_equal 3, results.length + assert_equal %w[Bar Baz Foo], results.map { |node| node.value.value }.sort + end + + def test_search_object_const + results = search("1 + 2 + 3", "Int[value: String]") + + assert_equal 3, results.length + end + + def test_search_syntax_tree_const + results = search("Foo + Bar + Baz", "SyntaxTree::VarRef") + + assert_equal 3, results.length + end + + def test_search_hash_pattern_no_constant + results = search("Foo + Bar + Baz", "{ value: Const }") + + assert_equal 3, results.length + end + + def test_search_hash_pattern_string + results = search("Foo + Bar + Baz", "VarRef[value: Const[value: 'Foo']]") + + assert_equal 1, results.length + assert_equal "Foo", results.first.value.value + end + + def test_search_hash_pattern_regexp + results = search("Foo + Bar + Baz", "VarRef[value: Const[value: /^Ba/]]") + + assert_equal 2, results.length + assert_equal %w[Bar Baz], results.map { |node| node.value.value }.sort + end + + def test_search_string_empty + results = search("", "''") + + assert_empty results + end + + def test_search_symbol_empty + results = search("", ":''") + + assert_empty results + end + + def test_search_symbol_plain + results = search("1 + 2", "Binary[operator: :'+']") + + assert_equal 1, results.length + end + + def test_search_symbol + results = search("1 + 2", "Binary[operator: :+]") + + assert_equal 1, results.length + end + + private + + def search(source, query) + SyntaxTree.search(source, query).to_a + end + end +end diff --git a/test/syntax_tree_test.rb b/test/syntax_tree_test.rb index 3d5ae90e..27aa6851 100644 --- a/test/syntax_tree_test.rb +++ b/test/syntax_tree_test.rb @@ -22,13 +22,22 @@ def method # comment SOURCE bodystmt = SyntaxTree.parse(source).statements.body.first.bodystmt - assert_equal(20, bodystmt.location.start_char) + assert_equal(20, bodystmt.start_char) end def test_parse_error assert_raises(Parser::ParseError) { SyntaxTree.parse("<>") } end + def test_marshalable + node = SyntaxTree.parse("1 + 2") + assert_operator(node, :===, Marshal.load(Marshal.dump(node))) + end + + def test_maxwidth_format + assert_equal("foo +\n bar\n", SyntaxTree.format("foo + bar", 5)) + end + def test_read source = SyntaxTree.read(__FILE__) assert_equal(Encoding.default_external, source.encoding) diff --git a/test/test_helper.rb b/test/test_helper.rb index ce75aeb2..787f819d 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -1,12 +1,11 @@ # frozen_string_literal: true -require "simplecov" -SimpleCov.start do - add_filter("prettyprint.rb") - - unless ENV["CI"] - add_filter("accept_methods_test.rb") - add_filter("idempotency_test.rb") +unless RUBY_ENGINE == "truffleruby" + require "simplecov" + SimpleCov.start do + add_filter("idempotency_test.rb") unless ENV["CI"] + add_group("lib", "lib") + add_group("test", "test") end end @@ -14,6 +13,41 @@ require "syntax_tree" require "syntax_tree/cli" +unless RUBY_ENGINE == "truffleruby" + # Here we are going to establish type verification whenever a new node is + # created. We do this through the reflection module, which in turn parses the + # source code of the node classes. + require "syntax_tree/reflection" + SyntaxTree::Reflection.nodes.each do |name, node| + next if name == :Statements + + clazz = SyntaxTree.const_get(name) + parameters = clazz.instance_method(:initialize).parameters + + # First, verify that all of the parameters listed in the list of attributes. + # If there are any parameters that aren't listed in the attributes, then + # something went wrong with the parsing in the reflection module. + raise unless (parameters.map(&:last) - node.attributes.keys).empty? + + # Now we're going to use an alias chain to redefine the initialize method to + # include type checking. + clazz.alias_method(:initialize_without_verify, :initialize) + clazz.define_method(:initialize) do |**kwargs| + kwargs.each do |kwarg, value| + attribute = node.attributes.fetch(kwarg) + + unless attribute.type === value + raise TypeError, + "invalid type for #{name}##{kwarg}, expected " \ + "#{attribute.type.inspect}, got #{value.inspect}" + end + end + + initialize_without_verify(**kwargs) + end + end +end + require "json" require "tempfile" require "pp" @@ -28,7 +62,7 @@ def initialize @called = nil end - def method_missing(called, ...) + def method_missing(called, *, **) @called = called end end @@ -63,6 +97,9 @@ def assert_syntax_tree(node) refute_includes(pretty, "#<") assert_includes(pretty, type) + # Assert that we can get back a new tree by using the mutation visitor. + assert_operator node, :===, node.accept(MutationVisitor.new) + # Serialize the node to JSON, parse it back out, and assert that we have # found the expected type. json = node.to_json @@ -79,11 +116,11 @@ def assert_syntax_tree(node) end RUBY end + + Minitest::Test.include(self) end end -Minitest::Test.include(SyntaxTree::Assertions) - # There are a bunch of fixtures defined in test/fixtures. They exercise every # possible combination of syntax that leads to variations in the types of nodes. # They are used for testing various parts of Syntax Tree, including formatting, @@ -131,9 +168,8 @@ def self.each_fixture # If there's a comment starting with >= that starts after the % that # delineates the test, then we're going to check if the version # satisfies that constraint. - if comment&.start_with?(">=") && - (ruby_version < Gem::Version.new(comment.split[1])) - next + if comment&.start_with?(">=") + next if ruby_version < Gem::Version.new(comment.split[1]) end name = :"#{fixture}_#{index}" diff --git a/test/visitor_test.rb b/test/visitor_test.rb index 5e4f134d..d9637df0 100644 --- a/test/visitor_test.rb +++ b/test/visitor_test.rb @@ -30,19 +30,44 @@ def initialize @visited_nodes = [] end - visit_method def visit_class(node) - @visited_nodes << node.constant.constant.value - super + visit_methods do + def visit_class(node) + @visited_nodes << node.constant.constant.value + super + end + + def visit_def(node) + @visited_nodes << node.name.value + end end + end - visit_method def visit_def(node) - @visited_nodes << node.name.value + if defined?(DidYouMean.correct_error) + def test_visit_method_correction + error = assert_raises { Visitor.visit_method(:visit_binar) } + message = + if Exception.method_defined?(:detailed_message) + error.detailed_message + else + error.message + end + + assert_match(/visit_binary/, message) end end - def test_visit_method_correction - error = assert_raises { Visitor.visit_method(:visit_binar) } - assert_match(/visit_binary/, error.message) + class VisitMethodsTestVisitor < BasicVisitor + end + + def test_visit_methods + VisitMethodsTestVisitor.visit_methods do + assert_raises(BasicVisitor::VisitMethodError) do + # In reality, this would be a method defined using the def keyword, + # but we're using method_added here to trigger the checker so that we + # aren't defining methods dynamically in the test suite. + VisitMethodsTestVisitor.method_added(:visit_foo) + end + end end end end diff --git a/test/with_scope_test.rb b/test/with_scope_test.rb new file mode 100644 index 00000000..6b48d17d --- /dev/null +++ b/test/with_scope_test.rb @@ -0,0 +1,567 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class WithScopeTest < Minitest::Test + class Collector < Visitor + prepend WithScope + + attr_reader :arguments, :variables + + def initialize + @arguments = {} + @variables = {} + end + + def self.collect(source) + new.tap { SyntaxTree.parse(source).accept(_1) } + end + + visit_methods do + def visit_ident(node) + value = node.value.delete_suffix(":") + local = current_scope.find_local(node.value) + + case local&.type + when :argument + arguments[[current_scope.id, value]] = local + when :variable + variables[[current_scope.id, value]] = local + end + end + + def visit_label(node) + value = node.value.delete_suffix(":") + local = current_scope.find_local(value) + + if local&.type == :argument + arguments[[current_scope.id, value]] = node + end + end + + def visit_vcall(node) + local = current_scope.find_local(node.value) + variables[[current_scope.id, value]] = local if local + + super + end + end + end + + def test_collecting_simple_variables + collector = Collector.collect(<<~RUBY) + def foo + a = 1 + a + end + RUBY + + assert_equal(1, collector.variables.length) + assert_variable(collector, "a", definitions: [2], usages: [3]) + end + + def test_collecting_aref_variables + collector = Collector.collect(<<~RUBY) + def foo + a = [] + a[1] + end + RUBY + + assert_equal(1, collector.variables.length) + assert_variable(collector, "a", definitions: [2], usages: [3]) + end + + def test_collecting_multi_assign_variables + collector = Collector.collect(<<~RUBY) + def foo + a, b = [1, 2] + puts a + puts b + end + RUBY + + assert_equal(2, collector.variables.length) + assert_variable(collector, "a", definitions: [2], usages: [3]) + assert_variable(collector, "b", definitions: [2], usages: [4]) + end + + def test_collecting_pattern_matching_variables + collector = Collector.collect(<<~RUBY) + def foo + case [1, 2] + in Integer => a, Integer + puts a + end + end + RUBY + + # There are two occurrences, one on line 3 for pinning and one on line 4 + # for reference + assert_equal(1, collector.variables.length) + assert_variable(collector, "a", definitions: [3], usages: [4]) + end + + def test_collecting_pinned_variables + collector = Collector.collect(<<~RUBY) + def foo + a = 18 + case [1, 2] + in ^a, *rest + puts a + puts rest + end + end + RUBY + + assert_equal(2, collector.variables.length) + assert_variable(collector, "a", definitions: [2], usages: [4, 5]) + assert_variable(collector, "rest", definitions: [4], usages: [6]) + end + + if RUBY_VERSION >= "3.1" + def test_collecting_one_line_pattern_matching_variables + collector = Collector.collect(<<~RUBY) + def foo + [1] => a + puts a + end + RUBY + + assert_equal(1, collector.variables.length) + assert_variable(collector, "a", definitions: [2], usages: [3]) + end + + def test_collecting_endless_method_arguments + collector = Collector.collect(<<~RUBY) + def foo(a) = puts a + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "a", definitions: [1], usages: [1]) + end + end + + def test_collecting_method_arguments + collector = Collector.collect(<<~RUBY) + def foo(a) + puts a + end + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "a", definitions: [1], usages: [2]) + end + + def test_collecting_methods_with_destructured_post_arguments + collector = Collector.collect(<<~RUBY) + def foo(optional = 1, (bin, bag)) + end + RUBY + + assert_equal(3, collector.arguments.length) + assert_argument(collector, "optional", definitions: [1], usages: []) + assert_argument(collector, "bin", definitions: [1], usages: []) + assert_argument(collector, "bag", definitions: [1], usages: []) + end + + def test_collecting_methods_with_desctructured_post_using_splat + collector = Collector.collect(<<~RUBY) + def foo(optional = 1, (bin, bag, *)) + end + RUBY + + assert_equal(3, collector.arguments.length) + assert_argument(collector, "optional", definitions: [1], usages: []) + assert_argument(collector, "bin", definitions: [1], usages: []) + assert_argument(collector, "bag", definitions: [1], usages: []) + end + + def test_collecting_methods_with_nested_desctructured + collector = Collector.collect(<<~RUBY) + def foo(optional = 1, (bin, (bag))) + end + RUBY + + assert_equal(3, collector.arguments.length) + assert_argument(collector, "optional", definitions: [1], usages: []) + assert_argument(collector, "bin", definitions: [1], usages: []) + assert_argument(collector, "bag", definitions: [1], usages: []) + end + + def test_collecting_singleton_method_arguments + collector = Collector.collect(<<~RUBY) + def self.foo(a) + puts a + end + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "a", definitions: [1], usages: [2]) + end + + def test_collecting_method_arguments_all_types + collector = Collector.collect(<<~RUBY) + def foo(a, b = 1, *c, d, e: 1, **f, &block) + puts a + puts b + puts c + puts d + puts e + puts f + block.call + end + RUBY + + assert_equal(7, collector.arguments.length) + assert_argument(collector, "a", definitions: [1], usages: [2]) + assert_argument(collector, "b", definitions: [1], usages: [3]) + assert_argument(collector, "c", definitions: [1], usages: [4]) + assert_argument(collector, "d", definitions: [1], usages: [5]) + assert_argument(collector, "e", definitions: [1], usages: [6]) + assert_argument(collector, "f", definitions: [1], usages: [7]) + assert_argument(collector, "block", definitions: [1], usages: [8]) + end + + def test_collecting_block_arguments + collector = Collector.collect(<<~RUBY) + def foo + [].each do |i| + puts i + end + end + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "i", definitions: [2], usages: [3]) + end + + def test_collecting_destructured_block_arguments + collector = Collector.collect(<<~RUBY) + [].each do |(a, *b)| + end + RUBY + + assert_equal(2, collector.arguments.length) + assert_argument(collector, "b", definitions: [1]) + end + + def test_collecting_anonymous_destructured_block_arguments + collector = Collector.collect(<<~RUBY) + [].each do |(a, *)| + end + RUBY + + assert_equal(1, collector.arguments.length) + end + + def test_collecting_one_line_block_arguments + collector = Collector.collect(<<~RUBY) + def foo + [].each { |i| puts i } + end + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "i", definitions: [2], usages: [2]) + end + + def test_collecting_shadowed_block_arguments + collector = Collector.collect(<<~RUBY) + def foo + i = "something" + + [].each do |i| + puts i + end + + i + end + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "i", definitions: [4], usages: [5]) + + assert_equal(1, collector.variables.length) + assert_variable(collector, "i", definitions: [2], usages: [8]) + end + + def test_collecting_shadowed_local_variables + collector = Collector.collect(<<~RUBY) + def foo(a) + puts a + a = 123 + a + end + RUBY + + # All occurrences are considered arguments, despite overriding the + # argument value + assert_equal(1, collector.arguments.length) + assert_equal(0, collector.variables.length) + assert_argument(collector, "a", definitions: [1, 3], usages: [2, 4]) + end + + def test_variables_in_the_top_level + collector = Collector.collect(<<~RUBY) + a = 123 + a + RUBY + + assert_equal(0, collector.arguments.length) + assert_equal(1, collector.variables.length) + assert_variable(collector, "a", definitions: [1], usages: [2]) + end + + def test_aref_field + collector = Collector.collect(<<~RUBY) + object = {} + object["name"] = "something" + RUBY + + assert_equal(0, collector.arguments.length) + assert_equal(1, collector.variables.length) + assert_variable(collector, "object", definitions: [1], usages: [2]) + end + + def test_aref_on_a_method_call + collector = Collector.collect(<<~RUBY) + object = MyObject.new + object.attributes["name"] = "something" + RUBY + + assert_equal(0, collector.arguments.length) + assert_equal(1, collector.variables.length) + assert_variable(collector, "object", definitions: [1], usages: [2]) + end + + def test_aref_with_two_accesses + collector = Collector.collect(<<~RUBY) + object = MyObject.new + object["first"]["second"] ||= [] + RUBY + + assert_equal(0, collector.arguments.length) + assert_equal(1, collector.variables.length) + assert_variable(collector, "object", definitions: [1], usages: [2]) + end + + def test_aref_on_a_method_call_with_arguments + collector = Collector.collect(<<~RUBY) + object = MyObject.new + object.instance_variable_get(:@attributes)[:something] = :other_thing + RUBY + + assert_equal(0, collector.arguments.length) + assert_equal(1, collector.variables.length) + assert_variable(collector, "object", definitions: [1], usages: [2]) + end + + def test_double_aref_on_method_call + collector = Collector.collect(<<~RUBY) + object = MyObject.new + object["attributes"].find { |a| a["field"] == "expected" }["value"] = "changed" + RUBY + + assert_equal(1, collector.arguments.length) + assert_argument(collector, "a", definitions: [2], usages: [2]) + + assert_equal(1, collector.variables.length) + assert_variable(collector, "object", definitions: [1], usages: [2]) + end + + def test_nested_arguments + collector = Collector.collect(<<~RUBY) + [[1, [2, 3]]].each do |one, (two, three)| + one + two + three + end + RUBY + + assert_equal(3, collector.arguments.length) + assert_equal(0, collector.variables.length) + + assert_argument(collector, "one", definitions: [1], usages: [2]) + assert_argument(collector, "two", definitions: [1], usages: [3]) + assert_argument(collector, "three", definitions: [1], usages: [4]) + end + + def test_double_nested_arguments + collector = Collector.collect(<<~RUBY) + [[1, [2, 3]]].each do |one, (two, (three, four))| + one + two + three + four + end + RUBY + + assert_equal(4, collector.arguments.length) + assert_equal(0, collector.variables.length) + + assert_argument(collector, "one", definitions: [1], usages: [2]) + assert_argument(collector, "two", definitions: [1], usages: [3]) + assert_argument(collector, "three", definitions: [1], usages: [4]) + assert_argument(collector, "four", definitions: [1], usages: [5]) + end + + def test_block_locals + collector = Collector.collect(<<~RUBY) + [].each do |; a| + end + RUBY + + assert_equal(1, collector.variables.length) + + assert_variable(collector, "a", definitions: [1]) + end + + def test_lambda_locals + collector = Collector.collect(<<~RUBY) + ->(;a) { } + RUBY + + assert_equal(1, collector.variables.length) + + assert_variable(collector, "a", definitions: [1]) + end + + def test_regex_named_capture_groups + collector = Collector.collect(<<~RUBY) + if /(?\\w+)-(?\\w+)/ =~ "something-else" + one + two + end + RUBY + + assert_equal(2, collector.variables.length) + + assert_variable(collector, "one", definitions: [1], usages: [2]) + assert_variable(collector, "two", definitions: [1], usages: [3]) + end + + def test_multiline_regex_named_capture_groups + collector = Collector.collect(<<~RUBY) + if %r{ + (?\\w+)- + (?\\w+) + } =~ "something-else" + one + two + end + RUBY + + assert_equal(2, collector.variables.length) + + assert_variable(collector, "one", definitions: [2], usages: [5]) + assert_variable(collector, "two", definitions: [3], usages: [6]) + end + + class Resolver < Visitor + prepend WithScope + + attr_reader :locals + + def initialize + @locals = [] + end + + visit_methods do + def visit_assign(node) + super.tap do + level = 0 + name = node.target.value.value + + scope = current_scope + while !scope.locals.key?(name) && !scope.parent.nil? + level += 1 + scope = scope.parent + end + + locals << [name, level] + end + end + end + end + + def test_resolver + source = <<~RUBY + module Level0 + level0 = 0 + + class Level1 + level1 = 1 + + def level2 + level2 = 2 + + tap do |level3| + level2 = 2 + level3 = 3 + + tap do |level4| + level2 = 2 + level4 = 4 + end + end + end + end + end + RUBY + + resolver = Resolver.new + SyntaxTree.parse(source).accept(resolver) + + expected = [ + ["level0", 0], + ["level1", 0], + ["level2", 0], + ["level2", 1], + ["level3", 0], + ["level2", 2], + ["level4", 0] + ] + + assert_equal expected, resolver.locals + end + + private + + def assert_collected(field, name, definitions: [], usages: []) + keys = field.keys.select { |key| key[1] == name } + assert_equal(1, keys.length) + + variable = field[keys.first] + + assert_equal(definitions.length, variable.definitions.length) + definitions.each_with_index do |definition, index| + assert_equal(definition, variable.definitions[index].start_line) + end + + assert_equal(usages.length, variable.usages.length) + usages.each_with_index do |usage, index| + assert_equal(usage, variable.usages[index].start_line) + end + end + + def assert_argument(collector, name, definitions: [], usages: []) + assert_collected( + collector.arguments, + name, + definitions: definitions, + usages: usages + ) + end + + def assert_variable(collector, name, definitions: [], usages: []) + assert_collected( + collector.variables, + name, + definitions: definitions, + usages: usages + ) + end + end +end