Skip to content

Commit d91cf59

Browse files
committed
Minimal working with allocs
1 parent c2367cf commit d91cf59

File tree

4 files changed

+215
-95
lines changed

4 files changed

+215
-95
lines changed

src/MatrixFields/field_name.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ extract_first(::FieldName{name_chain}) where {name_chain} = first(name_chain)
5050
drop_first(::FieldName{name_chain}) where {name_chain} =
5151
FieldName(Base.tail(name_chain)...)
5252

53-
extract_last(::FieldName{name_chain}) where {name_chain} = last(name_chain)
54-
drop_last(::FieldName{name_chain}) where {name_chain} =
55-
FieldName(name_chain[1:end-1]...)
53+
extract_last(::FieldName{name_chain}) where {name_chain} = name_chain[length(name_chain)]
54+
# drop_last(::FieldName{name_chain}) where {name_chain} =
55+
# FieldName(name_chain[1:(end - 1)]...)
5656

5757
has_field(x, ::FieldName{()}) = true
5858
has_field(x, name::FieldName) =

src/MatrixFields/field_name_dict.jl

Lines changed: 193 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,14 @@ function get_internal_entry(
165165
name_pair::FieldNamePair,
166166
key_error,
167167
)
168+
# TODO: Add other cases
168169
if name_pair[1] == name_pair[2]
169170
entry
170-
elseif name_pair[2] == @name() && has_field(entry, name_pair[1])
171-
DiagonalMatrixRow(get_field(entry, name_pair[1]))
171+
elseif name_pair[2] == @name() &&
172+
broadcasted_has_field(eltype(entry), name_pair[1])
173+
DiagonalMatrixRow(
174+
broadcasted_get_field(entry.entries.:(1), name_pair[1]),
175+
)
172176
elseif is_overlapping_name(name_pair[1], name_pair[2])
173177
throw(key_error)
174178
else
@@ -185,13 +189,32 @@ function get_internal_entry(
185189
T = eltype(eltype(entry))
186190
if name_pair == (@name(), @name())
187191
entry
188-
elseif name_pair[1] == name_pair[2] && !broadcasted_has_field(T, name_pair[1])
189-
# multiplication case 3 or 4, first argument
190-
@assert T <: Geometry.SingleValue
192+
elseif name_pair[1] == name_pair[2] &&
193+
!broadcasted_has_field(T, name_pair[1])
194+
# @show "aa"
195+
# # multiplication case 3 or 4, first argument
196+
@assert T <: Number
191197
entry
192-
elseif name_pair[2] == @name() && broadcasted_has_field(T, name_pair[1])
193-
# multiplication case 2 or 4, second argument
194-
target_field_eltype = broadcasted_get_field_type(T, name_pair[1])
198+
elseif name_pair[1] == @name() || name_pair[2] == @name()
199+
200+
target_chain = if name_pair[1] == @name()
201+
if broadcasted_has_field(T, name_pair[2])
202+
# this case should be dscalar/dvec with T isa vec
203+
name_pair[2]
204+
else
205+
# this should be dscalar/dvec with T isa adjoint
206+
append_internal_name(@name(parent), name_pair[2])
207+
end
208+
else
209+
if broadcasted_has_field(T, name_pair[1])
210+
# this case should be dtuple/dscalar or dvec/dscalar with T isa vec
211+
name_pair[1]
212+
else
213+
# this should be dvec/dscalar with T isa adjoint
214+
append_internal_name(@name(parent), name_pair[1])
215+
end
216+
end
217+
target_field_eltype = broadcasted_get_field_type(T, target_chain)
195218
if target_field_eltype == eltype(parent(entry))
196219
T_band = eltype(entry)
197220
singleton_datalayout =
@@ -203,7 +226,7 @@ function get_internal_entry(
203226
)
204227
field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
205228
scalar_field_offset = get_field_first_index_offset(
206-
name_pair[1],
229+
target_chain,
207230
target_field_eltype,
208231
T,
209232
)
@@ -231,57 +254,131 @@ function get_internal_entry(
231254
else
232255
Base.broadcasted(entry) do matrix_row
233256
map(matrix_row) do matrix_row_entry
234-
broadcasted_get_field(matrix_row_entry, name_pair[1])
257+
broadcasted_get_field(matrix_row_entry, target_chain)
235258
end
236-
end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
259+
end
260+
end
261+
elseif name_pair[2] != @name() && name_pair[1] != @name()
262+
# this should only be the case with dvec/dvec or dNTuple/dvec
263+
if T <: Geometry.SingleValue
264+
# @assert drop_last(name_pair[1]) ==
265+
# drop_last(name_pair[2]) ==
266+
# @name(components.data)
267+
row_index = extract_last(name_pair[1])
268+
col_index = extract_last(name_pair[2])
269+
(n_rows, n_cols) = map(length, axes(T))
270+
@assert row_index <= n_rows && col_index <= n_cols
271+
flattened_index = n_rows * (col_index - 1) + row_index
272+
elseif eltype(T) <: Geometry.SingleValue #TODO: nested tuples?
273+
# @assert drop_last(name_pair[2]) == @name(components.data)
274+
modified_first_name =
275+
broadcasted_has_field(T, name_pair[1]) ? name_pair[1] :
276+
append_internal_name(@name(parent), name_pair[1])
277+
flattened_index =
278+
get_field_first_index_offset(
279+
name_pair[1],
280+
broadcasted_get_field_type(T, name_pair[1]),
281+
T,
282+
) + extract_last(name_pair[2])
283+
else
284+
error("Cannot get entry for key $name_pair")
237285
end
238-
elseif broadcasted_has_field(T, name_pair[1]) && broadcasted_has_field(T, name_pair[2])
239-
# this should only be the case when both independent and dependent var are axisvectors
240-
@assert T <: Geometry.SingleValue && !(T <: Number)
241-
@assert drop_last(name_pair[1]) == drop_last(name_pair[2]) == @name(components.data)
242-
row_index = extract_last(name_pair[1])
243-
col_index = extract_last(name_pair[2])
244-
(n_rows, n_cols) = map(length, axes(T))
245-
@assert row_index <= n_rows && col_index <= n_cols
246-
flattened_index = n_rows * (col_index - 1) + row_index
247286
band_element_size = div(sizeof(T), sizeof(eltype(T)))
248287
T_band = eltype(entry)
249-
singleton_datalayout =
250-
DataLayouts.singleton(Fields.field_values(entry))
288+
singleton_datalayout = DataLayouts.singleton(Fields.field_values(entry))
251289
# BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
252-
scalar_band_type = band_matrix_row_type(
253-
outer_diagonals(T_band)...,
254-
eltype(T),
255-
)
290+
scalar_band_type =
291+
band_matrix_row_type(outer_diagonals(T_band)..., eltype(eltype(T)))
256292
field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
257293
band_element_size = div(sizeof(T), sizeof(eltype(T)))
258-
@assert band_element_size == n_rows * n_cols
259294
parent_indices = DataLayouts.to_data_specific_field(
260-
singleton_datalayout,
261-
(
262-
:,
263-
:,
264-
flattened_index:band_element_size:field_dim_size,
265-
:,
266-
:,
267-
),
268-
)
269-
# Main.@infiltrate
295+
singleton_datalayout,
296+
(:, :, flattened_index:band_element_size:field_dim_size, :, :),
297+
)
298+
270299
scalar_data = view(parent(entry), parent_indices...)
271-
values = DataLayouts.union_all(singleton_datalayout){
272-
scalar_band_type,
273-
Base.tail(
274-
DataLayouts.type_params(Fields.field_values(entry)),
275-
)...,
276-
}(
277-
scalar_data,
278-
)
279-
Fields.Field(values, axes(entry))
300+
301+
values = DataLayouts.union_all(singleton_datalayout){
302+
scalar_band_type,
303+
Base.tail(DataLayouts.type_params(Fields.field_values(entry)))...,
304+
}(
305+
scalar_data,
306+
)
307+
Fields.Field(values, axes(entry))
280308
else
281309
throw(key_error)
282310
end
283311
end
284312

313+
function get_scalar_keys(dict::FieldMatrix)
314+
keys_tuple = unrolled_flatmap(keys(dict).values) do key
315+
entry = dict[unrolled_filter(isequal(key), keys(dict).values)[1]]
316+
entry =
317+
entry isa ColumnwiseBandMatrixField ? entry.entries.:(1) : entry
318+
unrolled_map(filtered_names(entry) do field
319+
if field isa UniformScaling
320+
true
321+
elseif field isa Fields.Field
322+
eltype(field) == eltype(eltype(field))
323+
else
324+
eltype(field) == typeof(field)
325+
end
326+
end) do name
327+
(append_internal_name(key[1], name), key[2])
328+
end
329+
end
330+
return FieldMatrixKeys(keys_tuple)
331+
end
332+
# function combine_name_pair(name_pair::Tuple{FieldName, FieldName{()}}, T)
333+
# end
334+
335+
# function combine_name_pair(name_pair::Tuple{FieldName{()}, FieldName}, ::Type{T}) where {T}
336+
# T <: NamedTuple && error("Cannot return ")
337+
# # @assert eltype()
338+
# end
339+
340+
# function combine_name_pair(name_pair::FieldNamePair, T)
341+
# end
342+
343+
# function foobarbaz(combined_name_chain, T, entry, target_field_eltype)
344+
# band_element_size = div(sizeof(T), sizeof(eltype(T)))
345+
# T_band = eltype(entry)
346+
# singleton_datalayout =
347+
# DataLayouts.singleton(Fields.field_values(entry))
348+
# # BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
349+
# scalar_band_type = band_matrix_row_type(
350+
# outer_diagonals(T_band)...,
351+
# eltype(T),
352+
# )
353+
# field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
354+
# band_element_size = div(sizeof(T), sizeof(eltype(target_field_eltype)))
355+
# first_index = get_field_first_index_offset(
356+
# combined_name_chain,
357+
# target_field_eltype,
358+
# T,
359+
# )
360+
# parent_indices = DataLayouts.to_data_specific_field(
361+
# singleton_datalayout,
362+
# (
363+
# :,
364+
# :,
365+
# first_index:band_element_size:field_dim_size,
366+
# :,
367+
# :,
368+
# ),
369+
# )
370+
# target_data = view(parent(entry), parent_indices...)
371+
# values = DataLayouts.union_all(singleton_datalayout){
372+
# target_data,
373+
# Base.tail(
374+
# DataLayouts.type_params(Fields.field_values(entry)),
375+
# )...,
376+
# }(
377+
# scalar_data,
378+
# )
379+
# Fields.Field(values, axes(entry))
380+
# end
381+
285382
# Similar behavior to indexing an array with a slice.
286383
function Base.getindex(dict::FieldNameDict, new_keys::FieldNameSet)
287384
common_keys = intersect(keys(dict), new_keys)
@@ -368,10 +465,15 @@ function get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
368465
entry = dict[unrolled_filter(isequal(key), keys(dict).values)[1]]
369466
if entry isa UniformScaling # uniformscalings can only contain numbers
370467
(key,)
371-
elseif entry isa ColumnwiseBandMatrixField
468+
elseif entry isa ColumnwiseBandMatrixField ||
469+
entry isa DiagonalMatrixRow
372470
first_band = entry.entries.:(1)
373471
target_eltype = eltype(parent(first_band))
374-
if eltype(first_band) == target_eltype
472+
if entry isa ColumnwiseBandMatrixField &&
473+
eltype(first_band) <: target_eltype
474+
(key,)
475+
elseif entry isa DiagonalMatrixRow &&
476+
typeof(first_band) <: target_eltype
375477
(key,)
376478
else
377479
dependent_var = get_field(Y, key[1])
@@ -380,52 +482,55 @@ function get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
380482
independent_type = eltype(independent_var)
381483
# @Main.infiltrate
382484
@assert dependent_type <: Geometry.SingleValue ||
383-
independent_type <: Geometry.SingleValue ||
384-
"cannot get scalar keys for key $key"
485+
independent_type <: Geometry.SingleValue ||
486+
"cannot get scalar keys for key $key"
385487

386488
# figure out if we need to drill into key[1] or key[2], or both
387489
# @show key
388-
unrolled_flatmap(filtered_names(x -> eltype(x) <: target_eltype, dependent_var)) do dependent_name
389-
unrolled_map(filtered_names(x -> eltype(x) <: target_eltype, independent_var)) do independent_name
390-
(append_internal_name(key[1], dependent_name), append_internal_name(key[2], independent_name))
490+
unrolled_flatmap(
491+
filtered_names(
492+
x -> eltype(x) <: target_eltype,
493+
dependent_var,
494+
),
495+
) do dependent_name
496+
unrolled_map(
497+
filtered_names(
498+
x -> eltype(x) <: target_eltype,
499+
independent_var,
500+
),
501+
) do independent_name
502+
(
503+
append_internal_name(key[1], dependent_name),
504+
append_internal_name(key[2], independent_name),
505+
)
391506
end
392507
# @Main.infiltrate
393508
# key
394509
end
395510
# (key,)
396511
end
512+
# elseif entry isa DiagonalMatrixRow
513+
# target_eltype = eltype(parent(get_field(Y, key[1])))
514+
# # TODO: unify target_eltype
515+
# (key,)
397516
else
398-
# TODO: Fix me
399-
(key,)
517+
error("Cannot get scalar keys for key $key")
400518
end
401519

402-
# entry =
403-
# entry isa ColumnwiseBandMatrixField ? entry.entries.:(1) : entry
404-
# unrolled_map(filtered_names(entry) do field
405-
# if field isa UniformScaling
406-
# true
407-
# elseif field isa Fields.Field
408-
# eltype(field) == eltype(eltype(field))
409-
# else
410-
# eltype(field) == typeof(field)
411-
# end
412-
# end) do name
413-
# (append_internal_name(key[1], name), key[2])
414-
# end
415520
end
416521
return FieldMatrixKeys(keys_tuple)
417522
end
418523

419-
function new_get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
420-
scalar_field_vector_keys = MatrixFields.filtered_names(Y) do field
421-
field isa Fields.Field && eltype(field) == eltype(parent(field))
422-
end
423-
map(keys(dict).values) do key
424-
first_key_is_scalar = unrolled_any(isequal(key[1]), scalar_field_vector_keys)
425-
second_key_is_scalar = unrolled_any(isequal(key[2]), scalar_field_vector_keys)
426-
@assert first_key_is_scalar || second_key_is_scalar "$key"
427-
end
428-
end
524+
# function new_get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
525+
# scalar_field_vector_keys = MatrixFields.filtered_names(Y) do field
526+
# field isa Fields.Field && eltype(field) == eltype(parent(field))
527+
# end
528+
# map(keys(dict).values) do key
529+
# first_key_is_scalar = unrolled_any(isequal(key[1]), scalar_field_vector_keys)
530+
# second_key_is_scalar = unrolled_any(isequal(key[2]), scalar_field_vector_keys)
531+
# @assert first_key_is_scalar || second_key_is_scalar "$key"
532+
# end
533+
# end
429534

430535
"""
431536
scalar_fieldmatrix(field_matrix::FieldMatrix)
@@ -467,6 +572,14 @@ function scalar_fieldmatrix(field_matrix::FieldMatrix)
467572
return FieldNameDict(scalar_keys, entries)
468573
end
469574

575+
function scalar_fieldmatrix(field_matrix::FieldMatrix, Y::Fields.FieldVector)
576+
scalar_keys = get_scalar_keys(field_matrix, Y)
577+
entries = unrolled_map(scalar_keys.values) do key
578+
field_matrix[key]
579+
end
580+
return FieldNameDict(scalar_keys, entries)
581+
end
582+
470583
replace_name_tree(dict::FieldNameDict, name_tree) =
471584
FieldNameDict(replace_name_tree(keys(dict), name_tree), values(dict))
472585

@@ -776,8 +889,8 @@ function Base.Broadcast.broadcasted(
776889
)
777890
product_value = scaling_value(entry1) * scaling_value(entry2)
778891
product_value isa Number ?
779-
UniformScaling(product_value) :
780-
DiagonalMatrixRow(product_value)
892+
(UniformScaling(product_value),) :
893+
(DiagonalMatrixRow(product_value),)
781894
elseif entry1 isa ScalingFieldMatrixEntry
782895
Base.Broadcast.broadcasted(*, (scaling_value(entry1),), entry2)
783896
elseif entry2 isa ScalingFieldMatrixEntry

0 commit comments

Comments
 (0)