matlabで並列処理を伴う高次方程式の解計算とグラフ作成

なにがしたいか

  • 5次以上の高次方程式の数値解を網羅したい
    • 一般解の表現が存在しない以上は数値的に求めるしかない
    • パラメータを弄るとどんな感じに変化するかを確認したい
    • でもちょっと面倒なので簡単な2次方程式で実装してみるよ
  • 国の金で買った多コアCPUなんだから使い潰したいね
    • 簡単な並列処理の導入と通信手法のテンプレート化を図る

注意

  • matlabの並列化にはParallel Computing Toolboxのライセンスが必要だぞ
    • 対応していないと並列化は(多分)できない,設定も表示されないかも

jp.mathworks.com

方程式を定義しよう

.mlxって知ってる?
ライブスクリプトを使うと方程式の作成と編集が楽になるよ

clear;

syms a b x;
symfx = x^2 + a*x + b;

save("testEquation.mat");

.mファイルに保存すれば同じ方程式をいつでも呼び出せる

並列処理の環境設定

ライセンスがあれば左下のアイコンが表示されるので,並列基本設定に進む

手元のPCだけならlocalを選ぶ(普通はそれ以外は出てこないはず),ワーカー数を指定する(物理コアよりも少なめで設定)

OKを押したら並列プールを起動してみる,謎の模様が青く光ってればOKだ

実装と実験

簡単にルールを説明すると,

  • parforforループの体をしながらも処理を1つずつ並列ワーカーに送信する
  • parforの中にparforを入れてはいけない
  • parforの内から外に変数をダイレクトに持ち出してはいけない
    • 外から内だとメモリ空間を共用できるけど内から外だとできるわけないからね
    • 代わりにparforの中からparallel.pool.DataQueuesendすることはできる
  • 以上の条件から実装したい処理は以下の通りになる
    • 親プロセスで方程式,パラメータ,グラフ,キュー,ちょっとした関数などの環境を整える
    • 子プロセスにターゲットの方程式とそれに適用するパラメータを与えて解かせる
    • 解を出したらデータセットを整形して子プロセスがキューを更新(事実上のプッシュ)する
    • 親プロセスはキューを監視してデータセットを元にグラフを(リアルタイムで)反映させる
  • 2種類のパラメータの総当たりを1つのparforで振り分けるので,そのへんの工夫は必要になる
    • 1次元の配列に2次元のデータをマッピングする競プロ感のある実装になってる,しゃーない
  • 実数解は青十字,虚数解は赤正円で表示を分けてみる
    • このビジュアライズとても便利なのでちょっと布教したい

コード

clear;
load('testEquation.mat');

brStr = convertCharsToStrings(char(10));


arrayA = -10:1:10;
lengtA = [-10,10];
arrayB = -10:1:10;
lengtB = [-10,10];


stDateTime = datetime();
messg = "Launched Analysis: testEquation"+brStr+"Start at: "+datestr(stDateTime);
disp(messg);
sendSlackTxt('simulinkMonitor',messg);



try
    image = imread(title+".png");
    clear("image");
    messg = "Already Analysed"+brStr+"Terminate Situation";
    disp(messg);
    sendSlackTxt('simulinkMonitor',messg);
catch


    title = "testEquation";
    disp(title);

    index = "testEquationSolver"+brStr+"with Parallel Computing";
    graph = initGraph(index);
    graph.WindowState="maximized";


    titleSolve = "Solutions Mapping";
    tableSolve = initTable(titleSolve,1,2,1);
    xlabel(tableSolve,"$\rm{Re}$",'Interpreter','latex');
    ylabel(tableSolve,"$\rm{Im}$",'Interpreter','latex');
    xticks(tableSolve,-10:1:10);
    xlim(tableSolve,[-10,10]);
    yticks(tableSolve,-10:1:10);
    ylim(tableSolve,[-10,10]);
    tableSolve.XAxis.Exponent = 0;
    tableSolve.YAxis.Exponent = 0;

    titleParam = "Parameter Mapping";
    tableParam = initTable(titleParam,1,2,2);
    xlabel(tableParam,"$a$",'Interpreter','latex')
    ylabel(tableParam,"$b$",'Interpreter','latex')
    xticks(tableParam,arrayA);
    xlim(tableParam,lengtA);
    yticks(tableParam,arrayB);
    ylim(tableParam,lengtB);
    tableParam.XAxis.Exponent = 0;
    tableParam.YAxis.Exponent = 0;


    pListSolve = parallel.pool.DataQueue;
    pListSolve.afterEach(@(data) addPlot(tableSolve,data));

    pListParam = parallel.pool.DataQueue;
    pListParam.afterEach(@(data) addPlot(tableParam,data));



    parfor countPattr = 1:(length(arrayA)*length(arrayB))


        [countA,countB] = addAxis(length(arrayA),length(arrayB),countPattr);
        paramA = arrayA(countA);
        paramB = arrayB(countB);


        numfx = subs(symfx,a,paramA);
        numfx = subs(numfx,b,paramB);

        poles = vpasolve(numfx,x);
        reals = real(poles);
        imgns = imag(poles);
        color = 'b';
        shape = '+';
        diamt = 8;
        for check=imgns.'
            if (check~=0)
                color='r';
                shape='o';
                diamt=16;
            else
            end
        end


        send(pListSolve,{reals,imgns,diamt,shape,color});
        send(pListParam,{paramA,paramB,diamt,shape,color});


    end

    %saveas(graph,title+".png");
    %close(graph);

    edDateTime = datetime();
    messg = "Finished Analysis: testEquation"+brStr+"End at: "+datestr(edDateTime);
    disp(messg);
    sendSlackTxt('simulinkMonitor',messg);


end



function [] = addPlot(table,data)
scatter(table,data{1},data{2},data{3}^2,data{4},data{5});
hold on;
drawnow('limitrate');
end



function [countA,countB] = addAxis(lengthA,lengthB,countPattr)
ii = countPattr;
jj = 0;
while ii>0
    ii = ii-lengthA;
    jj = jj+1;
end
countA = ii+lengthA;
countB = jj;
end

ちなみにグラフのプリセットはこれで生成している

function graph = initGraph(index)

    graph = figure;
    set(graph,'color','white');
    
    if index==""
    else
        sgtitle(graph,index,'FontName','Times New Roman','FontSize',48,'interpreter','latex');
    end

end
function [table] = initTable(index,heigt,width,order)

    table = subplot(heigt,width,order);
    set(table,'FontName','Times New Roman','FontSize',24);

    if index==""
    else
        title(index,'Interpreter','latex','FontSize',36);
    end

    set(table,'Box','on','LineWidth',1,'GridLineStyle','-','MinorGridLineStyle','none');
    set(table,'XColor','k','YColor','k','ZColor','k','GridColor','k');
    set(table,'GridColor','k','MinorGridColor','k','GridAlpha',0.5);
    set(table,'TickLabelInterpreter','latex')
    grid on;
    hold on;
    
end

結果

あー,中学校の数学でこういうのをやった記憶,みなさんにあるはずっすよねぇ
$a2-4b<0$で虚数になるから$b>\dfrac{a2}{4}$で青十字から赤正円に切り替わる

ちなみに私は普段はvpasolverで8次方程式を解かせています
2次方程式だと計算スピードがプロットスピードを超えるのでちょっともたついて感じるかもしれん