Monkey Brain Machine Interface

Download NotebookAuthorAuthorCreated

This demo is based on and uses data from the following Tensor Toolbox demo: https://gitlab.com/tensors/tensor_data_monkey_bmi

Please see the license here: https://gitlab.com/tensors/tensor_data_monkey_bmi/-/blob/4870db135b362b2de499c63b48533abdd5185228/LICENSE.

Relevant papers:

  1. S. Vyas, N. Even-Chen, S. D. Stavisky, S. I. Ryu, P. Nuyujukian, and K. V. Shenoy, Neural Population Dynamics Underlying Motor Learning Transfer, Elsevier BV, Vol. 97, No. 5, pp. 1177-1186.e3, March 2018, https://doi.org/10.1016/j.neuron.2018.01.040.

  2. S. Vyas, D. J. O'Shea, S. I. Ryu, and K. V. Shenoy, Causal Role of Motor Preparation during Error-Driven Learning, Neuron, Elsevier BV, Vol. 106, No. 2, pp. 329-339.e4, April 2020, https://doi.org/10.1016/j.neuron.2020.01.019.

  3. A. H. Williams, T. H. Kim, F. Wang, S. Vyas, S. I. Ryu, K. V. Shenoy, M. Schnitzer, T. G. Kolda, S. Ganguli, Unsupervised Discovery of Demixed, Low-dimensional Neural Dynamics across Multiple Timescales through Tensor Components Analysis, Neuron, 98(6):1099-1115, 2018, https://doi.org/10.1016/j.neuron.2018.05.015.

using CairoMakie, GCPDecompositions, LinearAlgebra, Statistics

Loading the data

The following code downloads the data file, extracts the data, and caches it.

using Downloads: download
using CacheVariables, MAT
data = cache(joinpath("monkey-bmi-cache", "data.bson")) do
    # Download file
    url = "https://gitlab.com/tensors/tensor_data_monkey_bmi/-/raw/main/data.mat"
    path = download(url, tempname(@__DIR__))

    # Extract data
    data = matread(path)

    # Clean up and output data
    rm(path)
    data
end
Dict{String, Any} with 5 entries:
  "angle_list"  => [-90; 0; 90; 180;;]
  "X"           => [0.0252503 0.0258503 … 0.0722463 0.0735006; 0.0579854 0.0583162 … 0.…
  "loc"         => [-1.3754 -1.4 … -80.3915 -79.93; -2.984 -3.4072 … -11.5703 -11.45;;;…
  "angle"       => [180; 0; … ; 180; 90;;]
  "angle_xyloc" => [0.0 -80.0; 80.0 0.0; 0.0 80.0; -80.0 0.0]

The data tensor X is 43×200×88 and consists of measurements across 43 neurons, 200 time steps, and 88 trials.

X = data["X"]
43×200×88 Array{Float64, 3}:
[:, :, 1] =
 0.0252503  0.0258503  0.0264512  0.0270512  …  0.0710843   0.0722463   0.0735006
 0.0579854  0.0583162  0.0586282  0.0589141     0.0670181   0.0667575   0.0670088
 0.0415487  0.0418862  0.0427734  0.0436908     0.0445177   0.0432959   0.0426969
 0.0359435  0.0366685  0.0368237  0.0369573     0.0400855   0.0402347   0.0403619
 0.0279543  0.0281149  0.0282441  0.0288642     0.0312721   0.030951    0.0306436
 0.0        0.0        0.0        0.0        …  0.0234634   0.024461    0.0254708
 0.0261938  0.0263678  0.0260585  0.0257498     0.0349308   0.0352453   0.0355543
 ⋮                                           ⋱                          
 0.0154549  0.0150709  0.0146989  0.014339      0.0         0.0         0.0
 0.0        0.0        0.0        0.0           0.0         0.0         0.0
 0.0        0.0        0.0        0.0           0.00182208  0.00164106  0.0014742
 0.0        0.0        0.0        0.0        …  0.0         0.0         0.0
 0.0        0.0        0.0        0.0           0.0         0.0         0.0
 0.0        0.0        0.0        0.0           0.0         0.0         0.0

[:, :, 2] =
 0.0680895    0.0686456   0.0691604    …  0.0356086   0.0350154   0.0338763
 0.0215888    0.0224091   0.0232479       0.0331028   0.033491    0.0338228
 0.0512025    0.0499419   0.0486987       0.0694201   0.0705065   0.0716077
 0.0346605    0.0351354   0.0355906       0.0271583   0.0271947   0.0271703
 0.00268048   0.0029609   0.00326414      0.0257025   0.0263651   0.0270168
 0.00155825   0.0017528   0.00196778   …  0.0311654   0.0317876   0.032349
 0.0282606    0.0282435   0.0282038       0.0293163   0.0293868   0.0294493
 ⋮                                     ⋱                          
 0.000703103  0.00079422  0.000895353     0.0         0.0         0.0
 0.0          0.0         0.0             0.00529259  0.00503994  0.00478694
 0.0119871    0.0116158   0.0112317       0.0         0.0         0.0
 0.00350601   0.00373047  0.00396087   …  0.0         0.0         0.0
 0.00462999   0.00485932  0.0050891       0.0         0.0         0.0
 0.00828468   0.00849865  0.00870713      0.0         0.0         0.0

[:, :, 3] =
 0.0300831    0.0298822   0.0296841   …  0.0293035   0.0286333   0.0279737
 0.0248244    0.0259678   0.0276347      0.0351669   0.034704    0.0337091
 0.0351928    0.0358708   0.0361082      0.0359025   0.0359008   0.0358891
 0.00619878   0.00661573  0.00704468     0.0390072   0.0396213   0.0402089
 0.0199387    0.0192099   0.01849        0.032983    0.0329447   0.032904
 0.00162596   0.00182567  0.00204314  …  0.0104352   0.00970968  0.00845601
 0.0117967    0.0116142   0.0114182      0.024702    0.0249506   0.0252274
 ⋮                                    ⋱                          
 0.0          0.0         0.0            0.0         0.0         0.0
 0.0114445    0.0110998   0.0107294      0.0         0.0         0.0
 0.000914655  0.00105263  0.00120745     0.0         0.0         0.0
 0.0          0.0         0.0         …  0.0         0.0         0.0
 0.0          0.0         0.0            0.00170744  0.00153353  0.00137377
 0.000880271  0.00101462  0.00116565     0.0         0.0         0.0

;;; … 

[:, :, 86] =
 0.0505504   0.0517011   0.0533712    …  0.0587105   0.057458    0.056728
 0.0276506   0.0276121   0.0280878       0.0313349   0.0313622   0.0313908
 0.0486812   0.0494814   0.0497367       0.0472109   0.0471953   0.0477024
 0.0390944   0.0389782   0.0388526       0.0457143   0.0463767   0.0464754
 0.0313674   0.031608    0.0318505       0.0260392   0.0249495   0.0243747
 0.0         0.0         0.000567613  …  0.0351996   0.0363748   0.0380602
 0.0399333   0.0406247   0.0418846       0.0135598   0.0132747   0.0129711
 ⋮                                    ⋱                          
 0.00211271  0.00239722  0.00271361      0.0         0.0         0.0
 0.0141189   0.0142276   0.0143198       0.0         0.0         0.0
 0.0113172   0.0110655   0.0107943       0.00355169  0.00330299  0.00306376
 0.0124426   0.011916    0.0113852    …  0.0         0.0         0.0
 0.0         0.0         0.000522812     0.0         0.0         0.0
 0.0         0.0         0.0             0.0         0.0         0.0

[:, :, 87] =
 0.0727801   0.0718963   0.0715868   …  0.0265111  0.0266063    0.0267094
 0.022145    0.021518    0.0214768      0.0746487  0.0756809    0.076685
 0.0556517   0.0570774   0.0580086      0.027572   0.0269332    0.026887
 0.046503    0.0458768   0.0454267      0.0262182  0.025823     0.025974
 0.0257984   0.0261197   0.0264118      0.0015628  0.000856775  0.000750598
 0.0133426   0.014055    0.0153178   …  0.0        0.0          0.0
 0.0451843   0.0457879   0.0475244      0.0272005  0.0276764    0.0281518
 ⋮                                   ⋱                          
 0.0         0.0         0.0            0.0        0.0          0.0
 0.00244836  0.00294107  0.00363711     0.0        0.0          0.0
 0.00442345  0.00466451  0.00490712     0.0        0.0          0.0
 0.0128039   0.0131198   0.0134187   …  0.0        0.0          0.0
 0.0         0.0         0.0            0.0        0.0          0.0
 0.00960898  0.01033     0.0110863      0.0        0.0          0.0

[:, :, 88] =
 0.0367077   0.0381974   0.0396966   …  0.0972722   0.0966024   0.0959118
 0.00511397  0.00557908  0.00654481     0.0185589   0.0189553   0.0193387
 0.0592377   0.0597916   0.0598491      0.0136062   0.0133903   0.0131727
 0.0197038   0.0195294   0.0193657      0.0154254   0.0149737   0.0145147
 0.0177937   0.0177489   0.0177452      0.0417118   0.0425089   0.0433053
 0.00229669  0.00251443  0.00274525  …  0.0302213   0.0303329   0.03045
 0.011038    0.0115438   0.0120566      0.0381511   0.0391631   0.0401913
 ⋮                                   ⋱                          
 0.00510878  0.00536217  0.00561265     0.00392946  0.00367472  0.00342761
 0.00202583  0.00223041  0.00244888     0.0         0.0         0.0
 0.0         0.0         0.0            0.00373788  0.00348584  0.00324238
 0.00334882  0.00359993  0.00385921  …  0.0         0.0         0.0
 0.0         0.0         0.0            0.0         0.0         0.0
 0.0         0.0         0.0            0.00597703  0.00573948  0.00549713

Each trial has an associated angle (described more below).

angles = dropdims(data["angle"]; dims=2)
88-element Vector{Int64}:
 180
   0
 -90
  90
   0
 180
 -90
   ⋮
  90
 180
  90
   0
 180
  90

Understanding and visualizing the data

Monkey BMI GraphicMonkey BMI Cursors
Image Credit: https://gitlab.com/tensors/tensor_data_monkey_bmi

The data tensor is (pre-processed) neural data from a Brain-Machine Interface (BMI) experiment (illustrated above). In this experiment, a monkey uses the BMI to:

  1. move the cursor to one of the four targets, then

  2. hold the cursor on the target.

The targets are identified by their positions along a circle: 0, 90, 180, and -90 degrees.

While the monkey does these two tasks, the BMI records neural spike data for many neurons over time. This is then repeated for many trials. After pre-processing, the result is a 43×200×88 data tensor of measurements across 43 neurons, 200 time steps, and 88 trials.

The first 100 time steps correspond to the first task (acquire a target) and the second 100 time steps correspond to the second task (hold the target).

The following figure plots the time series in the data tensor. Each subplot shows the time series from a single neuron (each thin curve is the time series for a single trial, colored by target). The thick curves show the average time series for each target.

angle_colors = Dict(0 => :tomato1, 90 => :gold, 180 => :darkorchid3, -90 => :cyan3);
with_theme() do
    fig = Figure(; size = (800, 800))

    # Plot time series
    for (idx, data) in enumerate(eachslice(X; dims=1))
        ax = Axis(fig[fldmod1(idx < 4 ? idx : idx+2, 5)...]; title="Neuron $idx",
            xlabel="Time Steps", xticks=0:100:200,
            ylabel="Activity", yticks=LinearTicks(3)
        )

        # Individual time series
        series!(ax, permutedims(data);
            color=[(angle_colors[angle], 0.7) for angle in angles], linewidth=0.2)

        # Average time series
        for angle in -90:90:180
            lines!(ax, mean(eachcol(data)[angles .== angle]);
                color=angle_colors[angle], linewidth=1.5)
        end
    end

    # Tweak formatting
    linkxaxes!(contents(fig.layout)...)
    hidexdecorations!.(contents(fig[1:end-1, :]); ticks=false, grid=false)
    hideydecorations!.(contents(fig[:, 2:end]);
        ticklabels=false, ticks=false, grid=false)
    rowgap!(fig.layout, 10)
    colgap!(fig.layout, 10)

    # Add legend
    Legend(fig[1, 4:5],
        [LineElement(; color=angle_colors[angle]) for angle in [0, 90, 180, -90]],
        ["$(angle)°" for angle in [0, 90, 180, -90]],
        "Target Path Trajectory";
        orientation = :horizontal
    )

    fig
end

Note that the neurons have significantly varying overall levels of activity (e.g., neuron 1 and neuron 43 differ by roughly a factor of four). Likewise, there appears to generally be more activity when acquiring the target (the first 100 time steps) than when holding it (the second 100 time steps).

Run GCP Decomposition

Generalized CP decomposition with respect to non-negative least-squares can be computed using gcp by setting the loss keyword argument.

M = gcp(X, 10; loss = GCPLosses.NonnegativeLeastSquares())
43×200×88 CPD{Float64, 3, Vector{Float64}, Matrix{Float64}} with 10 components
λ weights:
10-element Vector{Float64}: …
U[1] factor matrix:
43×10 Matrix{Float64}: …
U[2] factor matrix:
200×10 Matrix{Float64}: …
U[3] factor matrix:
88×10 Matrix{Float64}: …

Now, we plot the (normalized) factors.

with_theme() do
    fig = Figure(; size = (700, 700))

    # Plot factors (normalized by max)
    for row in 1:ncomps(M)
        barplot(fig[row,1], normalize(M.U[1][:,row], Inf); color = :orange)
        lines(fig[row,2], normalize(M.U[2][:,row], Inf); linewidth = 4)
        scatter(fig[row,3], normalize(M.U[3][:,row], Inf);
            color = [angle_colors[angle] for angle in angles])
    end

    # Link and hide x axes
    linkxaxes!(contents(fig[:,1])...)
    linkxaxes!(contents(fig[:,2])...)
    linkxaxes!(contents(fig[:,3])...)
    hidexdecorations!.(contents(fig[1:end-1,:]); ticks=false, grid=false)

    # Link and hide y axes
    linkyaxes!(contents(fig.layout)...)
    hideydecorations!.(contents(fig.layout); ticks=false, grid=false)

    # Add legend
    Legend(fig[:, 4],
        [MarkerElement(; color=angle_colors[angle], marker=:circle)
            for angle in [0, 90, 180, -90]],
        ["$(angle)°" for angle in [0, 90, 180, -90]],
    )

    # Add labels
    Label(fig[0,1], "Neurons"; tellwidth=false, fontsize=20)
    Label(fig[0,2], "Time"; tellwidth=false, fontsize=20)
    Label(fig[0,3], "Trials"; tellwidth=false, fontsize=20)

    # Tweak layout
    rowgap!(fig.layout, 10)
    colgap!(fig.layout, 10)
    colsize!(fig.layout, 2, Relative(1/4))

    fig
end

Note that the factors in the neuron mode reflect our earlier observation that the neurons are roughly in decreasing order of activity. Moreover, several factors in the trial mode reflect which target was selected even though the tensor decomposition was not given that information.

Built with Julia 1.10.4 and

CacheVariables 0.1.4
CairoMakie 0.12.6
Downloads 1.6.0
GCPDecompositions 0.3.0
MAT 0.10.7
Statistics 1.10.0