Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions ext/SparseMatrixColoringsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ function SMC.StarSetColoringResult(
A::CuSparseMatrixCSC,
ag::SMC.AdjacencyGraph{T},
color::Vector{<:Integer},
star_set::SMC.StarSet{<:Integer},
star_set::SMC.StarSet{<:Integer};
decompression_uplo::Symbol=:F,
) where {T<:Integer}
@assert decompression_uplo == :F
group = SMC.group_by_color(T, color)
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
return SMC.StarSetColoringResult(
A, ag, color, group, compressed_indices, additional_info
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
)
end

Expand Down Expand Up @@ -85,13 +87,15 @@ function SMC.StarSetColoringResult(
A::CuSparseMatrixCSR,
ag::SMC.AdjacencyGraph{T},
color::Vector{<:Integer},
star_set::SMC.StarSet{<:Integer},
star_set::SMC.StarSet{<:Integer};
decompression_uplo::Symbol=:F,
) where {T<:Integer}
@assert decompression_uplo == :F
group = SMC.group_by_color(T, color)
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
return SMC.StarSetColoringResult(
A, ag, color, group, compressed_indices, additional_info
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
)
end

Expand Down
50 changes: 48 additions & 2 deletions src/decompression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,18 @@ function decompress!(
check_compatible_pattern(A, ag, uplo)
fill!(A, zero(eltype(A)))

l = 0
rvS = rowvals(S)
for j in axes(S, 2)
for k in nzrange(S, j)
i = rvS[k]
if in_triangle(i, j, uplo)
A[i, j] = B[compressed_indices[k]]
if result.decompression_uplo == :F
A[i, j] = B[compressed_indices[k]]
else
l += 1
A[i, j] = B[compressed_indices[l]]
end
end
end
end
Expand All @@ -472,6 +478,7 @@ function decompress_single_color!(
result::StarSetColoringResult,
uplo::Symbol=:F,
)
@assert result.decompression_uplo == :F
(; ag, compressed_indices, group) = result
(; S) = ag
check_compatible_pattern(A, ag, uplo)
Expand Down Expand Up @@ -509,11 +516,12 @@ function decompress!(
(; S) = ag
nzA = nonzeros(A)
check_compatible_pattern(A, ag, uplo)
if uplo == :F
if result.decompression_uplo == uplo
for k in eachindex(nzA, compressed_indices)
nzA[k] = B[compressed_indices[k]]
end
else
@assert result.decompression_uplo == :F
rvS = rowvals(S)
l = 0 # assume A has the same pattern as the triangle
for j in axes(S, 2)
Expand All @@ -529,6 +537,44 @@ function decompress!(
return A
end

function decompress_single_color!(
A::SparseMatrixCSC,
b::AbstractVector,
c::Integer,
result::StarSetColoringResult,
uplo::Symbol=:F,
)
(; ag, compressed_indices) = result
(; S) = ag
lower_index = (c - 1) * S.n + 1
upper_index = c * S.n
nzA = nonzeros(A)
if result.decompression_uplo == uplo
uplo == :F && check_same_pattern(A, S)
for k in eachindex(nzA, compressed_indices)
if lower_index <= compressed_indices[k] <= upper_index
nzA[k] = b[compressed_indices[k] - lower_index + 1]
end
end
else
@assert result.decompression_uplo == :F
rvS = rowvals(S)
l = 0 # assume A has the same pattern as the triangle
for j in axes(S, 2)
for k in nzrange(S, j)
i = rvS[k]
if in_triangle(i, j, uplo)
l += 1
if lower_index <= compressed_indices[k] <= upper_index
nzA[l] = b[i]
end
end
end
end
end
return A
end

## TreeSetColoringResult

function decompress!(
Expand Down
35 changes: 23 additions & 12 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,11 @@ function coloring(
A::AbstractMatrix,
problem::ColoringProblem,
algo::GreedyColoringAlgorithm;
decompression_eltype::Type{R}=Float64,
symmetric_pattern::Bool=false,
decompression_eltype::Type{R}=Float64,
decompression_uplo::Symbol=:F,
) where {R}
return _coloring(WithResult(), A, problem, algo, R, symmetric_pattern)
return _coloring(WithResult(), A, problem, algo, symmetric_pattern, R, decompression_uplo)
end

"""
Expand Down Expand Up @@ -229,8 +230,9 @@ function _coloring(
A::AbstractMatrix,
::ColoringProblem{:nonsymmetric,:column},
algo::GreedyColoringAlgorithm,
symmetric_pattern::Bool,
decompression_eltype::Type,
symmetric_pattern::Bool;
decompression_uplo::Symbol;
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
)
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
Expand All @@ -252,8 +254,9 @@ function _coloring(
A::AbstractMatrix,
::ColoringProblem{:nonsymmetric,:row},
algo::GreedyColoringAlgorithm,
symmetric_pattern::Bool,
decompression_eltype::Type,
symmetric_pattern::Bool;
decompression_uplo::Symbol;
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
)
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
Expand All @@ -275,8 +278,9 @@ function _coloring(
A::AbstractMatrix,
::ColoringProblem{:symmetric,:column},
algo::GreedyColoringAlgorithm{:direct},
symmetric_pattern::Bool,
decompression_eltype::Type,
symmetric_pattern::Bool;
decompression_uplo::Symbol;
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
)
ag = AdjacencyGraph(A; augmented_graph=false)
Expand All @@ -286,7 +290,7 @@ function _coloring(
end
color, star_set = argmin(maximum ∘ first, color_and_star_set_by_order)
if speed_setting isa WithResult
return StarSetColoringResult(A, ag, color, star_set)
return StarSetColoringResult(A, ag, color, star_set; decompression_uplo)
else
return color
end
Expand All @@ -297,8 +301,9 @@ function _coloring(
A::AbstractMatrix,
::ColoringProblem{:symmetric,:column},
algo::GreedyColoringAlgorithm{:substitution},
decompression_eltype::Type{R},
symmetric_pattern::Bool,
decompression_eltype::Type{R},
decompression_uplo::Symbol,
) where {R}
ag = AdjacencyGraph(A; augmented_graph=false)
color_and_tree_set_by_order = map(algo.orders) do order
Expand All @@ -307,7 +312,7 @@ function _coloring(
end
color, tree_set = argmin(maximum ∘ first, color_and_tree_set_by_order)
if speed_setting isa WithResult
return TreeSetColoringResult(A, ag, color, tree_set, R)
return TreeSetColoringResult(A, ag, color, tree_set, R; decompression_uplo)
else
return color
end
Expand All @@ -318,8 +323,9 @@ function _coloring(
A::AbstractMatrix,
::ColoringProblem{:nonsymmetric,:bidirectional},
algo::GreedyColoringAlgorithm{:direct},
symmetric_pattern::Bool,
decompression_eltype::Type{R},
symmetric_pattern::Bool;
decompression_uplo::Symbol;
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
) where {R}
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
Expand All @@ -345,7 +351,9 @@ function _coloring(
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
) # can't use ncolors without computing the full result
if speed_setting isa WithResult
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
symmetric_result = StarSetColoringResult(
A_and_Aᵀ, ag, color, star_set; decompression_uplo=:L
)
return BicoloringResult(
A,
ag,
Expand All @@ -366,8 +374,9 @@ function _coloring(
A::AbstractMatrix,
::ColoringProblem{:nonsymmetric,:bidirectional},
algo::GreedyColoringAlgorithm{:substitution},
decompression_eltype::Type{R},
symmetric_pattern::Bool,
decompression_eltype::Type{R},
decompression_uplo::Symbol,
) where {R}
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
ag = AdjacencyGraph(A_and_Aᵀ, edge_to_index, 0; augmented_graph=true)
Expand All @@ -390,7 +399,9 @@ function _coloring(
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
) # can't use ncolors without computing the full result
if speed_setting isa WithResult
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
symmetric_result = TreeSetColoringResult(
A_and_Aᵀ, ag, color, tree_set, R; decompression_uplo=:L
)
return BicoloringResult(
A,
ag,
Expand Down
Loading
Loading