@@ -4,6 +4,7 @@ import TensorFlow
4
4
public enum TrainingMetrics {
5
5
case loss
6
6
case accuracy
7
+ case top5Accuracy
7
8
case matthewsCorrelationCoefficient
8
9
case perplexity
9
10
@@ -13,6 +14,8 @@ public enum TrainingMetrics {
13
14
return " loss "
14
15
case . accuracy:
15
16
return " accuracy "
17
+ case . top5Accuracy:
18
+ return " top5Accuracy "
16
19
case . matthewsCorrelationCoefficient:
17
20
return " mcc "
18
21
case . perplexity:
@@ -25,7 +28,9 @@ public enum TrainingMetrics {
25
28
case . loss:
26
29
return LossMeasurer ( self . name)
27
30
case . accuracy:
28
- return AccuracyMeasurer ( self . name)
31
+ return TopKAccuracyMeasurer ( self . name)
32
+ case . top5Accuracy:
33
+ return TopKAccuracyMeasurer ( self . name, n: 5 )
29
34
case . matthewsCorrelationCoefficient:
30
35
return MCCMeasurer ( self . name)
31
36
case . perplexity:
@@ -89,20 +94,22 @@ public struct LossMeasurer: MetricsMeasurer {
89
94
}
90
95
}
91
96
92
- /// A measurer for measuring accuracy
93
- public struct AccuracyMeasurer : MetricsMeasurer {
97
+ /// A measurer for measuring accuracy (top k, default k=1)
98
+ public struct TopKAccuracyMeasurer : MetricsMeasurer {
94
99
/// Name of the AccuracyMeasurer.
95
100
public var name : String
101
+ public var k : Int32 = 1
96
102
97
103
/// Count of correct guesses.
98
104
private var correctGuessCount : Int32 = 0
99
105
100
106
/// Count of total guesses.
101
107
private var totalGuessCount : Int32 = 0
102
108
103
- /// Creates an instance with the AccuracyMeasurer named `name`.
104
- public init ( _ name: String = " accuracy " ) {
109
+ /// Creates an instance with the TopKAccuracyMeasurer named `name`.
110
+ public init ( _ name: String = " accuracy " , n : Int32 = 1 ) {
105
111
self . name = name
112
+ self . k = n
106
113
}
107
114
108
115
/// Resets correctGuessCount and totalGuessCount to zero.
@@ -123,8 +130,15 @@ public struct AccuracyMeasurer: MetricsMeasurer {
123
130
" For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>. "
124
131
)
125
132
}
126
- correctGuessCount += Tensor < Int32 > ( predictions. argmax ( squeezingAxis: - 1 ) .== labels) . sum ( )
127
- . scalarized ( )
133
+ let predictionsReshaped = predictions. reshaped (
134
+ to: [ predictions. shape. dropLast ( ) . reduce ( 1 , * ) , predictions. shape. last!] )
135
+ let labelsReshaped = labels. reshaped ( to: [ labels. shape. reduce ( 1 , * ) ] )
136
+
137
+ correctGuessCount += Int32 (
138
+ Tensor < Int32 > (
139
+ _Raw. inTopKV2 (
140
+ predictions: predictionsReshaped, targets: labelsReshaped, k: Tensor < Int32 > ( k) ) ) . sum ( )
141
+ . scalar ?? 0 )
128
142
totalGuessCount += Int32 ( labels. shape. reduce ( 1 , * ) )
129
143
}
130
144
0 commit comments