Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import numpy as np

import jax
from jax._src import core
from jax import export
from jax import jvp, grad
from jax import lax
Expand All @@ -38,6 +37,8 @@
from jax.interpreters import batching
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import lax_reference
from jax._src import test_util as jtu
Expand Down Expand Up @@ -1106,17 +1107,24 @@ def testDotPositionalArgumentDeprecation(self):
lhs = jnp.arange(5.0)
rhs = jnp.arange(5.0)
msg = "jax.lax.dot: passing precision or preferred_element_type by position"
multiple_args_msg = "jax.lax.dot got multiple values for argument"

with self.assertWarnsRegex(DeprecationWarning, msg):
with self.assertDeprecationWarnsOrRaises("jax-lax-dot-positional-args", msg):
lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32)

with self.assertWarnsRegex(DeprecationWarning, msg):
with self.assertDeprecationWarnsOrRaises("jax-lax-dot-positional-args", msg):
with self.assertRaises(TypeError):
lax.dot(lhs, rhs, lax.Precision.DEFAULT, precision=lax.Precision.DEFAULT)

with self.assertWarnsRegex(DeprecationWarning, msg):
with self.assertRaises(TypeError):
lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32, preferred_element_type=jnp.float32)
if deprecations.is_accelerated("jax-lax-dot-positional-args"):
with self.assertRaisesRegex(ValueError, msg):
lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32,
preferred_element_type=jnp.float32)
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
with self.assertRaisesRegex(TypeError, multiple_args_msg):
lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32,
preferred_element_type=jnp.float32)

@parameterized.parameters([
(algorithm, dtype)
Expand Down
Loading