Skip to content

Commit 7b6c28e

Browse files
committed
Confusion matrix!
1 parent a400cce commit 7b6c28e

File tree

5 files changed

+107
-7
lines changed

5 files changed

+107
-7
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ require (
174174
github.com/coder/flog v1.0.0 // indirect
175175
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
176176
github.com/gocarina/gocsv v0.0.0-20230123225133-763e25b40669 // indirect
177+
github.com/goml/gobrain v0.0.0-20201212123421-2e2d98ca8249 // indirect
177178
github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320 // indirect
178179
github.com/hashicorp/go-hclog v1.2.1 // indirect
179180
github.com/hashicorp/go-plugin v1.4.4 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,8 @@ github.com/golangci/maligned v0.0.0-20180506175553-b1d89398deca/go.mod h1:tvlJhZ
860860
github.com/golangci/misspell v0.3.5/go.mod h1:dEbvlSfYbMQDtrpRMQU675gSDLDNa8sCPPChZ7PhiVA=
861861
github.com/golangci/revgrep v0.0.0-20210930125155-c22e5001d4f2/go.mod h1:LK+zW4MpyytAWQRz0M4xnzEk50lSvqDQKfx304apFkY=
862862
github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4/go.mod h1:Izgrg8RkN3rCIMLGE9CyYmU9pY2Jer6DgANEnZ/L/cQ=
863+
github.com/goml/gobrain v0.0.0-20201212123421-2e2d98ca8249 h1:Xst86cFqcrNSUxId+2A/o3NZUwm7H+H6bUcehmD2t4o=
864+
github.com/goml/gobrain v0.0.0-20201212123421-2e2d98ca8249/go.mod h1:imJK2QRE3080lm54nE96nRHXv7EvvVMCbF5N2QdGa9I=
863865
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
864866
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
865867
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=

scripts/aistart/loadcsv.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,23 @@ type trainingRow struct {
2727
// HourOfDay ranges from 0 to 23
2828
HourOfDay int `csv:"hour"`
2929
// Day of Week ranges from 0 to 6
30-
DayOfWeek int `csv:"day"`
31-
Used bool `csv:"used"`
30+
DayOfWeek int `csv:"day"`
31+
Used int `csv:"used"`
32+
}
33+
34+
func (t trainingRow) vectorize() vector {
35+
return [][]float64{
36+
{float64(t.HourOfDay) / 23, float64(t.DayOfWeek / 6)},
37+
{float64(t.Used)},
38+
}
3239
}
3340

3441
type dbRow struct {
3542
Time time.Time
3643
WorkspaceID string
3744
}
3845

39-
func (db dbRow) convert(used bool) trainingRow {
46+
func (db dbRow) convert(used int) trainingRow {
4047
return trainingRow{
4148
WorkspaceID: db.WorkspaceID,
4249
HourOfDay: db.Time.Hour(),
@@ -66,12 +73,12 @@ func generateTrainingRows(rs []dbRow) []trainingRow {
6673
WorkspaceID: wid,
6774
HourOfDay: last.Hour(),
6875
DayOfWeek: int(last.Weekday()),
69-
Used: false,
76+
Used: 0,
7077
})
7178
}
7279
}
7380
}
74-
trainingRows = append(trainingRows, r.convert(true))
81+
trainingRows = append(trainingRows, r.convert(1))
7582
}
7683

7784
return trainingRows
@@ -124,8 +131,8 @@ func loadTrainingCSV() *cobra.Command {
124131
flog.Info("loaded %v rows", len(rs))
125132
trainingRows := generateTrainingRows(rs)
126133
flog.Info("generated %v training rows", len(trainingRows))
127-
gocsv.Marshal(trainingRows, os.Stdout)
128-
return nil
134+
err = gocsv.Marshal(trainingRows, os.Stdout)
135+
return err
129136
},
130137
}
131138
}

scripts/aistart/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func main() {
1616
}
1717

1818
cmd.AddCommand(loadTrainingCSV())
19+
cmd.AddCommand(train())
1920

2021
cmd, err := cmd.ExecuteC()
2122
if err != nil {

scripts/aistart/train.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"math"
6+
"math/rand"
7+
"os"
8+
"text/tabwriter"
9+
10+
"github.com/gocarina/gocsv"
11+
"github.com/goml/gobrain"
12+
"github.com/spf13/cobra"
13+
14+
"github.com/coder/flog"
15+
)
16+
17+
type vector [][]float64
18+
19+
type pattern []vector
20+
21+
func (p pattern) floats() [][][]float64 {
22+
var r [][][]float64
23+
for _, v := range p {
24+
r = append(r, [][]float64(v))
25+
}
26+
return r
27+
}
28+
29+
func vectorizeTrainingRows(rs []trainingRow) pattern {
30+
var p pattern
31+
for _, r := range rs {
32+
p = append(p, r.vectorize())
33+
}
34+
return p
35+
}
36+
37+
func splitTrainTest(rat float64, p pattern) (train, test pattern) {
38+
perms := rand.Perm(len(p))
39+
for i, v := range p {
40+
if float64(perms[i])/float64(len(p)) > rat {
41+
test = append(test, v)
42+
} else {
43+
train = append(train, v)
44+
}
45+
}
46+
return train, test
47+
}
48+
49+
func train() *cobra.Command {
50+
return &cobra.Command{
51+
Use: "train",
52+
RunE: func(cmd *cobra.Command, _ []string) error {
53+
var rs []trainingRow
54+
55+
err := gocsv.Unmarshal(os.Stdin, &rs)
56+
if err != nil {
57+
return err
58+
}
59+
60+
all := vectorizeTrainingRows(rs)
61+
62+
train, test := splitTrainTest(0.5, all)
63+
64+
flog.Info("split train test: %v/%v", len(train), len(test))
65+
66+
ff := &gobrain.FeedForward{}
67+
ff.Init(2, 2, 1)
68+
ff.Train(train.floats(), 50, 0.001, 0.4, true)
69+
var (
70+
// confusionMatrix has actual values in the first index with
71+
// predicted values in the second.
72+
confusionMatrix [2][2]int
73+
)
74+
for _, v := range train {
75+
want := v[1][0]
76+
gotArr := ff.Update(v[0])
77+
got := gotArr[0]
78+
confusionMatrix[0][int(math.Round(want))]++
79+
confusionMatrix[1][int(math.Round(got))]++
80+
}
81+
twr := tabwriter.NewWriter(os.Stderr, 0, 4, 3, ' ', 0)
82+
_, _ = fmt.Fprintf(twr, "-\tOff\tOn\n")
83+
_, _ = fmt.Fprintf(twr, "Actual\t%v\t%v\n", confusionMatrix[0][0], confusionMatrix[0][1])
84+
_, _ = fmt.Fprintf(twr, "Predicted\t%v\t%v\n", confusionMatrix[1][0], confusionMatrix[1][1])
85+
twr.Flush()
86+
return nil
87+
},
88+
}
89+
}

0 commit comments

Comments
 (0)