Skip to content

Commit 145dcc2

Browse files
author
Maximilian Szengel
committed
Init
0 parents  commit 145dcc2

File tree

5 files changed

+329
-0
lines changed

5 files changed

+329
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
example/example
2+
.DS_Store
3+
example/example_curried.go

README.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
> I like my chicken like my functions of multiple arguments: curried. - Michael Snoyman
2+
3+
curry
4+
=======
5+
6+
Function curring for [Go](https://golang.org).
7+
8+
Usage
9+
-----
10+
11+
1. Install `go get github.com/maxsz/curry`
12+
2. Add `//go:generate curry` as a first line of any go file within a project.
13+
3. Run `go generate` before `go build`
14+
4. Use generated curried versions of any functions in your code. Example:
15+
16+
```go
17+
// Example function definition
18+
func example(first, second string) { return first + " " + second }
19+
20+
// Use generated, curried version of the function
21+
hello := exampleC("hello")
22+
hello("world")
23+
```
24+
25+
Motivation
26+
----------
27+
28+
Although I really like go, I think it's biggest draw-back is the missing
29+
generics support. I wanted to explore and learn about the go type system and
30+
code generation tools and how they could help mitigate the missing generics
31+
support. This is what came out of it.
32+
33+
References
34+
----------
35+
[The Go Blog - Generating code](https://blog.golang.org/generate)

curry.go

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"flag"
6+
"fmt"
7+
"go/ast"
8+
"go/build"
9+
"go/format"
10+
"go/importer"
11+
"go/parser"
12+
"go/token"
13+
"go/types"
14+
"io"
15+
"io/ioutil"
16+
"log"
17+
"path/filepath"
18+
"strings"
19+
20+
"golang.org/x/tools/go/ast/astutil"
21+
)
22+
23+
const (
24+
fileHeader = "// Code generated by \"curry\" " +
25+
"(http://github.com/maxsz/curry) - DO NOT EDIT\n\n"
26+
)
27+
28+
type file struct {
29+
Name string
30+
Directory string
31+
}
32+
33+
func (f file) Path() string {
34+
return filepath.Join(f.Directory, f.Name)
35+
}
36+
37+
func main() {
38+
var (
39+
fileSuffix string
40+
functionModifier string
41+
)
42+
flag.StringVar(&fileSuffix, "file-suffix", "_curry.go",
43+
"Curried file suffix.",
44+
)
45+
flag.StringVar(&functionModifier, "function-modifier", "C",
46+
"The string appended to the curried function name.",
47+
)
48+
flag.Parse()
49+
args := flag.Args()
50+
if len(args) == 0 {
51+
args = []string{"."}
52+
}
53+
54+
files, err := parsePackageDir(args[0], fileSuffix)
55+
if err != nil {
56+
log.Fatal("Failed to parse package dir:", err)
57+
}
58+
fs := token.NewFileSet()
59+
for _, f := range files {
60+
src, err := curryFile(fs, f, functionModifier)
61+
if err != nil {
62+
log.Fatal(err)
63+
}
64+
src, err = removeUnusedImports(fs, src)
65+
if err != nil {
66+
log.Fatal(err)
67+
}
68+
curriedFile := file{
69+
Name: strings.Replace(f.Name, ".go", fileSuffix, -1),
70+
Directory: f.Directory,
71+
}
72+
ioutil.WriteFile(curriedFile.Path(), src, 0644)
73+
}
74+
}
75+
76+
// parsePackageDir checks the `directory` and collects all relevant go files
77+
// that do not have the `fileSuffix`.
78+
func parsePackageDir(directory, fileSuffix string) ([]file, error) {
79+
pkg, err := build.Default.ImportDir(directory, 0)
80+
if err != nil {
81+
return nil, err
82+
}
83+
84+
var files []file
85+
addFiles := func(x []string) {
86+
for _, filename := range x {
87+
if !strings.HasSuffix(filename, ".go") ||
88+
strings.HasSuffix(filename, fileSuffix) {
89+
continue
90+
}
91+
files = append(files, file{Name: filename, Directory: directory})
92+
}
93+
}
94+
addFiles(pkg.GoFiles)
95+
addFiles(pkg.CgoFiles)
96+
addFiles(pkg.SFiles)
97+
98+
return files, nil
99+
}
100+
101+
// removeUnusedImports removes all unused imports from `src`.
102+
func removeUnusedImports(fs *token.FileSet, src []byte) ([]byte, error) {
103+
f, err := parser.ParseFile(fs, "", src, 0)
104+
if err != nil {
105+
return nil, err
106+
}
107+
for _, imp := range f.Imports {
108+
path := strings.Trim(imp.Path.Value, "\"")
109+
if !astutil.UsesImport(f, path) {
110+
astutil.DeleteImport(fs, f, path)
111+
}
112+
}
113+
114+
var srcOut bytes.Buffer
115+
err = format.Node(&srcOut, fs, f)
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
return srcOut.Bytes(), nil
121+
}
122+
123+
// curryFile analyzes the given `file` from fileset `fs` and returns a string
124+
// containing all curried function versions. The `modifier` is appended to the
125+
// function names, so `func bla(arg1, arg2 string)` becomes
126+
// `func blaC(arg1 string)... `if modifier is "C".
127+
func curryFile(fs *token.FileSet, file file, modifier string) ([]byte, error) {
128+
f, err := parser.ParseFile(fs, file.Path(), nil, 0)
129+
if err != nil {
130+
return nil, err
131+
}
132+
info := types.Info{
133+
Types: make(map[ast.Expr]types.TypeAndValue),
134+
Defs: make(map[*ast.Ident]types.Object),
135+
Uses: make(map[*ast.Ident]types.Object),
136+
}
137+
conf := types.Config{Importer: importer.Default()}
138+
conf.Check(file.Directory, fs, []*ast.File{f}, &info)
139+
140+
var src bytes.Buffer
141+
fmt.Fprint(&src, fileHeader)
142+
fmt.Fprintf(&src, "package %s\n\n", f.Name.String())
143+
144+
for _, imp := range f.Imports {
145+
fmt.Fprintf(&src, "import %s\n", imp.Path.Value)
146+
}
147+
148+
nameMod := strings.TrimSpace(modifier)
149+
if len(modifier) == 0 {
150+
nameMod = "C"
151+
}
152+
ast.Inspect(f, func(n ast.Node) bool {
153+
switch n := n.(type) {
154+
case *ast.FuncDecl:
155+
curryFunction(&src, n.Name.String(), nameMod, fs,
156+
info.Defs[n.Name].Type().(*types.Signature),
157+
)
158+
}
159+
return true
160+
})
161+
return format.Source(src.Bytes())
162+
}
163+
164+
// curryFunction writes all function versions (one for each argument) to `src`.
165+
func curryFunction(src io.Writer, name, modifier string,
166+
fs *token.FileSet, funcSig *types.Signature,
167+
) error {
168+
if funcSig.Params().Len() < 2 {
169+
return nil
170+
}
171+
typeName := func(v *types.Var) string {
172+
var buf bytes.Buffer
173+
types.WriteType(&buf, v.Type(), func(p *types.Package) string {
174+
if p.Name() != "main" {
175+
return p.Name()
176+
}
177+
return ""
178+
})
179+
return buf.String()
180+
}
181+
params := funcSig.Params()
182+
for i := 0; i < params.Len(); i++ {
183+
param := params.At(i)
184+
if i == 0 {
185+
if funcSig.Recv() != nil {
186+
recv := funcSig.Recv()
187+
fmt.Fprintf(src, "func (%s %s) %s%s",
188+
recv.Name(),
189+
typeName(recv),
190+
name,
191+
modifier,
192+
)
193+
} else {
194+
fmt.Fprintf(src, "func %s%s", name, modifier)
195+
}
196+
} else {
197+
fmt.Fprintf(src, "return func ")
198+
}
199+
fmt.Fprintf(src, "(%s %s) ", param.Name(), typeName(param))
200+
for j := i + 1; j < params.Len(); j++ {
201+
remaining := params.At(j)
202+
fmt.Fprintf(src, "func (%s) ", typeName(remaining))
203+
}
204+
fmt.Fprintf(src, "%s{\n", funcSig.Results().String())
205+
}
206+
receiver := ""
207+
if funcSig.Recv() != nil {
208+
receiver = fmt.Sprintf("%s.", funcSig.Recv().Name())
209+
}
210+
fmt.Fprintf(src, "return %s%s(", receiver, name)
211+
for i := 0; i < params.Len(); i++ {
212+
param := params.At(i)
213+
fmt.Fprintf(src, "%s", param.Name())
214+
if i < params.Len()-1 {
215+
fmt.Fprint(src, ", ")
216+
} else {
217+
fmt.Fprint(src, ")\n")
218+
}
219+
}
220+
for i := 0; i < params.Len(); i++ {
221+
fmt.Fprintf(src, "}\n")
222+
}
223+
return nil
224+
}

example/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Example curry
2+
3+
To run the example:
4+
```shell
5+
go generate
6+
go build
7+
./example
8+
```

example/example.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//go:generate curry -file-suffix=_curried.go -function-modifier=C
2+
3+
package main
4+
5+
import "fmt"
6+
7+
// The generator will generate a curried version of the function:
8+
//
9+
// func appendC(seperator string) func(string) string {
10+
// return func(str string) string {
11+
// return append(seperator, str)
12+
// }
13+
// }
14+
func append(seperator, str string) string {
15+
return str + seperator
16+
}
17+
18+
type Superhero struct {
19+
Name string
20+
}
21+
22+
// The generator will generate a curried version of the function:
23+
//
24+
// func (s Superhero) AbilityC(against Superhero) func(bool) string {
25+
// return func(destructive bool) string {
26+
// return s.Ability(against, destructive)
27+
// }
28+
// }
29+
func (s Superhero) Ability(against Superhero, destructive bool) string {
30+
w := "wins"
31+
if !destructive {
32+
w = "loses"
33+
}
34+
return fmt.Sprintf("%s uses ability agains %s and %s the fight.",
35+
s.Name, against.Name, w,
36+
)
37+
}
38+
39+
func main() {
40+
sentences := []string{
41+
"Programmer",
42+
"A person who fixed a problem",
43+
"That you don't know you have",
44+
"In a way you don't understand",
45+
}
46+
appendNewline := appendC(".\n")
47+
48+
var result string
49+
for _, sentence := range sentences {
50+
result += appendNewline(sentence)
51+
}
52+
fmt.Println(result)
53+
54+
batman := Superhero{Name: "Batman"}
55+
superman := Superhero{Name: "Superman"}
56+
batmanVsSuperman := batman.AbilityC(superman)
57+
fmt.Println(batmanVsSuperman(true))
58+
fmt.Println(batmanVsSuperman(false))
59+
}

0 commit comments

Comments
 (0)