mukan-ignite/ignite/pkg/xast/function_test.go
Mukan Erkin Törük 26b204bd04
Some checks are pending
Docs Deploy / build_and_deploy (push) Waiting to run
Generate Docs / cli (push) Waiting to run
Generate Config Doc / cli (push) Waiting to run
Go formatting / go-formatting (push) Waiting to run
Check links / markdown-link-check (push) Waiting to run
Integration / pre-test (push) Waiting to run
Integration / test on (push) Blocked by required conditions
Integration / status (push) Blocked by required conditions
Lint / Lint Go code (push) Waiting to run
Test / test (ubuntu-latest) (push) Waiting to run
feat: fork Ignite CLI v29 as Mukan Ignite — remove cosmos-sdk restrictions
2026-05-11 03:31:37 +03:00

1487 lines
29 KiB
Go

package xast
import (
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/ignite/cli/v29/ignite/pkg/errors"
)
func TestModifyFunction(t *testing.T) {
existingContent := `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
// init param
p := bla.NewParam()
// start to call something
p.CallSomething("Another call")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`
type args struct {
fileContent string
functionName string
functions []FunctionOptions
}
tests := []struct {
name string
args args
want string
err error
}{
{
name: "add a case to switch statement",
args: args{
fileContent: `package test
func processPacket(packet interface{}) error {
switch packet := packet.(type) {
default:
return fmt.Errorf("unknown packet type: %T", packet)
}
}`,
functionName: "processPacket",
functions: []FunctionOptions{
AppendSwitchCase(
"packet := packet.(type)",
"*types.FooPacket",
"return handleFooPacket(packet)",
),
},
},
want: `package test
func processPacket(packet interface{}) error {
switch packet := packet.(type) {
case *types.FooPacket:
return handleFooPacket(packet)
default:
return fmt.Errorf("unknown packet type: %T", packet)
}
}`,
},
{
name: "add multiple cases to switch statement",
args: args{
fileContent: `package test
func handlePacket(data interface{}) error {
switch v := data.(type) {
case string:
return processString(v)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}`,
functionName: "handlePacket",
functions: []FunctionOptions{
AppendSwitchCase(
"v := data.(type)",
"int",
"return processInt(v)",
),
AppendSwitchCase(
"v := data.(type)",
"bool",
"return processBool(v)",
),
},
},
want: `package test
func handlePacket(data interface{}) error {
switch v := data.(type) {
case string:
return processString(v)
case int:
return processInt(v)
case bool:
return processBool(v)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}`,
},
{
name: "add multiple cases to two switch statement",
args: args{
fileContent: `package test
func handlePacket(data interface{}) error {
switch v := data.(type) {
case string:
return processString(v)
default:
return fmt.Errorf("unsupported type: %T", v)
}
switch x {
case 1:
return "one"
default:
return "unknown"
}
}`,
functionName: "handlePacket",
functions: []FunctionOptions{
AppendSwitchCase(
"v := data.(type)",
"int",
"return processInt(v)",
),
AppendSwitchCase(
"x",
"2",
`return "two"`,
),
},
},
want: `package test
func handlePacket(data interface{}) error {
switch v := data.(type) {
case string:
return processString(v)
case int:
return processInt(v)
default:
return fmt.Errorf("unsupported type: %T", v)
}
switch x {
case 1:
return "one"
case 2:
return "two"
default:
return "unknown"
}
}`,
},
{
name: "add case to switch with non-matching condition",
args: args{
fileContent: `package test
func process(x int) string {
switch x {
case 1:
return "one"
default:
return "unknown"
}
}`,
functionName: "process",
functions: []FunctionOptions{
AppendSwitchCase(
"wrongCondition",
"2",
`return "two"`,
),
},
},
err: errors.New("function switch not found: map[wrongCondition:[{wrongCondition 2 return \"two\"}]]"),
},
{
name: "add all modifications type",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendFuncParams("param1", "string", 0),
ReplaceFuncBody(`return false`),
AppendFuncAtLine(`fmt.Println("Appended at line 0.")`, 0),
AppendFuncAtLine(`SimpleCall(foo, bar)`, 1),
AppendFuncAtLine(`if param1 == "" {
return false
}`, 2),
AppendFuncCode(`fmt.Println("Appended code.")`),
AppendFuncCode(`Param{
Baz: baz,
Foo: foo,
}`),
NewFuncReturn("1"),
AppendInsideFuncCall("SimpleCall", "baz", 0),
AppendInsideFuncCall("SimpleCall", "bla", -1),
AppendInsideFuncCall("Println", strconv.Quote("test"), -1),
AppendFuncStruct("Param", "Bar", strconv.Quote("bar")),
AppendFuncTestCase(`{
desc: "valid first genesis state",
genState: GenesisState{},
}`),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction(param1 string) bool {
fmt.Println("Appended at line 0.", "test")
SimpleCall(baz, foo, bar, bla)
if param1 == "" {
return false
}
fmt.Println("Appended code.", "test")
Param{
Baz: baz,
Foo: foo,
Bar: "bar",
}
return 1
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add the replace body",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{ReplaceFuncBody(`return false`)},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool { return false }
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add a new test case",
args: args{
fileContent: existingContent,
functionName: "TestValidate",
functions: []FunctionOptions{
AppendFuncTestCase(`{
desc: "valid genesis state",
genState: GenesisState{},
}`),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
// init param
p := bla.NewParam()
// start to call something
p.CallSomething("Another call")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
}, {
desc: "valid genesis state",
genState: GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add two test cases",
args: args{
fileContent: existingContent,
functionName: "TestValidate",
functions: []FunctionOptions{
AppendFuncTestCase(`
{
desc: "valid first genesis state",
genState: GenesisState{},
}`),
AppendFuncTestCase(`
{
desc: "valid second genesis state",
genState: GenesisState{},
}`),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
// init param
p := bla.NewParam()
// start to call something
p.CallSomething("Another call")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
}, {
desc: "valid first genesis state",
genState: GenesisState{},
}, {
desc: "valid second genesis state",
genState: GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add append line and code modification",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendFuncAtLine(`fmt.Println("Appended at line 0.")`, 0),
AppendFuncAtLine(`SimpleCall(foo, bar)`, 1),
AppendFuncCode(`fmt.Println("Appended code.")`),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
fmt.Println("Appended at line 0.")
SimpleCall(foo, bar)
// init param
p := bla.NewParam()
// start to call something
p.CallSomething("Another call")
fmt.Println("Appended code.")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add all modifications type",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{NewFuncReturn("1")},
},
want: strings.ReplaceAll(existingContent, "return true", "return 1\n"),
},
{
name: "add inside call modifications",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendInsideFuncCall("NewParam", "baz", 0),
AppendInsideFuncCall("NewParam", "bla", -1),
AppendInsideFuncCall("CallSomething", strconv.Quote("test1"), -1),
AppendInsideFuncCall("CallSomething", strconv.Quote("test2"), 0),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
// init param
p := bla.NewParam(baz, bla)
// start to call something
p.CallSomething("test2", "Another call", "test1")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add inside call modifications with qualified package name",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendInsideFuncCall("bla.NewParam", "baz", 0),
AppendInsideFuncCall("bla.NewParam", "bla", -1),
AppendInsideFuncCall("CallSomething", strconv.Quote("test1"), -1),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
// init param
p := bla.NewParam(baz, bla)
// start to call something
p.CallSomething("Another call", "test1")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add inside call modifications with mixed qualified and unqualified names",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendInsideFuncCall("bla.NewParam", "ctx", 0),
AppendInsideFuncCall("NewParam", "baz", -1),
AppendInsideFuncCall("p.CallSomething", strconv.Quote("test1"), 0),
AppendInsideFuncCall("CallSomething", strconv.Quote("test2"), -1),
},
},
want: `package main
import (
"fmt"
)
// main function
func main() {
// print hello world
fmt.Println("Hello, world!")
// call new param function
New(param1, param2)
}
// anotherFunction another function
func anotherFunction() bool {
// init param
p := bla.NewParam(ctx, baz)
// start to call something
p.CallSomething("test1", "Another call", "test2")
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "add inside struct modifications",
args: args{
fileContent: `package main
import (
"fmt"
)
// anotherFunction another function
func anotherFunction() bool {
Param{
Baz: baz,
Foo: foo,
}
Client{baz, foo}
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendFuncStruct("Param", "Bar", "bar"),
AppendFuncStruct("Param", "Bla", "bla"),
AppendFuncStruct("Client", "", "bar"),
},
},
want: `package main
import (
"fmt"
)
// anotherFunction another function
func anotherFunction() bool {
Param{
Baz: baz,
Foo: foo,
Bar: bar,
Bla: bla,
}
Client{baz, foo, bar}
// return always true
return true
}
// TestValidate test the validations
func TestValidate(t *testing.T) {
tests := []struct {
desc string
genState types.GenesisState
}{
{
desc: "default is valid",
genState: types.DefaultGenesis(),
},
{
desc: "valid genesis state",
genState: types.GenesisState{},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
err := tc.genState.Validate()
require.NoError(t, err)
})
}
}`,
},
{
name: "function without test case assertion",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{
AppendFuncTestCase(`{
desc: "valid second genesis state",
genState: GenesisState{},
}`),
},
},
want: existingContent,
},
{
name: "params out of range",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendFuncParams("param1", "string", 1)},
},
err: errors.New("params index 1 out of range"),
},
{
name: "invalid params",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendFuncParams("9#.(c", "string", 0)},
},
err: errors.New("format.Node internal error (16:22: expected ')', found 9 (and 1 more errors))"),
},
{
name: "invalid content for replace body",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{ReplaceFuncBody("9#.(c")},
},
err: errors.New("1:24: illegal character U+0023 '#'"),
},
{
name: "line number out of range",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendFuncAtLine(`fmt.Println("")`, 4)},
},
err: errors.New("line number 4 out of range (max 2)"),
},
{
name: "invalid code for append at line",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendFuncAtLine("9#.(c", 0)},
},
err: errors.New("1:24: illegal character U+0023 '#'"),
},
{
name: "invalid code append",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendFuncCode("9#.(c")},
},
err: errors.New("1:24: illegal character U+0023 '#'"),
},
{
name: "invalid new return",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{NewFuncReturn("9#.(c")},
},
err: errors.New("1:2: illegal character U+0023 '#'"),
},
{
name: "call name not found",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendInsideFuncCall("FooFunction", "baz", 0)},
},
err: errors.New("function calls not found: map[FooFunction:[{FooFunction baz 0}]]"),
},
{
name: "invalid call param",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendInsideFuncCall("NewParam", "9#.(c", 0)},
},
err: errors.New("format.Node internal error (18:21: illegal character U+0023 '#' (and 4 more errors))"),
},
{
name: "call params out of range",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{AppendInsideFuncCall("NewParam", "baz", 1)},
},
err: errors.New("function call index 1 out of range"),
},
{
name: "empty modifications",
args: args{
fileContent: existingContent,
functionName: "anotherFunction",
functions: []FunctionOptions{},
},
want: existingContent,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ModifyFunction(tt.args.fileContent, tt.args.functionName, tt.args.functions...)
if tt.err != nil {
require.Error(t, err)
require.Equal(t, tt.err.Error(), err.Error())
return
}
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}
func TestModifyCaller(t *testing.T) {
existingContent := `package main
import (
"context"
"fmt"
)
// main function
func main() {
// Simple function call
// print hello world
fmt.Println("Hello, world!")
// Call with multiple arguments
server.Foo(param1, param2, 42)
// Call with no arguments
EmptyFunc()
// Call with complex arguments
ComplexFunc([]string{"a", "b"}, map[string]int{"a": 1})
// Multiple calls to the same function
fmt.Println("First call")
fmt.Println("Second call")
}
`
tests := []struct {
name string
content string
callerExpr string
modifierFunc func([]string) ([]string, error)
expected string
expectedError string
}{
{
name: "replace arguments in fmt.Println",
content: existingContent,
callerExpr: "fmt.Println",
modifierFunc: func(args []string) ([]string, error) {
return []string{`"Modified output"`}, nil
},
expected: `package main
import (
"context"
"fmt"
)
// main function
func main() {
// Simple function call
// print hello world
fmt.Println("Modified output")
// Call with multiple arguments
server.Foo(param1, param2, 42)
// Call with no arguments
EmptyFunc()
// Call with complex arguments
ComplexFunc([]string{"a", "b"}, map[string]int{"a": 1})
// Multiple calls to the same function
fmt.Println("Modified output")
fmt.Println("Modified output")
}
`,
},
{
name: "replace server.Foo arguments",
content: existingContent,
callerExpr: "server.Foo",
modifierFunc: func(args []string) ([]string, error) {
return []string{"context.Background()", "newParam", "123"}, nil
},
expected: `package main
import (
"context"
"fmt"
)
// main function
func main() {
// Simple function call
// print hello world
fmt.Println("Hello, world!")
// Call with multiple arguments
server.Foo(context.Background(), newParam, 123)
// Call with no arguments
EmptyFunc()
// Call with complex arguments
ComplexFunc([]string{"a", "b"}, map[string]int{"a": 1})
// Multiple calls to the same function
fmt.Println("First call")
fmt.Println("Second call")
}
`,
},
{
name: "add argument to EmptyFunc",
content: existingContent,
callerExpr: "EmptyFunc",
modifierFunc: func(args []string) ([]string, error) {
return []string{`"new argument"`}, nil
},
expected: `package main
import (
"context"
"fmt"
)
// main function
func main() {
// Simple function call
// print hello world
fmt.Println("Hello, world!")
// Call with multiple arguments
server.Foo(param1, param2, 42)
// Call with no arguments
EmptyFunc("new argument")
// Call with complex arguments
ComplexFunc([]string{"a", "b"}, map[string]int{"a": 1})
// Multiple calls to the same function
fmt.Println("First call")
fmt.Println("Second call")
}
`,
},
{
name: "modify complex arguments",
content: existingContent,
callerExpr: "ComplexFunc",
modifierFunc: func(args []string) ([]string, error) {
return []string{`[]string{"x", "y", "z"}`, `map[string]int{"x": 10}`}, nil
},
expected: `package main
import (
"context"
"fmt"
)
// main function
func main() {
// Simple function call
// print hello world
fmt.Println("Hello, world!")
// Call with multiple arguments
server.Foo(param1, param2, 42)
// Call with no arguments
EmptyFunc()
// Call with complex arguments
ComplexFunc([]string{"x", "y", "z"}, map[string]int{"x": 10})
// Multiple calls to the same function
fmt.Println("First call")
fmt.Println("Second call")
}
`,
},
{
name: "function not found",
content: existingContent,
callerExpr: "NonExistentFunc",
modifierFunc: func(args []string) ([]string, error) {
return []string{`"test"`}, nil
},
expectedError: "function call NonExistentFunc not found in file content",
},
{
name: "error in modifier function",
content: existingContent,
callerExpr: "fmt.Println",
modifierFunc: func(args []string) ([]string, error) {
return nil, errors.New("custom error in modifier")
},
expectedError: "custom error in modifier",
},
{
name: "invalid caller expression",
content: existingContent,
callerExpr: "pkg.sub.Function",
modifierFunc: func(args []string) ([]string, error) {
return []string{`"test"`}, nil
},
expectedError: "invalid caller expression format, use 'pkgname.FuncName' or 'FuncName'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ModifyCaller(tt.content, tt.callerExpr, tt.modifierFunc)
if tt.expectedError != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedError)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, result)
})
}
}
func TestRemoveFunction(t *testing.T) {
tests := []struct {
name string
content string
funcName string
expected string
expectError bool
}{
{
name: "remove a simple function",
content: `package main
func main() {
println("hello")
}
func anotherFunction() {
println("another")
}
func thirdFunction() {
println("third")
}
`,
funcName: "anotherFunction",
expected: `package main
func main() {
println("hello")
}
func thirdFunction() {
println("third")
}`,
},
{
name: "remove first function",
content: `package main
func first() {
println("first")
}
func second() {
println("second")
}
`,
funcName: "first",
expected: `package main
func second() {
println("second")
}`,
},
{
name: "remove last function",
content: `package main
func first() {
println("first")
}
func second() {
println("second")
}
`,
funcName: "second",
expected: `package main
func first() {
println("first")
}`,
},
{
name: "remove function with comments",
content: `package main
// main is the entry point
func main() {
println("main")
}
// helperFunc does something
func helperFunc() {
println("helper")
}
`,
funcName: "helperFunc",
expected: `package main
// main is the entry point
func main() {
println("main")
}`,
},
{
name: "function not found",
content: `package main
func main() {
println("hello")
}
`,
funcName: "notFound",
expectError: true,
},
{
name: "invalid source file",
content: `package main func`,
funcName: "main",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := RemoveFunction(tt.content, tt.funcName)
if tt.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, result)
})
}
}
func TestRemoveFuncCall(t *testing.T) {
tests := []struct {
name string
content string
funcName string
callName string
expected string
}{
{
name: "remove a function call",
content: `package main
func main() {
fmt.Println("before")
doSomething()
fmt.Println("after")
}
`,
funcName: "main",
callName: "doSomething",
expected: `package main
func main() {
fmt.Println("before")
fmt.Println("after")
}`,
},
{
name: "remove qualified function call",
content: `package main
func main() {
fmt.Println("hello")
pkg.DoSomething()
fmt.Println("world")
}
`,
funcName: "main",
callName: "pkg.DoSomething",
expected: `package main
func main() {
fmt.Println("hello")
fmt.Println("world")
}`,
},
{
name: "remove multiple calls to same function",
content: `package main
func main() {
doSomething()
fmt.Println("middle")
doSomething()
}
`,
funcName: "main",
callName: "doSomething",
expected: `package main
func main() {
fmt.Println("middle")
}`,
},
{
name: "remove call with arguments",
content: `package main
func process() {
validate(arg1, arg2)
execute()
}
`,
funcName: "process",
callName: "validate",
expected: `package main
func process() {
execute()
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ModifyFunction(tt.content, tt.funcName, RemoveFuncCall(tt.callName))
require.NoError(t, err)
require.Equal(t, tt.expected, result)
})
}
}
func TestModifyFunctionMissingTargets(t *testing.T) {
t.Run("invalid source content", func(t *testing.T) {
_, err := ModifyFunction("package main\nfunc", "anotherFunction")
require.Error(t, err)
require.Contains(t, err.Error(), "failed to parse file (anotherFunction)")
})
t.Run("function not found", func(t *testing.T) {
_, err := ModifyFunction(`package main
func main() {}
`, "anotherFunction")
require.EqualError(t, err, `function "anotherFunction" not found`)
})
}
func TestModifyFunctionReturnStatementNotFound(t *testing.T) {
t.Run("non-empty body without return", func(t *testing.T) {
_, err := ModifyFunction(`package main
func noReturn() {
doSomething()
}
`, "noReturn", NewFuncReturn("1"))
require.EqualError(t, err, "return statement not found")
})
t.Run("empty body without return", func(t *testing.T) {
_, err := ModifyFunction(`package main
func empty() {}
`, "empty", NewFuncReturn("1"))
require.EqualError(t, err, "return statement not found")
})
}
func TestRemoveFuncCallNestedStatements(t *testing.T) {
content := `package main
func process(values []int, anyValue interface{}) int {
if len(values) > 0 {
doRemove()
} else if len(values) == 0 {
doRemove()
} else {
doRemove()
}
for i := 0; i < len(values); i++ {
doRemove()
}
for _, v := range values {
_ = v
doRemove()
}
switch value := anyValue.(type) {
case int:
_ = value
doRemove()
default:
doKeep()
}
switch len(values) {
case 1:
doRemove()
default:
doKeep()
}
return 1
}
`
got, err := ModifyFunction(content, "process", RemoveFuncCall("doRemove"))
require.NoError(t, err)
require.NotContains(t, got, "doRemove(")
require.Contains(t, got, "doKeep()")
require.Contains(t, got, "return 1")
}