Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keys can be lost in mapreduce(..., cat, ...) #87

Open
bencottier opened this issue Oct 8, 2021 · 2 comments
Open

Keys can be lost in mapreduce(..., cat, ...) #87

bencottier opened this issue Oct 8, 2021 · 2 comments

Comments

@bencottier
Copy link

I want to slice a 3D array into matrices, multiply each matrix by another matrix, and then cat the result back into a 3D array.

The reason to do it this way is: the dimension shared by the matrices has non-overlapping keys, and I want to find the overlapping keys (that have non-missing values) for each slice.

The problem is that I can't preserve the axiskeys that I sliced over in the final concatenated array - it defaults to OneTo.

MWE:

julia> KA1 = KeyedArray(ones(2, 3), w=['a', 'b'], x=[:a, :b, :c]);

julia> KA2 = KeyedArray(ones(3, 4, 2), x=[:a, :b, :c], y=0.:3., z=["foo", "bar"]);

julia> mapreduce((x, y) -> cat(x, y; dims=:z), axiskeys(KA2, :z)) do z
           KA2_slice = KA2(z=z)
           return KA1 * KA2_slice
       end
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
   w  2-element Vector{Char}
   y  4-element StepRangeLen{Float64,...}
□   z  2-element OneTo{Int}
And data, 2×4×2 Array{Float64, 3}:
[:, :, 1] ~ (:, :, 1):
         (0.0)  (1.0)  (2.0)  (3.0)
  ('a')    3.0    3.0    3.0    3.0
  ('b')    3.0    3.0    3.0    3.0

[:, :, 2] ~ (:, :, 2):
         (0.0)  (1.0)  (2.0)  (3.0)
  ('a')    3.0    3.0    3.0    3.0
  ('b')    3.0    3.0    3.0    3.0

I tried using an Interval at the KA2_slice = KA2(z=z) step, but Julia doesn't seem to support multiplying tensors with different dimensions, even if it's a trailing singleton dimension:

julia> ones(2, 3) * ones(3, 2, 1)
ERROR: MethodError: no method matching *(::Matrix{Float64}, ::Array{Float64, 3})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  *(::StridedMatrix{T}, ::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real} at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:44
  *(::StridedMatrix{var"#s832"} where var"#s832"<:Union{Float32, Float64}, ::StridedMatrix{var"#s831"} where var"#s831"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:158
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[21]:1

My current workaround is to use wrapdims afterward to re-key the array.

@mcabbott
Copy link
Owner

mcabbott commented Oct 8, 2021

This is tricky. My first thought here is that, without thinking about this package, you probably want to just reshape and call * instead of making slices. One package which wraps this neatly is:

julia> using TensorCore

julia> KA1  KA2
2×4×2 Array{Float64, 3}:
[:, :, 1] =
 3.0  3.0  3.0  3.0
 3.0  3.0  3.0  3.0

[:, :, 2] =
 3.0  3.0  3.0  3.0
 3.0  3.0  3.0  3.0

These packages are unaware of each other, but reshaping done there uses axes and thus with #6 it almost succeeds:

julia> KA1 ⊡ KA2
3-dimensional KeyedArray(...) with keys:
↓   2-element Vector{Char}
→   4-element StepRangeLen{Float64,...}
◪   2-element Vector{String}
And data, 2×4×2 reshape(::NamedDimsArray{(:w, :_), Float64, 2, Matrix{Float64}}, 2, 4, 2) with eltype Float64:
[:, :, 1] ~ (:, :, "foo"):
     → _
↓ w          (0.0)  (1.0)  (2.0)  (3.0)
      ('a')    3.0    3.0    3.0    3.0
      ('b')    3.0    3.0    3.0    3.0

[:, :, 2] ~ (:, :, "bar"):
     → _
↓ w          (0.0)  (1.0)  (2.0)  (3.0)
      ('a')    3.0    3.0    3.0    3.0
      ('b')    3.0    3.0    3.0    3.0

The second thought is that, regardless of where slices come from, it would be nice to better propagate properties. Instead of cat(xs...; dims=:z) you almost want cat(xs...; z = KA2.z)? Not sure that can be done. The closest which works now is to wrap a comprehension and call stack, which has methods for keys etc:

julia> [KA1 * KA2(z=z) for z in KA2.z]  # KA2.z === axiskeys(KA2, :z)
2-element Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}:
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]

julia> comp = [KA1 * KA2[z=i] for i in axes(KA2, :z)]
2-element Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}:
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]

julia> using LazyStack  # a package AxisKeys knows about

julia> stack(comp)  # 3rd axis still _ ∈ 2-element OneTo
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
   w  2-element Vector{Char}
   y  4-element StepRangeLen{Float64,...}
◪   _  2-element OneTo{Int}
And data, 2×4×2 stack(::Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}) with eltype Float64:
[:, :, 1] ~ (:, :, 1):
...

julia> stack(wrapdims(comp, z=KA2.z))
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
   w  2-element Vector{Char}
   y  4-element StepRangeLen{Float64,...}
◪   z  2-element Vector{String}
And data, 2×4×2 stack(::Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}) with eltype Float64:
[:, :, 1] ~ (:, :, "foo"):

Again thinking about #6, I think comp could plausibly be made to automatically wrap like this, since it makes axes(KA2, :z) a special type.

@bencottier
Copy link
Author

Thanks for the suggestions!

To clarify, is there currently no support (on master in any package) for tensor multiplication that preserves axiskeys? And is #6 (or similar) your preferred solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants