mukan-ignite/ignite/templates/module/migration/module.go
Mukan Erkin Törük c32551b6f7
Some checks failed
Docs Deploy / build_and_deploy (push) Has been cancelled
Generate Docs / cli (push) Has been cancelled
Generate Config Doc / cli (push) Has been cancelled
Go formatting / go-formatting (push) Has been cancelled
Check links / markdown-link-check (push) Has been cancelled
Integration / pre-test (push) Has been cancelled
Integration / test on (push) Has been cancelled
Integration / status (push) Has been cancelled
Lint / Lint Go code (push) Has been cancelled
Test / test (ubuntu-latest) (push) Has been cancelled
refactor: replace all github.com upstream refs with git.cw.tr/mukan-network
2026-05-11 03:36:24 +03:00

344 lines
8.5 KiB
Go

package modulemigration
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"strconv"
"github.com/gobuffalo/genny/v2"
"git.cw.tr/mukan-network/mukan-ignite/ignite/pkg/errors"
"git.cw.tr/mukan-network/mukan-ignite/ignite/pkg/xast"
)
func moduleModify(opts *Options) genny.RunFn {
return func(r *genny.Runner) error {
f, err := r.Disk.Find(opts.ModuleFile())
if err != nil {
return err
}
content, err := updateModule(f.String(), opts)
if err != nil {
return err
}
return r.File(genny.NewFileS(opts.ModuleFile(), content))
}
}
func updateModule(content string, opts *Options) (string, error) {
currentVersion, err := ConsensusVersion(content)
if err != nil {
return "", err
}
if currentVersion != opts.FromVersion {
return "", errors.Errorf("expected module consensus version %d, got %d", opts.FromVersion, currentVersion)
}
content, err = xast.AppendImports(
content,
xast.WithNamedImport(opts.MigrationImportAlias(), opts.MigrationImportPath()),
)
if err != nil {
return "", err
}
content, err = setConsensusVersion(content, opts.ToVersion)
if err != nil {
return "", err
}
return addMigrationRegistration(content, opts)
}
// ConsensusVersion returns the current module consensus version from module.go content.
func ConsensusVersion(content string) (uint64, error) {
fileSet := token.NewFileSet()
file, err := parser.ParseFile(fileSet, "", content, parser.ParseComments)
if err != nil {
return 0, err
}
expr, err := consensusVersionExpr(file)
if err != nil {
return 0, err
}
return parseConsensusVersionExpr(file, expr)
}
func addMigrationRegistration(content string, opts *Options) (string, error) {
info, err := registerServicesInfoFromContent(content)
if err != nil {
return "", err
}
var functionOptions []xast.FunctionOptions
if info.needsConfiguratorSetup {
functionOptions = append(functionOptions, xast.AppendFuncCode(configuratorSetupCode(info)))
}
functionOptions = append(functionOptions, xast.AppendFuncCode(migrationRegistrationCode(info, opts)))
return xast.ModifyFunction(content, "RegisterServices", functionOptions...)
}
func configuratorSetupCode(info registerServicesInfo) string {
returnStmt := "return"
if info.returnsError {
returnStmt = "return nil"
}
return info.cfgVar + ", ok := " + info.parameterName + ".(module.Configurator)\n" +
"if !ok {\n\t" + returnStmt + "\n}"
}
func migrationRegistrationCode(info registerServicesInfo, opts *Options) string {
handleErr := "panic(err)"
if info.returnsError {
handleErr = "return err"
}
return "if err := " + info.cfgVar +
".RegisterMigration(types.ModuleName, " +
strconv.FormatUint(opts.FromVersion, 10) + ", " +
opts.MigrationImportAlias() + "." + opts.MigrationFunc() +
"); err != nil {\n\t" + handleErr + "\n}"
}
func setConsensusVersion(content string, version uint64) (string, error) {
fileSet := token.NewFileSet()
file, err := parser.ParseFile(fileSet, "", content, parser.ParseComments)
if err != nil {
return "", err
}
commentMap := ast.NewCommentMap(fileSet, file, file.Comments)
expr, err := consensusVersionExpr(file)
if err != nil {
return "", err
}
switch versionExpr := expr.(type) {
case *ast.BasicLit:
versionExpr.Value = strconv.FormatUint(version, 10)
case *ast.Ident:
valueSpec, valueIndex, err := findValueSpec(file, versionExpr.Name)
if err != nil {
return "", err
}
valueSpec.Values[valueIndex] = &ast.BasicLit{
Kind: token.INT,
Value: strconv.FormatUint(version, 10),
}
default:
return "", errors.Errorf("unsupported consensus version expression %T", expr)
}
file.Comments = commentMap.Filter(file).Comments()
return formatFile(fileSet, file)
}
type registerServicesInfo struct {
cfgVar string
needsConfiguratorSetup bool
parameterName string
returnsError bool
}
func registerServicesInfoFromContent(content string) (registerServicesInfo, error) {
fileSet := token.NewFileSet()
file, err := parser.ParseFile(fileSet, "", content, parser.ParseComments)
if err != nil {
return registerServicesInfo{}, err
}
funcDecl := findFuncDecl(file, "RegisterServices")
if funcDecl == nil {
return registerServicesInfo{}, errors.New("function \"RegisterServices\" not found")
}
if funcDecl.Type.Params == nil || len(funcDecl.Type.Params.List) == 0 || len(funcDecl.Type.Params.List[0].Names) == 0 {
return registerServicesInfo{}, errors.New("RegisterServices must have a named parameter")
}
param := funcDecl.Type.Params.List[0]
info := registerServicesInfo{
parameterName: param.Names[0].Name,
returnsError: functionReturnsError(funcDecl),
}
if isModuleConfiguratorType(param.Type) {
info.cfgVar = info.parameterName
return info, nil
}
cfgVar := findConfiguratorVar(funcDecl, info.parameterName)
if cfgVar != "" {
info.cfgVar = cfgVar
return info, nil
}
info.cfgVar = "cfg"
info.needsConfiguratorSetup = true
return info, nil
}
func functionReturnsError(funcDecl *ast.FuncDecl) bool {
if funcDecl.Type.Results == nil || len(funcDecl.Type.Results.List) != 1 {
return false
}
ident, ok := funcDecl.Type.Results.List[0].Type.(*ast.Ident)
return ok && ident.Name == "error"
}
func findConfiguratorVar(funcDecl *ast.FuncDecl, parameterName string) string {
for _, stmt := range funcDecl.Body.List {
assignStmt, ok := stmt.(*ast.AssignStmt)
if !ok || len(assignStmt.Lhs) < 1 || len(assignStmt.Rhs) != 1 {
continue
}
typeAssert, ok := assignStmt.Rhs[0].(*ast.TypeAssertExpr)
if !ok || !isModuleConfiguratorType(typeAssert.Type) {
continue
}
ident, ok := typeAssert.X.(*ast.Ident)
if !ok || ident.Name != parameterName {
continue
}
cfgVar, ok := assignStmt.Lhs[0].(*ast.Ident)
if !ok {
continue
}
return cfgVar.Name
}
return ""
}
func isModuleConfiguratorType(expr ast.Expr) bool {
switch typedExpr := expr.(type) {
case *ast.Ident:
return typedExpr.Name == "Configurator"
case *ast.SelectorExpr:
return typedExpr.Sel.Name == "Configurator"
default:
return false
}
}
func consensusVersionExpr(file *ast.File) (ast.Expr, error) {
funcDecl := findFuncDecl(file, "ConsensusVersion")
if funcDecl == nil {
return nil, errors.New("function \"ConsensusVersion\" not found")
}
if funcDecl.Body == nil || len(funcDecl.Body.List) == 0 {
return nil, errors.New("ConsensusVersion has an empty body")
}
lastStmt, ok := funcDecl.Body.List[len(funcDecl.Body.List)-1].(*ast.ReturnStmt)
if !ok || len(lastStmt.Results) != 1 {
return nil, errors.New("ConsensusVersion must return exactly one value")
}
return lastStmt.Results[0], nil
}
func parseConsensusVersionExpr(file *ast.File, expr ast.Expr) (uint64, error) {
switch typedExpr := expr.(type) {
case *ast.BasicLit:
return parseConsensusVersionLiteral(typedExpr)
case *ast.Ident:
valueSpec, valueIndex, err := findValueSpec(file, typedExpr.Name)
if err != nil {
return 0, err
}
return parseConsensusVersionExpr(file, valueSpec.Values[valueIndex])
default:
return 0, errors.Errorf("unsupported consensus version expression %T", expr)
}
}
func parseConsensusVersionLiteral(lit *ast.BasicLit) (uint64, error) {
if lit.Kind != token.INT {
return 0, errors.Errorf("unsupported consensus version literal kind %v", lit.Kind)
}
version, err := strconv.ParseUint(lit.Value, 10, 64)
if err != nil {
return 0, err
}
return version, nil
}
func findValueSpec(file *ast.File, name string) (*ast.ValueSpec, int, error) {
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || (genDecl.Tok != token.CONST && genDecl.Tok != token.VAR) {
continue
}
for _, spec := range genDecl.Specs {
valueSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
for i, specName := range valueSpec.Names {
if specName.Name != name {
continue
}
if len(valueSpec.Values) == 0 {
return nil, 0, errors.Errorf("%s has no value", name)
}
valueIndex := i
if valueIndex >= len(valueSpec.Values) {
valueIndex = len(valueSpec.Values) - 1
}
return valueSpec, valueIndex, nil
}
}
}
return nil, 0, errors.Errorf("%s value not found", name)
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if ok && funcDecl.Name.Name == name {
return funcDecl
}
}
return nil
}
func formatFile(fileSet *token.FileSet, file *ast.File) (string, error) {
var buf bytes.Buffer
if err := format.Node(&buf, fileSet, file); err != nil {
return "", err
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
return "", err
}
return string(formatted), nil
}