diff --git a/CMakeLists.txt b/CMakeLists.txt index c0aa84e..d46f0cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.22) project(psat C) set(CMAKE_C_STANDARD 99) -# set(CMAKE_C_FLAGS "-mavx2 -O3 -ftree-loop-linear -ftree-loop-im -ftree-loop-ivcanon -fivopts -ftree-vectorize -ftracer -funroll-all-loops ") +set(CMAKE_C_FLAGS "-mavx2 -O3 -ftree-loop-linear -ftree-loop-im -ftree-loop-ivcanon -fivopts -ftree-vectorize") -add_executable(psat main.c cnf.c cnf.h time.c time.h types.h gpusolver.c gpusolver.h cpusolver.c cpusolver.h tests/masterTest.c tests/masterTest.h ncnf.c ncnf.h rng.h rng.c) +add_executable(psat main.c cnf.c cnf.h time.c time.h types.h gpusolver.c gpusolver.h cpusolver.c cpusolver.h tests/masterTest.c tests/masterTest.h ncnf.c ncnf.h rng.h rng.c csflocref.c csflocref.h) target_link_libraries(psat -lOpenCL -lgmp) \ No newline at end of file diff --git a/cpusolver.c b/cpusolver.c index b77e88d..ded6d11 100644 --- a/cpusolver.c +++ b/cpusolver.c @@ -1 +1,664 @@ #include "cpusolver.h" +#include "csflocref.h" + +static inline void stateaddpow(uint wcnt, uint* state, uint pow) { + uint corpow = pow & 0b11111U; + uint startind = pow >> 5U; + uint tr = 1U << corpow; + uint tval = state[startind] + tr; + bool carry = tval < state[startind]; + state[startind] = tval; + if (carry) { + uint i = startind + 1; + for ( ; i < wcnt; ++i) { + state[i]++; + if (state[i]) break; + } + } +} + +#define DEBUGSAT + +#define ADD (0) +#define CMP (1) +#define BCT (2) +#define CHK (3) + +void ctrthings3(cnf* c, u32* state, u32* ctr, const u32* max) { + u32 varcnt = c->cnts[0]; + u32 clausecnt = c->cnts[1]; + u32 wcnt = 1U + (varcnt >> 5U); + u32* mode = state; + u32* index = state + 1; + u32* addval = state + 2; + + // printf("> %u %u %u %u %u\n", *mode, *index, *addval, *ctr, *max); + + + u8 addmode = (*mode == ADD); + u8 cmpmode = (*mode == CMP); + u8 bctmode = (*mode == BCT); + u8 submode = (*mode == CHK); + + // 1 mask for if in subcheck mode, 1 for if not + u32 submask = 0xFFFFFFFFU * submode; + u32 nsubmask = ~submask; + u32 addmask = 0xFFFFFFFFU * addmode; + u32 naddmask = ~addmask; + + /* subcheck */ + // used to find correct ctr index + u32 varcor = varcnt - 1; + // Mask to prevent oob accesses for when (varcnt / 32) > clausecnt (or smth like that) + u32 currclause = *index & submask; + // clausedat records are 3 words long, multiply index by 3 + u32 currclauseshifted = currclause * 3U; + // Retrieve beginning index of clause, add current var index + u32 currvarind = c->clausedat[currclauseshifted] + (*addval & submask); + // Retrieve current var & parity + u32 currvar = c->variables[currvarind]; + u8 corrpar = c->parities[currvarind]; + // Calculate index in ctr + u32 currvarcorr = (varcor - currvar); + u32 currvarword = currvarcorr >> 5U; + u32 currvarbit = currvarcorr & 0b11111U; + + /* Can we make there be 1 ctr read per iteration? */ + // If in sub mode, retrieve current var's bit, else return ctr[*index] + u32 ctrind = (currvarword & submask) | (*index & nsubmask); + u32 ctrval = ctr[ctrind]; + u32 maxval = max[*index & nsubmask]; + + // Extract parity bit from ctr + u32 currpar = (ctrval >> currvarbit) & 1U; + // Check if assignment parity matches clause parity + u8 subvalid = currpar == corrpar; + // Don't mask addval because it's not being used for a lookup + u8 islvar = c->clausedat[currclauseshifted + 1U] == (*addval + 1); + u32 jval = c->clausedat[currclauseshifted + 2U]; + u32 jword = jval >> 5U; + u32 jbit = 1U << (jval & 0b11111U); + + u8 islclause = (clausecnt - 1) <= *index; + + + u8 endOfCtr = (ctrind == (wcnt - 1)); + u8 begOfCtr = (ctrind == 0); + + /* add */ + u32 nctr = ctrval + *addval; + u8 addoverflow = (nctr < ctrval); + + /* cmp */ + u8 cmpresult = -(ctrval < maxval); + cmpresult += (ctrval > maxval); + u8 cmpnores = cmpresult == 0; + u8 cmpisless = cmpresult == 255U; + + /* bitcnt */ + u32 bitcntval = (ctrval & -ctrval); + u32 bitcnt = ((bitcntval & 0x0000FFFFU) != 0) << 4U; + bitcnt |= ((bitcntval & 0x00FF00FF) != 0) << 3U; + bitcnt |= ((bitcntval & 0x0F0F0F0F) != 0) << 2U; + bitcnt |= ((bitcntval & 0x33333333) != 0) << 1U; + bitcnt |= (bitcntval & 0x55555555) != 0; + bitcnt += bitcntval != 0; + bitcntval = 32 - bitcnt; + u8 ctrWordIsZero = ctrval == 0; + + // Set current ctr word + // If in add mode, update ctr val, otherwise leave unchanged + ctr[ctrind] = (nctr & addmask) | (ctr[ctrind] & naddmask); + // Set current index + + + /* + * jword : (subcheck & valid & lastvar) + * index : (subcheck & valid & !lastvar) | (add & (!addoverflow | endOfCtr)) + * index + 1 : (subcheck & !valid) | (add & (addoverflow & !endOfCtr)) | (bitcnt & ctrWordIsZero & !endOfCtr) + * index - 1 : (cmp & cmpnores & !begOfCtr) + * 0 : (cmp & (!cmpnores | begOfCtr)) + * index[addoverflow + bitcntval] : (bitcnt & (!ctrWordIsZero | endOfCtr) + */ + u32 tempA = submode & (subvalid & islvar); + // if (submode & (subvalid & islvar)) fprintf(iofile, "j: %u\n", jval); + u32 tempB = (submode & (subvalid & !islvar)); + u32 tempC = (submode & !subvalid) | (addmode & (addoverflow & !endOfCtr)) | (bctmode & (ctrWordIsZero & !endOfCtr)); + u32 tempD = (cmpmode & (cmpnores & !begOfCtr)); + u32 tempE = (addmode & (!addoverflow | endOfCtr)); + u32 tempF = (bctmode & (!ctrWordIsZero | endOfCtr)); + tempA *= jword; + tempB *= *index; + tempC *= (*index + 1); + tempD *= (*index - 1); + tempE *= (wcnt - 1); + u32 indexval = (*addval + bitcntval); + u8 indexcmp = indexval < varcnt; + indexval = indexval * indexcmp + (varcnt - 1) * (!indexcmp); + tempF *= c->index[((indexval) * bctmode)]; + // if (bctmode & (!ctrWordIsZero | endOfCtr)) fprintf(iofile,"z: %u %u\n", (indexval), c->index[((indexval) * bctmode)]); + *index = tempA | tempB | tempC | tempD | tempE | tempF; + + tempA = submode & (subvalid & islvar); + tempB = (submode & (subvalid & !islvar)); + tempC = (addmode & (addoverflow & !endOfCtr)); + tempD = (bctmode & (ctrWordIsZero & !endOfCtr)); + tempA *= (jbit); + tempB *= (*addval + 1); + tempC *= (addoverflow); + tempD *= (*addval + 32); + *addval = tempA | tempB | tempC | tempD; + /* + * 1 << jbit : (subcheck & valid & lastvar) + * addval + 1 : (subcheck & valid & !lastvar) + * 0 : (subcheck & !valid) | (add & (!addoverflow | endOfCtr)) | (cmp) | (bitcnt & (!ctrWordIsZero | endOfCtr) + * addoverflow : (add & (addoverflow & !endOfCtr)) + * addoverflow + 32 : (bitcnt & ctrWordIsZero & !endOfCtr) + * + */ + + /* + * SAT : (subcheck & !valid & lastclause) + * UNSAT : (cmp & (!cmpnores | begOfCtr)) & !cmpisless + */ + u8 issat = (submode & ((!subvalid) & islclause)); + u8 isusat = (cmpmode & ((!cmpnores) | begOfCtr)) & (!cmpisless); + u8 isdone = issat | isusat; + + *mode += isdone << 2U; + *index ^= (*index * isdone); + *index += issat; + /* + if (isdone) { + *mode = 9; + *index = issat; + // printf("%u\n", issat); + } + */ + + /* + * ADD : (add & (!addoverflow | endOfCtr)) + * CMP : (cmp & (!cmpnores | begOfCtr)) + * BCT : (bitcnt & (!ctrWordIsZero | endOfCtr)) + * CHK : (subcheck & valid & lastvar) + */ + tempA = (addmode & (!addoverflow | endOfCtr)); + tempB = (cmpmode & (!cmpnores | begOfCtr)); + tempC = (bctmode & (!ctrWordIsZero | endOfCtr)); + tempD = (submode & (subvalid & islvar)); + *mode += tempA | tempB | tempC; + *mode -= (tempD * 3); +} + +void ctrthings4(cnf* c, u32* state, u32* ctr, const u32* max, FILE* iofile) { + u32 varcnt = c->cnts[0]; + u32 clausecnt = c->cnts[1]; + u32 wcnt = 1U + (varcnt >> 5U); + u32* mode = state; + u32* index = state + 1; + u32* addval = state + 2; + + // printf("> %u %u %u %u %u\n", *mode, *index, *addval, *ctr, *max); + + + u8 addmode = (*mode == ADD); + u8 cmpmode = (*mode == CMP); + u8 bctmode = (*mode == BCT); + u8 submode = (*mode == CHK); + + // 1 mask for if in subcheck mode, 1 for if not + u32 submask = -submode; + u32 nsubmask = ~submask; + u32 addmask = -addmode; + u32 naddmask = ~addmask; + + /* subcheck */ + // used to find correct ctr index + u32 varcor = varcnt - 1; + // Mask to prevent oob accesses for when (varcnt / 32) > clausecnt (or smth like that) + u32 currclause = *index & submask; + // clausedat records are 3 words long, multiply index by 3 + u32 currclauseshifted = currclause * 3U; + // Retrieve beginning index of clause, add current var index + u32 currvarind = c->clausedat[currclauseshifted] + (*addval & submask); + // Retrieve current var & parity + u32 currvar = c->variables[currvarind]; + u8 corrpar = c->parities[currvarind]; + // Calculate index in ctr + u32 currvarcorr = (varcor - currvar); + u32 currvarword = currvarcorr >> 5U; + u32 currvarbit = currvarcorr & 0b11111U; + + /* Can we make there be 1 ctr read per iteration? */ + // If in sub mode, retrieve current var's bit, else return ctr[*index] + u32 ctrind = (currvarword & submask) | (*index & nsubmask); + u32 ctrval = ctr[ctrind]; + u32 maxval = max[*index & nsubmask]; + + // Extract parity bit from ctr + u32 currpar = (ctrval >> currvarbit) & 1U; + // Check if assignment parity matches clause parity + u8 subvalid = currpar == corrpar; + // Don't mask addval because it's not being used for a lookup + u8 islvar = c->clausedat[currclauseshifted + 1U] == (*addval + 1); + u32 jval = c->clausedat[currclauseshifted + 2U]; + u32 jword = jval >> 5U; + u32 jbit = 1U << (jval & 0b11111U); + + u8 islclause = (clausecnt - 1) <= *index; + + + u8 endOfCtr = (ctrind == (wcnt - 1)); + u8 begOfCtr = (ctrind == 0); + + /* add */ + u32 nctr = ctrval + *addval; + u8 addoverflow = (nctr < ctrval); + + /* cmp */ + u8 cmpresult = -(ctrval < maxval); + cmpresult += (ctrval > maxval); + u8 cmpnores = cmpresult == 0; + u8 cmpisless = cmpresult == 255U; + + /* bitcnt */ + u32 bitcntval = (ctrval & -ctrval); + u32 bitcnt = ((bitcntval & 0x0000FFFFU) != 0) << 4U; + bitcnt |= ((bitcntval & 0x00FF00FF) != 0) << 3U; + bitcnt |= ((bitcntval & 0x0F0F0F0F) != 0) << 2U; + bitcnt |= ((bitcntval & 0x33333333) != 0) << 1U; + bitcnt |= (bitcntval & 0x55555555) != 0; + bitcnt += bitcntval != 0; + bitcntval = 32 - bitcnt; + u8 ctrWordIsZero = ctrval == 0; + + // Set current ctr word + // If in add mode, update ctr val, otherwise leave unchanged + ctr[ctrind] = (nctr & addmask) | (ctr[ctrind] & naddmask); + // Set current index + + + /* + * jword : (subcheck & valid & lastvar) + * index : (subcheck & valid & !lastvar) | (add & (!addoverflow | endOfCtr)) + * index + 1 : (subcheck & !valid) | (add & (addoverflow & !endOfCtr)) | (bitcnt & ctrWordIsZero & !endOfCtr) + * index - 1 : (cmp & cmpnores & !begOfCtr) + * 0 : (cmp & (!cmpnores | begOfCtr)) + * index[addoverflow + bitcntval] : (bitcnt & (!ctrWordIsZero | endOfCtr) + */ + u32 tempA = -(submode & (subvalid & islvar)); + if (submode & (subvalid & islvar)) fprintf(iofile, "j: %u\n", jval); + u32 tempB = -(submode & (subvalid & !islvar)); + u32 tempC = -((submode & !subvalid) | (addmode & (addoverflow & !endOfCtr)) | (bctmode & (ctrWordIsZero & !endOfCtr))); + u32 tempD = -(cmpmode & (cmpnores & !begOfCtr)); + u32 tempE = -(addmode & (!addoverflow | endOfCtr)); + u32 tempF = -(bctmode & (!ctrWordIsZero | endOfCtr)); + tempA &= jword; + tempB &= *index; + tempC &= (*index + 1); + tempD &= (*index - 1); + tempE &= (wcnt - 1); + u32 indexval = (*addval + bitcntval); + u32 indexcmp = -(indexval < varcnt); + indexval = (indexval & indexcmp) | ((varcnt - 1) & (~indexcmp)); + tempF &= c->index[((indexval) * bctmode)]; + if (bctmode & (!ctrWordIsZero | endOfCtr)) fprintf(iofile,"z: %u %u\n", (indexval), c->index[((indexval) * bctmode)]); + *index = tempA | tempB | tempC | tempD | tempE | tempF; + + tempA = -(submode & (subvalid & islvar)); + tempB = -(submode & (subvalid & !islvar)); + tempC = -(addmode & (addoverflow & !endOfCtr)); + tempD = -(bctmode & (ctrWordIsZero & !endOfCtr)); + tempA &= (jbit); + tempB &= (*addval + 1); + tempC &= (addoverflow); + tempD &= (*addval + 32); + *addval = tempA | tempB | tempC | tempD; + /* + * 1 << jbit : (subcheck & valid & lastvar) + * addval + 1 : (subcheck & valid & !lastvar) + * 0 : (subcheck & !valid) | (add & (!addoverflow | endOfCtr)) | (cmp) | (bitcnt & (!ctrWordIsZero | endOfCtr) + * addoverflow : (add & (addoverflow & !endOfCtr)) + * addoverflow + 32 : (bitcnt & ctrWordIsZero & !endOfCtr) + * + */ + + /* + * SAT : (subcheck & !valid & lastclause) + * UNSAT : (cmp & (!cmpnores | begOfCtr)) & !cmpisless + */ + u8 issat = (submode & ((!subvalid) & islclause)); + u8 isusat = (cmpmode & ((!cmpnores) | begOfCtr)) & (!cmpisless); + u32 isdone = -(issat | isusat); + + *mode |= isdone; + *index ^= (*index & isdone); + *index += issat; + /* + if (isdone) { + *mode = 9; + *index = issat; + // printf("%u\n", issat); + } + */ + + /* + * ADD : (add & (!addoverflow | endOfCtr)) + * CMP : (cmp & (!cmpnores | begOfCtr)) + * BCT : (bitcnt & (!ctrWordIsZero | endOfCtr)) + * CHK : (subcheck & valid & lastvar) + */ + tempA = (addmode & (!addoverflow | endOfCtr)); + tempB = (cmpmode & (!cmpnores | begOfCtr)); + tempC = (bctmode & (!ctrWordIsZero | endOfCtr)); + tempD = (submode & (subvalid & islvar)); + *mode += tempA | tempB | tempC; + *mode -= (tempD * 3); +} + +void ctrthings5(cnf* c, u32* state, u32* ctr, const u32* max) { + u32 varcnt = c->cnts[0]; + u32 clausecnt = c->cnts[1]; + u32 wcnt = 1U + (varcnt >> 5U); + u32* mode = state; + u32* index = state + 1; + u32* addval = state + 2; + + // printf("> %u %u %u %u %u\n", *mode, *index, *addval, *ctr, *max); + + + u8 addmode = (*mode == ADD); + u8 cmpmode = (*mode == CMP); + u8 bctmode = (*mode == BCT); + u8 submode = (*mode == CHK); + + // 1 mask for if in subcheck mode, 1 for if not + u32 submask = -submode; + u32 nsubmask = ~submask; + u32 addmask = -addmode; + u32 naddmask = ~addmask; + + /* subcheck */ + // used to find correct ctr index + u32 varcor = varcnt - 1; + // Mask to prevent oob accesses for when (varcnt / 32) > clausecnt (or smth like that) + u32 currclause = *index & submask; + // clausedat records are 3 words long, multiply index by 3 + u32 currclauseshifted = currclause * 3U; + // Retrieve beginning index of clause, add current var index + u32 currvarind = c->clausedat[currclauseshifted] + (*addval & submask); + // Retrieve current var & parity + u32 currvar = c->variables[currvarind]; + u8 corrpar = c->parities[currvarind]; + // Calculate index in ctr + u32 currvarcorr = (varcor - currvar); + u32 currvarword = currvarcorr >> 5U; + u32 currvarbit = currvarcorr & 0b11111U; + + /* Can we make there be 1 ctr read per iteration? */ + // If in sub mode, retrieve current var's bit, else return ctr[*index] + u32 ctrind = (currvarword & submask) | (*index & nsubmask); + u32 ctrval = ctr[ctrind]; + u32 maxval = max[*index & nsubmask]; + + // Extract parity bit from ctr + u32 currpar = (ctrval >> currvarbit) & 1U; + // Check if assignment parity matches clause parity + u8 subvalid = currpar == corrpar; + // Don't mask addval because it's not being used for a lookup + u8 islvar = c->clausedat[currclauseshifted + 1U] == (*addval + 1); + u32 jval = c->clausedat[currclauseshifted + 2U]; + u32 jword = jval >> 5U; + u32 jbit = 1U << (jval & 0b11111U); + + u8 islclause = (clausecnt - 1) <= *index; + + + u8 endOfCtr = (ctrind == (wcnt - 1)); + u8 begOfCtr = (ctrind == 0); + + /* add */ + u32 nctr = ctrval + *addval; + u8 addoverflow = (nctr < ctrval); + + /* cmp */ + u8 cmpresult = -(ctrval < maxval); + cmpresult += (ctrval > maxval); + u8 cmpnores = cmpresult == 0; + u8 cmpisless = cmpresult == 255U; + + /* bitcnt */ + u32 bitcntval = (ctrval & -ctrval); + u32 bitcnt = ((bitcntval & 0x0000FFFFU) != 0) << 4U; + bitcnt |= ((bitcntval & 0x00FF00FF) != 0) << 3U; + bitcnt |= ((bitcntval & 0x0F0F0F0F) != 0) << 2U; + bitcnt |= ((bitcntval & 0x33333333) != 0) << 1U; + bitcnt |= (bitcntval & 0x55555555) != 0; + bitcnt += bitcntval != 0; + bitcntval = 32 - bitcnt; + u8 ctrWordIsZero = ctrval == 0; + + // Set current ctr word + // If in add mode, update ctr val, otherwise leave unchanged + ctr[ctrind] = (nctr & addmask) | (ctr[ctrind] & naddmask); + // Set current index + + u32 tempA = -(submode & (subvalid & islvar)); + u32 tempB = -(submode & (subvalid & !islvar)); + u32 tempC = -((submode & !subvalid) | (addmode & (addoverflow & !endOfCtr)) | (bctmode & (ctrWordIsZero & !endOfCtr))); + u32 tempD = -(cmpmode & (cmpnores & !begOfCtr)); + u32 tempE = -(addmode & (!addoverflow | endOfCtr)); + u32 tempF = -(bctmode & (!ctrWordIsZero | endOfCtr)); + tempA &= jword; + tempB &= *index; + tempC &= (*index + 1); + tempD &= (*index - 1); + tempE &= (wcnt - 1); + u32 indexval = (*addval + bitcntval); + u32 indexcmp = -(indexval < varcnt); + indexval = (indexval & indexcmp) | ((varcnt - 1) & (~indexcmp)); + tempF &= c->index[((indexval) * bctmode)]; + *index = tempA | tempB | tempC | tempD | tempE | tempF; + + tempA = -(submode & (subvalid & islvar)); + tempB = -(submode & (subvalid & !islvar)); + tempC = -(addmode & (addoverflow & !endOfCtr)); + tempD = -(bctmode & (ctrWordIsZero & !endOfCtr)); + tempA &= (jbit); + tempB &= (*addval + 1); + tempC &= (addoverflow); + tempD &= (*addval + 32); + *addval = tempA | tempB | tempC | tempD; + + /* completion detection */ + u8 issat = (submode & ((!subvalid) & islclause)); + u8 isusat = (cmpmode & ((!cmpnores) | begOfCtr)) & (!cmpisless); + u32 isdone = -(issat | isusat); + + /* exit */ + *mode |= isdone; + *index ^= (*index & isdone); + *index += issat; + + tempA = (addmode & (!addoverflow | endOfCtr)); + tempB = (cmpmode & (!cmpnores | begOfCtr)); + tempC = (bctmode & (!ctrWordIsZero | endOfCtr)); + tempD = (submode & (subvalid & islvar)); + *mode += tempA | tempB | tempC; + *mode -= (tempD * 3); +} +/** NOTES: + * First thing's first is indexing, let's get started with that. + * + * + * + * + * + */ + +void ctrthings2(cnf* c, u32* state, u32* ctr, const u32* max, FILE* iofile) { + // Pre-actual work setup stuff + u32 wcnt = 1U + (c->cnts[0] >> 5U); // number of words we need (bitcnt + 1 ceiling divided by 32) + u32* mode = state; // Make pointers to array values for convenience so I don't need to refer to everything via array indices + u32* index = state + 1; // tbh the names aren't much better + u32* addval = state + 2; + + u32 varcnt = c->cnts[0] - 1; // (varcnt - 1) - v gives us the bit index of variable v in the ctr + u32 chkmsk = 0xFFFFFFFFU * (*mode == (2)); // Create a mask that zeros anything if we aren't in subsumption-check mode, to prevent out of bounds accesses + u32 chkcls = *index & chkmsk; // Use said mask to zero the index if we're not in check mode + u32 chkind = c->clausedat[3 * chkcls] + (*addval & chkmsk); // potentially zeroed index to lookup clause start, add potentially zeroed variable index in that clause + u32 var = c->variables[chkind]; // Retrieve variable at that index + u8 par = c->parities[chkind]; // Retrieve the variable's CNF parity + u32 vword = (varcnt - var) >> 5U; // As mentioned in the varcnt line, compute word index of var in ctr + u32 vbit = (varcnt - var) & 0b11111U; // Compute bit index of var in above ctr word + u8 corpar = (ctr[vword] >> vbit) & 1U; // Extract correct parity + u8 isvalid = (par == corpar); // Check if CNF parity equals assignment parity from ctr + u8 islvar = ((*addval + 1) == c->clausedat[3 * chkcls + 1]); // Check if this is the last variable of the clause + u8 isbchk0 = (*mode == (2)); // If in check mode + u8 isbchk1 = isbchk0 & isvalid; // If in check mode & valid + u8 isbchk2 = isbchk1 & islvar; // If in check mode, valid, and on the last var + u32 j = c->clausedat[3 * chkcls + 2]; // retrieve j val for current clause + *mode -= 2 * isbchk2; // If isbchk2, move to add mode +#ifdef DEBUGSAT + if (isbchk2) fprintf(iofile,"j: %u\n", j); +#endif + *index = (j >> 5U) * isbchk2 + *index * (!isbchk2); // if isbchk2, set index for the jval to be added + *addval = (1U << (j & 0b11111U)) * isbchk2 + *addval * (!isbchk2); // if isbchk, set addval for the jval to be added; + *addval += ((isbchk1) & (!islvar)); // If in check mode, valid, and not on the last var, move to the next var + u8 isbchk3 = (isbchk0 & (!isvalid)); // if in check mode and not valid, move to the next clause + *addval *= (!isbchk3); // if moving to next clause, go to 0th var of that clause + *index += (isbchk3); // actually moving to the next clause + u8 issat = (*index == c->cnts[1]) * (isbchk3); // if we just passed the last clause, none were valid, so our assignment is SAT + + u32 cmpaddind = *index * (*mode != (2)); // Zero index into ctr if in check mode, to prevent out of bounds accesses + u32 nval = ctr[cmpaddind] + *addval; // Find the result of the current step if it was addition + *addval = (nval < ctr[cmpaddind]) * (*mode == (0)) + (*addval) * (*mode == (2)); // If in add mode, set addval to carry. If in cmp mode, set to 0. If in check mode, leave alone. + ctr[cmpaddind] = nval * ((*mode == (0)) & !issat) + ctr[cmpaddind] * ((*mode != (0)) | issat); // If in add mode, set new ctr val, otherwise leave unchanged + *addval -= (ctr[cmpaddind] < max[cmpaddind]) * (*mode == (1)); // If in comparison mode, decrement addval if less than + *addval += (ctr[cmpaddind] > max[cmpaddind]) * (*mode == (1)); // If in comparison mode, increment addval if greater than + u8 addcond = (*addval == 0) | (cmpaddind == (wcnt - 1)); // Exit condition for the ADD state: If addval is zero (no carry) or we're at the last word + u8 cmpcond = (*addval != 0) | (cmpaddind == 0); // Exit condition for the CMP state: if addval is nonzero (lt or gt) or we're at the least significant word + u8 exittime = (*mode == (1)) & cmpcond & (*addval != -1); + exittime |= issat; + if (exittime) { // If in cmpmode and the comparison result is not less than, unsat + printf("%u\n", issat); + *mode = 3; + return; + } + u8 cmpdone = cmpcond & (*mode == (1)); // if comparison completion conditions are satisfied and in CMP mode + u32 addindex = (cmpaddind + 1) * !addcond + (cmpaddind) * addcond; // if add completion is satisfied, set index to most significant word, else increment by 1 + *index = addindex * (*mode == (0)) + (*index - (*mode == (1))) * (*mode != (0)); // If in add mode, use addindex; if in cmp mode, decrement index by 1 + *index *= !cmpdone; + *addval *= !(((addcond) & (*mode == (0))) | cmpdone); // If add is complete, or cmp is complete, zero. Else leave unchanged. + *mode += addcond * (*mode == (0)) + cmpdone; // If in add mode and add completion is reached, increment mode. If in cmp mode and cmp completion reached, increment mode. +} + +i32 cpusolve(cnf* c) { + u32 state[3]; + u32 wcnt = 1 + (c->cnts[0] >> 5U); + + u32* ctr = calloc((wcnt), sizeof(u32)); + if (ctr == NULL) { + printf("Failed to allocate solution buffer\n"); + exit(1); + } + + u32* max = calloc((wcnt), sizeof(u32)); + if (max == NULL) { + printf("Failed to allocate solution buffer\n"); + exit(1); + } + + stateaddpow(wcnt, max, c->cnts[0]); + + state[0] = 2; + state[1] = state[2] = 0; + + while (state[0] < 4) { + ctrthings5(c, state, ctr, max); + } + + free(ctr); + free(max); + return state[1]; +} + +i32 compareFiles(FILE *file1, FILE *file2){ + char ch1 = getc(file1); + char ch2 = getc(file2); + int error = 0, pos = 0, line = 1; + while (ch1 != EOF && ch2 != EOF){ + pos++; + if (ch1 == '\n' && ch2 == '\n'){ + line++; + pos = 0; + } + if (ch1 != ch2){ + error++; + printf("Line Number : %d \tError" + " Position : %d \n", line, pos); + } + ch1 = getc(file1); + ch2 = getc(file2); + } + if (ch1 != EOF || ch2 != EOF) error += 1; + printf("Total Errors : %d\t\n", error); + return error; +} + +void debugsolve(char* path) { + + FILE* reffile = fopen("temp.txt", "w"); + FILE* cmpfile = fopen("temp2.txt", "w"); + + runcsfloc(path, reffile); + + cnf* c = readDIMACS(path); + sortlastnum(c); + + + u32* state = malloc(sizeof(u32) * 3); + u32 wcnt = 1 + (c->cnts[0] >> 5U); + + u32* ctr = calloc((wcnt), sizeof(u32)); + if (ctr == NULL) { + printf("Failed to allocate solution buffer\n"); + exit(1); + } + + u32* max = calloc((wcnt), sizeof(u32)); + if (max == NULL) { + printf("Failed to allocate solution buffer\n"); + exit(1); + } + + stateaddpow(wcnt, max, c->cnts[0]); + + state[0] = 2; + state[1] = state[2] = 0; + + while (state[0] < 4) { + ctrthings4(c, state, ctr, max, cmpfile); + } + + free(ctr); + free(max); + + freecnf(c); + + fclose(reffile); + fclose(cmpfile); + + reffile = fopen("temp.txt", "r"); + cmpfile = fopen("temp2.txt", "r"); + + i32 res = compareFiles(reffile, cmpfile); + if (res != 0) { + printf("%s\n", path); + exit(1); + } + fclose(reffile); + fclose(cmpfile); + free(state); + +} \ No newline at end of file diff --git a/cpusolver.h b/cpusolver.h index 1067199..78ed6d6 100644 --- a/cpusolver.h +++ b/cpusolver.h @@ -1,31 +1,9 @@ #pragma once #include "types.h" - -typedef struct { - u32 varcnt; - u32 clausecnt; - u32* clauseinds; - u32* clausevals; - - u32* watchlist; -} cpusolver; - -void cpusolve() { - cpusolver s; +#include "ncnf.h" -} -/* - * Read in DIMACs - * A clause is a list of variables and their assignments - * - * Create watchlists: - * 2 entries for each list - flag bit for if its a binary clause - * - * 2-watch: - * If one of them is marked false, remove and replace - * if there are no other literals in the clause that are true, unitprop the last literal - * if there are no literals that evaluate to true and the other watch is false, UNSAT - * if a literal is true, do nothing - * if 2 literals are unassigned, clause is unsat - */ \ No newline at end of file + +i32 cpusolve(cnf* c); + +void debugsolve(char* path); \ No newline at end of file diff --git a/gpusolver.c b/gpusolver.c index 1a84f01..831f89b 100644 --- a/gpusolver.c +++ b/gpusolver.c @@ -152,6 +152,7 @@ i32 gpusolve2(gpusolver* gs, cnf* c) { exit(1); } + gs->gpuCUs = 1024; mpz_t gmpmax; mpz_init(gmpmax); mpz_ui_pow_ui(gmpmax, 2, c->cnts[0]); @@ -161,6 +162,8 @@ i32 gpusolve2(gpusolver* gs, cnf* c) { // printf("\n\n"); mpz_clear(gmpmax); + solution[0] = 0; + cl_int res = 2; cl_mem gpuheader = clCreateBuffer(gs->ctx, CL_MEM_READ_ONLY, 2 * sizeof(cl_uint), NULL, &res); @@ -233,7 +236,7 @@ i32 gpusolve2(gpusolver* gs, cnf* c) { res = clSetKernelArg(gs->kernel, 5, sizeof(cl_mem), (void*) &gpuscratchpad); - size_t deploySize[2] = { gs->gpuCUs, 1 }; + size_t deploySize[2] = { gs->gpuCUs, 64 }; res = clEnqueueNDRangeKernel(gs->commqueue, gs->kernel, 1, NULL, &(gs->gpuCUs), &(gs->gpuCUs), 0, NULL, NULL); if (res != CL_SUCCESS) { printf("Failed to queue kernel for execution\n"); diff --git a/main.c b/main.c index 7c9567c..a97c575 100644 --- a/main.c +++ b/main.c @@ -2,67 +2,13 @@ #include "gpusolver.h" #include "time.h" #include "tests/masterTest.h" + #include "gmp.h" #include "rng.h" #include "ncnf.h" - -#define ADD (0) -#define CMP (1) -#define CHK (2) +#include "cpusolver.h" -void ctrthings2(cnf* c, u32* state, u32* ctr, u32* max) { - u32 wcnt = 1U + (c->cnts[0] >> 5U); - u32* mode = state; - u32* index = state + 1; - u32* addval = state + 2; - - u32 varcnt = c->cnts[0] - 1; - u32 chkmsk = 0xFFFFFFFFU * (*mode == CHK); - u32 chkcls = *index & chkmsk; - u32 chkind = c->clausedat[3 * chkcls] + (*addval & chkmsk); - u32 var = c->variables[chkind]; - u8 par = c->parities[chkind]; - u32 vword = (varcnt - var) >> 5U; - u32 vbit = (varcnt - var) & 0b11111U; - u8 corpar = (ctr[vword] >> vbit) & 1U; - u8 isvalid = (par == corpar); - u8 islvar = ((*addval + 1) == c->clausedat[3 * chkcls + 1]); - u8 isbchk0 = (*mode == CHK); - u8 isbchk1 = isbchk0 & isvalid; - u8 isbchk2 = isbchk1 & islvar; - u32 j = c->clausedat[3 * chkcls + 2]; - *mode -= 2 * isbchk2; - *index = (j >> 5U) * isbchk2 + *index * (!isbchk2); - *addval = (1U << (j & 0b11111U)) * isbchk2 + *addval * (!isbchk2); - *addval += ((isbchk1) & (!islvar)); - u8 isbchk3 = (isbchk0 & (!isvalid)); - *addval *= (!isbchk3); - *index += (isbchk3); - u8 issat = (*index == c->cnts[1]) * (isbchk3); - - u32 cmpaddind = *index * (*mode != CHK); - u32 nval = ctr[cmpaddind] + *addval; // Find the result of the current step if it was addition - *addval = (nval < ctr[cmpaddind]) * (*mode == ADD) + (*addval) * (*mode == CHK); // If in add mode, set addval to carry. If in cmp mode, set to 0. If in check mode, leave alone. - ctr[cmpaddind] = nval * ((*mode == ADD) & !issat) + ctr[cmpaddind] * ((*mode != ADD) | issat); // If in add mode, set new ctr val, otherwise leave unchanged - *addval -= (ctr[cmpaddind] < max[cmpaddind]) * (*mode == CMP); // If in comparison mode, decrement addval if less than - *addval += (ctr[cmpaddind] > max[cmpaddind]) * (*mode == CMP); // If in comparison mode, increment addval if greater than - u8 addcond = (*addval == 0) | (cmpaddind == (wcnt - 1)); // Exit condition for the ADD state: If addval is zero (no carry) or we're at the last word - u8 cmpcond = (*addval != 0) | (cmpaddind == 0); // Exit condition for the CMP state: if addval is nonzero (lt or gt) or we're at the least significant word - u8 exittime = (*mode == CMP) & cmpcond & (*addval != -1); - exittime |= issat; - if (exittime) { // If in cmpmode and the comparison result is not less than, unsat - printf("Result: %u\n", issat); - *mode = 4; - return; - } - u8 cmpdone = cmpcond & (*mode == CMP); // if comparison completion conditions are satisfied and in CMP mode - u32 addindex = (cmpaddind + 1) * !addcond + (wcnt - 1) * addcond; // if add completion is satisfied, set index to most significant word, else increment by 1 - *index = addindex * (*mode == ADD) + (*index - (*mode == CMP)) * (*mode != ADD); // If in add mode, use addindex; if in cmp mode, decrement index by 1 - *index *= !cmpdone; - *addval *= !(((addcond) & (*mode == ADD)) | cmpdone); // If add is complete, or cmp is complete, zero. Else leave unchanged. - *mode += addcond * (*mode == ADD) + cmpdone; // If in add mode and add completion is reached, increment mode. If in cmp mode and cmp completion reached, increment mode. -} void printbits(unsigned a) { for (unsigned i = 0; i < 32; ++i) { @@ -105,123 +51,41 @@ void mul(u32* c, u32 len, u32* a, u32 b) { } } +i32 runuf20lol() { + u32 passed = 0; + u64 tottime = 0; + for (u32 i = 0; i < 1000; ++i) { + char buf[128]; + i32 len = sprintf(buf, "/home/lev/Downloads/uf20/uf20-0%u.cnf", i + 1); + + debugsolve(buf); + + passed++; + } + return 1; +} + +i32 runuf50lol() { + u32 passed = 0; + u64 tottime = 0; + for (u32 i = 0; i < 1000; ++i) { + char buf[128]; + i32 len = sprintf(buf, "/home/lev/Downloads/uf50/uf50-0%u.cnf", i + 1); + + debugsolve(buf); + + passed++; + } + return 1; +} + int main() { - /* - printf("Tests: %lu\n", TESTS); - rngstate rng; - u64 rseed = utime(); - seed(&rng, rseed); - printf("Seed: %lu\n", rseed); - - mpz_t a, b, c, d, e; - - mpz_inits(a, b, c, d, e, NULL); - - u32 ctrp[CSZE]; - u32 maxp[CSZE]; - - u32* ctr = ctrp; - u32* max = maxp; - - u32 state[3]; - u32 hdr[2]; - - char buf[4096]; - - - for (u64 i = 0; i < TESTS; ++i) { - memset(ctr, 0, sizeof(u32) * CSZE); - memset(max, 0, sizeof(u32) * CSZE); - - //u32 lenval = ru32(&rng) % CSZE; - u32 lenval = CSZE; - if (lenval < 2) lenval = 2; - u32 jval = ru32(&rng) % ((lenval - 1) * 32); - - for (u32 j = 0; j < lenval - 2; ++j) { - ctr[j] = ru32(&rng); - max[j] = ru32(&rng); - } - - mpz_import(a, lenval, -1, sizeof(u32), 0, 0, ctr); - mpz_import(b, lenval, -1, sizeof(u32), 0, 0, max); - - state[1] = jval >> 5U; - state[2] = jval & 0b11111U; - state[2] = 1U << state[2]; - mpz_ui_pow_ui(c, 2, jval); - - - if (rf32(&rng) < eqprob && mpz_cmp) { - mpz_sub(a, b, c); - if (mpz_sgn(a) != -1) { - mpz_export(ctr, NULL, -1, sizeof(u32), 0, 0, a); - } else { - mpz_import(a, lenval, -1, sizeof(u32), 0, 0, ctr); - } - } - - mpz_add(d, a, c); - - state[0] = 0; - - while (state[0] < 2U) { - ctrthings2(hdr, lenval, state, ctr, max); - } - - mpz_import(c, lenval, -1, sizeof(u32), 0, 0, ctr); - - - - i32 res = mpz_cmp(d, b); - - if (res == -1) { - if (state[2] != -1) { - printf("Fuck2 %lu\n", i); - printf("d: "); - mpz_out_str(stdout, 10, d); - - printf("\nb: "); - mpz_out_str(stdout, 10, b); - printf("\n"); - printf("mode: %u\n", state[2]); - exit(0); - } - } else { - if (state[2] != 2) { - printf("Fuck3 %lu\n", i); - printf("d: "); - mpz_out_str(stdout, 10, d); - printf("\nc: "); - mpz_out_str(stdout, 10, c); - printf("\nb: "); - mpz_out_str(stdout, 10, b); - printf("\n"); - printf("mode: %u\n", state[2]); - exit(0); - } - } - - res = mpz_cmp(d, c); - if (res != 0) { - printf("Fuck %lu\na: ", i); - mpz_out_str(stdout, 10, a); - printf("\nd: "); - mpz_out_str(stdout, 10, d); - printf("\nc: "); - mpz_out_str(stdout, 10, c); - printf("\n%u %u\n", lenval, jval); - exit(0); - } - } - - - - mpz_clears(a, b, c, d, e, NULL); - */ + // debugsolve("/home/lev/Downloads/uf50/uf50-01.cnf"); + // runuf20lol(); + runTests(); /* srand( utime()); @@ -271,8 +135,8 @@ int main() { return 0; */ - runTests(); - return 0; + // runTests(); + // return 0; /* // printf("%u\n", c->litcnt); diff --git a/ncnf.c b/ncnf.c index e9e8f05..e69abe0 100644 --- a/ncnf.c +++ b/ncnf.c @@ -60,6 +60,8 @@ cnf* readDIMACS(char* path) { c->clausedat = calloc(*clausecnt, sizeof(u32) * 3); CHECK(c->clausedat, "Failed to allocate clause data\n") + c->index = calloc(*varcnt, sizeof(u32)); + CHECK(c->index, "Failed to allocate clause index\n") c->variables = calloc(cap, sizeof(u32)); CHECK(c->variables, "Failed to allocate literal variables\n") c->parities = calloc(cap, sizeof(u8)); @@ -208,4 +210,29 @@ void sortlastnum(cnf* c) { } free(d); + + // TODO: Rewrite, the following is copied from Gábor Kusper's implementation + u32 lastNum = 0; + for (u32 i = c->cnts[1] - 1; i < c->cnts[1]; --i) { + if (c->clausedat[3 * i + 2] > lastNum) { + while (lastNum < c->clausedat[3 * i + 2]) { + c->index[lastNum] = i + 1; + lastNum++; + } + } + } + + + u32 corrInd = 0; + while (c->index[corrInd] == c->cnts[1]) { + corrInd++; + } + + if (corrInd != 0) { + u32 goodVal = c->index[corrInd]; + do { + corrInd--; + c->index[corrInd] = goodVal; + } while (corrInd != 0); + } } diff --git a/ncnf.h b/ncnf.h index 6e84f26..6b4fc04 100644 --- a/ncnf.h +++ b/ncnf.h @@ -12,6 +12,7 @@ typedef struct { u32 cnts[3]; // { varcnt, clausecnt } u32* clausedat; // { ind, len, jval } + u32* index; u32* variables; u8* parities; } cnf; @@ -23,3 +24,7 @@ void printcnf(cnf* c); void sortlastnum(cnf* c); void freecnf(cnf* c); + +/* -mavx2 -O3 -ftree-loop-linear -ftree-loop-im -ftree-loop-ivcanon -fivopts -ftree-vectorize -ftracer -funroll-all-loops + * + */ \ No newline at end of file diff --git a/psat.cl b/psat.cl index 9305baa..d71b4b5 100644 --- a/psat.cl +++ b/psat.cl @@ -62,8 +62,6 @@ void mul(uint* c, uint len, uint* a, uint b) { } __kernel void vectorSAT(__global const uint* cnfhdr, __global const uint* clausedat, __global const uint* vars, __global const uchar* pars, __global uint* output, __global uint* lctrs) { - uint locid = get_local_id(0); - uint locsz = get_local_size(0); // uint grpid = get_group_id(0); // uint grpcn = get_num_groups(0); uint globid = get_global_id(0); @@ -75,10 +73,7 @@ __kernel void vectorSAT(__global const uint* cnfhdr, __global const uint* clause uint index = 0; uint addval = 0; - - - output[0] = 0; - + //uint ctroff = uint* ctr = lctrs + wcnt * 2 * globid; uint* max = lctrs + wcnt * (2 * globid + 1); @@ -92,7 +87,8 @@ __kernel void vectorSAT(__global const uint* cnfhdr, __global const uint* clause } else { mul(max, wcnt, output + 1, globid + 1); } - // printf("%u %u\n", ctr[0], max[0]); + //printf("%u %u\n", ctr[0], max[0]); + //printf("%u %u\n", wcnt * 2 * globid, wcnt * (2 * globid + 1)); barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE); @@ -115,7 +111,7 @@ __kernel void vectorSAT(__global const uint* cnfhdr, __global const uint* clause uchar isbchk2 = isbchk1 & islvar; uint j = clausedat[3 * chkcls + 2]; mode -= 2 * isbchk2; - // if (isbchk2) printf("j: %u\n", j); + //if (isbchk2 && globid == 0) printf("%u j: %u\n", globid, j); index = (j >> 5U) * isbchk2 + index * (!isbchk2); addval = (1U << (j & 0b11111U)) * isbchk2 + addval * (!isbchk2); addval += ((isbchk1) & (!islvar)); @@ -142,8 +138,10 @@ __kernel void vectorSAT(__global const uint* cnfhdr, __global const uint* clause } output[0] = 1; } + return; } - return; + //if (globid == 0) printf("fuck %u %u\n", ctr[0], max[0]); + break; } uchar cmpdone = cmpcond & (mode == 1); // if comparison completion conditions are satisfied and in CMP mode uint addindex = (cmpaddind + 1) * !addcond + (wcnt - 1) * addcond; // if add completion is satisfied, set index to most significant word, else increment by 1 diff --git a/tests/masterTest.c b/tests/masterTest.c index c079b1e..d7bb2ca 100644 --- a/tests/masterTest.c +++ b/tests/masterTest.c @@ -3,6 +3,7 @@ #include "../cnf.h" #include "../gpusolver.h" #include "../time.h" +#include "../cpusolver.h" i32 runTests() { i32 res = runuf20(); @@ -24,8 +25,7 @@ i32 runTests() { i32 runuf20() { - gpusolver* gs = initSolver(); - // printf("Running against uf20\n"); + printf("Running against uf20\n"); u32 passed = 0; u64 tottime = 0; for (u32 i = 0; i < 1000; ++i) { @@ -37,23 +37,21 @@ i32 runuf20() { sortlastnum(c); u64 start = utime(); - i32 res = gpusolve2(gs, c); + i32 res = cpusolve(c); u64 stop = utime(); tottime += (stop - start); freecnf(c); if (res == 1) passed++; } - // printf("Passed %u / 1000 tests\n", passed); - // printf("Took %lf s total, %lf s on avg\n", ((f64) tottime) / 1000000.0, ((f64) tottime) / 1000000000.0); + printf("Passed %u / 1000 tests\n", passed); + printf("Took %lf s total, %lf s on avg\n", ((f64) tottime) / 1000000.0, ((f64) tottime) / 1000000000.0); if (passed == 1000) return 0; - freeSolver(gs); return 1; } i32 runuf50() { - gpusolver* gs = initSolver(); - // printf("Running against uf50\n"); + printf("Running against uf50\n"); u32 passed = 0; u64 tottime = 0; for (u32 i = 0; i < 1000; ++i) { @@ -65,23 +63,21 @@ i32 runuf50() { sortlastnum(c); u64 start = utime(); - i32 res = gpusolve2(gs, c); + i32 res = cpusolve(c); u64 stop = utime(); tottime += (stop - start); freecnf(c); if (res == 1) passed++; } - // printf("Passed %u / 1000 tests\n", passed); - // printf("Took %lf s total, %lf s on avg\n", ((f64) tottime) / 1000000.0, ((f64) tottime) / 1000000000.0); + printf("Passed %u / 1000 tests\n", passed); + printf("Took %lf s total, %lf s on avg\n", ((f64) tottime) / 1000000.0, ((f64) tottime) / 1000000000.0); if (passed == 1000) return 0; - freeSolver(gs); return 1; } i32 runuuf50() { - gpusolver* gs = initSolver(); - // printf("Running against uuf50\n"); + printf("Running against uuf50\n"); u32 passed = 0; u64 tottime = 0; for (u32 i = 0; i < 1000; ++i) { @@ -93,16 +89,15 @@ i32 runuuf50() { sortlastnum(c); u64 start = utime(); - i32 res = gpusolve2(gs, c); + i32 res = cpusolve(c); u64 stop = utime(); tottime += (stop - start); freecnf(c); if (res == 0) passed++; } - // printf("Passed %u / 1000 tests\n", passed); - // printf("Took %lf s total, %lf s on avg\n", ((f64) tottime) / 1000000.0, ((f64) tottime) / 1000000000.0); + printf("Passed %u / 1000 tests\n", passed); + printf("Took %lf s total, %lf s on avg\n", ((f64) tottime) / 1000000.0, ((f64) tottime) / 1000000000.0); if (passed == 1000) return 0; - freeSolver(gs); return 1; } \ No newline at end of file