1515from .index import *
1616
1717def create_array (buf , numdims , idims , dtype ):
18- out_arr = ct .c_longlong (0 )
19- ct . c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
20- safe_call (clib .af_create_array (ct .pointer (out_arr ), ct .c_longlong (buf ),\
21- numdims , ct .pointer (ct . c_dims ), dtype ))
18+ out_arr = ct .c_void_p (0 )
19+ c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
20+ safe_call (clib .af_create_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
21+ numdims , ct .pointer (c_dims ), dtype ))
2222 return out_arr
2323
2424def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
@@ -29,7 +29,7 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
2929 else :
3030 raise TypeError ("Invalid dtype" )
3131
32- out = ct .c_longlong (0 )
32+ out = ct .c_void_p (0 )
3333 dims = dim4 (d0 , d1 , d2 , d3 )
3434
3535 if isinstance (val , complex ):
@@ -39,7 +39,7 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
3939 if (dtype != c32 and dtype != c64 ):
4040 dtype = c32
4141
42- safe_call (clib .af_constant_complex (ct .pointer (out ), c_real , c_imag ,\
42+ safe_call (clib .af_constant_complex (ct .pointer (out ), c_real , c_imag ,
4343 4 , ct .pointer (dims ), dtype ))
4444 elif dtype == s64 :
4545 c_val = ct .c_longlong (val .real )
@@ -55,39 +55,39 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
5555
5656
5757def binary_func (lhs , rhs , c_func ):
58- out = array ()
58+ out = Array ()
5959 other = rhs
6060
6161 if (is_number (rhs )):
6262 ldims = dim4_tuple (lhs .dims ())
6363 rty = implicit_dtype (rhs , lhs .type ())
64- other = array ()
64+ other = Array ()
6565 other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], rty )
66- elif not isinstance (rhs , array ):
66+ elif not isinstance (rhs , Array ):
6767 raise TypeError ("Invalid parameter to binary function" )
6868
6969 safe_call (c_func (ct .pointer (out .arr ), lhs .arr , other .arr , bcast .get ()))
7070
7171 return out
7272
7373def binary_funcr (lhs , rhs , c_func ):
74- out = array ()
74+ out = Array ()
7575 other = lhs
7676
7777 if (is_number (lhs )):
7878 rdims = dim4_tuple (rhs .dims ())
7979 lty = implicit_dtype (lhs , rhs .type ())
80- other = array ()
80+ other = Array ()
8181 other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], lty )
82- elif not isinstance (lhs , array ):
82+ elif not isinstance (lhs , Array ):
8383 raise TypeError ("Invalid parameter to binary function" )
8484
8585 c_func (ct .pointer (out .arr ), other .arr , rhs .arr , bcast .get ())
8686
8787 return out
8888
8989def transpose (a , conj = False ):
90- out = array ()
90+ out = Array ()
9191 safe_call (clib .af_transpose (ct .pointer (out .arr ), a .arr , conj ))
9292 return out
9393
@@ -124,11 +124,11 @@ def get_info(dims, buf_len):
124124 return numdims , idims
125125
126126
127- class array ( base_array ):
127+ class Array ( BaseArray ):
128128
129129 def __init__ (self , src = None , dims = (0 ,), type_char = None ):
130130
131- super (array , self ).__init__ ()
131+ super (Array , self ).__init__ ()
132132
133133 buf = None
134134 buf_len = 0
@@ -137,7 +137,7 @@ def __init__(self, src=None, dims=(0,), type_char=None):
137137
138138 if src is not None :
139139
140- if (isinstance (src , array )):
140+ if (isinstance (src , Array )):
141141 safe_call (clib .af_retain_array (ct .pointer (self .arr ), src .arr ))
142142 return
143143
@@ -178,7 +178,7 @@ def __init__(self, src=None, dims=(0,), type_char=None):
178178 self .arr = create_array (buf , numdims , idims , to_dtype [_type_char ])
179179
180180 def copy (self ):
181- out = array ()
181+ out = Array ()
182182 safe_call (clib .af_copy_array (ct .pointer (out .arr ), self .arr ))
183183 return out
184184
@@ -187,7 +187,7 @@ def __del__(self):
187187 clib .af_release_array (self .arr )
188188
189189 def device_ptr (self ):
190- ptr = ctypes .c_void_p (0 )
190+ ptr = ct .c_void_p (0 )
191191 clib .af_get_device_ptr (ct .pointer (ptr ), self .arr )
192192 return ptr .value
193193
@@ -206,7 +206,7 @@ def dims(self):
206206 d1 = ct .c_longlong (0 )
207207 d2 = ct .c_longlong (0 )
208208 d3 = ct .c_longlong (0 )
209- safe_call (clib .af_get_dims (ct .pointer (d0 ), ct .pointer (d1 ),\
209+ safe_call (clib .af_get_dims (ct .pointer (d0 ), ct .pointer (d1 ),
210210 ct .pointer (d2 ), ct .pointer (d3 ), self .arr ))
211211 dims = (d0 .value ,d1 .value ,d2 .value ,d3 .value )
212212 return dims [:self .numdims ()]
@@ -424,11 +424,11 @@ def __nonzero__(self):
424424
425425 def __getitem__ (self , key ):
426426 try :
427- out = array ()
427+ out = Array ()
428428 n_dims = self .numdims ()
429429 inds = get_indices (key , n_dims )
430430
431- safe_call (clib .af_index_gen (ct .pointer (out .arr ),\
431+ safe_call (clib .af_index_gen (ct .pointer (out .arr ),
432432 self .arr , ct .c_longlong (n_dims ), ct .pointer (inds )))
433433 return out
434434 except RuntimeError as e :
@@ -445,11 +445,11 @@ def __setitem__(self, key, val):
445445 else :
446446 other_arr = val .arr
447447
448- out_arr = ct .c_longlong (0 )
448+ out_arr = ct .c_void_p (0 )
449449 inds = get_indices (key , n_dims )
450450
451- safe_call (clib .af_assign_gen (ct .pointer (out_arr ),\
452- self .arr , ct .c_longlong (n_dims ), ct .pointer (inds ),\
451+ safe_call (clib .af_assign_gen (ct .pointer (out_arr ),
452+ self .arr , ct .c_longlong (n_dims ), ct .pointer (inds ),
453453 other_arr ))
454454 safe_call (clib .af_release_array (self .arr ))
455455 self .arr = out_arr
0 commit comments