Quantcast
Channel: OpenCV Q&A Forum - RSS feed
Viewing all articles
Browse latest Browse all 4615

Trouble with OpenCV Implementation of EM Algorithm using Spherical GMM

$
0
0
Hi I've been trying to use the OpenCV EM code and have had a problem. By setting the maxiter parameter to 1 and using the trainE function I allow only the first expectation step to be completed. I then check the loglikelihood output. Now my problem lies in the fact that if I change the model from GENERIC to SPHERICAL, I get different values for loglikelihood for the same initialization. However the equations for the expectation step remain the same and the loglikelihoods in the first E-Step should be the same values for both models. So I am unable to figure out where the problem lies. Here is my code #include #define NumObs 100 #define Dim 2 #define numClusters 2 #define maxiter 1 int main(int argc, char** argv) { cv::Mat X = cv::Mat(NumObs, Dim, CV_64F); cv::Mat mean = cv::Mat(numClusters,Dim,CV_64F); std::vector covar; cv::Mat mixfrac = cv::Mat(numClusters, 1, CV_64F); cv::Mat logLikelihoods = cv::Mat(NumObs, 1, CV_64F); cv::Mat labels = cv::Mat(NumObs, 1, CV_32SC1); cv::Mat probs = cv::Mat(NumObs, numClusters, CV_64F); int i; mean = (cv::Mat_(numClusters, Dim) << 2, 3, 4, 5); mixfrac = (cv::Mat_(numClusters, 1) << 0.5, 0.5); cv::Mat temp(2, 2, CV_64F); temp = (cv::Mat_(Dim, Dim) << 2, 0, 0, 2); covar.push_back(temp); temp = (cv::Mat_(Dim, Dim) << 3, 0, 0, 3); covar.push_back(temp); X = (cv::Mat_(NumObs, Dim) << 1.497295712046560, 2.389842806384741 , -1.152388866819642, 2.324051537951996 , 0.173091890138541, 3.534436744900519 , 4.013635589903881, 2.635470732249428 , 0.343404593485856, 2.710797002553179 , 1.921946750771822, 2.230097059251492 , 2.923105482852674, 1.968687377364361 , 0.833763816549866, 1.621119740290876 , 1.000974973520235, 0.501003642054062 , 2.722581031828046, 2.179151938477484 , 2.061657164408961, 1.870114601526011 , 2.020337185029648, 1.822954976448330 , 1.076363394410326, 2.640385082400831 , 1.543478612011669, 2.241597508998520 , 1.485372235285237, 2.893724906106137 , -0.715927044291090, 1.369537907446789 , 1.864588596758692, 1.688875700397673 , -0.335577615523311, 2.183975485869516 , -1.456483320753825, 1.493965172927608 , -0.712160818106508, 1.912240385882839 , 0.233365021273562, 1.954626227148290 , 1.998344873833923, 0.343780671699337 , 0.880159284239810, 0.926396318575558 , 0.800239863676473, 2.841303191042556 , 2.509835594907550, 0.767610261936824 , -0.806015998003402, 2.440467562347095 , 2.672846813458981, 1.899891020249070 , 2.396696676145479, 3.621937374575406 , 0.151374345266610, 2.569943677660735 , 2.775466044119043, 1.781063533049695 , 1.191366838548084, 1.702684062103669 , -0.232621783623176, 1.631979353155736 , -1.174537262824991, 1.738099976739696 , -2.550748638571660, 1.816963733157464 , 0.039014544632652, 1.808195242782888 , -2.078943086646864, 2.084038835619226 , 1.081089945552743, 2.576115695603576 , -0.916439364072319, 0.708922705869511 , 0.903297610096863, 2.791378047921652 , -0.561947988394385, 2.656510856158910 , 1.216211768410659, 1.219031885881044 , 0.693337555617687, 1.359035671153648 , 1.417702588339615, 2.313005179476850 , 0.430424593609415, 1.132064410966174 , -0.261023655491497, 1.981218019474763 , 2.268819458983812, 0.391842019202139 , 0.806386480042758, 3.199864590240754 , -0.985989695126394, 0.893059534741550 , -2.442654072043812, 1.689554199357084 , 1.210751074224775, 1.499372753495019 , -0.672567018407857, 2.970910199573346 , -0.507381574772955, 1.967926272242394 , 0.247969979069518, 0.947965104732303 , 2.891713920836473, 0.861177994177589 , 3.354602636199656, 1.942017768571913 , -1.693061842976859, 1.760474923494942 , 3.997953758133590, 1.297952400955861 , 1.512000798331720, 0.853595479991537 , 1.423345018469304, 3.082589664236521 , -0.666695083234841, 1.125720158829322 , 3.464220247867981, 1.038778673164228 , 2.933373891043690, 1.628066445794965 , 2.438787742568557, 1.882463318993010 , 1.156530661730840, 2.009751864529454 , 1.471382343384891, 2.299778193525170 , 0.710205836898082, 2.982463815839716 , 0.574655988307198, 1.897536569003219 , -0.481180738023749, 3.243258184346350 , 2.057634665518593, 1.158138790043992 , 2.764066483839555, 1.749966682737925 , 1.188742536016829, 1.830731044279766 , 1.062643063683092, 2.464027653634071 , 2.274832425103863, 2.976150757847411 , -0.391925728913642, 2.459189741355192 , 0.214293806085435, 1.173608789474463 , 0.457619276357005, 2.215723842685201 , -3.276595636919820, 2.027190449572088 , 1.136813051985737, 2.411476105137650 , 2.754286393796612, 1.525510310208986 , 1.772550672306315, 1.678744622863505 , 1.151056973603011, 0.934580048074344 , 1.572929476032799, 1.934085057470423 , 2.070875710998629, 2.577539460790769 , -0.449497770648108, 1.397218305043158 , 0.271960009043126, 2.461090941688480 , 1.502451262202109, 2.258525296739507 , 1.867786756275275, 1.401950731642538 , -0.397193560292116, 1.959501415433482 , -0.799417427242084, 1.323993855920221 , 1.063400687978604, 1.690226733204147 , 1.270024853701466, 2.875733585500125 , -1.772562991551423, 0.986750552965217 , 0.532253741838827, 0.619813291703453 , 1.138823786767506, 2.845567676726764 , 1.285833310664554, 1.354502560988730 , 3.353413094595020, 2.320556360323725 , -1.158616580303092, 1.591324520299605 , 0.300520863016906, 0.893745059745460 , 1.133378827555012, 1.534905699580405 , 0.040235892761628, 2.368829212292840); cv::EM emModel(numClusters, cv::EM::COV_MAT_SPHERICAL, cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::EPS, maxiter, FLT_EPSILON)); bool success = emModel.cv::EM::trainE(X, mean, covar, mixfrac, logLikelihoods,labels,probs); if (success) printf("Success\n"); else printf("Failed\n"); cv::Mat newmean = emModel.cv::EM::get("means"); cv::Mat newweight = emModel.cv::EM::get("weights"); std::vector newvar = emModel.cv::EM::get>("covs"); std::cout << "LLH\n" << logLikelihoods << "\n\n"; std::cout <<"NEWMEAN\n"<< newmean<<"\n\n"; std::cout << "NEWWEIGHTS\n" << newweight << "\n\n"; std::cout << "NEWSIGMA1\n" << newvar[0] << "\n"; std::cout << "NEWSIGMA2\n" << newvar[1] << "\n\n"; scanf("%d", &i); } Output for Spherical GMM LLH [-2.931318011752626; -5.437115860016975; -3.667420897566784; -3.274239241308711; -3.512773886210777; -2.906378240768077; -3.145533448424205; -3.642481763141506; -4.650852173801083; -2.975512838378632; -3.086635213907651; -3.118147350422762; -3.02853906103177; -2.976176233676393; -2.81071576186187; -5.359239062666107; -3.221342073874848; -4.364890315806004; -6.406040105709825; -4.978136266645567; -3.885209083574705; -4.577679428989389; -4.226199670930593; -3.149979577537311; -4.088024534030816; -4.880127907834185; -3.122802132137833; -2.646131497925848; -3.716520697504866; -3.224961706370468; -3.399575606628337; -4.557644951387831; -5.765981658468872; -8.37554003304065; -4.15407503628638; -7.21426852540733; -3.042169136121539; -6.295443523476872; -3.093842961677683; -4.495265457885608; -3.772140198103474; -3.935028312978259; -2.986138732258714; -4.33067591541676; -4.375573985871765; -4.519102272986751; -3.12565345021085; -6.194705958136484; -8.213811782800219; -3.538572963594731; -4.601005773126856; -4.678804407509512; -4.666930234789473; -4.083218428518389; -3.326859772310369; -6.643334152388976; -4.217625141712781; -4.034734535109596; -2.814646585841023; -5.508686183211898; -4.14839280421732; -3.381874480419066; -3.094404661798319; -3.230164061642246; -2.973679171864843; -3.195153092832868; -3.638373707120748; -4.353325915681707; -3.643177605071091; -3.243476559240886; -3.318103507338534; -3.083978027938285; -2.668655332006318; -4.331983635797021; -4.47589600534346; -3.569193241927018; -10.03869961506342; -3.064692854541199; -3.408620622130172; -3.24166993975683; -4.078595023654192; -3.122174103654394; -2.764010652665211; -4.990765765921384; -3.636211257765106; -2.980988844038545; -3.439623639120886; -4.547175745026653; -5.512926199458751; -3.467745941915222; -2.894697423739033; -7.428294826501615; -4.801229141025504; -2.958713722999831; -3.625296272969862; -3.108524035685613; -5.840522469845191; -4.677873214788853; -3.545697817254149; -3.884322355072641] NEWMEAN [2, 3; 4, 5] NEWWEIGHTS [0.5, 0.5] NEWSIGMA1 [2, 0; 0, 2] NEWSIGMA2 [3, 0; 0, 3] Output for Spherical GMM LLH [-3.295935746117501; -5.790750393412513; -4.03429419104581; -3.71263347108744; -3.872144489126807; -3.274075193172064; -3.527637461113017; -3.998128057592525; -5.004203357427852; -3.356081006983058; -3.452688438712195; -3.483334605887541; -3.391826118348322; -3.33991303502579; -3.180885949172912; -5.710717495080189; -3.583938119252925; -4.719201739742172; -6.757228800928109; -5.330902409373262; -4.240038592584679; -4.935658714571719; -4.579970929182056; -3.513114639026275; -4.4522906425319; -5.234662136130952; -3.498568777684983; -3.050795422255073; -4.074129455698702; -3.60134679297809; -3.757286012239146; -4.910366860830981; -6.117838563989898; -8.727422737804261; -4.507867855539753; -7.566730910097011; -3.404909140030045; -6.645804917417026; -3.457241856183214; -4.851307959586312; -4.128055700550207; -4.289284804578601; -3.349349646254201; -4.683600556531062; -4.729290711442341; -4.879400262194479; -3.492881334973137; -6.545230968603729; -8.565356892970549; -3.895492494864605; -4.958752084391844; -5.032035254354657; -5.01908185049789; -4.453534794387545; -3.72062903465443; -6.99499565346089; -4.624250701398458; -4.391029813440906; -3.186754423099298; -5.859793331816577; -4.532911965106766; -3.759785100598588; -3.465822699838461; -3.589297731185657; -3.337252862476405; -3.55907539338088; -3.994110028844047; -4.713786842080629; -4.004537247619449; -3.619271902441214; -3.67643636032134; -3.4456342295298; -3.053353732793124; -4.687381509300202; -4.828411093949869; -3.925970540180974; -10.3918013626343; -3.426464016898089; -3.781854557415382; -3.603359371297283; -4.433394247345532; -3.483824967527195; -3.138004621822025; -5.342610735990292; -3.993620083098697; -3.344492754280415; -3.800540886464314; -4.900564570171441; -5.864239423435192; -3.824719561169336; -3.262235974218792; -7.778626718537899; -5.153466442034869; -3.324629479023928; -3.982041010852076; -3.50977175586426; -6.192064850807197; -5.030038737191805; -3.902365100307677; -4.240420295976542] NEWMEAN [2, 3; 4, 5] NEWWEIGHTS [0.5, 0.5] NEWSIGMA1 [2, 0; 0, 2] NEWSIGMA2 [3, 0; 0, 3] Thank you for your time.

Viewing all articles
Browse latest Browse all 4615


<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>