@@ -20,41 +20,79 @@ import Foundation
20
20
import TensorFlow
21
21
import Batcher
22
22
23
- public struct KuzushijiMNIST : ImageClassificationDataset {
24
- public typealias SourceDataSet = [ TensorPair < Float , Int32 > ]
25
- public let training : Batcher < SourceDataSet >
26
- public let test : Batcher < SourceDataSet >
23
+ public struct KuzushijiMNIST < Entropy: RandomNumberGenerator > {
24
+ /// Type of the collection of non-collated batches.
25
+ public typealias Batches = Slices < Sampling < [ ( data: [ UInt8 ] , label: Int32 ) ] , ArraySlice < Int > > >
26
+ /// The type of the training data, represented as a sequence of epochs, which
27
+ /// are collection of batches.
28
+ public typealias Training = LazyMapSequence <
29
+ TrainingEpochs < [ ( data: [ UInt8 ] , label: Int32 ) ] , Entropy > ,
30
+ LazyMapSequence < Batches , LabeledImage >
31
+ >
32
+ /// The type of the validation data, represented as a collection of batches.
33
+ public typealias Validation = LazyMapSequence < Slices < [ ( data: [ UInt8 ] , label: Int32 ) ] > , LabeledImage >
34
+ /// The training epochs.
35
+ public let training : Training
36
+ /// The validation batches.
37
+ public let validation : Validation
27
38
28
- public init ( batchSize: Int ) {
29
- self . init ( batchSize: batchSize, flattening: false , normalizing: false )
30
- }
31
-
32
- public init (
33
- batchSize: Int , flattening: Bool = false , normalizing: Bool = false ,
34
- localStorageDirectory: URL = DatasetUtilities . defaultDirectory
35
- . appendingPathComponent ( " KuzushijiMNIST " , isDirectory: true )
36
- ) {
37
- training = Batcher < SourceDataSet > (
38
- on: fetchMNISTDataset (
39
- localStorageDirectory: localStorageDirectory,
40
- remoteBaseDirectory: " https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST " ,
41
- imagesFilename: " train-images-idx3-ubyte " ,
42
- labelsFilename: " train-labels-idx1-ubyte " ,
43
- flattening: flattening,
44
- normalizing: normalizing) ,
45
- batchSize: batchSize,
46
- numWorkers: 1 , //No need to use parallelism since everything is loaded in memory
47
- shuffle: true )
39
+ /// Creates an instance with `batchSize`.
40
+ ///
41
+ /// - Parameter entropy: a source of randomness used to shuffle sample
42
+ /// ordering. It will be stored in `self`, so if it is only pseudorandom
43
+ /// and has value semantics, the sequence of epochs is deterministic and not
44
+ /// dependent on other operations.
45
+ public init ( batchSize: Int , entropy: Entropy ) {
46
+ self . init ( batchSize: batchSize, device: Device . default, entropy: entropy,
47
+ flattening: false , normalizing: false )
48
+ }
48
49
49
- test = Batcher < SourceDataSet > (
50
- on: fetchMNISTDataset (
51
- localStorageDirectory: localStorageDirectory,
52
- remoteBaseDirectory: " https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST " ,
53
- imagesFilename: " t10k-images-idx3-ubyte " ,
54
- labelsFilename: " t10k-labels-idx1-ubyte " ,
55
- flattening: flattening,
56
- normalizing: normalizing) ,
57
- batchSize: batchSize,
58
- numWorkers: 1 ) //No need to use parallelism since everything is loaded in memory
50
+ /// Creates an instance with `batchSize` on `device`.
51
+ ///
52
+ /// - Parameters:
53
+ /// - entropy: a source of randomness used to shuffle sample ordering. It
54
+ /// will be stored in `self`, so if it is only pseudorandom and has value
55
+ /// semantics, the sequence of epochs is deterministic and not dependent
56
+ /// on other operations.
57
+ /// - flattening: flattens the data to be a 2d-tensor iff `true. The default value
58
+ /// is `false`.
59
+ /// - normalizing: normalizes the batches to have values from -1.0 to 1.0 iff `true`.
60
+ /// The default value is `false`.
61
+ /// - localStorageDirectory: the directory in which the dataset is stored.
62
+ public init (
63
+ batchSize: Int , device: Device , entropy: Entropy , flattening: Bool = false ,
64
+ normalizing: Bool = false ,
65
+ localStorageDirectory: URL = DatasetUtilities . defaultDirectory
66
+ . appendingPathComponent ( " KuzushijiMNIST " , isDirectory: true )
67
+ ) {
68
+ training = TrainingEpochs (
69
+ samples: fetchMNISTDataset (
70
+ localStorageDirectory: localStorageDirectory,
71
+ remoteBaseDirectory: " https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST " ,
72
+ imagesFilename: " train-images-idx3-ubyte " ,
73
+ labelsFilename: " train-labels-idx1-ubyte " ) ,
74
+ batchSize: batchSize, entropy: entropy
75
+ ) . lazy. map { ( batches: Batches ) -> LazyMapSequence < Batches , LabeledImage > in
76
+ return batches. lazy. map { makeMNISTBatch (
77
+ samples: $0, flattening: flattening, normalizing: normalizing, device: device
78
+ ) }
59
79
}
80
+
81
+ validation = fetchMNISTDataset (
82
+ localStorageDirectory: localStorageDirectory,
83
+ remoteBaseDirectory: " https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST " ,
84
+ imagesFilename: " t10k-images-idx3-ubyte " ,
85
+ labelsFilename: " t10k-labels-idx1-ubyte "
86
+ ) . inBatches ( of: batchSize) . lazy. map {
87
+ makeMNISTBatch ( samples: $0, flattening: flattening, normalizing: normalizing,
88
+ device: device)
89
+ }
90
+ }
60
91
}
92
+
93
+ extension KuzushijiMNIST : ImageClassificationData where Entropy == SystemRandomNumberGenerator {
94
+ /// Creates an instance with `batchSize`.
95
+ public init ( batchSize: Int ) {
96
+ self . init ( batchSize: batchSize, entropy: SystemRandomNumberGenerator ( ) )
97
+ }
98
+ }
0 commit comments