Skip to content

Commit 2203f49

Browse files
Heterotopic MO kernel helper (#353)
* Heterotopic helper * Bump patch * Tweak docstring * Update docs * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump patch Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent aed40f0 commit 2203f49

File tree

5 files changed

+82
-4
lines changed

5 files changed

+82
-4
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.12"
3+
version = "0.10.13"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

Diff for: docs/src/api.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,19 @@ For an explanation of this design choice, see [the design notes on multi-output
5858
An input to a multi-output `Kernel` should be a `Tuple{T, Int}`, whose first element specifies a location in the domain of the multi-output GP, and whose second element specifies which output the inputs corresponds to.
5959
The type of collections of inputs for multi-output GPs is therefore `AbstractVector{<:Tuple{T, Int}}`.
6060

61-
KernelFunctions.jl provides the following helper function for situations in which all outputs are observed all of the time:
61+
KernelFunctions.jl provides the following helper functions to reduce the cognitive load
62+
associated with working with multi-output kernels by dealing with transforming data from the
63+
formats in which it is commonly found into the format required by KernelFunctions.
64+
The intention is that users can pass their data to these functions, and use the returned
65+
values throughout their code, without having to worry further about correctly formatting
66+
their data for KernelFunctions' sake:
6267
```@docs
6368
prepare_isotopic_multi_output_data(x::AbstractVector, y::ColVecs)
6469
prepare_isotopic_multi_output_data(x::AbstractVector, y::RowVecs)
70+
prepare_heterotopic_multi_output_data
6571
```
6672

67-
The input types that it constructs can also be constructed manually:
73+
The input types returned by `prepare_isotopic_multi_output_data` can also be constructed manually:
6874
```@docs
6975
MOInput
7076
```

Diff for: src/KernelFunctions.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3636

3737
export ColVecs, RowVecs
3838

39-
export MOInput, prepare_isotopic_multi_output_data
39+
export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_output_data
4040
export IndependentMOKernel,
4141
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel
4242

Diff for: src/mokernels/moinput.jl

+59
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ julia> outputs
143143
3.1
144144
3.2
145145
```
146+
147+
See also [`prepare_heterotopic_multi_output_data`](@ref).
146148
"""
147149
function prepare_isotopic_multi_output_data(x::AbstractVector, y::ColVecs)
148150
length(x) == length(y) || throw(ArgumentError("length(x) not equal to length(y)."))
@@ -189,8 +191,65 @@ julia> outputs
189191
2.2
190192
3.2
191193
```
194+
195+
See also [`prepare_heterotopic_multi_output_data`](@ref).
192196
"""
193197
function prepare_isotopic_multi_output_data(x::AbstractVector, y::RowVecs)
194198
length(x) == length(y) || throw(ArgumentError("length(x) not equal to length(y)."))
195199
return MOInputIsotopicByOutputs(x, size(y.X, 2)), vec(y.X)
196200
end
201+
202+
"""
203+
prepare_heterotopic_multi_output_data(
204+
x::AbstractVector, y::AbstractVector{<:Real}, output_indices::AbstractVector{Int},
205+
)
206+
207+
Utility functionality to convert a collection of inputs `x`, observations `y`, and
208+
`output_indices` into a format suitable for use with multi-output kernels.
209+
Handles the situation in which only one (or a subset) of outputs are observed at each
210+
feature.
211+
Ensures that all arguments are compatible with one another, and returns a vector of inputs
212+
and a vector of outputs.
213+
214+
`y[n]` should be the observed value associated with output `output_indices[n]` at feature
215+
`x[n]`.
216+
217+
```jldoctest
218+
julia> x = [1.0, 2.0, 3.0];
219+
220+
julia> y = [-1.0, 0.0, 1.0];
221+
222+
julia> output_indices = [3, 2, 1];
223+
224+
julia> inputs, outputs = prepare_heterotopic_multi_output_data(x, y, output_indices);
225+
226+
julia> inputs
227+
3-element Vector{Tuple{Float64, Int64}}:
228+
(1.0, 3)
229+
(2.0, 2)
230+
(3.0, 1)
231+
232+
julia> outputs
233+
3-element Vector{Float64}:
234+
-1.0
235+
0.0
236+
1.0
237+
```
238+
239+
See also [`prepare_isotopic_multi_output_data`](@ref).
240+
"""
241+
function prepare_heterotopic_multi_output_data(
242+
x::AbstractVector, y::AbstractVector{<:Real}, output_indices::AbstractVector{Int}
243+
)
244+
# Ensure validity of arguments.
245+
if length(x) != length(y)
246+
throw(ArgumentError("length(x) != length(y)"))
247+
end
248+
if length(x) != length(output_indices)
249+
throw(ArgumentError("length(x) != length(output_indices)"))
250+
end
251+
252+
# Construct inputs and outputs for multi-output kernel.
253+
x_mogp = map(tuple, x, output_indices)
254+
return x_mogp, y
255+
end

Diff for: test/mokernels/moinput.jl

+13
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,17 @@
7575
@test length(y_canon) == length(x_canon)
7676
end
7777
end
78+
@testset "prepare_heterotopic_multi_output_data" begin
79+
x = randn(3)
80+
y = randn(3)
81+
output_indices = [3, 1, 2]
82+
x_canon, y_canon = prepare_heterotopic_multi_output_data(x, y, output_indices)
83+
@test x_canon isa AbstractVector{<:Tuple{<:Real,Int}}
84+
@test y isa AbstractVector{<:Real}
85+
86+
@test_throws ArgumentError prepare_heterotopic_multi_output_data(x, y, [3, 1])
87+
@test_throws(
88+
ArgumentError, prepare_heterotopic_multi_output_data(x, [1.0], output_indices),
89+
)
90+
end
7891
end

0 commit comments

Comments
 (0)