Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ee31b7f

Browse files
tanmayb123rxwei
authored andcommitted
Add Sigmoid Cross Entropy Loss (#24)
* Add Sigmoid Cross Entropy Loss, a.k.a. "binary_crossentropy" from Keras. * Add documentation to each loss function
1 parent 39114bb commit ee31b7f

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

Sources/DeepLearning/Loss.swift

+24
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,38 @@
1616
import TensorFlow
1717
#endif
1818

19+
/// Computes the mean squared error between logits and labels.
20+
///
21+
/// - Parameters:
22+
/// - logits: One-hot encoded outputs from a neural network.
23+
/// - labels: One-hot encoded values that correspond to the correct output.
1924
@differentiable
2025
public func meanSquaredError<Scalar: TensorFlowFloatingPoint>(
2126
predicted: Tensor<Scalar>, expected: Tensor<Scalar>) -> Tensor<Scalar> {
2227
return (expected - predicted).squared().mean()
2328
}
2429

30+
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
31+
///
32+
/// - Parameters:
33+
/// - logits: One-hot encoded outputs from a neural network.
34+
/// - labels: One-hot encoded values that correspond to the correct output.
2535
@differentiable
2636
public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
2737
logits: Tensor<Scalar>, labels: Tensor<Scalar>) -> Tensor<Scalar> {
2838
return -(labels * logSoftmax(logits)).mean(alongAxes: 0).sum()
2939
}
40+
41+
/// Computes the sigmoid cross entropy (binary cross entropy) between logits and labels.
42+
///
43+
/// - Parameters:
44+
/// - logits: Single continuous values from `0` to `1`.
45+
/// - labels: Integer values that correspond to the correct output.
46+
@differentiable
47+
public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
48+
logits: Tensor<Scalar>, labels: Tensor<Scalar>
49+
) -> Tensor<Scalar> {
50+
let loss = labels * log(logits) +
51+
(Tensor<Scalar>(1) - labels) * log(Tensor<Scalar>(1) - logits)
52+
return -loss.mean(alongAxes: 0).sum()
53+
}

0 commit comments

Comments
 (0)