Some checks are pending
Docs Deploy / build_and_deploy (push) Waiting to run
Generate Docs / cli (push) Waiting to run
Generate Config Doc / cli (push) Waiting to run
Go formatting / go-formatting (push) Waiting to run
Check links / markdown-link-check (push) Waiting to run
Integration / pre-test (push) Waiting to run
Integration / test on (push) Blocked by required conditions
Integration / status (push) Blocked by required conditions
Lint / Lint Go code (push) Waiting to run
Test / test (ubuntu-latest) (push) Waiting to run
306 lines
6.9 KiB
Go
306 lines
6.9 KiB
Go
package protoanalysis
|
|
|
|
import (
|
|
"fmt"
|
|
"regexp"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/emicklei/proto"
|
|
)
|
|
|
|
type builder struct {
|
|
p protoPackage
|
|
}
|
|
|
|
// build turns a low level proto pkg into a high level Package.
|
|
func build(p protoPackage) Package {
|
|
br := builder{p}
|
|
|
|
pk := Package{
|
|
Name: p.name,
|
|
Path: p.dir,
|
|
Files: br.buildFiles(),
|
|
Messages: br.buildMessages(),
|
|
Services: br.toServices(p.services()),
|
|
}
|
|
|
|
for _, option := range p.options() {
|
|
if option.Name == optionGoPkg {
|
|
pk.GoImportName = option.Constant.Source
|
|
break
|
|
}
|
|
}
|
|
|
|
return pk
|
|
}
|
|
|
|
func (b builder) buildFiles() (files []File) {
|
|
for _, f := range b.p.files {
|
|
files = append(files, File{f.path, f.imports})
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (b builder) buildMessages() (messages []Message) {
|
|
for _, f := range b.p.files {
|
|
for _, message := range f.messages {
|
|
// Keep track of the message fields and types
|
|
fields := make(map[string]string)
|
|
|
|
// Find the highest field number
|
|
var highestFieldNumber int
|
|
for _, elem := range message.Elements {
|
|
field, ok := elem.(*proto.NormalField)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if field.Sequence > highestFieldNumber {
|
|
highestFieldNumber = field.Sequence
|
|
}
|
|
|
|
fields[field.Name] = field.Type
|
|
}
|
|
|
|
// some proto messages might be defined inside another proto messages.
|
|
// to represents these types, an underscore is used.
|
|
// e.g. if C message inside B, and B inside A: A_B_C.
|
|
var (
|
|
name = message.Name
|
|
parent = message.Parent
|
|
)
|
|
for {
|
|
if parent == nil {
|
|
break
|
|
}
|
|
|
|
parentMessage, ok := parent.(*proto.Message)
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
name = fmt.Sprintf("%s_%s", parentMessage.Name, name)
|
|
parent = parentMessage.Parent
|
|
}
|
|
|
|
messages = append(messages, Message{
|
|
Name: name,
|
|
Path: f.path,
|
|
HighestFieldNumber: highestFieldNumber,
|
|
Fields: fields,
|
|
})
|
|
}
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
func (b builder) toServices(ps []*proto.Service) (services []Service) {
|
|
for _, service := range ps {
|
|
s := Service{
|
|
Name: service.Name,
|
|
RPCFuncs: b.elementsToRPCFunc(service.Elements),
|
|
}
|
|
|
|
services = append(services, s)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (b builder) elementsToRPCFunc(elems []proto.Visitee) (rpcFuncs []RPCFunc) {
|
|
for _, el := range elems {
|
|
rpc, ok := el.(*proto.RPC)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
requestMessage := findProtoMessageByTypeName(b.p.name, b.p.messages(), rpc.RequestType)
|
|
if requestMessage == nil {
|
|
continue
|
|
}
|
|
|
|
rf := RPCFunc{
|
|
Name: rpc.Name,
|
|
RequestType: rpc.RequestType,
|
|
ReturnsType: rpc.ReturnsType,
|
|
HTTPRules: b.elementsToHTTPRules(requestMessage, rpc.Elements),
|
|
}
|
|
|
|
rpcFuncs = append(rpcFuncs, rf)
|
|
}
|
|
|
|
return rpcFuncs
|
|
}
|
|
|
|
func (b builder) elementsToHTTPRules(requestMessage *proto.Message, elems []proto.Visitee) (httpRules []HTTPRule) {
|
|
for _, el := range elems {
|
|
option, ok := el.(*proto.Option)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if !strings.Contains(option.Name, "google.api.http") {
|
|
continue
|
|
}
|
|
|
|
httpRules = append(httpRules, b.constantToHTTPRules(requestMessage, option.Constant)...)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func findProtoMessageByTypeName(pkgName string, messages []*proto.Message, typeName string) *proto.Message {
|
|
var exactMatch *proto.Message
|
|
|
|
canonicalTypeName := canonicalMessageName(pkgName, typeName)
|
|
for _, message := range messages {
|
|
if message.Name == typeName {
|
|
exactMatch = message
|
|
}
|
|
|
|
if flattenProtoMessageName(message) == canonicalTypeName {
|
|
return message
|
|
}
|
|
}
|
|
|
|
return exactMatch
|
|
}
|
|
|
|
func flattenProtoMessageName(message *proto.Message) string {
|
|
name := message.Name
|
|
for parent := message.Parent; parent != nil; {
|
|
parentMessage, ok := parent.(*proto.Message)
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
name = fmt.Sprintf("%s_%s", parentMessage.Name, name)
|
|
parent = parentMessage.Parent
|
|
}
|
|
|
|
return name
|
|
}
|
|
|
|
// Regexp to extract HTTP rule URL parameter names.
|
|
// The expression extracts parameter names defined within "{}".
|
|
// Extra parameter arguments are ignored. These arguments are normally
|
|
// defined after an "=", for example as "{param=**}".
|
|
var urlParamRe = regexp.MustCompile(`(?m){([^=]+?)(?:=.+?)?}`)
|
|
|
|
func (b builder) constantToHTTPRules(requestMessage *proto.Message, constant proto.Literal) (httpRules []HTTPRule) {
|
|
// find out the endpoint template.
|
|
endpoint := constant.Source
|
|
|
|
if endpoint == "" {
|
|
for _, each := range constant.OrderedMap {
|
|
switch each.Name {
|
|
case
|
|
"get",
|
|
"post",
|
|
"put",
|
|
"patch",
|
|
"delete":
|
|
endpoint = each.Source
|
|
}
|
|
if endpoint != "" {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// find out url params.
|
|
var params []string
|
|
|
|
match := urlParamRe.FindAllStringSubmatch(endpoint, -1)
|
|
for _, item := range match {
|
|
params = append(params, item[1])
|
|
}
|
|
|
|
// calculate url params, query params and body fields counts.
|
|
var (
|
|
messageFields, messageFieldsCount = b.messageFieldsCount(requestMessage)
|
|
paramsCount = len(params)
|
|
bodyFieldsCount int
|
|
)
|
|
|
|
if body, ok := constant.OrderedMap.Get("body"); ok { // check if body is specified.
|
|
if body.Source == "*" { // means there should be no query params per the spec.
|
|
bodyFieldsCount = messageFieldsCount - paramsCount
|
|
} else if body.Source != "" {
|
|
bodyFieldsCount = 1 // means body fields are grouped under a single top-level field.
|
|
}
|
|
}
|
|
|
|
queryParamsCount := messageFieldsCount - paramsCount - bodyFieldsCount
|
|
|
|
var (
|
|
queryFields map[string]string
|
|
bodyFields map[string]string
|
|
)
|
|
for name, t := range messageFields {
|
|
if slices.Contains(params, name) {
|
|
// this is a URL parameter, skip it
|
|
continue
|
|
}
|
|
|
|
// If there are body fields, we need to add them to the bodyFields map.
|
|
// There are no known post requests that contain body fields and query params
|
|
if bodyFieldsCount > 0 {
|
|
if len(bodyFields) == 0 {
|
|
bodyFields = make(map[string]string)
|
|
}
|
|
bodyFields[name] = t
|
|
} else {
|
|
if len(queryFields) == 0 {
|
|
queryFields = make(map[string]string)
|
|
}
|
|
|
|
queryFields[name] = t
|
|
}
|
|
}
|
|
|
|
// create and add the HTTP rule to the list.
|
|
httpRule := HTTPRule{
|
|
Endpoint: endpoint,
|
|
Params: params,
|
|
HasQuery: queryParamsCount > 0,
|
|
QueryFields: queryFields,
|
|
HasBody: bodyFieldsCount > 0,
|
|
BodyFields: bodyFields,
|
|
}
|
|
|
|
httpRules = append(httpRules, httpRule)
|
|
|
|
// search for nested HTTP rules.
|
|
if constant, ok := constant.OrderedMap.Get("additional_bindings"); ok {
|
|
httpRules = append(httpRules, b.constantToHTTPRules(requestMessage, *constant)...)
|
|
}
|
|
|
|
return httpRules
|
|
}
|
|
|
|
func (b builder) messageFieldsCount(message *proto.Message) (messageFields map[string]string, count int) {
|
|
messageFields = make(map[string]string)
|
|
|
|
for _, el := range message.Elements {
|
|
switch el := el.(type) {
|
|
case *proto.NormalField:
|
|
count++
|
|
if el.Repeated {
|
|
messageFields[el.Name] = fmt.Sprintf("repeated %s", el.Type)
|
|
} else {
|
|
messageFields[el.Name] = el.Type
|
|
}
|
|
case *proto.MapField:
|
|
count++
|
|
messageFields[el.Name] = fmt.Sprintf("map<%s, %s>", el.KeyType, el.Type)
|
|
case *proto.OneOfField:
|
|
count++
|
|
messageFields[el.Name] = el.Type
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|