Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit adda08e

Browse files
authored
TrainingLoop: refactor progress printer and add CSVLogger (#668)
1 parent fafa4f0 commit adda08e

File tree

15 files changed

+556
-274
lines changed

15 files changed

+556
-274
lines changed

Examples/LeNet-MNIST/main.swift

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,33 @@ let dataset = MNIST(batchSize: batchSize, on: device)
3131

3232
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
3333
var classifier = Sequential {
34-
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
35-
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
36-
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
37-
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
38-
Flatten<Float>()
39-
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
40-
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
41-
Dense<Float>(inputSize: 84, outputSize: 10)
34+
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
35+
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
36+
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
37+
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
38+
Flatten<Float>()
39+
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
40+
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
41+
Dense<Float>(inputSize: 84, outputSize: 10)
4242
}
4343

4444
var optimizer = SGD(for: classifier, learningRate: 0.1)
4545

46-
let trainingProgress = TrainingProgress()
4746
var trainingLoop = TrainingLoop(
4847
training: dataset.training,
4948
validation: dataset.validation,
5049
optimizer: optimizer,
5150
lossFunction: softmaxCrossEntropy,
52-
callbacks: [trainingProgress.update])
51+
metrics: [.accuracy],
52+
callbacks: [try! CSVLogger().log])
53+
54+
// Compute statistics only when last batch ends.
55+
trainingLoop.statisticsRecorder!.shouldCompute = {
56+
(
57+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
58+
_ event: TrainingLoopEvent
59+
) -> Bool in
60+
return event == .batchEnd && batchIndex + 1 == batchCount
61+
}
5362

5463
try! trainingLoop.fit(&classifier, epochs: epochCount, on: device)

Examples/MobileNetV1-Imagenette/main.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
2929
var model = MobileNetV1(classCount: 10)
3030
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9)
3131

32-
let trainingProgress = TrainingProgress()
3332
var trainingLoop = TrainingLoop(
3433
training: dataset.training,
3534
validation: dataset.validation,
3635
optimizer: optimizer,
3736
lossFunction: softmaxCrossEntropy,
38-
callbacks: [trainingProgress.update])
37+
metrics: [.accuracy])
3938

4039
try! trainingLoop.fit(&model, epochs: 10, on: device)

Examples/MobileNetV2-Imagenette/main.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
2929
var model = MobileNetV2(classCount: 10)
3030
let optimizer = SGD(for: model, learningRate: 0.002, momentum: 0.9)
3131

32-
let trainingProgress = TrainingProgress()
3332
var trainingLoop = TrainingLoop(
3433
training: dataset.training,
3534
validation: dataset.validation,
3635
optimizer: optimizer,
3736
lossFunction: softmaxCrossEntropy,
38-
callbacks: [trainingProgress.update])
37+
metrics: [.accuracy])
3938

4039
try! trainingLoop.fit(&model, epochs: 10, on: device)

Examples/ResNet-CIFAR10/main.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ let dataset = CIFAR10(batchSize: 10, on: device)
2929
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
3030
var optimizer = SGD(for: model, learningRate: 0.001)
3131

32-
let trainingProgress = TrainingProgress()
3332
var trainingLoop = TrainingLoop(
3433
training: dataset.training,
3534
validation: dataset.validation,
3635
optimizer: optimizer,
3736
lossFunction: softmaxCrossEntropy,
38-
callbacks: [trainingProgress.update])
37+
metrics: [.accuracy])
3938

4039
try! trainingLoop.fit(&model, epochs: 10, on: device)

Examples/VGG-Imagewoof/main.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ public func scheduleLearningRate<L: TrainingLoopProtocol>(
3939
}
4040
}
4141

42-
let trainingProgress = TrainingProgress()
4342
var trainingLoop = TrainingLoop(
4443
training: dataset.training,
4544
validation: dataset.validation,
4645
optimizer: optimizer,
4746
lossFunction: softmaxCrossEntropy,
48-
callbacks: [trainingProgress.update, scheduleLearningRate])
47+
metrics: [.accuracy],
48+
callbacks: [scheduleLearningRate])
4949

5050
try! trainingLoop.fit(&model, epochs: 90, on: device)

Support/FileSystem.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ public protocol File {
3939
func read(position: Int, count: Int) throws -> Data
4040
func write(_ value: Data) throws
4141
func write(_ value: Data, position: Int) throws
42+
func append(_ value: Data) throws
4243
}

Support/FoundationFileSystem.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,14 @@ public struct FoundationFile: File {
5858
// TODO: Incorporate file offset.
5959
try value.write(to: location)
6060
}
61+
62+
/// Append data to the file.
63+
///
64+
/// Parameter value: data to be appended at the end.
65+
public func append(_ value: Data) throws {
66+
let fileHandler = try FileHandle(forUpdating: location)
67+
try fileHandler.seekToEnd()
68+
try fileHandler.write(contentsOf: value)
69+
try fileHandler.close()
70+
}
6171
}

TrainingLoop/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
add_library(TrainingLoop
22
LossFunctions.swift
3+
Metrics.swift
34
TrainingLoop.swift
4-
TrainingProgress.swift
5-
TrainingStatistics.swift)
5+
Callbacks/StatisticsRecorder.swift
6+
Callbacks/ProgressPrinter.swift
7+
Callbacks/CSVLogger.swift)
68
target_link_libraries(TrainingLoop PUBLIC
79
ModelSupport)
810
set_target_properties(TrainingLoop PROPERTIES
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import Foundation
2+
import ModelSupport
3+
4+
public enum CSVLoggerError: Error {
5+
case InvalidPath
6+
}
7+
8+
/// A handler for logging training and validation statistics to a CSV file.
9+
public class CSVLogger {
10+
/// The path of the file that statistics are logged to.
11+
public var path: String
12+
13+
// True iff the header of the CSV file has been written.
14+
fileprivate var headerWritten: Bool
15+
16+
/// Creates an instance that logs to a file with the given path.
17+
///
18+
/// Throws: File system errors.
19+
public init(path: String = "run/log.csv") throws {
20+
self.path = path
21+
22+
// Validate the path.
23+
let url = URL(fileURLWithPath: path)
24+
if url.pathExtension != "csv" {
25+
throw CSVLoggerError.InvalidPath
26+
}
27+
// Create the containing directory if it is missing.
28+
try FoundationFileSystem().createDirectoryIfMissing(at: url.deletingLastPathComponent().path)
29+
// Initialize the file with empty string.
30+
try FoundationFile(path: path).write(Data())
31+
32+
self.headerWritten = false
33+
}
34+
35+
/// Logs the statistics for the 'loop' when 'batchEnd' event happens;
36+
/// ignoring other events.
37+
///
38+
/// Throws: File system errors.
39+
public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
40+
switch event {
41+
case .batchEnd:
42+
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount,
43+
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount,
44+
let stats = loop.lastStatsLog
45+
else {
46+
// No-Op if trainingLoop doesn't set the required values for stats logging.
47+
return
48+
}
49+
50+
if !headerWritten {
51+
try writeHeader(stats: stats)
52+
headerWritten = true
53+
}
54+
55+
try writeDataRow(
56+
epoch: "\(epochIndex + 1)/\(epochCount)",
57+
batch: "\(batchIndex + 1)/\(batchCount)",
58+
stats: stats)
59+
default:
60+
return
61+
}
62+
}
63+
64+
func writeHeader(stats: [(name: String, value: Float)]) throws {
65+
let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n"
66+
try FoundationFile(path: path).append(header.data(using: .utf8)!)
67+
}
68+
69+
func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws {
70+
let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ")
71+
+ "\n"
72+
try FoundationFile(path: path).append(dataRow.data(using: .utf8)!)
73+
}
74+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
let progressBarLength = 30
18+
19+
/// A handler for printing the training and validation progress.
20+
public class ProgressPrinter {
21+
/// Print training or validation progress in response of the 'event'.
22+
///
23+
/// An example of the progress would be:
24+
/// Epoch 1/12
25+
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
26+
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
27+
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
28+
switch event {
29+
case .epochStart:
30+
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
31+
// No-Op if trainingLoop doesn't set the required values for progress printing.
32+
return
33+
}
34+
35+
Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
36+
case .batchEnd:
37+
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
38+
// No-Op if trainingLoop doesn't set the required values for progress printing.
39+
return
40+
}
41+
42+
let progressBar = formatProgressBar(
43+
progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength)
44+
var stats: String = ""
45+
if let lastStatsLog = loop.lastStatsLog {
46+
stats = formatStats(lastStatsLog)
47+
}
48+
49+
Swift.print(
50+
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
51+
terminator: ""
52+
)
53+
fflush(stdout)
54+
case .epochEnd:
55+
Swift.print("")
56+
case .validationStart:
57+
Swift.print("")
58+
default:
59+
return
60+
}
61+
}
62+
63+
func formatProgressBar(progress: Float, length: Int) -> String {
64+
let progressSteps = Int(round(Float(length) * progress))
65+
let leading = String(repeating: "=", count: progressSteps)
66+
let separator: String
67+
let trailing: String
68+
if progressSteps < progressBarLength {
69+
separator = ">"
70+
trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1)
71+
} else {
72+
separator = ""
73+
trailing = ""
74+
}
75+
return "[\(leading)\(separator)\(trailing)]"
76+
}
77+
78+
func formatStats(_ stats: [(String, Float)]) -> String {
79+
var result = ""
80+
for stat in stats {
81+
result += " - \(stat.0): \(String(format: "%.4f", stat.1))"
82+
}
83+
return result
84+
}
85+
}

0 commit comments

Comments
 (0)