@@ -3673,6 +3673,22 @@ def func(x):
3673
3673
return tf .identity (picks , name = _TFOUTPUT )
3674
3674
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3675
3675
3676
+ @check_opset_min_version (10 , "IsInf" )
3677
+ def test_where_with_isinf_condition (self ):
3678
+ def func (x , y , z ):
3679
+ # Use is_inf as condition to trigger the IsInf code path
3680
+ condition = tf .math .is_inf (x )
3681
+ result = tf .where (condition , y , z )
3682
+ return tf .identity (result , name = _TFOUTPUT )
3683
+
3684
+ # Create test data with some infinite values
3685
+ x_val = np .array ([1.0 , np .inf , 3.0 , - np .inf , 5.0 ], dtype = np .float32 )
3686
+ y_val = np .array ([0.0 , 0.0 , 0.0 , 0.0 , 0.0 ], dtype = np .float32 )
3687
+ z_val = np .array ([100.0 , 200.0 , 300.0 , 400.0 , 500.0 ], dtype = np .float32 )
3688
+
3689
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val , _INPUT1 : y_val , _INPUT2 : z_val })
3690
+
3691
+
3676
3692
@check_opset_min_version (9 , "IsNaN" )
3677
3693
def test_where_isnan (self ):
3678
3694
x_val = np .array ([1 , 2 , - 3 , float ('nan' ), - 5 , - 6 , float ('nan' ), 8 , 9 , 0 ], dtype = np .float32 )
0 commit comments