From 43f9a8cb9a3b0cf9c4264591e87c59a874c12946 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Wed, 17 Jun 2026 09:56:30 -0400 Subject: [PATCH 1/4] feat: initial implementation of codeanalyzer-go Adds the static code analysis engine for Go including syntactic analysis, schema definitions, CLI entrypoint, and test fixtures. Semantic analysis via CodeQL is scaffolded but not yet implemented. Signed-off-by: Saurabh Sinha --- .gitignore | 19 + LICENSE | 201 +++++ README.md | 209 +++++ cmd/codeanalyzer/main.go | 104 +++ go.mod | 15 + go.sum | 18 + internal/analysis/pass.go | 38 + internal/analysis/registry.go | 90 ++ internal/core/analyzer.go | 201 +++++ internal/core/analyzer_test.go | 363 ++++++++ internal/core/realistic_test.go | 326 +++++++ internal/frameworks/base.go | 38 + internal/options/options.go | 36 + internal/schema/schema.go | 190 ++++ internal/semantic_analysis/call_graph.go | 289 ++++++ internal/semantic_analysis/codeql/codeql.go | 48 + internal/semantic_analysis/codeql/errors.go | 15 + internal/semantic_analysis/codeql/loader.go | 22 + internal/semantic_analysis/codeql/runner.go | 27 + internal/syntactic_analysis/export.go | 15 + internal/syntactic_analysis/signature.go | 86 ++ internal/syntactic_analysis/symbol_table.go | 944 ++++++++++++++++++++ internal/utils/fs.go | 81 ++ internal/utils/logging.go | 30 + testdata/fixture/go.mod | 3 + testdata/fixture/main.go | 15 + testdata/fixture/pkg/greeter/greeter.go | 29 + testdata/realistic/go.mod | 3 + testdata/realistic/main.go | 26 + testdata/realistic/server/middleware.go | 19 + testdata/realistic/server/server.go | 53 ++ testdata/realistic/worker/worker.go | 65 ++ 32 files changed, 3618 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 cmd/codeanalyzer/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/analysis/pass.go create mode 100644 internal/analysis/registry.go create mode 100644 internal/core/analyzer.go create mode 100644 internal/core/analyzer_test.go create mode 100644 internal/core/realistic_test.go create mode 100644 internal/frameworks/base.go create mode 100644 internal/options/options.go create mode 100644 internal/schema/schema.go create mode 100644 internal/semantic_analysis/call_graph.go create mode 100644 internal/semantic_analysis/codeql/codeql.go create mode 100644 internal/semantic_analysis/codeql/errors.go create mode 100644 internal/semantic_analysis/codeql/loader.go create mode 100644 internal/semantic_analysis/codeql/runner.go create mode 100644 internal/syntactic_analysis/export.go create mode 100644 internal/syntactic_analysis/signature.go create mode 100644 internal/syntactic_analysis/symbol_table.go create mode 100644 internal/utils/fs.go create mode 100644 internal/utils/logging.go create mode 100644 testdata/fixture/go.mod create mode 100644 testdata/fixture/main.go create mode 100644 testdata/fixture/pkg/greeter/greeter.go create mode 100644 testdata/realistic/go.mod create mode 100644 testdata/realistic/main.go create mode 100644 testdata/realistic/server/middleware.go create mode 100644 testdata/realistic/server/server.go create mode 100644 testdata/realistic/worker/worker.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..169d632 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Binaries +/codeanalyzer +/codeanalyzer-go +*.exe + +# Build output +/dist/ +/bin/ + +# Claude Code session data +.claude/ + +# macOS +.DS_Store + +# Go test cache / coverage +*.test +*.out +coverage.txt diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..608e3a3 --- /dev/null +++ b/README.md @@ -0,0 +1,209 @@ +# codeanalyzer-go + +Static analysis for Go using `golang.org/x/tools/go/packages` (AST + type resolution). + +Produces `analysis.json` (symbol table + call graph) in the [CLDK canonical schema](https://github.com/codellm-devkit/python-sdk), consumable by the Python SDK via `CLDK(language="go").analysis(project_path=...)`. + +## Prerequisites + +- **Go 1.25+** — the only required runtime. Install from [go.dev/dl](https://go.dev/dl/). Developed and tested on Go 1.26.4. + +Verify: +```bash +go version +``` + +The binary is self-contained. No other tools are required for Level 1 analysis. + +## Building + +```bash +git clone https://github.com/codellm-devkit/codeanalyzer-go +cd codeanalyzer-go +go build -o codeanalyzer-go ./cmd/codeanalyzer +``` + +This produces a single static binary `codeanalyzer-go` with no runtime dependencies. + +## Usage + +```bash +codeanalyzer-go -i /path/to/go/project +``` + +### Command-line options + +``` +codeanalyzer-go produces analysis.json (symbol table + call graph) for Go projects. + +Usage: + codeanalyzer-go [flags] + +Flags: + -a, --analysis-level int Analysis level: 1=symbol table only, 2=+resolver call graph (default 1) + -c, --cache-dir string Cache directory (default: ~/.cldk/go-cache) + --codeql Enable CodeQL framework-based call graph (level 2, stub) + --eager Force clean rebuild (ignore cache) + -f, --format string Output format: json (default "json"); msgpack is not yet implemented + -h, --help help for codeanalyzer-go + -i, --input string Project root to analyze (required) + -o, --output string Output directory for analysis.json (default: stdout) + --skip-tests Skip *_test.go files (default true) + -t, --target-files strings Restrict analysis to specific files (incremental mode) + -v, --verbose count Verbosity (repeat for more detail) + --version Print version and exit +``` + +### Examples + +**Symbol table only (Level 1, default):** +```bash +codeanalyzer-go -i ./my-go-project +``` +Prints `analysis.json` to stdout. + +**Symbol table + call graph (Level 2):** +```bash +codeanalyzer-go -i ./my-go-project -a 2 +``` + +**Write output to a directory:** +```bash +codeanalyzer-go -i ./my-go-project -a 2 -o /path/to/output/ +# Writes: /path/to/output/analysis.json +``` + +**Incremental analysis (specific files only):** +```bash +codeanalyzer-go -i ./my-go-project -t pkg/server/server.go -t pkg/server/handler.go +``` + +**Force rebuild, ignore cache:** +```bash +codeanalyzer-go -i ./my-go-project --eager +``` + +**Verbose output:** +```bash +codeanalyzer-go -i ./my-go-project -a 2 -vv +``` + +## Analysis levels + +| Level | Flag | What runs | Status | +|-------|------|-----------|--------| +| 1 | `-a 1` (default) | Symbol table only — types, functions, call sites | Implemented | +| 2 | `-a 2` | Level 1 + resolver-based call graph via `go/types` | Implemented | +| — | `--codeql` | CodeQL framework-based call graph (merged with Level 2 edges) | Stub (not yet implemented) | + +**Level 1** loads each package with `packages.NeedSyntax | NeedTypes | NeedTypesInfo` and walks the AST file by file. Call sites are recorded with `callee_signature = null` at this stage. + +**Level 2** adds a resolver pass: for each call site, `go/types` resolves the callee to its full import-path signature (`pkgImportPath.TypeName.MethodName`). Only project-internal edges (both endpoints present in the symbol table) are emitted. `callee_signature` is backfilled on all successfully resolved sites. + +## Output schema + +The root object is `GoApplication`: + +```json +{ + "symbol_table": { + "pkg/greeter/greeter.go": { + "file_path": "pkg/greeter/greeter.go", + "module_name": "greeter", + "imports": [...], + "classes": { + "Greeter": { + "name": "Greeter", + "signature": "example.com/pkg/greeter.Greeter", + "is_interface": false, + "fields": [{ "name": "Prefix", "type": "string", "tags": {"json": "prefix"} }], + "methods": { ... } + } + }, + "functions": { ... } + } + }, + "call_graph": [ + { + "source": "example.com/main.main", + "target": "example.com/pkg/greeter.Greeter.Greet", + "type": "CALL_DEP", + "weight": 1, + "provenance": ["go/types"] + } + ], + "entrypoints": {} +} +``` + +Key schema properties: +- `symbol_table` — keyed by **file path relative to the project root** (never absolute) +- `classes` — JSON key for types (spine compatibility with Java/Python schemas); value is `GoType` +- `module_name` — JSON key for the Go package name (spine compatibility) +- `GoType.is_interface: bool` — unified type model; structs and interfaces are both `GoType` +- `GoCallable.receiver_type / receiver_name` — non-empty for methods, empty for package-level functions +- `GoCallable.return_types: List[str]` — individual return types (Go-specific extension) +- `GoCallsite.is_goroutine: bool` — true when the call is preceded by the `go` keyword +- `GoCallEdge.provenance: List[str]` — resolver identifiers, e.g. `["go/types"]` or `["go/types","codeql"]` +- Call edges are **identity-only**: source and target are `GoCallable.signature` strings that exist in the symbol table + +## Python SDK (CLDK) integration + +```python +from cldk import CLDK + +analysis = CLDK(language="go").analysis(project_path="/path/to/go/project") +for file_path, go_file in analysis.get_symbol_table().items(): + print(file_path, go_file.module_name) +``` + +See [python-sdk](https://github.com/codellm-devkit/python-sdk) for full API documentation. + +## Architecture & Tooling + +| Slot | Choice | Rationale | +|------|--------|-----------| +| Runtime | Go binary | Self-contained; no runtime dep for SDK users | +| Structural parser | `go/ast` (stdlib) | Part of the standard toolchain; no external dep | +| Type resolver | `golang.org/x/tools/go/packages` | Single API for both AST + full type resolution; handles modules natively | +| Optional enrichment | CodeQL (stubbed) | Same enrichment path as Python/Java analyzers; stubbed for Level 1 | +| Build/dep materialization | `go mod download` | Required before `packages.Load` so the module cache is warm; result cached by `go.sum` hash | +| Packaging | Native binary (`go build`) | Zero-runtime-dep distribution; matches Rust/C++ analyzers | +| Analysis depth | Level 1 (rapid) | Symbol table + resolver call graph; CodeQL stub wired but not implemented | +| Call-graph dispatch | Declared-type resolution via `go/types.Selections` | CHA-equivalent; sufficient for cross-package reachability at Level 1 | + +### Package structure + +``` +codeanalyzer-go/ +├── cmd/codeanalyzer/ # CLI entry point (cobra) +├── internal/ +│ ├── core/ # Orchestrator — delegates only, no inlined analysis +│ ├── schema/ # GoApplication, GoFile, GoType, GoCallable, … (schema.go) +│ ├── options/ # AnalysisOptions + AnalysisLevel constants +│ ├── syntactic_analysis/ # SymbolTableBuilder (packages.Load → AST walk) +│ ├── semantic_analysis/ # CallGraphBuilder (go/types resolver) +│ │ └── codeql/ # CodeQL backend subpackage (stubbed) +│ ├── analysis/ # Pluggable pass interface + registry (topo-ordered pipeline) +│ ├── frameworks/ # BaseEntrypointFinder — extension seam for framework passes +│ └── utils/ # DiscoverGoFiles, IsVendored, IsTestFile, logging +└── testdata/fixture/ # Minimal Go fixture used by tests +``` + +The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `semantic_analysis` → `analysis.RunPipeline` → optional CodeQL in sequence, with no inlined parsing logic. Framework-specific analysis extends through the `analysis/` + `frameworks/` layer without touching `core`. + +## Development + +### Running tests + +```bash +go test ./... +``` + +Tests run against `testdata/fixture/` and `testdata/realistic/` — a minimal two-package and a richer multi-package Go module. All 33 tests cover symbol table correctness, call graph edges, JSON round-trip, output format validation, and caching/incremental behaviour. + +### Running from source + +```bash +go run ./cmd/codeanalyzer -i /path/to/project -a 2 +``` diff --git a/cmd/codeanalyzer/main.go b/cmd/codeanalyzer/main.go new file mode 100644 index 0000000..7425781 --- /dev/null +++ b/cmd/codeanalyzer/main.go @@ -0,0 +1,104 @@ +// codeanalyzer-go is the CLI entry point for the Go language analyzer. +// +// It exposes the standard CLDK CLI surface (cli-contract.md) so the Python SDK +// facade can shell out to it uniformly alongside Java and Python backends. +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +const version = "0.1.0" + +func main() { + if err := rootCmd().Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func rootCmd() *cobra.Command { + var ( + inputPath string + outputDir string + format string + level int + targetFiles []string + skipTests bool + eager bool + cacheDir string + useCodeQL bool + verbosity int + showVersion bool + ) + + cmd := &cobra.Command{ + Use: "codeanalyzer-go", + Short: "Static analysis for Go — symbol table and call graph via go/types", + Long: `codeanalyzer-go produces analysis.json (symbol table + call graph) for Go projects. + +The output conforms to the CLDK canonical schema so the Python SDK can load it +via CLDK(language="go").analysis(project_path=...).`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if showVersion { + fmt.Println("codeanalyzer-go " + version) + return nil + } + if inputPath == "" { + return fmt.Errorf("--input / -i is required") + } + utils.SetVerbosity(verbosity) + + if cacheDir == "" { + home, _ := os.UserHomeDir() + cacheDir = home + "/.cldk/go-cache" + } + + opts := options.AnalysisOptions{ + InputPath: inputPath, + OutputDir: outputDir, + Format: format, + Level: options.AnalysisLevel(level), + TargetFiles: targetFiles, + SkipTests: skipTests, + Eager: eager, + CacheDir: cacheDir, + UseCodeQL: useCodeQL, + Verbose: verbosity > 0, + } + + analyzer := core.New(opts) + app, err := analyzer.Analyze() + if err != nil { + return err + } + + return core.WriteOutput(app, outputDir, format) + }, + } + + f := cmd.Flags() + f.StringVarP(&inputPath, "input", "i", "", "Project root to analyze (required)") + f.StringVarP(&outputDir, "output", "o", "", "Output directory for analysis.json (default: stdout)") + f.StringVarP(&format, "format", "f", "json", "Output format: json|msgpack") + f.IntVarP(&level, "analysis-level", "a", 1, + "Analysis level: 1=symbol table only, 2=+resolver call graph") + f.StringSliceVarP(&targetFiles, "target-files", "t", nil, + "Restrict analysis to specific files (incremental mode)") + f.BoolVar(&skipTests, "skip-tests", true, "Skip *_test.go files") + f.BoolVar(&eager, "eager", false, "Force clean rebuild (ignore cache)") + f.StringVarP(&cacheDir, "cache-dir", "c", "", "Cache directory (default: ~/.cldk/go-cache)") + f.BoolVar(&useCodeQL, "codeql", false, "Enable CodeQL framework-based call graph (level 2, stub)") + f.CountVarP(&verbosity, "verbose", "v", "Verbosity (repeat for more detail)") + f.BoolVar(&showVersion, "version", false, "Print version and exit") + + return cmd +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c913971 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/codellm-devkit/codeanalyzer-go + +go 1.25.0 + +require ( + github.com/spf13/cobra v1.8.0 + golang.org/x/tools v0.46.0 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/mod v0.37.0 // indirect + golang.org/x/sync v0.21.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..88530cf --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ= +golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0= +golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= +golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk= +golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/analysis/pass.go b/internal/analysis/pass.go new file mode 100644 index 0000000..3142774 --- /dev/null +++ b/internal/analysis/pass.go @@ -0,0 +1,38 @@ +// Package analysis defines the pluggable pass layer for codeanalyzer-go. +// +// An AnalysisPass enriches a GoApplication after the base analysis (symbol table +// + call graph) is built. Passes declare capability tokens in Provides/Requires +// and are ordered topologically by the registry before running. +// +// This mirrors codeanalyzer-python's analysis/_pass.py. The seam exists so that +// codeanalyzer-extension-builder can register out-of-tree passes. +package analysis + +import "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + +// AnalysisContext carries the shared context available to every pass. +type AnalysisContext struct { + ProjectDir string + CacheDir string +} + +// AnalysisResult holds the output of a single pass run. +type AnalysisResult struct { + // Entrypoints discovered by this pass, keyed by framework name. + Entrypoints map[string][]schema.GoEntrypoint + // SyntheticEdges are additional call-graph edges contributed by this pass. + SyntheticEdges []schema.GoCallEdge +} + +// AnalysisPass is the interface every built-in and out-of-tree pass implements. +type AnalysisPass interface { + // Name is the unique identifier for this pass (e.g. "gin-entrypoints"). + Name() string + // Provides is the set of capability tokens this pass adds to the application. + Provides() []string + // Requires is the set of capability tokens that must have been provided before + // this pass runs. The registry hard-errors on unsatisfied dependencies. + Requires() []string + // Run performs the analysis and returns its contributions. + Run(app *schema.GoApplication, ctx AnalysisContext) (AnalysisResult, error) +} diff --git a/internal/analysis/registry.go b/internal/analysis/registry.go new file mode 100644 index 0000000..2503578 --- /dev/null +++ b/internal/analysis/registry.go @@ -0,0 +1,90 @@ +// Package analysis — registry discovers, orders, and runs passes. +package analysis + +import ( + "fmt" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +var registeredPasses []AnalysisPass + +// RegisterPass adds a pass to the built-in registry. Call from init() in each +// pass file to register without touching the registry directly. +func RegisterPass(p AnalysisPass) { + registeredPasses = append(registeredPasses, p) +} + +// orderPasses performs a topological sort by Requires/Provides. +// Returns an error if a dependency is unsatisfied or a cycle exists. +func orderPasses(passes []AnalysisPass) ([]AnalysisPass, error) { + provided := map[string]bool{} + var ordered []AnalysisPass + remaining := make([]AnalysisPass, len(passes)) + copy(remaining, passes) + + for len(remaining) > 0 { + progress := false + var next []AnalysisPass + for _, p := range remaining { + ready := true + for _, req := range p.Requires() { + if !provided[req] { + ready = false + break + } + } + if ready { + ordered = append(ordered, p) + for _, cap := range p.Provides() { + provided[cap] = true + } + progress = true + } else { + next = append(next, p) + } + } + if !progress { + return nil, fmt.Errorf("unsatisfied pass dependencies or cycle among: %v", + func() []string { + names := make([]string, len(remaining)) + for i, p := range remaining { + names[i] = p.Name() + } + return names + }()) + } + remaining = next + } + return ordered, nil +} + +// RunPipeline runs all registered passes over app in dependency order, +// merging each result into the running application before the next pass. +// Pass output is deliberately not cached — out-of-tree enrichment must not go stale. +func RunPipeline(app *schema.GoApplication, ctx AnalysisContext) error { + ordered, err := orderPasses(registeredPasses) + if err != nil { + return err + } + if len(ordered) == 0 { + utils.Debug("no registered analysis passes; skipping pipeline") + return nil + } + for _, p := range ordered { + utils.Info("running pass: %s", p.Name()) + result, err := p.Run(app, ctx) + if err != nil { + utils.Warn("pass %s failed: %v (continuing)", p.Name(), err) + continue + } + // Merge entrypoints + for framework, eps := range result.Entrypoints { + app.Entrypoints[framework] = append(app.Entrypoints[framework], eps...) + } + // Merge synthetic edges + app.CallGraph = append(app.CallGraph, result.SyntheticEdges...) + } + return nil +} diff --git a/internal/core/analyzer.go b/internal/core/analyzer.go new file mode 100644 index 0000000..2b9f410 --- /dev/null +++ b/internal/core/analyzer.go @@ -0,0 +1,201 @@ +// Package core is the ORCHESTRATOR for codeanalyzer-go analysis. +// +// Analyzer.Analyze() delegates each phase to its own package; it inlines no +// analysis logic and never hardcodes entrypoints. This mirrors the structural +// discipline of codeanalyzer-python/codeanalyzer/core.py. +// +// Phase order: +// 1. Project materialization (go mod download) +// 2. Symbol table construction (syntactic_analysis) +// 3. Resolver-based call graph (semantic_analysis) — if level >= 2 +// 4. Pass pipeline (analysis/registry) +// 5. Optional CodeQL enrichment (semantic_analysis/codeql) — if --codeql +package core + +import ( + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/codellm-devkit/codeanalyzer-go/internal/analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/semantic_analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/semantic_analysis/codeql" + "github.com/codellm-devkit/codeanalyzer-go/internal/syntactic_analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// Analyzer is the top-level analysis driver. Construct with New() and call Analyze(). +type Analyzer struct { + opts options.AnalysisOptions +} + +// New creates an Analyzer for the given options. +func New(opts options.AnalysisOptions) *Analyzer { + return &Analyzer{opts: opts} +} + +// Analyze runs the full analysis pipeline and returns a GoApplication. +func (a *Analyzer) Analyze() (*schema.GoApplication, error) { + // Resolve to absolute path so filepath.Rel works correctly in all cases. + abs, err := filepath.Abs(a.opts.InputPath) + if err != nil { + return nil, fmt.Errorf("resolving input path: %w", err) + } + a.opts.InputPath = abs + utils.Info("analyzing project: %s", a.opts.InputPath) + + // ── Phase 1: Project materialization ────────────────────────────────────── + if err := a.materialize(); err != nil { + // Degrade gracefully — log but don't abort. Partial types are better than nothing. + utils.Warn("dependency materialization failed: %v (continuing with partial types)", err) + } + + // ── Phase 2: Symbol table construction ─────────────────────────────────── + builder := syntactic_analysis.NewSymbolTableBuilder(a.opts.InputPath) + symbolTable, err := builder.Build(a.opts.TargetFiles, a.opts.SkipTests) + if err != nil { + return nil, fmt.Errorf("symbol table construction failed: %w", err) + } + utils.Info("symbol table: %d files", len(symbolTable)) + + app := &schema.GoApplication{ + SymbolTable: symbolTable, + CallGraph: []schema.GoCallEdge{}, + Entrypoints: map[string][]schema.GoEntrypoint{}, + } + + if a.opts.Level < options.LevelCallGraph { + // Level 1 (symbol-table only) — skip call graph and passes. + return a.finalizeAndCache(app) + } + + // ── Phase 3: Resolver-based call graph ──────────────────────────────────── + cgBuilder := semantic_analysis.NewCallGraphBuilder( + a.opts.InputPath, builder.Fset(), builder.Pkgs(), + ) + edges := cgBuilder.Build(symbolTable) + app.CallGraph = edges + utils.Info("call graph: %d edges", len(edges)) + + // ── Phase 4: Pass pipeline ──────────────────────────────────────────────── + ctx := analysis.AnalysisContext{ + ProjectDir: a.opts.InputPath, + CacheDir: a.opts.CacheDir, + } + if err := analysis.RunPipeline(app, ctx); err != nil { + utils.Warn("pass pipeline error: %v", err) + } + + // ── Phase 5: Optional CodeQL enrichment ────────────────────────────────── + if a.opts.UseCodeQL { + cq, err := codeql.New(a.opts.CacheDir, true) + if err != nil { + utils.Warn("CodeQL unavailable: %v", err) + } else { + if err := cq.Build(a.opts.InputPath); err != nil { + utils.Warn("CodeQL build failed: %v", err) + } else { + cqlEdges, err := cq.Edges() + if err != nil { + utils.Warn("CodeQL edge extraction failed: %v", err) + } else { + app.CallGraph = semantic_analysis.MergeEdges(app.CallGraph, cqlEdges) + utils.Info("merged %d CodeQL edges", len(cqlEdges)) + } + } + } + } + + return a.finalizeAndCache(app) +} + +// materialize runs `go mod download` to ensure the module graph is available +// for go/packages to resolve imports and types. Idempotent and cached. +func (a *Analyzer) materialize() error { + goModPath := filepath.Join(a.opts.InputPath, "go.mod") + if _, err := os.Stat(goModPath); os.IsNotExist(err) { + utils.Debug("no go.mod found at %s; skipping go mod download", a.opts.InputPath) + return nil + } + + // Check cache: if the go.sum hasn't changed, skip download. + if !a.opts.Eager { + goSumPath := filepath.Join(a.opts.InputPath, "go.sum") + cacheKey := filepath.Join(a.opts.CacheDir, "go_mod_hash") + if currentHash, err := utils.FileHash(goSumPath); err == nil { + if cachedHash, err := os.ReadFile(cacheKey); err == nil && string(cachedHash) == currentHash { + utils.Debug("go mod download: cache hit, skipping") + return nil + } + } + defer func() { + goSumPath := filepath.Join(a.opts.InputPath, "go.sum") + if currentHash, err := utils.FileHash(goSumPath); err == nil { + _ = utils.EnsureDir(a.opts.CacheDir) + _ = os.WriteFile(cacheKey, []byte(currentHash), 0o644) + } + }() + } + + utils.Info("running go mod download...") + cmd := exec.Command("go", "mod", "download") + cmd.Dir = a.opts.InputPath + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + return cmd.Run() +} + +// finalizeAndCache caches the application and returns it. +func (a *Analyzer) finalizeAndCache(app *schema.GoApplication) (*schema.GoApplication, error) { + if a.opts.CacheDir != "" { + _ = a.saveCache(app) + } + return app, nil +} + +// saveCache persists the application to cache as analysis_cache.json. +func (a *Analyzer) saveCache(app *schema.GoApplication) error { + if err := utils.EnsureDir(a.opts.CacheDir); err != nil { + return err + } + cachePath := filepath.Join(a.opts.CacheDir, "analysis_cache.json") + data, err := json.Marshal(app) + if err != nil { + return err + } + return os.WriteFile(cachePath, data, 0o644) +} + +// WriteOutput writes the GoApplication to outputDir/analysis.json (or stdout +// when outputDir is empty). Only "json" is supported; "msgpack" and other +// values return an explicit error rather than silently falling back to JSON. +func WriteOutput(app *schema.GoApplication, outputDir, format string) error { + if format == "" { + format = "json" + } + switch format { + case "json": + // only supported format + case "msgpack": + return fmt.Errorf("msgpack output is not yet implemented; use --format json") + default: + return fmt.Errorf("unsupported output format %q; supported: json", format) + } + + data, err := json.Marshal(app) + if err != nil { + return err + } + if outputDir == "" { + _, err = os.Stdout.Write(data) + return err + } + if err := utils.EnsureDir(outputDir); err != nil { + return err + } + return os.WriteFile(filepath.Join(outputDir, "analysis.json"), data, 0o644) +} diff --git a/internal/core/analyzer_test.go b/internal/core/analyzer_test.go new file mode 100644 index 0000000..8d322ba --- /dev/null +++ b/internal/core/analyzer_test.go @@ -0,0 +1,363 @@ +package core_test + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// fixtureDir returns the absolute path to testdata/fixture. +func fixtureDir(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("cannot determine source file path") + } + // internal/core/analyzer_test.go → ../.. → codeanalyzer-go root → testdata/fixture + root := filepath.Join(filepath.Dir(file), "..", "..") + abs, err := filepath.Abs(filepath.Join(root, "testdata", "fixture")) + if err != nil { + t.Fatalf("resolving fixture dir: %v", err) + } + return abs +} + +func runAnalysis(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { + t.Helper() + dir := fixtureDir(t) + outDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: outDir, + Level: level, + SkipTests: true, + CacheDir: t.TempDir(), + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("Analyze() failed: %v", err) + } + return app +} + +// ── Symbol table tests ──────────────────────────────────────────────────────── + +func TestSymbolTable_NonEmpty(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + if len(app.SymbolTable) == 0 { + t.Fatal("symbol table is empty") + } +} + +func TestSymbolTable_PathKeysAreRelative(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + for key := range app.SymbolTable { + if filepath.IsAbs(key) { + t.Errorf("symbol_table key is absolute path: %s", key) + } + } +} + +func TestSymbolTable_KnownType(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + const wantFile = "pkg/greeter/greeter.go" + f, ok := app.SymbolTable[wantFile] + if !ok { + t.Fatalf("file %q not in symbol table; got keys: %v", wantFile, keys(app.SymbolTable)) + } + if _, ok := f.Types["Greeter"]; !ok { + t.Errorf("GoType 'Greeter' not found in %s", wantFile) + } +} + +func TestSymbolTable_KnownInterface(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + f := app.SymbolTable["pkg/greeter/greeter.go"] + gt, ok := f.Types["Logger"] + if !ok { + t.Fatal("GoType 'Logger' not found") + } + if !gt.IsInterface { + t.Error("Logger.is_interface should be true") + } +} + +func TestSymbolTable_StructFields(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + f := app.SymbolTable["pkg/greeter/greeter.go"] + gt := f.Types["Greeter"] + if len(gt.Fields) == 0 { + t.Fatal("Greeter has no fields") + } + if gt.Fields[0].Name != "Prefix" { + t.Errorf("expected field 'Prefix', got %q", gt.Fields[0].Name) + } + if _, hasJSON := gt.Fields[0].Tags["json"]; !hasJSON { + t.Error("Greeter.Prefix missing json struct tag") + } +} + +func TestSymbolTable_CallSitesRecorded(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + f := app.SymbolTable["main.go"] + var mainFn *schema.GoCallable + for _, c := range f.Functions { + c := c + if c.Name == "main" { + mainFn = &c + break + } + } + if mainFn == nil { + t.Fatal("main function not found") + } + if len(mainFn.CallSites) == 0 { + t.Error("main() has no recorded call sites") + } + // All call sites must start with callee_signature == nil (pre-resolution). + for _, cs := range mainFn.CallSites { + if cs.CalleeSignature != nil { + t.Errorf("call site %q has callee_signature pre-filled during symbol-table build", cs.MethodName) + } + } +} + +// ── Call graph tests ────────────────────────────────────────────────────────── + +func TestCallGraph_NonEmpty(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + if len(app.CallGraph) == 0 { + t.Fatal("call graph is empty") + } +} + +func TestCallGraph_NoDanglingEdges(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + sigs := allSignatures(app) + for _, e := range app.CallGraph { + if !sigs[e.Source] { + t.Errorf("dangling edge source: %s", e.Source) + } + if !sigs[e.Target] { + t.Errorf("dangling edge target: %s", e.Target) + } + } +} + +func TestCallGraph_Provenance(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + for _, e := range app.CallGraph { + if len(e.Provenance) == 0 { + t.Errorf("edge %s→%s has empty provenance", e.Source, e.Target) + } + } +} + +func TestCallGraph_CallSitesBackfilled(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + f := app.SymbolTable["main.go"] + for _, callable := range f.Functions { + for _, cs := range callable.CallSites { + // Sites that resolved to a project-internal callee must be backfilled. + if cs.CalleeSignature != nil && *cs.CalleeSignature == "" { + t.Errorf("callable %s: call site %q has empty string callee_signature", callable.Signature, cs.MethodName) + } + } + } +} + +// ── JSON output tests ───────────────────────────────────────────────────────── + +func TestWriteOutput_ValidJSON(t *testing.T) { + app := runAnalysis(t, options.LevelCallGraph) + outDir := t.TempDir() + if err := core.WriteOutput(app, outDir, "json"); err != nil { + t.Fatalf("WriteOutput: %v", err) + } + data, err := os.ReadFile(filepath.Join(outDir, "analysis.json")) + if err != nil { + t.Fatalf("reading analysis.json: %v", err) + } + var round schema.GoApplication + if err := json.Unmarshal(data, &round); err != nil { + t.Fatalf("JSON round-trip failed: %v", err) + } + if len(round.SymbolTable) == 0 { + t.Error("round-tripped symbol table is empty") + } +} + +func TestWriteOutput_EmptyFormatDefaultsToJSON(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + outDir := t.TempDir() + if err := core.WriteOutput(app, outDir, ""); err != nil { + t.Fatalf("WriteOutput with empty format: %v", err) + } + if _, err := os.Stat(filepath.Join(outDir, "analysis.json")); err != nil { + t.Fatalf("analysis.json not written: %v", err) + } +} + +func TestWriteOutput_MsgpackNotImplemented(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + outDir := t.TempDir() + err := core.WriteOutput(app, outDir, "msgpack") + if err == nil { + t.Fatal("expected error for --format msgpack, got nil") + } +} + +func TestWriteOutput_UnknownFormatErrors(t *testing.T) { + app := runAnalysis(t, options.LevelSymbolTable) + outDir := t.TempDir() + err := core.WriteOutput(app, outDir, "csv") + if err == nil { + t.Fatal("expected error for unknown format, got nil") + } +} + +// ── Caching tests ───────────────────────────────────────────────────────────── + +func TestCaching_SecondRunReuses(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + outDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: outDir, + Level: options.LevelCallGraph, + SkipTests: true, + CacheDir: cacheDir, + } + // First run — populates cache. + app1, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("first run: %v", err) + } + // Second run — must not error and must return identical key count. + app2, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("second run: %v", err) + } + if len(app2.SymbolTable) == 0 { + t.Error("second run returned empty symbol table") + } + if len(app2.SymbolTable) != len(app1.SymbolTable) { + t.Errorf("symbol table key count changed between runs: %d → %d", + len(app1.SymbolTable), len(app2.SymbolTable)) + } +} + +func TestCaching_CacheFileWritten(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: cacheDir, + } + if _, err := core.New(opts).Analyze(); err != nil { + t.Fatalf("Analyze: %v", err) + } + cachePath := filepath.Join(cacheDir, "analysis_cache.json") + if _, err := os.Stat(cachePath); err != nil { + t.Fatalf("analysis_cache.json not written to CacheDir: %v", err) + } +} + +func TestCaching_CacheContentsRoundTrip(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: cacheDir, + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + data, err := os.ReadFile(filepath.Join(cacheDir, "analysis_cache.json")) + if err != nil { + t.Fatalf("reading analysis_cache.json: %v", err) + } + var cached schema.GoApplication + if err := json.Unmarshal(data, &cached); err != nil { + t.Fatalf("cache JSON round-trip failed: %v", err) + } + if len(cached.SymbolTable) != len(app.SymbolTable) { + t.Errorf("cache symbol table key count %d != in-memory %d", + len(cached.SymbolTable), len(app.SymbolTable)) + } +} + +func TestCaching_EagerForcesRebuild(t *testing.T) { + dir := fixtureDir(t) + cacheDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: cacheDir, + } + // First run (non-eager) — seeds go_mod_hash. + if _, err := core.New(opts).Analyze(); err != nil { + t.Fatalf("first run: %v", err) + } + cachePath := filepath.Join(cacheDir, "analysis_cache.json") + info1, err := os.Stat(cachePath) + if err != nil { + t.Fatalf("cache not written after first run: %v", err) + } + + time.Sleep(10 * time.Millisecond) + + // Second run with Eager=true — must rewrite cache even when go_mod_hash matches. + opts.Eager = true + if _, err := core.New(opts).Analyze(); err != nil { + t.Fatalf("eager run: %v", err) + } + info2, err := os.Stat(cachePath) + if err != nil { + t.Fatalf("cache not found after eager run: %v", err) + } + // saveCache always writes, so mtime must advance. + if !info2.ModTime().After(info1.ModTime()) { + t.Errorf("analysis_cache.json mtime did not advance on eager=true run: %v vs %v", + info1.ModTime(), info2.ModTime()) + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +func allSignatures(app *schema.GoApplication) map[string]bool { + sigs := map[string]bool{} + for _, f := range app.SymbolTable { + for sig := range f.Functions { + sigs[sig] = true + } + for _, t := range f.Types { + for sig := range t.Methods { + sigs[sig] = true + } + } + } + return sigs +} + +func keys[K comparable, V any](m map[K]V) []K { + out := make([]K, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} diff --git a/internal/core/realistic_test.go b/internal/core/realistic_test.go new file mode 100644 index 0000000..3f651f4 --- /dev/null +++ b/internal/core/realistic_test.go @@ -0,0 +1,326 @@ +package core_test + +// Targeted tests for Go-specific schema fields that the greeter fixture does not exercise: +// is_goroutine, return_types (multiple), is_exported=false, receiver_type/name, +// is_variadic, is_embedded, multi-file package, cyclomatic_complexity, specific edges. + +import ( + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +func realisticDir(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("cannot determine source file path") + } + root := filepath.Join(filepath.Dir(file), "..", "..") + abs, err := filepath.Abs(filepath.Join(root, "testdata", "realistic")) + if err != nil { + t.Fatalf("resolving realistic fixture dir: %v", err) + } + return abs +} + +func runRealistic(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { + t.Helper() + dir := realisticDir(t) + outDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: outDir, + Level: level, + SkipTests: true, + CacheDir: t.TempDir(), + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Fatalf("Analyze() failed: %v", err) + } + return app +} + +// findCallableByName searches all functions and methods in a GoFile by short name. +func findCallableByName(f schema.GoFile, name string) *schema.GoCallable { + for _, c := range f.Functions { + if c.Name == name { + c := c + return &c + } + } + for _, gt := range f.Types { + for _, m := range gt.Methods { + if m.Name == name { + m := m + return &m + } + } + } + return nil +} + +// ── Multi-file package ──────────────────────────────────────────────────────── + +func TestRealistic_MultiFilePkg(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + _, hasServer := app.SymbolTable["server/server.go"] + _, hasMiddleware := app.SymbolTable["server/middleware.go"] + if !hasServer { + t.Error("server/server.go missing from symbol table") + } + if !hasMiddleware { + t.Error("server/middleware.go missing from symbol table") + } + // Tags must live in middleware.go, not server.go. + mw := app.SymbolTable["server/middleware.go"] + if findCallableByName(mw, "Tags") == nil { + t.Error("Tags function not found in server/middleware.go") + } +} + +// ── Embedded struct field ───────────────────────────────────────────────────── + +func TestRealistic_EmbeddedField(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + server, ok := srv.Types["Server"] + if !ok { + t.Fatal("GoType 'Server' not found in server/server.go") + } + for _, f := range server.Fields { + if f.IsEmbedded { + return // pass + } + } + t.Errorf("Server has no embedded field; fields: %+v", server.Fields) +} + +// ── Multiple return types — (T, error) pattern ──────────────────────────────── + +func TestRealistic_MultipleReturnTypes(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + newFn := findCallableByName(srv, "New") + if newFn == nil { + t.Fatal("function 'New' not found in server/server.go") + } + if len(newFn.ReturnTypes) < 2 { + t.Fatalf("New() should have >= 2 return types; got %v", newFn.ReturnTypes) + } + hasError := false + for _, rt := range newFn.ReturnTypes { + if rt == "error" { + hasError = true + } + } + if !hasError { + t.Errorf("New() return_types should include 'error'; got %v", newFn.ReturnTypes) + } +} + +func TestRealistic_ValidateReturnTypes(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + validate := findCallableByName(srv, "Validate") + if validate == nil { + t.Fatal("method 'Validate' not found in server/server.go") + } + if len(validate.ReturnTypes) != 2 { + t.Fatalf("Validate() should have 2 return types; got %v", validate.ReturnTypes) + } +} + +// ── Unexported callables ────────────────────────────────────────────────────── + +func TestRealistic_UnexportedMethod(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + shutdown := findCallableByName(srv, "shutdown") + if shutdown == nil { + t.Fatal("method 'shutdown' not found in server/server.go") + } + if shutdown.IsExported { + t.Error("shutdown.is_exported should be false") + } +} + +func TestRealistic_UnexportedWorkerMethod(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + execute := findCallableByName(wkr, "execute") + if execute == nil { + t.Fatal("method 'execute' not found in worker/worker.go") + } + if execute.IsExported { + t.Error("execute.is_exported should be false") + } +} + +// ── Receiver type / name ────────────────────────────────────────────────────── + +func TestRealistic_ReceiverType(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + srv := app.SymbolTable["server/server.go"] + addr := findCallableByName(srv, "Addr") + if addr == nil { + t.Fatal("method 'Addr' not found in server/server.go") + } + if addr.ReceiverType == "" { + t.Error("Addr().receiver_type should be non-empty") + } + if addr.ReceiverName == "" { + t.Error("Addr().receiver_name should be non-empty") + } + // Pointer receiver — type should contain '*' or 'Server'. + if !strings.Contains(addr.ReceiverType, "Server") { + t.Errorf("Addr().receiver_type %q should reference Server", addr.ReceiverType) + } +} + +func TestRealistic_ValueReceiver(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + // Describe is defined in middleware.go but its receiver type (Server) lives in + // server.go — the reconcileCrossFileMethods pass attaches it to server.go's type. + srv := app.SymbolTable["server/server.go"] + describe := findCallableByName(srv, "Describe") + if describe == nil { + t.Fatal("method 'Describe' not found attached to Server in server/server.go") + } + // Value receiver — ReceiverType should not contain '*'. + if strings.Contains(describe.ReceiverType, "*") { + t.Errorf("Describe().receiver_type %q should be a value receiver (no '*')", describe.ReceiverType) + } + // Path should still record the physical definition file. + if !strings.Contains(describe.Path, "middleware.go") { + t.Errorf("Describe().path %q should point to middleware.go", describe.Path) + } +} + +// ── Variadic parameters ─────────────────────────────────────────────────────── + +func TestRealistic_VariadicParamTags(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + mw := app.SymbolTable["server/middleware.go"] + tags := findCallableByName(mw, "Tags") + if tags == nil { + t.Fatal("function 'Tags' not found in server/middleware.go") + } + for _, p := range tags.Parameters { + if p.IsVariadic { + return // pass + } + } + t.Errorf("Tags() has no variadic parameter; params: %+v", tags.Parameters) +} + +func TestRealistic_VariadicParamCombine(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + combine := findCallableByName(wkr, "Combine") + if combine == nil { + t.Fatal("function 'Combine' not found in worker/worker.go") + } + for _, p := range combine.Parameters { + if p.IsVariadic { + return // pass + } + } + t.Errorf("Combine() has no variadic parameter; params: %+v", combine.Parameters) +} + +// ── Goroutine call site ─────────────────────────────────────────────────────── + +func TestRealistic_GoroutineCallsite(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + run := findCallableByName(wkr, "Run") + if run == nil { + t.Fatal("method 'Run' not found in worker/worker.go") + } + for _, cs := range run.CallSites { + if cs.IsGoroutine { + return // pass + } + } + t.Errorf("Run() has no goroutine call site; sites: %+v", run.CallSites) +} + +// ── Cyclomatic complexity ───────────────────────────────────────────────────── + +func TestRealistic_CyclomaticComplexity(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + execute := findCallableByName(wkr, "execute") + if execute == nil { + t.Fatal("method 'execute' not found in worker/worker.go") + } + // execute() has an `if err != nil` branch → CC >= 2. + if execute.CyclomaticComplexity < 2 { + t.Errorf("execute().cyclomatic_complexity should be >= 2; got %d", execute.CyclomaticComplexity) + } +} + +// ── Interface detection ─────────────────────────────────────────────────────── + +func TestRealistic_InterfaceType(t *testing.T) { + app := runRealistic(t, options.LevelSymbolTable) + wkr := app.SymbolTable["worker/worker.go"] + proc, ok := wkr.Types["Processor"] + if !ok { + t.Fatal("GoType 'Processor' not found in worker/worker.go") + } + if !proc.IsInterface { + t.Error("Processor.is_interface should be true") + } +} + +// ── Specific call-graph edge ────────────────────────────────────────────────── + +func TestRealistic_SpecificCallEdge(t *testing.T) { + app := runRealistic(t, options.LevelCallGraph) + // main() calls server.New() — this is a cross-package project-internal edge. + const wantTarget = "example.com/realistic/server.New" + for _, e := range app.CallGraph { + if e.Target == wantTarget { + return // pass + } + } + t.Errorf("call graph missing expected edge to %s; edges: %v", wantTarget, edgeTargets(app)) +} + +func TestRealistic_CrossPackageEdges(t *testing.T) { + app := runRealistic(t, options.LevelCallGraph) + // At least one edge must cross the main→server boundary and one main→worker boundary. + var serverEdge, workerEdge bool + for _, e := range app.CallGraph { + if strings.Contains(e.Target, "realistic/server.") { + serverEdge = true + } + if strings.Contains(e.Target, "realistic/worker.") { + workerEdge = true + } + } + if !serverEdge { + t.Error("no call-graph edge into the server package") + } + if !workerEdge { + t.Error("no call-graph edge into the worker package") + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +func edgeTargets(app *schema.GoApplication) []string { + out := make([]string, 0, len(app.CallGraph)) + for _, e := range app.CallGraph { + out = append(out, e.Target) + } + return out +} diff --git a/internal/frameworks/base.go b/internal/frameworks/base.go new file mode 100644 index 0000000..9d17251 --- /dev/null +++ b/internal/frameworks/base.go @@ -0,0 +1,38 @@ +// Package frameworks provides the base for entrypoint-finder passes. +// +// Concrete finders (gin-router, net/http handler, etc.) embed BaseEntrypointFinder +// and override FindEntrypoints. They register themselves via analysis.RegisterPass. +// +// This mirrors codeanalyzer-python's frameworks/_base.py. +package frameworks + +import ( + "github.com/codellm-devkit/codeanalyzer-go/internal/analysis" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// BaseEntrypointFinder is the abstract base for framework-specific entrypoint finders. +// Concrete implementations embed this struct and override FindEntrypoints. +type BaseEntrypointFinder struct { + name string + framework string +} + +// NewBaseEntrypointFinder constructs a finder with the given name and framework label. +func NewBaseEntrypointFinder(name, framework string) BaseEntrypointFinder { + return BaseEntrypointFinder{name: name, framework: framework} +} + +// Name implements analysis.AnalysisPass. +func (b BaseEntrypointFinder) Name() string { return b.name } + +// Provides implements analysis.AnalysisPass — finders provide the framework name as a capability. +func (b BaseEntrypointFinder) Provides() []string { return []string{b.framework + ":entrypoints"} } + +// Requires implements analysis.AnalysisPass — no dependencies by default. +func (b BaseEntrypointFinder) Requires() []string { return nil } + +// Run implements analysis.AnalysisPass — delegates to FindEntrypoints. +func (b BaseEntrypointFinder) Run(app *schema.GoApplication, ctx analysis.AnalysisContext) (analysis.AnalysisResult, error) { + return analysis.AnalysisResult{}, nil +} diff --git a/internal/options/options.go b/internal/options/options.go new file mode 100644 index 0000000..97e6782 --- /dev/null +++ b/internal/options/options.go @@ -0,0 +1,36 @@ +// Package options defines the AnalysisOptions passed from the CLI into core. +package options + +// AnalysisLevel controls how much analysis is performed. +type AnalysisLevel int + +const ( + // LevelSymbolTable produces the symbol table only (no call graph). + LevelSymbolTable AnalysisLevel = 1 + // LevelCallGraph produces symbol table + resolver-based call graph (still cheap). + LevelCallGraph AnalysisLevel = 2 +) + +// AnalysisOptions is the configuration surface passed from the CLI into Analyzer. +type AnalysisOptions struct { + // InputPath is the project root to analyze. + InputPath string + // OutputDir is where analysis.json is written. Empty = write to stdout. + OutputDir string + // Format is the serialization format: "json" or "msgpack". + Format string + // AnalysisLevel controls symbol-table-only (1) vs + call graph (2). + Level AnalysisLevel + // TargetFiles restricts analysis to specific files (incremental mode). + TargetFiles []string + // SkipTests skips test files (files ending in _test.go). + SkipTests bool + // Eager forces a clean rebuild ignoring any cache. + Eager bool + // CacheDir is where per-file caches and intermediate data are stored. + CacheDir string + // UseCodeQL enables the framework-based (Tier-2) CodeQL call graph. + UseCodeQL bool + // Verbose enables verbose logging. + Verbose bool +} diff --git a/internal/schema/schema.go b/internal/schema/schema.go new file mode 100644 index 0000000..6eedd79 --- /dev/null +++ b/internal/schema/schema.go @@ -0,0 +1,190 @@ +// Package schema defines the canonical data contract for codeanalyzer-go output. +// +// The root object is GoApplication{symbol_table, call_graph}. Every field uses +// snake_case JSON keys so the Python SDK's Pydantic models parse it without +// transformation. Design decisions are recorded in .claude/SCHEMA_DECISIONS.md. +package schema + +// ─── Leaf models ───────────────────────────────────────────────────────────── + +// GoImport represents a single import declaration in a Go source file. +type GoImport struct { + Module string `json:"module"` + Alias string `json:"alias,omitempty"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// GoComment represents a comment (line, block, or doc comment). +type GoComment struct { + Content string `json:"content"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + StartColumn int `json:"start_column"` + EndColumn int `json:"end_column"` + IsDocComment bool `json:"is_doc_comment"` +} + +// GoParameter represents a single parameter of a callable. +type GoParameter struct { + Name string `json:"name"` + Type string `json:"type"` + IsVariadic bool `json:"is_variadic"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// GoVariableDeclaration represents a variable declaration (var/short-assign). +type GoVariableDeclaration struct { + Name string `json:"name"` + Type string `json:"type,omitempty"` + Initializer string `json:"initializer,omitempty"` + Scope string `json:"scope"` // "package" | "function" + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + StartColumn int `json:"start_column"` + EndColumn int `json:"end_column"` +} + +// GoField represents a struct field, including its struct tags. +type GoField struct { + Name string `json:"name"` + Type string `json:"type"` + Comments []GoComment `json:"comments"` + Tags map[string]string `json:"tags"` // parsed struct tags, e.g. {"json": "name,omitempty"} + IsExported bool `json:"is_exported"` + IsEmbedded bool `json:"is_embedded"` // anonymous/embedded field + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// GoSymbol represents a symbol accessed inside a callable. +type GoSymbol struct { + Name string `json:"name"` + Scope string `json:"scope"` // "local" | "package" | "external" + Kind string `json:"kind"` // "variable" | "function" | "type" | "constant" + Type string `json:"type,omitempty"` + QualifiedName string `json:"qualified_name,omitempty"` + IsBuiltin bool `json:"is_builtin"` + Lineno int `json:"lineno"` + ColOffset int `json:"col_offset"` +} + +// ─── Call site ──────────────────────────────────────────────────────────────── + +// GoCallsite represents a single call expression inside a callable. +// callee_signature is null when first recorded; the resolver backfills it +// during call-graph construction (never during symbol-table build). +type GoCallsite struct { + MethodName string `json:"method_name"` + ReceiverExpr string `json:"receiver_expr,omitempty"` + ReceiverType string `json:"receiver_type,omitempty"` + ArgumentTypes []string `json:"argument_types"` + ReturnType string `json:"return_type,omitempty"` + CalleeSignature *string `json:"callee_signature"` // null until resolved + IsConstructorCall bool `json:"is_constructor_call"` + IsGoroutine bool `json:"is_goroutine"` // true when preceded by `go` keyword + StartLine int `json:"start_line"` + StartColumn int `json:"start_column"` + EndLine int `json:"end_line"` + EndColumn int `json:"end_column"` +} + +// ─── Callable ───────────────────────────────────────────────────────────────── + +// GoCallable represents a function, method, or function literal in Go. +// receiver_type / receiver_name are non-empty for methods; empty for functions. +type GoCallable struct { + Name string `json:"name"` + Path string `json:"path"` + Signature string `json:"signature"` // signatureOf() output — edge id + Comments []GoComment `json:"comments"` + Parameters []GoParameter `json:"parameters"` + ReturnType string `json:"return_type"` // joined, e.g. "(int, error)" + ReturnTypes []string `json:"return_types"` // Go extension: individual return types + Code string `json:"code,omitempty"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + CodeStartLine int `json:"code_start_line"` + AccessedSymbols []GoSymbol `json:"accessed_symbols"` + CallSites []GoCallsite `json:"call_sites"` + InnerCallables map[string]GoCallable `json:"inner_callables"` + LocalVariables []GoVariableDeclaration `json:"local_variables"` + CyclomaticComplexity int `json:"cyclomatic_complexity"` + IsEntrypoint bool `json:"is_entrypoint"` + EntrypointFramework string `json:"entrypoint_framework,omitempty"` + ReceiverType string `json:"receiver_type,omitempty"` // e.g. "*MyStruct" + ReceiverName string `json:"receiver_name,omitempty"` // e.g. "r" + IsExported bool `json:"is_exported"` +} + +// ─── Type (struct or interface) ─────────────────────────────────────────────── + +// GoType represents a named Go type — either a struct (is_interface=false) or +// an interface (is_interface=true). This unified model mirrors Go's native type +// system where both are types.Named with different Underlying() values. +// +// base_types carries: embedded struct types (for structs) and the method-set +// signatures of satisfied interfaces (for both) — the Go analog of base_classes. +type GoType struct { + Name string `json:"name"` + Signature string `json:"signature"` // signatureOf() output + Comments []GoComment `json:"comments"` + Code string `json:"code,omitempty"` + IsInterface bool `json:"is_interface"` + IsExported bool `json:"is_exported"` + Fields []GoField `json:"fields"` // empty for interfaces + Methods map[string]GoCallable `json:"methods"` // sig → callable + BaseTypes []string `json:"base_types"` // embedded types + satisfied interface sigs + InnerTypes map[string]GoType `json:"inner_types"` // Go doesn't nest, but preserves spine + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` +} + +// ─── File (Module analog) ───────────────────────────────────────────────────── + +// GoFile is the Module analog for Go: one compilation unit (source file). +// symbol_table is keyed by file path relative to the project root. +type GoFile struct { + FilePath string `json:"file_path"` + PackageName string `json:"module_name"` // JSON key = module_name for spine compat + Imports []GoImport `json:"imports"` + Comments []GoComment `json:"comments"` + Types map[string]GoType `json:"classes"` // JSON key = classes for spine compat + Functions map[string]GoCallable `json:"functions"` + Variables []GoVariableDeclaration `json:"variables"` + // Caching metadata + ContentHash *string `json:"content_hash"` + LastModified *float64 `json:"last_modified"` + FileSize *int64 `json:"file_size"` +} + +// ─── Call graph edge ────────────────────────────────────────────────────────── + +// GoCallEdge is an identity-only call-graph edge. source and target are +// GoCallable.signature strings that must exist in the symbol table. +type GoCallEdge struct { + Source string `json:"source"` + Target string `json:"target"` + Type string `json:"type"` // always "CALL_DEP" + Weight int `json:"weight"` // accumulated when merging backends + Provenance []string `json:"provenance"` // e.g. ["go/types"], ["go/types","codeql"] + Tags map[string]string `json:"tags"` +} + +// ─── Root object ────────────────────────────────────────────────────────────── + +// GoApplication is the root of analysis.json. The SDK facade deserializes this. +type GoApplication struct { + SymbolTable map[string]GoFile `json:"symbol_table"` // file_path → GoFile + CallGraph []GoCallEdge `json:"call_graph"` // identity-only edges + Entrypoints map[string][]GoEntrypoint `json:"entrypoints"` // optional, default {} +} + +// GoEntrypoint marks a callable as a framework entry point. +type GoEntrypoint struct { + Signature string `json:"signature"` + Framework string `json:"framework"` + DetectionSource string `json:"detection_source"` + Tags map[string]string `json:"tags"` +} diff --git a/internal/semantic_analysis/call_graph.go b/internal/semantic_analysis/call_graph.go new file mode 100644 index 0000000..7dd4adc --- /dev/null +++ b/internal/semantic_analysis/call_graph.go @@ -0,0 +1,289 @@ +// Package semantic_analysis builds the resolver-based call graph (Level 1, Tier 1). +// +// This stage uses the same golang.org/x/tools/go/packages load that built the +// symbol table. For each recorded call site it resolves the callee to a +// *types.Func, derives its signature via signatureOf(), backfills +// callee_signature in place, and emits an identity-only GoCallEdge. +// +// Precision choice: declared-type dispatch (CHA-style). Pointer-receiver +// methods are followed; interface dispatch records the interface method +// signature and falls back gracefully when the concrete type is unknown. +package semantic_analysis + +import ( + "go/ast" + "go/token" + "go/types" + + "golang.org/x/tools/go/packages" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// CallGraphBuilder resolves call sites and builds the call graph. +type CallGraphBuilder struct { + projectDir string + fset *token.FileSet + pkgs map[string]*packages.Package +} + +// NewCallGraphBuilder creates a builder using the same pkgs/fset loaded for the symbol table. +func NewCallGraphBuilder(projectDir string, fset *token.FileSet, pkgs map[string]*packages.Package) *CallGraphBuilder { + return &CallGraphBuilder{projectDir: projectDir, fset: fset, pkgs: pkgs} +} + +// Build resolves all call sites in symbolTable, backfills callee_signature, and +// returns the identity-only edge list. Never crashes on unresolved sites — it logs +// and skips the edge while leaving callee_signature nil. +func (cg *CallGraphBuilder) Build(symbolTable map[string]schema.GoFile) []schema.GoCallEdge { + var edges []schema.GoCallEdge + seen := map[[2]string]bool{} + + // Build the known-sig set once. Only emit edges where the target is in the + // project's symbol table (drop stdlib / external callees, same as Python/Jedi). + knownSigs := buildKnownSigs(symbolTable) + + for fileKey, goFile := range symbolTable { + // Find the package for this file. + pkg := cg.packageForFile(fileKey) + if pkg == nil || pkg.TypesInfo == nil { + continue + } + + // Resolve methods on types. + for typeSig, goType := range goFile.Types { + for methSig, callable := range goType.Methods { + newCallable, newEdges := cg.resolveCallable(pkg, callable, seen, knownSigs) + goType.Methods[methSig] = newCallable + edges = append(edges, newEdges...) + } + goFile.Types[typeSig] = goType + } + + // Resolve package-level functions. + for fnSig, callable := range goFile.Functions { + newCallable, newEdges := cg.resolveCallable(pkg, callable, seen, knownSigs) + goFile.Functions[fnSig] = newCallable + edges = append(edges, newEdges...) + } + + symbolTable[fileKey] = goFile + } + + return edges +} + +// buildKnownSigs returns the set of all callable signatures present in the symbol table. +func buildKnownSigs(symbolTable map[string]schema.GoFile) map[string]bool { + sigs := make(map[string]bool) + for _, f := range symbolTable { + for sig := range f.Functions { + sigs[sig] = true + } + for _, t := range f.Types { + for sig := range t.Methods { + sigs[sig] = true + } + } + } + return sigs +} + +// resolveCallable backfills callee_signature on each call site and produces edges. +// Only emits edges where the target is in knownSigs (project-internal callees). +// External/stdlib callees have callee_signature backfilled but no edge emitted — +// matching Python/Jedi's behavior of dropping unresolved external sites. +func (cg *CallGraphBuilder) resolveCallable( + pkg *packages.Package, + callable schema.GoCallable, + seen map[[2]string]bool, + knownSigs map[string]bool, +) (schema.GoCallable, []schema.GoCallEdge) { + var edges []schema.GoCallEdge + + for i := range callable.CallSites { + site := &callable.CallSites[i] + if site.CalleeSignature != nil { + continue // already resolved + } + + calleeSig := cg.resolveCallSite(pkg, site) + if calleeSig == "" { + utils.Debug("unresolved call site: %s in %s", site.MethodName, callable.Signature) + continue + } + + // Backfill the site regardless of whether the callee is in-project. + site.CalleeSignature = &calleeSig + + // Only emit an edge when the callee is in the project's symbol table. + if !knownSigs[calleeSig] { + utils.Debug("external callee (no edge): %s", calleeSig) + continue + } + + key := [2]string{callable.Signature, calleeSig} + if seen[key] { + continue + } + seen[key] = true + + edges = append(edges, schema.GoCallEdge{ + Source: callable.Signature, + Target: calleeSig, + Type: "CALL_DEP", + Weight: 1, + Provenance: []string{"go/types"}, + Tags: map[string]string{}, + }) + } + + return callable, edges +} + +// resolveCallSite resolves a single call site to a callee signature string. +// Returns "" when the site cannot be resolved (graceful fallback). +func (cg *CallGraphBuilder) resolveCallSite(pkg *packages.Package, site *schema.GoCallsite) string { + if pkg.TypesInfo == nil { + return "" + } + + // Walk the package's syntax to find the call expression at this location. + for _, astFile := range pkg.Syntax { + var result string + ast.Inspect(astFile, func(n ast.Node) bool { + if result != "" { + return false + } + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + pos := cg.fset.Position(call.Pos()) + if pos.Line != site.StartLine || pos.Column != site.StartColumn { + return true + } + result = cg.resolveCallExpr(pkg, call) + return false + }) + if result != "" { + return result + } + } + return "" +} + +// resolveCallExpr resolves a *ast.CallExpr to a callee signature using go/types. +func (cg *CallGraphBuilder) resolveCallExpr(pkg *packages.Package, call *ast.CallExpr) string { + info := pkg.TypesInfo + + switch fn := call.Fun.(type) { + case *ast.Ident: + if obj := info.ObjectOf(fn); obj != nil { + if f, ok := obj.(*types.Func); ok { + return calleeSignatureOf(f) + } + // Type conversion — normalize as constructor. + if tn, ok := obj.(*types.TypeName); ok { + return calleeSignatureOf(tn) + ".__new__" + } + } + case *ast.SelectorExpr: + sel, ok := info.Selections[fn] + if ok { + if f, ok := sel.Obj().(*types.Func); ok { + return calleeSignatureOf(f) + } + } + // Package-level function via qualified identifier. + if obj := info.ObjectOf(fn.Sel); obj != nil { + if f, ok := obj.(*types.Func); ok { + return calleeSignatureOf(f) + } + } + } + return "" +} + +// calleeSignatureOf is the call-site mirror of signatureOf — same canonicalization. +// Must produce byte-identical strings to signatureOf() in syntactic_analysis. +func calleeSignatureOf(obj types.Object) string { + if obj == nil { + return "" + } + pkg := obj.Pkg() + pkgPath := "" + if pkg != nil { + pkgPath = pkg.Path() + } + + switch o := obj.(type) { + case *types.Func: + sig := o.Type().(*types.Signature) + recv := sig.Recv() + if recv != nil { + recvType := recv.Type() + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + } + if named, ok := recvType.(*types.Named); ok { + typeName := named.Obj().Name() + return pkgPath + "." + typeName + "." + o.Name() + } + } + return pkgPath + "." + o.Name() + case *types.TypeName: + return pkgPath + "." + o.Name() + default: + return pkgPath + "." + obj.Name() + } +} + +// packageForFile returns the *packages.Package that contains the given relative file path. +func (cg *CallGraphBuilder) packageForFile(relPath string) *packages.Package { + for _, pkg := range cg.pkgs { + for _, f := range pkg.GoFiles { + if utils.RelativePath(cg.projectDir, f) == relPath { + return pkg + } + } + } + return nil +} + +// MergeEdges merges two edge lists, unioning provenance and accumulating weight +// for duplicate (source, target) pairs. Mirrors Python's merge_edges(). +func MergeEdges(primary, secondary []schema.GoCallEdge) []schema.GoCallEdge { + type key struct{ src, tgt string } + index := map[key]int{} + result := make([]schema.GoCallEdge, 0, len(primary)+len(secondary)) + + add := func(e schema.GoCallEdge) { + k := key{e.Source, e.Target} + if idx, exists := index[k]; exists { + // Merge provenance (union) and accumulate weight. + provSet := map[string]bool{} + for _, p := range result[idx].Provenance { + provSet[p] = true + } + for _, p := range e.Provenance { + if !provSet[p] { + result[idx].Provenance = append(result[idx].Provenance, p) + } + } + result[idx].Weight += e.Weight + } else { + index[k] = len(result) + result = append(result, e) + } + } + + for _, e := range primary { + add(e) + } + for _, e := range secondary { + add(e) + } + return result +} diff --git a/internal/semantic_analysis/codeql/codeql.go b/internal/semantic_analysis/codeql/codeql.go new file mode 100644 index 0000000..88d1d29 --- /dev/null +++ b/internal/semantic_analysis/codeql/codeql.go @@ -0,0 +1,48 @@ +package codeql + +import ( + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// CodeQL is the top-level handle for CodeQL-backed analysis. core.go talks only +// to this type; it never touches the binary, database, or query strings directly. +// TODO(level-2): implement Build() and Edges(). +type CodeQL struct { + loader *Loader + runner *Runner + enabled bool +} + +// New probes for the CodeQL binary and returns a CodeQL handle. +// If the binary is absent and enabled=true, returns an error. +func New(cacheDir string, enabled bool) (*CodeQL, error) { + if !enabled { + return &CodeQL{enabled: false}, nil + } + loader, err := NewLoader() + if err != nil { + return nil, err + } + runner := NewRunner(loader, cacheDir+"/codeql-db") + return &CodeQL{loader: loader, runner: runner, enabled: true}, nil +} + +// Build creates the CodeQL database for projectDir. +// No-op when CodeQL is disabled. +func (c *CodeQL) Build(projectDir string) error { + if !c.enabled { + return nil + } + utils.Info("building CodeQL database (stub — TODO level-2)") + return c.runner.BuildDatabase(projectDir) +} + +// Edges returns Tier-2 call-graph edges. +// Returns empty slice when disabled; returns ErrCodeQLNotImplemented when enabled (stub). +func (c *CodeQL) Edges() ([]schema.GoCallEdge, error) { + if !c.enabled { + return nil, nil + } + return c.runner.QueryCallGraph() +} diff --git a/internal/semantic_analysis/codeql/errors.go b/internal/semantic_analysis/codeql/errors.go new file mode 100644 index 0000000..a4affd1 --- /dev/null +++ b/internal/semantic_analysis/codeql/errors.go @@ -0,0 +1,15 @@ +// Package codeql is the isolated framework-backend subpackage for CodeQL analysis. +// It provides Tier-2 (framework-based) call-graph edges beyond what go/types resolves. +// +// The seams (loader, driver, query runner, errors) are scaffolded here even though +// the implementation is stubbed — dropping in the full implementation later requires +// no refactor. Mirrors codeanalyzer-python's semantic_analysis/codeql/ split. +package codeql + +import "errors" + +// ErrCodeQLNotFound is returned when the CodeQL CLI binary cannot be located. +var ErrCodeQLNotFound = errors.New("codeql: CLI binary not found; install from https://github.com/github/codeql-cli-binaries") + +// ErrCodeQLNotImplemented is returned when CodeQL analysis is requested but not yet implemented. +var ErrCodeQLNotImplemented = errors.New("codeql: Go backend is a wired stub — implementation TODO (level-2 analysis)") diff --git a/internal/semantic_analysis/codeql/loader.go b/internal/semantic_analysis/codeql/loader.go new file mode 100644 index 0000000..fce359d --- /dev/null +++ b/internal/semantic_analysis/codeql/loader.go @@ -0,0 +1,22 @@ +package codeql + +import ( + "os/exec" +) + +// Loader resolves the CodeQL CLI binary path. +type Loader struct { + binaryPath string +} + +// NewLoader creates a Loader, probing the PATH for the codeql binary. +func NewLoader() (*Loader, error) { + path, err := exec.LookPath("codeql") + if err != nil { + return nil, ErrCodeQLNotFound + } + return &Loader{binaryPath: path}, nil +} + +// BinaryPath returns the resolved CodeQL binary path. +func (l *Loader) BinaryPath() string { return l.binaryPath } diff --git a/internal/semantic_analysis/codeql/runner.go b/internal/semantic_analysis/codeql/runner.go new file mode 100644 index 0000000..9abbd3e --- /dev/null +++ b/internal/semantic_analysis/codeql/runner.go @@ -0,0 +1,27 @@ +package codeql + +import "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + +// Runner builds a CodeQL database and runs queries to produce call-graph edges. +// TODO(level-2): implement database creation, query execution, and result parsing. +type Runner struct { + loader *Loader + dbDir string +} + +// NewRunner creates a Runner for the given project and database directory. +func NewRunner(loader *Loader, dbDir string) *Runner { + return &Runner{loader: loader, dbDir: dbDir} +} + +// BuildDatabase creates a CodeQL database for the Go project at projectDir. +// TODO(level-2): run `codeql database create --language=go`. +func (r *Runner) BuildDatabase(projectDir string) error { + return ErrCodeQLNotImplemented +} + +// QueryCallGraph runs the CodeQL call-graph query and returns edges. +// TODO(level-2): run query, parse SARIF/CSV, produce GoCallEdge list. +func (r *Runner) QueryCallGraph() ([]schema.GoCallEdge, error) { + return nil, ErrCodeQLNotImplemented +} diff --git a/internal/syntactic_analysis/export.go b/internal/syntactic_analysis/export.go new file mode 100644 index 0000000..fad49bb --- /dev/null +++ b/internal/syntactic_analysis/export.go @@ -0,0 +1,15 @@ +package syntactic_analysis + +import ( + "go/token" + + "golang.org/x/tools/go/packages" +) + +// Fset returns the token.FileSet used during package loading. +// Used by CallGraphBuilder to resolve source positions. +func (b *SymbolTableBuilder) Fset() *token.FileSet { return b.fset } + +// Pkgs returns the map of loaded packages keyed by package path. +// Used by CallGraphBuilder for type-info lookups. +func (b *SymbolTableBuilder) Pkgs() map[string]*packages.Package { return b.pkgs } diff --git a/internal/syntactic_analysis/signature.go b/internal/syntactic_analysis/signature.go new file mode 100644 index 0000000..6f538bf --- /dev/null +++ b/internal/syntactic_analysis/signature.go @@ -0,0 +1,86 @@ +// Package syntactic_analysis builds the symbol table from Go source files. +package syntactic_analysis + +import ( + "fmt" + "go/types" + "strings" +) + +// signatureOf is the single canonicalizer for all signature strings in the analyzer. +// It produces the edge id used in GoCallable.signature, GoType.signature, and every +// call-graph edge source/target. All callers must use this function — never build +// signatures ad-hoc. Caller-side and callee-side ids are then guaranteed identical. +// +// Format: +// - Package function: "pkg/path.FuncName" +// - Method: "pkg/path.TypeName.MethodName" +// - Interface method: "pkg/path.InterfaceName.MethodName" +// - Type: "pkg/path.TypeName" +func signatureOf(obj types.Object) string { + if obj == nil { + return "" + } + pkg := obj.Pkg() + pkgPath := "" + if pkg != nil { + pkgPath = pkg.Path() + } + + switch o := obj.(type) { + case *types.Func: + sig := o.Type().(*types.Signature) + recv := sig.Recv() + if recv != nil { + // Method: extract the receiver type name, stripping pointer indirection. + recvType := recv.Type() + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + } + if named, ok := recvType.(*types.Named); ok { + typeName := named.Obj().Name() + return fmt.Sprintf("%s.%s.%s", pkgPath, typeName, o.Name()) + } + } + return fmt.Sprintf("%s.%s", pkgPath, o.Name()) + case *types.TypeName: + return fmt.Sprintf("%s.%s", pkgPath, o.Name()) + default: + return fmt.Sprintf("%s.%s", pkgPath, obj.Name()) + } +} + +// signatureOfNamed builds a type signature from a *types.Named directly. +// Used when we have the named type but not a types.Object. +func signatureOfNamed(named *types.Named) string { + if named == nil { + return "" + } + obj := named.Obj() + pkgPath := "" + if obj.Pkg() != nil { + pkgPath = obj.Pkg().Path() + } + return fmt.Sprintf("%s.%s", pkgPath, obj.Name()) +} + +// signatureForCall builds a callee signature from a *types.Func resolved at a call site. +func signatureForCall(fn *types.Func) string { + return signatureOf(fn) +} + +// normalizeReturnType joins multiple return types into a single parenthesized string. +// Single non-error returns are returned as-is; multiple returns become "(t1, t2, ...)". +func normalizeReturnType(results *types.Tuple) (joined string, parts []string) { + if results == nil || results.Len() == 0 { + return "", nil + } + parts = make([]string, results.Len()) + for i := 0; i < results.Len(); i++ { + parts[i] = results.At(i).Type().String() + } + if len(parts) == 1 { + return parts[0], parts + } + return "(" + strings.Join(parts, ", ") + ")", parts +} diff --git a/internal/syntactic_analysis/symbol_table.go b/internal/syntactic_analysis/symbol_table.go new file mode 100644 index 0000000..f6a7838 --- /dev/null +++ b/internal/syntactic_analysis/symbol_table.go @@ -0,0 +1,944 @@ +package syntactic_analysis + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "os" + "reflect" + "strings" + "unicode" + + "golang.org/x/tools/go/packages" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// SymbolTableBuilder constructs a symbol table by loading packages with full type +// information via golang.org/x/tools/go/packages. One builder per analysis run. +// +// Architecture mirrors codeanalyzer-python's SymbolTableBuilder: a cohesive struct +// with per-node-kind private methods, sharing the loaded package context on self. +type SymbolTableBuilder struct { + projectDir string + fset *token.FileSet + // pkgs is the flat list of loaded packages, keyed by package path. + pkgs map[string]*packages.Package +} + +// NewSymbolTableBuilder creates a builder for projectDir. +func NewSymbolTableBuilder(projectDir string) *SymbolTableBuilder { + return &SymbolTableBuilder{ + projectDir: projectDir, + fset: token.NewFileSet(), + pkgs: map[string]*packages.Package{}, + } +} + +// Build loads all packages under projectDir (or only targetFiles if non-empty), +// walks each file, and returns the symbol table keyed by relative file path. +func (b *SymbolTableBuilder) Build(targetFiles []string, skipTests bool) (map[string]schema.GoFile, error) { + cfg := &packages.Config{ + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedSyntax | + packages.NeedTypes | + packages.NeedTypesInfo | + packages.NeedImports | + packages.NeedDeps, + Dir: b.projectDir, + Fset: b.fset, + // Silence go vet; we only need type info, not a full build. + BuildFlags: []string{}, + } + + // Build pattern(s) for packages.Load. + patterns := []string{"./..."} + if len(targetFiles) > 0 { + // Map each target file to a file= pattern; packages.Load accepts multiple patterns. + patterns = make([]string, len(targetFiles)) + for i, f := range targetFiles { + patterns[i] = "file=" + f + } + } + + pkgList, err := packages.Load(cfg, patterns...) + if err != nil { + return nil, err + } + + // Collect packages, warn on errors but don't abort (graceful partial analysis). + for _, pkg := range pkgList { + if len(pkg.Errors) > 0 { + for _, e := range pkg.Errors { + utils.Warn("package %s: %v", pkg.PkgPath, e) + } + } + if pkg.Types != nil { + b.pkgs[pkg.PkgPath] = pkg + } + } + + symbolTable := map[string]schema.GoFile{} + + for _, pkg := range b.pkgs { + for i, astFile := range pkg.Syntax { + if i >= len(pkg.GoFiles) { + continue + } + filePath := pkg.GoFiles[i] + relPath := utils.RelativePath(b.projectDir, filePath) + if skipTests && utils.IsTestFile(relPath) { + continue + } + if utils.IsVendored(relPath) { + continue + } + goFile := b.buildGoFile(pkg, astFile, filePath, relPath) + symbolTable[relPath] = goFile + } + } + + // In Go a method can be defined in any file of the package; the main loop + // above only attaches a method when its receiver type is in the same file. + // This pass finds methods whose type lives in a sibling file and attaches them. + b.reconcileCrossFileMethods(symbolTable) + + return symbolTable, nil +} + +// reconcileCrossFileMethods attaches methods to their type's owner file when +// the method and its receiver type are declared in different files of the same package. +func (b *SymbolTableBuilder) reconcileCrossFileMethods(symbolTable map[string]schema.GoFile) { + // Build index: (pkgPath, shortTypeName) → relPath of the file that owns the type. + type typeKey struct{ pkgPath, typeName string } + typeOwner := make(map[typeKey]string) + for relPath, gf := range symbolTable { + pkgPath := b.filePkgPath(relPath) + for typeName := range gf.Types { + typeOwner[typeKey{pkgPath, typeName}] = relPath + } + } + + for _, pkg := range b.pkgs { + for i, astFile := range pkg.Syntax { + if i >= len(pkg.GoFiles) { + continue + } + filePath := pkg.GoFiles[i] + relPath := utils.RelativePath(b.projectDir, filePath) + + for _, decl := range astFile.Decls { + fd, ok := decl.(*ast.FuncDecl) + if !ok || fd.Recv == nil { + continue + } + typeName := b.receiverTypeName(fd.Recv) + if typeName == "" { + continue + } + ownerRelPath, ok := typeOwner[typeKey{pkg.PkgPath, typeName}] + if !ok || ownerRelPath == relPath { + continue // type not found or already handled by the main loop + } + + callable := b.buildCallable(pkg, astFile, fd) + if callable == nil { + continue + } + ownerFile, found := symbolTable[ownerRelPath] + if !found { + continue + } + gt, found := ownerFile.Types[typeName] + if !found { + continue + } + if _, alreadyPresent := gt.Methods[callable.Signature]; !alreadyPresent { + gt.Methods[callable.Signature] = *callable + ownerFile.Types[typeName] = gt + symbolTable[ownerRelPath] = ownerFile + } + } + } + } +} + +// filePkgPath returns the package import path for the file at relPath. +func (b *SymbolTableBuilder) filePkgPath(relPath string) string { + for _, pkg := range b.pkgs { + for _, absFile := range pkg.GoFiles { + if utils.RelativePath(b.projectDir, absFile) == relPath { + return pkg.PkgPath + } + } + } + return "" +} + +// ─── Per-file builder ────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildGoFile( + pkg *packages.Package, + astFile *ast.File, + absPath, relPath string, +) schema.GoFile { + info, _ := os.Stat(absPath) + hash, _ := utils.FileHash(absPath) + + var lastMod *float64 + var fileSize *int64 + if info != nil { + lm := float64(info.ModTime().Unix()) + float64(info.ModTime().Nanosecond())/1e9 + lastMod = &lm + sz := info.Size() + fileSize = &sz + } + var contentHash *string + if hash != "" { + contentHash = &hash + } + + gf := schema.GoFile{ + FilePath: relPath, + PackageName: astFile.Name.Name, + Imports: b.buildImports(astFile), + Comments: b.buildFileComments(astFile), + Types: map[string]schema.GoType{}, + Functions: map[string]schema.GoCallable{}, + Variables: b.buildPackageVars(pkg, astFile), + ContentHash: contentHash, + LastModified: lastMod, + FileSize: fileSize, + } + + // Walk top-level declarations. + for _, decl := range astFile.Decls { + switch d := decl.(type) { + case *ast.GenDecl: + b.processGenDecl(pkg, astFile, d, &gf) + case *ast.FuncDecl: + callable := b.buildCallable(pkg, astFile, d) + if callable == nil { + continue + } + if d.Recv != nil { + // Method — attach to its type. + if typeName := b.receiverTypeName(d.Recv); typeName != "" { + if gt, ok := gf.Types[typeName]; ok { + gt.Methods[callable.Signature] = *callable + gf.Types[typeName] = gt + } + } + } else { + gf.Functions[callable.Signature] = *callable + } + } + } + + return gf +} + +// ─── Imports ────────────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildImports(astFile *ast.File) []schema.GoImport { + var imports []schema.GoImport + for _, imp := range astFile.Imports { + path := strings.Trim(imp.Path.Value, `"`) + alias := "" + if imp.Name != nil { + alias = imp.Name.Name + } + pos := b.fset.Position(imp.Pos()) + end := b.fset.Position(imp.End()) + imports = append(imports, schema.GoImport{ + Module: path, + Alias: alias, + StartLine: pos.Line, + EndLine: end.Line, + }) + } + return imports +} + +// ─── Comments ───────────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildFileComments(astFile *ast.File) []schema.GoComment { + var comments []schema.GoComment + for _, cg := range astFile.Comments { + for _, c := range cg.List { + pos := b.fset.Position(c.Pos()) + end := b.fset.Position(c.End()) + comments = append(comments, schema.GoComment{ + Content: c.Text, + StartLine: pos.Line, + EndLine: end.Line, + StartColumn: pos.Column, + EndColumn: end.Column, + IsDocComment: strings.HasPrefix(c.Text, "//") || strings.HasPrefix(c.Text, "/*"), + }) + } + } + return comments +} + +func (b *SymbolTableBuilder) docComments(doc *ast.CommentGroup) []schema.GoComment { + if doc == nil { + return nil + } + var comments []schema.GoComment + for _, c := range doc.List { + pos := b.fset.Position(c.Pos()) + end := b.fset.Position(c.End()) + comments = append(comments, schema.GoComment{ + Content: c.Text, + StartLine: pos.Line, + EndLine: end.Line, + StartColumn: pos.Column, + EndColumn: end.Column, + IsDocComment: true, + }) + } + return comments +} + +// ─── GenDecl processor (type/var/const declarations) ───────────────────────── + +func (b *SymbolTableBuilder) processGenDecl( + pkg *packages.Package, + astFile *ast.File, + decl *ast.GenDecl, + gf *schema.GoFile, +) { + for _, spec := range decl.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + gt := b.buildType(pkg, astFile, decl, s) + if gt != nil { + gf.Types[gt.Name] = *gt + } + } + } +} + +// ─── Type builder ───────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildType( + pkg *packages.Package, + astFile *ast.File, + decl *ast.GenDecl, + spec *ast.TypeSpec, +) *schema.GoType { + typeName := spec.Name.Name + isExported := unicode.IsUpper(rune(typeName[0])) + + // Resolve via type info if available. + var typSig string + if pkg.TypesInfo != nil { + if obj, ok := pkg.TypesInfo.Defs[spec.Name]; ok && obj != nil { + typSig = signatureOf(obj) + } + } + if typSig == "" { + typSig = pkg.Types.Path() + "." + typeName + } + + pos := b.fset.Position(spec.Pos()) + end := b.fset.Position(spec.End()) + + gt := &schema.GoType{ + Name: typeName, + Signature: typSig, + Comments: b.docComments(decl.Doc), + IsExported: isExported, + Fields: []schema.GoField{}, + Methods: map[string]schema.GoCallable{}, + BaseTypes: []string{}, + InnerTypes: map[string]schema.GoType{}, + StartLine: pos.Line, + EndLine: end.Line, + } + + switch t := spec.Type.(type) { + case *ast.StructType: + gt.IsInterface = false + gt.Fields = b.buildStructFields(pkg, t) + gt.BaseTypes = b.embeddedTypes(pkg, t) + case *ast.InterfaceType: + gt.IsInterface = true + // Interface methods are collected when we process FuncDecl with receivers, + // but interface method signatures within the interface type are also recorded here. + b.collectInterfaceMethods(pkg, t, gt) + } + + return gt +} + +func (b *SymbolTableBuilder) buildStructFields(pkg *packages.Package, st *ast.StructType) []schema.GoField { + var fields []schema.GoField + if st.Fields == nil { + return fields + } + for _, field := range st.Fields.List { + typStr := b.typeString(pkg, field.Type) + tags := parseStructTags(field.Tag) + isEmbedded := len(field.Names) == 0 + + if isEmbedded { + pos := b.fset.Position(field.Pos()) + end := b.fset.Position(field.End()) + fields = append(fields, schema.GoField{ + Name: typStr, + Type: typStr, + Tags: tags, + IsExported: true, + IsEmbedded: true, + StartLine: pos.Line, + EndLine: end.Line, + }) + continue + } + for _, name := range field.Names { + pos := b.fset.Position(name.Pos()) + end := b.fset.Position(field.End()) + fields = append(fields, schema.GoField{ + Name: name.Name, + Type: typStr, + Tags: tags, + IsExported: unicode.IsUpper(rune(name.Name[0])), + IsEmbedded: false, + StartLine: pos.Line, + EndLine: end.Line, + }) + } + } + return fields +} + +func (b *SymbolTableBuilder) embeddedTypes(pkg *packages.Package, st *ast.StructType) []string { + var embedded []string + if st.Fields == nil { + return embedded + } + for _, field := range st.Fields.List { + if len(field.Names) == 0 { + embedded = append(embedded, b.typeString(pkg, field.Type)) + } + } + return embedded +} + +func (b *SymbolTableBuilder) collectInterfaceMethods(pkg *packages.Package, it *ast.InterfaceType, gt *schema.GoType) { + if it.Methods == nil { + return + } + for _, method := range it.Methods.List { + if len(method.Names) == 0 { + // Embedded interface — add to base_types. + gt.BaseTypes = append(gt.BaseTypes, b.typeString(pkg, method.Type)) + continue + } + for _, name := range method.Names { + pos := b.fset.Position(name.Pos()) + end := b.fset.Position(method.End()) + + var retType string + var retTypes []string + if ft, ok := method.Type.(*ast.FuncType); ok && ft.Results != nil { + for _, r := range ft.Results.List { + retTypes = append(retTypes, b.typeString(pkg, r.Type)) + } + retType = b.joinReturnTypes(retTypes) + } + + sig := pkg.Types.Path() + "." + gt.Name + "." + name.Name + callable := schema.GoCallable{ + Name: name.Name, + Path: "", + Signature: sig, + Parameters: b.buildFuncTypeParams(pkg, method.Type), + ReturnType: retType, + ReturnTypes: retTypes, + IsExported: unicode.IsUpper(rune(name.Name[0])), + ReceiverType: gt.Signature, + CallSites: []schema.GoCallsite{}, + InnerCallables: map[string]schema.GoCallable{}, + StartLine: pos.Line, + EndLine: end.Line, + } + gt.Methods[sig] = callable + } + } +} + +// ─── Callable builder ───────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildCallable( + pkg *packages.Package, + astFile *ast.File, + decl *ast.FuncDecl, +) *schema.GoCallable { + name := decl.Name.Name + isExported := unicode.IsUpper(rune(name[0])) + pos := b.fset.Position(decl.Pos()) + end := b.fset.Position(decl.End()) + + var sig string + if pkg.TypesInfo != nil { + if obj, ok := pkg.TypesInfo.Defs[decl.Name]; ok && obj != nil { + sig = signatureOf(obj) + } + } + if sig == "" { + if decl.Recv != nil { + if recvName := b.receiverTypeName(decl.Recv); recvName != "" { + sig = pkg.Types.Path() + "." + recvName + "." + name + } + } + if sig == "" { + sig = pkg.Types.Path() + "." + name + } + } + + var recvType, recvName string + if decl.Recv != nil && len(decl.Recv.List) > 0 { + rf := decl.Recv.List[0] + recvType = b.typeString(pkg, rf.Type) + if len(rf.Names) > 0 { + recvName = rf.Names[0].Name + } + } + + retType, retTypes := b.buildReturnTypes(pkg, decl.Type) + bodyStart := pos.Line + if decl.Body != nil { + bodyStart = b.fset.Position(decl.Body.Pos()).Line + } + + callable := &schema.GoCallable{ + Name: name, + Path: utils.RelativePath(b.projectDir, b.fset.File(decl.Pos()).Name()), + Signature: sig, + Comments: b.docComments(decl.Doc), + Parameters: b.buildParams(pkg, decl.Type), + ReturnType: retType, + ReturnTypes: retTypes, + IsExported: isExported, + ReceiverType: recvType, + ReceiverName: recvName, + CallSites: []schema.GoCallsite{}, + InnerCallables: map[string]schema.GoCallable{}, + LocalVariables: []schema.GoVariableDeclaration{}, + AccessedSymbols: []schema.GoSymbol{}, + StartLine: pos.Line, + EndLine: end.Line, + CodeStartLine: bodyStart, + CyclomaticComplexity: b.cyclomaticComplexity(decl), + } + + if decl.Body != nil { + callable.Code = b.nodeSource(decl) + callable.CallSites = b.buildCallSites(pkg, decl.Body) + callable.LocalVariables = b.buildLocalVars(pkg, decl.Body) + } + + return callable +} + +// ─── Parameters ─────────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildParams(pkg *packages.Package, ft *ast.FuncType) []schema.GoParameter { + if ft == nil || ft.Params == nil { + return nil + } + var params []schema.GoParameter + for _, field := range ft.Params.List { + typStr := b.typeString(pkg, field.Type) + isVariadic := false + if _, ok := field.Type.(*ast.Ellipsis); ok { + isVariadic = true + } + pos := b.fset.Position(field.Pos()) + end := b.fset.Position(field.End()) + if len(field.Names) == 0 { + params = append(params, schema.GoParameter{ + Name: "_", Type: typStr, IsVariadic: isVariadic, + StartLine: pos.Line, EndLine: end.Line, + }) + continue + } + for _, name := range field.Names { + params = append(params, schema.GoParameter{ + Name: name.Name, Type: typStr, IsVariadic: isVariadic, + StartLine: pos.Line, EndLine: end.Line, + }) + } + } + return params +} + +func (b *SymbolTableBuilder) buildFuncTypeParams(pkg *packages.Package, expr ast.Expr) []schema.GoParameter { + if ft, ok := expr.(*ast.FuncType); ok { + return b.buildParams(pkg, ft) + } + return nil +} + +// ─── Return types ───────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildReturnTypes(pkg *packages.Package, ft *ast.FuncType) (string, []string) { + if ft == nil || ft.Results == nil { + return "", nil + } + var parts []string + for _, field := range ft.Results.List { + typStr := b.typeString(pkg, field.Type) + if len(field.Names) == 0 { + parts = append(parts, typStr) + } else { + for range field.Names { + parts = append(parts, typStr) + } + } + } + return b.joinReturnTypes(parts), parts +} + +func (b *SymbolTableBuilder) joinReturnTypes(parts []string) string { + if len(parts) == 0 { + return "" + } + if len(parts) == 1 { + return parts[0] + } + return "(" + strings.Join(parts, ", ") + ")" +} + +// ─── Call sites ─────────────────────────────────────────────────────────────── + +// buildCallSites walks the function body and records every call expression. +// callee_signature is left nil here — the resolver backfills it in the call-graph stage. +func (b *SymbolTableBuilder) buildCallSites(pkg *packages.Package, body *ast.BlockStmt) []schema.GoCallsite { + var sites []schema.GoCallsite + ast.Inspect(body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.GoStmt: + // Goroutine launch — record with is_goroutine=true. + if call, ok := node.Call.Fun.(*ast.CallExpr); ok { + _ = call + } + site := b.callExprToSite(pkg, node.Call, true) + if site != nil { + sites = append(sites, *site) + } + return false + case *ast.CallExpr: + site := b.callExprToSite(pkg, node, false) + if site != nil { + sites = append(sites, *site) + } + } + return true + }) + return sites +} + +func (b *SymbolTableBuilder) callExprToSite(pkg *packages.Package, call *ast.CallExpr, isGoroutine bool) *schema.GoCallsite { + pos := b.fset.Position(call.Pos()) + end := b.fset.Position(call.End()) + + var methodName, receiverExpr, receiverType string + isConstructor := false + + switch fn := call.Fun.(type) { + case *ast.Ident: + methodName = fn.Name + // Check if it's a type conversion / constructor call. + if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(fn); obj != nil { + if _, ok := obj.(*types.TypeName); ok { + isConstructor = true + } + } + } + case *ast.SelectorExpr: + methodName = fn.Sel.Name + receiverExpr = b.exprString(fn.X) + if pkg.TypesInfo != nil { + if t := pkg.TypesInfo.TypeOf(fn.X); t != nil { + receiverType = t.String() + } + } + default: + methodName = b.exprString(call.Fun) + } + + // Collect argument types. + var argTypes []string + for _, arg := range call.Args { + if pkg.TypesInfo != nil { + if t := pkg.TypesInfo.TypeOf(arg); t != nil { + argTypes = append(argTypes, t.String()) + continue + } + } + argTypes = append(argTypes, "") + } + + return &schema.GoCallsite{ + MethodName: methodName, + ReceiverExpr: receiverExpr, + ReceiverType: receiverType, + ArgumentTypes: argTypes, + IsConstructorCall: isConstructor, + IsGoroutine: isGoroutine, + CalleeSignature: nil, // backfilled by the call-graph stage + StartLine: pos.Line, + StartColumn: pos.Column, + EndLine: end.Line, + EndColumn: end.Column, + } +} + +// ─── Local variables ────────────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildLocalVars(pkg *packages.Package, body *ast.BlockStmt) []schema.GoVariableDeclaration { + var vars []schema.GoVariableDeclaration + ast.Inspect(body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.AssignStmt: + if node.Tok.String() == ":=" { + for i, lhs := range node.Lhs { + if ident, ok := lhs.(*ast.Ident); ok { + pos := b.fset.Position(ident.Pos()) + typStr := "" + if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(ident); obj != nil { + typStr = obj.Type().String() + } + } + init := "" + if i < len(node.Rhs) { + init = b.exprString(node.Rhs[i]) + } + vars = append(vars, schema.GoVariableDeclaration{ + Name: ident.Name, Type: typStr, Initializer: init, + Scope: "function", StartLine: pos.Line, EndLine: pos.Line, + }) + } + } + } + case *ast.DeclStmt: + if gen, ok := node.Decl.(*ast.GenDecl); ok { + for _, spec := range gen.Specs { + if vs, ok := spec.(*ast.ValueSpec); ok { + for i, name := range vs.Names { + pos := b.fset.Position(name.Pos()) + typStr := "" + if vs.Type != nil { + typStr = b.typeString(pkg, vs.Type) + } else if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(name); obj != nil { + typStr = obj.Type().String() + } + } + init := "" + if i < len(vs.Values) { + init = b.exprString(vs.Values[i]) + } + vars = append(vars, schema.GoVariableDeclaration{ + Name: name.Name, Type: typStr, Initializer: init, + Scope: "function", StartLine: pos.Line, EndLine: pos.Line, + }) + } + } + } + } + } + return true + }) + return vars +} + +// ─── Package-level variables ────────────────────────────────────────────────── + +func (b *SymbolTableBuilder) buildPackageVars(pkg *packages.Package, astFile *ast.File) []schema.GoVariableDeclaration { + var vars []schema.GoVariableDeclaration + for _, decl := range astFile.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + if gen.Tok.String() != "var" && gen.Tok.String() != "const" { + continue + } + for _, spec := range gen.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for i, name := range vs.Names { + pos := b.fset.Position(name.Pos()) + typStr := "" + if vs.Type != nil { + typStr = b.typeString(pkg, vs.Type) + } else if pkg.TypesInfo != nil { + if obj := pkg.TypesInfo.ObjectOf(name); obj != nil { + typStr = obj.Type().String() + } + } + init := "" + if i < len(vs.Values) { + init = b.exprString(vs.Values[i]) + } + vars = append(vars, schema.GoVariableDeclaration{ + Name: name.Name, Type: typStr, Initializer: init, + Scope: "package", StartLine: pos.Line, EndLine: pos.Line, + }) + } + } + } + return vars +} + +// ─── Cyclomatic complexity ──────────────────────────────────────────────────── + +// cyclomaticComplexity computes McCabe complexity: 1 + decision points. +func (b *SymbolTableBuilder) cyclomaticComplexity(decl *ast.FuncDecl) int { + if decl.Body == nil { + return 0 + } + complexity := 1 + ast.Inspect(decl.Body, func(n ast.Node) bool { + switch n.(type) { + case *ast.IfStmt, *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt, + *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.CaseClause, + *ast.CommClause: + complexity++ + } + return true + }) + return complexity +} + +// ─── Helpers ────────────────────────────────────────────────────────────────── + +// receiverTypeName extracts the base type name from a receiver field list. +func (b *SymbolTableBuilder) receiverTypeName(recv *ast.FieldList) string { + if recv == nil || len(recv.List) == 0 { + return "" + } + expr := recv.List[0].Type + // Strip pointer. + if star, ok := expr.(*ast.StarExpr); ok { + expr = star.X + } + if ident, ok := expr.(*ast.Ident); ok { + return ident.Name + } + return "" +} + +// typeString returns a human-readable string for an ast.Expr type node. +func (b *SymbolTableBuilder) typeString(pkg *packages.Package, expr ast.Expr) string { + if pkg.TypesInfo != nil { + if t := pkg.TypesInfo.TypeOf(expr); t != nil { + return t.String() + } + } + // Fallback: print the expression. + return b.exprString(expr) +} + +// exprString returns a best-effort source representation of an expression. +func (b *SymbolTableBuilder) exprString(expr ast.Expr) string { + if expr == nil { + return "" + } + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + return b.exprString(e.X) + "." + e.Sel.Name + case *ast.StarExpr: + return "*" + b.exprString(e.X) + case *ast.ArrayType: + return "[]" + b.exprString(e.Elt) + case *ast.MapType: + return "map[" + b.exprString(e.Key) + "]" + b.exprString(e.Value) + case *ast.InterfaceType: + return "interface{}" + case *ast.BasicLit: + return e.Value + default: + return fmt.Sprint(expr) + } +} + +// nodeSource extracts the raw source text of a node (best effort). +func (b *SymbolTableBuilder) nodeSource(node ast.Node) string { + pos := b.fset.Position(node.Pos()) + if pos.Filename == "" { + return "" + } + data, err := os.ReadFile(pos.Filename) + if err != nil { + return "" + } + startOff := b.fset.File(node.Pos()).Offset(node.Pos()) + endOff := b.fset.File(node.End()).Offset(node.End()) + if startOff < 0 || endOff > len(data) || startOff >= endOff { + return "" + } + return string(data[startOff:endOff]) +} + +// parseStructTags parses a struct tag literal into a key→value map. +// e.g. `json:"name,omitempty" db:"name"` → {"json": "name,omitempty", "db": "name"} +func parseStructTags(lit *ast.BasicLit) map[string]string { + tags := map[string]string{} + if lit == nil { + return tags + } + raw := strings.Trim(lit.Value, "`") + st := reflect.StructTag(raw) + // Iterate common tag keys; for a full parse, walk the raw string. + for _, key := range extractTagKeys(raw) { + if v := st.Get(key); v != "" { + tags[key] = v + } + } + return tags +} + +// extractTagKeys extracts tag key names from a raw struct tag string. +func extractTagKeys(raw string) []string { + var keys []string + for len(raw) > 0 { + raw = strings.TrimLeft(raw, " \t") + if raw == "" { + break + } + idx := strings.IndexByte(raw, ':') + if idx < 0 { + break + } + keys = append(keys, raw[:idx]) + // Skip past the value. + rest := raw[idx+1:] + if len(rest) == 0 || rest[0] != '"' { + break + } + end := strings.IndexByte(rest[1:], '"') + if end < 0 { + break + } + raw = rest[end+2:] + } + return keys +} + +// ensure fmt is used (used in exprString fallback and signature.go). +var _ = fmt.Sprintf diff --git a/internal/utils/fs.go b/internal/utils/fs.go new file mode 100644 index 0000000..e46da2e --- /dev/null +++ b/internal/utils/fs.go @@ -0,0 +1,81 @@ +// Package utils provides filesystem helpers and logging utilities. +package utils + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// IsTestFile reports whether path is a Go test file (_test.go suffix). +func IsTestFile(path string) bool { + return strings.HasSuffix(path, "_test.go") +} + +// IsVendored reports whether path is under a vendored or generated directory. +func IsVendored(path string) bool { + for _, seg := range strings.Split(filepath.ToSlash(path), "/") { + switch seg { + case "vendor", "testdata", ".git": + return true + } + } + return false +} + +// RelativePath returns path relative to root, or path itself on error. +func RelativePath(root, path string) string { + rel, err := filepath.Rel(root, path) + if err != nil { + return path + } + return rel +} + +// FileHash returns the SHA-256 hex digest of the file at path. +func FileHash(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return fmt.Sprintf("%x", h.Sum(nil)), nil +} + +// DiscoverGoFiles returns all *.go files under root, skipping vendored dirs +// and optionally test files. +func DiscoverGoFiles(root string, skipTests bool) ([]string, error) { + var files []string + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil // skip unreadable entries gracefully + } + if d.IsDir() { + if IsVendored(path) { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") { + return nil + } + if skipTests && IsTestFile(path) { + return nil + } + files = append(files, path) + return nil + }) + return files, err +} + +// EnsureDir creates dir and all parents if they don't exist. +func EnsureDir(dir string) error { + return os.MkdirAll(dir, 0o755) +} diff --git a/internal/utils/logging.go b/internal/utils/logging.go new file mode 100644 index 0000000..a5c7674 --- /dev/null +++ b/internal/utils/logging.go @@ -0,0 +1,30 @@ +package utils + +import ( + "fmt" + "os" +) + +var verbosity int + +// SetVerbosity sets the global log level (0=quiet, 1=info, 2=debug). +func SetVerbosity(v int) { verbosity = v } + +// Info logs an informational message when verbosity >= 1. +func Info(format string, args ...any) { + if verbosity >= 1 { + fmt.Fprintf(os.Stderr, "[codeanalyzer-go] "+format+"\n", args...) + } +} + +// Debug logs a debug message when verbosity >= 2. +func Debug(format string, args ...any) { + if verbosity >= 2 { + fmt.Fprintf(os.Stderr, "[codeanalyzer-go DEBUG] "+format+"\n", args...) + } +} + +// Warn always prints a warning to stderr. +func Warn(format string, args ...any) { + fmt.Fprintf(os.Stderr, "[codeanalyzer-go WARN] "+format+"\n", args...) +} diff --git a/testdata/fixture/go.mod b/testdata/fixture/go.mod new file mode 100644 index 0000000..30818c1 --- /dev/null +++ b/testdata/fixture/go.mod @@ -0,0 +1,3 @@ +module example.com/fixture + +go 1.21 diff --git a/testdata/fixture/main.go b/testdata/fixture/main.go new file mode 100644 index 0000000..f20e3f9 --- /dev/null +++ b/testdata/fixture/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + + "example.com/fixture/pkg/greeter" +) + +func main() { + g := greeter.New("Hello") + msg := g.Greet("World") + fmt.Println(msg) + loud := greeter.Shout(msg) + fmt.Println(loud) +} diff --git a/testdata/fixture/pkg/greeter/greeter.go b/testdata/fixture/pkg/greeter/greeter.go new file mode 100644 index 0000000..85f19c6 --- /dev/null +++ b/testdata/fixture/pkg/greeter/greeter.go @@ -0,0 +1,29 @@ +// Package greeter provides simple greeting functionality. +package greeter + +import "fmt" + +// Greeter holds a greeting prefix. +type Greeter struct { + Prefix string `json:"prefix"` +} + +// New creates a Greeter with the given prefix. +func New(prefix string) *Greeter { + return &Greeter{Prefix: prefix} +} + +// Greet returns a greeting for name. +func (g *Greeter) Greet(name string) string { + return fmt.Sprintf("%s, %s!", g.Prefix, name) +} + +// Logger is a simple logging interface. +type Logger interface { + Log(msg string) +} + +// Shout returns the message in a louder form. +func Shout(msg string) string { + return msg + "!!!" +} diff --git a/testdata/realistic/go.mod b/testdata/realistic/go.mod new file mode 100644 index 0000000..f1da8db --- /dev/null +++ b/testdata/realistic/go.mod @@ -0,0 +1,3 @@ +module example.com/realistic + +go 1.21 diff --git a/testdata/realistic/main.go b/testdata/realistic/main.go new file mode 100644 index 0000000..81ae136 --- /dev/null +++ b/testdata/realistic/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "log" + + "example.com/realistic/server" + "example.com/realistic/worker" +) + +func main() { + cfg := server.Config{Host: "localhost", Port: 8080} + srv, err := server.New(cfg) + if err != nil { + log.Fatal(err) + } + + fmt.Println(srv.Addr()) + fmt.Println(server.Tags("env", "prod", "region", "us-east-1")) + + w := worker.New() + w.Run(nil, worker.Task{ID: 1, Payload: "hello"}) + + combined := worker.Combine(worker.Result{TaskID: 1, Output: "a"}, worker.Result{TaskID: 2, Output: "b"}) + fmt.Println(combined.Output) +} diff --git a/testdata/realistic/server/middleware.go b/testdata/realistic/server/middleware.go new file mode 100644 index 0000000..70ae722 --- /dev/null +++ b/testdata/realistic/server/middleware.go @@ -0,0 +1,19 @@ +// Second file in the server package — exercises multi-file package detection. +package server + +import ( + "fmt" + "strings" +) + +// Tags formats key-value pairs into a single string. +// Exercises: variadic parameter (pairs ...string), is_variadic=true. +func Tags(pairs ...string) string { + return fmt.Sprintf("[%s]", strings.Join(pairs, ", ")) +} + +// Describe returns a human-readable description of the server. +// Exercises: value receiver (s Server) vs pointer receiver in server.go. +func (s Server) Describe() string { + return fmt.Sprintf("server at %s", s.Addr()) +} diff --git a/testdata/realistic/server/server.go b/testdata/realistic/server/server.go new file mode 100644 index 0000000..4034e77 --- /dev/null +++ b/testdata/realistic/server/server.go @@ -0,0 +1,53 @@ +// Package server provides a minimal configurable server. +package server + +import ( + "errors" + "fmt" +) + +// Config holds server configuration. +type Config struct { + Host string `json:"host" validate:"required"` + Port int `json:"port"` +} + +// Server wraps a Config with lifecycle state. +// It embeds Config directly so callers can access Host and Port without indirection. +type Server struct { + Config // embedded — exercises GoField.is_embedded + ready bool // unexported field — exercises is_exported=false on GoField +} + +// New creates a Server, returning an error if the config is invalid. +// Exercises: multiple return types (*Server, error), (T, error) idiom. +func New(cfg Config) (*Server, error) { + if cfg.Host == "" { + return nil, errors.New("host required") + } + return &Server{Config: cfg}, nil +} + +// Addr returns the host:port address string. +// Exercises: pointer receiver (*Server), non-empty receiver_type / receiver_name. +func (s *Server) Addr() string { + return fmt.Sprintf("%s:%d", s.Host, s.Port) +} + +// Validate checks the config fields and returns any validation errors. +// Exercises: named return, multiple return types (bool, error). +func (s *Server) Validate() (bool, error) { + if s.Host == "" { + return false, errors.New("host is empty") + } + if s.Port <= 0 { + return false, fmt.Errorf("invalid port: %d", s.Port) + } + return true, nil +} + +// shutdown performs internal cleanup. +// Exercises: unexported method — is_exported=false. +func (s *Server) shutdown() { + s.ready = false +} diff --git a/testdata/realistic/worker/worker.go b/testdata/realistic/worker/worker.go new file mode 100644 index 0000000..0559946 --- /dev/null +++ b/testdata/realistic/worker/worker.go @@ -0,0 +1,65 @@ +// Package worker runs tasks concurrently. +package worker + +import ( + "fmt" + "sync" +) + +// Task is a unit of work. +type Task struct { + ID int `json:"id"` + Payload string `json:"payload"` +} + +// Result holds the outcome of processing a Task. +type Result struct { + TaskID int `json:"task_id"` + Output string `json:"output"` +} + +// Processor is the processing interface. +// Exercises: interface type, is_interface=true. +type Processor interface { + Process(t Task) (Result, error) +} + +// Worker runs tasks in background goroutines. +type Worker struct { + mu sync.Mutex + done bool +} + +// New creates a Worker. +func New() *Worker { + return &Worker{} +} + +// Run launches a goroutine to process t. +// Exercises: goroutine launch — GoCallsite.is_goroutine=true for the w.execute call. +func (w *Worker) Run(p Processor, t Task) { + go w.execute(p, t) +} + +// Combine merges multiple results into a single Result. +// Exercises: variadic parameter (results ...Result), is_variadic=true. +func Combine(results ...Result) Result { + out := Result{} + for _, r := range results { + out.Output += r.Output + } + return out +} + +// execute processes a task under the mutex. +// Exercises: unexported method (is_exported=false), cyclomatic_complexity > 1 (if branch). +func (w *Worker) execute(p Processor, t Task) { + w.mu.Lock() + defer w.mu.Unlock() + r, err := p.Process(t) + if err != nil { + _ = fmt.Errorf("task %d: %w", t.ID, err) + return + } + _ = r +} From fe277cd34fea64360312ba31cbda066cea12149a Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Wed, 17 Jun 2026 09:58:22 -0400 Subject: [PATCH 2/4] chore: merge standard Go .gitignore with project-specific entries Incorporates the GitHub Go template (*.dll, *.so, go.work, .env, etc.) alongside project-specific ignores for built binaries and .claude/. Signed-off-by: Saurabh Sinha --- .gitignore | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 169d632..08088c3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,37 @@ -# Binaries +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Project-specific binaries /codeanalyzer /codeanalyzer-go -*.exe # Build output /dist/ /bin/ +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +coverage.txt + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + # Claude Code session data .claude/ # macOS .DS_Store - -# Go test cache / coverage -*.test -*.out -coverage.txt From b7a7c9b1efa7af7d34c56de0da0e3648bc2e53c5 Mon Sep 17 00:00:00 2001 From: Saurabh Sinha Date: Wed, 17 Jun 2026 16:20:20 -0400 Subject: [PATCH 3/4] test: add functional test coverage and restructure testdata fixtures Bug fix: populate InnerCallables by walking ast.FuncLit nodes in buildInnerCallables; add *ast.FuncLit: return false to buildCallSites so closure call sites are not double-counted in the outer function. New test files: - cmd/codeanalyzer/main_test.go: CLI integration tests covering --version, --format validation, --output, --analysis-level, --skip-tests, --target-files - internal/analysis/registry_test.go: orderPasses topo-sort tests and RunPipeline smoke test - internal/semantic_analysis/call_graph_test.go: MergeEdges unit tests - internal/utils/fs_test.go: table-driven tests for IsTestFile, IsVendored, FileHash, EnsureDir, DiscoverGoFiles - internal/core/skip_tests_test.go: SkipTests true/false behaviour - internal/core/incremental_test.go: TargetFiles single/multi package and nil Test additions to existing files: - chi_test.go: InnerCallables populated, IsConstructorCall for methodTyp() type conversion, init() presence and is_exported - multipackage_test.go (renamed from realistic_test.go): LocalVariables present with correct type and scope Testdata restructure: - testdata/fixture/ -> testdata/greeter/ - testdata/realistic/ -> testdata/multipackage/ - testdata/chi/: replace single-file wrapper with chi v5 library source (35 files) - testdata/generics/: add generics fixture (fn + set packages) - testdata/multipackage/server/server_test.go: minimal test file for --skip-tests=false coverage Signed-off-by: Saurabh Sinha --- README.md | 32 +- cmd/codeanalyzer/main.go | 21 +- cmd/codeanalyzer/main_test.go | 205 ++++ internal/analysis/registry_test.go | 137 +++ internal/core/analyzer_test.go | 119 +-- internal/core/chi_test.go | 196 ++++ internal/core/errors_test.go | 76 ++ internal/core/generics_test.go | 127 +++ internal/core/incremental_test.go | 120 +++ ...realistic_test.go => multipackage_test.go} | 149 ++- internal/core/skip_tests_test.go | 70 ++ internal/core/testsetup_test.go | 99 ++ internal/semantic_analysis/call_graph_test.go | 116 +++ internal/syntactic_analysis/symbol_table.go | 68 ++ internal/utils/fs_test.go | 206 ++++ testdata/chi/chain.go | 49 + testdata/chi/chi.go | 137 +++ testdata/chi/context.go | 166 ++++ testdata/chi/go.mod | 3 + testdata/chi/middleware/basic_auth.go | 33 + testdata/chi/middleware/clean_path.go | 28 + testdata/chi/middleware/client_ip.go | 263 ++++++ testdata/chi/middleware/compress.go | 392 ++++++++ testdata/chi/middleware/content_charset.go | 45 + testdata/chi/middleware/content_encoding.go | 34 + testdata/chi/middleware/content_type.go | 45 + testdata/chi/middleware/get_head.go | 39 + testdata/chi/middleware/heartbeat.go | 26 + testdata/chi/middleware/logger.go | 178 ++++ testdata/chi/middleware/maybe.go | 18 + testdata/chi/middleware/middleware.go | 23 + testdata/chi/middleware/nocache.go | 59 ++ testdata/chi/middleware/page_route.go | 20 + testdata/chi/middleware/path_rewrite.go | 16 + testdata/chi/middleware/profiler.go | 49 + testdata/chi/middleware/realip.go | 53 ++ testdata/chi/middleware/recoverer.go | 203 ++++ testdata/chi/middleware/request_id.go | 96 ++ testdata/chi/middleware/request_size.go | 18 + testdata/chi/middleware/route_headers.go | 146 +++ testdata/chi/middleware/strip.go | 77 ++ testdata/chi/middleware/sunset.go | 25 + testdata/chi/middleware/supress_notfound.go | 27 + testdata/chi/middleware/terminal.go | 63 ++ testdata/chi/middleware/throttle.go | 151 +++ testdata/chi/middleware/timeout.go | 48 + testdata/chi/middleware/url_format.go | 77 ++ testdata/chi/middleware/value.go | 17 + testdata/chi/middleware/wrap_writer.go | 243 +++++ testdata/chi/mux.go | 526 +++++++++++ testdata/chi/tree.go | 877 ++++++++++++++++++ testdata/fixture/go.mod | 3 - testdata/generics/fn/fn.go | 46 + testdata/generics/go.mod | 3 + testdata/generics/main.go | 22 + testdata/generics/set/set.go | 36 + testdata/greeter/go.mod | 3 + testdata/{fixture => greeter}/main.go | 2 +- .../pkg/greeter/greeter.go | 0 testdata/multipackage/go.mod | 3 + testdata/{realistic => multipackage}/main.go | 4 +- .../server/middleware.go | 0 .../server/server.go | 0 testdata/multipackage/server/server_test.go | 16 + .../worker/worker.go | 0 testdata/realistic/go.mod | 3 - 66 files changed, 5968 insertions(+), 184 deletions(-) create mode 100644 cmd/codeanalyzer/main_test.go create mode 100644 internal/analysis/registry_test.go create mode 100644 internal/core/chi_test.go create mode 100644 internal/core/errors_test.go create mode 100644 internal/core/generics_test.go create mode 100644 internal/core/incremental_test.go rename internal/core/{realistic_test.go => multipackage_test.go} (69%) create mode 100644 internal/core/skip_tests_test.go create mode 100644 internal/core/testsetup_test.go create mode 100644 internal/semantic_analysis/call_graph_test.go create mode 100644 internal/utils/fs_test.go create mode 100644 testdata/chi/chain.go create mode 100644 testdata/chi/chi.go create mode 100644 testdata/chi/context.go create mode 100644 testdata/chi/go.mod create mode 100644 testdata/chi/middleware/basic_auth.go create mode 100644 testdata/chi/middleware/clean_path.go create mode 100644 testdata/chi/middleware/client_ip.go create mode 100644 testdata/chi/middleware/compress.go create mode 100644 testdata/chi/middleware/content_charset.go create mode 100644 testdata/chi/middleware/content_encoding.go create mode 100644 testdata/chi/middleware/content_type.go create mode 100644 testdata/chi/middleware/get_head.go create mode 100644 testdata/chi/middleware/heartbeat.go create mode 100644 testdata/chi/middleware/logger.go create mode 100644 testdata/chi/middleware/maybe.go create mode 100644 testdata/chi/middleware/middleware.go create mode 100644 testdata/chi/middleware/nocache.go create mode 100644 testdata/chi/middleware/page_route.go create mode 100644 testdata/chi/middleware/path_rewrite.go create mode 100644 testdata/chi/middleware/profiler.go create mode 100644 testdata/chi/middleware/realip.go create mode 100644 testdata/chi/middleware/recoverer.go create mode 100644 testdata/chi/middleware/request_id.go create mode 100644 testdata/chi/middleware/request_size.go create mode 100644 testdata/chi/middleware/route_headers.go create mode 100644 testdata/chi/middleware/strip.go create mode 100644 testdata/chi/middleware/sunset.go create mode 100644 testdata/chi/middleware/supress_notfound.go create mode 100644 testdata/chi/middleware/terminal.go create mode 100644 testdata/chi/middleware/throttle.go create mode 100644 testdata/chi/middleware/timeout.go create mode 100644 testdata/chi/middleware/url_format.go create mode 100644 testdata/chi/middleware/value.go create mode 100644 testdata/chi/middleware/wrap_writer.go create mode 100644 testdata/chi/mux.go create mode 100644 testdata/chi/tree.go delete mode 100644 testdata/fixture/go.mod create mode 100644 testdata/generics/fn/fn.go create mode 100644 testdata/generics/go.mod create mode 100644 testdata/generics/main.go create mode 100644 testdata/generics/set/set.go create mode 100644 testdata/greeter/go.mod rename testdata/{fixture => greeter}/main.go (82%) rename testdata/{fixture => greeter}/pkg/greeter/greeter.go (100%) create mode 100644 testdata/multipackage/go.mod rename testdata/{realistic => multipackage}/main.go (87%) rename testdata/{realistic => multipackage}/server/middleware.go (100%) rename testdata/{realistic => multipackage}/server/server.go (100%) create mode 100644 testdata/multipackage/server/server_test.go rename testdata/{realistic => multipackage}/worker/worker.go (100%) delete mode 100644 testdata/realistic/go.mod diff --git a/README.md b/README.md index 608e3a3..9310f1c 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,11 @@ codeanalyzer-go/ │ ├── analysis/ # Pluggable pass interface + registry (topo-ordered pipeline) │ ├── frameworks/ # BaseEntrypointFinder — extension seam for framework passes │ └── utils/ # DiscoverGoFiles, IsVendored, IsTestFile, logging -└── testdata/fixture/ # Minimal Go fixture used by tests +├── testdata/ +│ ├── fixture/ # Minimal two-package fixture (basic struct/interface/call sites) +│ ├── realistic/ # Richer fixture covering embedded fields, variadic params, goroutines, … +│ ├── generics/ # Go 1.18+ generics fixture (Set[T], union-constraint interfaces, Map[T,U]) +│ └── chi/ # External-dep fixture (chi v5, vendored) for HTTP handler patterns ``` The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `semantic_analysis` → `analysis.RunPipeline` → optional CodeQL in sequence, with no inlined parsing logic. Framework-specific analysis extends through the `analysis/` + `frameworks/` layer without touching `core`. @@ -200,7 +204,31 @@ The `core` package is a pure orchestrator: it calls `syntactic_analysis` → `se go test ./... ``` -Tests run against `testdata/fixture/` and `testdata/realistic/` — a minimal two-package and a richer multi-package Go module. All 33 tests cover symbol table correctness, call graph edges, JSON round-trip, output format validation, and caching/incremental behaviour. +Tests run against four fixtures: `testdata/fixture/` (basic), `testdata/realistic/` (multi-file packages, goroutines, variadic params), `testdata/generics/` (Go 1.18+ generics — `Set[T]`, union constraints, multi-type-param functions), and `testdata/chi/` (external dependency via vendored chi v5, HTTP handler patterns). All 57 tests cover symbol table correctness, generic receiver attribution, call graph edges, JSON round-trip, output format validation, caching behaviour, and error paths. + +`go test` caches passing results by source hash. To force a full re-run: + +```bash +go clean -testcache && go test ./... +``` + +The analyzer's own `CacheDir` (used inside tests for `analysis_cache.json` and `go_mod_hash`) is written to OS temp directories that are wiped automatically when the test binary exits — there is no persistent on-disk state between test runs. The chi fixture is fully vendored, so tests never require network access. + +### Clearing the production cache + +By default the CLI writes its cache to `~/.cldk/go-cache`. To bypass it for a single run: + +```bash +codeanalyzer-go -i ./my-project --eager +``` + +To delete it entirely: + +```bash +rm -rf ~/.cldk/go-cache +``` + +If you pass a custom `--cache-dir`, remove that directory instead. ### Running from source diff --git a/cmd/codeanalyzer/main.go b/cmd/codeanalyzer/main.go index 7425781..6968fd8 100644 --- a/cmd/codeanalyzer/main.go +++ b/cmd/codeanalyzer/main.go @@ -5,6 +5,7 @@ package main import ( + "encoding/json" "fmt" "os" @@ -49,12 +50,20 @@ via CLDK(language="go").analysis(project_path=...).`, SilenceUsage: true, RunE: func(cmd *cobra.Command, args []string) error { if showVersion { - fmt.Println("codeanalyzer-go " + version) + cmd.Println("codeanalyzer-go " + version) return nil } if inputPath == "" { return fmt.Errorf("--input / -i is required") } + switch format { + case "", "json": + // valid + case "msgpack": + return fmt.Errorf("msgpack output is not yet implemented; use --format json") + default: + return fmt.Errorf("unsupported output format %q; supported: json", format) + } utils.SetVerbosity(verbosity) if cacheDir == "" { @@ -81,6 +90,16 @@ via CLDK(language="go").analysis(project_path=...).`, return err } + // When no --output dir is given, write JSON to cobra's output + // writer so tests can capture it via cmd.SetOut. + if outputDir == "" { + data, err := json.Marshal(app) + if err != nil { + return err + } + _, err = cmd.OutOrStdout().Write(data) + return err + } return core.WriteOutput(app, outputDir, format) }, } diff --git a/cmd/codeanalyzer/main_test.go b/cmd/codeanalyzer/main_test.go new file mode 100644 index 0000000..67fdefd --- /dev/null +++ b/cmd/codeanalyzer/main_test.go @@ -0,0 +1,205 @@ +package main + +// CLI integration tests. These call rootCmd().Execute() directly (same +// package, so the unexported function is accessible) with controlled args and +// capture cobra's output buffer. No subprocess or binary required. + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +// cliTestdataDir returns the absolute path to the repo-level testdata directory. +func cliTestdataDir() string { + _, thisFile, _, _ := runtime.Caller(0) + abs, _ := filepath.Abs(filepath.Join(filepath.Dir(thisFile), "..", "..", "testdata")) + return abs +} + +// runCmd executes rootCmd with the given args and returns (stdout, stderr, error). +func runCmd(args ...string) (stdout, stderr string, err error) { + cmd := rootCmd() + var outBuf, errBuf bytes.Buffer + cmd.SetOut(&outBuf) + cmd.SetErr(&errBuf) + cmd.SetArgs(args) + err = cmd.Execute() + return outBuf.String(), errBuf.String(), err +} + +// ── Flag validation ─────────────────────────────────────────────────────────── + +func TestRootCmd_MissingInputReturnsError(t *testing.T) { + _, _, err := runCmd() + if err == nil { + t.Fatal("expected error when --input is missing, got nil") + } + if !strings.Contains(err.Error(), "required") { + t.Errorf("error should mention 'required'; got %q", err.Error()) + } +} + +func TestRootCmd_NonExistentInputReturnsError(t *testing.T) { + _, _, err := runCmd("--input", filepath.Join(t.TempDir(), "does_not_exist")) + if err == nil { + t.Fatal("expected error for non-existent --input path, got nil") + } +} + +func TestRootCmd_UnknownFormatReturnsError(t *testing.T) { + td := cliTestdataDir() + _, _, err := runCmd("--input", filepath.Join(td, "greeter"), "--format", "csv") + if err == nil { + t.Fatal("expected error for unknown --format value, got nil") + } +} + +// ── --version ──────────────────────────────────────────────────────────────── + +func TestRootCmd_VersionFlag(t *testing.T) { + out, _, err := runCmd("--version") + if err != nil { + t.Fatalf("--version returned unexpected error: %v", err) + } + if !strings.Contains(out, "codeanalyzer-go") { + t.Errorf("--version output should contain 'codeanalyzer-go'; got %q", out) + } + if !strings.Contains(out, version) { + t.Errorf("--version output should contain version %q; got %q", version, out) + } +} + +// ── --output writes analysis.json ──────────────────────────────────────────── + +func TestRootCmd_OutputDirWritesFile(t *testing.T) { + td := cliTestdataDir() + outDir := t.TempDir() + + _, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--output", outDir, + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + analysisPath := filepath.Join(outDir, "analysis.json") + if _, statErr := os.Stat(analysisPath); statErr != nil { + t.Fatalf("analysis.json not created in output dir: %v", statErr) + } +} + +func TestRootCmd_NoOutputWritesToStdout(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + if out == "" { + t.Fatal("expected JSON on stdout when --output is omitted, got empty string") + } + var v interface{} + if jsonErr := json.Unmarshal([]byte(out), &v); jsonErr != nil { + t.Errorf("stdout is not valid JSON: %v\noutput: %s", jsonErr, out) + } +} + +// ── --analysis-level ───────────────────────────────────────────────────────── + +func TestRootCmd_Level1ProducesNoCallGraph(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--analysis-level", "1", + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + var result struct { + CallGraph []interface{} `json:"call_graph"` + } + if jsonErr := json.Unmarshal([]byte(out), &result); jsonErr != nil { + t.Fatalf("stdout is not valid JSON: %v", jsonErr) + } + if len(result.CallGraph) != 0 { + t.Errorf("level 1 should produce no call graph edges; got %d", len(result.CallGraph)) + } +} + +func TestRootCmd_Level2ProducesCallGraph(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "greeter"), + "--analysis-level", "2", + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + var result struct { + CallGraph []interface{} `json:"call_graph"` + } + if jsonErr := json.Unmarshal([]byte(out), &result); jsonErr != nil { + t.Fatalf("stdout is not valid JSON: %v", jsonErr) + } + if len(result.CallGraph) == 0 { + t.Error("level 2 should produce call graph edges; got none") + } +} + +// ── --skip-tests ───────────────────────────────────────────────────────────── + +func TestRootCmd_SkipTestsFalseIncludesTestFiles(t *testing.T) { + td := cliTestdataDir() + + out, _, err := runCmd( + "--input", filepath.Join(td, "multipackage"), + "--skip-tests=false", + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + if !strings.Contains(out, "server_test.go") { + t.Error("--skip-tests=false: server_test.go should appear in JSON output") + } +} + +// ── --target-files ──────────────────────────────────────────────────────────── + +func TestRootCmd_TargetFilesRestrictsOutput(t *testing.T) { + td := cliTestdataDir() + serverFile := filepath.Join(td, "multipackage", "server", "server.go") + + out, _, err := runCmd( + "--input", filepath.Join(td, "multipackage"), + "--target-files", serverFile, + "--cache-dir", t.TempDir(), + ) + if err != nil { + t.Fatalf("command failed: %v", err) + } + + if strings.Contains(out, `"worker/worker.go"`) { + t.Error("--target-files: worker/worker.go should not appear when only server is targeted") + } + if !strings.Contains(out, `"server/server.go"`) { + t.Error("--target-files: server/server.go should appear in output") + } +} diff --git a/internal/analysis/registry_test.go b/internal/analysis/registry_test.go new file mode 100644 index 0000000..68bd55d --- /dev/null +++ b/internal/analysis/registry_test.go @@ -0,0 +1,137 @@ +package analysis + +// Tests for orderPasses (unexported — must be in the same package). +// +// We use lightweight stub passes so these tests have no external dependencies +// and run without loading any Go source files. + +import ( + "strings" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// stubPass is a minimal AnalysisPass for testing orderPasses. +type stubPass struct { + name string + provides []string + requires []string +} + +func (s *stubPass) Name() string { return s.name } +func (s *stubPass) Provides() []string { return s.provides } +func (s *stubPass) Requires() []string { return s.requires } +func (s *stubPass) Run(_ *schema.GoApplication, _ AnalysisContext) (AnalysisResult, error) { + return AnalysisResult{}, nil +} + +func mkPass(name string, provides, requires []string) AnalysisPass { + return &stubPass{name: name, provides: provides, requires: requires} +} + +// ── orderPasses ─────────────────────────────────────────────────────────────── + +func TestOrderPasses_Empty(t *testing.T) { + ordered, err := orderPasses(nil) + if err != nil { + t.Fatalf("empty passes: unexpected error: %v", err) + } + if len(ordered) != 0 { + t.Errorf("got %d passes, want 0", len(ordered)) + } +} + +func TestOrderPasses_SingleNoDeps(t *testing.T) { + p := mkPass("solo", []string{"x"}, nil) + ordered, err := orderPasses([]AnalysisPass{p}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ordered) != 1 || ordered[0].Name() != "solo" { + t.Errorf("expected [solo]; got %v", names(ordered)) + } +} + +// A → B: A provides "feat", B requires "feat". A must come before B. +func TestOrderPasses_LinearDependency(t *testing.T) { + a := mkPass("a", []string{"feat"}, nil) + b := mkPass("b", nil, []string{"feat"}) + // Deliver in reverse order to stress the sort. + ordered, err := orderPasses([]AnalysisPass{b, a}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ordered) != 2 { + t.Fatalf("got %d passes, want 2", len(ordered)) + } + if ordered[0].Name() != "a" || ordered[1].Name() != "b" { + t.Errorf("wrong order: got %v, want [a b]", names(ordered)) + } +} + +// A → C ← B: two independent passes both provide something C needs. +func TestOrderPasses_DiamondDependency(t *testing.T) { + a := mkPass("a", []string{"x"}, nil) + b := mkPass("b", []string{"y"}, nil) + c := mkPass("c", nil, []string{"x", "y"}) + ordered, err := orderPasses([]AnalysisPass{c, b, a}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ordered) != 3 { + t.Fatalf("got %d passes, want 3", len(ordered)) + } + // c must be last. + if ordered[len(ordered)-1].Name() != "c" { + t.Errorf("c should be last; got %v", names(ordered)) + } +} + +func TestOrderPasses_UnsatisfiedRequirement(t *testing.T) { + p := mkPass("needy", nil, []string{"missing-cap"}) + _, err := orderPasses([]AnalysisPass{p}) + if err == nil { + t.Fatal("expected error for unsatisfied requirement, got nil") + } + if !strings.Contains(err.Error(), "needy") { + t.Errorf("error should mention the blocked pass name; got %q", err.Error()) + } +} + +func TestOrderPasses_Cycle(t *testing.T) { + // A requires "b", B requires "a" — neither can run. + a := mkPass("a", []string{"a-cap"}, []string{"b-cap"}) + b := mkPass("b", []string{"b-cap"}, []string{"a-cap"}) + _, err := orderPasses([]AnalysisPass{a, b}) + if err == nil { + t.Fatal("expected error for cycle, got nil") + } +} + +// ── RunPipeline with empty registry ────────────────────────────────────────── + +func TestRunPipeline_EmptyRegistry(t *testing.T) { + // Save and restore to avoid affecting other tests. + old := registeredPasses + registeredPasses = nil + defer func() { registeredPasses = old }() + + app := &schema.GoApplication{ + Entrypoints: map[string][]schema.GoEntrypoint{}, + CallGraph: []schema.GoCallEdge{}, + } + if err := RunPipeline(app, AnalysisContext{}); err != nil { + t.Fatalf("RunPipeline with empty registry: %v", err) + } +} + +// ── helpers ─────────────────────────────────────────────────────────────────── + +func names(passes []AnalysisPass) []string { + out := make([]string, len(passes)) + for i, p := range passes { + out[i] = p.Name() + } + return out +} diff --git a/internal/core/analyzer_test.go b/internal/core/analyzer_test.go index 8d322ba..1e95011 100644 --- a/internal/core/analyzer_test.go +++ b/internal/core/analyzer_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "os" "path/filepath" - "runtime" "testing" "time" @@ -13,52 +12,23 @@ import ( "github.com/codellm-devkit/codeanalyzer-go/internal/schema" ) -// fixtureDir returns the absolute path to testdata/fixture. -func fixtureDir(t *testing.T) string { +// greeterDir returns the absolute path to testdata/greeter. +// Still used by the caching tests, which must run fresh analysis. +func greeterDir(t *testing.T) string { t.Helper() - _, file, _, ok := runtime.Caller(0) - if !ok { - t.Fatal("cannot determine source file path") - } - // internal/core/analyzer_test.go → ../.. → codeanalyzer-go root → testdata/fixture - root := filepath.Join(filepath.Dir(file), "..", "..") - abs, err := filepath.Abs(filepath.Join(root, "testdata", "fixture")) - if err != nil { - t.Fatalf("resolving fixture dir: %v", err) - } - return abs -} - -func runAnalysis(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { - t.Helper() - dir := fixtureDir(t) - outDir := t.TempDir() - opts := options.AnalysisOptions{ - InputPath: dir, - OutputDir: outDir, - Level: level, - SkipTests: true, - CacheDir: t.TempDir(), - } - app, err := core.New(opts).Analyze() - if err != nil { - t.Fatalf("Analyze() failed: %v", err) - } - return app + return filepath.Join(testdataDir(), "greeter") } // ── Symbol table tests ──────────────────────────────────────────────────────── func TestSymbolTable_NonEmpty(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - if len(app.SymbolTable) == 0 { + if len(sharedGreeterL1.SymbolTable) == 0 { t.Fatal("symbol table is empty") } } func TestSymbolTable_PathKeysAreRelative(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - for key := range app.SymbolTable { + for key := range sharedGreeterL1.SymbolTable { if filepath.IsAbs(key) { t.Errorf("symbol_table key is absolute path: %s", key) } @@ -66,11 +36,10 @@ func TestSymbolTable_PathKeysAreRelative(t *testing.T) { } func TestSymbolTable_KnownType(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) const wantFile = "pkg/greeter/greeter.go" - f, ok := app.SymbolTable[wantFile] + f, ok := sharedGreeterL1.SymbolTable[wantFile] if !ok { - t.Fatalf("file %q not in symbol table; got keys: %v", wantFile, keys(app.SymbolTable)) + t.Fatalf("file %q not in symbol table; got keys: %v", wantFile, keys(sharedGreeterL1.SymbolTable)) } if _, ok := f.Types["Greeter"]; !ok { t.Errorf("GoType 'Greeter' not found in %s", wantFile) @@ -78,8 +47,7 @@ func TestSymbolTable_KnownType(t *testing.T) { } func TestSymbolTable_KnownInterface(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - f := app.SymbolTable["pkg/greeter/greeter.go"] + f := sharedGreeterL1.SymbolTable["pkg/greeter/greeter.go"] gt, ok := f.Types["Logger"] if !ok { t.Fatal("GoType 'Logger' not found") @@ -90,8 +58,7 @@ func TestSymbolTable_KnownInterface(t *testing.T) { } func TestSymbolTable_StructFields(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - f := app.SymbolTable["pkg/greeter/greeter.go"] + f := sharedGreeterL1.SymbolTable["pkg/greeter/greeter.go"] gt := f.Types["Greeter"] if len(gt.Fields) == 0 { t.Fatal("Greeter has no fields") @@ -105,8 +72,7 @@ func TestSymbolTable_StructFields(t *testing.T) { } func TestSymbolTable_CallSitesRecorded(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) - f := app.SymbolTable["main.go"] + f := sharedGreeterL1.SymbolTable["main.go"] var mainFn *schema.GoCallable for _, c := range f.Functions { c := c @@ -121,7 +87,6 @@ func TestSymbolTable_CallSitesRecorded(t *testing.T) { if len(mainFn.CallSites) == 0 { t.Error("main() has no recorded call sites") } - // All call sites must start with callee_signature == nil (pre-resolution). for _, cs := range mainFn.CallSites { if cs.CalleeSignature != nil { t.Errorf("call site %q has callee_signature pre-filled during symbol-table build", cs.MethodName) @@ -132,16 +97,14 @@ func TestSymbolTable_CallSitesRecorded(t *testing.T) { // ── Call graph tests ────────────────────────────────────────────────────────── func TestCallGraph_NonEmpty(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - if len(app.CallGraph) == 0 { + if len(sharedGreeterL2.CallGraph) == 0 { t.Fatal("call graph is empty") } } func TestCallGraph_NoDanglingEdges(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - sigs := allSignatures(app) - for _, e := range app.CallGraph { + sigs := allSignatures(sharedGreeterL2) + for _, e := range sharedGreeterL2.CallGraph { if !sigs[e.Source] { t.Errorf("dangling edge source: %s", e.Source) } @@ -152,8 +115,7 @@ func TestCallGraph_NoDanglingEdges(t *testing.T) { } func TestCallGraph_Provenance(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - for _, e := range app.CallGraph { + for _, e := range sharedGreeterL2.CallGraph { if len(e.Provenance) == 0 { t.Errorf("edge %s→%s has empty provenance", e.Source, e.Target) } @@ -161,11 +123,9 @@ func TestCallGraph_Provenance(t *testing.T) { } func TestCallGraph_CallSitesBackfilled(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) - f := app.SymbolTable["main.go"] + f := sharedGreeterL2.SymbolTable["main.go"] for _, callable := range f.Functions { for _, cs := range callable.CallSites { - // Sites that resolved to a project-internal callee must be backfilled. if cs.CalleeSignature != nil && *cs.CalleeSignature == "" { t.Errorf("callable %s: call site %q has empty string callee_signature", callable.Signature, cs.MethodName) } @@ -176,9 +136,8 @@ func TestCallGraph_CallSitesBackfilled(t *testing.T) { // ── JSON output tests ───────────────────────────────────────────────────────── func TestWriteOutput_ValidJSON(t *testing.T) { - app := runAnalysis(t, options.LevelCallGraph) outDir := t.TempDir() - if err := core.WriteOutput(app, outDir, "json"); err != nil { + if err := core.WriteOutput(sharedGreeterL2, outDir, "json"); err != nil { t.Fatalf("WriteOutput: %v", err) } data, err := os.ReadFile(filepath.Join(outDir, "analysis.json")) @@ -195,9 +154,8 @@ func TestWriteOutput_ValidJSON(t *testing.T) { } func TestWriteOutput_EmptyFormatDefaultsToJSON(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) outDir := t.TempDir() - if err := core.WriteOutput(app, outDir, ""); err != nil { + if err := core.WriteOutput(sharedGreeterL1, outDir, ""); err != nil { t.Fatalf("WriteOutput with empty format: %v", err) } if _, err := os.Stat(filepath.Join(outDir, "analysis.json")); err != nil { @@ -206,42 +164,36 @@ func TestWriteOutput_EmptyFormatDefaultsToJSON(t *testing.T) { } func TestWriteOutput_MsgpackNotImplemented(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) outDir := t.TempDir() - err := core.WriteOutput(app, outDir, "msgpack") - if err == nil { + if err := core.WriteOutput(sharedGreeterL1, outDir, "msgpack"); err == nil { t.Fatal("expected error for --format msgpack, got nil") } } func TestWriteOutput_UnknownFormatErrors(t *testing.T) { - app := runAnalysis(t, options.LevelSymbolTable) outDir := t.TempDir() - err := core.WriteOutput(app, outDir, "csv") - if err == nil { + if err := core.WriteOutput(sharedGreeterL1, outDir, "csv"); err == nil { t.Fatal("expected error for unknown format, got nil") } } // ── Caching tests ───────────────────────────────────────────────────────────── +// These tests must run their own analysis to exercise the caching machinery. func TestCaching_SecondRunReuses(t *testing.T) { - dir := fixtureDir(t) + dir := greeterDir(t) cacheDir := t.TempDir() - outDir := t.TempDir() opts := options.AnalysisOptions{ InputPath: dir, - OutputDir: outDir, + OutputDir: t.TempDir(), Level: options.LevelCallGraph, SkipTests: true, CacheDir: cacheDir, } - // First run — populates cache. app1, err := core.New(opts).Analyze() if err != nil { t.Fatalf("first run: %v", err) } - // Second run — must not error and must return identical key count. app2, err := core.New(opts).Analyze() if err != nil { t.Fatalf("second run: %v", err) @@ -256,10 +208,9 @@ func TestCaching_SecondRunReuses(t *testing.T) { } func TestCaching_CacheFileWritten(t *testing.T) { - dir := fixtureDir(t) cacheDir := t.TempDir() opts := options.AnalysisOptions{ - InputPath: dir, + InputPath: greeterDir(t), Level: options.LevelSymbolTable, SkipTests: true, CacheDir: cacheDir, @@ -267,17 +218,15 @@ func TestCaching_CacheFileWritten(t *testing.T) { if _, err := core.New(opts).Analyze(); err != nil { t.Fatalf("Analyze: %v", err) } - cachePath := filepath.Join(cacheDir, "analysis_cache.json") - if _, err := os.Stat(cachePath); err != nil { + if _, err := os.Stat(filepath.Join(cacheDir, "analysis_cache.json")); err != nil { t.Fatalf("analysis_cache.json not written to CacheDir: %v", err) } } func TestCaching_CacheContentsRoundTrip(t *testing.T) { - dir := fixtureDir(t) cacheDir := t.TempDir() opts := options.AnalysisOptions{ - InputPath: dir, + InputPath: greeterDir(t), Level: options.LevelSymbolTable, SkipTests: true, CacheDir: cacheDir, @@ -301,15 +250,13 @@ func TestCaching_CacheContentsRoundTrip(t *testing.T) { } func TestCaching_EagerForcesRebuild(t *testing.T) { - dir := fixtureDir(t) cacheDir := t.TempDir() opts := options.AnalysisOptions{ - InputPath: dir, + InputPath: greeterDir(t), Level: options.LevelSymbolTable, SkipTests: true, CacheDir: cacheDir, } - // First run (non-eager) — seeds go_mod_hash. if _, err := core.New(opts).Analyze(); err != nil { t.Fatalf("first run: %v", err) } @@ -319,9 +266,12 @@ func TestCaching_EagerForcesRebuild(t *testing.T) { t.Fatalf("cache not written after first run: %v", err) } - time.Sleep(10 * time.Millisecond) + // Backdate the cache file so the mtime delta is unambiguous — no sleep needed. + past := info1.ModTime().Add(-time.Second) + if err := os.Chtimes(cachePath, past, past); err != nil { + t.Fatalf("backdating cache mtime: %v", err) + } - // Second run with Eager=true — must rewrite cache even when go_mod_hash matches. opts.Eager = true if _, err := core.New(opts).Analyze(); err != nil { t.Fatalf("eager run: %v", err) @@ -330,10 +280,9 @@ func TestCaching_EagerForcesRebuild(t *testing.T) { if err != nil { t.Fatalf("cache not found after eager run: %v", err) } - // saveCache always writes, so mtime must advance. - if !info2.ModTime().After(info1.ModTime()) { + if !info2.ModTime().After(past) { t.Errorf("analysis_cache.json mtime did not advance on eager=true run: %v vs %v", - info1.ModTime(), info2.ModTime()) + past, info2.ModTime()) } } diff --git a/internal/core/chi_test.go b/internal/core/chi_test.go new file mode 100644 index 0000000..48742fe --- /dev/null +++ b/internal/core/chi_test.go @@ -0,0 +1,196 @@ +package core_test + +// Tests for the chi fixture — chi v5 (github.com/go-chi/chi/v5) analyzed as the +// project under test, not as a dependency. +// +// Goals: +// 1. Multi-package library (root + middleware) is fully indexed. +// 2. Interface types (chi.Router) and struct types (chi.Mux) are both captured. +// 3. Methods on *Mux (Get, Post, Route, …) appear in the symbol table. +// 4. Vendor files are absent (chi has no external deps; nothing to exclude). +// 5. Call graph edges (Level 2) are internally consistent — no dangling endpoints. + +import ( + "strings" + "testing" +) + +// ── File coverage ───────────────────────────────────────────────────────────── + +func TestChi_SymbolTableNonEmpty(t *testing.T) { + if len(sharedChiL2.SymbolTable) == 0 { + t.Fatal("chi symbol table is empty — analysis may have failed silently") + } +} + +// chi v5 has exactly 35 non-test Go source files (5 root + 30 middleware). +func TestChi_SymbolTableFileCount(t *testing.T) { + const want = 35 + if got := len(sharedChiL2.SymbolTable); got != want { + t.Errorf("symbol table: got %d file(s), want %d; keys: %v", got, want, keys(sharedChiL2.SymbolTable)) + } +} + +func TestChi_PathKeysAreRelative(t *testing.T) { + for key := range sharedChiL2.SymbolTable { + if strings.HasPrefix(key, "/") { + t.Errorf("symbol_table key is absolute: %s", key) + } + } +} + +// ── Root package files ──────────────────────────────────────────────────────── + +func TestChi_RootFilesPresent(t *testing.T) { + for _, name := range []string{"chi.go", "mux.go", "context.go", "chain.go", "tree.go"} { + t.Run(name, func(t *testing.T) { + if _, ok := sharedChiL2.SymbolTable[name]; !ok { + t.Errorf("%s not in symbol table; keys: %v", name, keys(sharedChiL2.SymbolTable)) + } + }) + } +} + +// ── Middleware package files ────────────────────────────────────────────────── + +func TestChi_MiddlewareFilesPresent(t *testing.T) { + for _, name := range []string{ + "middleware/logger.go", + "middleware/recoverer.go", + "middleware/middleware.go", + } { + t.Run(name, func(t *testing.T) { + if _, ok := sharedChiL2.SymbolTable[name]; !ok { + t.Errorf("%s not in symbol table", name) + } + }) + } +} + +// ── Interface and struct types ──────────────────────────────────────────────── + +// chi.go declares the Router interface. +func TestChi_RouterIsInterface(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["chi.go"] + if !ok { + t.Fatal("chi.go not in symbol table") + } + router, ok := f.Types["Router"] + if !ok { + t.Fatal("Router type not found in chi.go") + } + if !router.IsInterface { + t.Error("Router should be an interface, got is_interface=false") + } +} + +// mux.go declares the Mux struct (not an interface). +func TestChi_MuxIsStruct(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["mux.go"] + if !ok { + t.Fatal("mux.go not in symbol table") + } + mux, ok := f.Types["Mux"] + if !ok { + t.Fatal("Mux type not found in mux.go") + } + if mux.IsInterface { + t.Error("Mux should be a struct, got is_interface=true") + } +} + +// ── Methods on *Mux ─────────────────────────────────────────────────────────── + +func TestChi_MuxHasRoutingMethods(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["mux.go"] + if !ok { + t.Fatal("mux.go not in symbol table") + } + for _, method := range []string{"Get", "Post", "Put", "Delete", "Route", "Use", "With"} { + t.Run(method, func(t *testing.T) { + if findCallableByName(f, method) == nil { + t.Errorf("method %q not found on Mux in mux.go", method) + } + }) + } +} + +// ── Call graph: no dangling edges ───────────────────────────────────────────── + +func TestChi_NoDanglingEdges(t *testing.T) { + sigs := allSignatures(sharedChiL2) + for _, e := range sharedChiL2.CallGraph { + if !sigs[e.Source] { + t.Errorf("dangling edge source: %s", e.Source) + } + if !sigs[e.Target] { + t.Errorf("dangling edge target: %s", e.Target) + } + } +} + +// ── H1: InnerCallables populated for functions with closures ────────────────── + +// middleware/logger.go RequestLogger returns a closure-based middleware; its +// outer function body should have at least one inner callable after the fix. +func TestChi_RequestLoggerHasInnerCallables(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["middleware/logger.go"] + if !ok { + t.Fatal("middleware/logger.go not in symbol table") + } + rl := findCallableByName(f, "RequestLogger") + if rl == nil { + t.Fatal("RequestLogger not found in middleware/logger.go") + } + if len(rl.InnerCallables) == 0 { + t.Error("RequestLogger should have at least one inner callable (closure), got none") + } +} + +// ── H2: IsConstructorCall for type-conversion call sites ────────────────────── + +// tree.go RegisterMethod contains `mt := methodTyp(2 << n)`. +// methodTyp is a named type, so the call is a type-conversion (constructor) call. +func TestChi_RegisterMethodHasConstructorCallSite(t *testing.T) { + f, ok := sharedChiL2.SymbolTable["tree.go"] + if !ok { + t.Fatal("tree.go not in symbol table") + } + rm := findCallableByName(f, "RegisterMethod") + if rm == nil { + t.Fatal("RegisterMethod not found in tree.go") + } + for _, site := range rm.CallSites { + if site.IsConstructorCall { + return // found + } + } + t.Error("RegisterMethod should have at least one IsConstructorCall=true site (methodTyp(...))") +} + +// ── H7: init() functions captured ──────────────────────────────────────────── + +// middleware/terminal.go, middleware/logger.go, and middleware/request_id.go +// each declare an init() function that should appear in the symbol table. +func TestChi_InitFunctionsPresent(t *testing.T) { + for _, file := range []string{ + "middleware/terminal.go", + "middleware/logger.go", + "middleware/request_id.go", + } { + t.Run(file, func(t *testing.T) { + f, ok := sharedChiL2.SymbolTable[file] + if !ok { + t.Fatalf("%s not in symbol table", file) + } + initFn := findCallableByName(f, "init") + if initFn == nil { + t.Errorf("init() not found in %s", file) + return + } + if initFn.IsExported { + t.Errorf("init() in %s should not be exported", file) + } + }) + } +} diff --git a/internal/core/errors_test.go b/internal/core/errors_test.go new file mode 100644 index 0000000..f66e24e --- /dev/null +++ b/internal/core/errors_test.go @@ -0,0 +1,76 @@ +package core_test + +// Error-path tests: verify that the analyzer returns meaningful errors (not +// panics or silent empty results) when given bad inputs. + +import ( + "os" + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" +) + +func TestAnalyze_NonExistentPath(t *testing.T) { + opts := options.AnalysisOptions{ + InputPath: filepath.Join(t.TempDir(), "does_not_exist"), + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + } + _, err := core.New(opts).Analyze() + if err == nil { + t.Fatal("expected error for non-existent InputPath, got nil") + } +} + +func TestAnalyze_EmptyDirectory(t *testing.T) { + // A real directory with no Go files: analyzer should succeed but produce an + // empty symbol table (graceful degradation, not a hard error). + emptyDir := t.TempDir() + opts := options.AnalysisOptions{ + InputPath: emptyDir, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Logf("Analyze returned error (acceptable): %v", err) + return + } + if len(app.SymbolTable) != 0 { + t.Errorf("expected empty symbol table for empty directory; got %d entries", len(app.SymbolTable)) + } +} + +func TestAnalyze_MissingGoMod(t *testing.T) { + // A directory with a .go file but no go.mod — not a valid module. + // The analyzer either returns an error or an empty symbol table; both are acceptable. + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "hello.go"), []byte("package main\nfunc main(){}\n"), 0o644); err != nil { + t.Fatalf("writing hello.go: %v", err) + } + opts := options.AnalysisOptions{ + InputPath: dir, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + } + app, err := core.New(opts).Analyze() + if err != nil { + t.Logf("Analyze returned error (acceptable): %v", err) + return + } + if len(app.SymbolTable) != 0 { + t.Errorf("expected empty symbol table for module with no go.mod; got %d entries", len(app.SymbolTable)) + } +} + +func TestAnalyze_LevelOneDoesNotProduceCallGraph(t *testing.T) { + // Level-1 analysis must never populate the call graph. + if len(sharedGreeterL1.CallGraph) != 0 { + t.Errorf("LevelSymbolTable produced %d call-graph edges; expected 0", len(sharedGreeterL1.CallGraph)) + } +} diff --git a/internal/core/generics_test.go b/internal/core/generics_test.go new file mode 100644 index 0000000..53e6788 --- /dev/null +++ b/internal/core/generics_test.go @@ -0,0 +1,127 @@ +package core_test + +// Tests for Go 1.18+ generic constructs — type parameters, union-constraint +// interfaces, multi-type-parameter functions, and methods on generic types. +// These exercise AST paths (IndexExpr receivers, TypeParams) that the greeter +// and realistic fixtures never reach. + +import ( + "strings" + "testing" +) + +// ── Symbol table completeness ───────────────────────────────────────────────── + +func TestGenerics_SymbolTableNonEmpty(t *testing.T) { + if len(sharedGenericsL1.SymbolTable) == 0 { + t.Fatal("generics symbol table is empty") + } +} + +func TestGenerics_PathKeysAreRelative(t *testing.T) { + for key := range sharedGenericsL1.SymbolTable { + if strings.HasPrefix(key, "/") { + t.Errorf("symbol_table key is absolute: %s", key) + } + } +} + +// ── Type name integrity ─────────────────────────────────────────────────────── + +func TestGenerics_SetTypePresentInSymbolTable(t *testing.T) { + const wantFile = "set/set.go" + f, ok := sharedGenericsL1.SymbolTable[wantFile] + if !ok { + t.Fatalf("%s not in symbol table; keys: %v", wantFile, keys(sharedGenericsL1.SymbolTable)) + } + if _, ok := f.Types["Set"]; !ok { + t.Errorf("GoType 'Set' not found in %s", wantFile) + } +} + +// The type name must be the base identifier only, not the parameterised form. +func TestGenerics_TypeNameHasNoTypeParams(t *testing.T) { + for _, f := range sharedGenericsL1.SymbolTable { + for name := range f.Types { + if strings.ContainsAny(name, "[]") { + t.Errorf("type name %q contains type-parameter brackets — should be stripped", name) + } + } + } +} + +// ── Methods on generic types ────────────────────────────────────────────────── + +func TestGenerics_SetMethods(t *testing.T) { + f := sharedGenericsL1.SymbolTable["set/set.go"] + for _, want := range []string{"Add", "Remove", "Contains", "Len"} { + t.Run(want, func(t *testing.T) { + if findCallableByName(f, want) == nil { + t.Errorf("method %q not found on Set", want) + } + }) + } +} + +func TestGenerics_UnexportedMethodOnGenericType(t *testing.T) { + f := sharedGenericsL1.SymbolTable["set/set.go"] + snapshot := findCallableByName(f, "snapshot") + if snapshot == nil { + t.Fatal("unexported method 'snapshot' not found on Set") + } + if snapshot.IsExported { + t.Error("snapshot.is_exported should be false") + } +} + +// ── Union-constraint interfaces ─────────────────────────────────────────────── + +func TestGenerics_OrderedIsInterface(t *testing.T) { + f, ok := sharedGenericsL1.SymbolTable["fn/fn.go"] + if !ok { + t.Fatal("fn/fn.go not in symbol table") + } + ordered, ok := f.Types["Ordered"] + if !ok { + t.Fatal("GoType 'Ordered' not found in fn/fn.go") + } + if !ordered.IsInterface { + t.Error("Ordered.is_interface should be true (union constraint)") + } +} + +func TestGenerics_NumericIsInterface(t *testing.T) { + f := sharedGenericsL1.SymbolTable["fn/fn.go"] + numeric, ok := f.Types["Numeric"] + if !ok { + t.Fatal("GoType 'Numeric' not found in fn/fn.go") + } + if !numeric.IsInterface { + t.Error("Numeric.is_interface should be true") + } +} + +// ── Generic functions ───────────────────────────────────────────────────────── + +func TestGenerics_SingleTypeParamFunctions(t *testing.T) { + f := sharedGenericsL1.SymbolTable["fn/fn.go"] + for _, name := range []string{"Min", "Max", "Filter"} { + t.Run(name, func(t *testing.T) { + if findCallableByName(f, name) == nil { + t.Errorf("generic function %q not found in fn/fn.go", name) + } + }) + } +} + +func TestGenerics_MapHasMultipleParams(t *testing.T) { + f := sharedGenericsL1.SymbolTable["fn/fn.go"] + mapFn := findCallableByName(f, "Map") + if mapFn == nil { + t.Fatal("generic function 'Map' not found in fn/fn.go") + } + // Map[T, U any](in []T, f func(T) U) []U — two declared parameters. + if len(mapFn.Parameters) < 2 { + t.Errorf("Map() should have >= 2 parameters; got %d", len(mapFn.Parameters)) + } +} diff --git a/internal/core/incremental_test.go b/internal/core/incremental_test.go new file mode 100644 index 0000000..7fc85e8 --- /dev/null +++ b/internal/core/incremental_test.go @@ -0,0 +1,120 @@ +package core_test + +// Tests for --target-files / TargetFiles incremental analysis mode. +// +// When TargetFiles is non-empty each value is passed as a "file=" +// pattern to packages.Load, which loads only the package(s) containing those +// files. Other packages in the project are not loaded and must not appear in +// the symbol table. + +import ( + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" +) + +func multipackageDir() string { + return filepath.Join(testdataDir(), "multipackage") +} + +// TestTargetFiles_SinglePackage: targeting one file restricts the symbol table +// to the package containing that file. The multipackage fixture has three +// packages (main, server, worker); targeting server/server.go should exclude +// main.go and worker/worker.go. +func TestTargetFiles_SinglePackage(t *testing.T) { + td := multipackageDir() + serverFile := filepath.Join(td, "server", "server.go") + + app, err := core.New(options.AnalysisOptions{ + InputPath: td, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: t.TempDir(), + TargetFiles: []string{serverFile}, + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + if len(app.SymbolTable) == 0 { + t.Fatal("symbol table is empty — analysis may have failed silently") + } + if _, ok := app.SymbolTable["server/server.go"]; !ok { + t.Errorf("server/server.go should be in symbol table; got keys: %v", keys(app.SymbolTable)) + } + if _, ok := app.SymbolTable["main.go"]; ok { + t.Error("main.go must not be in symbol table when only server package is targeted") + } + if _, ok := app.SymbolTable["worker/worker.go"]; ok { + t.Error("worker/worker.go must not be in symbol table when only server package is targeted") + } +} + +// TestTargetFiles_MultiplePackages: targeting files in two separate packages +// includes both packages but still excludes the third. +func TestTargetFiles_MultiplePackages(t *testing.T) { + td := multipackageDir() + + app, err := core.New(options.AnalysisOptions{ + InputPath: td, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: t.TempDir(), + TargetFiles: []string{ + filepath.Join(td, "server", "server.go"), + filepath.Join(td, "worker", "worker.go"), + }, + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + for _, want := range []string{"server/server.go", "worker/worker.go"} { + if _, ok := app.SymbolTable[want]; !ok { + t.Errorf("%s should be in symbol table; got keys: %v", want, keys(app.SymbolTable)) + } + } + if _, ok := app.SymbolTable["main.go"]; ok { + t.Error("main.go must not be in symbol table when not targeted") + } +} + +// TestTargetFiles_NilMeansAllFiles: nil TargetFiles produces a full analysis, +// matching the file count of the pre-computed sharedMultipackageL1. +func TestTargetFiles_NilMeansAllFiles(t *testing.T) { + const want = 4 // main.go + server/server.go + server/middleware.go + worker/worker.go + if got := len(sharedMultipackageL1.SymbolTable); got != want { + t.Errorf("multipackage fixture with nil TargetFiles: got %d files, want %d; keys: %v", + got, want, keys(sharedMultipackageL1.SymbolTable)) + } +} + +// TestTargetFiles_SiblingFilesIncluded: when a package has multiple source files +// (server.go + middleware.go), targeting any one file loads the entire package, +// so sibling files are also present in the symbol table. +func TestTargetFiles_SiblingFilesIncluded(t *testing.T) { + td := multipackageDir() + serverFile := filepath.Join(td, "server", "server.go") + + app, err := core.New(options.AnalysisOptions{ + InputPath: td, + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: true, + CacheDir: t.TempDir(), + TargetFiles: []string{serverFile}, + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + // middleware.go is in the same package as server.go; it must appear too. + if _, ok := app.SymbolTable["server/middleware.go"]; !ok { + t.Errorf("server/middleware.go (sibling file) should be in symbol table; got keys: %v", + keys(app.SymbolTable)) + } +} diff --git a/internal/core/realistic_test.go b/internal/core/multipackage_test.go similarity index 69% rename from internal/core/realistic_test.go rename to internal/core/multipackage_test.go index 3f651f4..b1070bc 100644 --- a/internal/core/realistic_test.go +++ b/internal/core/multipackage_test.go @@ -5,48 +5,12 @@ package core_test // is_variadic, is_embedded, multi-file package, cyclomatic_complexity, specific edges. import ( - "path/filepath" - "runtime" "strings" "testing" - "github.com/codellm-devkit/codeanalyzer-go/internal/core" - "github.com/codellm-devkit/codeanalyzer-go/internal/options" "github.com/codellm-devkit/codeanalyzer-go/internal/schema" ) -func realisticDir(t *testing.T) string { - t.Helper() - _, file, _, ok := runtime.Caller(0) - if !ok { - t.Fatal("cannot determine source file path") - } - root := filepath.Join(filepath.Dir(file), "..", "..") - abs, err := filepath.Abs(filepath.Join(root, "testdata", "realistic")) - if err != nil { - t.Fatalf("resolving realistic fixture dir: %v", err) - } - return abs -} - -func runRealistic(t *testing.T, level options.AnalysisLevel) *schema.GoApplication { - t.Helper() - dir := realisticDir(t) - outDir := t.TempDir() - opts := options.AnalysisOptions{ - InputPath: dir, - OutputDir: outDir, - Level: level, - SkipTests: true, - CacheDir: t.TempDir(), - } - app, err := core.New(opts).Analyze() - if err != nil { - t.Fatalf("Analyze() failed: %v", err) - } - return app -} - // findCallableByName searches all functions and methods in a GoFile by short name. func findCallableByName(f schema.GoFile, name string) *schema.GoCallable { for _, c := range f.Functions { @@ -69,17 +33,15 @@ func findCallableByName(f schema.GoFile, name string) *schema.GoCallable { // ── Multi-file package ──────────────────────────────────────────────────────── func TestRealistic_MultiFilePkg(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - _, hasServer := app.SymbolTable["server/server.go"] - _, hasMiddleware := app.SymbolTable["server/middleware.go"] + _, hasServer := sharedMultipackageL1.SymbolTable["server/server.go"] + _, hasMiddleware := sharedMultipackageL1.SymbolTable["server/middleware.go"] if !hasServer { t.Error("server/server.go missing from symbol table") } if !hasMiddleware { t.Error("server/middleware.go missing from symbol table") } - // Tags must live in middleware.go, not server.go. - mw := app.SymbolTable["server/middleware.go"] + mw := sharedMultipackageL1.SymbolTable["server/middleware.go"] if findCallableByName(mw, "Tags") == nil { t.Error("Tags function not found in server/middleware.go") } @@ -88,15 +50,14 @@ func TestRealistic_MultiFilePkg(t *testing.T) { // ── Embedded struct field ───────────────────────────────────────────────────── func TestRealistic_EmbeddedField(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] server, ok := srv.Types["Server"] if !ok { t.Fatal("GoType 'Server' not found in server/server.go") } for _, f := range server.Fields { if f.IsEmbedded { - return // pass + return } } t.Errorf("Server has no embedded field; fields: %+v", server.Fields) @@ -105,8 +66,7 @@ func TestRealistic_EmbeddedField(t *testing.T) { // ── Multiple return types — (T, error) pattern ──────────────────────────────── func TestRealistic_MultipleReturnTypes(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] newFn := findCallableByName(srv, "New") if newFn == nil { t.Fatal("function 'New' not found in server/server.go") @@ -126,8 +86,7 @@ func TestRealistic_MultipleReturnTypes(t *testing.T) { } func TestRealistic_ValidateReturnTypes(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] validate := findCallableByName(srv, "Validate") if validate == nil { t.Fatal("method 'Validate' not found in server/server.go") @@ -140,8 +99,7 @@ func TestRealistic_ValidateReturnTypes(t *testing.T) { // ── Unexported callables ────────────────────────────────────────────────────── func TestRealistic_UnexportedMethod(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] shutdown := findCallableByName(srv, "shutdown") if shutdown == nil { t.Fatal("method 'shutdown' not found in server/server.go") @@ -152,8 +110,7 @@ func TestRealistic_UnexportedMethod(t *testing.T) { } func TestRealistic_UnexportedWorkerMethod(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] execute := findCallableByName(wkr, "execute") if execute == nil { t.Fatal("method 'execute' not found in worker/worker.go") @@ -166,8 +123,7 @@ func TestRealistic_UnexportedWorkerMethod(t *testing.T) { // ── Receiver type / name ────────────────────────────────────────────────────── func TestRealistic_ReceiverType(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] addr := findCallableByName(srv, "Addr") if addr == nil { t.Fatal("method 'Addr' not found in server/server.go") @@ -178,26 +134,20 @@ func TestRealistic_ReceiverType(t *testing.T) { if addr.ReceiverName == "" { t.Error("Addr().receiver_name should be non-empty") } - // Pointer receiver — type should contain '*' or 'Server'. if !strings.Contains(addr.ReceiverType, "Server") { t.Errorf("Addr().receiver_type %q should reference Server", addr.ReceiverType) } } func TestRealistic_ValueReceiver(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - // Describe is defined in middleware.go but its receiver type (Server) lives in - // server.go — the reconcileCrossFileMethods pass attaches it to server.go's type. - srv := app.SymbolTable["server/server.go"] + srv := sharedMultipackageL1.SymbolTable["server/server.go"] describe := findCallableByName(srv, "Describe") if describe == nil { t.Fatal("method 'Describe' not found attached to Server in server/server.go") } - // Value receiver — ReceiverType should not contain '*'. if strings.Contains(describe.ReceiverType, "*") { t.Errorf("Describe().receiver_type %q should be a value receiver (no '*')", describe.ReceiverType) } - // Path should still record the physical definition file. if !strings.Contains(describe.Path, "middleware.go") { t.Errorf("Describe().path %q should point to middleware.go", describe.Path) } @@ -206,30 +156,28 @@ func TestRealistic_ValueReceiver(t *testing.T) { // ── Variadic parameters ─────────────────────────────────────────────────────── func TestRealistic_VariadicParamTags(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - mw := app.SymbolTable["server/middleware.go"] + mw := sharedMultipackageL1.SymbolTable["server/middleware.go"] tags := findCallableByName(mw, "Tags") if tags == nil { t.Fatal("function 'Tags' not found in server/middleware.go") } for _, p := range tags.Parameters { if p.IsVariadic { - return // pass + return } } t.Errorf("Tags() has no variadic parameter; params: %+v", tags.Parameters) } func TestRealistic_VariadicParamCombine(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] combine := findCallableByName(wkr, "Combine") if combine == nil { t.Fatal("function 'Combine' not found in worker/worker.go") } for _, p := range combine.Parameters { if p.IsVariadic { - return // pass + return } } t.Errorf("Combine() has no variadic parameter; params: %+v", combine.Parameters) @@ -238,15 +186,14 @@ func TestRealistic_VariadicParamCombine(t *testing.T) { // ── Goroutine call site ─────────────────────────────────────────────────────── func TestRealistic_GoroutineCallsite(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] run := findCallableByName(wkr, "Run") if run == nil { t.Fatal("method 'Run' not found in worker/worker.go") } for _, cs := range run.CallSites { if cs.IsGoroutine { - return // pass + return } } t.Errorf("Run() has no goroutine call site; sites: %+v", run.CallSites) @@ -255,13 +202,11 @@ func TestRealistic_GoroutineCallsite(t *testing.T) { // ── Cyclomatic complexity ───────────────────────────────────────────────────── func TestRealistic_CyclomaticComplexity(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] execute := findCallableByName(wkr, "execute") if execute == nil { t.Fatal("method 'execute' not found in worker/worker.go") } - // execute() has an `if err != nil` branch → CC >= 2. if execute.CyclomaticComplexity < 2 { t.Errorf("execute().cyclomatic_complexity should be >= 2; got %d", execute.CyclomaticComplexity) } @@ -270,8 +215,7 @@ func TestRealistic_CyclomaticComplexity(t *testing.T) { // ── Interface detection ─────────────────────────────────────────────────────── func TestRealistic_InterfaceType(t *testing.T) { - app := runRealistic(t, options.LevelSymbolTable) - wkr := app.SymbolTable["worker/worker.go"] + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] proc, ok := wkr.Types["Processor"] if !ok { t.Fatal("GoType 'Processor' not found in worker/worker.go") @@ -281,29 +225,25 @@ func TestRealistic_InterfaceType(t *testing.T) { } } -// ── Specific call-graph edge ────────────────────────────────────────────────── +// ── Specific call-graph edges ───────────────────────────────────────────────── func TestRealistic_SpecificCallEdge(t *testing.T) { - app := runRealistic(t, options.LevelCallGraph) - // main() calls server.New() — this is a cross-package project-internal edge. - const wantTarget = "example.com/realistic/server.New" - for _, e := range app.CallGraph { + const wantTarget = "example.com/multipackage/server.New" + for _, e := range sharedMultipackageL2.CallGraph { if e.Target == wantTarget { - return // pass + return } } - t.Errorf("call graph missing expected edge to %s; edges: %v", wantTarget, edgeTargets(app)) + t.Errorf("call graph missing expected edge to %s; edges: %v", wantTarget, edgeTargets(sharedMultipackageL2)) } func TestRealistic_CrossPackageEdges(t *testing.T) { - app := runRealistic(t, options.LevelCallGraph) - // At least one edge must cross the main→server boundary and one main→worker boundary. var serverEdge, workerEdge bool - for _, e := range app.CallGraph { - if strings.Contains(e.Target, "realistic/server.") { + for _, e := range sharedMultipackageL2.CallGraph { + if strings.Contains(e.Target, "multipackage/server.") { serverEdge = true } - if strings.Contains(e.Target, "realistic/worker.") { + if strings.Contains(e.Target, "multipackage/worker.") { workerEdge = true } } @@ -315,6 +255,41 @@ func TestRealistic_CrossPackageEdges(t *testing.T) { } } +// ── H6: LocalVariables assertions ──────────────────────────────────────────── + +// worker.Combine has `out := Result{}` — a local variable with a known type. +func TestRealistic_LocalVariablesPresent(t *testing.T) { + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] + combine := findCallableByName(wkr, "Combine") + if combine == nil { + t.Fatal("function 'Combine' not found in worker/worker.go") + } + if len(combine.LocalVariables) == 0 { + t.Fatal("Combine() should have at least one local variable; got none") + } +} + +// worker.execute has `r, err := p.Process(t)` — two local variables. +func TestRealistic_LocalVariablesHaveType(t *testing.T) { + wkr := sharedMultipackageL1.SymbolTable["worker/worker.go"] + execute := findCallableByName(wkr, "execute") + if execute == nil { + t.Fatal("method 'execute' not found in worker/worker.go") + } + for _, v := range execute.LocalVariables { + if v.Name == "err" { + if v.Type == "" { + t.Error("local variable 'err' should have a non-empty type") + } + if v.Scope != "function" { + t.Errorf("local variable 'err' scope should be 'function'; got %q", v.Scope) + } + return + } + } + t.Errorf("local variable 'err' not found in execute(); vars: %+v", execute.LocalVariables) +} + // ── Helpers ─────────────────────────────────────────────────────────────────── func edgeTargets(app *schema.GoApplication) []string { diff --git a/internal/core/skip_tests_test.go b/internal/core/skip_tests_test.go new file mode 100644 index 0000000..227039e --- /dev/null +++ b/internal/core/skip_tests_test.go @@ -0,0 +1,70 @@ +package core_test + +// Tests for the SkipTests option. +// +// testdata/multipackage/server/server_test.go is a minimal test file whose sole +// purpose is to give these tests something to look for. It is never included +// in the shared fixtures (all use SkipTests: true). + +import ( + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" +) + +const serverTestFile = "server/server_test.go" + +// TestSkipTests_TrueExcludesTestFiles: the default (SkipTests=true) must not +// include any *_test.go file in the symbol table. +func TestSkipTests_TrueExcludesTestFiles(t *testing.T) { + // sharedMultipackageL1 is built with SkipTests: true — re-use it. + for key := range sharedMultipackageL1.SymbolTable { + if len(key) >= 8 && key[len(key)-8:] == "_test.go" { + t.Errorf("SkipTests=true: found test file in symbol table: %s", key) + } + } +} + +// TestSkipTests_FalseIncludesTestFiles: with SkipTests=false the analyzer must +// include *_test.go files in the symbol table. +func TestSkipTests_FalseIncludesTestFiles(t *testing.T) { + app, err := core.New(options.AnalysisOptions{ + InputPath: filepath.Join(testdataDir(), "multipackage"), + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: false, + CacheDir: t.TempDir(), + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + if _, ok := app.SymbolTable[serverTestFile]; !ok { + t.Errorf("SkipTests=false: %s not in symbol table; got keys: %v", + serverTestFile, keys(app.SymbolTable)) + } +} + +// TestSkipTests_FalseIncreasesFileCount: the symbol table with SkipTests=false +// must have more files than the same analysis with SkipTests=true. +func TestSkipTests_FalseIncreasesFileCount(t *testing.T) { + withSkip := len(sharedMultipackageL1.SymbolTable) + + app, err := core.New(options.AnalysisOptions{ + InputPath: filepath.Join(testdataDir(), "multipackage"), + OutputDir: t.TempDir(), + Level: options.LevelSymbolTable, + SkipTests: false, + CacheDir: t.TempDir(), + }).Analyze() + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + if len(app.SymbolTable) <= withSkip { + t.Errorf("SkipTests=false: expected more files than %d (SkipTests=true count); got %d", + withSkip, len(app.SymbolTable)) + } +} diff --git a/internal/core/testsetup_test.go b/internal/core/testsetup_test.go new file mode 100644 index 0000000..85f3b72 --- /dev/null +++ b/internal/core/testsetup_test.go @@ -0,0 +1,99 @@ +package core_test + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/core" + "github.com/codellm-devkit/codeanalyzer-go/internal/options" + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" +) + +// Shared analysis results, populated once in TestMain and reused across all tests. +// Caching tests are excluded: they must exercise the caching machinery themselves. +var ( + sharedGreeterL1 *schema.GoApplication // greeter, symbol-table only + sharedGreeterL2 *schema.GoApplication // greeter, full call-graph + sharedMultipackageL1 *schema.GoApplication // multipackage, symbol-table only + sharedMultipackageL2 *schema.GoApplication // multipackage, full call-graph + sharedGenericsL1 *schema.GoApplication // generics, symbol-table only + sharedChiL2 *schema.GoApplication // chi (external dep), full call-graph +) + +func TestMain(m *testing.M) { + os.Exit(runTestMain(m)) +} + +// runTestMain wraps m.Run so that deferred cleanup runs before os.Exit. +func runTestMain(m *testing.M) int { + tdRoot := testdataDir() + + tmpRoot, err := os.MkdirTemp("", "codeanalyzer-test-*") + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup: MkdirTemp: %v\n", err) + return 1 + } + defer os.RemoveAll(tmpRoot) + + type fixture struct { + name string + path string + level options.AnalysisLevel + dst **schema.GoApplication + } + + for _, f := range []fixture{ + {"greeter/L1", filepath.Join(tdRoot, "greeter"), options.LevelSymbolTable, &sharedGreeterL1}, + {"greeter/L2", filepath.Join(tdRoot, "greeter"), options.LevelCallGraph, &sharedGreeterL2}, + {"multipackage/L1", filepath.Join(tdRoot, "multipackage"), options.LevelSymbolTable, &sharedMultipackageL1}, + {"multipackage/L2", filepath.Join(tdRoot, "multipackage"), options.LevelCallGraph, &sharedMultipackageL2}, + {"generics/L1", filepath.Join(tdRoot, "generics"), options.LevelSymbolTable, &sharedGenericsL1}, + {"chi/L2", filepath.Join(tdRoot, "chi"), options.LevelCallGraph, &sharedChiL2}, + } { + outDir, err := os.MkdirTemp(tmpRoot, "out-*") + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: MkdirTemp out: %v\n", f.name, err) + return 1 + } + cacheDir, err := os.MkdirTemp(tmpRoot, "cache-*") + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: MkdirTemp cache: %v\n", f.name, err) + return 1 + } + opts := options.AnalysisOptions{ + InputPath: f.path, + OutputDir: outDir, + Level: f.level, + SkipTests: true, + CacheDir: cacheDir, + } + app, err := core.New(opts).Analyze() + if err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: Analyze: %v\n", f.name, err) + return 1 + } + if err := core.WriteOutput(app, outDir, "json"); err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: WriteOutput: %v\n", f.name, err) + return 1 + } + if _, err := os.Stat(filepath.Join(outDir, "analysis.json")); err != nil { + fmt.Fprintf(os.Stderr, "testsetup %s: analysis.json not created: %v\n", f.name, err) + return 1 + } + *f.dst = app + } + + return m.Run() +} + +// testdataDir returns the absolute path to the testdata directory. +// Uses runtime.Caller so it resolves correctly regardless of working directory. +func testdataDir() string { + _, thisFile, _, _ := runtime.Caller(0) + abs, _ := filepath.Abs(filepath.Join(filepath.Dir(thisFile), "..", "..", "testdata")) + return abs +} + diff --git a/internal/semantic_analysis/call_graph_test.go b/internal/semantic_analysis/call_graph_test.go new file mode 100644 index 0000000..60b89ee --- /dev/null +++ b/internal/semantic_analysis/call_graph_test.go @@ -0,0 +1,116 @@ +package semantic_analysis_test + +import ( + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/schema" + "github.com/codellm-devkit/codeanalyzer-go/internal/semantic_analysis" +) + +func edge(src, tgt string, weight int, prov ...string) schema.GoCallEdge { + return schema.GoCallEdge{Source: src, Target: tgt, Weight: weight, Provenance: prov} +} + +// ── MergeEdges ──────────────────────────────────────────────────────────────── + +func TestMergeEdges_EmptyBoth(t *testing.T) { + result := semantic_analysis.MergeEdges(nil, nil) + if len(result) != 0 { + t.Errorf("got %d edges, want 0", len(result)) + } +} + +func TestMergeEdges_PrimaryOnly(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + result := semantic_analysis.MergeEdges(primary, nil) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + if result[0].Source != "a" || result[0].Target != "b" { + t.Errorf("unexpected edge: %+v", result[0]) + } +} + +func TestMergeEdges_SecondaryOnly(t *testing.T) { + secondary := []schema.GoCallEdge{edge("x", "y", 2.0, "codeql")} + result := semantic_analysis.MergeEdges(nil, secondary) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + if result[0].Source != "x" || result[0].Target != "y" { + t.Errorf("unexpected edge: %+v", result[0]) + } +} + +func TestMergeEdges_DisjointEdges(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + secondary := []schema.GoCallEdge{edge("c", "d", 1.0, "codeql")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 2 { + t.Errorf("got %d edges, want 2", len(result)) + } +} + +func TestMergeEdges_DuplicateAccumulatesWeight(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 3, "resolver")} + secondary := []schema.GoCallEdge{edge("a", "b", 5, "codeql")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 1 { + t.Fatalf("duplicate (a→b) should collapse to 1 edge; got %d", len(result)) + } + if result[0].Weight != 8 { + t.Errorf("weight: got %v, want 8", result[0].Weight) + } +} + +func TestMergeEdges_DuplicateUnionsProvenance(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + secondary := []schema.GoCallEdge{edge("a", "b", 1.0, "codeql")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + provSet := map[string]bool{} + for _, p := range result[0].Provenance { + provSet[p] = true + } + if !provSet["resolver"] || !provSet["codeql"] { + t.Errorf("provenance union failed; got %v", result[0].Provenance) + } +} + +func TestMergeEdges_DuplicateProvenanceNotDuplicated(t *testing.T) { + primary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + secondary := []schema.GoCallEdge{edge("a", "b", 1.0, "resolver")} + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 1 { + t.Fatalf("got %d edges, want 1", len(result)) + } + count := 0 + for _, p := range result[0].Provenance { + if p == "resolver" { + count++ + } + } + if count != 1 { + t.Errorf("duplicate provenance should appear once; got %d times", count) + } +} + +func TestMergeEdges_OrderPreserved(t *testing.T) { + primary := []schema.GoCallEdge{ + edge("a", "b", 1.0), + edge("c", "d", 1.0), + } + secondary := []schema.GoCallEdge{ + edge("e", "f", 1.0), + } + result := semantic_analysis.MergeEdges(primary, secondary) + if len(result) != 3 { + t.Fatalf("got %d edges, want 3", len(result)) + } + // Primary edges come first, then secondary. + if result[0].Source != "a" || result[1].Source != "c" || result[2].Source != "e" { + t.Errorf("order not preserved: %v", result) + } +} diff --git a/internal/syntactic_analysis/symbol_table.go b/internal/syntactic_analysis/symbol_table.go index f6a7838..c7f48ce 100644 --- a/internal/syntactic_analysis/symbol_table.go +++ b/internal/syntactic_analysis/symbol_table.go @@ -50,6 +50,9 @@ func (b *SymbolTableBuilder) Build(targetFiles []string, skipTests bool) (map[st packages.NeedDeps, Dir: b.projectDir, Fset: b.fset, + // Include test packages when the caller wants test files. Without this, + // go/packages never presents *_test.go files to the loader. + Tests: !skipTests, // Silence go vet; we only need type info, not a full build. BuildFlags: []string{}, } @@ -90,6 +93,11 @@ func (b *SymbolTableBuilder) Build(targetFiles []string, skipTests bool) (map[st } filePath := pkg.GoFiles[i] relPath := utils.RelativePath(b.projectDir, filePath) + // Paths that escape the project root (generated test runners in the + // Go build cache, stdlib, etc.) are never project files. + if strings.HasPrefix(relPath, "..") { + continue + } if skipTests && utils.IsTestFile(relPath) { continue } @@ -129,6 +137,9 @@ func (b *SymbolTableBuilder) reconcileCrossFileMethods(symbolTable map[string]sc } filePath := pkg.GoFiles[i] relPath := utils.RelativePath(b.projectDir, filePath) + if strings.HasPrefix(relPath, "..") { + continue + } for _, decl := range astFile.Decls { fd, ok := decl.(*ast.FuncDecl) @@ -542,6 +553,7 @@ func (b *SymbolTableBuilder) buildCallable( callable.Code = b.nodeSource(decl) callable.CallSites = b.buildCallSites(pkg, decl.Body) callable.LocalVariables = b.buildLocalVars(pkg, decl.Body) + callable.InnerCallables = b.buildInnerCallables(pkg, sig, decl.Body) } return callable @@ -634,6 +646,9 @@ func (b *SymbolTableBuilder) buildCallSites(pkg *packages.Package, body *ast.Blo sites = append(sites, *site) } return false + case *ast.FuncLit: + // Closure bodies are handled by buildInnerCallables; don't double-count. + return false case *ast.CallExpr: site := b.callExprToSite(pkg, node, false) if site != nil { @@ -762,6 +777,49 @@ func (b *SymbolTableBuilder) buildLocalVars(pkg *packages.Package, body *ast.Blo return vars } +// buildInnerCallables walks body and collects each FuncLit as a named closure. +// Only the top level is captured; nested closures appear in the closure's own +// InnerCallables (populated when buildCallSites recurses into lit.Body). +func (b *SymbolTableBuilder) buildInnerCallables(pkg *packages.Package, outerSig string, body *ast.BlockStmt) map[string]schema.GoCallable { + inner := map[string]schema.GoCallable{} + n := 0 + ast.Inspect(body, func(node ast.Node) bool { + lit, ok := node.(*ast.FuncLit) + if !ok { + return true + } + n++ + name := fmt.Sprintf("closure_%d", n) + sig := outerSig + "." + name + pos := b.fset.Position(lit.Pos()) + endPos := b.fset.Position(lit.End()) + _, retTypes := b.buildReturnTypes(pkg, lit.Type) + retType := b.joinReturnTypes(retTypes) + + ic := schema.GoCallable{ + Name: name, + Signature: sig, + Parameters: b.buildParams(pkg, lit.Type), + ReturnType: retType, + ReturnTypes: retTypes, + CallSites: []schema.GoCallsite{}, + InnerCallables: map[string]schema.GoCallable{}, + LocalVariables: []schema.GoVariableDeclaration{}, + StartLine: pos.Line, + EndLine: endPos.Line, + } + if lit.Body != nil { + ic.Code = b.nodeSource(lit) + ic.CallSites = b.buildCallSites(pkg, lit.Body) + ic.LocalVariables = b.buildLocalVars(pkg, lit.Body) + ic.InnerCallables = b.buildInnerCallables(pkg, sig, lit.Body) + } + inner[name] = ic + return false // don't recurse; nested closures are handled above + }) + return inner +} + // ─── Package-level variables ────────────────────────────────────────────────── func (b *SymbolTableBuilder) buildPackageVars(pkg *packages.Package, astFile *ast.File) []schema.GoVariableDeclaration { @@ -826,6 +884,8 @@ func (b *SymbolTableBuilder) cyclomaticComplexity(decl *ast.FuncDecl) int { // ─── Helpers ────────────────────────────────────────────────────────────────── // receiverTypeName extracts the base type name from a receiver field list. +// It handles pointer receivers (*T), generic single-param receivers (T[A]), +// pointer-to-generic (*T[A]), and multi-param generic receivers (*T[A, B]). func (b *SymbolTableBuilder) receiverTypeName(recv *ast.FieldList) string { if recv == nil || len(recv.List) == 0 { return "" @@ -835,6 +895,14 @@ func (b *SymbolTableBuilder) receiverTypeName(recv *ast.FieldList) string { if star, ok := expr.(*ast.StarExpr); ok { expr = star.X } + // Generic single type param: Set[T] → IndexExpr{X: Ident("Set")} + if idx, ok := expr.(*ast.IndexExpr); ok { + expr = idx.X + } + // Generic multi type param: Map[K, V] → IndexListExpr{X: Ident("Map")} + if idx, ok := expr.(*ast.IndexListExpr); ok { + expr = idx.X + } if ident, ok := expr.(*ast.Ident); ok { return ident.Name } diff --git a/internal/utils/fs_test.go b/internal/utils/fs_test.go new file mode 100644 index 0000000..a349d07 --- /dev/null +++ b/internal/utils/fs_test.go @@ -0,0 +1,206 @@ +package utils_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/codellm-devkit/codeanalyzer-go/internal/utils" +) + +// ── IsTestFile ──────────────────────────────────────────────────────────────── + +func TestIsTestFile(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"foo_test.go", true}, + {"server_test.go", true}, + {"foo.go", false}, + {"test.go", false}, // doesn't end with _test.go + {"_test.go", true}, // edge case: file is literally "_test.go" + {"", false}, + {"foo_test.go.bak", false}, + } + for _, tc := range tests { + if got := utils.IsTestFile(tc.path); got != tc.want { + t.Errorf("IsTestFile(%q) = %v, want %v", tc.path, got, tc.want) + } + } +} + +// ── IsVendored ──────────────────────────────────────────────────────────────── + +func TestIsVendored(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"vendor/github.com/foo/bar/baz.go", true}, + {"pkg/vendor/something.go", true}, + {"testdata/greeter/main.go", true}, + {".git/config", true}, + {"internal/core/analyzer.go", false}, + {"main.go", false}, + {"", false}, + {"vendored/not-vendor.go", false}, // "vendored" ≠ "vendor" + } + for _, tc := range tests { + if got := utils.IsVendored(tc.path); got != tc.want { + t.Errorf("IsVendored(%q) = %v, want %v", tc.path, got, tc.want) + } + } +} + +// ── FileHash ────────────────────────────────────────────────────────────────── + +func TestFileHash_Deterministic(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "file.txt") + if err := os.WriteFile(f, []byte("hello world"), 0o644); err != nil { + t.Fatal(err) + } + + h1, err := utils.FileHash(f) + if err != nil { + t.Fatalf("FileHash: %v", err) + } + h2, err := utils.FileHash(f) + if err != nil { + t.Fatalf("FileHash second call: %v", err) + } + if h1 != h2 { + t.Errorf("FileHash is not deterministic: %q != %q", h1, h2) + } + if len(h1) != 64 { + t.Errorf("FileHash should return 64-char hex SHA-256; got len %d: %q", len(h1), h1) + } +} + +func TestFileHash_DifferentContent(t *testing.T) { + dir := t.TempDir() + a := filepath.Join(dir, "a.txt") + b := filepath.Join(dir, "b.txt") + os.WriteFile(a, []byte("aaa"), 0o644) + os.WriteFile(b, []byte("bbb"), 0o644) + + ha, _ := utils.FileHash(a) + hb, _ := utils.FileHash(b) + if ha == hb { + t.Error("different files should have different hashes") + } +} + +func TestFileHash_NonExistentFile(t *testing.T) { + _, err := utils.FileHash(filepath.Join(t.TempDir(), "no-such-file")) + if err == nil { + t.Error("expected error for non-existent file, got nil") + } +} + +// ── EnsureDir ───────────────────────────────────────────────────────────────── + +func TestEnsureDir_CreatesDirectory(t *testing.T) { + dir := filepath.Join(t.TempDir(), "a", "b", "c") + if err := utils.EnsureDir(dir); err != nil { + t.Fatalf("EnsureDir: %v", err) + } + if fi, err := os.Stat(dir); err != nil || !fi.IsDir() { + t.Errorf("directory %s was not created", dir) + } +} + +func TestEnsureDir_Idempotent(t *testing.T) { + dir := t.TempDir() + if err := utils.EnsureDir(dir); err != nil { + t.Errorf("EnsureDir on existing dir: %v", err) + } +} + +// ── DiscoverGoFiles ─────────────────────────────────────────────────────────── + +func TestDiscoverGoFiles_FindsGoFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "util.go", "package main") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 2 { + t.Errorf("got %d files, want 2: %v", len(files), files) + } +} + +func TestDiscoverGoFiles_SkipsTestFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "main_test.go", "package main") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + for _, f := range files { + if utils.IsTestFile(f) { + t.Errorf("skipTests=true: found test file: %s", f) + } + } +} + +func TestDiscoverGoFiles_IncludesTestFilesWhenNotSkipped(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "main_test.go", "package main") + + files, err := utils.DiscoverGoFiles(dir, false) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 2 { + t.Errorf("got %d files, want 2: %v", len(files), files) + } +} + +func TestDiscoverGoFiles_SkipsVendorDir(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + vendor := filepath.Join(dir, "vendor", "pkg") + os.MkdirAll(vendor, 0o755) + writeFile(t, vendor, "lib.go", "package pkg") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 1 { + t.Errorf("got %d files, want 1 (vendor should be skipped): %v", len(files), files) + } +} + +func TestDiscoverGoFiles_IgnoresNonGoFiles(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "main.go", "package main") + writeFile(t, dir, "readme.md", "# readme") + writeFile(t, dir, "config.yaml", "key: val") + + files, err := utils.DiscoverGoFiles(dir, true) + if err != nil { + t.Fatalf("DiscoverGoFiles: %v", err) + } + if len(files) != 1 { + t.Errorf("got %d files, want 1: %v", len(files), files) + } +} + +// ── helpers ─────────────────────────────────────────────────────────────────── + +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("writeFile %s: %v", name, err) + } +} diff --git a/testdata/chi/chain.go b/testdata/chi/chain.go new file mode 100644 index 0000000..a227841 --- /dev/null +++ b/testdata/chi/chain.go @@ -0,0 +1,49 @@ +package chi + +import "net/http" + +// Chain returns a Middlewares type from a slice of middleware handlers. +func Chain(middlewares ...func(http.Handler) http.Handler) Middlewares { + return Middlewares(middlewares) +} + +// Handler builds and returns a http.Handler from the chain of middlewares, +// with `h http.Handler` as the final handler. +func (mws Middlewares) Handler(h http.Handler) http.Handler { + return &ChainHandler{h, chain(mws, h), mws} +} + +// HandlerFunc builds and returns a http.Handler from the chain of middlewares, +// with `h http.Handler` as the final handler. +func (mws Middlewares) HandlerFunc(h http.HandlerFunc) http.Handler { + return &ChainHandler{h, chain(mws, h), mws} +} + +// ChainHandler is a http.Handler with support for handler composition and +// execution. +type ChainHandler struct { + Endpoint http.Handler + chain http.Handler + Middlewares Middlewares +} + +func (c *ChainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.chain.ServeHTTP(w, r) +} + +// chain builds a http.Handler composed of an inline middleware stack and endpoint +// handler in the order they are passed. +func chain(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { + // Return ahead of time if there aren't any middlewares for the chain + if len(middlewares) == 0 { + return endpoint + } + + // Wrap the end handler with the middleware chain + h := middlewares[len(middlewares)-1](endpoint) + for i := len(middlewares) - 2; i >= 0; i-- { + h = middlewares[i](h) + } + + return h +} diff --git a/testdata/chi/chi.go b/testdata/chi/chi.go new file mode 100644 index 0000000..ad0ca74 --- /dev/null +++ b/testdata/chi/chi.go @@ -0,0 +1,137 @@ +// Package chi is a small, idiomatic and composable router for building HTTP services. +// +// chi supports the four most recent major versions of Go. +// +// Example: +// +// package main +// +// import ( +// "net/http" +// +// "github.com/go-chi/chi/v5" +// "github.com/go-chi/chi/v5/middleware" +// ) +// +// func main() { +// r := chi.NewRouter() +// r.Use(middleware.Logger) +// r.Use(middleware.Recoverer) +// +// r.Get("/", func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("root.")) +// }) +// +// http.ListenAndServe(":3333", r) +// } +// +// See github.com/go-chi/chi/_examples/ for more in-depth examples. +// +// URL patterns allow for easy matching of path components in HTTP +// requests. The matching components can then be accessed using +// chi.URLParam(). All patterns must begin with a slash. +// +// A simple named placeholder {name} matches any sequence of characters +// up to the next / or the end of the URL. Trailing slashes on paths must +// be handled explicitly. +// +// A placeholder with a name followed by a colon allows a regular +// expression match, for example {number:\\d+}. The regular expression +// syntax is Go's normal regexp RE2 syntax, except that / will never be +// matched. An anonymous regexp pattern is allowed, using an empty string +// before the colon in the placeholder, such as {:\\d+} +// +// The special placeholder of asterisk matches the rest of the requested +// URL. Any trailing characters in the pattern are ignored. This is the only +// placeholder which will match / characters. +// +// Examples: +// +// "/user/{name}" matches "/user/jsmith" but not "/user/jsmith/info" or "/user/jsmith/" +// "/user/{name}/info" matches "/user/jsmith/info" +// "/page/*" matches "/page/intro/latest" +// "/page/{other}/latest" also matches "/page/intro/latest" +// "/date/{yyyy:\\d\\d\\d\\d}/{mm:\\d\\d}/{dd:\\d\\d}" matches "/date/2017/04/01" +package chi + +import "net/http" + +// NewRouter returns a new Mux object that implements the Router interface. +func NewRouter() *Mux { + return NewMux() +} + +// Router consisting of the core routing methods used by chi's Mux, +// using only the standard net/http. +type Router interface { + http.Handler + Routes + + // Use appends one or more middlewares onto the Router stack. + Use(middlewares ...func(http.Handler) http.Handler) + + // With adds inline middlewares for an endpoint handler. + With(middlewares ...func(http.Handler) http.Handler) Router + + // Group adds a new inline-Router along the current routing + // path, with a fresh middleware stack for the inline-Router. + Group(fn func(r Router)) Router + + // Route mounts a sub-Router along a `pattern` string. + Route(pattern string, fn func(r Router)) Router + + // Mount attaches another http.Handler along ./pattern/* + Mount(pattern string, h http.Handler) + + // Handle and HandleFunc adds routes for `pattern` that matches + // all HTTP methods. + Handle(pattern string, h http.Handler) + HandleFunc(pattern string, h http.HandlerFunc) + + // Method and MethodFunc adds routes for `pattern` that matches + // the `method` HTTP method. + Method(method, pattern string, h http.Handler) + MethodFunc(method, pattern string, h http.HandlerFunc) + + // HTTP-method routing along `pattern` + Connect(pattern string, h http.HandlerFunc) + Delete(pattern string, h http.HandlerFunc) + Get(pattern string, h http.HandlerFunc) + Head(pattern string, h http.HandlerFunc) + Options(pattern string, h http.HandlerFunc) + Patch(pattern string, h http.HandlerFunc) + Post(pattern string, h http.HandlerFunc) + Put(pattern string, h http.HandlerFunc) + Trace(pattern string, h http.HandlerFunc) + + // NotFound defines a handler to respond whenever a route could + // not be found. + NotFound(h http.HandlerFunc) + + // MethodNotAllowed defines a handler to respond whenever a method is + // not allowed. + MethodNotAllowed(h http.HandlerFunc) +} + +// Routes interface adds two methods for router traversal, which is also +// used by the `docgen` subpackage to generation documentation for Routers. +type Routes interface { + // Routes returns the routing tree in an easily traversable structure. + Routes() []Route + + // Middlewares returns the list of middlewares in use by the router. + Middlewares() Middlewares + + // Match searches the routing tree for a handler that matches + // the method/path - similar to routing a http request, but without + // executing the handler thereafter. + Match(rctx *Context, method, path string) bool + + // Find searches the routing tree for the pattern that matches + // the method/path. + Find(rctx *Context, method, path string) string +} + +// Middlewares type is a slice of standard middleware handlers with methods +// to compose middleware chains and http.Handler's. +type Middlewares []func(http.Handler) http.Handler diff --git a/testdata/chi/context.go b/testdata/chi/context.go new file mode 100644 index 0000000..8222073 --- /dev/null +++ b/testdata/chi/context.go @@ -0,0 +1,166 @@ +package chi + +import ( + "context" + "net/http" + "strings" +) + +// URLParam returns the url parameter from a http.Request object. +func URLParam(r *http.Request, key string) string { + if rctx := RouteContext(r.Context()); rctx != nil { + return rctx.URLParam(key) + } + return "" +} + +// URLParamFromCtx returns the url parameter from a http.Request Context. +func URLParamFromCtx(ctx context.Context, key string) string { + if rctx := RouteContext(ctx); rctx != nil { + return rctx.URLParam(key) + } + return "" +} + +// RouteContext returns chi's routing Context object from a +// http.Request Context. +func RouteContext(ctx context.Context) *Context { + val, _ := ctx.Value(RouteCtxKey).(*Context) + return val +} + +// NewRouteContext returns a new routing Context object. +func NewRouteContext() *Context { + return &Context{} +} + +var ( + // RouteCtxKey is the context.Context key to store the request context. + RouteCtxKey = &contextKey{"RouteContext"} +) + +// Context is the default routing context set on the root node of a +// request context to track route patterns, URL parameters and +// an optional routing path. +type Context struct { + Routes Routes + + // parentCtx is the parent of this one, for using Context as a + // context.Context directly. This is an optimization that saves + // 1 allocation. + parentCtx context.Context + + // Routing path/method override used during the route search. + // See Mux#routeHTTP method. + RoutePath string + RouteMethod string + + // URLParams are the stack of routeParams captured during the + // routing lifecycle across a stack of sub-routers. + URLParams RouteParams + + // Route parameters matched for the current sub-router. It is + // intentionally unexported so it can't be tampered. + routeParams RouteParams + + // The endpoint routing pattern that matched the request URI path + // or `RoutePath` of the current sub-router. This value will update + // during the lifecycle of a request passing through a stack of + // sub-routers. + routePattern string + + // Routing pattern stack throughout the lifecycle of the request, + // across all connected routers. It is a record of all matching + // patterns across a stack of sub-routers. + RoutePatterns []string + + methodsAllowed []methodTyp // allowed methods in case of a 405 + methodNotAllowed bool +} + +// Reset a routing context to its initial state. +func (x *Context) Reset() { + x.Routes = nil + x.RoutePath = "" + x.RouteMethod = "" + x.RoutePatterns = x.RoutePatterns[:0] + x.URLParams.Keys = x.URLParams.Keys[:0] + x.URLParams.Values = x.URLParams.Values[:0] + + x.routePattern = "" + x.routeParams.Keys = x.routeParams.Keys[:0] + x.routeParams.Values = x.routeParams.Values[:0] + x.methodNotAllowed = false + x.methodsAllowed = x.methodsAllowed[:0] + x.parentCtx = nil +} + +// URLParam returns the corresponding URL parameter value from the request +// routing context. +func (x *Context) URLParam(key string) string { + for k := len(x.URLParams.Keys) - 1; k >= 0; k-- { + if x.URLParams.Keys[k] == key { + return x.URLParams.Values[k] + } + } + return "" +} + +// RoutePattern builds the routing pattern string for the particular +// request, at the particular point during routing. This means, the value +// will change throughout the execution of a request in a router. That is +// why it's advised to only use this value after calling the next handler. +// +// For example, +// +// func Instrument(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// next.ServeHTTP(w, r) +// routePattern := chi.RouteContext(r.Context()).RoutePattern() +// measure(w, r, routePattern) +// }) +// } +func (x *Context) RoutePattern() string { + if x == nil { + return "" + } + routePattern := strings.Join(x.RoutePatterns, "") + routePattern = replaceWildcards(routePattern) + if routePattern != "/" { + routePattern = strings.TrimSuffix(routePattern, "//") + routePattern = strings.TrimSuffix(routePattern, "/") + } + return routePattern +} + +// replaceWildcards takes a route pattern and replaces all occurrences of +// "/*/" with "/". It iteratively runs until no wildcards remain to +// correctly handle consecutive wildcards. +func replaceWildcards(p string) string { + for strings.Contains(p, "/*/") { + p = strings.ReplaceAll(p, "/*/", "/") + } + return p +} + +// RouteParams is a structure to track URL routing parameters efficiently. +type RouteParams struct { + Keys, Values []string +} + +// Add will append a URL parameter to the end of the route param +func (s *RouteParams) Add(key, value string) { + s.Keys = append(s.Keys, key) + s.Values = append(s.Values, value) +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "chi context value " + k.name +} diff --git a/testdata/chi/go.mod b/testdata/chi/go.mod new file mode 100644 index 0000000..4c47b09 --- /dev/null +++ b/testdata/chi/go.mod @@ -0,0 +1,3 @@ +module github.com/go-chi/chi/v5 + +go 1.23 diff --git a/testdata/chi/middleware/basic_auth.go b/testdata/chi/middleware/basic_auth.go new file mode 100644 index 0000000..a546c9e --- /dev/null +++ b/testdata/chi/middleware/basic_auth.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "crypto/subtle" + "fmt" + "net/http" +) + +// BasicAuth implements a simple middleware handler for adding basic http auth to a route. +func BasicAuth(realm string, creds map[string]string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + if !ok { + basicAuthFailed(w, realm) + return + } + + credPass, credUserOk := creds[user] + if !credUserOk || subtle.ConstantTimeCompare([]byte(pass), []byte(credPass)) != 1 { + basicAuthFailed(w, realm) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func basicAuthFailed(w http.ResponseWriter, realm string) { + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) + w.WriteHeader(http.StatusUnauthorized) +} diff --git a/testdata/chi/middleware/clean_path.go b/testdata/chi/middleware/clean_path.go new file mode 100644 index 0000000..adeba42 --- /dev/null +++ b/testdata/chi/middleware/clean_path.go @@ -0,0 +1,28 @@ +package middleware + +import ( + "net/http" + "path" + + "github.com/go-chi/chi/v5" +) + +// CleanPath middleware will clean out double slash mistakes from a user's request path. +// For example, if a user requests /users//1 or //users////1 will both be treated as: /users/1 +func CleanPath(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + + routePath := rctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + rctx.RoutePath = path.Clean(routePath) + } + + next.ServeHTTP(w, r) + }) +} diff --git a/testdata/chi/middleware/client_ip.go b/testdata/chi/middleware/client_ip.go new file mode 100644 index 0000000..1495a86 --- /dev/null +++ b/testdata/chi/middleware/client_ip.go @@ -0,0 +1,263 @@ +package middleware + +import ( + "context" + "net" + "net/http" + "net/netip" + "strings" +) + +// clientIPCtxKey stores the client IP set by any of the ClientIPFrom* middlewares. +var clientIPCtxKey = &contextKey{"clientIP"} + +// xForwardedForHeader is the canonical form of the X-Forwarded-For header +// name, used by the XFF-based middlewares. +const xForwardedForHeader = "X-Forwarded-For" + +// ClientIPFromHeader stores the client IP from a single-IP header set by +// your reverse proxy. Read it with [GetClientIP]. +// +// Only safe with headers your proxy unconditionally OVERWRITES on every +// request, e.g.: +// +// - X-Real-IP — Nginx with ngx_http_realip_module +// - X-Client-IP — Apache with mod_remoteip +// - CF-Connecting-IP — Cloudflare +// +// True-Client-IP, X-Azure-ClientIP, and Fastly-Client-IP look similar but +// pass through from the client by default in those products; don't use them +// unless your edge strips the inbound value. +// +// If the header reaches us with multiple values (misconfigured proxy that +// appends, or a downstream proxy not stripping a client-supplied value), +// the LAST value wins — that's the one set by the hop closest to us, and +// therefore the most trusted. Fail-closed if the last value doesn't parse: +// no client IP is set rather than falling back to earlier (less-trusted) +// values. +// +// v4-mapped IPv6 (::ffff:a.b.c.d) folds to plain v4 and IPv6 zones are +// stripped before storage. +func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler { + header := http.CanonicalHeaderKey(trustedHeader) + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + values := r.Header.Values(header) + if len(values) > 0 { + if ip, ok := parseHeaderAddr(values[len(values)-1]); ok { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, ip)) + } + } + h.ServeHTTP(w, r) + }) + } +} + +// ClientIPFromXFF stores the client IP read from the X-Forwarded-For header, +// walking the chain right-to-left and skipping any IP that falls within one +// of the given trusted CIDR prefixes. The first IP that is not trusted is +// the client. Read it with [GetClientIP]. +// +// An unparseable entry mid-chain aborts the walk and leaves no client IP +// set (fail-closed) — we can't safely trust anything left of garbage. +// +// Use this when you sit behind one or more reverse proxies whose IP ranges +// you can enumerate as CIDRs: +// +// r.Use(middleware.ClientIPFromXFF( +// "13.32.0.0/15", // CloudFront IPv4 +// "52.46.0.0/18", // CloudFront IPv4 +// "2600:9000::/28", // CloudFront IPv6 +// )) +// +// Calling with no arguments returns the rightmost XFF entry, or no IP if +// that entry doesn't parse (fail-closed) — safe only if you have exactly +// one trusted hop directly in front of this server (e.g., nginx on localhost). +// +// v4-mapped IPv6 (::ffff:a.b.c.d) folds to plain v4 and IPv6 zones are +// stripped before the prefix check and storage; otherwise an attacker +// could use either notation to alias a trusted IP past the check. +// +// If you know the number of trusted proxies but not their IPs, use +// [ClientIPFromXFFTrustedProxies] instead. +// +// Panics at startup if any prefix is invalid. +func ClientIPFromXFF(trustedIPPrefixes ...string) func(http.Handler) http.Handler { + prefixes := make([]netip.Prefix, len(trustedIPPrefixes)) + for i, p := range trustedIPPrefixes { + prefixes[i] = netip.MustParsePrefix(p) + } + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var found netip.Addr + walkXFF(r.Header[xForwardedForHeader], func(v string) bool { + ip, ok := parseHeaderAddr(v) + if !ok { + return true // fail-closed; leave found unset + } + if inAnyPrefix(ip, prefixes) { + return false // trusted hop; keep walking left + } + found = ip + return true + }) + if found.IsValid() { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, found)) + } + h.ServeHTTP(w, r) + }) + } +} + +// ClientIPFromXFFTrustedProxies stores the client IP read from the +// X-Forwarded-For header, given the exact number of trusted reverse proxies +// between this server and the public internet. It returns the IP at position +// len(xff) - numTrustedProxies in the merged X-Forwarded-For list — the IP +// added by the outermost of your trusted proxies, the only IP in the chain +// that none of your proxies have allowed an attacker to forge. Read it with +// [GetClientIP]. +// +// Use this when: +// - You know exactly how many proxies you sit behind, AND +// - Their IP addresses are dynamic (autoscaling proxy pools, ephemeral +// containers, dynamic CDN edges) so listing CIDRs with [ClientIPFromXFF] +// is impractical. +// +// WARNING: This variant is brittle to network architecture changes. If you +// add or remove a proxy level, numTrustedProxies silently becomes wrong and +// you may start trusting an attacker-supplied IP. Prefer [ClientIPFromXFF] +// with explicit trusted CIDRs whenever you can. +// +// If the XFF chain has fewer than numTrustedProxies entries (header missing +// or architecture changed), no client IP is set and [GetClientIP] returns "". +// +// Like [ClientIPFromXFF], v4-mapped IPv6 folds to plain v4 and IPv6 zones +// are stripped before storage. +// +// Panics at startup if numTrustedProxies < 1. +func ClientIPFromXFFTrustedProxies(numTrustedProxies int) func(http.Handler) http.Handler { + if numTrustedProxies < 1 { + panic("middleware.ClientIPFromXFFTrustedProxies: numTrustedProxies must be >= 1") + } + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := numTrustedProxies + var entry string + walkXFF(r.Header[xForwardedForHeader], func(v string) bool { + n-- + if n == 0 { + entry = v + return true + } + return false + }) + if entry != "" { + if ip, ok := parseHeaderAddr(entry); ok { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, ip)) + } + } + h.ServeHTTP(w, r) + }) + } +} + +// ClientIPFromRemoteAddr stores the client IP read from the TCP RemoteAddr +// of the incoming request — the IP address of whoever opened the connection +// to this server. Read it with [GetClientIP]. +// +// Use this when this server is directly connected to the public internet +// with NO reverse proxy in front of it. Behind a reverse proxy, RemoteAddr +// is the proxy's IP, not the client's — use [ClientIPFromHeader] or +// [ClientIPFromXFF] instead. +// +// IPv4 clients on a dual-stack listener surface as ::ffff:a.b.c.d; they +// fold to plain v4 before storage so one logical client maps to one key. +// IPv6 zones are preserved (link-local connections may legitimately have one). +func ClientIPFromRemoteAddr(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr // RemoteAddr may already be a bare IP (e.g. in tests). + } + if ip, err := netip.ParseAddr(host); err == nil { + r = r.WithContext(context.WithValue(r.Context(), clientIPCtxKey, ip.Unmap())) + } + h.ServeHTTP(w, r) + }) +} + +// GetClientIP returns the client IP as a string, as set by one of the +// ClientIPFrom* middlewares. Returns "" if no valid IP was set. +// Convenient for logging, rate-limit keys, etc. +func GetClientIP(ctx context.Context) string { + ip := GetClientIPAddr(ctx) + if !ip.IsValid() { + return "" + } + return ip.String() +} + +// GetClientIPAddr returns the client IP as a [netip.Addr], as set by one of +// the ClientIPFrom* middlewares. The returned Addr is the zero value if not +// set; use [netip.Addr.IsValid] to check. Useful when you need typed work — +// prefix containment, Is4/Is6, etc. — without re-parsing the string. +func GetClientIPAddr(ctx context.Context) netip.Addr { + ip, _ := ctx.Value(clientIPCtxKey).(netip.Addr) + return ip +} + +// walkXFF walks the entries of the merged X-Forwarded-For chain +// RIGHT-TO-LEFT, invoking visit on each trimmed non-empty entry. visit +// returns true to stop the walk. Lazy walk, zero allocations (entries +// are substrings of the input headers). +// +// Multiple XFF headers are merged per RFC 2616 — each header's +// comma-separated entries in order received — so an attacker cannot pick +// which value security logic sees by sending a duplicate header. +func walkXFF(headers []string, visit func(entry string) bool) { + for hi := len(headers) - 1; hi >= 0; hi-- { + h := headers[hi] + for h != "" { + var v string + if i := strings.LastIndexByte(h, ','); i >= 0 { + v, h = h[i+1:], h[:i] + } else { + v, h = h, "" + } + v = strings.TrimSpace(v) + if v == "" { + continue + } + if visit(v) { + return + } + } + } +} + +// inAnyPrefix reports whether ip falls within any of the given prefixes. +func inAnyPrefix(ip netip.Addr, prefixes []netip.Prefix) bool { + for _, p := range prefixes { + if p.Contains(ip) { + return true + } + } + return false +} + +// parseHeaderAddr parses s and normalizes for storage: v4-mapped IPv6 +// (::ffff:a.b.c.d) folds to plain v4, IPv6 zone is stripped. Both defend the +// trust-prefix check against attacker-injected aliases — [netip.Prefix.Contains] +// returns false for v4-mapped addresses vs v4 prefixes and for any zoned +// address, so without folding/stripping an attacker could escape an +// otherwise valid trust list. +// +// Header-sourced IPs only. [ClientIPFromRemoteAddr] normalizes inline +// (Unmap, but zone preserved for legitimate link-local connections). +func parseHeaderAddr(s string) (netip.Addr, bool) { + ip, err := netip.ParseAddr(s) + if err != nil { + return netip.Addr{}, false + } + return ip.Unmap().WithZone(""), true +} diff --git a/testdata/chi/middleware/compress.go b/testdata/chi/middleware/compress.go new file mode 100644 index 0000000..4e46f70 --- /dev/null +++ b/testdata/chi/middleware/compress.go @@ -0,0 +1,392 @@ +package middleware + +import ( + "bufio" + "compress/flate" + "compress/gzip" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" +) + +var defaultCompressibleContentTypes = []string{ + "text/html", + "text/css", + "text/plain", + "text/javascript", + "application/javascript", + "application/x-javascript", + "application/json", + "application/atom+xml", + "application/rss+xml", + "image/svg+xml", +} + +// Compress is a middleware that compresses response +// body of a given content types to a data format based +// on Accept-Encoding request header. It uses a given +// compression level. +// +// NOTE: make sure to set the Content-Type header on your response +// otherwise this middleware will not compress the response body. For ex, in +// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody)) +// or set it manually. +// +// Passing a compression level of 5 is sensible value +func Compress(level int, types ...string) func(next http.Handler) http.Handler { + compressor := NewCompressor(level, types...) + return compressor.Handler +} + +// Compressor represents a set of encoding configurations. +type Compressor struct { + // The mapping of encoder names to encoder functions. + encoders map[string]EncoderFunc + // The mapping of pooled encoders to pools. + pooledEncoders map[string]*sync.Pool + // The set of content types allowed to be compressed. + allowedTypes map[string]struct{} + allowedWildcards map[string]struct{} + // The list of encoders in order of decreasing precedence. + encodingPrecedence []string + level int // The compression level. +} + +// NewCompressor creates a new Compressor that will handle encoding responses. +// +// The level should be one of the ones defined in the flate package. +// The types are the content types that are allowed to be compressed. +func NewCompressor(level int, types ...string) *Compressor { + // If types are provided, set those as the allowed types. If none are + // provided, use the default list. + allowedTypes := make(map[string]struct{}) + allowedWildcards := make(map[string]struct{}) + if len(types) > 0 { + for _, t := range types { + if strings.Contains(strings.TrimSuffix(t, "/*"), "*") { + panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t)) + } + if before, ok := strings.CutSuffix(t, "/*"); ok { + allowedWildcards[before] = struct{}{} + } else { + allowedTypes[t] = struct{}{} + } + } + } else { + for _, t := range defaultCompressibleContentTypes { + allowedTypes[t] = struct{}{} + } + } + + c := &Compressor{ + level: level, + encoders: make(map[string]EncoderFunc), + pooledEncoders: make(map[string]*sync.Pool), + allowedTypes: allowedTypes, + allowedWildcards: allowedWildcards, + } + + // Set the default encoders. The precedence order uses the reverse + // ordering that the encoders were added. This means adding new encoders + // will move them to the front of the order. + // + // TODO: + // lzma: Opera. + // sdch: Chrome, Android. Gzip output + dictionary header. + // br: Brotli, see https://github.com/go-chi/chi/pull/326 + + // HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951) + // wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32 + // checksum compared to CRC-32 used in "gzip" and thus is faster. + // + // But.. some old browsers (MSIE, Safari 5.1) incorrectly expect + // raw DEFLATE data only, without the mentioned zlib wrapper. + // Because of this major confusion, most modern browsers try it + // both ways, first looking for zlib headers. + // Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548 + // + // The list of browsers having problems is quite big, see: + // http://zoompf.com/blog/2012/02/lose-the-wait-http-compression + // https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results + // + // That's why we prefer gzip over deflate. It's just more reliable + // and not significantly slower than deflate. + c.SetEncoder("deflate", encoderDeflate) + + // TODO: Exception for old MSIE browsers that can't handle non-HTML? + // https://zoompf.com/blog/2012/02/lose-the-wait-http-compression + c.SetEncoder("gzip", encoderGzip) + + // NOTE: Not implemented, intentionally: + // case "compress": // LZW. Deprecated. + // case "bzip2": // Too slow on-the-fly. + // case "zopfli": // Too slow on-the-fly. + // case "xz": // Too slow on-the-fly. + return c +} + +// SetEncoder can be used to set the implementation of a compression algorithm. +// +// The encoding should be a standardised identifier. See: +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding +// +// For example, add the Brotli algorithm: +// +// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc" +// +// compressor := middleware.NewCompressor(5, "text/html") +// compressor.SetEncoder("br", func(w io.Writer, level int) io.Writer { +// params := brotli_enc.NewBrotliParams() +// params.SetQuality(level) +// return brotli_enc.NewBrotliWriter(params, w) +// }) +func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) { + encoding = strings.ToLower(encoding) + if encoding == "" { + panic("the encoding can not be empty") + } + if fn == nil { + panic("attempted to set a nil encoder function") + } + + // If we are adding a new encoder that is already registered, we have to + // clear that one out first. + delete(c.pooledEncoders, encoding) + delete(c.encoders, encoding) + + // If the encoder supports Resetting (IoReseterWriter), then it can be pooled. + encoder := fn(io.Discard, c.level) + if _, ok := encoder.(ioResetterWriter); ok { + pool := &sync.Pool{ + New: func() interface{} { + return fn(io.Discard, c.level) + }, + } + c.pooledEncoders[encoding] = pool + } + // If the encoder is not in the pooledEncoders, add it to the normal encoders. + if _, ok := c.pooledEncoders[encoding]; !ok { + c.encoders[encoding] = fn + } + + for i, v := range c.encodingPrecedence { + if v == encoding { + c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...) + } + } + + c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...) +} + +// Handler returns a new middleware that will compress the response based on the +// current Compressor. +func (c *Compressor) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + encoder, encoding, cleanup := c.selectEncoder(r.Header, w) + + cw := &compressResponseWriter{ + ResponseWriter: w, + w: w, + contentTypes: c.allowedTypes, + contentWildcards: c.allowedWildcards, + encoding: encoding, + compressible: false, // determined in post-handler + } + if encoder != nil { + cw.w = encoder + } + // Re-add the encoder to the pool if applicable. + defer cleanup() + defer cw.Close() + + next.ServeHTTP(cw, r) + }) +} + +// selectEncoder returns the encoder, the name of the encoder, and a closer function. +func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) { + header := h.Get("Accept-Encoding") + + // Parse the names of all accepted algorithms from the header. + accepted := strings.Split(strings.ToLower(header), ",") + + // Find supported encoder by accepted list by precedence + for _, name := range c.encodingPrecedence { + if matchAcceptEncoding(accepted, name) { + if pool, ok := c.pooledEncoders[name]; ok { + encoder := pool.Get().(ioResetterWriter) + cleanup := func() { + pool.Put(encoder) + } + encoder.Reset(w) + return encoder, name, cleanup + + } + if fn, ok := c.encoders[name]; ok { + return fn(w, c.level), name, func() {} + } + } + + } + + // No encoder found to match the accepted encoding + return nil, "", func() {} +} + +func matchAcceptEncoding(accepted []string, encoding string) bool { + for _, v := range accepted { + if strings.Contains(v, encoding) { + return true + } + } + return false +} + +// An EncoderFunc is a function that wraps the provided io.Writer with a +// streaming compression algorithm and returns it. +// +// In case of failure, the function should return nil. +type EncoderFunc func(w io.Writer, level int) io.Writer + +// Interface for types that allow resetting io.Writers. +type ioResetterWriter interface { + io.Writer + Reset(w io.Writer) +} + +type compressResponseWriter struct { + http.ResponseWriter + + // The streaming encoder writer to be used if there is one. Otherwise, + // this is just the normal writer. + w io.Writer + contentTypes map[string]struct{} + contentWildcards map[string]struct{} + encoding string + wroteHeader bool + compressible bool +} + +func (cw *compressResponseWriter) isCompressible() bool { + // Parse the first part of the Content-Type response header. + contentType := cw.Header().Get("Content-Type") + contentType, _, _ = strings.Cut(contentType, ";") + + // Is the content type compressible? + if _, ok := cw.contentTypes[contentType]; ok { + return true + } + if contentType, _, hadSlash := strings.Cut(contentType, "/"); hadSlash { + _, ok := cw.contentWildcards[contentType] + return ok + } + return false +} + +func (cw *compressResponseWriter) WriteHeader(code int) { + if cw.wroteHeader { + cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate. + return + } + cw.wroteHeader = true + defer cw.ResponseWriter.WriteHeader(code) + + // Already compressed data? + if cw.Header().Get("Content-Encoding") != "" { + return + } + + if !cw.isCompressible() { + cw.compressible = false + return + } + + if cw.encoding != "" { + cw.compressible = true + cw.Header().Set("Content-Encoding", cw.encoding) + cw.Header().Add("Vary", "Accept-Encoding") + + // The content-length after compression is unknown + cw.Header().Del("Content-Length") + } +} + +func (cw *compressResponseWriter) Write(p []byte) (int, error) { + if !cw.wroteHeader { + cw.WriteHeader(http.StatusOK) + } + + return cw.writer().Write(p) +} + +func (cw *compressResponseWriter) writer() io.Writer { + if cw.compressible { + return cw.w + } + return cw.ResponseWriter +} + +type compressFlusher interface { + Flush() error +} + +func (cw *compressResponseWriter) Flush() { + if f, ok := cw.writer().(http.Flusher); ok { + f.Flush() + } + // If the underlying writer has a compression flush signature, + // call this Flush() method instead + if f, ok := cw.writer().(compressFlusher); ok { + f.Flush() + + // Also flush the underlying response writer + if f, ok := cw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + } +} + +func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := cw.writer().(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer") +} + +func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error { + if ps, ok := cw.writer().(http.Pusher); ok { + return ps.Push(target, opts) + } + return errors.New("chi/middleware: http.Pusher is unavailable on the writer") +} + +func (cw *compressResponseWriter) Close() error { + if c, ok := cw.writer().(io.WriteCloser); ok { + return c.Close() + } + return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer") +} + +func (cw *compressResponseWriter) Unwrap() http.ResponseWriter { + return cw.ResponseWriter +} + +func encoderGzip(w io.Writer, level int) io.Writer { + gw, err := gzip.NewWriterLevel(w, level) + if err != nil { + return nil + } + return gw +} + +func encoderDeflate(w io.Writer, level int) io.Writer { + dw, err := flate.NewWriter(w, level) + if err != nil { + return nil + } + return dw +} diff --git a/testdata/chi/middleware/content_charset.go b/testdata/chi/middleware/content_charset.go new file mode 100644 index 0000000..8e75fe8 --- /dev/null +++ b/testdata/chi/middleware/content_charset.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + "slices" + "strings" +) + +// ContentCharset generates a handler that writes a 415 Unsupported Media Type response if none of the charsets match. +// An empty charset will allow requests with no Content-Type header or no specified charset. +func ContentCharset(charsets ...string) func(next http.Handler) http.Handler { + for i, c := range charsets { + charsets[i] = strings.ToLower(c) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !contentEncoding(r.Header.Get("Content-Type"), charsets...) { + w.WriteHeader(http.StatusUnsupportedMediaType) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// Check the content encoding against a list of acceptable values. +func contentEncoding(ce string, charsets ...string) bool { + _, ce = split(strings.ToLower(ce), ";") + _, ce = split(ce, "charset=") + ce, _ = split(ce, ";") + return slices.Contains(charsets, ce) +} + +// Split a string in two parts, cleaning any whitespace. +func split(str, sep string) (string, string) { + a, b, found := strings.Cut(str, sep) + a = strings.TrimSpace(a) + if found { + b = strings.TrimSpace(b) + } + + return a, b +} diff --git a/testdata/chi/middleware/content_encoding.go b/testdata/chi/middleware/content_encoding.go new file mode 100644 index 0000000..e0b9ccc --- /dev/null +++ b/testdata/chi/middleware/content_encoding.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// AllowContentEncoding enforces a whitelist of request Content-Encoding otherwise responds +// with a 415 Unsupported Media Type status. +func AllowContentEncoding(contentEncoding ...string) func(next http.Handler) http.Handler { + allowedEncodings := make(map[string]struct{}, len(contentEncoding)) + for _, encoding := range contentEncoding { + allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))] = struct{}{} + } + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + requestEncodings := r.Header["Content-Encoding"] + // skip check for empty content body or no Content-Encoding + if r.ContentLength == 0 { + next.ServeHTTP(w, r) + return + } + // All encodings in the request must be allowed + for _, encoding := range requestEncodings { + if _, ok := allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))]; !ok { + w.WriteHeader(http.StatusUnsupportedMediaType) + return + } + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/testdata/chi/middleware/content_type.go b/testdata/chi/middleware/content_type.go new file mode 100644 index 0000000..cdfc21e --- /dev/null +++ b/testdata/chi/middleware/content_type.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// SetHeader is a convenience handler to set a response header key/value +func SetHeader(key, value string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(key, value) + next.ServeHTTP(w, r) + }) + } +} + +// AllowContentType enforces a whitelist of request Content-Types otherwise responds +// with a 415 Unsupported Media Type status. +func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { + allowedContentTypes := make(map[string]struct{}, len(contentTypes)) + for _, ctype := range contentTypes { + allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength == 0 { + // Skip check for empty content body + next.ServeHTTP(w, r) + return + } + + s, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";") + s = strings.ToLower(strings.TrimSpace(s)) + + if _, ok := allowedContentTypes[s]; ok { + next.ServeHTTP(w, r) + return + } + + w.WriteHeader(http.StatusUnsupportedMediaType) + }) + } +} diff --git a/testdata/chi/middleware/get_head.go b/testdata/chi/middleware/get_head.go new file mode 100644 index 0000000..d4606d8 --- /dev/null +++ b/testdata/chi/middleware/get_head.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +// GetHead automatically route undefined HEAD requests to GET handlers. +func GetHead(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + rctx := chi.RouteContext(r.Context()) + routePath := rctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + } + + // Temporary routing context to look-ahead before routing the request + tctx := chi.NewRouteContext() + + // Attempt to find a HEAD handler for the routing path, if not found, traverse + // the router as through its a GET route, but proceed with the request + // with the HEAD method. + if !rctx.Routes.Match(tctx, "HEAD", routePath) { + rctx.RouteMethod = "GET" + rctx.RoutePath = routePath + next.ServeHTTP(w, r) + return + } + } + + next.ServeHTTP(w, r) + }) +} diff --git a/testdata/chi/middleware/heartbeat.go b/testdata/chi/middleware/heartbeat.go new file mode 100644 index 0000000..f36e8cc --- /dev/null +++ b/testdata/chi/middleware/heartbeat.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// Heartbeat endpoint middleware useful to setting up a path like +// `/ping` that load balancers or uptime testing external services +// can make a request before hitting any routes. It's also convenient +// to place this above ACL middlewares as well. +func Heartbeat(endpoint string) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if (r.Method == "GET" || r.Method == "HEAD") && strings.EqualFold(r.URL.Path, endpoint) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(".")) + return + } + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + return f +} diff --git a/testdata/chi/middleware/logger.go b/testdata/chi/middleware/logger.go new file mode 100644 index 0000000..4d30a9a --- /dev/null +++ b/testdata/chi/middleware/logger.go @@ -0,0 +1,178 @@ +package middleware + +import ( + "bytes" + "context" + "log" + "net/http" + "os" + "runtime" + "time" +) + +var ( + // LogEntryCtxKey is the context.Context key to store the request log entry. + LogEntryCtxKey = &contextKey{"LogEntry"} + + // DefaultLogger is called by the Logger middleware handler to log each request. + // Its made a package-level variable so that it can be reconfigured for custom + // logging configurations. + DefaultLogger func(next http.Handler) http.Handler +) + +// Logger is a middleware that logs the start and end of each request, along +// with some useful data about what was requested, what the response status was, +// and how long it took to return. When standard output is a TTY, Logger will +// print in color, otherwise it will print in black and white. Logger prints a +// request ID if one is provided. +// +// Alternatively, look at https://github.com/goware/httplog for a more in-depth +// http logger with structured logging support. +// +// IMPORTANT NOTE: Logger should go before any other middleware that may change +// the response, such as middleware.Recoverer. Example: +// +// r := chi.NewRouter() +// r.Use(middleware.Logger) // <--<< Logger should come before Recoverer +// r.Use(middleware.Recoverer) +// r.Get("/", handler) +func Logger(next http.Handler) http.Handler { + return DefaultLogger(next) +} + +// RequestLogger returns a logger handler using a custom LogFormatter. +func RequestLogger(f LogFormatter) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + entry := f.NewLogEntry(r) + ww := NewWrapResponseWriter(w, r.ProtoMajor) + + t1 := time.Now() + defer func() { + entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil) + }() + + next.ServeHTTP(ww, WithLogEntry(r, entry)) + } + return http.HandlerFunc(fn) + } +} + +// LogFormatter initiates the beginning of a new LogEntry per request. +// See DefaultLogFormatter for an example implementation. +type LogFormatter interface { + NewLogEntry(r *http.Request) LogEntry +} + +// LogEntry records the final log when a request completes. +// See defaultLogEntry for an example implementation. +type LogEntry interface { + Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) + Panic(v interface{}, stack []byte) +} + +// GetLogEntry returns the in-context LogEntry for a request. +func GetLogEntry(r *http.Request) LogEntry { + entry, _ := r.Context().Value(LogEntryCtxKey).(LogEntry) + return entry +} + +// WithLogEntry sets the in-context LogEntry for a request. +func WithLogEntry(r *http.Request, entry LogEntry) *http.Request { + r = r.WithContext(context.WithValue(r.Context(), LogEntryCtxKey, entry)) + return r +} + +// LoggerInterface accepts printing to stdlib logger or compatible logger. +type LoggerInterface interface { + Print(v ...interface{}) +} + +// DefaultLogFormatter is a simple logger that implements a LogFormatter. +type DefaultLogFormatter struct { + Logger LoggerInterface + NoColor bool +} + +// NewLogEntry creates a new LogEntry for the request. +func (l *DefaultLogFormatter) NewLogEntry(r *http.Request) LogEntry { + ctx := r.Context() + + useColor := !l.NoColor + entry := &defaultLogEntry{ + DefaultLogFormatter: l, + request: r, + buf: &bytes.Buffer{}, + useColor: useColor, + } + + reqID := GetReqID(ctx) + if reqID != "" { + cW(entry.buf, useColor, nYellow, "[%s] ", reqID) + } + cW(entry.buf, useColor, nCyan, "\"") + cW(entry.buf, useColor, bMagenta, "%s ", r.Method) + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + cW(entry.buf, useColor, nCyan, "%s://%s%s %s\" ", scheme, r.Host, r.RequestURI, r.Proto) + + entry.buf.WriteString("from ") + clientIP := GetClientIP(ctx) + if clientIP == "" { + clientIP = r.RemoteAddr + } + entry.buf.WriteString(clientIP) + entry.buf.WriteString(" - ") + + return entry +} + +type defaultLogEntry struct { + *DefaultLogFormatter + request *http.Request + buf *bytes.Buffer + useColor bool +} + +func (l *defaultLogEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { + switch { + case status < 200: + cW(l.buf, l.useColor, bBlue, "%03d", status) + case status < 300: + cW(l.buf, l.useColor, bGreen, "%03d", status) + case status < 400: + cW(l.buf, l.useColor, bCyan, "%03d", status) + case status < 500: + cW(l.buf, l.useColor, bYellow, "%03d", status) + default: + cW(l.buf, l.useColor, bRed, "%03d", status) + } + + cW(l.buf, l.useColor, bBlue, " %dB", bytes) + + l.buf.WriteString(" in ") + if elapsed < 500*time.Millisecond { + cW(l.buf, l.useColor, nGreen, "%s", elapsed) + } else if elapsed < 5*time.Second { + cW(l.buf, l.useColor, nYellow, "%s", elapsed) + } else { + cW(l.buf, l.useColor, nRed, "%s", elapsed) + } + + l.Logger.Print(l.buf.String()) +} + +func (l *defaultLogEntry) Panic(v interface{}, stack []byte) { + PrintPrettyStack(v) +} + +func init() { + color := true + if runtime.GOOS == "windows" { + color = false + } + DefaultLogger = RequestLogger(&DefaultLogFormatter{Logger: log.New(os.Stdout, "", log.LstdFlags), NoColor: !color}) +} diff --git a/testdata/chi/middleware/maybe.go b/testdata/chi/middleware/maybe.go new file mode 100644 index 0000000..eabca00 --- /dev/null +++ b/testdata/chi/middleware/maybe.go @@ -0,0 +1,18 @@ +package middleware + +import "net/http" + +// Maybe middleware will allow you to change the flow of the middleware stack execution depending on return +// value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if +// a request does not satisfy the maybeFn logic. +func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) bool) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if maybeFn(r) { + mw(next).ServeHTTP(w, r) + } else { + next.ServeHTTP(w, r) + } + }) + } +} diff --git a/testdata/chi/middleware/middleware.go b/testdata/chi/middleware/middleware.go new file mode 100644 index 0000000..cc371e0 --- /dev/null +++ b/testdata/chi/middleware/middleware.go @@ -0,0 +1,23 @@ +package middleware + +import "net/http" + +// New will create a new middleware handler from a http.Handler. +func New(h http.Handler) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) + } +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "chi/middleware context value " + k.name +} diff --git a/testdata/chi/middleware/nocache.go b/testdata/chi/middleware/nocache.go new file mode 100644 index 0000000..9308d40 --- /dev/null +++ b/testdata/chi/middleware/nocache.go @@ -0,0 +1,59 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "net/http" + "time" +) + +// Unix epoch time +var epoch = time.Unix(0, 0).UTC().Format(http.TimeFormat) + +// Taken from https://github.com/mytrile/nocache +var noCacheHeaders = map[string]string{ + "Expires": epoch, + "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", + "Pragma": "no-cache", + "X-Accel-Expires": "0", +} + +var etagHeaders = []string{ + "ETag", + "If-Modified-Since", + "If-Match", + "If-None-Match", + "If-Range", + "If-Unmodified-Since", +} + +// NoCache is a simple piece of middleware that sets a number of HTTP headers to prevent +// a router (or subrouter) from being cached by an upstream proxy and/or client. +// +// As per http://wiki.nginx.org/HttpProxyModule - NoCache sets: +// +// Expires: Thu, 01 Jan 1970 00:00:00 UTC +// Cache-Control: no-cache, private, max-age=0 +// X-Accel-Expires: 0 +// Pragma: no-cache (for HTTP/1.0 proxies/clients) +func NoCache(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + + // Delete any ETag headers that may have been set + for _, v := range etagHeaders { + if r.Header.Get(v) != "" { + r.Header.Del(v) + } + } + + // Set our NoCache headers + for k, v := range noCacheHeaders { + w.Header().Set(k, v) + } + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} diff --git a/testdata/chi/middleware/page_route.go b/testdata/chi/middleware/page_route.go new file mode 100644 index 0000000..32871b7 --- /dev/null +++ b/testdata/chi/middleware/page_route.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// PageRoute is a simple middleware which allows you to route a static GET request +// at the middleware stack level. +func PageRoute(path string, handler http.Handler) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && strings.EqualFold(r.URL.Path, path) { + handler.ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/path_rewrite.go b/testdata/chi/middleware/path_rewrite.go new file mode 100644 index 0000000..99af62c --- /dev/null +++ b/testdata/chi/middleware/path_rewrite.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// PathRewrite is a simple middleware which allows you to rewrite the request URL path. +func PathRewrite(old, new string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.Replace(r.URL.Path, old, new, 1) + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/profiler.go b/testdata/chi/middleware/profiler.go new file mode 100644 index 0000000..0ad6a99 --- /dev/null +++ b/testdata/chi/middleware/profiler.go @@ -0,0 +1,49 @@ +//go:build !tinygo +// +build !tinygo + +package middleware + +import ( + "expvar" + "net/http" + "net/http/pprof" + + "github.com/go-chi/chi/v5" +) + +// Profiler is a convenient subrouter used for mounting net/http/pprof. ie. +// +// func MyService() http.Handler { +// r := chi.NewRouter() +// // ..middlewares +// r.Mount("/debug", middleware.Profiler()) +// // ..routes +// return r +// } +func Profiler() http.Handler { + r := chi.NewRouter() + r.Use(NoCache) + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.RequestURI+"/pprof/", http.StatusMovedPermanently) + }) + r.HandleFunc("/pprof", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.RequestURI+"/", http.StatusMovedPermanently) + }) + + r.HandleFunc("/pprof/*", pprof.Index) + r.HandleFunc("/pprof/cmdline", pprof.Cmdline) + r.HandleFunc("/pprof/profile", pprof.Profile) + r.HandleFunc("/pprof/symbol", pprof.Symbol) + r.HandleFunc("/pprof/trace", pprof.Trace) + r.Handle("/vars", expvar.Handler()) + + r.Handle("/pprof/goroutine", pprof.Handler("goroutine")) + r.Handle("/pprof/threadcreate", pprof.Handler("threadcreate")) + r.Handle("/pprof/mutex", pprof.Handler("mutex")) + r.Handle("/pprof/heap", pprof.Handler("heap")) + r.Handle("/pprof/block", pprof.Handler("block")) + r.Handle("/pprof/allocs", pprof.Handler("allocs")) + + return r +} diff --git a/testdata/chi/middleware/realip.go b/testdata/chi/middleware/realip.go new file mode 100644 index 0000000..349f168 --- /dev/null +++ b/testdata/chi/middleware/realip.go @@ -0,0 +1,53 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "net" + "net/http" + "strings" +) + +var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") +var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") +var xRealIP = http.CanonicalHeaderKey("X-Real-IP") + +// RealIP is a middleware that sets a http.Request's RemoteAddr to the results +// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers +// (in that order). +// +// Deprecated: RealIP is vulnerable to IP spoofing — it mutates r.RemoteAddr +// to the leftmost X-Forwarded-For value, or to True-Client-IP / X-Real-IP +// whether or not your infrastructure actually sets them. See +// GHSA-3fxj-6jh8-hvhx, GHSA-rjr7-jggh-pgcp, GHSA-9g5q-2w5x-hmxf. +// +// Use [ClientIPFromHeader], [ClientIPFromXFF], [ClientIPFromXFFTrustedProxies] +// or [ClientIPFromRemoteAddr] and read the IP with [GetClientIP] instead. +// These never mutate r.RemoteAddr. +func RealIP(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if rip := realIP(r); rip != "" { + r.RemoteAddr = rip + } + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +func realIP(r *http.Request) string { + var ip string + + if tcip := r.Header.Get(trueClientIP); tcip != "" { + ip = tcip + } else if xrip := r.Header.Get(xRealIP); xrip != "" { + ip = xrip + } else if xff := r.Header.Get(xForwardedFor); xff != "" { + ip, _, _ = strings.Cut(xff, ",") + } + if ip == "" || net.ParseIP(ip) == nil { + return "" + } + return ip +} diff --git a/testdata/chi/middleware/recoverer.go b/testdata/chi/middleware/recoverer.go new file mode 100644 index 0000000..81342df --- /dev/null +++ b/testdata/chi/middleware/recoverer.go @@ -0,0 +1,203 @@ +package middleware + +// The original work was derived from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "os" + "runtime/debug" + "strings" +) + +// Recoverer is a middleware that recovers from panics, logs the panic (and a +// backtrace), and returns a HTTP 500 (Internal Server Error) status if +// possible. Recoverer prints a request ID if one is provided. +// +// Alternatively, look at https://github.com/go-chi/httplog middleware pkgs. +func Recoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + if rvr == http.ErrAbortHandler { + // we don't recover http.ErrAbortHandler so the response + // to the client is aborted, this should not be logged + panic(rvr) + } + + logEntry := GetLogEntry(r) + if logEntry != nil { + logEntry.Panic(rvr, debug.Stack()) + } else { + PrintPrettyStack(rvr) + } + + if r.Header.Get("Connection") != "Upgrade" { + w.WriteHeader(http.StatusInternalServerError) + } + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +// for ability to test the PrintPrettyStack function +var recovererErrorWriter io.Writer = os.Stderr + +func PrintPrettyStack(rvr interface{}) { + debugStack := debug.Stack() + s := prettyStack{} + out, err := s.parse(debugStack, rvr) + if err == nil { + recovererErrorWriter.Write(out) + } else { + // print stdlib output as a fallback + os.Stderr.Write(debugStack) + } +} + +type prettyStack struct { +} + +func (s prettyStack) parse(debugStack []byte, rvr interface{}) ([]byte, error) { + var err error + useColor := true + buf := &bytes.Buffer{} + + cW(buf, false, bRed, "\n") + cW(buf, useColor, bCyan, " panic: ") + cW(buf, useColor, bBlue, "%v", rvr) + cW(buf, false, bWhite, "\n \n") + + // process debug stack info + stack := strings.Split(string(debugStack), "\n") + lines := []string{} + + // locate panic line, as we may have nested panics + for i := len(stack) - 1; i > 0; i-- { + lines = append(lines, stack[i]) + if strings.HasPrefix(stack[i], "panic(") { + lines = lines[0 : len(lines)-2] // remove boilerplate + break + } + } + + // reverse + for i := len(lines)/2 - 1; i >= 0; i-- { + opp := len(lines) - 1 - i + lines[i], lines[opp] = lines[opp], lines[i] + } + + // decorate + for i, line := range lines { + lines[i], err = s.decorateLine(line, useColor, i) + if err != nil { + return nil, err + } + } + + for _, l := range lines { + fmt.Fprintf(buf, "%s", l) + } + return buf.Bytes(), nil +} + +func (s prettyStack) decorateLine(line string, useColor bool, num int) (string, error) { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "\t") || strings.Contains(line, ".go:") { + return s.decorateSourceLine(line, useColor, num) + } + if strings.HasSuffix(line, ")") { + return s.decorateFuncCallLine(line, useColor, num) + } + if strings.HasPrefix(line, "\t") { + return strings.Replace(line, "\t", " ", 1), nil + } + return fmt.Sprintf(" %s\n", line), nil +} + +func (s prettyStack) decorateFuncCallLine(line string, useColor bool, num int) (string, error) { + idx := strings.LastIndex(line, "(") + if idx < 0 { + return "", errors.New("not a func call line") + } + + buf := &bytes.Buffer{} + pkg := line[0:idx] + // addr := line[idx:] + method := "" + + if idx := strings.LastIndex(pkg, string(os.PathSeparator)); idx < 0 { + if idx := strings.Index(pkg, "."); idx > 0 { + method = pkg[idx:] + pkg = pkg[0:idx] + } + } else { + method = pkg[idx+1:] + pkg = pkg[0 : idx+1] + if idx := strings.Index(method, "."); idx > 0 { + pkg += method[0:idx] + method = method[idx:] + } + } + pkgColor := nYellow + methodColor := bGreen + + if num == 0 { + cW(buf, useColor, bRed, " -> ") + pkgColor = bMagenta + methodColor = bRed + } else { + cW(buf, useColor, bWhite, " ") + } + cW(buf, useColor, pkgColor, "%s", pkg) + cW(buf, useColor, methodColor, "%s\n", method) + // cW(buf, useColor, nBlack, "%s", addr) + return buf.String(), nil +} + +func (s prettyStack) decorateSourceLine(line string, useColor bool, num int) (string, error) { + idx := strings.LastIndex(line, ".go:") + if idx < 0 { + return "", errors.New("not a source line") + } + + buf := &bytes.Buffer{} + path := line[0 : idx+3] + lineno := line[idx+3:] + + idx = strings.LastIndex(path, string(os.PathSeparator)) + dir := path[0 : idx+1] + file := path[idx+1:] + + idx = strings.Index(lineno, " ") + if idx > 0 { + lineno = lineno[0:idx] + } + fileColor := bCyan + lineColor := bGreen + + if num == 1 { + cW(buf, useColor, bRed, " -> ") + fileColor = bRed + lineColor = bMagenta + } else { + cW(buf, false, bWhite, " ") + } + cW(buf, useColor, bWhite, "%s", dir) + cW(buf, useColor, fileColor, "%s", file) + cW(buf, useColor, lineColor, "%s", lineno) + if num == 1 { + cW(buf, false, bWhite, "\n") + } + cW(buf, false, bWhite, "\n") + + return buf.String(), nil +} diff --git a/testdata/chi/middleware/request_id.go b/testdata/chi/middleware/request_id.go new file mode 100644 index 0000000..e1d4ccb --- /dev/null +++ b/testdata/chi/middleware/request_id.go @@ -0,0 +1,96 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + "sync/atomic" +) + +// Key to use when setting the request ID. +type ctxKeyRequestID int + +// RequestIDKey is the key that holds the unique request ID in a request context. +const RequestIDKey ctxKeyRequestID = 0 + +// RequestIDHeader is the name of the HTTP Header which contains the request id. +// Exported so that it can be changed by developers +var RequestIDHeader = "X-Request-Id" + +var prefix string +var reqid atomic.Uint64 + +// A quick note on the statistics here: we're trying to calculate the chance that +// two randomly generated base62 prefixes will collide. We use the formula from +// http://en.wikipedia.org/wiki/Birthday_problem +// +// P[m, n] \approx 1 - e^{-m^2/2n} +// +// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server +// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$ +// +// For a $k$ character base-62 identifier, we have $n(k) = 62^k$ +// +// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for +// our purposes, and is surely more than anyone would ever need in practice -- a +// process that is rebooted a handful of times a day for a hundred years has less +// than a millionth of a percent chance of generating two colliding IDs. + +func init() { + hostname, err := os.Hostname() + if hostname == "" || err != nil { + hostname = "localhost" + } + var buf [12]byte + var b64 string + for len(b64) < 10 { + rand.Read(buf[:]) + b64 = base64.StdEncoding.EncodeToString(buf[:]) + b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) + } + + prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) +} + +// RequestID is a middleware that injects a request ID into the context of each +// request. A request ID is a string of the form "host.example.com/random-0001", +// where "random" is a base62 random string that uniquely identifies this go +// process, and where the last number is an atomically incremented request +// counter. +func RequestID(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + requestID := r.Header.Get(RequestIDHeader) + if requestID == "" { + myid := reqid.Add(1) + requestID = fmt.Sprintf("%s-%06d", prefix, myid) + } + ctx = context.WithValue(ctx, RequestIDKey, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// GetReqID returns a request ID from the given context if one is present. +// Returns the empty string if a request ID cannot be found. +func GetReqID(ctx context.Context) string { + if ctx == nil { + return "" + } + if reqID, ok := ctx.Value(RequestIDKey).(string); ok { + return reqID + } + return "" +} + +// NextRequestID generates the next request ID in the sequence. +func NextRequestID() uint64 { + return reqid.Add(1) +} diff --git a/testdata/chi/middleware/request_size.go b/testdata/chi/middleware/request_size.go new file mode 100644 index 0000000..678248c --- /dev/null +++ b/testdata/chi/middleware/request_size.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "net/http" +) + +// RequestSize is a middleware that will limit request sizes to a specified +// number of bytes. It uses MaxBytesReader to do so. +func RequestSize(bytes int64) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, bytes) + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + return f +} diff --git a/testdata/chi/middleware/route_headers.go b/testdata/chi/middleware/route_headers.go new file mode 100644 index 0000000..1c3334d --- /dev/null +++ b/testdata/chi/middleware/route_headers.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// RouteHeaders is a neat little header-based router that allows you to direct +// the flow of a request through a middleware stack based on a request header. +// +// For example, lets say you'd like to setup multiple routers depending on the +// request Host header, you could then do something as so: +// +// r := chi.NewRouter() +// rSubdomain := chi.NewRouter() +// r.Use(middleware.RouteHeaders(). +// Route("Host", "example.com", middleware.New(r)). +// Route("Host", "*.example.com", middleware.New(rSubdomain)). +// Handler) +// r.Get("/", h) +// rSubdomain.Get("/", h2) +// +// Another example, imagine you want to setup multiple CORS handlers, where for +// your origin servers you allow authorized requests, but for third-party public +// requests, authorization is disabled. +// +// r := chi.NewRouter() +// r.Use(middleware.RouteHeaders(). +// Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{ +// AllowedOrigins: []string{"https://api.skyweaver.net"}, +// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, +// AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, +// AllowCredentials: true, // <----------<<< allow credentials +// })). +// Route("Origin", "*", cors.Handler(cors.Options{ +// AllowedOrigins: []string{"*"}, +// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, +// AllowedHeaders: []string{"Accept", "Content-Type"}, +// AllowCredentials: false, // <----------<<< do not allow credentials +// })). +// Handler) +func RouteHeaders() HeaderRouter { + return HeaderRouter{} +} + +type HeaderRouter map[string][]HeaderRoute + +func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter { + header = strings.ToLower(header) + k := hr[header] + if k == nil { + hr[header] = []HeaderRoute{} + } + hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler}) + return hr +} + +func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter { + header = strings.ToLower(header) + k := hr[header] + if k == nil { + hr[header] = []HeaderRoute{} + } + patterns := []Pattern{} + for _, m := range match { + patterns = append(patterns, NewPattern(m)) + } + hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler}) + return hr +} + +func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter { + hr["*"] = []HeaderRoute{{Middleware: handler}} + return hr +} + +func (hr HeaderRouter) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(hr) == 0 { + // skip if no routes set + next.ServeHTTP(w, r) + return + } + + // find first matching header route, and continue + for header, matchers := range hr { + headerValue := r.Header.Get(header) + if headerValue == "" { + continue + } + headerValue = strings.ToLower(headerValue) + for _, matcher := range matchers { + if matcher.IsMatch(headerValue) { + matcher.Middleware(next).ServeHTTP(w, r) + return + } + } + } + + // if no match, check for "*" default route + matcher, ok := hr["*"] + if !ok || matcher[0].Middleware == nil { + next.ServeHTTP(w, r) + return + } + matcher[0].Middleware(next).ServeHTTP(w, r) + }) +} + +type HeaderRoute struct { + Middleware func(next http.Handler) http.Handler + MatchOne Pattern + MatchAny []Pattern +} + +func (r HeaderRoute) IsMatch(value string) bool { + if len(r.MatchAny) > 0 { + for _, m := range r.MatchAny { + if m.Match(value) { + return true + } + } + } else if r.MatchOne.Match(value) { + return true + } + return false +} + +type Pattern struct { + prefix string + suffix string + wildcard bool +} + +func NewPattern(value string) Pattern { + p := Pattern{} + p.prefix, p.suffix, p.wildcard = strings.Cut(value, "*") + return p +} + +func (p Pattern) Match(v string) bool { + if !p.wildcard { + return p.prefix == v + } + return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix) +} diff --git a/testdata/chi/middleware/strip.go b/testdata/chi/middleware/strip.go new file mode 100644 index 0000000..32d21e9 --- /dev/null +++ b/testdata/chi/middleware/strip.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +// StripSlashes is a middleware that will match request paths with a trailing +// slash, strip it from the path and continue routing through the mux, if a route +// matches, then it will serve the handler. +func StripSlashes(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + var path string + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + path = rctx.RoutePath + } else { + path = r.URL.Path + } + if len(path) > 1 && path[len(path)-1] == '/' { + newPath := path[:len(path)-1] + if rctx == nil { + r.URL.Path = newPath + } else { + rctx.RoutePath = newPath + } + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +// RedirectSlashes is a middleware that will match request paths with a trailing +// slash and redirect to the same path, less the trailing slash. +// +// NOTE: RedirectSlashes middleware is *incompatible* with http.FileServer, +// see https://github.com/go-chi/chi/issues/343 +func RedirectSlashes(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + var path string + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + path = rctx.RoutePath + } else { + path = r.URL.Path + } + + if len(path) > 1 && path[len(path)-1] == '/' { + // Normalize backslashes to forward slashes to prevent "/\evil.com" style redirects + // that some clients may interpret as protocol-relative. + path = strings.ReplaceAll(path, `\`, `/`) + + // Collapse leading/trailing slashes and force a single leading slash. + path := "/" + strings.Trim(path, "/") + + if r.URL.RawQuery != "" { + path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery) + } + http.Redirect(w, r, path, 301) + return + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +// StripPrefix is a middleware that will strip the provided prefix from the +// request path before handing the request over to the next handler. +func StripPrefix(prefix string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.StripPrefix(prefix, next) + } +} diff --git a/testdata/chi/middleware/sunset.go b/testdata/chi/middleware/sunset.go new file mode 100644 index 0000000..18815d5 --- /dev/null +++ b/testdata/chi/middleware/sunset.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "net/http" + "time" +) + +// Sunset set Deprecation/Sunset header to response +// This can be used to enable Sunset in a route or a route group +// For more: https://www.rfc-editor.org/rfc/rfc8594.html +func Sunset(sunsetAt time.Time, links ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !sunsetAt.IsZero() { + w.Header().Set("Sunset", sunsetAt.Format(http.TimeFormat)) + w.Header().Set("Deprecation", sunsetAt.Format(http.TimeFormat)) + + for _, link := range links { + w.Header().Add("Link", link) + } + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/supress_notfound.go b/testdata/chi/middleware/supress_notfound.go new file mode 100644 index 0000000..83a8a87 --- /dev/null +++ b/testdata/chi/middleware/supress_notfound.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +// SupressNotFound will quickly respond with a 404 if the route is not found +// and will not continue to the next middleware handler. +// +// This is handy to put at the top of your middleware stack to avoid unnecessary +// processing of requests that are not going to match any routes anyway. For +// example its super annoying to see a bunch of 404's in your logs from bots. +func SupressNotFound(router *chi.Mux) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + match := rctx.Routes.Match(rctx, r.Method, r.URL.Path) + if !match { + router.NotFoundHandler().ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/testdata/chi/middleware/terminal.go b/testdata/chi/middleware/terminal.go new file mode 100644 index 0000000..5ead7b9 --- /dev/null +++ b/testdata/chi/middleware/terminal.go @@ -0,0 +1,63 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "fmt" + "io" + "os" +) + +var ( + // Normal colors + nBlack = []byte{'\033', '[', '3', '0', 'm'} + nRed = []byte{'\033', '[', '3', '1', 'm'} + nGreen = []byte{'\033', '[', '3', '2', 'm'} + nYellow = []byte{'\033', '[', '3', '3', 'm'} + nBlue = []byte{'\033', '[', '3', '4', 'm'} + nMagenta = []byte{'\033', '[', '3', '5', 'm'} + nCyan = []byte{'\033', '[', '3', '6', 'm'} + nWhite = []byte{'\033', '[', '3', '7', 'm'} + // Bright colors + bBlack = []byte{'\033', '[', '3', '0', ';', '1', 'm'} + bRed = []byte{'\033', '[', '3', '1', ';', '1', 'm'} + bGreen = []byte{'\033', '[', '3', '2', ';', '1', 'm'} + bYellow = []byte{'\033', '[', '3', '3', ';', '1', 'm'} + bBlue = []byte{'\033', '[', '3', '4', ';', '1', 'm'} + bMagenta = []byte{'\033', '[', '3', '5', ';', '1', 'm'} + bCyan = []byte{'\033', '[', '3', '6', ';', '1', 'm'} + bWhite = []byte{'\033', '[', '3', '7', ';', '1', 'm'} + + reset = []byte{'\033', '[', '0', 'm'} +) + +var IsTTY bool + +func init() { + // This is sort of cheating: if stdout is a character device, we assume + // that means it's a TTY. Unfortunately, there are many non-TTY + // character devices, but fortunately stdout is rarely set to any of + // them. + // + // We could solve this properly by pulling in a dependency on + // code.google.com/p/go.crypto/ssh/terminal, for instance, but as a + // heuristic for whether to print in color or in black-and-white, I'd + // really rather not. + fi, err := os.Stdout.Stat() + if err == nil { + m := os.ModeDevice | os.ModeCharDevice + IsTTY = fi.Mode()&m == m + } +} + +// colorWrite +func cW(w io.Writer, useColor bool, color []byte, s string, args ...interface{}) { + if IsTTY && useColor { + w.Write(color) + } + fmt.Fprintf(w, s, args...) + if IsTTY && useColor { + w.Write(reset) + } +} diff --git a/testdata/chi/middleware/throttle.go b/testdata/chi/middleware/throttle.go new file mode 100644 index 0000000..7ea482b --- /dev/null +++ b/testdata/chi/middleware/throttle.go @@ -0,0 +1,151 @@ +package middleware + +import ( + "net/http" + "strconv" + "time" +) + +const ( + errCapacityExceeded = "Server capacity exceeded." + errTimedOut = "Timed out while waiting for a pending request to complete." + errContextCanceled = "Context was canceled." +) + +var ( + defaultBacklogTimeout = time.Second * 60 +) + +// ThrottleOpts represents a set of throttling options. +type ThrottleOpts struct { + RetryAfterFn func(ctxDone bool) time.Duration + Limit int + BacklogLimit int + BacklogTimeout time.Duration + StatusCode int +} + +// Throttle is a middleware that limits number of currently processed requests +// at a time across all users. Note: Throttle is not a rate-limiter per user, +// instead it just puts a ceiling on the number of current in-flight requests +// being processed from the point from where the Throttle middleware is mounted. +func Throttle(limit int) func(http.Handler) http.Handler { + return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout}) +} + +// ThrottleBacklog is a middleware that limits number of currently processed +// requests at a time and provides a backlog for holding a finite number of +// pending requests. +func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler { + return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogLimit: backlogLimit, BacklogTimeout: backlogTimeout}) +} + +// ThrottleWithOpts is a middleware that limits number of currently processed requests using passed ThrottleOpts. +func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { + if opts.Limit < 1 { + panic("chi/middleware: Throttle expects limit > 0") + } + + if opts.BacklogLimit < 0 { + panic("chi/middleware: Throttle expects backlogLimit to be positive") + } + + statusCode := opts.StatusCode + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + + t := throttler{ + tokens: make(chan token, opts.Limit), + backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit), + backlogTimeout: opts.BacklogTimeout, + statusCode: statusCode, + retryAfterFn: opts.RetryAfterFn, + } + + // Filling tokens. + for i := 0; i < opts.Limit+opts.BacklogLimit; i++ { + if i < opts.Limit { + t.tokens <- token{} + } + t.backlogTokens <- token{} + } + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + select { + + case <-ctx.Done(): + t.setRetryAfterHeaderIfNeeded(w, true) + http.Error(w, errContextCanceled, t.statusCode) + return + + case btok := <-t.backlogTokens: + defer func() { + t.backlogTokens <- btok + }() + + // Try to get a processing token immediately first + select { + case tok := <-t.tokens: + defer func() { + t.tokens <- tok + }() + next.ServeHTTP(w, r) + return + default: + // No immediate token available, need to wait with timer + } + + timer := time.NewTimer(t.backlogTimeout) + select { + case <-timer.C: + t.setRetryAfterHeaderIfNeeded(w, false) + http.Error(w, errTimedOut, t.statusCode) + return + case <-ctx.Done(): + timer.Stop() + t.setRetryAfterHeaderIfNeeded(w, true) + http.Error(w, errContextCanceled, t.statusCode) + return + case tok := <-t.tokens: + defer func() { + timer.Stop() + t.tokens <- tok + }() + next.ServeHTTP(w, r) + } + return + + default: + t.setRetryAfterHeaderIfNeeded(w, false) + http.Error(w, errCapacityExceeded, t.statusCode) + return + } + } + + return http.HandlerFunc(fn) + } +} + +// token represents a request that is being processed. +type token struct{} + +// throttler limits number of currently processed requests at a time. +type throttler struct { + tokens chan token + backlogTokens chan token + retryAfterFn func(ctxDone bool) time.Duration + backlogTimeout time.Duration + statusCode int +} + +// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized. +func (t throttler) setRetryAfterHeaderIfNeeded(w http.ResponseWriter, ctxDone bool) { + if t.retryAfterFn == nil { + return + } + w.Header().Set("Retry-After", strconv.Itoa(int(t.retryAfterFn(ctxDone).Seconds()))) +} diff --git a/testdata/chi/middleware/timeout.go b/testdata/chi/middleware/timeout.go new file mode 100644 index 0000000..add596d --- /dev/null +++ b/testdata/chi/middleware/timeout.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "context" + "net/http" + "time" +) + +// Timeout is a middleware that cancels ctx after a given timeout and return +// a 504 Gateway Timeout error to the client. +// +// It's required that you select the ctx.Done() channel to check for the signal +// if the context has reached its deadline and return, otherwise the timeout +// signal will be just ignored. +// +// ie. a route/handler may look like: +// +// r.Get("/long", func(w http.ResponseWriter, r *http.Request) { +// ctx := r.Context() +// processTime := time.Duration(rand.Intn(4)+1) * time.Second +// +// select { +// case <-ctx.Done(): +// return +// +// case <-time.After(processTime): +// // The above channel simulates some hard work. +// } +// +// w.Write([]byte("done")) +// }) +func Timeout(timeout time.Duration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer func() { + cancel() + if ctx.Err() == context.DeadlineExceeded { + w.WriteHeader(http.StatusGatewayTimeout) + } + }() + + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/testdata/chi/middleware/url_format.go b/testdata/chi/middleware/url_format.go new file mode 100644 index 0000000..2ec6657 --- /dev/null +++ b/testdata/chi/middleware/url_format.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "context" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +var ( + // URLFormatCtxKey is the context.Context key to store the URL format data + // for a request. + URLFormatCtxKey = &contextKey{"URLFormat"} +) + +// URLFormat is a middleware that parses the url extension from a request path and stores it +// on the context as a string under the key `middleware.URLFormatCtxKey`. The middleware will +// trim the suffix from the routing path and continue routing. +// +// Routers should not include a url parameter for the suffix when using this middleware. +// +// Sample usage for url paths `/articles/1`, `/articles/1.json` and `/articles/1.xml`: +// +// func routes() http.Handler { +// r := chi.NewRouter() +// r.Use(middleware.URLFormat) +// +// r.Get("/articles/{id}", ListArticles) +// +// return r +// } +// +// func ListArticles(w http.ResponseWriter, r *http.Request) { +// urlFormat, _ := r.Context().Value(middleware.URLFormatCtxKey).(string) +// +// switch urlFormat { +// case "json": +// render.JSON(w, r, articles) +// case "xml:" +// render.XML(w, r, articles) +// default: +// render.JSON(w, r, articles) +// } +// } +func URLFormat(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var format string + path := r.URL.Path + + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + path = rctx.RoutePath + } + + if strings.Index(path, ".") > 0 { + base := strings.LastIndex(path, "/") + idx := strings.LastIndex(path[base:], ".") + + if idx > 0 { + idx += base + format = path[idx+1:] + + if rctx != nil { + rctx.RoutePath = path[:idx] + } + } + } + + r = r.WithContext(context.WithValue(ctx, URLFormatCtxKey, format)) + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} diff --git a/testdata/chi/middleware/value.go b/testdata/chi/middleware/value.go new file mode 100644 index 0000000..a9dfd43 --- /dev/null +++ b/testdata/chi/middleware/value.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "context" + "net/http" +) + +// WithValue is a middleware that sets a given key/value in a context chain. +func WithValue(key, val interface{}) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), key, val)) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/testdata/chi/middleware/wrap_writer.go b/testdata/chi/middleware/wrap_writer.go new file mode 100644 index 0000000..b2de875 --- /dev/null +++ b/testdata/chi/middleware/wrap_writer.go @@ -0,0 +1,243 @@ +package middleware + +// The original work was derived from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "bufio" + "io" + "net" + "net/http" +) + +// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to +// hook into various parts of the response process. +func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter { + _, fl := w.(http.Flusher) + + bw := basicWriter{ResponseWriter: w} + + if protoMajor == 2 { + _, ps := w.(http.Pusher) + if fl && ps { + return &http2FancyWriter{bw} + } + } else { + _, hj := w.(http.Hijacker) + _, rf := w.(io.ReaderFrom) + if fl && hj && rf { + return &httpFancyWriter{bw} + } + if fl && hj { + return &flushHijackWriter{bw} + } + if hj { + return &hijackWriter{bw} + } + } + + if fl { + return &flushWriter{bw} + } + + return &bw +} + +// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook +// into various parts of the response process. +type WrapResponseWriter interface { + http.ResponseWriter + // Status returns the HTTP status of the request, or 0 if one has not + // yet been sent. + Status() int + // BytesWritten returns the total number of bytes sent to the client. + BytesWritten() int + // Tee causes the response body to be written to the given io.Writer in + // addition to proxying the writes through. Only one io.Writer can be + // tee'd to at once: setting a second one will overwrite the first. + // Writes will be sent to the proxy before being written to this + // io.Writer. It is illegal for the tee'd writer to be modified + // concurrently with writes. + Tee(io.Writer) + // Unwrap returns the original proxied target. + Unwrap() http.ResponseWriter + // Discard causes all writes to the original ResponseWriter be discarded, + // instead writing only to the tee'd writer if it's set. + // The caller is responsible for calling WriteHeader and Write on the + // original ResponseWriter once the processing is done. + Discard() +} + +// basicWriter wraps a http.ResponseWriter that implements the minimal +// http.ResponseWriter interface. +type basicWriter struct { + http.ResponseWriter + tee io.Writer + code int + bytes int + wroteHeader bool + discard bool +} + +func (b *basicWriter) WriteHeader(code int) { + if code >= 100 && code <= 199 && code != http.StatusSwitchingProtocols { + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } + } else if !b.wroteHeader { + b.code = code + b.wroteHeader = true + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } + } +} + +func (b *basicWriter) Write(buf []byte) (n int, err error) { + b.maybeWriteHeader() + if !b.discard { + n, err = b.ResponseWriter.Write(buf) + if b.tee != nil { + _, err2 := b.tee.Write(buf[:n]) + // Prefer errors generated by the proxied writer. + if err == nil { + err = err2 + } + } + } else if b.tee != nil { + n, err = b.tee.Write(buf) + } else { + n, err = io.Discard.Write(buf) + } + b.bytes += n + return n, err +} + +func (b *basicWriter) maybeWriteHeader() { + if !b.wroteHeader { + b.WriteHeader(http.StatusOK) + } +} + +func (b *basicWriter) Status() int { + return b.code +} + +func (b *basicWriter) BytesWritten() int { + return b.bytes +} + +func (b *basicWriter) Tee(w io.Writer) { + b.tee = w +} + +func (b *basicWriter) Unwrap() http.ResponseWriter { + return b.ResponseWriter +} + +func (b *basicWriter) Discard() { + b.discard = true +} + +// flushWriter ... +type flushWriter struct { + basicWriter +} + +func (f *flushWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &flushWriter{} + +// hijackWriter ... +type hijackWriter struct { + basicWriter +} + +func (f *hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +var _ http.Hijacker = &hijackWriter{} + +// flushHijackWriter ... +type flushHijackWriter struct { + basicWriter +} + +func (f *flushHijackWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +func (f *flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +var _ http.Flusher = &flushHijackWriter{} +var _ http.Hijacker = &flushHijackWriter{} + +// httpFancyWriter is a HTTP writer that additionally satisfies +// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type httpFancyWriter struct { + basicWriter +} + +func (f *httpFancyWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error { + return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts) +} + +func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) { + if f.basicWriter.tee != nil { + // Route through basicWriter.Write so that data is also written to the + // tee writer. basicWriter.Write already increments basicWriter.bytes, + // so we must NOT add n again here (that would double-count). + n, err := io.Copy(&f.basicWriter, r) + return n, err + } + rf := f.basicWriter.ResponseWriter.(io.ReaderFrom) + f.basicWriter.maybeWriteHeader() + n, err := rf.ReadFrom(r) + f.basicWriter.bytes += int(n) + return n, err +} + +var _ http.Flusher = &httpFancyWriter{} +var _ http.Hijacker = &httpFancyWriter{} +var _ http.Pusher = &http2FancyWriter{} +var _ io.ReaderFrom = &httpFancyWriter{} + +// http2FancyWriter is a HTTP2 writer that additionally satisfies +// http.Flusher, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type http2FancyWriter struct { + basicWriter +} + +func (f *http2FancyWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &http2FancyWriter{} diff --git a/testdata/chi/mux.go b/testdata/chi/mux.go new file mode 100644 index 0000000..3da7f3f --- /dev/null +++ b/testdata/chi/mux.go @@ -0,0 +1,526 @@ +package chi + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" +) + +var _ Router = &Mux{} + +// Mux is a simple HTTP route multiplexer that parses a request path, +// records any URL params, and executes an end handler. It implements +// the http.Handler interface and is friendly with the standard library. +// +// Mux is designed to be fast, minimal and offer a powerful API for building +// modular and composable HTTP services with a large set of handlers. It's +// particularly useful for writing large REST API services that break a handler +// into many smaller parts composed of middlewares and end handlers. +type Mux struct { + // The computed mux handler made of the chained middleware stack and + // the tree router + handler http.Handler + + // The radix trie router + tree *node + + // Custom method not allowed handler + methodNotAllowedHandler http.HandlerFunc + + // A reference to the parent mux used by subrouters when mounting + // to a parent mux + parent *Mux + + // Routing context pool + pool *sync.Pool + + // Custom route not found handler + notFoundHandler http.HandlerFunc + + // The middleware stack + middlewares []func(http.Handler) http.Handler + + // Controls the behaviour of middleware chain generation when a mux + // is registered as an inline group inside another mux. + inline bool +} + +// NewMux returns a newly initialized Mux object that implements the Router +// interface. +func NewMux() *Mux { + mux := &Mux{tree: &node{}, pool: &sync.Pool{}} + mux.pool.New = func() interface{} { + return NewRouteContext() + } + return mux +} + +// ServeHTTP is the single method of the http.Handler interface that makes +// Mux interoperable with the standard library. It uses a sync.Pool to get and +// reuse routing contexts for each request. +func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Ensure the mux has some routes defined on the mux + if mx.handler == nil { + mx.NotFoundHandler().ServeHTTP(w, r) + return + } + + // Check if a routing context already exists from a parent router. + rctx, _ := r.Context().Value(RouteCtxKey).(*Context) + if rctx != nil { + mx.handler.ServeHTTP(w, r) + return + } + + // Fetch a RouteContext object from the sync pool, and call the computed + // mx.handler that is comprised of mx.middlewares + mx.routeHTTP. + // Once the request is finished, reset the routing context and put it back + // into the pool for reuse from another request. + rctx = mx.pool.Get().(*Context) + rctx.Reset() + rctx.Routes = mx + rctx.parentCtx = r.Context() + + // NOTE: r.WithContext() causes 2 allocations and context.WithValue() causes 1 allocation + r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx)) + + // Serve the request and once its done, put the request context back in the sync pool + mx.handler.ServeHTTP(w, r) + mx.pool.Put(rctx) +} + +// Use appends a middleware handler to the Mux middleware stack. +// +// The middleware stack for any Mux will execute before searching for a matching +// route to a specific handler, which provides opportunity to respond early, +// change the course of the request execution, or set request-scoped values for +// the next http.Handler. +func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) { + if mx.handler != nil { + panic("chi: all middlewares must be defined before routes on a mux") + } + mx.middlewares = append(mx.middlewares, middlewares...) +} + +// Handle adds the route `pattern` that matches any http method to +// execute the `handler` http.Handler. +func (mx *Mux) Handle(pattern string, handler http.Handler) { + if i := strings.IndexAny(pattern, " \t"); i >= 0 { + method, rest := pattern[:i], strings.TrimLeft(pattern[i+1:], " \t") + mx.Method(method, rest, handler) + return + } + + mx.handle(mALL, pattern, handler) +} + +// HandleFunc adds the route `pattern` that matches any http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) { + mx.Handle(pattern, handlerFn) +} + +// Method adds the route `pattern` that matches `method` http method to +// execute the `handler` http.Handler. +func (mx *Mux) Method(method, pattern string, handler http.Handler) { + m, ok := methodMap[strings.ToUpper(method)] + if !ok { + panic(fmt.Sprintf("chi: '%s' http method is not supported.", method)) + } + mx.handle(m, pattern, handler) +} + +// MethodFunc adds the route `pattern` that matches `method` http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) { + mx.Method(method, pattern, handlerFn) +} + +// Connect adds the route `pattern` that matches a CONNECT http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Connect(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mCONNECT, pattern, handlerFn) +} + +// Delete adds the route `pattern` that matches a DELETE http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Delete(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mDELETE, pattern, handlerFn) +} + +// Get adds the route `pattern` that matches a GET http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Get(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mGET, pattern, handlerFn) +} + +// Head adds the route `pattern` that matches a HEAD http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Head(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mHEAD, pattern, handlerFn) +} + +// Options adds the route `pattern` that matches an OPTIONS http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Options(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mOPTIONS, pattern, handlerFn) +} + +// Patch adds the route `pattern` that matches a PATCH http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Patch(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mPATCH, pattern, handlerFn) +} + +// Post adds the route `pattern` that matches a POST http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Post(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mPOST, pattern, handlerFn) +} + +// Put adds the route `pattern` that matches a PUT http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Put(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mPUT, pattern, handlerFn) +} + +// Trace adds the route `pattern` that matches a TRACE http method to +// execute the `handlerFn` http.HandlerFunc. +func (mx *Mux) Trace(pattern string, handlerFn http.HandlerFunc) { + mx.handle(mTRACE, pattern, handlerFn) +} + +// NotFound sets a custom http.HandlerFunc for routing paths that could +// not be found. The default 404 handler is `http.NotFound`. +func (mx *Mux) NotFound(handlerFn http.HandlerFunc) { + // Build NotFound handler chain + m := mx + hFn := handlerFn + if mx.inline && mx.parent != nil { + m = mx.parent + hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP + } + + // Update the notFoundHandler from this point forward + m.notFoundHandler = hFn + m.updateSubRoutes(func(subMux *Mux) { + if subMux.notFoundHandler == nil { + subMux.NotFound(hFn) + } + }) +} + +// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the +// method is unresolved. The default handler returns a 405 with an empty body. +func (mx *Mux) MethodNotAllowed(handlerFn http.HandlerFunc) { + // Build MethodNotAllowed handler chain + m := mx + hFn := handlerFn + if mx.inline && mx.parent != nil { + m = mx.parent + hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP + } + + // Update the methodNotAllowedHandler from this point forward + m.methodNotAllowedHandler = hFn + m.updateSubRoutes(func(subMux *Mux) { + if subMux.methodNotAllowedHandler == nil { + subMux.MethodNotAllowed(hFn) + } + }) +} + +// With adds inline middlewares for an endpoint handler. +func (mx *Mux) With(middlewares ...func(http.Handler) http.Handler) Router { + // Similarly as in handle(), we must build the mux handler once additional + // middleware registration isn't allowed for this stack, like now. + if !mx.inline && mx.handler == nil { + mx.updateRouteHandler() + } + + // Copy middlewares from parent inline muxs + var mws Middlewares + if mx.inline { + mws = make(Middlewares, len(mx.middlewares)) + copy(mws, mx.middlewares) + } + mws = append(mws, middlewares...) + + im := &Mux{ + pool: mx.pool, inline: true, parent: mx, tree: mx.tree, middlewares: mws, + notFoundHandler: mx.notFoundHandler, methodNotAllowedHandler: mx.methodNotAllowedHandler, + } + + return im +} + +// Group creates a new inline-Mux with a copy of middleware stack. It's useful +// for a group of handlers along the same routing path that use an additional +// set of middlewares. See _examples/. +func (mx *Mux) Group(fn func(r Router)) Router { + im := mx.With() + if fn != nil { + fn(im) + } + return im +} + +// Route creates a new Mux and mounts it along the `pattern` as a subrouter. +// Effectively, this is a short-hand call to Mount. See _examples/. +func (mx *Mux) Route(pattern string, fn func(r Router)) Router { + if fn == nil { + panic(fmt.Sprintf("chi: attempting to Route() a nil subrouter on '%s'", pattern)) + } + subRouter := NewRouter() + fn(subRouter) + mx.Mount(pattern, subRouter) + return subRouter +} + +// Mount attaches another http.Handler or chi Router as a subrouter along a routing +// path. It's very useful to split up a large API as many independent routers and +// compose them as a single service using Mount. See _examples/. +// +// Note that Mount() simply sets a wildcard along the `pattern` that will continue +// routing at the `handler`, which in most cases is another chi.Router. As a result, +// if you define two Mount() routes on the exact same pattern the mount will panic. +func (mx *Mux) Mount(pattern string, handler http.Handler) { + if handler == nil { + panic(fmt.Sprintf("chi: attempting to Mount() a nil handler on '%s'", pattern)) + } + + // Provide runtime safety for ensuring a pattern isn't mounted on an existing + // routing pattern. + if mx.tree.findPattern(pattern+"*") || mx.tree.findPattern(pattern+"/*") { + panic(fmt.Sprintf("chi: attempting to Mount() a handler on an existing path, '%s'", pattern)) + } + + // Assign sub-Router's with the parent not found & method not allowed handler if not specified. + subr, ok := handler.(*Mux) + if ok && subr.notFoundHandler == nil && mx.notFoundHandler != nil { + subr.NotFound(mx.notFoundHandler) + } + if ok && subr.methodNotAllowedHandler == nil && mx.methodNotAllowedHandler != nil { + subr.MethodNotAllowed(mx.methodNotAllowedHandler) + } + + mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := RouteContext(r.Context()) + + // shift the url path past the previous subrouter + rctx.RoutePath = mx.nextRoutePath(rctx) + + // reset the wildcard URLParam which connects the subrouter + n := len(rctx.URLParams.Keys) - 1 + if n >= 0 && rctx.URLParams.Keys[n] == "*" && len(rctx.URLParams.Values) > n { + rctx.URLParams.Values[n] = "" + } + + handler.ServeHTTP(w, r) + }) + + if pattern == "" || pattern[len(pattern)-1] != '/' { + mx.handle(mALL|mSTUB, pattern, mountHandler) + mx.handle(mALL|mSTUB, pattern+"/", mountHandler) + pattern += "/" + } + + method := mALL + subroutes, _ := handler.(Routes) + if subroutes != nil { + method |= mSTUB + } + n := mx.handle(method, pattern+"*", mountHandler) + + if subroutes != nil { + n.subroutes = subroutes + } +} + +// Routes returns a slice of routing information from the tree, +// useful for traversing available routes of a router. +func (mx *Mux) Routes() []Route { + return mx.tree.routes() +} + +// Middlewares returns a slice of middleware handler functions. +func (mx *Mux) Middlewares() Middlewares { + return mx.middlewares +} + +// Match searches the routing tree for a handler that matches the method/path. +// It's similar to routing a http request, but without executing the handler +// thereafter. +// +// Note: the *Context state is updated during execution, so manage +// the state carefully or make a NewRouteContext(). +func (mx *Mux) Match(rctx *Context, method, path string) bool { + return mx.Find(rctx, method, path) != "" +} + +// Find searches the routing tree for the pattern that matches +// the method/path. +// +// Note: the *Context state is updated during execution, so manage +// the state carefully or make a NewRouteContext(). +func (mx *Mux) Find(rctx *Context, method, path string) string { + m, ok := methodMap[method] + if !ok { + return "" + } + + node, _, _ := mx.tree.FindRoute(rctx, m, path) + pattern := rctx.routePattern + + if node != nil { + if node.subroutes == nil { + e := node.endpoints[m] + return e.pattern + } + + rctx.RoutePath = mx.nextRoutePath(rctx) + subPattern := node.subroutes.Find(rctx, method, rctx.RoutePath) + if subPattern == "" { + return "" + } + + pattern = strings.TrimSuffix(pattern, "/*") + pattern += subPattern + } + + return pattern +} + +// NotFoundHandler returns the default Mux 404 responder whenever a route +// cannot be found. +func (mx *Mux) NotFoundHandler() http.HandlerFunc { + if mx.notFoundHandler != nil { + return mx.notFoundHandler + } + return http.NotFound +} + +// MethodNotAllowedHandler returns the default Mux 405 responder whenever +// a method cannot be resolved for a route. +func (mx *Mux) MethodNotAllowedHandler(methodsAllowed ...methodTyp) http.HandlerFunc { + if mx.methodNotAllowedHandler != nil { + return mx.methodNotAllowedHandler + } + return methodNotAllowedHandler(methodsAllowed...) +} + +// handle registers a http.Handler in the routing tree for a particular http method +// and routing pattern. +func (mx *Mux) handle(method methodTyp, pattern string, handler http.Handler) *node { + if len(pattern) == 0 || pattern[0] != '/' { + panic(fmt.Sprintf("chi: routing pattern must begin with '/' in '%s'", pattern)) + } + + // Build the computed routing handler for this routing pattern. + if !mx.inline && mx.handler == nil { + mx.updateRouteHandler() + } + + // Build endpoint handler with inline middlewares for the route + var h http.Handler + if mx.inline { + mx.handler = http.HandlerFunc(mx.routeHTTP) + h = Chain(mx.middlewares...).Handler(handler) + } else { + h = handler + } + + // Add the endpoint to the tree and return the node + return mx.tree.InsertRoute(method, pattern, h) +} + +// routeHTTP routes a http.Request through the Mux routing tree to serve +// the matching handler for a particular http method. +func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) { + // Grab the route context object + rctx := r.Context().Value(RouteCtxKey).(*Context) + + // The request routing path + routePath := rctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + if routePath == "" { + routePath = "/" + } + } + + // Check if method is supported by chi + if rctx.RouteMethod == "" { + rctx.RouteMethod = r.Method + } + method, ok := methodMap[rctx.RouteMethod] + if !ok { + mx.MethodNotAllowedHandler().ServeHTTP(w, r) + return + } + + // Find the route + if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil { + // Set http.Request path values from our request context + for i, key := range rctx.URLParams.Keys { + value := rctx.URLParams.Values[i] + r.SetPathValue(key, value) + } + r.Pattern = rctx.RoutePattern() + + h.ServeHTTP(w, r) + return + } + if rctx.methodNotAllowed { + mx.MethodNotAllowedHandler(rctx.methodsAllowed...).ServeHTTP(w, r) + } else { + mx.NotFoundHandler().ServeHTTP(w, r) + } +} + +func (mx *Mux) nextRoutePath(rctx *Context) string { + routePath := "/" + nx := len(rctx.routeParams.Keys) - 1 // index of last param in list + if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx { + routePath = "/" + rctx.routeParams.Values[nx] + } + return routePath +} + +// Recursively update data on child routers. +func (mx *Mux) updateSubRoutes(fn func(subMux *Mux)) { + for _, r := range mx.tree.routes() { + subMux, ok := r.SubRoutes.(*Mux) + if !ok { + continue + } + fn(subMux) + } +} + +// updateRouteHandler builds the single mux handler that is a chain of the middleware +// stack, as defined by calls to Use(), and the tree router (Mux) itself. After this +// point, no other middlewares can be registered on this Mux's stack. But you can still +// compose additional middlewares via Group()'s or using a chained middleware handler. +func (mx *Mux) updateRouteHandler() { + mx.handler = chain(mx.middlewares, http.HandlerFunc(mx.routeHTTP)) +} + +// methodNotAllowedHandler is a helper function to respond with a 405, +// method not allowed. It sets the Allow header with the list of allowed +// methods for the route. +func methodNotAllowedHandler(methodsAllowed ...methodTyp) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + for _, m := range methodsAllowed { + w.Header().Add("Allow", reverseMethodMap[m]) + } + w.WriteHeader(405) + w.Write(nil) + } +} diff --git a/testdata/chi/tree.go b/testdata/chi/tree.go new file mode 100644 index 0000000..95f31d4 --- /dev/null +++ b/testdata/chi/tree.go @@ -0,0 +1,877 @@ +package chi + +// Radix tree implementation below is a based on the original work by +// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go +// (MIT licensed). It's been heavily modified for use as a HTTP routing tree. + +import ( + "fmt" + "net/http" + "regexp" + "slices" + "sort" + "strconv" + "strings" +) + +type methodTyp uint + +const ( + mSTUB methodTyp = 1 << iota + mCONNECT + mDELETE + mGET + mHEAD + mOPTIONS + mPATCH + mPOST + mPUT + mTRACE +) + +var mALL = mCONNECT | mDELETE | mGET | mHEAD | + mOPTIONS | mPATCH | mPOST | mPUT | mTRACE + +var methodMap = map[string]methodTyp{ + http.MethodConnect: mCONNECT, + http.MethodDelete: mDELETE, + http.MethodGet: mGET, + http.MethodHead: mHEAD, + http.MethodOptions: mOPTIONS, + http.MethodPatch: mPATCH, + http.MethodPost: mPOST, + http.MethodPut: mPUT, + http.MethodTrace: mTRACE, +} + +var reverseMethodMap = map[methodTyp]string{ + mCONNECT: http.MethodConnect, + mDELETE: http.MethodDelete, + mGET: http.MethodGet, + mHEAD: http.MethodHead, + mOPTIONS: http.MethodOptions, + mPATCH: http.MethodPatch, + mPOST: http.MethodPost, + mPUT: http.MethodPut, + mTRACE: http.MethodTrace, +} + +// RegisterMethod adds support for custom HTTP method handlers, available +// via Router#Method and Router#MethodFunc +func RegisterMethod(method string) { + if method == "" { + return + } + method = strings.ToUpper(method) + if _, ok := methodMap[method]; ok { + return + } + n := len(methodMap) + if n > strconv.IntSize-2 { + panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize)) + } + mt := methodTyp(2 << n) + methodMap[method] = mt + reverseMethodMap[mt] = method + mALL |= mt +} + +type nodeTyp uint8 + +const ( + ntStatic nodeTyp = iota // /home + ntRegexp // /{id:[0-9]+} + ntParam // /{user} + ntCatchAll // /api/v1/* +) + +type node struct { + // subroutes on the leaf node + subroutes Routes + + // regexp matcher for regexp nodes + rex *regexp.Regexp + + // HTTP handler endpoints on the leaf node + endpoints endpoints + + // prefix is the common prefix we ignore + prefix string + + // child nodes should be stored in-order for iteration, + // in groups of the node type. + children [ntCatchAll + 1]nodes + + // first byte of the child prefix + tail byte + + // node type: static, regexp, param, catchAll + typ nodeTyp + + // first byte of the prefix + label byte +} + +// endpoints is a mapping of http method constants to handlers +// for a given route. +type endpoints map[methodTyp]*endpoint + +type endpoint struct { + // endpoint handler + handler http.Handler + + // pattern is the routing pattern for handler nodes + pattern string + + // parameter keys recorded on handler nodes + paramKeys []string +} + +func (s endpoints) Value(method methodTyp) *endpoint { + mh, ok := s[method] + if !ok { + mh = &endpoint{} + s[method] = mh + } + return mh +} + +func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node { + var parent *node + search := pattern + + for { + // Handle key exhaustion + if len(search) == 0 { + // Insert or update the node's leaf handler + n.setEndpoint(method, handler, pattern) + return n + } + + // We're going to be searching for a wild node next, + // in this case, we need to get the tail + var label = search[0] + var segTail byte + var segEndIdx int + var segTyp nodeTyp + var segRexpat string + if label == '{' || label == '*' { + segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search) + } + + var prefix string + if segTyp == ntRegexp { + prefix = segRexpat + } + + // Look for the edge to attach to + parent = n + n = n.getEdge(segTyp, label, segTail, prefix) + + // No edge, create one + if n == nil { + child := &node{label: label, tail: segTail, prefix: search} + hn := parent.addChild(child, search) + hn.setEndpoint(method, handler, pattern) + + return hn + } + + // Found an edge to match the pattern + + if n.typ > ntStatic { + // We found a param node, trim the param from the search path and continue. + // This param/wild pattern segment would already be on the tree from a previous + // call to addChild when creating a new node. + search = search[segEndIdx:] + continue + } + + // Static nodes fall below here. + // Determine longest prefix of the search key on match. + commonPrefix := longestPrefix(search, n.prefix) + if commonPrefix == len(n.prefix) { + // the common prefix is as long as the current node's prefix we're attempting to insert. + // keep the search going. + search = search[commonPrefix:] + continue + } + + // Split the node + child := &node{ + typ: ntStatic, + prefix: search[:commonPrefix], + } + parent.replaceChild(search[0], segTail, child) + + // Restore the existing node + n.label = n.prefix[commonPrefix] + n.prefix = n.prefix[commonPrefix:] + child.addChild(n, n.prefix) + + // If the new key is a subset, set the method/handler on this node and finish. + search = search[commonPrefix:] + if len(search) == 0 { + child.setEndpoint(method, handler, pattern) + return child + } + + // Create a new edge for the node + subchild := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn := child.addChild(subchild, search) + hn.setEndpoint(method, handler, pattern) + return hn + } +} + +// addChild appends the new `child` node to the tree using the `pattern` as the trie key. +// For a URL router like chi's, we split the static, param, regexp and wildcard segments +// into different nodes. In addition, addChild will recursively call itself until every +// pattern segment is added to the url pattern tree as individual nodes, depending on type. +func (n *node) addChild(child *node, prefix string) *node { + search := prefix + + // handler leaf node added to the tree is the child. + // this may be overridden later down the flow + hn := child + + // Parse next segment + segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search) + + // Add child depending on next up segment + switch segTyp { + + case ntStatic: + // Search prefix is all static (that is, has no params in path) + // noop + + default: + // Search prefix contains a param, regexp or wildcard + + if segTyp == ntRegexp { + rex, err := regexp.Compile(segRexpat) + if err != nil { + panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat)) + } + child.prefix = segRexpat + child.rex = rex + } + + if segStartIdx == 0 { + // Route starts with a param + child.typ = segTyp + + if segTyp == ntCatchAll { + segStartIdx = -1 + } else { + segStartIdx = segEndIdx + } + if segStartIdx < 0 { + segStartIdx = len(search) + } + child.tail = segTail // for params, we set the tail + + if segStartIdx != len(search) { + // add static edge for the remaining part, split the end. + // its not possible to have adjacent param nodes, so its certainly + // going to be a static node next. + + search = search[segStartIdx:] // advance search position + + nn := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn = child.addChild(nn, search) + } + + } else if segStartIdx > 0 { + // Route has some param + + // starts with a static segment + child.typ = ntStatic + child.prefix = search[:segStartIdx] + child.rex = nil + + // add the param edge node + search = search[segStartIdx:] + + nn := &node{ + typ: segTyp, + label: search[0], + tail: segTail, + } + hn = child.addChild(nn, search) + + } + } + + n.children[child.typ] = append(n.children[child.typ], child) + n.children[child.typ].Sort() + return hn +} + +func (n *node) replaceChild(label, tail byte, child *node) { + for i := 0; i < len(n.children[child.typ]); i++ { + if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { + n.children[child.typ][i] = child + n.children[child.typ][i].label = label + n.children[child.typ][i].tail = tail + return + } + } + panic("chi: replacing missing child") +} + +func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { + nds := n.children[ntyp] + for i := range nds { + if nds[i].label == label && nds[i].tail == tail { + if ntyp == ntRegexp && nds[i].prefix != prefix { + continue + } + return nds[i] + } + } + return nil +} + +func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) { + // Set the handler for the method type on the node + if n.endpoints == nil { + n.endpoints = make(endpoints) + } + + paramKeys := patParamKeys(pattern) + + if method&mSTUB == mSTUB { + n.endpoints.Value(mSTUB).handler = handler + } + if method&mALL == mALL { + h := n.endpoints.Value(mALL) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + for _, m := range methodMap { + h := n.endpoints.Value(m) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } + } else { + h := n.endpoints.Value(method) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } +} + +func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) { + // Reset the context routing pattern and params + rctx.routePattern = "" + rctx.routeParams.Keys = rctx.routeParams.Keys[:0] + rctx.routeParams.Values = rctx.routeParams.Values[:0] + + // Find the routing handlers for the path + rn := n.findRoute(rctx, method, path) + if rn == nil { + return nil, nil, nil + } + + // Record the routing params in the request lifecycle + rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) + rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) + + // Record the routing pattern in the request lifecycle + if rn.endpoints[method].pattern != "" { + rctx.routePattern = rn.endpoints[method].pattern + rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern) + } + + return rn, rn.endpoints, rn.endpoints[method].handler +} + +// Recursive edge traversal by checking all nodeTyp groups along the way. +// It's like searching through a multi-dimensional radix trie. +func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { + nn := n + search := path + + for t, nds := range nn.children { + ntyp := nodeTyp(t) + if len(nds) == 0 { + continue + } + + var xn *node + xsearch := search + + var label byte + if search != "" { + label = search[0] + } + + switch ntyp { + case ntStatic: + xn = nds.findEdge(label) + if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { + continue + } + xsearch = xsearch[len(xn.prefix):] + + case ntParam, ntRegexp: + // short-circuit and return no matching route for empty param values + if xsearch == "" { + continue + } + + // serially loop through each node grouped by the tail delimiter + for _, xn = range nds { + // label for param nodes is the delimiter byte + p := strings.IndexByte(xsearch, xn.tail) + + if p < 0 { + if xn.tail == '/' { + p = len(xsearch) + } else { + continue + } + } else if ntyp == ntRegexp && p == 0 { + continue + } + + if ntyp == ntRegexp && xn.rex != nil { + if !xn.rex.MatchString(xsearch[:p]) { + continue + } + } else if strings.IndexByte(xsearch[:p], '/') != -1 { + // avoid a match across path segments + continue + } + + prevlen := len(rctx.routeParams.Values) + rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) + xsearch = xsearch[p:] + + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node on this branch + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // not found on this branch, reset vars + rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] + xsearch = search + } + + rctx.routeParams.Values = append(rctx.routeParams.Values, "") + + default: + // catch-all nodes + rctx.routeParams.Values = append(rctx.routeParams.Values, search) + xn = nds[0] + xsearch = "" + } + + if xn == nil { + continue + } + + // did we find it yet? + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node.. + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // Did not find final handler, let's remove the param here if it was set + if xn.typ > ntStatic { + if len(rctx.routeParams.Values) > 0 { + rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] + } + } + + } + + return nil +} + +func (n *node) findEdge(ntyp nodeTyp, label byte) *node { + nds := n.children[ntyp] + num := len(nds) + idx := 0 + + switch ntyp { + case ntStatic, ntParam, ntRegexp: + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > nds[idx].label { + i = idx + 1 + } else if label < nds[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if nds[idx].label != label { + return nil + } + return nds[idx] + + default: // catch all + return nds[idx] + } +} + +func (n *node) isLeaf() bool { + return n.endpoints != nil +} + +func (n *node) findPattern(pattern string) bool { + nn := n + for _, nds := range nn.children { + if len(nds) == 0 { + continue + } + + n = nn.findEdge(nds[0].typ, pattern[0]) + if n == nil { + continue + } + + var idx int + var xpattern string + + switch n.typ { + case ntStatic: + idx = longestPrefix(pattern, n.prefix) + if idx < len(n.prefix) { + continue + } + + case ntParam, ntRegexp: + idx = strings.IndexByte(pattern, '}') + 1 + + case ntCatchAll: + idx = longestPrefix(pattern, "*") + + default: + panic("chi: unknown node type") + } + + xpattern = pattern[idx:] + if len(xpattern) == 0 { + return true + } + + return n.findPattern(xpattern) + } + return false +} + +func (n *node) routes() []Route { + rts := []Route{} + + n.walk(func(eps endpoints, subroutes Routes) bool { + if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { + return false + } + + // Group methodHandlers by unique patterns + pats := make(map[string]endpoints) + + for mt, h := range eps { + if h.pattern == "" { + continue + } + p, ok := pats[h.pattern] + if !ok { + p = endpoints{} + pats[h.pattern] = p + } + p[mt] = h + } + + for p, mh := range pats { + hs := make(map[string]http.Handler) + if mh[mALL] != nil && mh[mALL].handler != nil { + hs["*"] = mh[mALL].handler + } + + for mt, h := range mh { + if h.handler == nil { + continue + } + if m, ok := reverseMethodMap[mt]; ok { + hs[m] = h.handler + } + } + + rt := Route{subroutes, hs, p} + rts = append(rts, rt) + } + + return false + }) + + return rts +} + +func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { + // Visit the leaf values if any + if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { + return true + } + + // Recurse on the children + for _, ns := range n.children { + for _, cn := range ns { + if cn.walk(fn) { + return true + } + } + } + return false +} + +// patNextSegment returns the next segment details from a pattern: +// node type, param key, regexp string, param tail byte, param starting index, param ending index +func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) { + ps := strings.Index(pattern, "{") + ws := strings.Index(pattern, "*") + + if ps < 0 && ws < 0 { + return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing + } + + // Sanity check + if ps >= 0 && ws >= 0 && ws < ps { + panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") + } + + var tail byte = '/' // Default endpoint tail to / byte + + if ps >= 0 { + // Param/Regexp pattern is next + nt := ntParam + + // Read to closing } taking into account opens and closes in curl count (cc) + cc := 0 + pe := ps + for i, c := range pattern[ps:] { + if c == '{' { + cc++ + } else if c == '}' { + cc-- + if cc == 0 { + pe = ps + i + break + } + } + } + if pe == ps { + panic("chi: route param closing delimiter '}' is missing") + } + + key := pattern[ps+1 : pe] + pe++ // set end to next position + + if pe < len(pattern) { + tail = pattern[pe] + } + + key, rexpat, isRegexp := strings.Cut(key, ":") + if isRegexp { + nt = ntRegexp + } + + if len(rexpat) > 0 { + if rexpat[0] != '^' { + rexpat = "^" + rexpat + } + if rexpat[len(rexpat)-1] != '$' { + rexpat += "$" + } + } + + return nt, key, rexpat, tail, ps, pe + } + + // Wildcard pattern as finale + if ws < len(pattern)-1 { + panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead") + } + return ntCatchAll, "*", "", 0, ws, len(pattern) +} + +func patParamKeys(pattern string) []string { + pat := pattern + paramKeys := []string{} + for { + ptyp, paramKey, _, _, _, e := patNextSegment(pat) + if ptyp == ntStatic { + return paramKeys + } + for i := 0; i < len(paramKeys); i++ { + if paramKeys[i] == paramKey { + panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey)) + } + } + paramKeys = append(paramKeys, paramKey) + pat = pat[e:] + } +} + +// longestPrefix finds the length of the shared prefix of two strings +func longestPrefix(k1, k2 string) (i int) { + for i = 0; i < min(len(k1), len(k2)); i++ { + if k1[i] != k2[i] { + break + } + } + return +} + +type nodes []*node + +// Sort the list of nodes by label +func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } +func (ns nodes) Len() int { return len(ns) } +func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } +func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } + +// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. +// The list order determines the traversal order. +func (ns nodes) tailSort() { + for i := len(ns) - 1; i >= 0; i-- { + if ns[i].typ > ntStatic && ns[i].tail == '/' { + ns.Swap(i, len(ns)-1) + return + } + } +} + +func (ns nodes) findEdge(label byte) *node { + num := len(ns) + idx := 0 + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > ns[idx].label { + i = idx + 1 + } else if label < ns[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if ns[idx].label != label { + return nil + } + return ns[idx] +} + +// Route describes the details of a routing handler. +// Handlers map key is an HTTP method +type Route struct { + SubRoutes Routes + Handlers map[string]http.Handler + Pattern string +} + +// WalkFunc is the type of the function called for each method and route visited by Walk. +type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error + +// Walk walks any router tree that implements Routes interface. +func Walk(r Routes, walkFn WalkFunc) error { + return walk(r, walkFn, "") +} + +func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error { + for _, route := range r.Routes() { + mws := slices.Concat(parentMw, r.Middlewares()) + + if route.SubRoutes != nil { + if handler, ok := route.Handlers["*"]; ok { + if chain, ok := handler.(*ChainHandler); ok { + mws = append(mws, chain.Middlewares...) + } + } + + if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { + return err + } + continue + } + + for method, handler := range route.Handlers { + if method == "*" { + // Ignore a "catchAll" method, since we pass down all the specific methods for each route. + continue + } + + fullRoute := parentRoute + route.Pattern + fullRoute = strings.ReplaceAll(fullRoute, "/*/", "/") + + if chain, ok := handler.(*ChainHandler); ok { + if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil { + return err + } + } else { + if err := walkFn(method, fullRoute, handler, mws...); err != nil { + return err + } + } + } + } + + return nil +} diff --git a/testdata/fixture/go.mod b/testdata/fixture/go.mod deleted file mode 100644 index 30818c1..0000000 --- a/testdata/fixture/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module example.com/fixture - -go 1.21 diff --git a/testdata/generics/fn/fn.go b/testdata/generics/fn/fn.go new file mode 100644 index 0000000..a18a43b --- /dev/null +++ b/testdata/generics/fn/fn.go @@ -0,0 +1,46 @@ +package fn + +// Ordered is a union-constraint interface for comparable ordered primitives. +type Ordered interface { + ~int | ~int64 | ~float64 | ~string +} + +// Numeric extends Ordered with unsigned integer kinds. +type Numeric interface { + Ordered + ~uint | ~uint64 +} + +func Min[T Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +func Max[T Ordered](a, b T) T { + if a > b { + return a + } + return b +} + +// Map applies f to every element of in and returns the results. +func Map[T, U any](in []T, f func(T) U) []U { + out := make([]U, len(in)) + for i, v := range in { + out[i] = f(v) + } + return out +} + +// Filter returns the elements of in for which keep returns true. +func Filter[T any](in []T, keep func(T) bool) []T { + var out []T + for _, v := range in { + if keep(v) { + out = append(out, v) + } + } + return out +} diff --git a/testdata/generics/go.mod b/testdata/generics/go.mod new file mode 100644 index 0000000..3ad2259 --- /dev/null +++ b/testdata/generics/go.mod @@ -0,0 +1,3 @@ +module example.com/generics + +go 1.21 diff --git a/testdata/generics/main.go b/testdata/generics/main.go new file mode 100644 index 0000000..0ba0512 --- /dev/null +++ b/testdata/generics/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "fmt" + + "example.com/generics/fn" + "example.com/generics/set" +) + +func main() { + s := set.New[string]() + s.Add("hello") + s.Add("world") + fmt.Println(s.Contains("hello"), s.Len()) + + fmt.Println(fn.Min(3, 7)) + fmt.Println(fn.Max(3.14, 2.72)) + + nums := fn.Map([]int{1, 2, 3}, func(x int) string { return fmt.Sprintf("%d", x) }) + evens := fn.Filter([]int{1, 2, 3, 4}, func(x int) bool { return x%2 == 0 }) + fmt.Println(nums, evens) +} diff --git a/testdata/generics/set/set.go b/testdata/generics/set/set.go new file mode 100644 index 0000000..c5039d4 --- /dev/null +++ b/testdata/generics/set/set.go @@ -0,0 +1,36 @@ +package set + +// Set is a generic hash set. +type Set[T comparable] struct { + items map[T]struct{} +} + +func New[T comparable]() *Set[T] { + return &Set[T]{items: make(map[T]struct{})} +} + +func (s *Set[T]) Add(v T) { + s.items[v] = struct{}{} +} + +func (s *Set[T]) Remove(v T) { + delete(s.items, v) +} + +func (s *Set[T]) Contains(v T) bool { + _, ok := s.items[v] + return ok +} + +func (s *Set[T]) Len() int { + return len(s.items) +} + +// Snapshot returns all elements as a slice. Unexported helper for internal use. +func (s *Set[T]) snapshot() []T { + out := make([]T, 0, len(s.items)) + for k := range s.items { + out = append(out, k) + } + return out +} diff --git a/testdata/greeter/go.mod b/testdata/greeter/go.mod new file mode 100644 index 0000000..162252e --- /dev/null +++ b/testdata/greeter/go.mod @@ -0,0 +1,3 @@ +module example.com/greeter + +go 1.21 diff --git a/testdata/fixture/main.go b/testdata/greeter/main.go similarity index 82% rename from testdata/fixture/main.go rename to testdata/greeter/main.go index f20e3f9..1436fd2 100644 --- a/testdata/fixture/main.go +++ b/testdata/greeter/main.go @@ -3,7 +3,7 @@ package main import ( "fmt" - "example.com/fixture/pkg/greeter" + "example.com/greeter/pkg/greeter" ) func main() { diff --git a/testdata/fixture/pkg/greeter/greeter.go b/testdata/greeter/pkg/greeter/greeter.go similarity index 100% rename from testdata/fixture/pkg/greeter/greeter.go rename to testdata/greeter/pkg/greeter/greeter.go diff --git a/testdata/multipackage/go.mod b/testdata/multipackage/go.mod new file mode 100644 index 0000000..6a64ad1 --- /dev/null +++ b/testdata/multipackage/go.mod @@ -0,0 +1,3 @@ +module example.com/multipackage + +go 1.21 diff --git a/testdata/realistic/main.go b/testdata/multipackage/main.go similarity index 87% rename from testdata/realistic/main.go rename to testdata/multipackage/main.go index 81ae136..d3d6f89 100644 --- a/testdata/realistic/main.go +++ b/testdata/multipackage/main.go @@ -4,8 +4,8 @@ import ( "fmt" "log" - "example.com/realistic/server" - "example.com/realistic/worker" + "example.com/multipackage/server" + "example.com/multipackage/worker" ) func main() { diff --git a/testdata/realistic/server/middleware.go b/testdata/multipackage/server/middleware.go similarity index 100% rename from testdata/realistic/server/middleware.go rename to testdata/multipackage/server/middleware.go diff --git a/testdata/realistic/server/server.go b/testdata/multipackage/server/server.go similarity index 100% rename from testdata/realistic/server/server.go rename to testdata/multipackage/server/server.go diff --git a/testdata/multipackage/server/server_test.go b/testdata/multipackage/server/server_test.go new file mode 100644 index 0000000..351ae34 --- /dev/null +++ b/testdata/multipackage/server/server_test.go @@ -0,0 +1,16 @@ +package server + +import "testing" + +// TestServer_Addr is a minimal test in the realistic fixture's server package. +// Its only purpose is to give the --skip-tests=false integration test a +// _test.go file to look for in the symbol table. +func TestServer_Addr(t *testing.T) { + s, err := New(Config{Host: "localhost", Port: 8080}) + if err != nil { + t.Fatalf("New: %v", err) + } + if s.Addr() == "" { + t.Error("Addr() returned empty string") + } +} diff --git a/testdata/realistic/worker/worker.go b/testdata/multipackage/worker/worker.go similarity index 100% rename from testdata/realistic/worker/worker.go rename to testdata/multipackage/worker/worker.go diff --git a/testdata/realistic/go.mod b/testdata/realistic/go.mod deleted file mode 100644 index f1da8db..0000000 --- a/testdata/realistic/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module example.com/realistic - -go 1.21 From 80f7407e908e2fbb7bacaebd32aa07a30d0bb05c Mon Sep 17 00:00:00 2001 From: Rahul Krishna Date: Wed, 1 Jul 2026 19:27:59 -0400 Subject: [PATCH 4/4] chore: add CLAUDE.md agent guidance and AGENTS.md/GEMINI.md symlinks Closes #1 --- AGENTS.md | 1 + CLAUDE.md | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ GEMINI.md | 1 + 3 files changed, 134 insertions(+) create mode 120000 AGENTS.md create mode 100644 CLAUDE.md create mode 120000 GEMINI.md diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 0000000..681311e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..16816aa --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,132 @@ +# CLAUDE.md + +Agent guidance for `codellm-devkit/codeanalyzer-go` (`codeanalyzer-go`). + +Respect the global `~/.claude/CLAUDE.md` instructions strictly. + +## What this project is + +`codeanalyzer-go` is the CLDK Go static analyzer. It emits the canonical CLDK +`analysis.json` — a **symbol table** plus a **call graph** — consumable by the Python SDK +via `CLDK(language="go").analysis(project_path=...)`. It mirrors its +[Python](https://github.com/codellm-devkit/codeanalyzer-python) (`canpy`), +[TypeScript](https://github.com/codellm-devkit/codeanalyzer-typescript) (`cants`), and +[Java](https://github.com/codellm-devkit/codeanalyzer-java) sibling analyzers, so +output-shape parity with them is a first-class concern. + +It builds on **`golang.org/x/tools/go/packages`** (loaded with syntax + types + deps) plus +stdlib `go/ast`, `go/token`, and `go/types`. The call graph is a hand-rolled, CHA-style +**resolver over `go/types`** (declared-type dispatch) — it deliberately does *not* use +`go/ssa` or `x/tools/go/callgraph`. Edges are emitted only for project-internal callees; +external/stdlib callees get their `callee_signature` backfilled but no edge (mirroring +Python/Jedi). + +> **Status — read this first.** This is the newest backend. Implemented: the level-1 +> symbol table, the level-2 resolver call graph, `go mod` materialization with caching, the +> cobra CLI, incremental `--target-files`, and a pluggable pass framework. **Not yet +> implemented** (be honest about these; don't describe them as working): the **CodeQL** +> provider (`--codeql` is wired but `codeql.*` returns `ErrCodeQLNotImplemented`), +> **msgpack** output, **framework entrypoint finders** (no passes are registered, so +> `entrypoints` in the output is effectively always `{}`), and **Neo4j** projection +> (there is no Neo4j code here at all — JSON is the only output). The implementation +> currently lives on `feat/initial-implementation`; `main` is a stub. + +## Architecture — follow the pipeline + +The whole analyzer is one orchestrator: `Analyzer.Analyze()` in +`internal/core/analyzer.go` (a pure delegator, mirroring Python's `core.py`). Read it +first; everything else is a phase it calls, in order: + +1. **materialize** — `Analyzer.materialize()` runs `go mod download` (skipped without a + `go.mod`; cached by SHA-256 of `go.sum`; failures degrade gracefully). +2. **symbol table** (`internal/syntactic_analysis`) — + `NewSymbolTableBuilder(input).Build(targetFiles, skipTests)`. +3. **call graph** (`internal/semantic_analysis`, `Level >= 2` only) — + `NewCallGraphBuilder(...).Build(symbolTable)` resolves each call site via `go/types`, + backfills `callee_signature`, and emits internal `GoCallEdge`s. +4. **pass pipeline** — `analysis.RunPipeline(app, ctx)` runs topologically-ordered + pluggable passes (none registered yet). +5. **optional CodeQL** (`--codeql` only) — currently a stub; would merge via `MergeEdges`. + +Then `finalizeAndCache()` writes `/analysis_cache.json`, and +`core.WriteOutput()` writes `/analysis.json` (or stdout). + +The output shape is the **structs in `internal/schema/schema.go`** (`GoApplication` is the +top type; JSON keys are snake_case for Pydantic parity). + +## Directory map + +| Path | Responsibility | +|------|----------------| +| `cmd/codeanalyzer/main.go` | Entry point + cobra CLI (`rootCmd`), flag parsing | +| `internal/core/analyzer.go` | `Analyzer.Analyze()` orchestrator — the spine; `WriteOutput` | +| `internal/options/options.go` | `AnalysisOptions` + `AnalysisLevel` (`LevelSymbolTable=1`, `LevelCallGraph=2`) | +| `internal/schema/schema.go` | `GoApplication` structs (the output contract) | +| `internal/syntactic_analysis` | Symbol table (`go/packages` + `go/ast`); `signature.go` = canonical signatures; `export.go` = `Fset()`/`Pkgs()` | +| `internal/semantic_analysis` | Resolver call graph (`call_graph.go`, `go/types`); `codeql/` = CodeQL backend (stub) | +| `internal/analysis` | Pluggable pass framework: `pass.go` (interface), `registry.go` (`RegisterPass`, topo-ordered `RunPipeline`) | +| `internal/frameworks` | Entrypoint-finder base (no concrete finders yet) | +| `internal/utils` | `fs.go` (file discovery, hashing), `logging.go` | +| `testdata/{greeter,multipackage,generics,chi}` | Test fixtures, each with its own `go.mod` | + +## Commands + +Module `github.com/codellm-devkit/codeanalyzer-go`, **Go 1.25+**. No Makefile, no +golangci-lint config. + +- `go build -o codeanalyzer-go ./cmd/codeanalyzer` — build the binary. +- `go run ./cmd/codeanalyzer -i /path/to/project -a 2` — run from source + (`-a 1` = symbol table only, `-a 2` adds the call graph; `-o` outdir, `-t` target files, + `--eager`, `-v`). Default cache dir `~/.cldk/go-cache`. +- `go test ./...` — run tests (force re-run: `go clean -testcache && go test ./...`). +- `go vet ./...` — the only static-check wired up (no linter configured). + +## I implement features myself — you assist + +For feature work, **I write the implementation** to stay fluent in my own analyzer. +Act as a helper, not the author: + +- **Don't write the feature code** or apply edits to implement it unless I explicitly + ask ("write this", "implement X", "apply it"). Default to guiding, not doing. +- **Do** move me fast: explain the relevant phase, point at prior art (e.g. the Python or + Java backend's equivalent stage, or the resolver in `semantic_analysis/call_graph.go`), + sketch signatures/types, outline an approach, and answer questions about the codebase. +- **Review on request:** when I share a diff or push, critique it — correctness, + **parity with the Python/Java/TypeScript backends**, schema shape, missing tests, edge + cases — and suggest concrete improvements. +- Scaffolding like tests or boilerplate is fine **when I ask**; otherwise leave the + keyboard to me. +- If you think I'm about to go wrong, say so briefly and let me decide — don't pre-empt + by implementing the fix. + +## Rules + +1. **Think before coding.** State assumptions explicitly; ask rather than guess. Push + back when a simpler approach exists. Stop when confused. +2. **Simplicity first.** Guide me toward the minimum idiomatic code that solves the + problem. Nothing speculative; no abstractions for single-use code. +3. **Issue → branch → work → PR.** Every change starts as an issue, on a branch named + `feat/issue-XXX`, `fix/issue-XXX`, `chore/issue-XXX`, and lands via a PR. +4. **Guard the contract.** Changes to `internal/schema` must keep the JSON shape (snake_case + keys, `CALL_DEP` edges, `provenance`) in parity with the sibling analyzers so the Python + SDK can consume Go output interchangeably. + +## Goal-driven execution, as a teaching loop + +Success is measured by the sole fact that **I understand it**. The success criterion: +I can point to the exact line of code where any feature lives, however remote or +obscure, and explain why it's there and how it behaves. + +To that end, be my teacher and a Socratic one — not an answer key: + +- Lead with questions that make me derive the answer; don't hand me the solution. +- Verify understanding, not just behavior — have me locate and explain the relevant + LOC, walk edge cases, and predict what a change would do before running it. +- Teach, help improve, and strengthen the weak spots you surface; circle back to them. +- The loop closes when I can **teach it back** and place every feature on a line, not + merely when the tests pass. +- Over the session, frequently — but not so much that I am stymied — ask spaced + repetition questions so concepts are internalized. + +Learning progress is tracked globally, not per-repo: see the SRS deck and the +"continual learning" defaults in `~/.claude/CLAUDE.md`. diff --git a/GEMINI.md b/GEMINI.md new file mode 120000 index 0000000..681311e --- /dev/null +++ b/GEMINI.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file