@@ -13,6 +13,7 @@ def __init__(self):
13
13
self .dicPlatform2Devices = {}
14
14
15
15
self .context = None
16
+ self .curDevice = None
16
17
self .queue = None
17
18
self .program = None
18
19
self .mem_pool = None
@@ -42,10 +43,10 @@ def setupProgramAndDataStructure(self, program, lstIPath=[], dicName2DS={}):
42
43
# ('open', numpy.float32)]}
43
44
assert (self .context != None ), "Setup context seems incorrectly !!"
44
45
assert (len (self .context .devices ) > 0 ), "Error, No device for context !!"
45
- device = self .context .devices [0 ]
46
+ self . curDevice = self .context .devices [0 ]
46
47
dicReturnStruct = {}
47
48
for k , v in dicName2DS .iteritems ():
48
- kObj , k_c_decl = cl .tools .match_dtype_to_c_struct (device , k , v )
49
+ kObj , k_c_decl = cl .tools .match_dtype_to_c_struct (self . curDevice , k , v )
49
50
retV = cl .tools .get_or_register_dtype (k , kObj )
50
51
51
52
dicReturnStruct [k ] = retV
@@ -64,8 +65,15 @@ def setupProgramAndDataStructure(self, program, lstIPath=[], dicName2DS={}):
64
65
def callFuncFromProgram (self , strMethodName , * args , ** argd ):
65
66
methodCall = getattr (self .program , strMethodName )
66
67
if methodCall :
68
+ if len (args ) >= 2 and type (args [1 ])== tuple and (not args [1 ]) != True :
69
+ wgs = cl .Kernel (self .program , strMethodName ).get_work_group_info (
70
+ cl .kernel_work_group_info .WORK_GROUP_SIZE , self .curDevice )
71
+ local_worksize = reduce (lambda x ,y : x * y , args [1 ])
72
+ print 'local size : ' , local_worksize
73
+ assert wgs >= local_worksize , 'Out of capability, please reduce the local work size for %s()' % (strMethodName )
67
74
evt = methodCall (self .queue , * args )
68
75
return evt
76
+ return None
69
77
70
78
def getContext (self , device = PREFERRED_GPU ):
71
79
assert len (self .dicIdx2Platform ) > 0 , 'No platform for OCL operation'
@@ -104,8 +112,8 @@ def createOCLArrayEmpty(self, stDType, size):
104
112
assert size > 0 , "Can NOT create array size <= 0"
105
113
assert (self .queue != None ), " Make sure setup correctly"
106
114
# Creat a list which contains element initialized with structure stDType
107
- npArrData = np .zeros (size , dtype = stDType , allocator = self . mem_pool )
108
- clArrData = cl .array .to_device (self .queue , npArrData )
115
+ npArrData = np .zeros (size , dtype = stDType )
116
+ clArrData = cl .array .to_device (self .queue , npArrData , allocator = self . mem_pool )
109
117
return clArrData
110
118
111
119
def createOCLArrayForInput (self , stDType , lstData ):
@@ -114,6 +122,6 @@ def createOCLArrayForInput(self, stDType, lstData):
114
122
assert len (lstData ) > 0 , "Size of input data list = 0"
115
123
assert (self .queue != None ), " Make sure setup correctly"
116
124
117
- arrayData = np .array (lstData , dtype = stDType , allocator = self . mem_pool )
118
- clArrayData = cl .array .to_device (self .queue , arrayData )
125
+ arrayData = np .array (lstData , dtype = stDType )
126
+ clArrayData = cl .array .to_device (self .queue , arrayData , allocator = self . mem_pool )
119
127
return clArrayData
0 commit comments