// Copyright (C) 2014 Jakob Borg. All rights reserved. Use of this source code // is governed by an MIT-style license that can be found in the LICENSE file. package main import ( "bytes" "flag" "fmt" "go/ast" "go/format" "go/parser" "go/token" "io" "log" "os" "regexp" "strconv" "strings" "text/template" ) type fieldInfo struct { Name string IsBasic bool // handled by one the native Read/WriteUint64 etc functions IsSlice bool // field is a slice of FieldType FieldType string // original type of field, i.e. "int" Encoder string // the encoder name, i.e. "Uint64" for Read/WriteUint64 Convert string // what to convert to when encoding, i.e. "uint64" Max int // max size for slices and strings Submax int // max size for strings inside slices } type structInfo struct { Name string Fields []fieldInfo } func (i structInfo) SizeExpr() string { var xdrSizes = map[string]int{ "int8": 4, "uint8": 4, "int16": 4, "uint16": 4, "int32": 4, "uint32": 4, "int64": 8, "uint64": 8, "int": 8, "bool": 4, } var terms []string nl := "" for _, f := range i.Fields { if size := xdrSizes[f.FieldType]; size > 0 { if f.IsSlice { terms = append(terms, nl+"4+len(o."+f.Name+")*"+strconv.Itoa(size)) } else { terms = append(terms, strconv.Itoa(size)) } } else { switch f.FieldType { case "string", "[]byte": if f.IsSlice { terms = append(terms, nl+"4+xdr.SizeOfSlice(o."+f.Name+")") } else { terms = append(terms, nl+"4+len(o."+f.Name+")+xdr.Padding(len(o."+f.Name+"))") } default: if f.IsSlice { terms = append(terms, nl+"4+xdr.SizeOfSlice(o."+f.Name+")") } else { terms = append(terms, nl+"o."+f.Name+".XDRSize()") } } } nl = "\n" } return strings.Join(terms, "+") } var headerData = `// ************************************************************ // This file is automatically generated by genxdr. Do not edit. // ************************************************************ package {{.Package}} import ( "github.com/calmh/xdr" ) ` var encoderData = ` func (o {{.Name}}) XDRSize() int { return {{.SizeExpr}} }//+n func (o {{.Name}}) MarshalXDR() ([]byte, error) { buf:= make([]byte, o.XDRSize()) m := &xdr.Marshaller{Data: buf} return buf, o.MarshalXDRInto(m) }//+n func (o {{.Name}}) MustMarshalXDR() []byte { bs, err := o.MarshalXDR() if err != nil { panic(err) } return bs }//+n func (o {{.Name}}) MarshalXDRInto(m *xdr.Marshaller) error { {{range $fi := .Fields}} {{if $fi.IsSlice}} {{template "marshalSlice" $fi}} {{else}} {{template "marshalValue" $fi}} {{end}} {{end}} return m.Error }//+n {{define "marshalValue"}} {{if ne .Convert ""}} m.Marshal{{.Encoder}}({{.Convert}}(o.{{.Name}})) {{else if .IsBasic}} {{if ge .Max 1}} if l := len(o.{{.Name}}); l > {{.Max}} { return xdr.ElementSizeExceeded("{{.Name}}", l, {{.Max}}) } {{end}} m.Marshal{{.Encoder}}(o.{{.Name}}) {{else}} if err := o.{{.Name}}.MarshalXDRInto(m); err != nil { return err } {{end}} {{end}} {{define "marshalSlice"}} {{if ge .Max 1}} if l := len(o.{{.Name}}); l > {{.Max}} { return xdr.ElementSizeExceeded("{{.Name}}", l, {{.Max}}) } {{end}} m.MarshalUint32(uint32(len(o.{{.Name}}))) for i := range o.{{.Name}} { {{if ne .Convert ""}} m.Marshal{{.Encoder}}({{.Convert}}(o.{{.Name}}[i])) {{else if .IsBasic}} m.Marshal{{.Encoder}}(o.{{.Name}}[i]) {{else}} if err := o.{{.Name}}[i].MarshalXDRInto(m); err != nil { return err } {{end}} } {{end}} func (o *{{.Name}}) UnmarshalXDR(bs []byte) error { u := &xdr.Unmarshaller{Data: bs} return o.UnmarshalXDRFrom(u) } func (o *{{.Name}}) UnmarshalXDRFrom(u *xdr.Unmarshaller) error { {{range $fi := .Fields}} {{if $fi.IsSlice}} {{template "unmarshalSlice" $fi}} {{else}} {{template "unmarshalValue" $fi}} {{end}} {{end}} return u.Error }//+n {{define "unmarshalValue"}} {{if ne .Convert ""}} o.{{.Name}} = {{.FieldType}}(u.Unmarshal{{.Encoder}}()) {{else if .IsBasic}} {{if ge .Max 1}} o.{{.Name}} = u.Unmarshal{{.Encoder}}Max({{.Max}}) {{else}} o.{{.Name}} = u.Unmarshal{{.Encoder}}() {{end}} {{else}} (&o.{{.Name}}).UnmarshalXDRFrom(u) {{end}} {{end}} {{define "unmarshalSlice"}} _{{.Name}}Size := int(u.UnmarshalUint32()) if _{{.Name}}Size < 0 { return xdr.ElementSizeExceeded("{{.Name}}", _{{.Name}}Size, {{.Max}}) } else if _{{.Name}}Size == 0 { o.{{.Name}} = nil } else { {{if ge .Max 1}} if _{{.Name}}Size > {{.Max}} { return xdr.ElementSizeExceeded("{{.Name}}", _{{.Name}}Size, {{.Max}}) } {{end}} if _{{.Name}}Size <= len(o.{{.Name}}) { {{if eq .FieldType "string"}} for i := _{{.Name}}Size; i < len(o.{{.Name}}); i++ { o.{{.Name}}[i] = "" } {{end}} {{if eq .FieldType "[]byte"}} for i := _{{.Name}}Size; i < len(o.{{.Name}}); i++ { o.{{.Name}}[i] = nil } {{end}} o.{{.Name}} = o.{{.Name}}[:_{{.Name}}Size] } else { o.{{.Name}} = make([]{{.FieldType}}, _{{.Name}}Size) } for i := range o.{{.Name}} { {{if ne .Convert ""}} o.{{.Name}}[i] = {{.FieldType}}(u.Unmarshal{{.Encoder}}()) {{else if .IsBasic}} {{if ge .Submax 1}} o.{{.Name}}[i] = u.Unmarshal{{.Encoder}}Max({{.Submax}}) {{else}} o.{{.Name}}[i] = u.Unmarshal{{.Encoder}}() {{end}} {{else}} (&o.{{.Name}}[i]).UnmarshalXDRFrom(u) {{end}} } } {{end}} ` var ( encodeTpl = template.Must(template.New("encoder").Parse(encoderData)) headerTpl = template.Must(template.New("header").Parse(headerData)) ) var emptyTypeTpl = template.Must(template.New("encoder").Parse(` func (o {{.Name}}) XDRSize() int { return 0 } func (o {{.Name}}) MarshalXDR() ([]byte, error) { return nil, nil }//+n func (o {{.Name}}) MustMarshalXDR() []byte { return nil }//+n func (o {{.Name}}) MarshalXDRInto(m *xdr.Marshaller) error { return nil }//+n func (o *{{.Name}}) UnmarshalXDR(bs []byte) error { return nil }//+n func (o *{{.Name}}) UnmarshalXDRFrom(u *xdr.Unmarshaller) error { return nil }//+n `)) var maxRe = regexp.MustCompile(`(?:\Wmax:)(\d+)(?:\s*,\s*(\d+))?`) type typeSet struct { Type string Encoder string } var xdrEncoders = map[string]typeSet{ "int8": typeSet{"uint8", "Uint8"}, "uint8": typeSet{"", "Uint8"}, "int16": typeSet{"uint16", "Uint16"}, "uint16": typeSet{"", "Uint16"}, "int32": typeSet{"uint32", "Uint32"}, "uint32": typeSet{"", "Uint32"}, "int64": typeSet{"uint64", "Uint64"}, "uint64": typeSet{"", "Uint64"}, "int": typeSet{"uint64", "Uint64"}, "string": typeSet{"", "String"}, "[]byte": typeSet{"", "Bytes"}, "bool": typeSet{"", "Bool"}, } func handleStruct(t *ast.StructType) []fieldInfo { var fs []fieldInfo for _, sf := range t.Fields.List { if len(sf.Names) == 0 { // We don't handle anonymous fields continue } fn := sf.Names[0].Name var max1, max2 int if sf.Comment != nil { c := sf.Comment.List[0].Text m := maxRe.FindStringSubmatch(c) if len(m) >= 2 { max1, _ = strconv.Atoi(m[1]) } if len(m) >= 3 { max2, _ = strconv.Atoi(m[2]) } if strings.Contains(c, "noencode") { continue } } var f fieldInfo switch ft := sf.Type.(type) { case *ast.Ident: tn := ft.Name if enc, ok := xdrEncoders[tn]; ok { f = fieldInfo{ Name: fn, IsBasic: true, FieldType: tn, Encoder: enc.Encoder, Convert: enc.Type, Max: max1, Submax: max2, } } else { f = fieldInfo{ Name: fn, IsBasic: false, FieldType: tn, Max: max1, Submax: max2, } } case *ast.ArrayType: if ft.Len != nil { // We don't handle arrays continue } tn := ft.Elt.(*ast.Ident).Name if enc, ok := xdrEncoders["[]"+tn]; ok { f = fieldInfo{ Name: fn, IsBasic: true, FieldType: "[]" + tn, Encoder: enc.Encoder, Convert: enc.Type, Max: max1, Submax: max2, } } else if enc, ok := xdrEncoders[tn]; ok { f = fieldInfo{ Name: fn, IsBasic: true, IsSlice: true, FieldType: tn, Encoder: enc.Encoder, Convert: enc.Type, Max: max1, Submax: max2, } } else { f = fieldInfo{ Name: fn, IsSlice: true, FieldType: tn, Max: max1, Submax: max2, } } case *ast.SelectorExpr: f = fieldInfo{ Name: fn, FieldType: ft.Sel.Name, Max: max1, Submax: max2, } } fs = append(fs, f) } return fs } func generateCode(output io.Writer, s structInfo) { var buf bytes.Buffer var err error if len(s.Fields) == 0 { // This is an empty type. We can create a quite simple codec for it. err = emptyTypeTpl.Execute(&buf, s) } else { // Generate with the default template. err = encodeTpl.Execute(&buf, s) } if err != nil { panic(err) } bs := regexp.MustCompile(`(\s*\n)+`).ReplaceAll(buf.Bytes(), []byte("\n")) bs = bytes.Replace(bs, []byte("//+n"), []byte("\n"), -1) output.Write(bs) } func uncamelize(s string) string { return regexp.MustCompile("[a-z][A-Z]").ReplaceAllStringFunc(s, func(camel string) string { return camel[:1] + " " + camel[1:] }) } func generateDiagram(output io.Writer, s structInfo) { sn := s.Name fs := s.Fields fmt.Fprintln(output, sn+" Structure:") if len(fs) == 0 { fmt.Fprintln(output, "(contains no fields)") fmt.Fprintln(output) fmt.Fprintln(output) return } fmt.Fprintln(output) fmt.Fprintln(output, " 0 1 2 3") fmt.Fprintln(output, " 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1") line := "+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+" fmt.Fprintln(output, line) for _, f := range fs { tn := f.FieldType name := uncamelize(f.Name) suffix := "" if f.IsSlice { fmt.Fprintf(output, "| %s |\n", center("Number of "+name, 61)) fmt.Fprintln(output, line) suffix = " (n items)" fmt.Fprintf(output, "/ %s /\n", center("", 61)) } switch tn { case "bool": fmt.Fprintf(output, "| %s |V|\n", center(name+" (V=0 or 1)", 59)) case "int16", "uint16": fmt.Fprintf(output, "| %s | %s |\n", center("16 zero bits", 29), center(name, 29)) case "int8", "uint8": fmt.Fprintf(output, "| %s | %s |\n", center("24 zero bits", 45), center(name, 13)) case "int32", "uint32": fmt.Fprintf(output, "| %s |\n", center(name+suffix, 61)) case "int64", "uint64": fmt.Fprintf(output, "| %-61s |\n", "") fmt.Fprintf(output, "+ %s +\n", center(name+" (64 bits)", 61)) fmt.Fprintf(output, "| %-61s |\n", "") case "string", "[]byte": fmt.Fprintf(output, "/ %61s /\n", "") fmt.Fprintf(output, "\\ %s \\\n", center(name+" (length + padded data)", 61)) fmt.Fprintf(output, "/ %61s /\n", "") default: if f.IsSlice { tn = "Zero or more " + tn + " Structures" fmt.Fprintf(output, "\\ %s \\\n", center(tn, 61)) } else { tn = tn + " Structure" fmt.Fprintf(output, "/ %s /\n", center("", 61)) fmt.Fprintf(output, "\\ %s \\\n", center(tn, 61)) fmt.Fprintf(output, "/ %s /\n", center("", 61)) } } if f.IsSlice { fmt.Fprintf(output, "/ %s /\n", center("", 61)) } fmt.Fprintln(output, line) } fmt.Fprintln(output) fmt.Fprintln(output) } func generateXdr(output io.Writer, s structInfo) { sn := s.Name fs := s.Fields fmt.Fprintf(output, "struct %s {\n", sn) for _, f := range fs { tn := f.FieldType fn := f.Name suf := "" l := "" if f.Max > 0 { l = strconv.Itoa(f.Max) } if f.IsSlice { suf = "<" + l + ">" } switch tn { case "int8", "int16", "int32": fmt.Fprintf(output, "\tint %s%s;\n", fn, suf) case "uint8", "uint16", "uint32": fmt.Fprintf(output, "\tunsigned int %s%s;\n", fn, suf) case "int64": fmt.Fprintf(output, "\thyper %s%s;\n", fn, suf) case "uint64": fmt.Fprintf(output, "\tunsigned hyper %s%s;\n", fn, suf) case "string": fmt.Fprintf(output, "\tstring %s<%s>;\n", fn, l) case "[]byte": fmt.Fprintf(output, "\topaque %s<%s>;\n", fn, l) default: fmt.Fprintf(output, "\t%s %s%s;\n", tn, fn, suf) } } fmt.Fprintln(output, "}") fmt.Fprintln(output) } func center(s string, w int) string { w -= len(s) l := w / 2 r := l if l+r < w { r++ } return strings.Repeat(" ", l) + s + strings.Repeat(" ", r) } func inspector(structs *[]structInfo) func(ast.Node) bool { return func(n ast.Node) bool { switch n := n.(type) { case *ast.TypeSpec: switch t := n.Type.(type) { case *ast.StructType: name := n.Name.Name fs := handleStruct(t) *structs = append(*structs, structInfo{name, fs}) } return false default: return true } } } func main() { outputFile := flag.String("o", "", "Output file, blank for stdout") flag.Parse() fname := flag.Arg(0) fset := token.NewFileSet() f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments) if err != nil { log.Fatal(err) } var structs []structInfo i := inspector(&structs) ast.Inspect(f, i) buf := new(bytes.Buffer) headerTpl.Execute(buf, map[string]string{"Package": f.Name.Name}) for _, s := range structs { fmt.Fprintf(buf, "\n/*\n\n") generateDiagram(buf, s) generateXdr(buf, s) fmt.Fprintf(buf, "*/\n") generateCode(buf, s) } bs, err := format.Source(buf.Bytes()) if err != nil { log.Print(buf.String()) log.Fatal(err) } var output io.Writer = os.Stdout if *outputFile != "" { fd, err := os.Create(*outputFile) if err != nil { log.Fatal(err) } output = fd } output.Write(bs) }