mukan-ignite/ignite/pkg/xast/function.go
Mukan Erkin Törük c32551b6f7
Some checks failed
Docs Deploy / build_and_deploy (push) Has been cancelled
Generate Docs / cli (push) Has been cancelled
Generate Config Doc / cli (push) Has been cancelled
Go formatting / go-formatting (push) Has been cancelled
Check links / markdown-link-check (push) Has been cancelled
Integration / pre-test (push) Has been cancelled
Integration / test on (push) Has been cancelled
Integration / status (push) Has been cancelled
Lint / Lint Go code (push) Has been cancelled
Test / test (ubuntu-latest) (push) Has been cancelled
refactor: replace all github.com upstream refs with git.cw.tr/mukan-network
2026-05-11 03:36:24 +03:00

1035 lines
29 KiB
Go

package xast
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"sort"
"strings"
"git.cw.tr/mukan-network/mukan-ignite/ignite/pkg/errors"
)
type (
// functionOpts represent the options for functions.
functionOpts struct {
newParams []functionParam // Parameters to add to the function.
body string // New function body content.
newLines []functionLine // Lines to insert at specific positions.
insideCall functionCalls // Function calls to modify.
insideStruct functionStructs // Struct literals to modify.
appendTestCase []string // Test cases to append.
appendCode []string // Code to append at the end.
returnVars []string // Return variables to modify.
appendSwitch functionSwitches // Switch cases to append.
removeCalls []string // Function calls to remove.
}
// FunctionOptions configures code generation.
FunctionOptions func(*functionOpts)
// functionStruct represents a struct literal to modify.
functionSwitch struct {
condition string // Condition to find.
switchCase string // Switch case to insert.
switchBody string // Code to insert.
}
functionSwitches []functionSwitch
functionSwitchesMap map[string]functionSwitches
// functionStruct represents a struct literal to modify.
functionStruct struct {
name string // Name of the struct type.
param string // Name of the struct field.
code string // Code to insert.
}
functionStructs []functionStruct
functionStructsMap map[string]functionStructs
// functionCall represents a function call to modify.
functionCall struct {
name string // Name of the function.
code string // Code to insert.
index int // Position to insert at.
}
functionCalls []functionCall
functionCallsMap map[string]functionCalls
// functionParam represents a parameter to add to a function.
functionParam struct {
name string // Parameter name.
varType string // Parameter type.
index int // Position to insert at.
}
// functionLine represents a line of code to insert.
functionLine struct {
code string // Code to insert.
number uint64 // Line number to insert at.
}
)
// Field creates an AST field node from the function parameter.
func (p functionParam) Field() *ast.Field {
return &ast.Field{
Names: []*ast.Ident{ast.NewIdent(p.name)},
Type: ast.NewIdent(p.varType),
}
}
// Map converts a slice of functionStructs to a map keyed by struct name.
func (s functionStructs) Map() functionStructsMap {
structMap := make(functionStructsMap)
for _, c := range s {
structs, ok := structMap[c.name]
if !ok {
structs = make(functionStructs, 0)
}
structMap[c.name] = append(structs, c)
}
return structMap
}
// Map converts a slice of functionStructs to a map keyed by struct name.
func (s functionSwitches) Map() functionSwitchesMap {
switchesMap := make(functionSwitchesMap)
for _, c := range s {
switches, ok := switchesMap[c.condition]
if !ok {
switches = make(functionSwitches, 0)
}
switchesMap[c.condition] = append(switches, c)
}
return switchesMap
}
// Map converts a slice of functionCalls to a map keyed by function name.
func (c functionCalls) Map() functionCallsMap {
callMap := make(functionCallsMap)
for _, c := range c {
calls, ok := callMap[c.name]
if !ok {
calls = make(functionCalls, 0)
}
callMap[c.name] = append(calls, c)
}
return callMap
}
func cloneFunctionCallsMap(src functionCallsMap) functionCallsMap {
dst := make(functionCallsMap, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func cloneFunctionStructsMap(src functionStructsMap) functionStructsMap {
dst := make(functionStructsMap, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func cloneFunctionSwitchesMap(src functionSwitchesMap) functionSwitchesMap {
dst := make(functionSwitchesMap, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
// AppendFuncParams adds a new parameter to a function.
func AppendFuncParams(name, varType string, index int) FunctionOptions {
return func(c *functionOpts) {
c.newParams = append(c.newParams, functionParam{
name: name,
varType: varType,
index: index,
})
}
}
// ReplaceFuncBody replaces the entire body of a function, the method will replace first and apply the other options after.
func ReplaceFuncBody(body string) FunctionOptions {
return func(c *functionOpts) {
c.body = body
}
}
// AppendFuncTestCase adds a test case to a test function, if exists, of a function in Go source code content.
func AppendFuncTestCase(testCase string) FunctionOptions {
return func(c *functionOpts) {
c.appendTestCase = append(c.appendTestCase, testCase)
}
}
// AppendFuncCode adds code to the end of a function, if exists, of a function in Go source code content.
func AppendFuncCode(code string) FunctionOptions {
return func(c *functionOpts) {
c.appendCode = append(c.appendCode, code)
}
}
// AppendFuncCodeAtLine inserts code at a specific line number.
var AppendFuncCodeAtLine = AppendFuncAtLine
// AppendFuncAtLine inserts code at a specific line number.
func AppendFuncAtLine(code string, lineNumber uint64) FunctionOptions {
return func(c *functionOpts) {
c.newLines = append(c.newLines, functionLine{
code: code,
number: lineNumber,
})
}
}
// AppendInsideFuncCall adds an argument to a function call. For instances, the method have a parameter a
// // call 'New(param1, param2)' and we want to add the param3 the result will be 'New(param1, param2, param3)'.
// AppendInsideFuncCall appends code inside a function call.
// The callName parameter can be either:
// - Simple name: "NewKeeper" matches any call to NewKeeper regardless of package/receiver
// - Qualified name: "foo.NewKeeper" matches only calls to NewKeeper with foo as the package/receiver
//
// The code parameter is the argument to insert, and index specifies the position:
// - index >= 0: insert at the specified position
// - index == -1: append at the end
func AppendInsideFuncCall(callName, code string, index int) FunctionOptions {
return func(c *functionOpts) {
c.insideCall = append(c.insideCall, functionCall{
name: callName,
code: code,
index: index,
})
}
}
// AppendFuncStruct adds a field to a struct literal. For instance,
// the struct has only one parameter 'Params{Param1: param1}' and we want to add
// the param2 the result will be 'Params{Param1: param1, Param2: param2}'.
//
// The name parameter can be either:
// - Simple name: "Keeper" matches any struct literal of type Keeper regardless of package
// - Qualified name: "keeper.Keeper" matches only struct literals with keeper as the package
func AppendFuncStruct(name, param, code string) FunctionOptions {
return func(c *functionOpts) {
c.insideStruct = append(c.insideStruct, functionStruct{
name: name,
param: param,
code: code,
})
}
}
// NewFuncReturn replaces return statements in a function.
func NewFuncReturn(returnVars ...string) FunctionOptions {
return func(c *functionOpts) {
c.returnVars = append(c.returnVars, returnVars...)
}
}
// AppendSwitchCase inserts a new case with the code at a specific switch condition statement.
func AppendSwitchCase(condition, switchCase, switchBody string) FunctionOptions {
return func(c *functionOpts) {
c.appendSwitch = append(c.appendSwitch, functionSwitch{
condition: condition,
switchCase: switchCase,
switchBody: switchBody,
})
}
}
// RemoveFuncCall removes function calls with the specified name from within a function.
// The callName can be either a simple function name like "doSomething" or a qualified
// name like "pkg.DoSomething".
func RemoveFuncCall(callName string) FunctionOptions {
return func(c *functionOpts) {
c.removeCalls = append(c.removeCalls, callName)
}
}
// newFunctionOptions creates a new functionOpts with defaults.
func newFunctionOptions() functionOpts {
return functionOpts{
newParams: make([]functionParam, 0),
body: "",
newLines: make([]functionLine, 0),
insideCall: make(functionCalls, 0),
insideStruct: make(functionStructs, 0),
appendTestCase: make([]string, 0),
appendCode: make([]string, 0),
returnVars: make([]string, 0),
removeCalls: make([]string, 0),
}
}
// ModifyFunction modifies a function in Go source code using functional options.
func ModifyFunction(content string, funcName string, functions ...FunctionOptions) (string, error) {
// Collect all function options.
opts := newFunctionOptions()
for _, fn := range functions {
fn(&opts)
}
// Parse source into AST
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", content, parser.ParseComments)
if err != nil {
return "", errors.Errorf("failed to parse file (%s): %w", funcName, err)
}
cmap := ast.NewCommentMap(fset, file, file.Comments)
// Find the target function.
funcDecl := findFuncDecl(file, funcName)
if funcDecl == nil {
return "", errors.Errorf("function %q not found", funcName)
}
// Apply modifications.
if err := applyFunctionOptions(fset, funcDecl, &opts); err != nil {
return "", err
}
file.Comments = cmap.Filter(file).Comments()
return formatNode(fset, file)
}
// findFuncDecl finds a function declaration in an AST by name.
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
if fd, ok := decl.(*ast.FuncDecl); ok && fd.Name.Name == name {
return fd
}
}
return nil
}
// addCode converts string code snippets into AST statements.
func addCode(fileSet *token.FileSet, appendCode []string) ([]ast.Stmt, error) {
code := make([]ast.Stmt, 0, len(appendCode))
for _, codeToInsert := range appendCode {
body, err := codeToBlockStmt(fileSet, codeToInsert)
if err != nil {
return nil, err
}
code = append(code, body.List...)
}
return code, nil
}
// modifyReturnVars converts return variable strings into AST expressions.
func modifyReturnVars(fileSet *token.FileSet, returnVars []string) ([]ast.Expr, error) {
stmts := make([]ast.Expr, 0, len(returnVars))
for _, returnVar := range returnVars {
newRetExpr, err := parser.ParseExprFrom(fileSet, "", []byte(returnVar), parser.ParseComments)
if err != nil {
return nil, err
}
stmts = append(stmts, newRetExpr)
}
return stmts, nil
}
// appendSwitchCase appends a new case to a switch statement.
func appendSwitchCase(fileSet *token.FileSet, stmt ast.Node, fs functionSwitches) error {
for _, f := range fs {
// Parse the new case code
newRetExpr, err := parser.ParseExprFrom(fileSet, "", []byte(f.switchCase), parser.ParseComments)
if err != nil {
return err
}
bodyStmt, err := codeToBlockStmt(fileSet, f.switchBody)
if err != nil {
return err
}
// Create a new case clause
newCase := &ast.CaseClause{
List: []ast.Expr{newRetExpr},
Body: bodyStmt.List,
Case: token.NoPos, // Keep first item aligned with case keyword
Colon: token.NoPos, // Keep colon aligned with case keyword
}
// Handle different types of switch statements
switch statement := stmt.(type) {
case *ast.TypeSwitchStmt:
statement.Body.List = appendCaseToList(statement.Body.List, newCase)
case *ast.SwitchStmt:
statement.Body.List = appendCaseToList(statement.Body.List, newCase)
default:
return errors.Errorf("unsupported switch statement type: %T", stmt)
}
}
return nil
}
// appendCaseToList handles inserting a case clause into a list of statements,
// placing it before any default case if one exists.
func appendCaseToList(list []ast.Stmt, newCase *ast.CaseClause) []ast.Stmt {
if len(list) > 0 {
lastCase, isDefault := list[len(list)-1].(*ast.CaseClause)
if isDefault && len(lastCase.List) == 0 {
// Insert before default.
return append(list[:len(list)-1], newCase, list[len(list)-1])
}
}
return append(list, newCase)
}
// addParams adds new parameters to a function declaration.
func addParams(funcDecl *ast.FuncDecl, newParams []functionParam) error {
for _, p := range newParams {
switch {
case p.index == -1:
// Append at end
funcDecl.Type.Params.List = append(funcDecl.Type.Params.List, p.Field())
case p.index >= 0 && p.index <= len(funcDecl.Type.Params.List):
// Insert at index
funcDecl.Type.Params.List = append(
funcDecl.Type.Params.List[:p.index],
append([]*ast.Field{p.Field()}, funcDecl.Type.Params.List[p.index:]...)...,
)
default:
return errors.Errorf("params index %d out of range", p.index)
}
}
return nil
}
// addNewLine inserts code at specific line numbers in a function body.
func addNewLine(fileSet *token.FileSet, funcDecl *ast.FuncDecl, newLines []functionLine) error {
for _, newLine := range newLines {
maxLine := uint64(len(funcDecl.Body.List)) - 1
// Validate line number
if newLine.number > maxLine {
return errors.Errorf("line number %d out of range (max %d)", newLine.number, maxLine)
}
// Parse insertion code
insertionExpr, err := codeToBlockStmt(fileSet, newLine.code)
if err != nil {
return err
}
// Insert code at the specified line number.
funcDecl.Body.List = append(
funcDecl.Body.List[:newLine.number],
append(insertionExpr.List, funcDecl.Body.List[newLine.number:]...)...,
)
}
return nil
}
// modifyReturn handles return statement modifications and code appending.
func modifyReturn(funcDecl *ast.FuncDecl, returnStmts []ast.Expr, appendCode []ast.Stmt) error {
if len(funcDecl.Body.List) > 0 {
lastStmt := funcDecl.Body.List[len(funcDecl.Body.List)-1]
switch stmt := lastStmt.(type) {
case *ast.ReturnStmt:
// Modify the return statement
if len(returnStmts) > 0 {
stmt.Results = returnStmts
}
if len(appendCode) > 0 {
// Insert before return
appendCode = append(appendCode, stmt)
funcDecl.Body.List = append(funcDecl.Body.List[:len(funcDecl.Body.List)-1], appendCode...)
}
default:
if len(returnStmts) > 0 {
return errors.New("return statement not found")
}
// Append at end
if len(appendCode) > 0 {
funcDecl.Body.List = append(funcDecl.Body.List, appendCode...)
}
}
} else {
if len(returnStmts) > 0 {
return errors.New("return statement not found")
}
// Append to empty body
if len(appendCode) > 0 {
funcDecl.Body.List = append(funcDecl.Body.List, appendCode...)
}
}
return nil
}
// addTestCase adds test cases to a test function.
func addTestCase(fSet *token.FileSet, funcDecl *ast.FuncDecl, testCase []string) error {
// Find tests variable
for _, stmt := range funcDecl.Body.List {
assignStmt, ok := stmt.(*ast.AssignStmt)
if !ok || len(assignStmt.Lhs) == 0 {
continue
}
// Check for "tests" variable
ident, ok := assignStmt.Lhs[0].(*ast.Ident)
if !ok || ident.Name != "tests" {
continue
}
// Get composite literal
compositeLit, ok := assignStmt.Rhs[0].(*ast.CompositeLit)
if !ok {
continue
}
// Add test cases
for _, tc := range testCase {
testCaseStmt, err := structToBlockStmt(fSet, tc)
if err != nil {
return err
}
compositeLit.Elts = append(compositeLit.Elts, testCaseStmt)
}
}
return nil
}
// structToBlockStmt parses struct literal code into AST expression.
func structToBlockStmt(fSet *token.FileSet, code string) (ast.Expr, error) {
newFuncContent := toStruct(code)
newContent, err := parser.ParseExprFrom(fSet, "", newFuncContent, parser.AllErrors)
if err != nil {
return nil, err
}
newCompositeList, ok := newContent.(*ast.CompositeLit)
if !ok {
return nil, errors.New("not a composite literal")
}
if len(newCompositeList.Elts) != 1 {
return nil, errors.New("composite literal has more than one element or zero")
}
return newCompositeList.Elts[0], nil
}
// toStruct wraps code in an anonymous struct literal for parsing.
func toStruct(code string) string {
code = strings.TrimSpace(code)
code = strings.ReplaceAll(code, "\n\t", "\n")
code = strings.ReplaceAll(code, "\n ", "\n")
return fmt.Sprintf("struct {}{ %s }", code)
}
// exprName extracts the name from an AST expression.
func exprName(expr ast.Expr) (string, bool) {
switch exp := expr.(type) {
case *ast.Ident:
return exp.Name, true
case *ast.SelectorExpr:
// Check if X is an identifier to get the package name
if ident, ok := exp.X.(*ast.Ident); ok {
// Return qualified name: package.Function
return ident.Name + "." + exp.Sel.Name, true
}
// Fallback to just the selector name if X is not an identifier
return exp.Sel.Name, true
default:
return "", false
}
}
// addFunctionCall modifies function call arguments.
func addFunctionCall(expr *ast.CallExpr, calls functionCalls) error {
for _, c := range calls {
newArg := ast.NewIdent(c.code)
newArg.NamePos = token.Pos(c.index)
switch {
case c.index == -1:
// Append at end
expr.Args = append(expr.Args, newArg)
case c.index >= 0 && c.index <= len(expr.Args):
// Insert at index
expr.Args = append(expr.Args[:c.index], append([]ast.Expr{newArg}, expr.Args[c.index:]...)...)
default:
return errors.Errorf("function call index %d out of range", c.index)
}
}
return nil
}
// addStructs modifies struct literal fields.
func addStructs(fileSet *token.FileSet, expr *ast.CompositeLit, structs functionStructs) {
// Find the current max offset to avoid reused positions
file := fileSet.File(expr.Pos())
maxOffset := file.Offset(expr.Rbrace)
for _, elt := range expr.Elts {
if pos := elt.End(); pos.IsValid() {
offset := file.Offset(pos)
if offset > maxOffset {
maxOffset = offset
}
}
}
for i, s := range structs {
// Advance position
insertOffset := maxOffset + i
insertPos := file.Pos(insertOffset)
value := ast.NewIdent(s.code)
value.NamePos = insertPos
var newArg ast.Expr = value
if s.param != "" {
key := ast.NewIdent(s.param)
key.NamePos = insertPos + token.Pos(i)
newArg = &ast.KeyValueExpr{
Key: key,
Value: value,
Colon: insertPos,
}
}
expr.Elts = append(expr.Elts, newArg)
expr.Rbrace += token.Pos(i + 1)
}
// Ensure closing brace is on a new line
if len(expr.Elts) > 0 {
last := expr.Elts[len(expr.Elts)-1]
if file.Line(expr.Rbrace) == file.Line(last.End()) {
// Force a new line before Rbrace
file.AddLine(file.Offset(expr.Rbrace))
}
}
}
// codeToBlockStmt parses code string into AST block statement.
func codeToBlockStmt(fileSet *token.FileSet, code string) (*ast.BlockStmt, error) {
newFuncContent := toCode(code)
newContent, err := parser.ParseFile(fileSet, "", newFuncContent, parser.ParseComments)
if err != nil {
return nil, err
}
return newContent.Decls[0].(*ast.FuncDecl).Body, nil
}
// toCode wraps code in a function for parsing.
func toCode(code string) string {
return fmt.Sprintf("package p; func _() { %s }", strings.TrimSpace(code))
}
// formatNode formats an AST node into Go source code.
func formatNode(fileSet *token.FileSet, n ast.Node) (string, error) {
var buf bytes.Buffer
if err := format.Node(&buf, fileSet, n); err != nil {
return "", err
}
node := strings.TrimSpace(buf.String())
return node, nil
}
// applyFunctionOptions applies all modifications to a function.
func applyFunctionOptions(fileSet *token.FileSet, f *ast.FuncDecl, opts *functionOpts) (err error) {
var newFunctionBody *ast.BlockStmt
if opts.body != "" {
newFunctionBody, err = codeToBlockStmt(fileSet, opts.body)
if err != nil {
return err
}
}
appendCode, err := addCode(fileSet, opts.appendCode)
if err != nil {
return err
}
returnStmts, err := modifyReturnVars(fileSet, opts.returnVars)
if err != nil {
return err
}
callMap := opts.insideCall.Map()
callMapCheck := cloneFunctionCallsMap(callMap)
structMap := opts.insideStruct.Map()
structMapCheck := cloneFunctionStructsMap(structMap)
switchesCasesMap := opts.appendSwitch.Map()
switchesCasesMapCheck := cloneFunctionSwitchesMap(switchesCasesMap)
if len(opts.removeCalls) > 0 {
removeFunctionCalls(f, opts.removeCalls)
}
if err := addParams(f, opts.newParams); err != nil {
return err
}
if newFunctionBody != nil {
f.Body = newFunctionBody
f.Body.Rbrace = f.Body.Pos()
}
if err := addNewLine(fileSet, f, opts.newLines); err != nil {
return err
}
if err := modifyReturn(f, returnStmts, appendCode); err != nil {
return err
}
for _, bodyList := range f.Body.List {
var (
stmt ast.Stmt
key string
)
switch expr := bodyList.(type) {
case *ast.TypeSwitchStmt:
stmt = expr
key, err = formatNode(fileSet, expr.Assign)
case *ast.SwitchStmt:
stmt = expr
key, err = formatNode(fileSet, expr.Tag)
default:
continue
}
if err != nil {
return err
}
switchCase, ok := switchesCasesMap[key]
if !ok {
continue
}
if err := appendSwitchCase(fileSet, stmt, switchCase); err != nil {
return err
}
delete(switchesCasesMapCheck, key)
}
ast.Inspect(f, func(n ast.Node) bool {
switch expr := n.(type) {
case *ast.CallExpr:
name, exist := exprName(expr.Fun)
if !exist {
return true
}
var allCalls functionCalls
if calls, ok := callMap[name]; ok {
allCalls = append(allCalls, calls...)
delete(callMapCheck, name)
}
if sel, isSel := expr.Fun.(*ast.SelectorExpr); isSel {
simpleName := sel.Sel.Name
if calls, ok := callMap[simpleName]; ok {
allCalls = append(allCalls, calls...)
delete(callMapCheck, simpleName)
}
}
if len(allCalls) == 0 {
return true
}
if err = addFunctionCall(expr, allCalls); err != nil {
return false
}
case *ast.CompositeLit:
name, exist := exprName(expr.Type)
if !exist {
return true
}
var allStructs functionStructs
if structs, ok := structMap[name]; ok {
allStructs = append(allStructs, structs...)
delete(structMapCheck, name)
}
if sel, isSel := expr.Type.(*ast.SelectorExpr); isSel {
simpleName := sel.Sel.Name
if structs, ok := structMap[simpleName]; ok {
allStructs = append(allStructs, structs...)
delete(structMapCheck, simpleName)
}
}
if len(allStructs) == 0 {
return true
}
addStructs(fileSet, expr, allStructs)
}
return true
})
if err != nil {
return err
}
if err := addTestCase(fileSet, f, opts.appendTestCase); err != nil {
return err
}
if len(callMapCheck) > 0 {
return errors.Errorf("function calls not found: %v", callMapCheck)
}
if len(structMapCheck) > 0 {
return errors.Errorf("function structs not found: %v", structMapCheck)
}
if len(switchesCasesMapCheck) > 0 {
return errors.Errorf("function switch not found: %v", switchesCasesMapCheck)
}
return nil
}
// ModifyCaller replaces all arguments of a specific function call in the given content.
// The callerExpr should be in the format "pkgname.FuncName" or just "FuncName".
// The modifiers function is called with the existing arguments and should return the new arguments.
func ModifyCaller(content, callerExpr string, modifiers func([]string) ([]string, error)) (string, error) {
// parse the caller expression to extract package name and function name
var pkgName, funcName string
parts := strings.Split(callerExpr, ".")
switch len(parts) {
case 1:
funcName = parts[0]
case 2:
pkgName = parts[0]
funcName = parts[1]
default:
return "", errors.New("invalid caller expression format, use 'pkgname.FuncName' or 'FuncName'")
}
fileSet := token.NewFileSet()
// preserve original source positions for maintaining whitespace
fileSet.AddFile("", fileSet.Base(), len(content))
f, err := parser.ParseFile(fileSet, "", content, parser.ParseComments)
if err != nil {
return "", err
}
// track positions of all call expressions that need modification
type callModification struct {
node *ast.CallExpr
newArgs []string
startPos token.Pos
endPos token.Pos
}
var modifications []callModification
errInspect := Inspect(f, func(n ast.Node) error {
callExpr, ok := n.(*ast.CallExpr)
if !ok {
return nil
}
// check if this call matches our target function
match := false
switch fun := callExpr.Fun.(type) {
case *ast.Ident:
// handle case of FuncName()
if pkgName == "" && fun.Name == funcName {
match = true
}
case *ast.SelectorExpr:
// handle case of pkg.FuncName()
if ident, ok := fun.X.(*ast.Ident); ok && ident.Name == pkgName && fun.Sel.Name == funcName {
match = true
}
}
if !match {
return nil
}
// extract current arguments as strings
currentArgs := make([]string, len(callExpr.Args))
for i, arg := range callExpr.Args {
var buf bytes.Buffer
if err := format.Node(&buf, fileSet, arg); err != nil {
return err
}
currentArgs[i] = buf.String()
}
// apply the modifier function
newArgs, err := modifiers(currentArgs)
if err != nil {
return err
}
// record this modification for later application
modifications = append(modifications, callModification{
node: callExpr,
newArgs: newArgs,
startPos: callExpr.Lparen + 1, // position right after the left parenthesis
endPos: callExpr.Rparen, // position of the right parenthesis
})
return nil
})
if errInspect != nil {
return "", errInspect
}
if len(modifications) == 0 {
return "", errors.Errorf("function call %s not found in file content", callerExpr)
}
// apply modifications in reverse order to avoid position shifts
sort.Slice(modifications, func(i, j int) bool {
return modifications[i].startPos > modifications[j].startPos
})
// make modifications directly to the content string
result := []byte(content)
for _, mod := range modifications {
// build the new arguments string
newArgsStr := strings.Join(mod.newArgs, ", ")
// replace the arguments in the original content
startOffset := fileSet.Position(mod.startPos).Offset
endOffset := fileSet.Position(mod.endPos).Offset
result = append(
result[:startOffset],
append(
[]byte(newArgsStr),
result[endOffset:]...,
)...,
)
}
return string(result), nil
}
// RemoveFunction removes a function declaration from the file content.
func RemoveFunction(content, funcName string) (string, error) {
// Parse source into AST.
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", content, parser.ParseComments)
if err != nil {
return "", errors.Errorf("failed to parse file: %w", err)
}
cmap := ast.NewCommentMap(fset, file, file.Comments)
// Find the function to remove.
var found bool
var newDecls []ast.Decl
for _, decl := range file.Decls {
if fd, ok := decl.(*ast.FuncDecl); ok && fd.Name.Name == funcName {
found = true
// Remove comments associated with this function.
delete(cmap, decl)
continue // Skip this declaration to remove it.
}
newDecls = append(newDecls, decl)
}
if !found {
return "", errors.Errorf("function %q not found", funcName)
}
// Update file declarations and comments.
file.Decls = newDecls
file.Comments = cmap.Filter(file).Comments()
return formatNode(fset, file)
}
// removeFunctionCalls removes all function calls matching the specified names from a function.
func removeFunctionCalls(f *ast.FuncDecl, callNames []string) {
if f.Body == nil {
return
}
callMap := make(map[string]bool, len(callNames))
for _, name := range callNames {
callMap[name] = true
}
matchesCall := func(callExpr *ast.CallExpr) bool {
switch fun := callExpr.Fun.(type) {
case *ast.Ident:
return callMap[fun.Name]
case *ast.SelectorExpr:
if ident, ok := fun.X.(*ast.Ident); ok {
return callMap[ident.Name+"."+fun.Sel.Name]
}
}
return false
}
var filterStmts func([]ast.Stmt) []ast.Stmt
filterStmts = func(stmts []ast.Stmt) []ast.Stmt {
filtered := make([]ast.Stmt, 0, len(stmts))
for _, stmt := range stmts {
if exprStmt, ok := stmt.(*ast.ExprStmt); ok {
if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok && matchesCall(callExpr) {
continue
}
}
switch typedStmt := stmt.(type) {
case *ast.BlockStmt:
typedStmt.List = filterStmts(typedStmt.List)
case *ast.IfStmt:
if typedStmt.Body != nil {
typedStmt.Body.List = filterStmts(typedStmt.Body.List)
}
switch elseNode := typedStmt.Else.(type) {
case *ast.BlockStmt:
elseNode.List = filterStmts(elseNode.List)
case *ast.IfStmt:
elseNode.Body.List = filterStmts(elseNode.Body.List)
if elseBlock, ok := elseNode.Else.(*ast.BlockStmt); ok {
elseBlock.List = filterStmts(elseBlock.List)
}
}
case *ast.ForStmt:
if typedStmt.Body != nil {
typedStmt.Body.List = filterStmts(typedStmt.Body.List)
}
case *ast.RangeStmt:
if typedStmt.Body != nil {
typedStmt.Body.List = filterStmts(typedStmt.Body.List)
}
case *ast.SwitchStmt:
if typedStmt.Body != nil {
for _, caseClause := range typedStmt.Body.List {
if cc, ok := caseClause.(*ast.CaseClause); ok {
cc.Body = filterStmts(cc.Body)
}
}
}
case *ast.TypeSwitchStmt:
if typedStmt.Body != nil {
for _, caseClause := range typedStmt.Body.List {
if cc, ok := caseClause.(*ast.CaseClause); ok {
cc.Body = filterStmts(cc.Body)
}
}
}
}
filtered = append(filtered, stmt)
}
return filtered
}
f.Body.List = filterStmts(f.Body.List)
}