Skip to content

Commit a9bde68

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add int32 quant/dequant back (#14269)
Summary: Pull Request resolved: #14269 Now that the previous diff exists, we can add the int32 case back without adding to the code size of deployed models. Reviewed By: hsharma35 Differential Revision: D82282481
1 parent e92a2fc commit a9bde68

File tree

10 files changed

+148
-14
lines changed

10 files changed

+148
-14
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@
208208
- arg_meta: null
209209
kernel_name: impl::generic::quantize_per_tensor_asym16u_out
210210

211+
- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
212+
variants: function
213+
kernels:
214+
- arg_meta: null
215+
kernel_name: impl::generic::quantize_per_tensor_asym32s_out
216+
211217
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
212218
variants: function
213219
kernels:
@@ -238,6 +244,12 @@
238244
- arg_meta: null
239245
kernel_name: impl::generic::dequantize_per_tensor_asym16u_out
240246

247+
- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
248+
variants: function
249+
kernels:
250+
- arg_meta: null
251+
kernel_name: impl::generic::dequantize_per_tensor_asym32s_out
252+
241253
- func: cadence::quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
242254
kernels:
243255
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,11 @@
308308
- arg_meta: null
309309
kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out
310310

311+
- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
312+
variants: function
313+
kernels:
314+
- arg_meta: null
315+
kernel_name: impl::HiFi::quantize_per_tensor_asym32s_out
311316

312317
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
313318
variants: function
@@ -339,6 +344,12 @@
339344
- arg_meta: null
340345
kernel_name: impl::HiFi::dequantize_per_tensor_asym16u_out
341346

347+
- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
348+
variants: function
349+
kernels:
350+
- arg_meta: null
351+
kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out
352+
342353
- func: cadence::quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
343354
kernels:
344355
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
"quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
5757
)
5858

59+
lib.define(
60+
"quantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
61+
)
62+
lib.define(
63+
"quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
64+
)
65+
5966
lib.define(
6067
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
6168
)
@@ -87,6 +94,13 @@
8794
"dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
8895
)
8996

97+
lib.define(
98+
"dequantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
99+
)
100+
lib.define(
101+
"dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
102+
)
103+
90104
lib.define(
91105
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
92106
)
@@ -617,6 +631,18 @@ def quantize_per_tensor_asym16u_meta(
617631
return input.new_empty(input.size(), dtype=dtype)
618632

619633

634+
@register_fake("cadence::quantize_per_tensor_asym32s")
635+
def quantize_per_tensor_asym32s_meta(
636+
input: torch.Tensor,
637+
scale: float,
638+
zero_point: int,
639+
quant_min: int,
640+
quant_max: int,
641+
dtype: torch.dtype,
642+
) -> torch.Tensor:
643+
return input.new_empty(input.size(), dtype=dtype)
644+
645+
620646
@register_fake("cadence::dequantize_per_tensor")
621647
def dequantize_per_tensor_meta(
622648
input: torch.Tensor,
@@ -677,6 +703,18 @@ def dequantize_per_tensor_asym16u_meta(
677703
return input.new_empty(input.size(), dtype=torch.float)
678704

679705

706+
@register_fake("cadence::dequantize_per_tensor_asym32s")
707+
def dequantize_per_tensor_asym32s_meta(
708+
input: torch.Tensor,
709+
scale: float,
710+
zero_point: int,
711+
quant_min: int,
712+
quant_max: int,
713+
dtype: torch.dtype,
714+
) -> torch.Tensor:
715+
return input.new_empty(input.size(), dtype=torch.float)
716+
717+
680718
@register_fake("cadence::quantized_add")
681719
def quantized_add_meta(
682720
X: torch.Tensor,

backends/cadence/aot/type_dispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class CompileTimeTypeDispatchPass(ExportPass):
108108
(torch.uint8,): "asym8u",
109109
(torch.int16,): "asym16s",
110110
(torch.uint16,): "asym16s",
111+
(torch.int32,): "asym32s",
111112
},
112113
variant="default",
113114
is_quant_op=True,
@@ -119,6 +120,7 @@ class CompileTimeTypeDispatchPass(ExportPass):
119120
(torch.uint8,): "asym8u",
120121
(torch.int16,): "asym16s",
121122
(torch.uint16,): "asym16s",
123+
(torch.int32,): "asym32s",
122124
},
123125
variant="default",
124126
),

backends/cadence/generic/kernels/kernels.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ typed_quantize_val(int8_t);
7373
typed_quantize_val(uint8_t);
7474
typed_quantize_val(int16_t);
7575
typed_quantize_val(uint16_t);
76+
typed_quantize_val(int32_t);
7677
#undef typed_quantize_val
7778

7879
#define typed_quantize_vec(dtype) \
@@ -86,6 +87,7 @@ typed_quantize_vec(int8_t);
8687
typed_quantize_vec(uint8_t);
8788
typed_quantize_vec(int16_t);
8889
typed_quantize_vec(uint16_t);
90+
typed_quantize_vec(int32_t);
8991
#undef typed_quantize_vec
9092

9193
#define typed_dequantize_val(dtype) \
@@ -94,6 +96,7 @@ typed_dequantize_val(int8_t);
9496
typed_dequantize_val(uint8_t);
9597
typed_dequantize_val(int16_t);
9698
typed_dequantize_val(uint16_t);
99+
typed_dequantize_val(int32_t);
97100
#undef typed_dequantize_val
98101

99102
#define typed_dequantize_vec(dtype) \
@@ -107,6 +110,7 @@ typed_dequantize_vec(int8_t);
107110
typed_dequantize_vec(uint8_t);
108111
typed_dequantize_vec(int16_t);
109112
typed_dequantize_vec(uint16_t);
113+
typed_dequantize_vec(int32_t);
110114
#undef typed_dequantize_vec
111115

112116
} // namespace kernels

backends/cadence/generic/operators/dequantize_per_tensor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ void dequantize_per_tensor_out(
4444
} else if (input.scalar_type() == ScalarType::Short) {
4545
const int16_t* input_data = input.const_data_ptr<int16_t>();
4646
dequantize<int16_t>(out_data, input_data, scale, zero_point, numel);
47+
} else if (input.scalar_type() == ScalarType::Int) {
48+
const int32_t* input_data = input.const_data_ptr<int32_t>();
49+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
4750
} else {
4851
ET_CHECK_MSG(
4952
false,
@@ -112,6 +115,21 @@ void dequantize_per_tensor_asym16u_out(
112115
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
113116
}
114117

118+
void dequantize_per_tensor_asym32s_out(
119+
KernelRuntimeContext& context,
120+
const Tensor& input,
121+
double scale,
122+
int64_t zero_point,
123+
int64_t quant_min,
124+
int64_t quant_max,
125+
ScalarType dtype,
126+
Tensor& out) {
127+
float* out_data = out.mutable_data_ptr<float>();
128+
size_t numel = out.numel();
129+
const int32_t* input_data = input.const_data_ptr<int32_t>();
130+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
131+
}
132+
115133
} // namespace native
116134
} // namespace generic
117135
} // namespace impl

backends/cadence/generic/operators/quantize_per_tensor.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ void quantize_per_tensor_out(
3434

3535
if (out.scalar_type() == ScalarType::Byte) {
3636
uint8_t* out_data = out.mutable_data_ptr<uint8_t>();
37-
quantize<uint8_t>(
38-
out_data, input_data, 1. / scale, zero_point, numel);
37+
quantize<uint8_t>(out_data, input_data, 1. / scale, zero_point, numel);
3938
} else if (out.scalar_type() == ScalarType::Char) {
4039
int8_t* out_data = out.mutable_data_ptr<int8_t>();
41-
quantize<int8_t>(
42-
out_data, input_data, 1. / scale, zero_point, numel);
40+
quantize<int8_t>(out_data, input_data, 1. / scale, zero_point, numel);
4341
} else if (
4442
out.scalar_type() == ScalarType::Bits16 ||
4543
out.scalar_type() == ScalarType::UInt16) {
4644
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
47-
quantize<uint16_t>(
48-
out_data, input_data, 1. / scale, zero_point, numel);
45+
quantize<uint16_t>(out_data, input_data, 1. / scale, zero_point, numel);
4946
} else if (out.scalar_type() == ScalarType::Short) {
5047
int16_t* out_data = out.mutable_data_ptr<int16_t>();
5148
quantize<int16_t>(
5249
out_data, input_data, 1. / scale, zero_point, numel);
50+
} else if (out.scalar_type() == ScalarType::Int) {
51+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
52+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
5353
} else {
5454
ET_CHECK_MSG(
5555
false,
@@ -70,8 +70,7 @@ void quantize_per_tensor_asym8s_out(
7070
const float* input_data = input.const_data_ptr<float>();
7171
size_t numel = out.numel();
7272
int8_t* out_data = out.mutable_data_ptr<int8_t>();
73-
impl::generic::kernels::quantize<int8_t>(
74-
out_data, input_data, 1. / scale, zero_point, numel);
73+
quantize<int8_t>(out_data, input_data, 1. / scale, zero_point, numel);
7574
}
7675

7776
void quantize_per_tensor_asym8u_out(
@@ -86,8 +85,7 @@ void quantize_per_tensor_asym8u_out(
8685
const float* input_data = input.const_data_ptr<float>();
8786
size_t numel = out.numel();
8887
uint8_t* out_data = out.mutable_data_ptr<uint8_t>();
89-
impl::generic::kernels::quantize<uint8_t>(
90-
out_data, input_data, 1. / scale, zero_point, numel);
88+
quantize<uint8_t>(out_data, input_data, 1. / scale, zero_point, numel);
9189
}
9290

9391
void quantize_per_tensor_asym16s_out(
@@ -102,8 +100,7 @@ void quantize_per_tensor_asym16s_out(
102100
const float* input_data = input.const_data_ptr<float>();
103101
size_t numel = out.numel();
104102
int16_t* out_data = out.mutable_data_ptr<int16_t>();
105-
impl::generic::kernels::quantize<int16_t>(
106-
out_data, input_data, 1. / scale, zero_point, numel);
103+
quantize<int16_t>(out_data, input_data, 1. / scale, zero_point, numel);
107104
}
108105

109106
void quantize_per_tensor_asym16u_out(
@@ -118,8 +115,22 @@ void quantize_per_tensor_asym16u_out(
118115
const float* input_data = input.const_data_ptr<float>();
119116
size_t numel = out.numel();
120117
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
121-
impl::generic::kernels::quantize<uint16_t>(
122-
out_data, input_data, 1. / scale, zero_point, numel);
118+
quantize<uint16_t>(out_data, input_data, 1. / scale, zero_point, numel);
119+
}
120+
121+
void quantize_per_tensor_asym32s_out(
122+
KernelRuntimeContext& context,
123+
const Tensor& input,
124+
double scale,
125+
int64_t zero_point,
126+
int64_t quant_min,
127+
int64_t quant_max,
128+
ScalarType dtype,
129+
Tensor& out) {
130+
const float* input_data = input.const_data_ptr<float>();
131+
size_t numel = out.numel();
132+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
133+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
123134
}
124135

125136
} // namespace native

backends/cadence/hifi/kernels/kernels.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ typed_quantize_val(int8_t);
135135
typed_quantize_val(uint8_t);
136136
typed_quantize_val(int16_t);
137137
typed_quantize_val(uint16_t);
138+
typed_quantize_val(int32_t);
138139
#undef typed_quantize_val
139140

140141
#define typed_quantize_vec(dtype) \
@@ -158,6 +159,7 @@ typed_dequantize_val(int8_t);
158159
typed_dequantize_val(uint8_t);
159160
typed_dequantize_val(int16_t);
160161
typed_dequantize_val(uint16_t);
162+
typed_dequantize_val(int32_t);
161163
#undef typed_dequantize_val
162164

163165
#define typed_dequantize_vec(dtype) \

backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ void dequantize_per_tensor_out(
4545
input.scalar_type() == ScalarType::UInt16) {
4646
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
4747
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
48+
} else if (input.scalar_type() == ScalarType::Int) {
49+
const int32_t* input_data = input.const_data_ptr<int32_t>();
50+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
4851
} else {
4952
ET_CHECK_MSG(
5053
false,
@@ -98,6 +101,21 @@ void dequantize_per_tensor_asym16u_out(
98101
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
99102
}
100103

104+
void dequantize_per_tensor_asym32s_out(
105+
KernelRuntimeContext& context,
106+
const Tensor& input,
107+
double scale,
108+
int64_t zero_point,
109+
int64_t quant_min,
110+
int64_t quant_max,
111+
ScalarType dtype,
112+
Tensor& out) {
113+
float* out_data = out.mutable_data_ptr<float>();
114+
size_t numel = out.numel();
115+
const int32_t* input_data = input.const_data_ptr<int32_t>();
116+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
117+
}
118+
101119
} // namespace native
102120
} // namespace HiFi
103121
} // namespace impl

backends/cadence/hifi/operators/op_quantize_per_tensor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ void quantize_per_tensor_out(
105105
out.scalar_type() == ScalarType::UInt16) {
106106
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
107107
quantize<uint16_t>(out_data, input_data, 1. / scale, zero_point, numel);
108+
} else if (out.scalar_type() == ScalarType::Int) {
109+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
110+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
108111
} else {
109112
ET_KERNEL_CHECK_MSG(
110113
ctx,
@@ -161,6 +164,21 @@ void quantize_per_tensor_asym16u_out(
161164
quantize<uint16_t>(out_data, input_data, 1. / scale, zero_point, numel);
162165
}
163166

167+
void quantize_per_tensor_asym32s_out(
168+
KernelRuntimeContext& context,
169+
const Tensor& input,
170+
double scale,
171+
int64_t zero_point,
172+
int64_t quant_min,
173+
int64_t quant_max,
174+
ScalarType dtype,
175+
Tensor& out) {
176+
const float* input_data = input.const_data_ptr<float>();
177+
size_t numel = out.numel();
178+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
179+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
180+
}
181+
164182
} // namespace native
165183
} // namespace HiFi
166184
} // namespace impl

0 commit comments

Comments
 (0)