forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
templates.go
96 lines (83 loc) · 2.79 KB
/
templates.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package gorgonia
import (
"fmt"
"strings"
"text/template"
)
const exprNodeTemplText = `<
<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" PORT="anchor" {{if isLeaf .}} COLOR="#00FF00;"{{else if isRoot . }} COLOR="#FF0000;" {{else if isMarked .}} COLOR="#0000FF;" {{end}}{{if isInput .}} BGCOLOR="lightyellow"{{else if isStmt .}} BGCOLOR="lightblue"{{end}}>
<TR><TD>{{printf "%x" .ID}}</TD><TD>{{printf "%v" .Name | html | dotEscape}} :: {{nodeType . | html | dotEscape }}</TD></TR>
{{if printOp . }}<TR><TD>Op</TD><TD>{{ opStr . | html | dotEscape }} :: {{ opType . | html | dotEscape }}</TD></TR>{{end}}
{{if hasShape .}}<TR><TD COLSPAN="2">{{ getShape .}}</TD></TR>{{end}}
<TR><TD COLSPAN="2">{{overwritesInput . }}</TD></TR>
{{if hasGrad .}}<TR><TD>Value</TD><TD>Grad</TD></TR>
<TR><TD>{{printf "%+3.3s" .Value | dotEscape}}</TD><TD>{{getGrad . | dotEscape }} </TD></TR>
{{else}}
<TR><TD>Value</TD><TD>{{printf "%+3.3s" .Value | dotEscape}}</TD></TR>
{{end}}
</TABLE>
>`
func dotEscape(s string) string {
s = strings.Replace(s, "\n", "<BR />", -1)
s = strings.Replace(s, "<nil>", "NIL", -1)
return s
}
func printOp(n *Node) bool { return n.op != nil && !n.isStmt }
func isLeaf(n *Node) bool { return len(n.children) == 0 }
func isInput(n *Node) bool { return n.isInput() }
func isMarked(n *Node) bool { return n.ofInterest }
func isRoot(n *Node) bool { return n.isRoot() }
func isStmt(n *Node) bool { return n.isStmt }
func hasShape(n *Node) bool { return n.shape != nil }
func hasGrad(n *Node) bool { _, err := n.Grad(); return err == nil }
func opStr(n *Node) string { return n.op.String() }
func opType(n *Node) string { return n.op.Type().String() }
func nodeType(n *Node) string {
if n.t == nil {
return "NIL"
}
return n.t.String()
}
func overwritesInput(n *Node) int {
if n.op == nil {
return -1
}
return n.op.OverwritesInput()
}
func getShape(n *Node) string {
if !n.inferredShape {
return fmt.Sprintf("%v", n.shape)
}
return fmt.Sprintf("<U>%v</U>", n.shape) // graphviz 2.38+ only supports <O>
}
func getGrad(n *Node) string {
grad, err := n.Grad()
if err == nil {
return fmt.Sprintf("%+3.3s (%p)", grad, grad)
}
return ""
}
var funcMap = template.FuncMap{
"dotEscape": dotEscape,
"printOp": printOp,
"isRoot": isRoot,
"isLeaf": isLeaf,
"isInput": isInput,
"isStmt": isStmt,
"isMarked": isMarked,
"hasShape": hasShape,
"hasGrad": hasGrad,
"getShape": getShape,
"getGrad": getGrad,
"overwritesInput": overwritesInput,
"opStr": opStr,
"opType": opType,
"nodeType": nodeType,
}
var (
exprNodeTempl *template.Template
exprNodeJSONTempl *template.Template
)
func init() {
exprNodeTempl = template.Must(template.New("node").Funcs(funcMap).Parse(exprNodeTemplText))
}