done
This commit is contained in:
		| @ -0,0 +1,44 @@ | ||||
| """ | ||||
| Use cffi to access any of the underlying C functions from distributions.h | ||||
| """ | ||||
| import os | ||||
|  | ||||
| import cffi | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from .parse import parse_distributions_h | ||||
|  | ||||
| ffi = cffi.FFI() | ||||
|  | ||||
| inc_dir = os.path.join(np.get_include(), 'numpy') | ||||
|  | ||||
| # Basic numpy types | ||||
| ffi.cdef(''' | ||||
|     typedef intptr_t npy_intp; | ||||
|     typedef unsigned char npy_bool; | ||||
|  | ||||
| ''') | ||||
|  | ||||
| parse_distributions_h(ffi, inc_dir) | ||||
|  | ||||
| lib = ffi.dlopen(np.random._generator.__file__) | ||||
|  | ||||
| # Compare the distributions.h random_standard_normal_fill to | ||||
| # Generator.standard_random | ||||
| bit_gen = np.random.PCG64() | ||||
| rng = np.random.Generator(bit_gen) | ||||
| state = bit_gen.state | ||||
|  | ||||
| interface = rng.bit_generator.cffi | ||||
| n = 100 | ||||
| vals_cffi = ffi.new('double[%d]' % n) | ||||
| lib.random_standard_normal_fill(interface.bit_generator, n, vals_cffi) | ||||
|  | ||||
| # reset the state | ||||
| bit_gen.state = state | ||||
|  | ||||
| vals = rng.standard_normal(n) | ||||
|  | ||||
| for i in range(n): | ||||
|     assert vals[i] == vals_cffi[i] | ||||
| @ -0,0 +1,53 @@ | ||||
| import os | ||||
|  | ||||
|  | ||||
| def parse_distributions_h(ffi, inc_dir): | ||||
|     """ | ||||
|     Parse distributions.h located in inc_dir for CFFI, filling in the ffi.cdef | ||||
|  | ||||
|     Read the function declarations without the "#define ..." macros that will | ||||
|     be filled in when loading the library. | ||||
|     """ | ||||
|  | ||||
|     with open(os.path.join(inc_dir, 'random', 'bitgen.h')) as fid: | ||||
|         s = [] | ||||
|         for line in fid: | ||||
|             # massage the include file | ||||
|             if line.strip().startswith('#'): | ||||
|                 continue | ||||
|             s.append(line) | ||||
|         ffi.cdef('\n'.join(s)) | ||||
|  | ||||
|     with open(os.path.join(inc_dir, 'random', 'distributions.h')) as fid: | ||||
|         s = [] | ||||
|         in_skip = 0 | ||||
|         ignoring = False | ||||
|         for line in fid: | ||||
|             # check for and remove extern "C" guards | ||||
|             if ignoring: | ||||
|                 if line.strip().startswith('#endif'): | ||||
|                     ignoring = False | ||||
|                 continue | ||||
|             if line.strip().startswith('#ifdef __cplusplus'): | ||||
|                 ignoring = True | ||||
|  | ||||
|             # massage the include file | ||||
|             if line.strip().startswith('#'): | ||||
|                 continue | ||||
|  | ||||
|             # skip any inlined function definition | ||||
|             # which starts with 'static inline xxx(...) {' | ||||
|             # and ends with a closing '}' | ||||
|             if line.strip().startswith('static inline'): | ||||
|                 in_skip += line.count('{') | ||||
|                 continue | ||||
|             elif in_skip > 0: | ||||
|                 in_skip += line.count('{') | ||||
|                 in_skip -= line.count('}') | ||||
|                 continue | ||||
|  | ||||
|             # replace defines with their value or remove them | ||||
|             line = line.replace('DECLDIR', '') | ||||
|             line = line.replace('RAND_INT_TYPE', 'int64_t') | ||||
|             s.append(line) | ||||
|         ffi.cdef('\n'.join(s)) | ||||
		Reference in New Issue
	
	Block a user