Files
psat/psat.cl
2022-07-17 14:03:47 -04:00

151 lines
4.1 KiB
Common Lisp

#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
#define DEBUG
#ifdef DEBUG
#define DBG(X) printf("%s\n", (X))
#else
#define DBG(X) do {} while (0)
#endif
static inline void printbits(uint a) {
for (uint i = 0; i < 32; ++i) {
uint ind = 31 - i;
printf("%u", (a >> ind) & 1U);
}
}
static inline void stateaddpow(uint wcnt, uint* state, uint pow) {
uint corpow = pow & 0b11111U;
uint startind = pow >> 5U;
uint tr = 1U << corpow;
uint tval = state[startind] + tr;
bool carry = tval < state[startind];
state[startind] = tval;
if (carry) {
uint i = startind + 1;
for ( ; i < wcnt; ++i) {
state[i]++;
if (state[i]) break;
}
}
}
__kernel void vectorSAT(__global const uint* cnfheader, __global const uint* lvars, __global const uint* vars, __global const uint* clauses, __global const uchar* pars, __global uint* output, __global uchar* scratchpad, __local uint* maxvals) {
output[0] = 2;
__local uint setmax;
uint cnt = cnfheader[0];
uint vcnt = cnfheader[1];
uint ccnt = cnfheader[2];
uint wcnt = 1 + (vcnt >> 5U);
uint maxctr = 1U << (vcnt & 0b11111U);
//uint glbid = get_global_id(0);
//uint glbsz = get_global_size(0);
uint locid = get_local_id(0);
uint locsz = get_local_size(0);
// uint grpid = get_group_id(0);
// uint grpcn = get_num_groups(0);
// Zero out the counter
for (uint i = 0; i < wcnt; ++i) output[i + 1] = 0;
// Set all scratchpad clauses to true
for (uint j = 0; j < ccnt; j += locsz) {
uchar cond = (j + locid) < ccnt;
j = j * cond + (!cond) * (ccnt - locid - 1);
scratchpad[j + locid] = 1;
}
__local uint firstind[1];
while (output[0] == 2) {
firstind[0] = ccnt;
setmax = 0;
uint maxnumx = 0;
for (uint j = 0; j < cnt; j += locsz) {
uchar cond = (j + locid) < cnt;
// Last element cap
j = j * cond + (!cond) * (cnt - locid - 1);
uint varind = vars[j + locid];
varind = (vcnt - 1) - varind;
uint iind = varind >> 5U;
uint bind = varind & 0b11111U;
uchar cpar = (output[iind + 1] >> bind) & 1U;
if (cpar != pars[j + locid]) {
scratchpad[clauses[j + locid]] = 0;
}
}
barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
for (uint j = 0; j < ccnt; j += locsz) {
if (scratchpad[j + locid] == 1 && (j + locid) < ccnt) {
setmax = 1;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (setmax) {
// Set maxval array to zero
maxvals[locid] = 0;
// Accumulate and reduce the maximums
for (uint j = 0; j < ccnt; j += locsz) {
//uint a = maxvals[locid];
//uint b = lvars[j + locid];
// uint c = max(a, b);
if ((j + locid) < ccnt && scratchpad[j + locid] == 1) {
//maxvals[locid] = c;
atomic_min(firstind, (j + locid));
}
}
barrier(CLK_LOCAL_MEM_FENCE);
uint maxj = lvars[firstind[0]];
// Set all scratchpad clauses to true
for (uint j = 0; j < ccnt; j += locsz) {
uchar cond = (j + locid) < ccnt;
j = j * cond + (!cond) * (ccnt - locid - 1);
scratchpad[j + locid] = 1;
}
// Final reduction pass
/*
uint maxj = maxvals[0];
for (uint j = 1; j < locsz; ++j) {
maxj = max(maxj, maxvals[j]);
}
*/
// Add to the counter
if (locid == 0) {
stateaddpow(wcnt, output + 1, maxj);
}
if (output[wcnt] >= maxctr) {
output[0] = 1;
}
} else {
output[0] = 0;
if (locid == 0) {
for (uint i = 0; i < wcnt; ++i) output[i + 1] = ~output[i + 1];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
}