/* foop[i] >>= 1; */
#include "qs.h"
extern char *malloc();
#define ison(n, v) (v[(n) >> 5] & (1 << ((n) & 31)))
#define on(n, v)   v[(n) >> 5] |= (1 << ((n) & 31))
#define off(n, v)  v[(n) >> 5] &= ~(1 << ((n) & 31))

mint *one;

vec *vecs;	/* a few extra ones to make sure */
int nvecs, vsize;
unsigned int *matrix;
int *foop;

int *pivot;

svvec(n, a, b, q, orig)
int *a, *b;
mint *q, *orig;
{	int i, p;
	for(i = 0; i < n; i++) {
		if((b[i] & 1) == 0)
			continue;
		p = a[i];
		on(p, vecs[nvecs].rval);
	}
	vecs[nvecs].iam = -1;
	vecs[nvecs].aval = itom(0);
	move(q, vecs[nvecs].aval);
	vecs[nvecs].qval = itom(0);
	move(orig, vecs[nvecs].qval);
	vecs[nvecs].ps = (int *)malloc((n + 2) * sizeof(int));
	vecs[nvecs].exps = (short *)malloc((n + 2) * sizeof(short));
	vecs[nvecs].vlen = n;
	for(i = 0; i < n; i++) {
		vecs[nvecs].ps[i] = a[i];
		vecs[nvecs].exps[i] = b[i];
	}
#ifdef DEBUG
	checkvec(nvecs);
#endif
	lu();
	nvecs++;
	if(nvecs >= np + 100) {
		printf("hey, tried 100 extra times?\n");
		exit(0);
	}
}

/* lu decomposition, maybe */
lu()
{	int i;
	/* too much work */
	for(i = 0; i < nvecs; i++) {
		if(vecs[i].iam < 0)
			continue;
		if(!ison(vecs[i].iam, vecs[nvecs].rval))
			continue;
		elim(vecs[i].iam, i, nvecs);
	}
	for(i = np - 1; i >= 0; i--)
		if(pivot[i] == -1 && ison(i, vecs[nvecs].rval)) {
			pivot[i] = nvecs;
			vecs[nvecs].iam = i;
			break;
		}
#ifdef DEBUG
	checkrow(nvecs);
#endif
	if(vecs[nvecs].iam != -1)
		return;
	for(i = 0; i < np; i++)
		if(ison(i, vecs[nvecs].rval) && pivot[i] == -1)
			return;
	doit(nvecs);
}

elim(piv, a, b)	/* add row a to row b */
{	int i;
	for(i = 0; i < vsize; i++)
		vecs[b].rval[i] ^= vecs[a].rval[i];
	if(ison(piv, vecs[b].rval))
		abort("elim");
	on(piv, vecs[b].rval);
}


mint *xa, *xb, *xc, *xd, *xe, *xbb;
doit(n)	/* row n is now a winner, we hope */
{	static flg = 0;
	int i, j;
#ifdef DEBUG
	checkrow(n);
#endif
	if(flg++ == 0) {
		xa = itom(1), xb = itom(0), xc = itom(0), xd = itom(0);
		xe = itom(0), xbb = itom(0);
	}
	for(i = 0; i < np; i++)
		foop[i] = 0;
	move(xa, xb);
	move(xa, xe);
	move(xa, xbb);
	for(i = 0; i < np; i++)
		if(ison(i, vecs[n].rval)) {
			refactor(pivot[i]);
			mult(xe, vecs[pivot[i]].aval, xe);
			mdiv(xe, num, xc, xe);
		}
	refactor(n);
	mult(xe, vecs[n].aval, xe);
	mdiv(xe, num, xc, xe);
	move(one, xd);
	for(i = 0; i < np; i++) {
		if(foop[i] & 1) {
			printf("hey, not a square foop[%d] = %d\n", i, foop[i]);
			fflush(stdout);
			abort();
		}
		foop[i] >>= 1;
	}
	for(i = 1; i < np; i++)	/* skip -1 */
		if(foop[i] > 0) {
			for(j = 0; j < foop[i]; j++)
				mult(xd, mintp[i], xd);
			mdiv(xd, num, xc, xd);
		}
	msub(xe, xd, xb);
	gcd(xb, num, xc);
	printf("gcd(%d) ", nvecs), mout(xc);
	if(mcmp(xc, num) != 0 && (xc->len != 1 || xc->val[0] != 1)) {
		newfactor(xc);
		mdiv(num, xc, xb, xe);
		newfactor(xb);
		return;
	}
	if(mcmp(xc, num) == 0)
		return;
	madd(xe, xd, xb);
	mult(xe, xe, xe);
	mult(xd, xd, xd);
	msub(xe, xd, xbb);
	mdiv(xbb, num, xc, xbb);
	mdiv(xb, num, xc, xb);
	if(xb->len != 0) {
		for(i = 0; i < (np<100?np:100); i++)
			if(foop[i] != 0)
				printf("p %d e %d\n", goodp[i], foop[i]);
		printf("...\n");
		for(i = 0; i <= n; i++)
			checkrow(i);
		abort();
	}
}

refactor(n)
{	int i;
	if(n < 0)
		printf("refactor(%d)?\n", n);
	for(i = 0; i < vecs[n].vlen; i++) {
		if(vecs[n].ps[i] >= np || vecs[n].ps[i] < 0)
			printf("oops n %d vlen %d ps[%d] %d\n", n, vecs[n].vlen, i, vecs[n].ps[i]);
		foop[vecs[n].ps[i]] += vecs[n].exps[i];
	}
}

#ifdef DEBUG
checkrow(n)
{	vec *v = vecs + n;
	int i, j;
	static mint *lhs, *scratch, *rhs, *arhs;
	static flg = 0;
	if(flg++ == 0)
		arhs = itom(0), rhs = itom(0), scratch = itom(0), lhs = itom(0);
	for(i = 0; i < np; i++)
		foop[i] = 0;
	move(v->aval, lhs);
	for(i = 0; i < np; i++)
		if(ison(i, v->rval) && pivot[i] >= 0 && pivot[i] != n) {
			mult(lhs, vecs[pivot[i]].aval, lhs);
			mdiv(lhs, num, scratch, lhs);
		}
	mult(lhs, lhs, lhs);
	mdiv(lhs, num, scratch, lhs);
	move(v->qval, rhs);
	refactor(n);
	for(i = 0; i < np; i++)
		if(ison(i, v->rval) && pivot[i] >= 0 && pivot[i] != n) {
			mult(rhs, vecs[pivot[i]].qval, rhs);
			mdiv(rhs, num, scratch, rhs);
			refactor(pivot[i]);
		}
	if(rhs->len < 0)
		madd(rhs, num, rhs);
	move(one, arhs);
	for(i = 1; i < np; i++)
		if(foop[i] > 0) {
			for(j = 0; j < foop[i]; j++)
				mult(arhs, mintp[i], arhs);
			mdiv(arhs, num, scratch, arhs);
		}
	if(foop[0] & 1)
		msub(num, arhs, arhs);
	if(mcmp(lhs, rhs) == 0 && mcmp(rhs, arhs) == 0)
		return;
	printf("row %d failed!\n", n);
	printf("lhs "), mout(lhs);
	printf("rhs "), mout(rhs);
	printf("arhs "), mout(arhs);
	exit(1);
}

checkvec(n)
{	static mint *xa, *xb, *xc;
	static flg = 0;
	int i, j;
	vec *v = vecs + n;
	if(!flg++) {
		xa = itom(0), xb = itom(0), xc = itom(0);
	}
	mult(v->aval, v->aval, xa);
	mdiv(xa, num, xa, xb);
	mdiv(v->qval, num, xa, xc);
	if(mcmp(xb, xc) != 0) {
		printf("a * a != q mod num\n");
		exit(1);
	}
	move(one, xa);
	for(i = 0; i < v->vlen; i++) {
		for(j = 0; j < v->exps[i]; j++) {
			mult(xa, mintp[v->ps[i]], xa);
			mdiv(xa, num, xb, xa);
		}
	}
	move(v->qval, xb);
	if(xb->len < 0)
		madd(xb, num, xb);
	if(mcmp(xa, xb) != 0) {
		printf("vec not q! "), mout(xa), mout(v->qval);
		exit(1);
	}
}
#endif
