1717 ColScaleCSC ,
1818 CSMProperties ,
1919 DenseFromSparse ,
20+ GetItem2Lists ,
21+ GetItem2ListsGrad ,
2022 GetItemList ,
2123 GetItemListGrad ,
2224 HStack ,
@@ -357,7 +359,7 @@ def get_item_list_csc(x, idx):
357359@register_funcify_default_op_cache_key (GetItemListGrad )
358360def numba_funcify_GetItemListGrad (op , node , ** kwargs ):
359361 output_format = node .outputs [0 ].type .format
360- out_dtype = np . dtype ( node .outputs [0 ].type .dtype )
362+ out_dtype = node .outputs [0 ].type .dtype
361363
362364 @numba_basic .numba_njit
363365 def get_item_list_grad_csr (x , idxs , gz ):
@@ -461,47 +463,50 @@ def get_item_list_grad_csc(x, idx, gz):
461463 return get_item_list_grad_csr (x , idx , gz ).tocsc ()
462464
463465 return get_item_list_grad_csc
464- < << << << HEAD
465- == == == =
466466
467467
468468@register_funcify_default_op_cache_key (GetItem2Lists )
469469def numba_funcify_GetItem2Lists (op , node , ** kwargs ):
470- out_dtype = np . dtype ( node .outputs [0 ].type .dtype )
470+ out_dtype = node .outputs [0 ].type .dtype
471471
472472 @numba_basic .numba_njit
473473 def get_item_2lists (x , ind1 , ind2 ):
474- x_csr = x . tocsr ()
475- n_rows , n_cols = x_csr . shape
474+ # Reproduces SciPy and NumPy when running:
475+ # np.asarray(x[ind1, ind2]).flatten()
476476
477477 if ind1 .shape != ind2 .shape :
478478 raise ValueError ("shape mismatch in row/column indices" )
479479
480+ # Output vector contains as many elements as the length of the index lists.
480481 out_size = ind1 .shape [0 ]
481482 out = np .zeros (out_size , dtype = out_dtype )
482483
483- x_data = x_csr . data
484+ x_csr = x . tocsr ()
484485 x_indices = x_csr .indices .view (np .uint32 )
485486 x_indptr = x_csr .indptr .view (np .uint32 )
487+ n_rows , n_cols = x_csr .shape
486488
487489 for i in range (out_size ):
490+ # Normalize row index
488491 row_idx = ind1 [i ]
489492 if row_idx < 0 :
490493 row_idx += n_rows
491494 if row_idx < 0 or row_idx >= n_rows :
492495 raise IndexError ("row index out of bounds" )
493496
497+ # Normalize column index
494498 col_idx = ind2 [i ]
495499 if col_idx < 0 :
496500 col_idx += n_cols
497501 if col_idx < 0 or col_idx >= n_cols :
498502 raise IndexError ("column index out of bounds" )
499503
500- col_idx_u32 = np .uint32 (col_idx )
504+ row_idx = np .uint32 (row_idx )
505+ col_idx = np .uint32 (col_idx )
501506 for data_idx in range (x_indptr [row_idx ], x_indptr [row_idx + 1 ]):
502- if x_indices [data_idx ] == col_idx_u32 :
507+ if x_indices [data_idx ] == col_idx :
503508 # Duplicate sparse entries must accumulate like scipy indexing.
504- out [i ] += x_data [data_idx ]
509+ out [i ] += x_csr . data [data_idx ]
505510
506511 return out
507512
@@ -511,31 +516,42 @@ def get_item_2lists(x, ind1, ind2):
511516@register_funcify_default_op_cache_key (GetItem2ListsGrad )
512517def numba_funcify_GetItem2ListsGrad (op , node , ** kwargs ):
513518 output_format = node .outputs [0 ].type .format
514- out_dtype = np . dtype ( node .outputs [0 ].type .dtype )
519+ out_dtype = node .outputs [0 ].type .dtype
515520
516521 @numba_basic .numba_njit
517522 def get_item_2lists_grad_csr (x , ind1 , ind2 , gz ):
518- n_rows , n_cols = x .shape
519- n_assignments = ind1 .shape [0 ]
523+ # Reproduces SciPy when running:
524+ # y = [csc|csr]_matrix(x.shape)
525+ # for i in range(len(ind1)):
526+ # y[(ind1[i], ind2[i])] = gz[i]
527+ #
528+ # Note that gz is a dense vector.
520529
521- if ind2 .shape [ 0 ] != n_assignments :
530+ if ind1 .shape != ind2 . shape :
522531 raise ValueError ("shape mismatch in row/column indices" )
532+
533+ n_assignments = ind1 .shape [0 ]
523534 if gz .shape [0 ] < n_assignments :
524535 raise IndexError ("gradient index out of bounds" )
525536
526- norm_row = np .empty (n_assignments , dtype = np .int32 )
527- norm_col = np .empty (n_assignments , dtype = np .int32 )
537+ # Vectors with normalized (non-negative) row and column indices
538+ norm_row = np .empty (n_assignments , dtype = np .uint32 )
539+ norm_col = np .empty (n_assignments , dtype = np .uint32 )
528540
541+ n_rows , n_cols = x .shape
542+ # Maps original rows to values in [0, ..., touched_n_rows - 1]
529543 row_to_pos = np .full (n_rows , - 1 , dtype = np .int32 )
530544 touched_n_rows = 0
531545
532546 for i in range (n_assignments ):
547+ # Normalize row idx
533548 row_idx = ind1 [i ]
534549 if row_idx < 0 :
535550 row_idx += n_rows
536551 if row_idx < 0 or row_idx >= n_rows :
537552 raise IndexError ("row index out of bounds" )
538553
554+ # Normalize column idx
539555 col_idx = ind2 [i ]
540556 if col_idx < 0 :
541557 col_idx += n_cols
@@ -552,40 +568,46 @@ def get_item_2lists_grad_csr(x, ind1, ind2, gz):
552568 # Build row-wise buffers for touched rows. Repeated writes overwrite values.
553569 row_data = np .zeros ((touched_n_rows , n_cols ), dtype = out_dtype )
554570 row_mask = np .zeros ((touched_n_rows , n_cols ), dtype = np .bool_ )
571+ row_nnz = np .zeros (touched_n_rows , dtype = np .int32 )
555572
556573 for i in range (n_assignments ):
557574 row_pos = row_to_pos [norm_row [i ]]
558575 col_idx = norm_col [i ]
576+ if not row_mask [row_pos , col_idx ]:
577+ row_nnz [row_pos ] += 1
578+ row_mask [row_pos , col_idx ] = True
559579 row_data [row_pos , col_idx ] = gz [i ]
560- row_mask [row_pos , col_idx ] = True
561580
581+ # Build output indptr.
582+ # For touched rows add row_nnz[row_pos] to total_nnz.
583+ # For untouched rows, carry forward the previous total_nnz count.
562584 out_indptr = np .empty (n_rows + 1 , dtype = np .int32 )
563585 out_indptr [0 ] = 0
564586
565587 total_nnz = 0
566588 for row_idx in range (n_rows ):
567589 row_pos = row_to_pos [row_idx ]
568590 if row_pos >= 0 :
569- row_nnz = 0
570- for col_idx in range (n_cols ):
571- if row_mask [row_pos , col_idx ]:
572- row_nnz += 1
573- total_nnz += row_nnz
591+ total_nnz += row_nnz [row_pos ]
574592 out_indptr [row_idx + 1 ] = total_nnz
575593
594+ # Build output data and indices, which need the total number of non-zero elements.
576595 out_data = np .empty (total_nnz , dtype = out_dtype )
577596 out_indices = np .empty (total_nnz , dtype = np .int32 )
578- out_pos = 0
579597
598+ # Populate indices and data by storing col_idx and value (row_data[row_pos, col_idx])
599+ # for touched rows/columns.
580600 for row_idx in range (n_rows ):
581601 row_pos = row_to_pos [row_idx ]
582602 if row_pos < 0 :
583603 continue
604+
605+ dst = out_indptr [row_idx ]
584606 for col_idx in range (n_cols ):
585607 if row_mask [row_pos , col_idx ]:
586- out_indices [out_pos ] = col_idx
587- out_data [out_pos ] = row_data [row_pos , col_idx ]
588- out_pos += 1
608+ out_indices [dst ] = col_idx
609+ out_data [dst ] = row_data [row_pos , col_idx ]
610+ dst += 1
589611
590612 return sp .sparse .csr_matrix (
591613 (out_data , out_indices , out_indptr ), shape = (n_rows , n_cols )
@@ -599,4 +621,3 @@ def get_item_2lists_grad_csc(x, ind1, ind2, gz):
599621 return get_item_2lists_grad_csr (x , ind1 , ind2 , gz ).tocsc ()
600622
601623 return get_item_2lists_grad_csc
602- > >> >> >> fb1d09134 (Better comments for GetItemList and GetItemListGrad )
0 commit comments