ecp.cpp

00001 // ecp.cpp - written and placed in the public domain by Wei Dai
00002 
00003 #include "pch.h"
00004 
00005 #ifndef CRYPTOPP_IMPORTS
00006 
00007 #include "ecp.h"
00008 #include "asn.h"
00009 #include "nbtheory.h"
00010 
00011 #include "algebra.cpp"
00012 
00013 NAMESPACE_BEGIN(CryptoPP)
00014 
00015 ANONYMOUS_NAMESPACE_BEGIN
00016 static inline ECP::Point ToMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
00017 {
00018         return P.identity ? P : ECP::Point(mr.ConvertIn(P.x), mr.ConvertIn(P.y));
00019 }
00020 
00021 static inline ECP::Point FromMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
00022 {
00023         return P.identity ? P : ECP::Point(mr.ConvertOut(P.x), mr.ConvertOut(P.y));
00024 }
00025 NAMESPACE_END
00026 
00027 ECP::ECP(const ECP &ecp, bool convertToMontgomeryRepresentation)
00028 {
00029         if (convertToMontgomeryRepresentation && !ecp.GetField().IsMontgomeryRepresentation())
00030         {
00031                 m_fieldPtr.reset(new MontgomeryRepresentation(ecp.GetField().GetModulus()));
00032                 m_a = GetField().ConvertIn(ecp.m_a);
00033                 m_b = GetField().ConvertIn(ecp.m_b);
00034         }
00035         else
00036                 operator=(ecp);
00037 }
00038 
00039 ECP::ECP(BufferedTransformation &bt)
00040         : m_fieldPtr(new Field(bt))
00041 {
00042         BERSequenceDecoder seq(bt);
00043         GetField().BERDecodeElement(seq, m_a);
00044         GetField().BERDecodeElement(seq, m_b);
00045         // skip optional seed
00046         if (!seq.EndReached())
00047                 BERDecodeOctetString(seq, TheBitBucket());
00048         seq.MessageEnd();
00049 }
00050 
00051 void ECP::DEREncode(BufferedTransformation &bt) const
00052 {
00053         GetField().DEREncode(bt);
00054         DERSequenceEncoder seq(bt);
00055         GetField().DEREncodeElement(seq, m_a);
00056         GetField().DEREncodeElement(seq, m_b);
00057         seq.MessageEnd();
00058 }
00059 
00060 bool ECP::DecodePoint(ECP::Point &P, const byte *encodedPoint, unsigned int encodedPointLen) const
00061 {
00062         StringStore store(encodedPoint, encodedPointLen);
00063         return DecodePoint(P, store, encodedPointLen);
00064 }
00065 
00066 bool ECP::DecodePoint(ECP::Point &P, BufferedTransformation &bt, unsigned int encodedPointLen) const
00067 {
00068         byte type;
00069         if (encodedPointLen < 1 || !bt.Get(type))
00070                 return false;
00071 
00072         switch (type)
00073         {
00074         case 0:
00075                 P.identity = true;
00076                 return true;
00077         case 2:
00078         case 3:
00079         {
00080                 if (encodedPointLen != EncodedPointSize(true))
00081                         return false;
00082 
00083                 Integer p = FieldSize();
00084 
00085                 P.identity = false;
00086                 P.x.Decode(bt, GetField().MaxElementByteLength()); 
00087                 P.y = ((P.x*P.x+m_a)*P.x+m_b) % p;
00088 
00089                 if (Jacobi(P.y, p) !=1)
00090                         return false;
00091 
00092                 P.y = ModularSquareRoot(P.y, p);
00093 
00094                 if ((type & 1) != P.y.GetBit(0))
00095                         P.y = p-P.y;
00096 
00097                 return true;
00098         }
00099         case 4:
00100         {
00101                 if (encodedPointLen != EncodedPointSize(false))
00102                         return false;
00103 
00104                 unsigned int len = GetField().MaxElementByteLength();
00105                 P.identity = false;
00106                 P.x.Decode(bt, len);
00107                 P.y.Decode(bt, len);
00108                 return true;
00109         }
00110         default:
00111                 return false;
00112         }
00113 }
00114 
00115 void ECP::EncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
00116 {
00117         if (P.identity)
00118                 NullStore().TransferTo(bt, EncodedPointSize(compressed));
00119         else if (compressed)
00120         {
00121                 bt.Put(2 + P.y.GetBit(0));
00122                 P.x.Encode(bt, GetField().MaxElementByteLength());
00123         }
00124         else
00125         {
00126                 unsigned int len = GetField().MaxElementByteLength();
00127                 bt.Put(4);      // uncompressed
00128                 P.x.Encode(bt, len);
00129                 P.y.Encode(bt, len);
00130         }
00131 }
00132 
00133 void ECP::EncodePoint(byte *encodedPoint, const Point &P, bool compressed) const
00134 {
00135         ArraySink sink(encodedPoint, EncodedPointSize(compressed));
00136         EncodePoint(sink, P, compressed);
00137         assert(sink.TotalPutLength() == EncodedPointSize(compressed));
00138 }
00139 
00140 ECP::Point ECP::BERDecodePoint(BufferedTransformation &bt) const
00141 {
00142         SecByteBlock str;
00143         BERDecodeOctetString(bt, str);
00144         Point P;
00145         if (!DecodePoint(P, str, str.size()))
00146                 BERDecodeError();
00147         return P;
00148 }
00149 
00150 void ECP::DEREncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
00151 {
00152         SecByteBlock str(EncodedPointSize(compressed));
00153         EncodePoint(str, P, compressed);
00154         DEREncodeOctetString(bt, str);
00155 }
00156 
00157 bool ECP::ValidateParameters(RandomNumberGenerator &rng, unsigned int level) const
00158 {
00159         Integer p = FieldSize();
00160 
00161         bool pass = p.IsOdd();
00162         pass = pass && !m_a.IsNegative() && m_a<p && !m_b.IsNegative() && m_b<p;
00163 
00164         if (level >= 1)
00165                 pass = pass && ((4*m_a*m_a*m_a+27*m_b*m_b)%p).IsPositive();
00166 
00167         if (level >= 2)
00168                 pass = pass && VerifyPrime(rng, p);
00169 
00170         return pass;
00171 }
00172 
00173 bool ECP::VerifyPoint(const Point &P) const
00174 {
00175         const FieldElement &x = P.x, &y = P.y;
00176         Integer p = FieldSize();
00177         return P.identity ||
00178                 (!x.IsNegative() && x<p && !y.IsNegative() && y<p
00179                 && !(((x*x+m_a)*x+m_b-y*y)%p));
00180 }
00181 
00182 bool ECP::Equal(const Point &P, const Point &Q) const
00183 {
00184         if (P.identity && Q.identity)
00185                 return true;
00186 
00187         if (P.identity && !Q.identity)
00188                 return false;
00189 
00190         if (!P.identity && Q.identity)
00191                 return false;
00192 
00193         return (GetField().Equal(P.x,Q.x) && GetField().Equal(P.y,Q.y));
00194 }
00195 
00196 const ECP::Point& ECP::Identity() const
00197 {
00198         return Singleton<Point>().Ref();
00199 }
00200 
00201 const ECP::Point& ECP::Inverse(const Point &P) const
00202 {
00203         if (P.identity)
00204                 return P;
00205         else
00206         {
00207                 m_R.identity = false;
00208                 m_R.x = P.x;
00209                 m_R.y = GetField().Inverse(P.y);
00210                 return m_R;
00211         }
00212 }
00213 
00214 const ECP::Point& ECP::Add(const Point &P, const Point &Q) const
00215 {
00216         if (P.identity) return Q;
00217         if (Q.identity) return P;
00218         if (GetField().Equal(P.x, Q.x))
00219                 return GetField().Equal(P.y, Q.y) ? Double(P) : Identity();
00220 
00221         FieldElement t = GetField().Subtract(Q.y, P.y);
00222         t = GetField().Divide(t, GetField().Subtract(Q.x, P.x));
00223         FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), Q.x);
00224         m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
00225 
00226         m_R.x.swap(x);
00227         m_R.identity = false;
00228         return m_R;
00229 }
00230 
00231 const ECP::Point& ECP::Double(const Point &P) const
00232 {
00233         if (P.identity || P.y==GetField().Identity()) return Identity();
00234 
00235         FieldElement t = GetField().Square(P.x);
00236         t = GetField().Add(GetField().Add(GetField().Double(t), t), m_a);
00237         t = GetField().Divide(t, GetField().Double(P.y));
00238         FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), P.x);
00239         m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
00240 
00241         m_R.x.swap(x);
00242         m_R.identity = false;
00243         return m_R;
00244 }
00245 
00246 template <class T, class Iterator> void ParallelInvert(const AbstractRing<T> &ring, Iterator begin, Iterator end)
00247 {
00248         unsigned int n = end-begin;
00249         if (n == 1)
00250                 *begin = ring.MultiplicativeInverse(*begin);
00251         else if (n > 1)
00252         {
00253                 std::vector<T> vec((n+1)/2);
00254                 unsigned int i;
00255                 Iterator it;
00256 
00257                 for (i=0, it=begin; i<n/2; i++, it+=2)
00258                         vec[i] = ring.Multiply(*it, *(it+1));
00259                 if (n%2 == 1)
00260                         vec[n/2] = *it;
00261 
00262                 ParallelInvert(ring, vec.begin(), vec.end());
00263 
00264                 for (i=0, it=begin; i<n/2; i++, it+=2)
00265                 {
00266                         if (!vec[i])
00267                         {
00268                                 *it = ring.MultiplicativeInverse(*it);
00269                                 *(it+1) = ring.MultiplicativeInverse(*(it+1));
00270                         }
00271                         else
00272                         {
00273                                 std::swap(*it, *(it+1));
00274                                 *it = ring.Multiply(*it, vec[i]);
00275                                 *(it+1) = ring.Multiply(*(it+1), vec[i]);
00276                         }
00277                 }
00278                 if (n%2 == 1)
00279                         *it = vec[n/2];
00280         }
00281 }
00282 
00283 struct ProjectivePoint
00284 {
00285         ProjectivePoint() {}
00286         ProjectivePoint(const Integer &x, const Integer &y, const Integer &z)
00287                 : x(x), y(y), z(z)      {}
00288 
00289         Integer x,y,z;
00290 };
00291 
00292 class ProjectiveDoubling
00293 {
00294 public:
00295         ProjectiveDoubling(const ModularArithmetic &mr, const Integer &m_a, const Integer &m_b, const ECPPoint &Q)
00296                 : mr(mr), firstDoubling(true), negated(false)
00297         {
00298                 if (Q.identity)
00299                 {
00300                         sixteenY4 = P.x = P.y = mr.MultiplicativeIdentity();
00301                         aZ4 = P.z = mr.Identity();
00302                 }
00303                 else
00304                 {
00305                         P.x = Q.x;
00306                         P.y = Q.y;
00307                         sixteenY4 = P.z = mr.MultiplicativeIdentity();
00308                         aZ4 = m_a;
00309                 }
00310         }
00311 
00312         void Double()
00313         {
00314                 twoY = mr.Double(P.y);
00315                 P.z = mr.Multiply(P.z, twoY);
00316                 fourY2 = mr.Square(twoY);
00317                 S = mr.Multiply(fourY2, P.x);
00318                 aZ4 = mr.Multiply(aZ4, sixteenY4);
00319                 M = mr.Square(P.x);
00320                 M = mr.Add(mr.Add(mr.Double(M), M), aZ4);
00321                 P.x = mr.Square(M);
00322                 mr.Reduce(P.x, S);
00323                 mr.Reduce(P.x, S);
00324                 mr.Reduce(S, P.x);
00325                 P.y = mr.Multiply(M, S);
00326                 sixteenY4 = mr.Square(fourY2);
00327                 mr.Reduce(P.y, mr.Half(sixteenY4));
00328         }
00329 
00330         const ModularArithmetic &mr;
00331         ProjectivePoint P;
00332         bool firstDoubling, negated;
00333         Integer sixteenY4, aZ4, twoY, fourY2, S, M;
00334 };
00335 
00336 struct ZIterator
00337 {
00338         ZIterator() {}
00339         ZIterator(std::vector<ProjectivePoint>::iterator it) : it(it) {}
00340         Integer& operator*() {return it->z;}
00341         int operator-(ZIterator it2) {return it-it2.it;}
00342         ZIterator operator+(int i) {return ZIterator(it+i);}
00343         ZIterator& operator+=(int i) {it+=i; return *this;}
00344         std::vector<ProjectivePoint>::iterator it;
00345 };
00346 
00347 ECP::Point ECP::ScalarMultiply(const Point &P, const Integer &k) const
00348 {
00349         Element result;
00350         if (k.BitCount() <= 5)
00351                 AbstractGroup<ECPPoint>::SimultaneousMultiply(&result, P, &k, 1);
00352         else
00353                 ECP::SimultaneousMultiply(&result, P, &k, 1);
00354         return result;
00355 }
00356 
00357 void ECP::SimultaneousMultiply(ECP::Point *results, const ECP::Point &P, const Integer *expBegin, unsigned int expCount) const
00358 {
00359         if (!GetField().IsMontgomeryRepresentation())
00360         {
00361                 ECP ecpmr(*this, true);
00362                 const ModularArithmetic &mr = ecpmr.GetField();
00363                 ecpmr.SimultaneousMultiply(results, ToMontgomery(mr, P), expBegin, expCount);
00364                 for (unsigned int i=0; i<expCount; i++)
00365                         results[i] = FromMontgomery(mr, results[i]);
00366                 return;
00367         }
00368 
00369         ProjectiveDoubling rd(GetField(), m_a, m_b, P);
00370         std::vector<ProjectivePoint> bases;
00371         std::vector<WindowSlider> exponents;
00372         exponents.reserve(expCount);
00373         std::vector<std::vector<unsigned int> > baseIndices(expCount);
00374         std::vector<std::vector<bool> > negateBase(expCount);
00375         std::vector<std::vector<unsigned int> > exponentWindows(expCount);
00376         unsigned int i;
00377 
00378         for (i=0; i<expCount; i++)
00379         {
00380                 assert(expBegin->NotNegative());
00381                 exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 5));
00382                 exponents[i].FindNextWindow();
00383         }
00384 
00385         unsigned int expBitPosition = 0;
00386         bool notDone = true;
00387 
00388         while (notDone)
00389         {
00390                 notDone = false;
00391                 bool baseAdded = false;
00392                 for (i=0; i<expCount; i++)
00393                 {
00394                         if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin)
00395                         {
00396                                 if (!baseAdded)
00397                                 {
00398                                         bases.push_back(rd.P);
00399                                         baseAdded =true;
00400                                 }
00401 
00402                                 exponentWindows[i].push_back(exponents[i].expWindow);
00403                                 baseIndices[i].push_back(bases.size()-1);
00404                                 negateBase[i].push_back(exponents[i].negateNext);
00405 
00406                                 exponents[i].FindNextWindow();
00407                         }
00408                         notDone = notDone || !exponents[i].finished;
00409                 }
00410 
00411                 if (notDone)
00412                 {
00413                         rd.Double();
00414                         expBitPosition++;
00415                 }
00416         }
00417 
00418         // convert from projective to affine coordinates
00419         ParallelInvert(GetField(), ZIterator(bases.begin()), ZIterator(bases.end()));
00420         for (i=0; i<bases.size(); i++)
00421         {
00422                 if (bases[i].z.NotZero())
00423                 {
00424                         bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
00425                         bases[i].z = GetField().Square(bases[i].z);
00426                         bases[i].x = GetField().Multiply(bases[i].x, bases[i].z);
00427                         bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
00428                 }
00429         }
00430 
00431         std::vector<BaseAndExponent<Point, word> > finalCascade;
00432         for (i=0; i<expCount; i++)
00433         {
00434                 finalCascade.resize(baseIndices[i].size());
00435                 for (unsigned int j=0; j<baseIndices[i].size(); j++)
00436                 {
00437                         ProjectivePoint &base = bases[baseIndices[i][j]];
00438                         if (base.z.IsZero())
00439                                 finalCascade[j].base.identity = true;
00440                         else
00441                         {
00442                                 finalCascade[j].base.identity = false;
00443                                 finalCascade[j].base.x = base.x;
00444                                 if (negateBase[i][j])
00445                                         finalCascade[j].base.y = GetField().Inverse(base.y);
00446                                 else
00447                                         finalCascade[j].base.y = base.y;
00448                         }
00449                         finalCascade[j].exponent = exponentWindows[i][j];
00450                 }
00451                 results[i] = GeneralCascadeMultiplication(*this, finalCascade.begin(), finalCascade.end());
00452         }
00453 }
00454 
00455 ECP::Point ECP::CascadeScalarMultiply(const Point &P, const Integer &k1, const Point &Q, const Integer &k2) const
00456 {
00457         if (!GetField().IsMontgomeryRepresentation())
00458         {
00459                 ECP ecpmr(*this, true);
00460                 const ModularArithmetic &mr = ecpmr.GetField();
00461                 return FromMontgomery(mr, ecpmr.CascadeScalarMultiply(ToMontgomery(mr, P), k1, ToMontgomery(mr, Q), k2));
00462         }
00463         else
00464                 return AbstractGroup<Point>::CascadeScalarMultiply(P, k1, Q, k2);
00465 }
00466 
00467 NAMESPACE_END
00468 
00469 #endif

Generated on Thu Jun 22 03:36:16 2006 for Crypto++ by  doxygen 1.4.6