Skip to content

Expand API: gemm, by-key reductions, meanVar, index ops, type fixes#68

Open
dmjio wants to merge 3 commits into
masterfrom
feature/api-improvements-and-new-functions
Open

Expand API: gemm, by-key reductions, meanVar, index ops, type fixes#68
dmjio wants to merge 3 commits into
masterfrom
feature/api-improvements-and-new-functions

Conversation

@dmjio
Copy link
Copy Markdown
Member

@dmjio dmjio commented Jun 5, 2026

Summary

This PR adds several new functions, fixes type errors and bugs, hardens the FFI layer, and expands test coverage.

New API surface

  • gemm (BLAS): General matrix multiply C = α·op(A)·op(B) + β·C, bound to af_gemm. Useful for iterative eigenvalue algorithms (Jacobi rotations, power iteration) where accumulated orthogonal transformations need scaling.
  • Key-value (segmented) reductions (Algorithm): sumByKey, sumByKeyNaN, productByKey, productByKeyNaN, minByKey, maxByKey, allTrueByKey, anyTrueByKey, countByKey — all bound to their af_*_by_key C counterparts. These enable sparse tensor contractions and grouped reductions needed for MPO sweeps in tensor network methods.
  • meanVar / meanVarWeighted (Statistics): simultaneous mean+variance in one pass via af_meanvar. Introduces the VarBias type (VarianceDefault | VarianceSample | VariancePopulation).
  • assignSeq, indexGen, assignGen (Index): three functions that were previously error "Not implemented" stubs, now fully implemented via af_assign_seq, af_index_gen, af_assign_gen.

Type corrections and bug fixes

  • imin, imax, sortIndex, topk: index output changed from Array a to Array Word32 (matching ArrayFire's u32 contract).
  • afBackendCpu was bound to AF_BACKEND_DEFAULT instead of AF_BACKEND_CPU.
  • toConnectivity: AFConnectivity 8 mapped to Conn4 instead of Conn8.
  • AFIndex Storable peek: Left/Right branches were swapped (seq vs array pointer).
  • histogram: spurious double-cast removed.
  • span renamed to afSpan to stop shadowing Prelude.span.
  • op1 generalised from Array a -> ... -> Array a to Array a -> ... -> Array b; op1d removed.
  • op2p return type generalised to (Array a, Array b).
  • af_cast qualified in Arith.hs to resolve GHC 9.10 ambiguous occurrence error.

FFI hardening

  • All unsafePerformIO helpers in FFI.hs now use mask_ to prevent async exceptions from leaving output pointers uninitialised.
  • op2p2kv added for the key-value two-output calling convention (handles Ints32/s64 casting transparently).

Num/Floating fixes (Orphans.hs)

  • negate simplified to scalar (-1) \mul` arr`.
  • Eq checks dimension-guards before broadcasting.
  • pi uses full IEEE 754 precision via realToFrac Prelude.pi.
  • NFData (Array a) instance added.

Documentation

Haddock constructor comments added to all major sum types in Internal/Types.hsc. Fixed stale parameter docs in drawVectorField2d.

Tests

Full test coverage added or corrected for all new and fixed functions. shouldBeApprox rewritten to use numpy-compatible tolerances (rtol=1e-5, atol=1e-8).

Test plan

  • cabal test passes (Algorithm, BLAS, Index, LAPACK, Statistics specs)
  • New gemm tests cover identity, alpha-scaling, and transpose cases
  • New *ByKey tests cover sum, product, min, max, count, allTrue, anyTrue
  • New meanVar tests cover population variance, sample variance, and weighted variant
  • New assignSeq/indexGen/assignGen tests cover 1D and 2D cases
  • topk and imin/imax index outputs are now correctly typed as Word32

🤖 Generated with Claude Code

…gnGen, index type fixes

## New functions

### BLAS: `gemm`
Adds `gemm :: AFType a => MatProp -> MatProp -> a -> Array a -> Array a -> a -> Array a`,
the general matrix multiply C = alpha * op(A) * op(B) + beta * C_prev.  This is more
expressive than the existing `matmul`: it supports in-place accumulation and scalar
scaling, making it directly useful for iterative eigenvalue algorithms (e.g. Jacobi
rotations) that accumulate orthogonal transformations in Q.  Implemented via the C FFI
binding `af_gemm`; scalars are passed through `Storable` alloca/poke so any `AFType`
element type is supported.  Three new unit tests cover identity scaling, alpha-scaling,
and transposition.

### Algorithm: key-value (segmented) reductions
Adds nine new functions mirroring ArrayFire's `af_*_by_key` family:
  `sumByKey`, `sumByKeyNaN`, `productByKey`, `productByKeyNaN`,
  `minByKey`, `maxByKey`, `allTrueByKey`, `anyTrueByKey`, `countByKey`
Each takes a keys `Array Int` and a values `Array a`, performs the named reduction over
contiguous equal-key runs along a given dimension, and returns `(Array Int, Array a)`.
These are essential for sparse tensor contractions that arise in many-body quantum
systems and tensor network methods (e.g. grouping indices in an MPO sweep).

A new internal FFI helper `op2p2kv` handles the keys–values two-output calling
convention.  Because ArrayFire requires the key array to be `s32` (C int) while
Haskell uses `Int` (typically `s64`), the helper casts input keys to `s32` before
calling the C function and casts the output keys back to `s64`, keeping the Haskell
API uniform at `Array Int`.

### Statistics: `meanVar` and `meanVarWeighted`
Adds `meanVar :: AFType a => Array a -> VarBias -> Int -> (Array a, Array a)` and its
weighted variant, bound to `af_meanvar`.  Computing mean and variance in a single pass
is both more accurate and more efficient than calling them separately, which matters
for normalisation steps in quantum state tomography and Hamiltonian learning.

Introduces the `VarBias` high-level type (`VarianceDefault | VarianceSample |
VariancePopulation`) backed by the previously-commented-out `AFVarBias` newtype in
`Internal/Defines.hsc` (now uncommented and given a `Storable` instance).  `VarBias`
and its conversion `fromVarBias` are exported from `ArrayFire.Types`.

### Index: `assignSeq`, `indexGen`, `assignGen`; rename `span` → `afSpan`
Implements three functions that were previously stubs (`error "Not implemented"`):

- `assignSeq :: Array a -> [Seq] -> Array a -> Array a` — write a source array into a
  sequential slice of a destination array, bound to `af_assign_seq`.
- `indexGen :: Array a -> [Index] -> Array a` — generalised indexing by a list of
  `Index` values (sequence or array), bound to `af_index_gen`.
- `assignGen :: Array a -> [Index] -> Array a -> Array a` — generalised slice
  assignment, bound to `af_assign_gen`.

These are needed for constructing sparse interaction terms (e.g. projecting onto a
subspace defined by an index set).

`span` is renamed to `afSpan` to avoid shadowing `Prelude.span`, which caused silent
import errors in downstream modules.

## Type corrections and bug fixes

### `Index` type redesign (`Internal/Types.hsc`)
The `Index a` type (which parameterised over the array element type) is replaced by a
simpler unparameterised GADT-style sum:
  `data Index = SeqIndex Bool Seq | ArrIndex Bool (Array Int)`
This removes a phantom type parameter that was never meaningful (index arrays are
always integral), and fixes the `toAFIndex` implementation which was using
`unsafeForeignPtrToPtr` incorrectly — the old version passed a pointer whose lifetime
was not guaranteed by `withForeignPtr`.  The new version stores the raw pointer and
relies on `touchForeignPtr` calls at the use site to keep the ForeignPtr alive.

The `Storable` peek instance for `AFIndex` also had the `Left`/`Right` branches swapped
(`isSeq == True` should produce a sequence, not an array pointer); this is fixed.

### Return types for index-returning operations
`imin`, `imax`, `sortIndex`, and `topk` all return an index array.  Their return types
are corrected from `(Array a, Array a)` to `(Array a, Array Word32)`, matching
ArrayFire's documented `u32` output for index arrays.  The corresponding `op2p` helper
in `FFI.hs` is generalised from `(Array a, Array a)` to `(Array a, Array b)`.

### `afBackendCpu` constant (`Internal/Defines.hsc`)
Fixed: `afBackendCpu` was mistakenly bound to `AF_BACKEND_DEFAULT` instead of
`AF_BACKEND_CPU`.

### `toConnectivity` (`Internal/Types.hsc`)
Fixed: `AFConnectivity 8` was mapped to `Conn4` instead of `Conn8`.

### `histogram` (`Image.hs`)
Removed a spurious `cast` wrapping around the `af_histogram` call; the C function
already returns `u32`, so double-casting was wrong.

## FFI infrastructure

### `op1d` removed; `op1` generalised
`op1d :: Array a -> (...) -> Array b` was an alias for `op1` but with the output type
fixed to `Array b` (different from input).  All call sites that used `op1d` (`not`,
`real`, `imag`, `count`) are migrated to `op1`.  `op1` itself is generalised from
`Array a -> ... -> Array a` to `Array a -> ... -> Array b`, making `op1d` redundant.

### `mask_` added to all `unsafePerformIO` helpers
Every `op*` helper in `FFI.hs` now wraps its `unsafePerformIO` block with `mask_`.
Without `mask_`, an asynchronous exception arriving during the FFI call can leave the
output `AFArray` pointer uninitialised, producing a segfault or a garbage `ForeignPtr`
finalization.

### `af_cast` disambiguation (`Arith.hs`)
`af_cast` is now qualified as `ArrayFire.Internal.Arith.af_cast` at its call site in
`cast` because `FFI.hs` also imports the same C symbol (needed for `op2p2kv`), creating
an ambiguous occurrence error under GHC 9.10.

## `Num` / `Floating` instance fixes (`Orphans.hs`)
- `negate` is simplified from an allocate-a-zero-constant approach to
  `scalar (-1) \`mul\` arr`, removing a dependency on dimension information.
- `Eq` checks now compare dimensions first before invoking `allTrueAll`,
  avoiding a broadcast-induced wrong answer when shapes differ.
- `pi` now uses `realToFrac (Prelude.pi :: Double)` instead of the hard-coded
  literal `3.14159`, gaining full IEEE 754 double precision.
- Added `NFData (Array a)` instance (shallow: evaluates the `ForeignPtr` to WHNF).

## Documentation
- Haddock constructor comments added to all sum types: `Backend`, `MatProp`,
  `BinaryOp`, `Storage`, `InterpType`, `CSpace`, `YccStd`, `MomentType`,
  `CannyThreshold`, `FluxFunction`, `DiffusionEq`, `IterativeDeconvAlgo`,
  `InverseDeconvAlgo`, `Cell`, `ColorMap`, `MarkerType`, `MatchType`, `TopK`,
  `HomographyType`, and the new `VarBias`.
- Fixed stale parameter documentation in `drawVectorField2d` (previously all four
  array parameters were labelled "is the window handle").

## Tests
- `AlgorithmSpec`: seven new tests covering all `*ByKey` functions.
- `BLASSpec`: three new tests for `gemm` (identity, alpha-scaling, transpose).
- `IndexSpec`: complete rewrite — `index`, `afSpan`, `lookup`, `assignSeq`,
  `indexGen`, `assignGen` each covered with multiple cases.
- `LAPACKSpec`: variable names corrected (`s,v,d` → `l,u,piv` / `q,r,tau`);
  `det` test split into real and complex cases with exact expected values;
  `inverse`, `rank`, and `norm` tests added.
- `StatisticsSpec`: `topk` index type updated to `Word32`; three new tests for
  `meanVar` (population, sample) and `meanVarWeighted`.
- `ArraySpec`: placeholder `1+1==2` replaced with a real `Array` addition test.
- `ApproxExpect`: `shouldBeApprox` rewritten to use numpy-compatible
  `|a-b| <= atol + rtol * max(|a|, |b|)` (rtol=1e-5, atol=1e-8) instead of the
  fragile scale-and-compare hack; signature now requires `Ord` and is exported cleanly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@dmjio dmjio force-pushed the feature/api-improvements-and-new-functions branch 2 times, most recently from 9373e43 to a99e153 Compare June 5, 2026 21:21
@dmjio dmjio changed the title Expand API: gemm, by-key reductions, meanVar, index ops, type fixes Expand API: gemm, by-key reductions, meanVar, index ops, type fixes Jun 5, 2026
@dmjio dmjio force-pushed the feature/api-improvements-and-new-functions branch from a99e153 to 723c64a Compare June 5, 2026 21:40
@dmjio dmjio force-pushed the feature/api-improvements-and-new-functions branch from e43c610 to c44d1f7 Compare June 5, 2026 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant