Skip to content

Commit daf920b

Browse files
committed
Finish with gobrain
1 parent 7b6c28e commit daf920b

File tree

2 files changed

+63
-27
lines changed

2 files changed

+63
-27
lines changed

scripts/aistart/loadcsv.go

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,33 @@ type trainingRow struct {
2727
// HourOfDay ranges from 0 to 23
2828
HourOfDay int `csv:"hour"`
2929
// Day of Week ranges from 0 to 6
30-
DayOfWeek int `csv:"day"`
31-
Used int `csv:"used"`
30+
DayOfWeek int `csv:"day"`
31+
HoursSinceUsed int `csv:"hours_since_used"`
32+
Used int `csv:"used"`
33+
}
34+
35+
func (t trainingRow) vectorizeHourOfDay() []float64 {
36+
var fs []float64
37+
for i := 0; i < 24; i++ {
38+
if i == t.HourOfDay {
39+
fs = append(fs, 1)
40+
} else {
41+
fs = append(fs, 0)
42+
}
43+
}
44+
return fs
3245
}
3346

3447
func (t trainingRow) vectorize() vector {
3548
return [][]float64{
36-
{float64(t.HourOfDay) / 23, float64(t.DayOfWeek / 6)},
49+
append(
50+
[]float64{
51+
float64(t.HoursSinceUsed) / 61,
52+
float64(t.HourOfDay) / 23,
53+
float64(t.DayOfWeek),
54+
},
55+
t.vectorizeHourOfDay()...,
56+
),
3757
{float64(t.Used)},
3858
}
3959
}
@@ -43,21 +63,15 @@ type dbRow struct {
4363
WorkspaceID string
4464
}
4565

46-
func (db dbRow) convert(used int) trainingRow {
47-
return trainingRow{
48-
WorkspaceID: db.WorkspaceID,
49-
HourOfDay: db.Time.Hour(),
50-
DayOfWeek: int(db.Time.Weekday()),
51-
Used: used,
52-
}
53-
}
54-
5566
// generateTrainingRows accepts sparse input data from the DB and creates
5667
// trainingRows suitable to enter a prediction model.
5768
func generateTrainingRows(rs []dbRow) []trainingRow {
58-
workspaceIDs := make(map[string]struct{})
69+
// WorkspaceIDs maps workspaces to the time they were previously seen.
70+
// We first generate a map of IDs to zero so we can easily fill in
71+
// missing hours.
72+
workspaceIDs := make(map[string]time.Time)
5973
for _, r := range rs {
60-
workspaceIDs[r.WorkspaceID] = struct{}{}
74+
workspaceIDs[r.WorkspaceID] = time.Time{}
6175
}
6276

6377
var trainingRows []trainingRow
@@ -66,19 +80,40 @@ func generateTrainingRows(rs []dbRow) []trainingRow {
6680
for _, r := range rs {
6781
if !r.Time.Equal(last) && !last.IsZero() {
6882
// We just skipped a time-slot, we must fill in the blanks.
69-
for last.Before(r.Time) {
83+
for {
7084
last = last.Add(time.Hour)
85+
if !last.Before(r.Time) {
86+
break
87+
}
7188
for wid := range workspaceIDs {
89+
var hoursSinceLastUsed int
90+
if !workspaceIDs[wid].IsZero() {
91+
hoursSinceLastUsed = int(last.Sub(workspaceIDs[wid]) / time.Hour)
92+
}
7293
trainingRows = append(trainingRows, trainingRow{
73-
WorkspaceID: wid,
74-
HourOfDay: last.Hour(),
75-
DayOfWeek: int(last.Weekday()),
76-
Used: 0,
94+
WorkspaceID: wid,
95+
HourOfDay: last.Hour(),
96+
DayOfWeek: int(last.Weekday()),
97+
HoursSinceUsed: hoursSinceLastUsed,
98+
Used: 0,
7799
})
78100
}
79101
}
80102
}
81-
trainingRows = append(trainingRows, r.convert(1))
103+
workspaceLastSeen := workspaceIDs[r.WorkspaceID]
104+
workspaceIDs[r.WorkspaceID] = r.Time
105+
106+
var hoursSinceLastSeen int
107+
if !workspaceLastSeen.IsZero() {
108+
hoursSinceLastSeen = int(r.Time.Sub(workspaceLastSeen) / time.Hour)
109+
}
110+
trainingRows = append(trainingRows, trainingRow{
111+
WorkspaceID: r.WorkspaceID,
112+
HourOfDay: r.Time.Hour(),
113+
DayOfWeek: int(r.Time.Weekday()),
114+
HoursSinceUsed: hoursSinceLastSeen,
115+
Used: 1,
116+
})
82117
}
83118

84119
return trainingRows
@@ -101,7 +136,7 @@ func loadTrainingCSV() *cobra.Command {
101136
JOIN workspaces w ON
102137
w.id = ag.workspace_id
103138
WHERE
104-
NOT w.deleted
139+
NOT w.deleted AND w.id = '0170be1c-735f-4a69-8223-8ef86af56ef5'
105140
GROUP BY
106141
workspace_id,
107142
user_id,

scripts/aistart/train.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ func vectorizeTrainingRows(rs []trainingRow) pattern {
3535
}
3636

3737
func splitTrainTest(rat float64, p pattern) (train, test pattern) {
38-
perms := rand.Perm(len(p))
38+
rng := rand.New(rand.NewSource(1))
39+
perms := rng.Perm(len(p))
3940
for i, v := range p {
4041
if float64(perms[i])/float64(len(p)) > rat {
4142
test = append(test, v)
@@ -64,14 +65,14 @@ func train() *cobra.Command {
6465
flog.Info("split train test: %v/%v", len(train), len(test))
6566

6667
ff := &gobrain.FeedForward{}
67-
ff.Init(2, 2, 1)
68-
ff.Train(train.floats(), 50, 0.001, 0.4, true)
68+
ff.Init(len((trainingRow{}).vectorize()[0]), 4, 1)
69+
ff.Train(train.floats(), 3000, 0.01, 0.4, true)
6970
var (
7071
// confusionMatrix has actual values in the first index with
7172
// predicted values in the second.
7273
confusionMatrix [2][2]int
7374
)
74-
for _, v := range train {
75+
for _, v := range test {
7576
want := v[1][0]
7677
gotArr := ff.Update(v[0])
7778
got := gotArr[0]
@@ -82,8 +83,8 @@ func train() *cobra.Command {
8283
_, _ = fmt.Fprintf(twr, "-\tOff\tOn\n")
8384
_, _ = fmt.Fprintf(twr, "Actual\t%v\t%v\n", confusionMatrix[0][0], confusionMatrix[0][1])
8485
_, _ = fmt.Fprintf(twr, "Predicted\t%v\t%v\n", confusionMatrix[1][0], confusionMatrix[1][1])
85-
twr.Flush()
86-
return nil
86+
err = twr.Flush()
87+
return err
8788
},
8889
}
8990
}

0 commit comments

Comments
 (0)