Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit 6189d7a

Browse files
authored
add top 5 accuracy metric (#718)
1 parent d279552 commit 6189d7a

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

Diff for: TrainingLoop/Metrics.swift

+21-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import TensorFlow
44
public enum TrainingMetrics {
55
case loss
66
case accuracy
7+
case top5Accuracy
78
case matthewsCorrelationCoefficient
89
case perplexity
910

@@ -13,6 +14,8 @@ public enum TrainingMetrics {
1314
return "loss"
1415
case .accuracy:
1516
return "accuracy"
17+
case .top5Accuracy:
18+
return "top5Accuracy"
1619
case .matthewsCorrelationCoefficient:
1720
return "mcc"
1821
case .perplexity:
@@ -25,7 +28,9 @@ public enum TrainingMetrics {
2528
case .loss:
2629
return LossMeasurer(self.name)
2730
case .accuracy:
28-
return AccuracyMeasurer(self.name)
31+
return TopKAccuracyMeasurer(self.name)
32+
case .top5Accuracy:
33+
return TopKAccuracyMeasurer(self.name, n: 5)
2934
case .matthewsCorrelationCoefficient:
3035
return MCCMeasurer(self.name)
3136
case .perplexity:
@@ -89,20 +94,22 @@ public struct LossMeasurer: MetricsMeasurer {
8994
}
9095
}
9196

92-
/// A measurer for measuring accuracy
93-
public struct AccuracyMeasurer: MetricsMeasurer {
97+
/// A measurer for measuring accuracy (top k, default k=1)
98+
public struct TopKAccuracyMeasurer: MetricsMeasurer {
9499
/// Name of the AccuracyMeasurer.
95100
public var name: String
101+
public var k: Int32 = 1
96102

97103
/// Count of correct guesses.
98104
private var correctGuessCount: Int32 = 0
99105

100106
/// Count of total guesses.
101107
private var totalGuessCount: Int32 = 0
102108

103-
/// Creates an instance with the AccuracyMeasurer named `name`.
104-
public init(_ name: String = "accuracy") {
109+
/// Creates an instance with the TopKAccuracyMeasurer named `name`.
110+
public init(_ name: String = "accuracy", n: Int32 = 1) {
105111
self.name = name
112+
self.k = n
106113
}
107114

108115
/// Resets correctGuessCount and totalGuessCount to zero.
@@ -123,8 +130,15 @@ public struct AccuracyMeasurer: MetricsMeasurer {
123130
"For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>."
124131
)
125132
}
126-
correctGuessCount += Tensor<Int32>(predictions.argmax(squeezingAxis: -1) .== labels).sum()
127-
.scalarized()
133+
let predictionsReshaped = predictions.reshaped(
134+
to: [predictions.shape.dropLast().reduce(1, *), predictions.shape.last!])
135+
let labelsReshaped = labels.reshaped(to: [labels.shape.reduce(1, *)])
136+
137+
correctGuessCount += Int32(
138+
Tensor<Int32>(
139+
_Raw.inTopKV2(
140+
predictions: predictionsReshaped, targets: labelsReshaped, k: Tensor<Int32>(k))).sum()
141+
.scalar ?? 0)
128142
totalGuessCount += Int32(labels.shape.reduce(1, *))
129143
}
130144

0 commit comments

Comments
 (0)