|
156 | 156 | ), |
157 | 157 | 0, |
158 | 158 | ), |
| 159 | + "bool_mask_scalar": ( |
| 160 | + lambda: ( |
| 161 | + torch.randn((2, 3, 4), dtype=torch.float32), |
| 162 | + ( |
| 163 | + torch.arange(3).expand(2, 3) |
| 164 | + >= torch.tensor([3, 2], dtype=torch.int64)[:, None], |
| 165 | + ), |
| 166 | + torch.tensor(0.0, dtype=torch.float32), |
| 167 | + False, |
| 168 | + ), |
| 169 | + 0, |
| 170 | + ), |
159 | 171 | "none_indices": ( |
160 | 172 | lambda: ( |
161 | 173 | torch.ones((5, 3, 2, 2), dtype=torch.float32), |
|
210 | 222 | ), |
211 | 223 | 0, |
212 | 224 | ), |
| 225 | + "none_and_bool_indices_scalar": ( |
| 226 | + lambda: ( |
| 227 | + torch.randn((2, 3, 4), dtype=torch.float32), |
| 228 | + (None, torch.tensor([True, False, True]), None), |
| 229 | + torch.tensor(0.0, dtype=torch.float32), |
| 230 | + False, |
| 231 | + ), |
| 232 | + 0, |
| 233 | + ), |
| 234 | +} |
| 235 | +mixed_indices_not_supported = { |
| 236 | + "bool_and_tensor_indices_scalar": ( |
| 237 | + lambda: ( |
| 238 | + torch.randn((2, 3, 4), dtype=torch.float32), |
| 239 | + ( |
| 240 | + torch.tensor([True, False]), |
| 241 | + torch.tensor([1, 2], dtype=torch.int64), |
| 242 | + ), |
| 243 | + torch.tensor(0.0, dtype=torch.float32), |
| 244 | + False, |
| 245 | + ), |
| 246 | + 0, |
| 247 | + ), |
| 248 | + "bool_mask_tensor": ( |
| 249 | + lambda: ( |
| 250 | + torch.randn((2, 3, 4), dtype=torch.float32), |
| 251 | + (torch.tensor([True, False]),), |
| 252 | + torch.rand((1, 3, 4), dtype=torch.float32), |
| 253 | + False, |
| 254 | + ), |
| 255 | + 0, |
| 256 | + ), |
| 257 | + "two_bool_mask_scalar": ( |
| 258 | + lambda: ( |
| 259 | + torch.randn((2, 3, 4), dtype=torch.float32), |
| 260 | + ( |
| 261 | + torch.tensor([False, True]), |
| 262 | + torch.tensor([True, False, False]), |
| 263 | + ), |
| 264 | + torch.tensor(0.0, dtype=torch.float32), |
| 265 | + False, |
| 266 | + ), |
| 267 | + 0, |
| 268 | + ), |
| 269 | + "two_bool_mask_tensor": ( |
| 270 | + lambda: ( |
| 271 | + torch.randn((2, 3, 4), dtype=torch.float32), |
| 272 | + ( |
| 273 | + torch.tensor([False, True]), |
| 274 | + torch.tensor([True, False, False]), |
| 275 | + ), |
| 276 | + torch.rand((1, 4), dtype=torch.float32), |
| 277 | + False, |
| 278 | + ), |
| 279 | + 0, |
| 280 | + ), |
213 | 281 | } |
214 | 282 | test_data_int = { |
215 | 283 | "rank3_zeros_int8": ( |
@@ -385,3 +453,28 @@ def test_index_put_vgf_quant(test_module: input_t): |
385 | 453 | exir_op=IndexPut.exir_op, |
386 | 454 | ) |
387 | 455 | pipeline.run() |
| 456 | + |
| 457 | + |
| 458 | +@common.parametrize("test_module", mixed_indices_not_supported) |
| 459 | +def test_index_put_tosa_FP_not_delegated(test_module: input_t): |
| 460 | + pipeline = OpNotSupportedPipeline[input_t]( |
| 461 | + IndexPut(), |
| 462 | + test_module[0](), |
| 463 | + {IndexPut.exir_op: 1}, |
| 464 | + quantize=False, |
| 465 | + u55_subset=False, |
| 466 | + n_expected_delegates=0, |
| 467 | + ) |
| 468 | + pipeline.run() |
| 469 | + |
| 470 | + |
| 471 | +@common.parametrize("test_module", mixed_indices_not_supported) |
| 472 | +def test_index_put_tosa_INT_not_delegated(test_module: input_t): |
| 473 | + pipeline = OpNotSupportedPipeline[input_t]( |
| 474 | + IndexPut(), |
| 475 | + test_module[0](), |
| 476 | + {IndexPut.exir_op: 1}, |
| 477 | + quantize=True, |
| 478 | + n_expected_delegates=0, |
| 479 | + ) |
| 480 | + pipeline.run() |
0 commit comments