Skip to content

Commit c643290

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Add derivative for cholesky_inverse (pytorch#26451)
Summary: Changelog: - Add derivative of cholesky_inverse. The equations are derived akin to the derivative of solve methods using the technique detailed [here](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=1&cad=rja&uact=8&ved=2ahUKEwiXrOjIyM7kAhWstlkKHRxqCDgQFjAAegQIAhAC&url=https%3A%2F%2Fpeople.maths.ox.ac.uk%2Fgilesm%2Ffiles%2FNA-08-01.pdf&usg=AOvVaw0BNISOvM_I9KjPrl0xv1R_) Pull Request resolved: pytorch#26451 Test Plan: - Added tests for cholesky_inverse in test_autograd.py Closes pytorch#4669. Differential Revision: D17548526 Pulled By: ezyang fbshipit-source-id: 51aa8b900a8dc4012b01a73d432606f216f62c9d
1 parent 7bdc0c1 commit c643290

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

test/test_autograd.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,6 +2442,29 @@ def run_test(upper, dims):
24422442
for upper, dims in product([True, False], [(3, 3), (5, 3, 3), (4, 3, 2, 2)]):
24432443
run_test(upper, dims)
24442444

2445+
@skipIfNoLapack
2446+
def test_cholesky_inverse(self):
2447+
def _test_with_size(upper, dims):
2448+
# We require to create a Cholesky factor which requires that the diagonal elements are positive.
2449+
# Initializing too small values for the diagonal elements could cause issues when being perturbed
2450+
# to obtain the numerical Jacobian, thereby leading to inconsistent gradcheck
2451+
A = torch.randn(*dims)
2452+
A.diagonal().uniform_(0.1, 5.0)
2453+
A.requires_grad_()
2454+
2455+
def func(A, upper):
2456+
if upper:
2457+
root = A.triu()
2458+
else:
2459+
root = A.tril()
2460+
return torch.cholesky_inverse(root, upper)
2461+
2462+
gradcheck(func, [A, upper])
2463+
gradgradcheck(func, [A, upper])
2464+
2465+
for upper, dims in product([True, False], [(3, 3), (5, 5)]):
2466+
_test_with_size(upper, dims)
2467+
24452468
@skipIfNoLapack
24462469
def test_triangular_solve(self):
24472470
def _test_with_size(A_dims, B_dims):

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@
208208
self, input2: cholesky_solve_backward(grad, self, input2, result, upper)
209209

210210
- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor
211-
self: not_implemented("cholesky_inverse")
211+
self: cholesky_inverse_backward(grad, self, upper, result)
212212

213213
- name: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
214214
self: not_implemented("fbgemm_linear_int8_weight_fp32_activation only supported for inference")

tools/autograd/templates/Functions.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,22 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
751751
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5); // Symmetrizing the gradient
752752
}
753753

754+
Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
755+
Tensor grad_L;
756+
if (grad.defined()) {
757+
Tensor common_term = grad + grad.transpose(-2, -1);
758+
common_term = at::matmul(inverse, at::matmul(common_term, inverse));
759+
if (upper) {
760+
grad_L = -at::matmul(L, common_term);
761+
} else {
762+
grad_L = -at::matmul(common_term, L);
763+
}
764+
} else {
765+
grad_L = at::zeros({1}, L.options()).expand_as(L);
766+
}
767+
return grad_L;
768+
}
769+
754770
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
755771
IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
756772
dim = at::maybe_wrap_dim(dim, sizes.size());

0 commit comments

Comments
 (0)