分析
先简单介绍一下阶。设 $\gcd(a,p)=1$,则满足 $a^x\equiv 1\pmod p$ 的最小正整数 $x$ 被称为 $a$ 在模 $p$ 意义下的阶,记做 $\operatorname{ord}_pa=x$
阶有这样一些性质(这里没有证明,可以自行查找相关资料):
- $\operatorname{ord}_p a|\varphi(p)$。于是我们可以枚举 $\varphi(p)$ 的质因子,如果能除掉就除掉,从而求出 $\operatorname{ord}_p a$。
- 如果 $a\equiv b\pmod p$,那么 $\operatorname{ord}_p a=\operatorname{ord}_p b$。
- $\operatorname{ord}_p a^k=\frac{\operatorname{ord}_p a}{\gcd(\operatorname{ord}_p a,k)}$。
现在考虑这道题。如果 $a_i\!^k=a_j$,我们从 $a_i$ 向 $a_j$ 连一条边。那么每个数被选择当且仅当能到达它的所有数都未被选择。设有 $cnt$ 个数能到达它,则它的贡献为 $2^{n-cnt-1}$。
把所有数分成两种考虑。对于 $\gcd(a_i,p)\neq 1$ 的数,我们可以暴力计算它能到达哪些点(显然它只会到达 $\gcd(a_j,p)\neq 1$ 的数),复杂度 $\mathcal{O}(n\log p)$。对于 $\gcd(a_i,p)=1$ 的数,可以发现 $a_i$ 能到达 $a_j$ 当且仅当 $\operatorname{ord}_p a_j|\operatorname{ord}_p a_i$(由上面的性质不难推出),所以我们 $\mathcal{O}(n^2)$ 算一下就好了。
然而可能会有很多阶相同的数也就是构成了一个强连通分量,可以缩成一个点,将其贡献乘上 $2^{sz}-1$ 即可。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=5000+10;
const int mod=998244353;
int qpow(int a,int b,int p) { int c=1;
for (;b;b>>=1,a=1ll*a*a%p) if (b&1) c=1ll*c*a%p;
return c;
}
int n,p,phi,pw[N];
vector<int> d;
int v1[N],sz1[N],c1[N],s1=0,v2[N],c2[N],s2=0;
map<int,int> id1,id2;
void calc() {
int q=p;
for (int i=2;i*i<=p;++i)
if (p%i==0) { q=i; break; }
phi=p-p/q; int x=phi;
for (int i=2;i*i<=x;++i) {
if (x%i) continue;
d.emplace_back(i);
while (x%i==0) x/=i;
}
if (x!=1) d.emplace_back(x);
}
int getord(int a) {
int x=phi;
for (int i:d)
while (x%i==0&&qpow(a,x/i,p)==1) x/=i;
return x;
}
int main() {
n=read(),p=read(); calc();
pw[0]=1;
for (int i=1;i<=n;++i) pw[i]=2ll*pw[i-1]%mod;
for (int i=1;i<=n;++i) {
int a=read();
if (__gcd(a,p)!=1) v2[++s2]=a,id2[a]=s2;
else {
int x=getord(a);
if (!id1.count(x)) v1[++s1]=x,id1[x]=s1,sz1[s1]=1;
else ++sz1[id1[x]];
}
}
for (int i=1;i<=s1;++i) c1[i]=n-sz1[i];
for (int i=1;i<=s1;++i)
for (int j=1;j<=s1;++j)
if (i!=j&&v1[j]%v1[i]==0) c1[i]-=sz1[j];
for (int i=1;i<=s2;++i) c2[i]=n-1;
for (int i=1;i<=s2;++i)
for (int j=1ll*v2[i]*v2[i]%p;j;j=1ll*j*v2[i]%p)
if (id2.count(j)) --c2[id2[j]];
int ans=0;
for (int i=1;i<=s1;++i) ans=(ans+1ll*(pw[sz1[i]]-1)*pw[c1[i]])%mod;
for (int i=1;i<=s2;++i) ans=(ans+pw[c2[i]])%mod;
printf("%d\n",ans);
return 0;
}