@@ -34,22 +34,22 @@ void quantize_per_tensor_out(
34
34
35
35
if (out.scalar_type () == ScalarType::Byte) {
36
36
uint8_t * out_data = out.mutable_data_ptr <uint8_t >();
37
- quantize<uint8_t >(
38
- out_data, input_data, 1 . / scale, zero_point, numel);
37
+ quantize<uint8_t >(out_data, input_data, 1 . / scale, zero_point, numel);
39
38
} else if (out.scalar_type () == ScalarType::Char) {
40
39
int8_t * out_data = out.mutable_data_ptr <int8_t >();
41
- quantize<int8_t >(
42
- out_data, input_data, 1 . / scale, zero_point, numel);
40
+ quantize<int8_t >(out_data, input_data, 1 . / scale, zero_point, numel);
43
41
} else if (
44
42
out.scalar_type () == ScalarType::Bits16 ||
45
43
out.scalar_type () == ScalarType::UInt16) {
46
44
uint16_t * out_data = out.mutable_data_ptr <uint16_t >();
47
- quantize<uint16_t >(
48
- out_data, input_data, 1 . / scale, zero_point, numel);
45
+ quantize<uint16_t >(out_data, input_data, 1 . / scale, zero_point, numel);
49
46
} else if (out.scalar_type () == ScalarType::Short) {
50
47
int16_t * out_data = out.mutable_data_ptr <int16_t >();
51
48
quantize<int16_t >(
52
49
out_data, input_data, 1 . / scale, zero_point, numel);
50
+ } else if (out.scalar_type () == ScalarType::Int) {
51
+ int32_t * out_data = out.mutable_data_ptr <int32_t >();
52
+ quantize<int32_t >(out_data, input_data, 1 . / scale, zero_point, numel);
53
53
} else {
54
54
ET_CHECK_MSG (
55
55
false ,
@@ -70,8 +70,7 @@ void quantize_per_tensor_asym8s_out(
70
70
const float * input_data = input.const_data_ptr <float >();
71
71
size_t numel = out.numel ();
72
72
int8_t * out_data = out.mutable_data_ptr <int8_t >();
73
- impl::generic::kernels::quantize<int8_t >(
74
- out_data, input_data, 1 . / scale, zero_point, numel);
73
+ quantize<int8_t >(out_data, input_data, 1 . / scale, zero_point, numel);
75
74
}
76
75
77
76
void quantize_per_tensor_asym8u_out (
@@ -86,8 +85,7 @@ void quantize_per_tensor_asym8u_out(
86
85
const float * input_data = input.const_data_ptr <float >();
87
86
size_t numel = out.numel ();
88
87
uint8_t * out_data = out.mutable_data_ptr <uint8_t >();
89
- impl::generic::kernels::quantize<uint8_t >(
90
- out_data, input_data, 1 . / scale, zero_point, numel);
88
+ quantize<uint8_t >(out_data, input_data, 1 . / scale, zero_point, numel);
91
89
}
92
90
93
91
void quantize_per_tensor_asym16s_out (
@@ -102,8 +100,7 @@ void quantize_per_tensor_asym16s_out(
102
100
const float * input_data = input.const_data_ptr <float >();
103
101
size_t numel = out.numel ();
104
102
int16_t * out_data = out.mutable_data_ptr <int16_t >();
105
- impl::generic::kernels::quantize<int16_t >(
106
- out_data, input_data, 1 . / scale, zero_point, numel);
103
+ quantize<int16_t >(out_data, input_data, 1 . / scale, zero_point, numel);
107
104
}
108
105
109
106
void quantize_per_tensor_asym16u_out (
@@ -118,8 +115,22 @@ void quantize_per_tensor_asym16u_out(
118
115
const float * input_data = input.const_data_ptr <float >();
119
116
size_t numel = out.numel ();
120
117
uint16_t * out_data = out.mutable_data_ptr <uint16_t >();
121
- impl::generic::kernels::quantize<uint16_t >(
122
- out_data, input_data, 1 . / scale, zero_point, numel);
118
+ quantize<uint16_t >(out_data, input_data, 1 . / scale, zero_point, numel);
119
+ }
120
+
121
+ void quantize_per_tensor_asym32s_out (
122
+ KernelRuntimeContext& context,
123
+ const Tensor& input,
124
+ double scale,
125
+ int64_t zero_point,
126
+ int64_t quant_min,
127
+ int64_t quant_max,
128
+ ScalarType dtype,
129
+ Tensor& out) {
130
+ const float * input_data = input.const_data_ptr <float >();
131
+ size_t numel = out.numel ();
132
+ int32_t * out_data = out.mutable_data_ptr <int32_t >();
133
+ quantize<int32_t >(out_data, input_data, 1 . / scale, zero_point, numel);
123
134
}
124
135
125
136
} // namespace native
0 commit comments