Skip to content

Commit 7b53c2f

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API. PiperOrigin-RevId: 679163106
1 parent 5cef547 commit 7b53c2f

File tree

6 files changed

+43
-15
lines changed

6 files changed

+43
-15
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1515
* New Functionality
1616
* This release includes wheels for Python 3.13. Free-threading mode is not yet
1717
supported.
18+
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
19+
formerly private `XlaRuntimeError` type.
1820

1921
* Breaking changes
2022
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
@@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
3234
in an error.
3335
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
3436
being deprecated in JAX v0.4.30.
37+
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
38+
`jax.errors.JaxRuntimeError` instead.
3539

3640
* Deletion:
3741
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation

docs/errors.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ along with representative examples of how one might fix them.
99
.. currentmodule:: jax.errors
1010
.. autoclass:: ConcretizationTypeError
1111
.. autoclass:: KeyReuseError
12+
.. autoclass:: JaxRuntimeError
1213
.. autoclass:: NonConcreteBooleanIndexError
1314
.. autoclass:: TracerArrayConversionError
1415
.. autoclass:: TracerBoolConversionError

jax/errors.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,9 @@
2626
UnexpectedTracerError as UnexpectedTracerError,
2727
KeyReuseError as KeyReuseError,
2828
)
29+
30+
from jax._src.lib import xla_client as _xc
31+
JaxRuntimeError = _xc.XlaRuntimeError
32+
del _xc
33+
2934
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback

jax/lib/xla_client.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,36 @@
3737
Traceback = _xc.Traceback
3838
XlaBuilder = _xc.XlaBuilder
3939
XlaComputation = _xc.XlaComputation
40-
XlaRuntimeError = _xc.XlaRuntimeError
4140

4241
_deprecations = {
43-
# Added Aug 5 2024
44-
"_xla" : (
45-
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
46-
_xc._xla
47-
),
48-
"bfloat16" : (
49-
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
50-
_xc.bfloat16
51-
),
42+
# Added Aug 5 2024
43+
"_xla": (
44+
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
45+
_xc._xla,
46+
),
47+
"bfloat16": (
48+
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
49+
_xc.bfloat16,
50+
),
51+
# Added Sep 26 2024
52+
"XlaRuntimeError": (
53+
(
54+
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
55+
" jax.errors.JaxRuntimeError."
56+
),
57+
_xc.XlaRuntimeError,
58+
),
5259
}
5360

5461
import typing as _typing
62+
5563
if _typing.TYPE_CHECKING:
5664
_xla = _xc._xla
5765
bfloat16 = _xc.bfloat16
66+
XlaRuntimeError = _xc.XlaRuntimeError
5867
else:
5968
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
69+
6070
__getattr__ = _deprecation_getattr(__name__, _deprecations)
6171
del _deprecation_getattr
6272
del _typing

tests/errors_test.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,19 @@ def test_grad_norm(self):
394394

395395

396396
class CustomErrorsTest(jtu.JaxTestCase):
397+
397398
@jtu.sample_product(
398-
errorclass=[
399-
errorclass for errorclass in dir(jax.errors)
400-
if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError']
401-
],
399+
errorclass=[
400+
errorclass
401+
for errorclass in dir(jax.errors)
402+
if errorclass.endswith('Error')
403+
and errorclass
404+
not in [
405+
'JaxIndexError',
406+
'JAXTypeError',
407+
'JaxRuntimeError',
408+
]
409+
],
402410
)
403411
def testErrorsURL(self, errorclass):
404412
class FakeTracer(core.Tracer):

tests/package_structure_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase):
3131

3232
@parameterized.parameters([
3333
# TODO(jakevdp): expand test to other public modules.
34-
_mod("jax.errors"),
34+
_mod("jax.errors", exclude=["JaxRuntimeError"]),
3535
_mod("jax.nn.initializers"),
3636
_mod(
3737
"jax.tree_util",

0 commit comments

Comments
 (0)