F#で逆誤差伝播法(ミニバッチ対応版)

下記記事にて逆誤差伝播法をF#で実装してみました。
F#で逆誤差伝播法 - 何でもプログラミング

1データ/教師データ毎にネットワークを更新していましたが、今回はある程度の数学習してその変位の平均でネットワークを更新する、ミニバッチ法に対応してみたいと思います。

前回のものを流用して、ネットワーク更新の時にAffine層の平均を計算するのでもよいのですが、今回はそもそも入力でMatrix(列方向に複数データが入る)を受け取れるよう実装してみます。

特に記載のないものは、上記記事を参照してみてください。

伝播&逆伝播関数

let forward (input : Matrix<double>) (layer : Layer) : Matrix<double> =
    match layer with
    | Affine(weight, bias) -> 
        input * weight |> Matrix.mapRows (fun _ x -> x + bias)            
    | ReLU ->
        input |> Matrix.map (max 0.0)

let forwardAndCreateBackward 
    (rate : double) (input : Matrix<double>) (layer : Layer) 
    : (Matrix<double> -> Layer * Matrix<double>) * Matrix<double> =
    let output = forward input layer
    let backward =
        match layer with
        | Affine(weight, bias) ->
            (fun (dy : Matrix<double>) ->
                let dx = dy * weight.Transpose()
                let dw = input.Transpose() * dy                    
                Affine(weight - rate * dw, bias - rate * (Matrix.sumCols dy)), dx
            )
        | ReLU ->
            (fun (dy : Matrix<double>) ->
                let dx = dy |> Matrix.mapi (fun i j dy -> if output.[i, j] = 0.0 then 0.0 else dy)
                layer, dx
            )
    backward, output


学習関数

let softmaxRows (x : Matrix<double>) : Matrix<double> =            
    x |> Matrix.mapRows (fun _ x -> softmax x)

let learn (rate : double) (network : Network) (input : Matrix<double>) (teacher : Matrix<double>) : Network =
    let backwards, y = network.Layers |> Array.mapFold (forwardAndCreateBackward rate) input
    let dy = 
        match network.LastLayer with
        | SoftmaxCrossEntropy -> 
            ((softmaxRows y) - teacher) / (double y.RowCount)
    let layers, _ = backwards |> Array.rev |> Array.mapFold (|>) dy
    { network with Layers = layers |> Array.rev }


評価関数

let predict (network : Network) (input : Matrix<double>) : Matrix<double> =
    let y = network.Layers |> Array.fold forward input
    match network.LastLayer with
    | SoftmaxCrossEntropy -> softmaxRows y        

let accuracy (network : Network) (input : Matrix<double>) (teacher : Matrix<double>) : double =
    let output = predict network input
    Seq.map2 
        (=) 
        (output  |> Matrix.toRowSeq |> Seq.map Vector.maxIndex)
        (teacher |> Matrix.toRowSeq |> Seq.map Vector.maxIndex)
    |> Seq.averageBy (fun x -> if x then 1.0 else 0.0)


MNISTを学習

10000データ学習ごとの正答率は下記のように推移しました。
[ 0.0947, 0.8723, 0.8893, 0.9092, 0.9154, 0.9154, 0.9183 ]

let shuffle (ary : 'a[]) : 'a[] =
    let random = System.Random()
    ary |> Array.sortBy (fun _ -> random.Next())

let trainImages = Mnist.loadImageVectors "train-images.idx3-ubyte"
let trainLabels = Mnist.loadLabelVectors "train-labels.idx1-ubyte"
let testImages  = Mnist.loadImageVectors "t10k-images.idx3-ubyte" |> Matrix.Build.DenseOfRowVectors
let testLabels  = Mnist.loadLabelVectors "t10k-labels.idx1-ubyte" |> Matrix.Build.DenseOfRowVectors

let initialNetwork =
    { Layers =
        [| createAffineHe 784 50
            ReLU
            createAffineHe 50 10
        |]
        LastLayer = SoftmaxCrossEntropy
    }

let batchSize = 100
seq { 1..trainImages.Length / batchSize }
|> Seq.scan
    (fun net i -> 
        let indices = [| 0..trainImages.Length - 1 |] |> shuffle |> Array.take batchSize
        let images = indices |> Array.map (fun i -> trainImages.[i]) |> Matrix.Build.DenseOfRowVectors
        let labels = indices |> Array.map (fun i -> trainLabels.[i]) |> Matrix.Build.DenseOfRowVectors
        learn 0.1 net images labels 
    )
    initialNetwork
|> Seq.iter (fun network -> printf "accuracy %f\n" (accuracy network testImages testLabels))






F#で逆誤差伝播法

今回はニューラルネットワークで利用される、逆誤差伝播法をF#で実装してみたいと思います。

実装をするに際し、Math.NETライブラリを利用しています。

レイヤーの定義

今回は、全結合のAffine層、ReLU活性化層、Softmax最終活性化層を定義しました。

その他の層が欲しい場合は、ここに定義を追加していく形となります。

またAffine層の初期化として、He初期値を利用する関数も定義しました。

type Layer =
    | Affine of weight : Matrix<double> * bias : Vector<double>
    | ReLU

type LastLayer =
    | SoftmaxCrossEntropy        

type Network =
    { Layers    : Layer[]
      LastLayer : LastLayer
    }

let createAffineHe (inputCount : int) (outputCount : int) : Layer =
    let weight = Matrix<double>.Build.Random(inputCount, outputCount) * (sqrt (2.0 / double inputCount))
    let bias   = Vector<double>.Build.Dense(outputCount)
    Affine(weight, bias)


伝播&逆伝播関数

純粋な伝播を定義するforward関数と、伝播しながら逆伝播関数を生成するforwardAndCreateBackward関数を定義します。

let forward (input : Vector<double>) (layer : Layer) : Vector<double> =
    match layer with
    | Affine(weight, bias) -> 
        input * weight + bias            
    | ReLU ->
        input |> Vector.map (max 0.0)

let forwardAndCreateBackward 
    (rate : double) (input : Vector<double>) (layer : Layer) 
    : (Vector<double> -> Layer * Vector<double>) * Vector<double> =
    let output = forward input layer
    let backward =
        match layer with
        | Affine(weight, bias) ->
            (fun (dy : Vector<double>) ->
                let dx = dy * weight.Transpose()
                let dw = input.ToColumnMatrix() * dy.ToRowMatrix()
                Affine(weight - rate * dw, bias - rate * dy), dx
            )
        | ReLU ->
            (fun dy ->
                let dx = Vector.map2 (fun y dy -> if y = 0.0 then 0.0 else dy) output dy
                layer, dx
            )
    backward, output


学習関数

順伝播しながら逆伝播関数を生成し、最終層から逆伝播させ、更新された新しいNetworkを生成しています。

let softmax (x : Vector<double>) : Vector<double> =
    let c = Vector.max x
    let e = x |> Vector.map (fun x -> exp (x - c))
    e / (Vector.sum e)

let learn (rate : double) (network : Network) (input : Vector<double>) (teacher : Vector<double>) : Network =
    let backwards, y = network.Layers |> Array.mapFold (forwardAndCreateBackward rate) input
    let dy = 
        match network.LastLayer with
        | SoftmaxCrossEntropy -> (softmax y) - teacher
    let layers, _ = backwards |> Array.rev |> Array.mapFold (|>) dy
    { network with Layers = layers |> Array.rev }


評価関数

入力と教師データから正答率を算出しています。(教師データは、どれか一つの値が活性化するものと想定しています。)

let predict (network : Network) (input : Vector<double>) : Vector<double> =
    let y = network.Layers |> Array.fold forward input
    match network.LastLayer with
    | SoftmaxCrossEntropy -> softmax y        

let accuracy (network : Network) (inputs : Vector<double>[]) (teachers : Vector<double>[]) : double =
    let outputs = inputs |> Array.map (predict network)
    Seq.map2
        (=)
        (outputs  |> Seq.map Vector.maxIndex)
        (teachers |> Seq.map Vector.maxIndex)
    |> Seq.averageBy (fun x -> if x then 1.0 else 0.0)


MNISTを学習してみる

MNISTの読み込みに関しては、下記記事を参照してみてください。
MNISTの読み込み(F#) - 何でもプログラミング

10000データ学習ごとの正答率は下記のように推移しました。
[ 0.0931, 0.913, 0.9321, 0.9297, 0.9441, 0.9471, 0.9488 ]

let trainImages = Mnist.loadImageVectors "train-images.idx3-ubyte"
let trainLabels = Mnist.loadLabelVectors "train-labels.idx1-ubyte"
let testImages  = Mnist.loadImageVectors "t10k-images.idx3-ubyte"
let testLabels  = Mnist.loadLabelVectors "t10k-labels.idx1-ubyte"

let initialNetwork =
    { Layers =
        [| createAffineHe 784 50
            ReLU
            createAffineHe 50 10
        |]
        LastLayer = SoftmaxCrossEntropy
    }

Seq.zip trainImages trainLabels
|> Seq.scan (fun network (image, label) -> learn 0.01 network image label) initialNetwork
|> Seq.indexed
|> Seq.iter 
    (fun (i, network) -> 
        if i % 10000 = 0 then 
            printf "accuracy %f\n" (accuracy network testImages testLabels)
    )






MNISTの読み込み(F#)

機械学習のデータとして、手書き数字の画像がまとめられた下記のサイトを利用することがあります。
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

訓練データとして60000画像、テストデータとして10000画像用意されています。

今回はこのデータをF#で利用できるようパースしてみたいと思います。(フォーマットは上記サイトに記載されています。)

読み込み関数

データはbig endianで保存されています。

let readInt32BigEndian (reader : BinaryReader) : int =
    BitConverter.ToInt32(reader.ReadBytes(4) |> Array.rev, 0)

let loadLabels (path : string) : byte[] =
    use reader = new BinaryReader(File.OpenRead(path))
    assert (readInt32BigEndian reader = 2049)
    let count = readInt32BigEndian reader
    reader.ReadBytes(count)

let loadImages (path : string) : byte[][] =
    use reader = new BinaryReader(File.OpenRead(path))
    assert (readInt32BigEndian reader = 2051)
    let count  = readInt32BigEndian reader
    let height = readInt32BigEndian reader
    let width  = readInt32BigEndian reader
    [| 1..count |] |> Array.map (fun _ -> reader.ReadBytes(width * height))


動作確認

実際にpngで保存してみて中身を確認してみます。

let savePng8 (width : int) (height : int) (pixels : byte[]) (path : string) : unit =
   use stream = new FileStream(path, FileMode.Create)
   let encoder = PngBitmapEncoder()
   let bmp = BitmapSource.Create(width, height, 96.0, 96.0, PixelFormats.Gray8, null, pixels, width)
   encoder.Frames.Add(BitmapFrame.Create(bmp))
   encoder.Save(stream)

let main argv = 
    let images = loadImages "train-images.idx3-ubyte"
    let labels = loadLabels "train-labels.idx1-ubyte"

    Array.iteri2
        (fun i image label ->
            let path = sprintf "image%d(%d).png" i label
            File1.savePng8 28 28 image path
        )
        (images |> Array.take 3)
        (labels |> Array.take 3)
f:id:any-programming:20180219125849p:plain f:id:any-programming:20180219125852p:plain f:id:any-programming:20180219125854p:plain
5 0 4


学習用に変形

実際にデータを利用する際には、数学ライブラリのデータで取得したほうが便利です。

今回はMath.NETのVector形式に変換してみます。
f:id:any-programming:20180219131050p:plain

また、画像データを255で割って正規化し、ラベルデータを10要素のVectorに変換します。(例:3 → [0, 0, 0, 1, 0, 0, 0, 0, 0, 0])

open MathNet.Numerics.LinearAlgebra

let loadLabelVectors (path : string) : Vector<double>[] =
    loadLabels path
    |> Array.map 
        (fun label -> 
            [| 0uy..9uy |] 
            |> Array.map (fun x -> if x = label then 1.0 else 0.0) 
            |> Vector.Build.Dense
        )

let loadImageVectors (path : string) : Vector<double>[] =
    loadImages path
    |> Array.map (Array.map (fun x -> (double x) / 255.0) >> Vector.Build.Dense)






Boost Preprocessorでコンストラクタ生成

今回は、下記のように記述すると、その下のコードのように展開してくれるマクロを、Boost Preprocessorを利用して作成してみたいと思います。(実際には改行は生成されません。)

struct Person {
    std::string FirstName;
    std::string LastName;
    int Age;

    CONSTRUCTOR(Person, FirstName, LastName, Age)
};
struct Person {
    std::string FirstName;
    std::string LastName;
    int Age;

    Person(decltype(FirstName) a_FirstName, decltype(LastName) a_LastName, decltype(Age) a_Age) 
        : FirstName(std::move(a_FirstName)), LastName(std::move(a_LastName)), Age(std::move(a_Age)) 
    {}
};


マクロ定義

Variadic引数をSEQに変換し、各要素をBOOST_PP_SEQ_TRANSFORMで変換したのちBOOST_PP_SEQ_ENUMで出力しています。(BOOST_PP_SEQ_ENUMはカンマ区切りで出力されます。)

通常のマクロでは結合はa_##memberのように記述しますが、Boost Preprocessorを利用している場合はBOOST_PP_CATを利用します。(同様に#memberはBOOST_PP_STRINGIZE(member)と記述します。)

#include <boost\preprocessor.hpp>

#define _CONSTRUCTOR_PARAM(s, data, member) \
    decltype(member) BOOST_PP_CAT(a_, member)

#define _CONSTRUCTOR_PARAMS(members) \
    BOOST_PP_SEQ_ENUM(BOOST_PP_SEQ_TRANSFORM(_CONSTRUCTOR_PARAM, , members))

#define _CONSTRUCTOR_INIT(s, data, member) \
    member(std::move(BOOST_PP_CAT(a_, member)))

#define _CONSTRUCTOR_INITS(members) \
    BOOST_PP_SEQ_ENUM(BOOST_PP_SEQ_TRANSFORM(_CONSTRUCTOR_INIT, , members))

#define _CONSTRUCTOR(cls, members) \
    cls(_CONSTRUCTOR_PARAMS(members)) : _CONSTRUCTOR_INITS(members) {}

#define CONSTRUCTOR(cls, ...) \
    _CONSTRUCTOR(cls, BOOST_PP_VARIADIC_TO_SEQ(__VA_ARGS__))


展開結果の確認

cppファイルのプロパティで、Process to a Fileを有効化してファイルをコンパイルすると、出力フォルダに .i ファイルが生成されます。(.objファイルが出力されなくなるため、ビルド時には無効化してください。)

f:id:any-programming:20180116153641p:plain





C++ AMPで画像処理

今回はC++ AMPを利用して画像処理を行ってみたいと思います。

下図のように、赤と青を入れ替える処理を実装していきます。

f:id:any-programming:20170319005947j:plain f:id:any-programming:20170319010003j:plain


アプリケーションコード

concurrency::arrayを利用するとint配列しか受け付けてくれないため、今回はconcurrency::graphics::textureを利用しています。

BMPの読み書きは下記記事のものを利用しています。
Bitmap読み書き - 何でもプログラミング

#include <amp.h>
#include <amp_graphics.h>

int main()
{
    int width, height;
    std::vector<byte> srcPixels;
    LoadBitmap24("Parrots.bmp", &width, &height, &srcPixels);

    // bits_per_scalar_elementに8Uを指定した場合、texture_view<const int>経由でアクセスしないとエラーになります。
    concurrency::graphics::texture<int, 2> srcTexture(height, 3 * width, srcPixels.data(), srcPixels.size(), 8U);
    concurrency::graphics::texture_view<const int, 2> srcView(srcTexture);

    concurrency::graphics::texture<int, 2> dstTexture(height, 3 * width, 8U);
    
    concurrency::extent<2> extent(height, width);
    concurrency::parallel_for_each(extent, [&, srcView](concurrency::index<2> idx) restrict(amp) {
        concurrency::index<2> idx1(idx[0], 3 * idx[1]);
        concurrency::index<2> idx2(idx[0], 3 * idx[1] + 1);
        concurrency::index<2> idx3(idx[0], 3 * idx[1] + 2);
        dstTexture.set(idx1, srcView[idx3]);
        dstTexture.set(idx2, srcView[idx2]);
        dstTexture.set(idx3, srcView[idx1]);
    });

    std::vector<byte> dstPixels(width * height * 3);
    concurrency::graphics::copy(dstTexture, dstPixels.data(), dstPixels.size());

    SaveBitmap24("Parrots2.bmp", width, height, dstPixels.data());

    return 0;
}






OpenCLで画像処理

今回はOpenCLを用いて画像処理を行ってみたいと思います。

下図のように、赤と青を入れ替える処理を実装していきます。

f:id:any-programming:20170319005947j:plain f:id:any-programming:20170319010003j:plain


アプリケーションコード

clCreateImageでもいいですが、今回はclCreateBufferで実装してみました。

BMPの読み書きは下記記事のものを利用しています。
Bitmap読み書き - 何でもプログラミング

#include <cl/cl.h>
#define ASSERT_CL(expr) if (expr != CL_SUCCESS) { throw std::exception(#expr); }

int main()
{
    // 画像読み込み
    int width, height;
    std::vector<byte> srcPixels;
    LoadBitmap24("Parrots.bmp", &width, &height, &srcPixels);

    // device, context作成
    cl_platform_id platform;
    ASSERT_CL(clGetPlatformIDs(1, &platform, nullptr));

    cl_device_id device;
    ASSERT_CL(clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 1, &device, nullptr));

    cl_context context = clCreateContext(nullptr, 1, &device, nullptr, nullptr, nullptr);
    assert(context);

    // kernel作成
    const char* source =
        "__kernel void main(__global const uchar* src, __global uchar *dst) {     \n"
        "   int i = get_global_id(0);                                             \n"
        "   dst[3 * i]     = src[3 * i + 2];                                      \n"
        "   dst[3 * i + 1] = src[3 * i + 1];                                      \n"
        "   dst[3 * i + 2] = src[3 * i];                                          \n"
        "}                                                                        \n";

    cl_program program = clCreateProgramWithSource(context, 1, &source, nullptr, nullptr);
    assert(program);

    ASSERT_CL(clBuildProgram(program, 1, &device, nullptr, nullptr, nullptr));
    char buildLog[1024];
    ASSERT_CL(clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 1024, buildLog, nullptr));
    printf("build %s\n", buildLog);

    cl_kernel kernel = clCreateKernel(program, "main", nullptr);
    assert(kernel);

    // 引数設定
    int size = (int)srcPixels.size();
    cl_mem srcBuffer = clCreateBuffer(context, CL_MEM_READ_ONLY, size, nullptr, nullptr);
    assert(srcBuffer);
    cl_mem dstBuffer = clCreateBuffer(context, CL_MEM_WRITE_ONLY, size, nullptr, nullptr);
    assert(dstBuffer);

    ASSERT_CL(clSetKernelArg(kernel, 0, sizeof(srcBuffer), &srcBuffer));
    ASSERT_CL(clSetKernelArg(kernel, 1, sizeof(dstBuffer), &dstBuffer));

    // Queue作成
    cl_command_queue queue = clCreateCommandQueue(context, device, 0, nullptr);
    assert(queue);

    ASSERT_CL(clEnqueueWriteBuffer(queue, srcBuffer, CL_TRUE, 0, size, srcPixels.data(), 0, nullptr, nullptr));

    size_t global_work_size = size / 3;
    ASSERT_CL(clEnqueueNDRangeKernel(queue, kernel, 1, nullptr, &global_work_size, nullptr, 0, nullptr, nullptr));

    std::vector<byte> dstPixels(size);
    ASSERT_CL(clEnqueueReadBuffer(queue, dstBuffer, CL_TRUE, 0, size, dstPixels.data(), 0, nullptr, nullptr));

    // 実行
    ASSERT_CL(clFlush(queue));
    ASSERT_CL(clFinish(queue));

    // 画像保存
    SaveBitmap24("Parrots2.bmp", width, height, dstPixels.data());

    return 0;
}


SDKなしでの開発

OpenCLSDKは、IntellやNVIDIAなどから提供されています。

ちょっとしたものであれば、ヘッダファイルだけ取得して動的にdllをロードすることで開発できます。

ヘッダは下記よりダウンロードできます。(最低限、cl.h と cl_platform.h で大丈夫です。)
Khronos OpenCL Registry - The Khronos Group Inc

dllのロードは、下記コードのように行っています。

ヘッダに関数郡がすでに定義されているため、一段namespaceで包んで関数をロードしています。

#include <cl/cl.h>
namespace App {
    HMODULE hOpenCL = LoadLibraryA("opencl.dll");

#define GET_PROC(name) decltype(::name)* name = (decltype(::name)*)GetProcAddress(hOpenCL, #name)
    GET_PROC(clGetPlatformIDs);
    GET_PROC(clGetDeviceIDs);
    GET_PROC(clCreateContext);
    ...

    int main()
    {
        ....
    }
}

int main() { return App::main(); }






WebGL2.0でボリュームレンダリング(TypeScript)

下記記事にて、Direct3D11を用いてボリュームレンダリングを実装してみました。
Direct3D11でボリュームレンダリング - 何でもプログラミング

今回はWebGL2.0を利用して同様のものを実装してみたいと思います。

ボリュームレンダリングの手法やデータは上記記事を参照してみてください。

※現時点ではWebGL2.0はChromeFirefoxでしかサポートされていません。

HTML

表示のため、下記のhtmlを用意しました。

canvasWebGLレンタリング用で、file inputはデータを読み込むためのものです。

最終的に下図のような表示になります。
f:id:any-programming:20171226133944p:plain

<!DOCTYPE html>
<html lang="ja">
  <head>
    <meta charset="utf-8">
    <title>webGL</title>
  </head>
  <body>
    <canvas id="gl" width="200" height="200"></canvas>
    <input id="file" type="file">
    <script src="script.js"></script>
  </body>
</html>


Canvasのサイズ

htmlのcanvas要素のwidthとheightはバックバッファのサイズ、cssで設定するwidthとheightは実際の表示サイズとして利用されます。

WebGL2.d.ts

WebGL2の定義ファイルはデフォルトで含まれていないため、下記のものをダウンロードして利用しました。
GitHub - MaxGraey/WebGL2-TypeScript: WebGL2 bindings for TypeScript

いくつかWebGLRenderingContextの定義とかち合う関数があるので、適宜コメントアウトして利用してください。

main関数

Direct3Dの時と同様に、三方向のスライスセットを準備し適宜切り替えています。

また陰影もボリューム値の勾配を法線として算出しています。

function main() : void {
    // context取得
    const canvas = <HTMLCanvasElement>document.getElementById("gl");
    const gl = canvas.getContext("webgl2")!;

    // Shader作成
    const vertexCode = `#version 300 es
        layout(location = 0) in vec4 position;
        layout(location = 1) in vec3 texcoord;
        uniform mat4 matrix;
        out vec3 v_texcoord;
        void main() {
            gl_Position = matrix * position;
            v_texcoord = texcoord;
        }    
    `;

    const fragmentCode = `#version 300 es
        precision mediump float;
        precision mediump sampler3D;
        in vec3 v_texcoord;
        uniform sampler3D tex;
        out vec4 fragColor;
        void main() {
            if (texture(tex, v_texcoord).r < 0.5)
                discard;
            
            float dx = textureOffset(tex, v_texcoord, ivec3(1, 0, 0)).r - textureOffset(tex, v_texcoord, ivec3(-1,  0,  0)).r;
            float dy = textureOffset(tex, v_texcoord, ivec3(0, 1, 0)).r - textureOffset(tex, v_texcoord, ivec3( 0, -1,  0)).r;
            float dz = textureOffset(tex, v_texcoord, ivec3(0, 0, 1)).r - textureOffset(tex, v_texcoord, ivec3( 0,  0, -1)).r;
            vec3 normal = normalize(vec3(dx, dy, dz));

            vec3 light = normalize(vec3(1, 1, 1));
            float gray = abs(dot(normal, light));
            fragColor = vec4(gray, gray, gray, 1);
        }
    `;

    const program = createProgram(gl, vertexCode, fragmentCode);

    // Texture3D、VertexBuffer、IndexBuffer作成(Fileが読み込まれた際)
    let texture : WebGLTexture | undefined;
    let slices : VolumeSlices | undefined;
    const fileInput = <HTMLInputElement>document.getElementById("file");
    fileInput.onchange = ev =>
    {
        const file = (<HTMLInputElement>ev.target).files![0];
        const reader = new FileReader();
        reader.onload = _ev =>
        {
            const buf = new Uint8Array(<ArrayBuffer>reader.result);
            texture = createTexture3D(gl, 512, 512, 360, buf);
            slices = new VolumeSlices(gl, 512, 512, 360);
        };
        reader.readAsArrayBuffer(file);
    };

    // 固定の設定
    gl.useProgram(program);
    gl.enableVertexAttribArray(0);
    gl.enableVertexAttribArray(1);
    const matrixUniform = gl.getUniformLocation(program, "matrix")!;
    gl.enable(gl.DEPTH_TEST);

    // 描画関数
    function draw(rotX : number, rotY : number) : void {
        gl.clearColor(0.5, 0.6, 1.0, 1);
        gl.clear(gl.COLOR_BUFFER_BIT | gl.DEPTH_BUFFER_BIT);

        if (slices == undefined || texture == undefined)
            return;

        // 行列設定
        const mat = Matrix4.rotate(rotY, new Vector3(0, 1, 0))
            .mul(Matrix4.rotate(rotX, new Vector3(1, 0, 0)))
            .mul(Matrix4.scale(1.5 / 512))
            .mul(Matrix4.translate(new Vector3(-512.0 / 2, -512.0 / 2, -360.0 / 2)));
        gl.uniformMatrix4fv(matrixUniform, false, mat.transpose().data);

        // VertexBuffer、IndexBuffer設定
        const { vertexBuffer, indexBuffer } = slices.getBuffer(mat);
        gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
        gl.vertexAttribPointer(0, 3, gl.FLOAT, false, 24, 0);
        gl.vertexAttribPointer(1, 3, gl.FLOAT, false, 24, 12);
        gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, indexBuffer);

        // 描画
        const indexCount = gl.getBufferParameter(gl.ELEMENT_ARRAY_BUFFER, gl.BUFFER_SIZE) / 2;
        gl.drawElements(gl.TRIANGLES, indexCount, gl.UNSIGNED_SHORT, 0);
    }

    // MouseMoveで描画
    canvas.onmousemove = ev => {
        const rotX = 2 * Math.PI * ev.y / canvas.height;
        const rotY = 2 * Math.PI * ev.x / canvas.width;
        draw(rotX, rotY);
    };
}


Slices

三方向のスライスセットを作成、適切なものを取得できるクラスになります。

class VolumeSlices {
    private vertexBufferX : WebGLBuffer;
    private indexBufferX  : WebGLBuffer;
    private vertexBufferY : WebGLBuffer;
    private indexBufferY  : WebGLBuffer;
    private vertexBufferZ : WebGLBuffer;
    private indexBufferZ  : WebGLBuffer;
    constructor(gl : WebGLRenderingContext, width : number, height : number, depth : number) {
        const verticesZ = normalizedRange(depth, 0, 1)
            .concatMap(t => [
                0,     0,      t * depth, 0, 0, t,
                width, 0,      t * depth, 1, 0, t,
                0,     height, t * depth, 0, 1, t,
                width, height, t * depth, 1, 1, t,
            ]);

        const verticesY = normalizedRange(height, 0, 1)
            .concatMap(t => [
                0,     t * height, 0,     0, t, 0,
                width, t * height, 0,     1, t, 0,
                0,     t * height, depth, 0, t, 1,
                width, t * height, depth, 1, t, 1,
            ]);

        const verticesX = normalizedRange(width, 0, 1)
            .concatMap(t => [
                t * width, 0,      0,     t, 0, 0,
                t * width, height, 0,     t, 1, 0,
                t * width, 0,      depth, t, 0, 1,
                t * width, height, depth, t, 1, 1,
            ]);

        const createIndices = (count : number) =>
            range(count).concatMap<number>(i => [0, 2, 1, 1, 2, 3].map<number>(x => x + 4 * i));

        this.vertexBufferX = createVertexBuffer(gl, new Float32Array(verticesX));
        this.vertexBufferY = createVertexBuffer(gl, new Float32Array(verticesY));
        this.vertexBufferZ = createVertexBuffer(gl, new Float32Array(verticesZ));
        this.indexBufferX = createIndexBuffer(gl, new Uint16Array(createIndices(width)));
        this.indexBufferY = createIndexBuffer(gl, new Uint16Array(createIndices(height)));
        this.indexBufferZ = createIndexBuffer(gl, new Uint16Array(createIndices(depth)));
    }
    getBuffer(m : Matrix4) : { vertexBuffer : WebGLBuffer, indexBuffer : WebGLBuffer } {
        const dotX = Math.abs(m.transform3x3(Vector3.ex).z);
        const dotY = Math.abs(m.transform3x3(Vector3.ey).z);
        const dotZ = Math.abs(m.transform3x3(Vector3.ez).z);
        if (dotX < dotZ && dotY < dotZ)
            return { vertexBuffer : this.vertexBufferZ, indexBuffer : this.indexBufferZ };
        else if (dotX < dotY)
            return { vertexBuffer : this.vertexBufferY, indexBuffer : this.indexBufferY };
        else
            return { vertexBuffer : this.vertexBufferX, indexBuffer : this.indexBufferX };
    }
}


WebGL関数群

main関数内で利用されているWebGL関連の関数は下記のように実装しています。

function compileShader(gl : WebGLRenderingContext, shader : WebGLShader, code : string) : void {
    gl.shaderSource(shader, code);
    gl.compileShader(shader);

    console.log(gl.getShaderInfoLog(shader));
    if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS))
        throw new Error("compile error");
}

function createProgram(gl : WebGLRenderingContext, vertexCode : string, fragmentCode : string) : WebGLProgram {
    const vertexShader = gl.createShader(gl.VERTEX_SHADER)!;
    compileShader(gl, vertexShader, vertexCode);

    const fragmentShader = gl.createShader(gl.FRAGMENT_SHADER)!;
    compileShader(gl, fragmentShader, fragmentCode);

    const program = gl.createProgram()!;
    gl.attachShader(program, vertexShader);
    gl.attachShader(program, fragmentShader);
    gl.linkProgram(program);

    console.log(gl.getProgramInfoLog(program));
    if (!gl.getProgramParameter(program, gl.LINK_STATUS))
        throw new Error("program error");

    return program;
}
function createVertexBuffer(gl : WebGLRenderingContext, vertices : Float32Array) : WebGLBuffer {
    const buffer = gl.createBuffer()!;

    gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
    gl.bufferData(gl.ARRAY_BUFFER, vertices, gl.STATIC_DRAW);

    return buffer;
}
function createIndexBuffer(gl : WebGLRenderingContext, indices : Uint16Array) : WebGLBuffer {
    const buffer = gl.createBuffer()!;

    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer);
    gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, indices, gl.STATIC_DRAW);

    return buffer;
}
function createTexture3D(gl : WebGL2RenderingContext, width : number, height : number, depth : number, voxels : Uint8Array) : WebGLTexture {
    const texture = gl.createTexture()!;

    gl.bindTexture(gl.TEXTURE_3D, texture);
    gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1);
    gl.texImage3D(gl.TEXTURE_3D, 0, gl.R8, width, height, depth, 0, gl.RED, gl.UNSIGNED_BYTE, voxels);
    gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
    gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
    gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_3D, gl.TEXTURE_WRAP_R, gl.CLAMP_TO_EDGE);

    return texture;
}


Vector3、Matrix4

動作確認レベルの実装ですので、お好みの行列ライブラリで置き換えてください。

class Vector3 {
    constructor(
        readonly x : number,
        readonly y : number,
        readonly z : number,
    ) {}
    add(v : Vector3) : Vector3 {
        return new Vector3(this.x + v.x, this.y + v.y, this.z + v.z);
    }
    sub(v : Vector3) : Vector3 {
        return new Vector3(this.x - v.x, this.y - v.y, this.z - v.z);
    }
    mul(s : number) : Vector3 {
        return new Vector3(this.x * s, this.y * s, this.z * s);
    }
    div(s : number) : Vector3 {
        return new Vector3(this.x / s, this.y / s, this.z / s);
    }
    length() : number {
        return Math.sqrt(this.x * this.x + this.y * this.y + this.z * this.z);
    }
    normalize() : Vector3 {
        return this.div(this.length());
    }
    dot(v : Vector3) : number {
        return this.x * v.x + this.y * v.y + this.z * v.z;
    }
    cross(v : Vector3) : Vector3 {
        return new Vector3(this.y * v.z - this.z * v.y, this.z * v.x - this.x * v.z, this.x * v.y - this.y * v.x);
    }
    static readonly ex = new Vector3(1, 0, 0);
    static readonly ey = new Vector3(0, 1, 0);
    static readonly ez = new Vector3(0, 0, 1);
}
class Matrix4 {
    constructor(readonly data: number[])
    { }
    mul(m : Matrix4) : Matrix4 {
        return new Matrix4([
            this.data[0] * m.data[0] + this.data[1] * m.data[4] + this.data[2] * m.data[8]  + this.data[3] * m.data[12],
            this.data[0] * m.data[1] + this.data[1] * m.data[5] + this.data[2] * m.data[9]  + this.data[3] * m.data[13],
            this.data[0] * m.data[2] + this.data[1] * m.data[6] + this.data[2] * m.data[10] + this.data[3] * m.data[14],
            this.data[0] * m.data[3] + this.data[1] * m.data[7] + this.data[2] * m.data[11] + this.data[3] * m.data[15],

            this.data[4] * m.data[0] + this.data[5] * m.data[4] + this.data[6] * m.data[8]  + this.data[7] * m.data[12],
            this.data[4] * m.data[1] + this.data[5] * m.data[5] + this.data[6] * m.data[9]  + this.data[7] * m.data[13],
            this.data[4] * m.data[2] + this.data[5] * m.data[6] + this.data[6] * m.data[10] + this.data[7] * m.data[14],
            this.data[4] * m.data[3] + this.data[5] * m.data[7] + this.data[6] * m.data[11] + this.data[7] * m.data[15],

            this.data[8] * m.data[0] + this.data[9] * m.data[4] + this.data[10] * m.data[8]  + this.data[11] * m.data[12],
            this.data[8] * m.data[1] + this.data[9] * m.data[5] + this.data[10] * m.data[9]  + this.data[11] * m.data[13],
            this.data[8] * m.data[2] + this.data[9] * m.data[6] + this.data[10] * m.data[10] + this.data[11] * m.data[14],
            this.data[8] * m.data[3] + this.data[9] * m.data[7] + this.data[10] * m.data[11] + this.data[11] * m.data[15],

            this.data[12] * m.data[0] + this.data[13] * m.data[4] + this.data[14] * m.data[8]  + this.data[15] * m.data[12],
            this.data[12] * m.data[1] + this.data[13] * m.data[5] + this.data[14] * m.data[9]  + this.data[15] * m.data[13],
            this.data[12] * m.data[2] + this.data[13] * m.data[6] + this.data[14] * m.data[10] + this.data[15] * m.data[14],
            this.data[12] * m.data[3] + this.data[13] * m.data[7] + this.data[14] * m.data[11] + this.data[15] * m.data[15],
        ]);
    }
    transform(v : Vector3) : Vector3 {
        const x = this.data[0]  * v.x + this.data[1]  * v.y + this.data[2]  * v.z + this.data[3];
        const y = this.data[4]  * v.x + this.data[5]  * v.y + this.data[6]  * v.z + this.data[7];
        const z = this.data[8]  * v.x + this.data[9]  * v.y + this.data[10] * v.z + this.data[11];
        const w = this.data[12] * v.x + this.data[13] * v.y + this.data[14] * v.z + this.data[15];
        return new Vector3(x / w, y / w, z / w);
    }
    transform3x3(v : Vector3) : Vector3 {
        return new Vector3(
            this.data[0] * v.x + this.data[1] * v.y + this.data[2]  * v.z,
            this.data[4] * v.x + this.data[5] * v.y + this.data[6]  * v.z,
            this.data[8] * v.x + this.data[9] * v.y + this.data[10] * v.z,
        );
    }
    transpose() : Matrix4 {
        return new Matrix4([
            this.data[0], this.data[4], this.data[8],  this.data[12],
            this.data[1], this.data[5], this.data[9],  this.data[13],
            this.data[2], this.data[6], this.data[10], this.data[14],
            this.data[3], this.data[7], this.data[11], this.data[15],
        ]);
    }
    getMatrix3Data() : number[] {
        return [
            this.data[0], this.data[1], this.data[2],
            this.data[4], this.data[5], this.data[6],
            this.data[8], this.data[9], this.data[10],
        ];
    }
    static readonly identity =
        new Matrix4([
            1, 0, 0, 0,
            0, 1, 0, 0,
            0, 0, 1, 0,
            0, 0, 0, 1
        ]);
    static translate(v : Vector3) : Matrix4 {
        return new Matrix4([
            1, 0, 0, v.x,
            0, 1, 0, v.y,
            0, 0, 1, v.z,
            0, 0, 0, 1
        ]);
    }
    static rotate(radian : number, axis : Vector3) : Matrix4 {
        const x = axis.x, y = axis.y, z = axis.z;
        const s = Math.sin(radian);
        const c = Math.cos(radian);
        return new Matrix4([
            x * x * (1 - c) + c,     x * y * (1 - c) - z * s, z * x * (1 - c) + y * s, 0,
            x * y * (1 - c) + z * s, y * y * (1 - c) + c,     y * z * (1 - c) - x * s, 0,
            z * x * (1 - c) - y * s, y * z * (1 - c) + x * s, z * z * (1 - c) + c,     0,
            0,                       0,                       0,                       1
        ]);
    }
    static scale(s : number) : Matrix4 {
        return new Matrix4([
            s, 0, 0, 0,
            0, s, 0, 0,
            0, 0, s, 0,
            0, 0, 0, 1
        ]);
    }
    static ortho(width : number, height : number, near : number, far : number) : Matrix4 {
        return new Matrix4([
            2.0 / width, 0,            0,                   0,
            0,           2.0 / height, 0,                   0,
            0,           0,            -2.0 / (far - near), -(far + near) / (far - near),
            0,           0,            0,                   1
        ]);
    }
    static perspective(fovy : number, aspect : number, near : number, far : number) : Matrix4 {
        const f = 1.0 / Math.tan(fovy / 2.0);
        return new Matrix4([
            f / aspect, 0, 0,                           0,
            0,          f, 0,                           0,
            0,          0, (far + near) / (near - far), 2 * far * near / (near - far),
            0,          0, -1,                          0
        ]);
    }
}


Array関連の拡張

Slices作成内で利用しているArray関連の関数は下記のように実装してあります。

interface Array<T> {
    concatMap<U>(f : (value : T) => U[]) : U[];
}
Array.prototype.concatMap = function(f : any) {
    return this.reduce((dst, x) => dst.concat(f(x)), []);
};

function range(count : number) : number[] {
    const ary = Array<number>(count);
    for (let i = 0; i < count; ++i)
        ary[i] = i;
    return ary;
}
function normalizedRange(count : number, min : number, max : number) : number[] {
    return range(count).map<number>(i => ((count - 1 - i) * min + i * max) / (count - 1));
}