mirror of
https://github.com/IBM/fp-go.git
synced 2026-01-13 00:44:11 +02:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f2e76dd94 | ||
|
|
77965a12ff | ||
|
|
ed77bd7971 | ||
|
|
f154790d88 | ||
|
|
e010f13dce | ||
|
|
86a260a204 |
623
v2/circuitbreaker/circuitbreaker.go
Normal file
623
v2/circuitbreaker/circuitbreaker.go
Normal file
@@ -0,0 +1,623 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/either"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/identity"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/ioref"
|
||||
"github.com/IBM/fp-go/v2/lazy"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
"github.com/IBM/fp-go/v2/retry"
|
||||
)
|
||||
|
||||
var (
|
||||
canaryRequestLens = lens.MakeLensWithName(
|
||||
func(os openState) bool { return os.canaryRequest },
|
||||
func(os openState, flag bool) openState {
|
||||
os.canaryRequest = flag
|
||||
return os
|
||||
},
|
||||
"openState.CanaryRequest",
|
||||
)
|
||||
|
||||
retryStatusLens = lens.MakeLensWithName(
|
||||
func(os openState) retry.RetryStatus { return os.retryStatus },
|
||||
func(os openState, status retry.RetryStatus) openState {
|
||||
os.retryStatus = status
|
||||
return os
|
||||
},
|
||||
"openState.RetryStatus",
|
||||
)
|
||||
|
||||
resetAtLens = lens.MakeLensWithName(
|
||||
func(os openState) time.Time { return os.resetAt },
|
||||
func(os openState, tm time.Time) openState {
|
||||
os.resetAt = tm
|
||||
return os
|
||||
},
|
||||
"openState.ResetAt",
|
||||
)
|
||||
|
||||
openedAtLens = lens.MakeLensWithName(
|
||||
func(os openState) time.Time { return os.openedAt },
|
||||
func(os openState, tm time.Time) openState {
|
||||
os.openedAt = tm
|
||||
return os
|
||||
},
|
||||
"openState.OpenedAt",
|
||||
)
|
||||
|
||||
createClosedCircuit = either.Right[openState, ClosedState]
|
||||
createOpenCircuit = either.Left[ClosedState, openState]
|
||||
|
||||
// MakeClosedIORef creates an IORef containing a closed circuit breaker state.
|
||||
// It wraps the provided ClosedState in a Right (closed) BreakerState and creates
|
||||
// a mutable reference to it.
|
||||
//
|
||||
// Parameters:
|
||||
// - closedState: The initial closed state configuration
|
||||
//
|
||||
// Returns:
|
||||
// - An IO operation that creates an IORef[BreakerState] initialized to closed state
|
||||
//
|
||||
// Thread Safety: The returned IORef[BreakerState] is thread-safe. It uses atomic
|
||||
// operations for all read/write/modify operations. The BreakerState itself is immutable.
|
||||
MakeClosedIORef = F.Flow2(
|
||||
createClosedCircuit,
|
||||
ioref.MakeIORef,
|
||||
)
|
||||
|
||||
// IsOpen checks if a BreakerState is in the open state.
|
||||
// Returns true if the circuit breaker is open (blocking requests), false otherwise.
|
||||
IsOpen = either.IsLeft[openState, ClosedState]
|
||||
|
||||
// IsClosed checks if a BreakerState is in the closed state.
|
||||
// Returns true if the circuit breaker is closed (allowing requests), false otherwise.
|
||||
IsClosed = either.IsRight[openState, ClosedState]
|
||||
|
||||
// modifyV creates a Reader that sequences an IORef modification operation.
|
||||
// It takes an IORef[BreakerState] and returns a Reader that, when given an endomorphism
|
||||
// (a function from BreakerState to BreakerState), produces an IO operation that modifies
|
||||
// the IORef and returns the new state.
|
||||
//
|
||||
// This is used internally to create state modification operations that can be composed
|
||||
// with other Reader-based operations in the circuit breaker logic.
|
||||
//
|
||||
// Thread Safety: The IORef modification is atomic. Multiple concurrent calls will be
|
||||
// serialized by the IORef's atomic operations.
|
||||
//
|
||||
// Type signature: Reader[IORef[BreakerState], IO[Endomorphism[BreakerState]]]
|
||||
modifyV = reader.Sequence(ioref.Modify[BreakerState])
|
||||
|
||||
initialRetry = retry.DefaultRetryStatus
|
||||
|
||||
// testCircuit sets the canaryRequest flag to true in an openState.
|
||||
// This is used to mark that the circuit breaker is in half-open state,
|
||||
// allowing a single test request (canary) to check if the service has recovered.
|
||||
//
|
||||
// When canaryRequest is true:
|
||||
// - One request is allowed through to test the service
|
||||
// - If the canary succeeds, the circuit closes
|
||||
// - If the canary fails, the circuit remains open with an extended reset time
|
||||
//
|
||||
// Thread Safety: This is a pure function that returns a new openState; it does not
|
||||
// modify its input. Safe for concurrent use.
|
||||
//
|
||||
// Type signature: Endomorphism[openState]
|
||||
testCircuit = canaryRequestLens.Set(true)
|
||||
)
|
||||
|
||||
// makeOpenCircuitFromPolicy creates a function that constructs an openState from a retry policy.
|
||||
// This is a curried function that takes a retry policy and returns a function that takes a retry status
|
||||
// and current time to produce an openState with calculated reset time.
|
||||
//
|
||||
// The function applies the retry policy to determine the next retry delay and calculates
|
||||
// the resetAt time by adding the delay to the current time. If no previous delay exists
|
||||
// (first failure), the resetAt is set to the current time.
|
||||
//
|
||||
// Parameters:
|
||||
// - policy: The retry policy that determines backoff strategy (e.g., exponential backoff)
|
||||
//
|
||||
// Returns:
|
||||
// - A curried function that takes:
|
||||
// 1. rs (retry.RetryStatus): The current retry status containing retry count and previous delay
|
||||
// 2. ct (time.Time): The current time when the circuit is opening
|
||||
// And returns an openState with:
|
||||
// - openedAt: Set to the current time (ct)
|
||||
// - resetAt: Current time plus the delay from the retry policy
|
||||
// - retryStatus: The updated retry status from applying the policy
|
||||
// - canaryRequest: false (will be set to true when reset time is reached)
|
||||
//
|
||||
// Thread Safety: This is a pure function that creates new openState instances.
|
||||
// Safe for concurrent use.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// policy := retry.ExponentialBackoff(1*time.Second, 2.0, 10)
|
||||
// makeOpen := makeOpenCircuitFromPolicy(policy)
|
||||
// openState := makeOpen(retry.DefaultRetryStatus)(time.Now())
|
||||
// // openState.resetAt will be approximately 1 second from now
|
||||
func makeOpenCircuitFromPolicy(policy retry.RetryPolicy) func(rs retry.RetryStatus) func(ct time.Time) openState {
|
||||
|
||||
return func(rs retry.RetryStatus) func(ct time.Time) openState {
|
||||
|
||||
retryStatus := retry.ApplyPolicy(policy, rs)
|
||||
|
||||
return func(ct time.Time) openState {
|
||||
|
||||
resetTime := F.Pipe2(
|
||||
retryStatus,
|
||||
retry.PreviousDelayLens.Get,
|
||||
option.Fold(
|
||||
F.Pipe1(
|
||||
ct,
|
||||
lazy.Of,
|
||||
),
|
||||
ct.Add,
|
||||
),
|
||||
)
|
||||
|
||||
return openState{openedAt: ct, resetAt: resetTime, retryStatus: retryStatus}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extendOpenCircuitFromMakeCircuit creates a function that extends the open state of a circuit breaker
|
||||
// when a canary request fails. It takes a circuit maker function and returns a function that,
|
||||
// given the current time, produces an endomorphism that updates an openState.
|
||||
//
|
||||
// This function is used when a canary request (test request in half-open state) fails.
|
||||
// It extends the circuit breaker's open period by:
|
||||
// 1. Extracting the current retry status from the open state
|
||||
// 2. Using the makeCircuit function to calculate a new open state with updated retry status
|
||||
// 3. Applying the current time to get the new state
|
||||
// 4. Setting the canaryRequest flag to true to allow another test request later
|
||||
//
|
||||
// Parameters:
|
||||
// - makeCircuit: A function that creates an openState from a retry status and current time.
|
||||
// This is typically created by makeOpenCircuitFromPolicy.
|
||||
//
|
||||
// Returns:
|
||||
// - A curried function that takes:
|
||||
// 1. ct (time.Time): The current time when extending the circuit
|
||||
// And returns an Endomorphism[openState] that:
|
||||
// - Increments the retry count
|
||||
// - Calculates a new resetAt time based on the retry policy (typically with exponential backoff)
|
||||
// - Sets canaryRequest to true for the next test attempt
|
||||
//
|
||||
// Thread Safety: This is a pure function that returns new openState instances.
|
||||
// Safe for concurrent use.
|
||||
//
|
||||
// Usage Context:
|
||||
// - Called when a canary request fails in the half-open state
|
||||
// - Extends the open period with increased backoff delay
|
||||
// - Prepares the circuit for another canary attempt at the new resetAt time
|
||||
func extendOpenCircuitFromMakeCircuit(
|
||||
makeCircuit func(rs retry.RetryStatus) func(ct time.Time) openState,
|
||||
) func(time.Time) Endomorphism[openState] {
|
||||
return func(ct time.Time) Endomorphism[openState] {
|
||||
return F.Flow4(
|
||||
retryStatusLens.Get,
|
||||
makeCircuit,
|
||||
identity.Flap[openState](ct),
|
||||
testCircuit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// isResetTimeExceeded checks if the reset time for an open circuit has been exceeded.
|
||||
// This is used to determine if the circuit breaker should transition from open to half-open state
|
||||
// by allowing a canary request.
|
||||
//
|
||||
// The function returns an option.Kleisli that succeeds (returns Some) only when:
|
||||
// 1. The circuit is not already in canary mode (canaryRequest is false)
|
||||
// 2. The current time is after the resetAt time
|
||||
//
|
||||
// Parameters:
|
||||
// - ct: The current time to compare against the reset time
|
||||
//
|
||||
// Returns:
|
||||
// - An option.Kleisli[openState, openState] that:
|
||||
// - Returns Some(openState) if the reset time has been exceeded and no canary is active
|
||||
// - Returns None if the reset time has not been exceeded or a canary request is already active
|
||||
//
|
||||
// Thread Safety: This is a pure function that does not modify its input.
|
||||
// Safe for concurrent use.
|
||||
//
|
||||
// Usage Context:
|
||||
// - Called when the circuit is open to check if it's time to attempt a canary request
|
||||
// - If this returns Some, the circuit transitions to half-open state (canary mode)
|
||||
// - If this returns None, the circuit remains fully open and requests are blocked
|
||||
func isResetTimeExceeded(ct time.Time) option.Kleisli[openState, openState] {
|
||||
return option.FromPredicate(func(open openState) bool {
|
||||
return !open.canaryRequest && ct.After(resetAtLens.Get(open))
|
||||
})
|
||||
}
|
||||
|
||||
// handleSuccessOnClosed handles a successful request when the circuit breaker is in closed state.
|
||||
// It updates the closed state by recording the success and returns an IO operation that
|
||||
// modifies the breaker state.
|
||||
//
|
||||
// This function is part of the circuit breaker's state management for the closed state.
|
||||
// When a request succeeds in closed state:
|
||||
// 1. The current time is obtained
|
||||
// 2. The addSuccess function is called with the current time to update the ClosedState
|
||||
// 3. The updated ClosedState is wrapped in a Right (closed) BreakerState
|
||||
// 4. The breaker state is modified with the new state
|
||||
//
|
||||
// Parameters:
|
||||
// - currentTime: An IO operation that provides the current time
|
||||
// - addSuccess: A Reader that takes a time and returns an endomorphism for ClosedState,
|
||||
// typically resetting failure counters or history
|
||||
//
|
||||
// Returns:
|
||||
// - An io.Kleisli that takes another io.Kleisli and chains them together.
|
||||
// The outer Kleisli takes an Endomorphism[BreakerState] and returns BreakerState.
|
||||
// This allows composing the success handling with other state modifications.
|
||||
//
|
||||
// Thread Safety: This function creates IO operations that will atomically modify the
|
||||
// IORef[BreakerState] when executed. The state modifications are thread-safe.
|
||||
//
|
||||
// Type signature:
|
||||
//
|
||||
// io.Kleisli[io.Kleisli[Endomorphism[BreakerState], BreakerState], BreakerState]
|
||||
//
|
||||
// Usage Context:
|
||||
// - Called when a request succeeds while the circuit is closed
|
||||
// - Resets failure tracking (counter or history) in the ClosedState
|
||||
// - Keeps the circuit in closed state
|
||||
func handleSuccessOnClosed(
|
||||
currentTime IO[time.Time],
|
||||
addSuccess Reader[time.Time, Endomorphism[ClosedState]],
|
||||
) io.Kleisli[io.Kleisli[Endomorphism[BreakerState], BreakerState], BreakerState] {
|
||||
return F.Flow2(
|
||||
io.Chain,
|
||||
identity.Flap[IO[BreakerState]](F.Pipe1(
|
||||
currentTime,
|
||||
io.Map(F.Flow2(
|
||||
addSuccess,
|
||||
either.Map[openState],
|
||||
)))),
|
||||
)
|
||||
}
|
||||
|
||||
// handleFailureOnClosed handles a failed request when the circuit breaker is in closed state.
|
||||
// It updates the closed state by recording the failure and checks if the circuit should open.
|
||||
//
|
||||
// This function is part of the circuit breaker's state management for the closed state.
|
||||
// When a request fails in closed state:
|
||||
// 1. The current time is obtained
|
||||
// 2. The addError function is called to record the failure in the ClosedState
|
||||
// 3. The checkClosedState function is called to determine if the failure threshold is exceeded
|
||||
// 4. If the threshold is exceeded (Check returns None):
|
||||
// - The circuit transitions to open state using openCircuit
|
||||
// - A new openState is created with resetAt time calculated from the retry policy
|
||||
// 5. If the threshold is not exceeded (Check returns Some):
|
||||
// - The circuit remains closed with the updated failure tracking
|
||||
//
|
||||
// Parameters:
|
||||
// - currentTime: An IO operation that provides the current time
|
||||
// - addError: A Reader that takes a time and returns an endomorphism for ClosedState,
|
||||
// recording a failure (incrementing counter or adding to history)
|
||||
// - checkClosedState: A Reader that takes a time and returns an option.Kleisli that checks
|
||||
// if the ClosedState should remain closed. Returns Some if circuit stays closed, None if it should open.
|
||||
// - openCircuit: A Reader that takes a time and returns an openState with calculated resetAt time
|
||||
//
|
||||
// Returns:
|
||||
// - An io.Kleisli that takes another io.Kleisli and chains them together.
|
||||
// The outer Kleisli takes an Endomorphism[BreakerState] and returns BreakerState.
|
||||
// This allows composing the failure handling with other state modifications.
|
||||
//
|
||||
// Thread Safety: This function creates IO operations that will atomically modify the
|
||||
// IORef[BreakerState] when executed. The state modifications are thread-safe.
|
||||
//
|
||||
// Type signature:
|
||||
//
|
||||
// io.Kleisli[io.Kleisli[Endomorphism[BreakerState], BreakerState], BreakerState]
|
||||
//
|
||||
// State Transitions:
|
||||
// - Closed -> Closed: When failure threshold is not exceeded (Some from checkClosedState)
|
||||
// - Closed -> Open: When failure threshold is exceeded (None from checkClosedState)
|
||||
//
|
||||
// Usage Context:
|
||||
// - Called when a request fails while the circuit is closed
|
||||
// - Records the failure in the ClosedState (counter or history)
|
||||
// - May trigger transition to open state if threshold is exceeded
|
||||
func handleFailureOnClosed(
|
||||
currentTime IO[time.Time],
|
||||
addError Reader[time.Time, Endomorphism[ClosedState]],
|
||||
checkClosedState Reader[time.Time, option.Kleisli[ClosedState, ClosedState]],
|
||||
openCircuit Reader[time.Time, openState],
|
||||
) io.Kleisli[io.Kleisli[Endomorphism[BreakerState], BreakerState], BreakerState] {
|
||||
|
||||
return F.Flow2(
|
||||
io.Chain,
|
||||
identity.Flap[IO[BreakerState]](F.Pipe1(
|
||||
currentTime,
|
||||
io.Map(func(ct time.Time) either.Operator[openState, ClosedState, ClosedState] {
|
||||
return either.Chain(F.Flow3(
|
||||
addError(ct),
|
||||
checkClosedState(ct),
|
||||
option.Fold(
|
||||
F.Pipe2(
|
||||
ct,
|
||||
lazy.Of,
|
||||
lazy.Map(F.Flow2(
|
||||
openCircuit,
|
||||
createOpenCircuit,
|
||||
)),
|
||||
),
|
||||
createClosedCircuit,
|
||||
),
|
||||
))
|
||||
}))),
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
// MakeCircuitBreaker creates a circuit breaker implementation for a higher-kinded type.
|
||||
//
|
||||
// This is a generic circuit breaker factory that works with any monad-like type (HKTT).
|
||||
// It implements the circuit breaker pattern by wrapping operations and managing state transitions
|
||||
// between closed, open, and half-open states based on failure rates and retry policies.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - E: The error type
|
||||
// - T: The success value type
|
||||
// - HKTT: The higher-kinded type representing the computation (e.g., IO[T], ReaderIO[R, T])
|
||||
// - HKTOP: The higher-kinded type for operators (e.g., IO[func(HKTT) HKTT])
|
||||
// - HKTHKTT: The nested higher-kinded type (e.g., IO[IO[T]])
|
||||
//
|
||||
// Parameters:
|
||||
// - left: Constructs an error result in HKTT from an error value
|
||||
// - chainFirstIOK: Chains an IO operation that runs after success, preserving the original value
|
||||
// - chainFirstLeftIOK: Chains an IO operation that runs after error, preserving the original error
|
||||
// - fromIO: Lifts an IO operation into HKTOP
|
||||
// - flap: Applies a value to a function wrapped in a higher-kinded type
|
||||
// - flatten: Flattens nested higher-kinded types (join operation)
|
||||
// - currentTime: IO operation that provides the current time
|
||||
// - closedState: The initial closed state configuration
|
||||
// - makeError: Creates an error from a reset time when the circuit is open
|
||||
// - checkError: Predicate to determine if an error should trigger circuit breaker logic
|
||||
// - policy: Retry policy for determining reset times when circuit opens
|
||||
// - logger: Logging function for circuit breaker events
|
||||
//
|
||||
// Thread Safety: The returned State monad creates operations that are thread-safe when
|
||||
// executed. The IORef[BreakerState] uses atomic operations for all state modifications.
|
||||
// Multiple concurrent requests will be properly serialized at the IORef level.
|
||||
//
|
||||
// Returns:
|
||||
// - A State monad that transforms a pair of (IORef[BreakerState], HKTT) into HKTT,
|
||||
// applying circuit breaker logic to the computation
|
||||
func MakeCircuitBreaker[E, T, HKTT, HKTOP, HKTHKTT any](
|
||||
|
||||
left func(E) HKTT,
|
||||
chainFirstIOK func(io.Kleisli[T, BreakerState]) func(HKTT) HKTT,
|
||||
chainFirstLeftIOK func(io.Kleisli[E, BreakerState]) func(HKTT) HKTT,
|
||||
|
||||
fromIO func(IO[func(HKTT) HKTT]) HKTOP,
|
||||
flap func(HKTT) func(HKTOP) HKTHKTT,
|
||||
flatten func(HKTHKTT) HKTT,
|
||||
|
||||
currentTime IO[time.Time],
|
||||
closedState ClosedState,
|
||||
makeError Reader[time.Time, E],
|
||||
checkError option.Kleisli[E, E],
|
||||
policy retry.RetryPolicy,
|
||||
metrics Metrics,
|
||||
) State[Pair[IORef[BreakerState], HKTT], HKTT] {
|
||||
|
||||
type Operator = func(HKTT) HKTT
|
||||
|
||||
addSuccess := reader.From1(ClosedState.AddSuccess)
|
||||
addError := reader.From1(ClosedState.AddError)
|
||||
checkClosedState := reader.From1(ClosedState.Check)
|
||||
|
||||
closedCircuit := createClosedCircuit(closedState.Empty())
|
||||
makeOpenCircuit := makeOpenCircuitFromPolicy(policy)
|
||||
|
||||
openCircuit := F.Pipe1(
|
||||
initialRetry,
|
||||
makeOpenCircuit,
|
||||
)
|
||||
|
||||
extendOpenCircuit := extendOpenCircuitFromMakeCircuit(makeOpenCircuit)
|
||||
|
||||
failWithError := F.Flow4(
|
||||
resetAtLens.Get,
|
||||
makeError,
|
||||
left,
|
||||
reader.Of[HKTT],
|
||||
)
|
||||
|
||||
handleSuccess := handleSuccessOnClosed(currentTime, addSuccess)
|
||||
handleFailure := handleFailureOnClosed(currentTime, addError, checkClosedState, openCircuit)
|
||||
|
||||
onClosed := func(modify io.Kleisli[Endomorphism[BreakerState], BreakerState]) Operator {
|
||||
|
||||
return F.Flow2(
|
||||
// error case
|
||||
chainFirstLeftIOK(F.Flow3(
|
||||
checkError,
|
||||
option.Fold(
|
||||
// the error is not applicable, handle as success
|
||||
F.Pipe2(
|
||||
modify,
|
||||
handleSuccess,
|
||||
lazy.Of,
|
||||
),
|
||||
// the error is relevant, record it
|
||||
F.Pipe2(
|
||||
modify,
|
||||
handleFailure,
|
||||
reader.Of[E],
|
||||
),
|
||||
),
|
||||
// metering
|
||||
io.ChainFirst(either.Fold(
|
||||
F.Flow2(
|
||||
openedAtLens.Get,
|
||||
metrics.Open,
|
||||
),
|
||||
func(c ClosedState) IO[Void] {
|
||||
return io.Of(function.VOID)
|
||||
},
|
||||
)),
|
||||
)),
|
||||
// good case
|
||||
chainFirstIOK(F.Pipe2(
|
||||
modify,
|
||||
handleSuccess,
|
||||
reader.Of[T],
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
onCanary := func(modify io.Kleisli[Endomorphism[BreakerState], BreakerState]) Operator {
|
||||
|
||||
handleSuccess := F.Pipe2(
|
||||
closedCircuit,
|
||||
reader.Of[BreakerState],
|
||||
modify,
|
||||
)
|
||||
|
||||
return F.Flow2(
|
||||
// the canary request fails
|
||||
chainFirstLeftIOK(F.Flow2(
|
||||
checkError,
|
||||
option.Fold(
|
||||
// the canary request succeeds, we close the circuit
|
||||
F.Pipe1(
|
||||
handleSuccess,
|
||||
lazy.Of,
|
||||
),
|
||||
// the canary request fails, we extend the circuit
|
||||
F.Pipe1(
|
||||
F.Pipe1(
|
||||
currentTime,
|
||||
io.Chain(func(ct time.Time) IO[BreakerState] {
|
||||
return F.Pipe1(
|
||||
F.Flow2(
|
||||
either.Fold(
|
||||
extendOpenCircuit(ct),
|
||||
F.Pipe1(
|
||||
openCircuit(ct),
|
||||
reader.Of[ClosedState],
|
||||
),
|
||||
),
|
||||
createOpenCircuit,
|
||||
),
|
||||
modify,
|
||||
)
|
||||
}),
|
||||
),
|
||||
reader.Of[E],
|
||||
),
|
||||
),
|
||||
)),
|
||||
// the canary request succeeds, we'll close the circuit
|
||||
chainFirstIOK(F.Pipe1(
|
||||
handleSuccess,
|
||||
reader.Of[T],
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
onOpen := func(ref IORef[BreakerState]) Operator {
|
||||
|
||||
modify := modifyV(ref)
|
||||
|
||||
return F.Pipe3(
|
||||
currentTime,
|
||||
io.Chain(func(ct time.Time) IO[Operator] {
|
||||
return F.Pipe1(
|
||||
ref,
|
||||
ioref.ModifyWithResult(either.Fold(
|
||||
func(open openState) Pair[BreakerState, Operator] {
|
||||
return option.Fold(
|
||||
func() Pair[BreakerState, Operator] {
|
||||
return pair.MakePair(createOpenCircuit(open), failWithError(open))
|
||||
},
|
||||
func(open openState) Pair[BreakerState, Operator] {
|
||||
return pair.MakePair(createOpenCircuit(testCircuit(open)), onCanary(modify))
|
||||
},
|
||||
)(isResetTimeExceeded(ct)(open))
|
||||
},
|
||||
func(closed ClosedState) Pair[BreakerState, Operator] {
|
||||
return pair.MakePair(createClosedCircuit(closed), onClosed(modify))
|
||||
},
|
||||
)),
|
||||
)
|
||||
}),
|
||||
fromIO,
|
||||
func(src HKTOP) Operator {
|
||||
return func(rdr HKTT) HKTT {
|
||||
return F.Pipe2(
|
||||
src,
|
||||
flap(rdr),
|
||||
flatten,
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return func(e Pair[IORef[BreakerState], HKTT]) Pair[Pair[IORef[BreakerState], HKTT], HKTT] {
|
||||
return pair.MakePair(e, onOpen(pair.Head(e))(pair.Tail(e)))
|
||||
}
|
||||
}
|
||||
|
||||
// MakeSingletonBreaker creates a singleton circuit breaker operator for a higher-kinded type.
|
||||
//
|
||||
// This function creates a circuit breaker that maintains its own internal state reference.
|
||||
// It's called "singleton" because it creates a single, self-contained circuit breaker instance
|
||||
// with its own IORef for state management. The returned function can be used to wrap
|
||||
// computations with circuit breaker protection.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - HKTT: The higher-kinded type representing the computation (e.g., IO[T], ReaderIO[R, T])
|
||||
//
|
||||
// Parameters:
|
||||
// - cb: The circuit breaker State monad created by MakeCircuitBreaker
|
||||
// - closedState: The initial closed state configuration for the circuit breaker
|
||||
//
|
||||
// Returns:
|
||||
// - A function that wraps a computation (HKTT) with circuit breaker logic.
|
||||
// The circuit breaker state is managed internally and persists across invocations.
|
||||
//
|
||||
// Thread Safety: The returned function is thread-safe. The internal IORef[BreakerState]
|
||||
// uses atomic operations to manage state. Multiple concurrent calls to the returned function
|
||||
// will be properly serialized at the state modification level.
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// // Create a circuit breaker for IO operations
|
||||
// breaker := MakeSingletonBreaker(
|
||||
// MakeCircuitBreaker(...),
|
||||
// MakeClosedStateCounter(3),
|
||||
// )
|
||||
//
|
||||
// // Use it to wrap operations
|
||||
// protectedOp := breaker(myIOOperation)
|
||||
func MakeSingletonBreaker[HKTT any](
|
||||
cb State[Pair[IORef[BreakerState], HKTT], HKTT],
|
||||
closedState ClosedState,
|
||||
) func(HKTT) HKTT {
|
||||
return F.Flow3(
|
||||
F.Pipe3(
|
||||
closedState,
|
||||
MakeClosedIORef,
|
||||
io.Run,
|
||||
pair.FromHead[HKTT],
|
||||
),
|
||||
cb,
|
||||
pair.Tail,
|
||||
)
|
||||
}
|
||||
579
v2/circuitbreaker/circuitbreaker_test.go
Normal file
579
v2/circuitbreaker/circuitbreaker_test.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/ioref"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
"github.com/IBM/fp-go/v2/retry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testMetrics struct {
|
||||
accepts int
|
||||
rejects int
|
||||
opens int
|
||||
closes int
|
||||
canary int
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *testMetrics) Accept(_ time.Time) IO[Void] {
|
||||
return func() Void {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.accepts++
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *testMetrics) Open(_ time.Time) IO[Void] {
|
||||
return func() Void {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.opens++
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *testMetrics) Close(_ time.Time) IO[Void] {
|
||||
return func() Void {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.closes++
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *testMetrics) Reject(_ time.Time) IO[Void] {
|
||||
return func() Void {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.rejects++
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *testMetrics) Canary(_ time.Time) IO[Void] {
|
||||
return func() Void {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.canary++
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
|
||||
// VirtualTimer provides a controllable time source for testing
|
||||
type VirtualTimer struct {
|
||||
mu sync.Mutex
|
||||
current time.Time
|
||||
}
|
||||
|
||||
func NewMockMetrics() Metrics {
|
||||
return &testMetrics{}
|
||||
}
|
||||
|
||||
// NewVirtualTimer creates a new virtual timer starting at the given time
|
||||
func NewVirtualTimer(start time.Time) *VirtualTimer {
|
||||
return &VirtualTimer{current: start}
|
||||
}
|
||||
|
||||
// Now returns the current virtual time
|
||||
func (vt *VirtualTimer) Now() time.Time {
|
||||
vt.mu.Lock()
|
||||
defer vt.mu.Unlock()
|
||||
return vt.current
|
||||
}
|
||||
|
||||
// Advance moves the virtual time forward by the given duration
|
||||
func (vt *VirtualTimer) Advance(d time.Duration) {
|
||||
vt.mu.Lock()
|
||||
defer vt.mu.Unlock()
|
||||
vt.current = vt.current.Add(d)
|
||||
}
|
||||
|
||||
// Set sets the virtual time to a specific value
|
||||
func (vt *VirtualTimer) Set(t time.Time) {
|
||||
vt.mu.Lock()
|
||||
defer vt.mu.Unlock()
|
||||
vt.current = t
|
||||
}
|
||||
|
||||
// TestModifyV tests the modifyV variable
|
||||
func TestModifyV(t *testing.T) {
|
||||
t.Run("modifyV creates a Reader that modifies IORef", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
|
||||
// Create initial state
|
||||
initialState := createClosedCircuit(MakeClosedStateCounter(3))
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
|
||||
// Create an endomorphism that opens the circuit
|
||||
now := vt.Now()
|
||||
openState := openState{
|
||||
openedAt: now,
|
||||
resetAt: now.Add(1 * time.Minute),
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
endomorphism := func(bs BreakerState) BreakerState {
|
||||
return createOpenCircuit(openState)
|
||||
}
|
||||
|
||||
// Apply modifyV
|
||||
modifyOp := modifyV(ref)
|
||||
result := io.Run(modifyOp(endomorphism))
|
||||
|
||||
// Verify the state was modified
|
||||
assert.True(t, IsOpen(result), "state should be open after modification")
|
||||
})
|
||||
|
||||
t.Run("modifyV returns the new state", func(t *testing.T) {
|
||||
initialState := createClosedCircuit(MakeClosedStateCounter(3))
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
|
||||
// Create a simple endomorphism
|
||||
endomorphism := F.Identity[BreakerState]
|
||||
|
||||
modifyOp := modifyV(ref)
|
||||
result := io.Run(modifyOp(endomorphism))
|
||||
|
||||
assert.True(t, IsClosed(result), "state should remain closed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestTestCircuit tests the testCircuit variable
|
||||
func TestTestCircuit(t *testing.T) {
|
||||
t.Run("testCircuit sets canaryRequest to true", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
now := vt.Now()
|
||||
|
||||
openState := openState{
|
||||
openedAt: now,
|
||||
resetAt: now.Add(1 * time.Minute),
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := testCircuit(openState)
|
||||
|
||||
assert.True(t, result.canaryRequest, "canaryRequest should be set to true")
|
||||
assert.Equal(t, openState.openedAt, result.openedAt, "openedAt should be unchanged")
|
||||
assert.Equal(t, openState.resetAt, result.resetAt, "resetAt should be unchanged")
|
||||
})
|
||||
|
||||
t.Run("testCircuit is idempotent", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
now := vt.Now()
|
||||
|
||||
openState := openState{
|
||||
openedAt: now,
|
||||
resetAt: now.Add(1 * time.Minute),
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: true, // already true
|
||||
}
|
||||
|
||||
result := testCircuit(openState)
|
||||
|
||||
assert.True(t, result.canaryRequest, "canaryRequest should remain true")
|
||||
})
|
||||
|
||||
t.Run("testCircuit preserves other fields", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
now := vt.Now()
|
||||
resetTime := now.Add(2 * time.Minute)
|
||||
retryStatus := retry.RetryStatus{
|
||||
IterNumber: 5,
|
||||
PreviousDelay: option.Some(30 * time.Second),
|
||||
}
|
||||
|
||||
openState := openState{
|
||||
openedAt: now,
|
||||
resetAt: resetTime,
|
||||
retryStatus: retryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := testCircuit(openState)
|
||||
|
||||
assert.Equal(t, now, result.openedAt, "openedAt should be preserved")
|
||||
assert.Equal(t, resetTime, result.resetAt, "resetAt should be preserved")
|
||||
assert.Equal(t, retryStatus.IterNumber, result.retryStatus.IterNumber, "retryStatus should be preserved")
|
||||
assert.True(t, result.canaryRequest, "canaryRequest should be set to true")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMakeOpenCircuitFromPolicy tests the makeOpenCircuitFromPolicy function
|
||||
func TestMakeOpenCircuitFromPolicy(t *testing.T) {
|
||||
t.Run("creates openState with calculated reset time", func(t *testing.T) {
|
||||
policy := retry.LimitRetries(5)
|
||||
makeOpen := makeOpenCircuitFromPolicy(policy)
|
||||
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
result := makeOpen(retry.DefaultRetryStatus)(currentTime)
|
||||
|
||||
assert.Equal(t, currentTime, result.openedAt, "openedAt should be current time")
|
||||
assert.False(t, result.canaryRequest, "canaryRequest should be false initially")
|
||||
assert.NotNil(t, result.retryStatus, "retryStatus should be set")
|
||||
})
|
||||
|
||||
t.Run("applies retry policy to calculate delay", func(t *testing.T) {
|
||||
// Use exponential backoff policy with limit and cap
|
||||
policy := retry.Monoid.Concat(
|
||||
retry.LimitRetries(10),
|
||||
retry.CapDelay(10*time.Second, retry.ExponentialBackoff(1*time.Second)),
|
||||
)
|
||||
makeOpen := makeOpenCircuitFromPolicy(policy)
|
||||
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// First retry (iter 0)
|
||||
result1 := makeOpen(retry.DefaultRetryStatus)(currentTime)
|
||||
|
||||
// The first delay should be approximately 1 second
|
||||
expectedResetTime1 := currentTime.Add(1 * time.Second)
|
||||
assert.WithinDuration(t, expectedResetTime1, result1.resetAt, 100*time.Millisecond,
|
||||
"first reset time should be ~1 second from now")
|
||||
|
||||
// Second retry (iter 1) - should double
|
||||
result2 := makeOpen(result1.retryStatus)(currentTime)
|
||||
expectedResetTime2 := currentTime.Add(2 * time.Second)
|
||||
assert.WithinDuration(t, expectedResetTime2, result2.resetAt, 100*time.Millisecond,
|
||||
"second reset time should be ~2 seconds from now")
|
||||
})
|
||||
|
||||
t.Run("handles first failure with no previous delay", func(t *testing.T) {
|
||||
policy := retry.LimitRetries(3)
|
||||
makeOpen := makeOpenCircuitFromPolicy(policy)
|
||||
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
result := makeOpen(retry.DefaultRetryStatus)(currentTime)
|
||||
|
||||
// With no previous delay, resetAt should be current time
|
||||
assert.Equal(t, currentTime, result.resetAt, "resetAt should be current time when no previous delay")
|
||||
})
|
||||
|
||||
t.Run("increments retry iteration number", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
policy := retry.LimitRetries(10)
|
||||
makeOpen := makeOpenCircuitFromPolicy(policy)
|
||||
|
||||
currentTime := vt.Now()
|
||||
initialStatus := retry.DefaultRetryStatus
|
||||
|
||||
result := makeOpen(initialStatus)(currentTime)
|
||||
|
||||
assert.Greater(t, result.retryStatus.IterNumber, initialStatus.IterNumber,
|
||||
"retry iteration should be incremented")
|
||||
})
|
||||
|
||||
t.Run("curried function can be partially applied", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
policy := retry.LimitRetries(5)
|
||||
makeOpen := makeOpenCircuitFromPolicy(policy)
|
||||
|
||||
// Partially apply with retry status
|
||||
makeOpenWithStatus := makeOpen(retry.DefaultRetryStatus)
|
||||
|
||||
currentTime := vt.Now()
|
||||
result := makeOpenWithStatus(currentTime)
|
||||
|
||||
assert.NotNil(t, result, "partially applied function should work")
|
||||
assert.Equal(t, currentTime, result.openedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// TestExtendOpenCircuitFromMakeCircuit tests the extendOpenCircuitFromMakeCircuit function
|
||||
func TestExtendOpenCircuitFromMakeCircuit(t *testing.T) {
|
||||
t.Run("extends open circuit with new retry status", func(t *testing.T) {
|
||||
policy := retry.Monoid.Concat(
|
||||
retry.LimitRetries(10),
|
||||
retry.ExponentialBackoff(1*time.Second),
|
||||
)
|
||||
makeCircuit := makeOpenCircuitFromPolicy(policy)
|
||||
extendCircuit := extendOpenCircuitFromMakeCircuit(makeCircuit)
|
||||
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create initial open state
|
||||
initialOpen := openState{
|
||||
openedAt: currentTime.Add(-1 * time.Minute),
|
||||
resetAt: currentTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
// Extend the circuit
|
||||
extendOp := extendCircuit(currentTime)
|
||||
result := extendOp(initialOpen)
|
||||
|
||||
assert.True(t, result.canaryRequest, "canaryRequest should be set to true")
|
||||
assert.Greater(t, result.retryStatus.IterNumber, initialOpen.retryStatus.IterNumber,
|
||||
"retry iteration should be incremented")
|
||||
assert.True(t, result.resetAt.After(currentTime), "resetAt should be in the future")
|
||||
})
|
||||
|
||||
t.Run("sets canaryRequest to true for next test", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
policy := retry.LimitRetries(5)
|
||||
makeCircuit := makeOpenCircuitFromPolicy(policy)
|
||||
extendCircuit := extendOpenCircuitFromMakeCircuit(makeCircuit)
|
||||
|
||||
currentTime := vt.Now()
|
||||
initialOpen := openState{
|
||||
openedAt: currentTime.Add(-30 * time.Second),
|
||||
resetAt: currentTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := extendCircuit(currentTime)(initialOpen)
|
||||
|
||||
assert.True(t, result.canaryRequest, "canaryRequest must be true after extension")
|
||||
})
|
||||
|
||||
t.Run("applies exponential backoff on successive extensions", func(t *testing.T) {
|
||||
policy := retry.Monoid.Concat(
|
||||
retry.LimitRetries(10),
|
||||
retry.ExponentialBackoff(1*time.Second),
|
||||
)
|
||||
makeCircuit := makeOpenCircuitFromPolicy(policy)
|
||||
extendCircuit := extendOpenCircuitFromMakeCircuit(makeCircuit)
|
||||
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// First extension
|
||||
state1 := openState{
|
||||
openedAt: currentTime,
|
||||
resetAt: currentTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
result1 := extendCircuit(currentTime)(state1)
|
||||
delay1 := result1.resetAt.Sub(currentTime)
|
||||
|
||||
// Second extension (should have longer delay)
|
||||
result2 := extendCircuit(currentTime)(result1)
|
||||
delay2 := result2.resetAt.Sub(currentTime)
|
||||
|
||||
assert.Greater(t, delay2, delay1, "second extension should have longer delay due to exponential backoff")
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsResetTimeExceeded tests the isResetTimeExceeded function
|
||||
func TestIsResetTimeExceeded(t *testing.T) {
|
||||
t.Run("returns Some when reset time is exceeded and no canary active", func(t *testing.T) {
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
resetTime := currentTime.Add(-1 * time.Second) // in the past
|
||||
|
||||
openState := openState{
|
||||
openedAt: currentTime.Add(-1 * time.Minute),
|
||||
resetAt: resetTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := isResetTimeExceeded(currentTime)(openState)
|
||||
|
||||
assert.True(t, option.IsSome(result), "should return Some when reset time exceeded")
|
||||
})
|
||||
|
||||
t.Run("returns None when reset time not yet exceeded", func(t *testing.T) {
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
resetTime := currentTime.Add(1 * time.Minute) // in the future
|
||||
|
||||
openState := openState{
|
||||
openedAt: currentTime.Add(-30 * time.Second),
|
||||
resetAt: resetTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := isResetTimeExceeded(currentTime)(openState)
|
||||
|
||||
assert.True(t, option.IsNone(result), "should return None when reset time not exceeded")
|
||||
})
|
||||
|
||||
t.Run("returns None when canary request is already active", func(t *testing.T) {
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
resetTime := currentTime.Add(-1 * time.Second) // in the past
|
||||
|
||||
openState := openState{
|
||||
openedAt: currentTime.Add(-1 * time.Minute),
|
||||
resetAt: resetTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: true, // canary already active
|
||||
}
|
||||
|
||||
result := isResetTimeExceeded(currentTime)(openState)
|
||||
|
||||
assert.True(t, option.IsNone(result), "should return None when canary is already active")
|
||||
})
|
||||
|
||||
t.Run("returns Some at exact reset time boundary", func(t *testing.T) {
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
resetTime := currentTime.Add(-1 * time.Nanosecond) // just passed
|
||||
|
||||
openState := openState{
|
||||
openedAt: currentTime.Add(-1 * time.Minute),
|
||||
resetAt: resetTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := isResetTimeExceeded(currentTime)(openState)
|
||||
|
||||
assert.True(t, option.IsSome(result), "should return Some when current time is after reset time")
|
||||
})
|
||||
|
||||
t.Run("returns None when current time equals reset time", func(t *testing.T) {
|
||||
currentTime := time.Date(2026, 1, 9, 12, 0, 0, 0, time.UTC)
|
||||
resetTime := currentTime // exactly equal
|
||||
|
||||
openState := openState{
|
||||
openedAt: currentTime.Add(-1 * time.Minute),
|
||||
resetAt: resetTime,
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
|
||||
result := isResetTimeExceeded(currentTime)(openState)
|
||||
|
||||
assert.True(t, option.IsNone(result), "should return None when times are equal (not After)")
|
||||
})
|
||||
}
|
||||
|
||||
// TestHandleSuccessOnClosed tests the handleSuccessOnClosed function
|
||||
func TestHandleSuccessOnClosed(t *testing.T) {
|
||||
t.Run("resets failure count on success", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
currentTime := vt.Now
|
||||
addSuccess := reader.From1(ClosedState.AddSuccess)
|
||||
|
||||
// Create initial state with some failures
|
||||
now := vt.Now()
|
||||
initialClosed := MakeClosedStateCounter(3)
|
||||
initialClosed = initialClosed.AddError(now)
|
||||
initialClosed = initialClosed.AddError(now)
|
||||
initialState := createClosedCircuit(initialClosed)
|
||||
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
modify := modifyV(ref)
|
||||
|
||||
handler := handleSuccessOnClosed(currentTime, addSuccess)
|
||||
|
||||
// Apply the handler
|
||||
result := io.Run(handler(modify))
|
||||
|
||||
// Verify state is still closed and failures are reset
|
||||
assert.True(t, IsClosed(result), "circuit should remain closed after success")
|
||||
})
|
||||
|
||||
t.Run("keeps circuit closed", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
currentTime := vt.Now
|
||||
addSuccess := reader.From1(ClosedState.AddSuccess)
|
||||
|
||||
initialState := createClosedCircuit(MakeClosedStateCounter(3))
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
modify := modifyV(ref)
|
||||
|
||||
handler := handleSuccessOnClosed(currentTime, addSuccess)
|
||||
result := io.Run(handler(modify))
|
||||
|
||||
assert.True(t, IsClosed(result), "circuit should remain closed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestHandleFailureOnClosed tests the handleFailureOnClosed function
|
||||
func TestHandleFailureOnClosed(t *testing.T) {
|
||||
t.Run("keeps circuit closed when threshold not exceeded", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
currentTime := vt.Now
|
||||
addError := reader.From1(ClosedState.AddError)
|
||||
checkClosedState := reader.From1(ClosedState.Check)
|
||||
openCircuit := func(ct time.Time) openState {
|
||||
return openState{
|
||||
openedAt: ct,
|
||||
resetAt: ct.Add(1 * time.Minute),
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Create initial state with room for more failures
|
||||
now := vt.Now()
|
||||
initialClosed := MakeClosedStateCounter(5) // threshold is 5
|
||||
initialClosed = initialClosed.AddError(now)
|
||||
initialState := createClosedCircuit(initialClosed)
|
||||
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
modify := modifyV(ref)
|
||||
|
||||
handler := handleFailureOnClosed(currentTime, addError, checkClosedState, openCircuit)
|
||||
result := io.Run(handler(modify))
|
||||
|
||||
assert.True(t, IsClosed(result), "circuit should remain closed when threshold not exceeded")
|
||||
})
|
||||
|
||||
t.Run("opens circuit when threshold exceeded", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
currentTime := vt.Now
|
||||
addError := reader.From1(ClosedState.AddError)
|
||||
checkClosedState := reader.From1(ClosedState.Check)
|
||||
openCircuit := func(ct time.Time) openState {
|
||||
return openState{
|
||||
openedAt: ct,
|
||||
resetAt: ct.Add(1 * time.Minute),
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Create initial state at threshold
|
||||
now := vt.Now()
|
||||
initialClosed := MakeClosedStateCounter(2) // threshold is 2
|
||||
initialClosed = initialClosed.AddError(now)
|
||||
initialState := createClosedCircuit(initialClosed)
|
||||
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
modify := modifyV(ref)
|
||||
|
||||
handler := handleFailureOnClosed(currentTime, addError, checkClosedState, openCircuit)
|
||||
result := io.Run(handler(modify))
|
||||
|
||||
assert.True(t, IsOpen(result), "circuit should open when threshold exceeded")
|
||||
})
|
||||
|
||||
t.Run("records failure in closed state", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
currentTime := vt.Now
|
||||
addError := reader.From1(ClosedState.AddError)
|
||||
checkClosedState := reader.From1(ClosedState.Check)
|
||||
openCircuit := func(ct time.Time) openState {
|
||||
return openState{
|
||||
openedAt: ct,
|
||||
resetAt: ct.Add(1 * time.Minute),
|
||||
retryStatus: retry.DefaultRetryStatus,
|
||||
canaryRequest: false,
|
||||
}
|
||||
}
|
||||
|
||||
initialState := createClosedCircuit(MakeClosedStateCounter(10))
|
||||
ref := io.Run(ioref.MakeIORef(initialState))
|
||||
modify := modifyV(ref)
|
||||
|
||||
handler := handleFailureOnClosed(currentTime, addError, checkClosedState, openCircuit)
|
||||
result := io.Run(handler(modify))
|
||||
|
||||
// Should still be closed but with failure recorded
|
||||
assert.True(t, IsClosed(result), "circuit should remain closed")
|
||||
})
|
||||
}
|
||||
329
v2/circuitbreaker/closed.go
Normal file
329
v2/circuitbreaker/closed.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
A "github.com/IBM/fp-go/v2/array"
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
N "github.com/IBM/fp-go/v2/number"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/ord"
|
||||
)
|
||||
|
||||
type (
|
||||
// ClosedState represents the closed state of a circuit breaker.
|
||||
// In the closed state, requests are allowed to pass through, but failures are tracked.
|
||||
// If a failure condition is met, the circuit breaker transitions to an open state.
|
||||
//
|
||||
// # Thread Safety
|
||||
//
|
||||
// All ClosedState implementations MUST be thread-safe. The recommended approach is to
|
||||
// make all methods return new copies rather than modifying the receiver, which provides
|
||||
// automatic thread safety through immutability.
|
||||
//
|
||||
// Implementations should ensure that:
|
||||
// - Empty() returns a new instance with cleared state
|
||||
// - AddError() returns a new instance with the error recorded
|
||||
// - AddSuccess() returns a new instance with success recorded
|
||||
// - Check() does not modify the receiver
|
||||
//
|
||||
// Both provided implementations (closedStateWithErrorCount and closedStateWithHistory)
|
||||
// follow this pattern and are safe for concurrent use.
|
||||
ClosedState interface {
|
||||
// Empty returns a new ClosedState with all tracked failures cleared.
|
||||
// This is used when transitioning back to a closed state from an open state.
|
||||
//
|
||||
// Thread Safety: Returns a new instance; safe for concurrent use.
|
||||
Empty() ClosedState
|
||||
|
||||
// AddError records a failure at the given time.
|
||||
// Returns an updated ClosedState reflecting the recorded failure.
|
||||
//
|
||||
// Thread Safety: Returns a new instance; safe for concurrent use.
|
||||
// The original ClosedState is not modified.
|
||||
AddError(time.Time) ClosedState
|
||||
|
||||
// AddSuccess records a successful request at the given time.
|
||||
// Returns an updated ClosedState reflecting the successful request.
|
||||
//
|
||||
// Thread Safety: Returns a new instance; safe for concurrent use.
|
||||
// The original ClosedState is not modified.
|
||||
AddSuccess(time.Time) ClosedState
|
||||
|
||||
// Check verifies if the circuit breaker should remain closed at the given time.
|
||||
// Returns Some(ClosedState) if the circuit should stay closed,
|
||||
// or None if the circuit should open due to exceeding the failure threshold.
|
||||
//
|
||||
// Thread Safety: Does not modify the receiver; safe for concurrent use.
|
||||
Check(time.Time) Option[ClosedState]
|
||||
}
|
||||
|
||||
// closedStateWithErrorCount is a counter-based implementation of ClosedState.
|
||||
// It tracks the number of consecutive failures and opens the circuit when
|
||||
// the failure count exceeds a configured threshold.
|
||||
//
|
||||
// Thread Safety: This implementation is immutable. All methods return new instances
|
||||
// rather than modifying the receiver, making it safe for concurrent use without locks.
|
||||
closedStateWithErrorCount struct {
|
||||
// checkFailures is a Kleisli arrow that checks if the failure count exceeds the threshold.
|
||||
// Returns Some(count) if threshold is exceeded, None otherwise.
|
||||
checkFailures option.Kleisli[uint, uint]
|
||||
// failureCount tracks the current number of consecutive failures.
|
||||
failureCount uint
|
||||
}
|
||||
|
||||
// closedStateWithHistory is a time-window-based implementation of ClosedState.
|
||||
// It tracks failures within a sliding time window and opens the circuit when
|
||||
// the failure count within the window exceeds a configured threshold.
|
||||
//
|
||||
// Thread Safety: This implementation is immutable. All methods return new instances
|
||||
// with new slices rather than modifying the receiver, making it safe for concurrent
|
||||
// use without locks. The history slice is never modified in place; addToSlice always
|
||||
// creates a new slice.
|
||||
closedStateWithHistory struct {
|
||||
ordTime Ord[time.Time]
|
||||
// maxFailures is the maximum number of failures allowed within the time window.
|
||||
checkFailures option.Kleisli[int, int]
|
||||
timeWindow time.Duration
|
||||
history []time.Time
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
failureCountLens = lens.MakeLensStrictWithName(
|
||||
func(s *closedStateWithErrorCount) uint { return s.failureCount },
|
||||
func(s *closedStateWithErrorCount, c uint) *closedStateWithErrorCount {
|
||||
s.failureCount = c
|
||||
return s
|
||||
},
|
||||
"closeStateWithErrorCount.failureCount",
|
||||
)
|
||||
|
||||
historyLens = lens.MakeLensRefWithName(
|
||||
func(s *closedStateWithHistory) []time.Time { return s.history },
|
||||
func(s *closedStateWithHistory, c []time.Time) *closedStateWithHistory {
|
||||
s.history = c
|
||||
return s
|
||||
},
|
||||
"closedStateWithHistory.history",
|
||||
)
|
||||
|
||||
resetHistory = historyLens.Set(A.Empty[time.Time]())
|
||||
resetFailureCount = failureCountLens.Set(0)
|
||||
incFailureCount = lens.Modify[*closedStateWithErrorCount](N.Add(uint(1)))(failureCountLens)
|
||||
)
|
||||
|
||||
// Empty returns a new closedStateWithErrorCount with the failure count reset to zero.
|
||||
//
|
||||
// Thread Safety: Returns a new instance; the original is not modified.
|
||||
// Safe for concurrent use.
|
||||
func (s *closedStateWithErrorCount) Empty() ClosedState {
|
||||
return resetFailureCount(s)
|
||||
}
|
||||
|
||||
// AddError increments the failure count and returns a new closedStateWithErrorCount.
|
||||
// The time parameter is ignored in this counter-based implementation.
|
||||
//
|
||||
// Thread Safety: Returns a new instance; the original is not modified.
|
||||
// Safe for concurrent use.
|
||||
func (s *closedStateWithErrorCount) AddError(_ time.Time) ClosedState {
|
||||
return incFailureCount(s)
|
||||
}
|
||||
|
||||
// AddSuccess resets the failure count to zero and returns a new closedStateWithErrorCount.
|
||||
// The time parameter is ignored in this counter-based implementation.
|
||||
//
|
||||
// Thread Safety: Returns a new instance; the original is not modified.
|
||||
// Safe for concurrent use.
|
||||
func (s *closedStateWithErrorCount) AddSuccess(_ time.Time) ClosedState {
|
||||
return resetFailureCount(s)
|
||||
}
|
||||
|
||||
// Check verifies if the failure count is below the threshold.
|
||||
// Returns Some(ClosedState) if below threshold, None if at or above threshold.
|
||||
// The time parameter is ignored in this counter-based implementation.
|
||||
//
|
||||
// Thread Safety: Does not modify the receiver; safe for concurrent use.
|
||||
func (s *closedStateWithErrorCount) Check(_ time.Time) Option[ClosedState] {
|
||||
return F.Pipe3(
|
||||
s,
|
||||
failureCountLens.Get,
|
||||
s.checkFailures,
|
||||
option.MapTo[uint](ClosedState(s)),
|
||||
)
|
||||
}
|
||||
|
||||
// MakeClosedStateCounter creates a counter-based ClosedState implementation.
|
||||
// The circuit breaker will open when the number of consecutive failures reaches maxFailures.
|
||||
//
|
||||
// Parameters:
|
||||
// - maxFailures: The threshold for consecutive failures. The circuit opens when
|
||||
// failureCount >= maxFailures (greater than or equal to).
|
||||
//
|
||||
// Returns:
|
||||
// - A ClosedState that tracks failures using a simple counter.
|
||||
//
|
||||
// Example:
|
||||
// - If maxFailures is 3, the circuit will open on the 3rd consecutive failure.
|
||||
// - Each AddError call increments the counter.
|
||||
// - Each AddSuccess call resets the counter to 0 (only consecutive failures count).
|
||||
// - Empty resets the counter to 0.
|
||||
//
|
||||
// Behavior:
|
||||
// - Check returns Some(ClosedState) when failureCount < maxFailures (circuit stays closed)
|
||||
// - Check returns None when failureCount >= maxFailures (circuit should open)
|
||||
// - AddSuccess resets the failure count, so only consecutive failures trigger circuit opening
|
||||
//
|
||||
// Thread Safety: The returned ClosedState is safe for concurrent use. All methods
|
||||
// return new instances rather than modifying the receiver.
|
||||
func MakeClosedStateCounter(maxFailures uint) ClosedState {
|
||||
return &closedStateWithErrorCount{
|
||||
checkFailures: option.FromPredicate(N.LessThan(maxFailures)),
|
||||
}
|
||||
}
|
||||
|
||||
// Empty returns a new closedStateWithHistory with an empty failure history.
|
||||
//
|
||||
// Thread Safety: Returns a new instance with a new empty slice; the original is not modified.
|
||||
// Safe for concurrent use.
|
||||
func (s *closedStateWithHistory) Empty() ClosedState {
|
||||
return resetHistory(s)
|
||||
}
|
||||
|
||||
// addToSlice creates a new sorted slice by adding an item to an existing slice.
|
||||
// This function does not modify the input slice; it creates a new slice with the item added
|
||||
// and returns it in sorted order.
|
||||
//
|
||||
// Parameters:
|
||||
// - o: An Ord instance for comparing time.Time values to determine sort order
|
||||
// - ar: The existing slice of time.Time values (assumed to be sorted)
|
||||
// - item: The new time.Time value to add to the slice
|
||||
//
|
||||
// Returns:
|
||||
// - A new slice containing all elements from ar plus the new item, sorted in ascending order
|
||||
//
|
||||
// Implementation Details:
|
||||
// - Creates a new slice with capacity len(ar)+1
|
||||
// - Copies all elements from ar to the new slice
|
||||
// - Appends the new item
|
||||
// - Sorts the entire slice using the provided Ord comparator
|
||||
//
|
||||
// Thread Safety: This function is pure and does not modify its inputs. It always returns
|
||||
// a new slice, making it safe for concurrent use. This is a key component of the immutable
|
||||
// design of closedStateWithHistory.
|
||||
//
|
||||
// Note: This function is used internally by closedStateWithHistory.AddError to maintain
|
||||
// a sorted history of failure timestamps for efficient binary search operations.
|
||||
func addToSlice(o ord.Ord[time.Time], ar []time.Time, item time.Time) []time.Time {
|
||||
cpy := make([]time.Time, len(ar)+1)
|
||||
cpy[copy(cpy, ar)] = item
|
||||
slices.SortFunc(cpy, o.Compare)
|
||||
return cpy
|
||||
}
|
||||
|
||||
// AddError records a failure at the given time and returns a new closedStateWithHistory.
|
||||
// The new instance contains the failure in its history, with old failures outside the
|
||||
// time window automatically pruned.
|
||||
//
|
||||
// Thread Safety: Returns a new instance with a new history slice; the original is not modified.
|
||||
// Safe for concurrent use. The addToSlice function creates a new slice, ensuring immutability.
|
||||
func (s *closedStateWithHistory) AddError(currentTime time.Time) ClosedState {
|
||||
|
||||
addFailureToHistory := F.Pipe1(
|
||||
historyLens,
|
||||
lens.Modify[*closedStateWithHistory](func(old []time.Time) []time.Time {
|
||||
// oldest valid entry
|
||||
idx, _ := slices.BinarySearchFunc(old, currentTime.Add(-s.timeWindow), s.ordTime.Compare)
|
||||
return addToSlice(s.ordTime, old[idx:], currentTime)
|
||||
}),
|
||||
)
|
||||
|
||||
return addFailureToHistory(s)
|
||||
}
|
||||
|
||||
// AddSuccess purges the entire failure history and returns a new closedStateWithHistory.
|
||||
// The time parameter is ignored; any success clears all tracked failures.
|
||||
//
|
||||
// Thread Safety: Returns a new instance with a new empty slice; the original is not modified.
|
||||
// Safe for concurrent use.
|
||||
func (s *closedStateWithHistory) AddSuccess(_ time.Time) ClosedState {
|
||||
return resetHistory(s)
|
||||
}
|
||||
|
||||
// Check verifies if the number of failures in the history is below the threshold.
|
||||
// Returns Some(ClosedState) if below threshold, None if at or above threshold.
|
||||
// The time parameter is ignored; the check is based on the current history size.
|
||||
//
|
||||
// Thread Safety: Does not modify the receiver; safe for concurrent use.
|
||||
func (s *closedStateWithHistory) Check(_ time.Time) Option[ClosedState] {
|
||||
|
||||
return F.Pipe4(
|
||||
s,
|
||||
historyLens.Get,
|
||||
A.Size,
|
||||
s.checkFailures,
|
||||
option.MapTo[int](ClosedState(s)),
|
||||
)
|
||||
}
|
||||
|
||||
// MakeClosedStateHistory creates a time-window-based ClosedState implementation.
|
||||
// The circuit breaker will open when the number of failures within a sliding time window reaches maxFailures.
|
||||
//
|
||||
// Unlike MakeClosedStateCounter which tracks consecutive failures, this implementation tracks
|
||||
// all failures within a time window. However, any successful request will purge the entire history,
|
||||
// effectively resetting the failure tracking.
|
||||
//
|
||||
// Parameters:
|
||||
// - timeWindow: The duration of the sliding time window. Failures older than this are automatically
|
||||
// discarded from the history when new failures are added.
|
||||
// - maxFailures: The threshold for failures within the time window. The circuit opens when
|
||||
// the number of failures in the window reaches this value (failureCount >= maxFailures).
|
||||
//
|
||||
// Returns:
|
||||
// - A ClosedState that tracks failures using a time-based sliding window.
|
||||
//
|
||||
// Example:
|
||||
// - If timeWindow is 1 minute and maxFailures is 5, the circuit will open when 5 failures
|
||||
// occur within any 1-minute period.
|
||||
// - Failures older than 1 minute are automatically removed from the history when AddError is called.
|
||||
// - Any successful request immediately purges all tracked failures from the history.
|
||||
//
|
||||
// Behavior:
|
||||
// - AddError records the failure timestamp and removes failures outside the time window
|
||||
// (older than currentTime - timeWindow).
|
||||
// - AddSuccess purges the entire failure history (all tracked failures are removed).
|
||||
// - Check returns Some(ClosedState) when failureCount < maxFailures (circuit stays closed).
|
||||
// - Check returns None when failureCount >= maxFailures (circuit should open).
|
||||
// - Empty purges the entire failure history.
|
||||
//
|
||||
// Time Window Management:
|
||||
// - The history is automatically pruned on each AddError call to remove failures older than
|
||||
// currentTime - timeWindow.
|
||||
// - The history is kept sorted by time for efficient binary search and pruning.
|
||||
//
|
||||
// Important Note:
|
||||
// - A successful request resets everything by purging the entire history. This means that
|
||||
// unlike a pure sliding window, a single success will clear all tracked failures, even
|
||||
// those within the time window. This behavior is similar to MakeClosedStateCounter but
|
||||
// with time-based tracking for failures.
|
||||
//
|
||||
// Thread Safety: The returned ClosedState is safe for concurrent use. All methods return
|
||||
// new instances with new slices rather than modifying the receiver. The history slice is
|
||||
// never modified in place.
|
||||
//
|
||||
// Use Cases:
|
||||
// - Systems where a successful request indicates recovery and past failures should be forgotten.
|
||||
// - Rate limiting with success-based reset: Allow bursts of failures but reset on success.
|
||||
// - Hybrid approach: Time-based failure tracking with success-based recovery.
|
||||
func MakeClosedStateHistory(
|
||||
timeWindow time.Duration,
|
||||
maxFailures uint) ClosedState {
|
||||
return &closedStateWithHistory{
|
||||
checkFailures: option.FromPredicate(N.LessThan(int(maxFailures))),
|
||||
ordTime: ord.OrdTime(),
|
||||
history: A.Empty[time.Time](),
|
||||
timeWindow: timeWindow,
|
||||
}
|
||||
}
|
||||
934
v2/circuitbreaker/closed_test.go
Normal file
934
v2/circuitbreaker/closed_test.go
Normal file
@@ -0,0 +1,934 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/ord"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMakeClosedStateCounter(t *testing.T) {
|
||||
t.Run("creates a valid ClosedState", func(t *testing.T) {
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
assert.NotNil(t, state, "MakeClosedStateCounter should return a non-nil ClosedState")
|
||||
})
|
||||
|
||||
t.Run("initial state passes Check", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
result := state.Check(now)
|
||||
|
||||
assert.True(t, option.IsSome(result), "initial state should pass Check (return Some, circuit stays closed)")
|
||||
})
|
||||
|
||||
t.Run("Empty resets failure count", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Add some errors
|
||||
state = state.AddError(now)
|
||||
state = state.AddError(now)
|
||||
|
||||
// Reset the state
|
||||
state = state.Empty()
|
||||
|
||||
// Should pass check after reset
|
||||
result := state.Check(now)
|
||||
assert.True(t, option.IsSome(result), "state should pass Check after Empty")
|
||||
})
|
||||
|
||||
t.Run("AddSuccess resets failure count", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
// Add errors
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Add success (should reset counter)
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Add another error (this is now the first consecutive error)
|
||||
state = state.AddError(vt.Now())
|
||||
|
||||
// Should still pass check (only 1 consecutive error, threshold is 3)
|
||||
result := state.Check(vt.Now())
|
||||
assert.True(t, option.IsSome(result), "AddSuccess should reset failure count")
|
||||
})
|
||||
|
||||
t.Run("circuit opens when failures reach threshold", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Add errors up to but not including threshold
|
||||
state = state.AddError(now)
|
||||
state = state.AddError(now)
|
||||
|
||||
// Should still pass before threshold
|
||||
result := state.Check(now)
|
||||
assert.True(t, option.IsSome(result), "should pass Check before threshold")
|
||||
|
||||
// Add one more error to reach threshold
|
||||
state = state.AddError(now)
|
||||
|
||||
// Should fail check at threshold
|
||||
result = state.Check(now)
|
||||
assert.True(t, option.IsNone(result), "should fail Check when reaching threshold")
|
||||
})
|
||||
|
||||
t.Run("circuit opens exactly at maxFailures", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(5)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Add exactly maxFailures - 1 errors
|
||||
for i := uint(0); i < maxFailures-1; i++ {
|
||||
state = state.AddError(now)
|
||||
}
|
||||
|
||||
// Should still pass
|
||||
result := state.Check(now)
|
||||
assert.True(t, option.IsSome(result), "should pass Check before maxFailures")
|
||||
|
||||
// Add one more to reach maxFailures
|
||||
state = state.AddError(now)
|
||||
|
||||
// Should fail now
|
||||
result = state.Check(now)
|
||||
assert.True(t, option.IsNone(result), "should fail Check at maxFailures")
|
||||
})
|
||||
|
||||
t.Run("zero maxFailures means circuit is always open", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(0)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Initial state should already fail (0 >= 0)
|
||||
result := state.Check(now)
|
||||
assert.True(t, option.IsNone(result), "initial state should fail Check with maxFailures=0")
|
||||
|
||||
// Add one error
|
||||
state = state.AddError(now)
|
||||
|
||||
// Should still fail
|
||||
result = state.Check(now)
|
||||
assert.True(t, option.IsNone(result), "should fail Check after error with maxFailures=0")
|
||||
})
|
||||
|
||||
t.Run("AddSuccess resets counter between errors", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
// Add errors
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Add success (resets counter)
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Add more errors
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
|
||||
// Should still pass (only 2 consecutive errors after reset)
|
||||
result := state.Check(vt.Now())
|
||||
assert.True(t, option.IsSome(result), "should pass with 2 consecutive errors after reset")
|
||||
|
||||
// Add one more to reach threshold
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
|
||||
// Should fail at threshold
|
||||
result = state.Check(vt.Now())
|
||||
assert.True(t, option.IsNone(result), "should fail after reaching threshold")
|
||||
})
|
||||
|
||||
t.Run("Empty can be called multiple times", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Add errors
|
||||
state = state.AddError(now)
|
||||
state = state.AddError(now)
|
||||
state = state.AddError(now)
|
||||
|
||||
// Reset multiple times
|
||||
state = state.Empty()
|
||||
state = state.Empty()
|
||||
state = state.Empty()
|
||||
|
||||
// Should still pass
|
||||
result := state.Check(now)
|
||||
assert.True(t, option.IsSome(result), "state should pass Check after multiple Empty calls")
|
||||
})
|
||||
|
||||
t.Run("time parameter is ignored in counter implementation", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
// Use different times for each operation
|
||||
time1 := vt.Now()
|
||||
time2 := time1.Add(1 * time.Hour)
|
||||
|
||||
state = state.AddError(time1)
|
||||
state = state.AddError(time2)
|
||||
|
||||
// Check with yet another time
|
||||
time3 := time1.Add(2 * time.Hour)
|
||||
result := state.Check(time3)
|
||||
|
||||
// Should still pass (2 errors, threshold is 3, not reached yet)
|
||||
assert.True(t, option.IsSome(result), "time parameter should not affect counter behavior")
|
||||
|
||||
// Add one more to reach threshold
|
||||
state = state.AddError(time1)
|
||||
result = state.Check(time1)
|
||||
assert.True(t, option.IsNone(result), "should fail after reaching threshold regardless of time")
|
||||
})
|
||||
|
||||
t.Run("large maxFailures value", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(1000)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Add many errors but not reaching threshold
|
||||
for i := uint(0); i < maxFailures-1; i++ {
|
||||
state = state.AddError(now)
|
||||
}
|
||||
|
||||
// Should still pass
|
||||
result := state.Check(now)
|
||||
assert.True(t, option.IsSome(result), "should pass Check with large maxFailures before threshold")
|
||||
|
||||
// Add one more to reach threshold
|
||||
state = state.AddError(now)
|
||||
|
||||
// Should fail
|
||||
result = state.Check(now)
|
||||
assert.True(t, option.IsNone(result), "should fail Check with large maxFailures at threshold")
|
||||
})
|
||||
|
||||
t.Run("state is immutable - original unchanged after AddError", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(2)
|
||||
originalState := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Create new state by adding error
|
||||
newState := originalState.AddError(now)
|
||||
|
||||
// Original should still pass check
|
||||
result := originalState.Check(now)
|
||||
assert.True(t, option.IsSome(result), "original state should be unchanged")
|
||||
|
||||
// New state should reach threshold (2 errors total, threshold is 2)
|
||||
newState = newState.AddError(now)
|
||||
|
||||
result = newState.Check(now)
|
||||
assert.True(t, option.IsNone(result), "new state should fail after reaching threshold")
|
||||
|
||||
// Original should still pass
|
||||
result = originalState.Check(now)
|
||||
assert.True(t, option.IsSome(result), "original state should still be unchanged")
|
||||
})
|
||||
|
||||
t.Run("state is immutable - original unchanged after Empty", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
now := vt.Now()
|
||||
|
||||
// Add errors to original
|
||||
state = state.AddError(now)
|
||||
state = state.AddError(now)
|
||||
stateWithErrors := state
|
||||
|
||||
// Create new state by calling Empty
|
||||
emptyState := stateWithErrors.Empty()
|
||||
|
||||
// Original with errors should reach threshold (2 errors total, threshold is 2)
|
||||
result := stateWithErrors.Check(now)
|
||||
assert.True(t, option.IsNone(result), "state with errors should fail after reaching threshold")
|
||||
|
||||
// Empty state should pass
|
||||
result = emptyState.Check(now)
|
||||
assert.True(t, option.IsSome(result), "empty state should pass Check")
|
||||
})
|
||||
|
||||
t.Run("AddSuccess prevents circuit from opening", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
// Add errors close to threshold
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Add success before reaching threshold
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Add more errors
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
|
||||
// Should still pass (only 2 consecutive errors)
|
||||
result := state.Check(vt.Now())
|
||||
assert.True(t, option.IsSome(result), "circuit should stay closed after success reset")
|
||||
})
|
||||
|
||||
t.Run("multiple AddSuccess calls keep counter at zero", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
// Add error
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Multiple successes
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
|
||||
// Should still pass
|
||||
result := state.Check(vt.Now())
|
||||
assert.True(t, option.IsSome(result), "multiple AddSuccess should keep counter at zero")
|
||||
|
||||
// Add errors to reach threshold
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(1 * time.Second)
|
||||
state = state.AddError(vt.Now())
|
||||
|
||||
// Should fail
|
||||
result = state.Check(vt.Now())
|
||||
assert.True(t, option.IsNone(result), "should fail after reaching threshold")
|
||||
})
|
||||
|
||||
t.Run("alternating errors and successes never opens circuit", func(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateCounter(maxFailures)
|
||||
|
||||
// Alternate errors and successes
|
||||
for i := 0; i < 10; i++ {
|
||||
state = state.AddError(vt.Now())
|
||||
vt.Advance(500 * time.Millisecond)
|
||||
state = state.AddSuccess(vt.Now())
|
||||
vt.Advance(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Should still pass (never had consecutive failures)
|
||||
result := state.Check(vt.Now())
|
||||
assert.True(t, option.IsSome(result), "alternating errors and successes should never open circuit")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddToSlice(t *testing.T) {
|
||||
ordTime := ord.OrdTime()
|
||||
|
||||
t.Run("adds item to empty slice and returns sorted result", func(t *testing.T) {
|
||||
input := []time.Time{}
|
||||
item := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 1, "result should have 1 element")
|
||||
assert.Equal(t, item, result[0], "result should contain the added item")
|
||||
})
|
||||
|
||||
t.Run("adds item and maintains sorted order", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
input := []time.Time{
|
||||
baseTime,
|
||||
baseTime.Add(20 * time.Second),
|
||||
baseTime.Add(40 * time.Second),
|
||||
}
|
||||
item := baseTime.Add(30 * time.Second)
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 4, "result should have 4 elements")
|
||||
// Verify sorted order
|
||||
assert.Equal(t, baseTime, result[0])
|
||||
assert.Equal(t, baseTime.Add(20*time.Second), result[1])
|
||||
assert.Equal(t, baseTime.Add(30*time.Second), result[2])
|
||||
assert.Equal(t, baseTime.Add(40*time.Second), result[3])
|
||||
})
|
||||
|
||||
t.Run("adds item at beginning when it's earliest", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
input := []time.Time{
|
||||
baseTime.Add(20 * time.Second),
|
||||
baseTime.Add(40 * time.Second),
|
||||
}
|
||||
item := baseTime
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 3, "result should have 3 elements")
|
||||
assert.Equal(t, baseTime, result[0], "earliest item should be first")
|
||||
assert.Equal(t, baseTime.Add(20*time.Second), result[1])
|
||||
assert.Equal(t, baseTime.Add(40*time.Second), result[2])
|
||||
})
|
||||
|
||||
t.Run("adds item at end when it's latest", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
input := []time.Time{
|
||||
baseTime,
|
||||
baseTime.Add(20 * time.Second),
|
||||
}
|
||||
item := baseTime.Add(40 * time.Second)
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 3, "result should have 3 elements")
|
||||
assert.Equal(t, baseTime, result[0])
|
||||
assert.Equal(t, baseTime.Add(20*time.Second), result[1])
|
||||
assert.Equal(t, baseTime.Add(40*time.Second), result[2], "latest item should be last")
|
||||
})
|
||||
|
||||
t.Run("does not modify original slice", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
input := []time.Time{
|
||||
baseTime,
|
||||
baseTime.Add(20 * time.Second),
|
||||
}
|
||||
originalLen := len(input)
|
||||
originalFirst := input[0]
|
||||
originalLast := input[1]
|
||||
item := baseTime.Add(10 * time.Second)
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
// Verify original slice is unchanged
|
||||
assert.Len(t, input, originalLen, "original slice length should be unchanged")
|
||||
assert.Equal(t, originalFirst, input[0], "original slice first element should be unchanged")
|
||||
assert.Equal(t, originalLast, input[1], "original slice last element should be unchanged")
|
||||
|
||||
// Verify result is different and has correct length
|
||||
assert.Len(t, result, 3, "result should have new length")
|
||||
// Verify the result contains the new item in sorted order
|
||||
assert.Equal(t, baseTime, result[0])
|
||||
assert.Equal(t, baseTime.Add(10*time.Second), result[1])
|
||||
assert.Equal(t, baseTime.Add(20*time.Second), result[2])
|
||||
})
|
||||
|
||||
t.Run("handles duplicate timestamps", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
input := []time.Time{
|
||||
baseTime,
|
||||
baseTime.Add(20 * time.Second),
|
||||
}
|
||||
item := baseTime // duplicate of first element
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 3, "result should have 3 elements including duplicate")
|
||||
// Both instances of baseTime should be present
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t.Equal(baseTime) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, count, "should have 2 instances of the duplicate timestamp")
|
||||
})
|
||||
|
||||
t.Run("maintains sort order with unsorted input", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
// Input is intentionally unsorted
|
||||
input := []time.Time{
|
||||
baseTime.Add(40 * time.Second),
|
||||
baseTime,
|
||||
baseTime.Add(20 * time.Second),
|
||||
}
|
||||
item := baseTime.Add(30 * time.Second)
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 4, "result should have 4 elements")
|
||||
// Verify result is sorted regardless of input order
|
||||
for i := 0; i < len(result)-1; i++ {
|
||||
assert.True(t, result[i].Before(result[i+1]) || result[i].Equal(result[i+1]),
|
||||
"result should be sorted: element %d (%v) should be <= element %d (%v)",
|
||||
i, result[i], i+1, result[i+1])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("works with nanosecond precision", func(t *testing.T) {
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
input := []time.Time{
|
||||
baseTime,
|
||||
baseTime.Add(2 * time.Nanosecond),
|
||||
}
|
||||
item := baseTime.Add(1 * time.Nanosecond)
|
||||
|
||||
result := addToSlice(ordTime, input, item)
|
||||
|
||||
assert.Len(t, result, 3, "result should have 3 elements")
|
||||
assert.Equal(t, baseTime, result[0])
|
||||
assert.Equal(t, baseTime.Add(1*time.Nanosecond), result[1])
|
||||
assert.Equal(t, baseTime.Add(2*time.Nanosecond), result[2])
|
||||
})
|
||||
}
|
||||
|
||||
func TestMakeClosedStateHistory(t *testing.T) {
|
||||
t.Run("creates a valid ClosedState", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
|
||||
assert.NotNil(t, state, "MakeClosedStateHistory should return a non-nil ClosedState")
|
||||
})
|
||||
|
||||
t.Run("initial state passes Check", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
result := state.Check(now)
|
||||
|
||||
assert.True(t, option.IsSome(result), "initial state should pass Check (return Some, circuit stays closed)")
|
||||
})
|
||||
|
||||
t.Run("Empty purges failure history", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add some errors
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
// Reset the state
|
||||
state = state.Empty()
|
||||
|
||||
// Should pass check after reset
|
||||
result := state.Check(baseTime.Add(20 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "state should pass Check after Empty")
|
||||
})
|
||||
|
||||
t.Run("AddSuccess purges entire failure history", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
// Add success (should purge all history)
|
||||
state = state.AddSuccess(baseTime.Add(20 * time.Second))
|
||||
|
||||
// Add another error (this is now the first error in history)
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
|
||||
// Should still pass check (only 1 error in history, threshold is 3)
|
||||
result := state.Check(baseTime.Add(30 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "AddSuccess should purge entire failure history")
|
||||
})
|
||||
|
||||
t.Run("circuit opens when failures reach threshold within time window", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors within time window but not reaching threshold
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
// Should still pass before threshold
|
||||
result := state.Check(baseTime.Add(20 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "should pass Check before threshold")
|
||||
|
||||
// Add one more error to reach threshold
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
|
||||
// Should fail check at threshold
|
||||
result = state.Check(baseTime.Add(30 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "should fail Check when reaching threshold")
|
||||
})
|
||||
|
||||
t.Run("old failures outside time window are automatically removed", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors that will be outside the time window
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
// Add error after time window has passed (this should remove old errors)
|
||||
state = state.AddError(baseTime.Add(2 * time.Minute))
|
||||
|
||||
// Should pass check (only 1 error in window, old ones removed)
|
||||
result := state.Check(baseTime.Add(2 * time.Minute))
|
||||
assert.True(t, option.IsSome(result), "old failures should be removed from history")
|
||||
})
|
||||
|
||||
t.Run("failures within time window are retained", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors within time window
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
state = state.AddError(baseTime.Add(50 * time.Second))
|
||||
|
||||
// All errors are within 1 minute window, should fail check
|
||||
result := state.Check(baseTime.Add(50 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "failures within time window should be retained")
|
||||
})
|
||||
|
||||
t.Run("sliding window behavior - errors slide out over time", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add 3 errors to reach threshold
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
state = state.AddError(baseTime.Add(20 * time.Second))
|
||||
|
||||
// Circuit should be open
|
||||
result := state.Check(baseTime.Add(20 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "circuit should be open with 3 failures")
|
||||
|
||||
// Add error after first failure has expired (> 1 minute from first error)
|
||||
// This should remove the first error, leaving only 3 in window
|
||||
state = state.AddError(baseTime.Add(70 * time.Second))
|
||||
|
||||
// Should still fail check (3 errors in window after pruning)
|
||||
result = state.Check(baseTime.Add(70 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "circuit should remain open with 3 failures in window")
|
||||
})
|
||||
|
||||
t.Run("zero maxFailures means circuit is always open", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(0)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Initial state should already fail (0 >= 0)
|
||||
result := state.Check(baseTime)
|
||||
assert.True(t, option.IsNone(result), "initial state should fail Check with maxFailures=0")
|
||||
|
||||
// Add one error
|
||||
state = state.AddError(baseTime)
|
||||
|
||||
// Should still fail
|
||||
result = state.Check(baseTime)
|
||||
assert.True(t, option.IsNone(result), "should fail Check after error with maxFailures=0")
|
||||
})
|
||||
|
||||
t.Run("success purges history even with failures in time window", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors within time window
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
// Add success (purges all history)
|
||||
state = state.AddSuccess(baseTime.Add(20 * time.Second))
|
||||
|
||||
// Add more errors
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
state = state.AddError(baseTime.Add(40 * time.Second))
|
||||
|
||||
// Should still pass (only 2 errors after purge)
|
||||
result := state.Check(baseTime.Add(40 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "success should purge all history")
|
||||
|
||||
// Add one more to reach threshold
|
||||
state = state.AddError(baseTime.Add(50 * time.Second))
|
||||
|
||||
// Should fail at threshold
|
||||
result = state.Check(baseTime.Add(50 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "should fail after reaching threshold")
|
||||
})
|
||||
|
||||
t.Run("multiple successes keep history empty", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add error
|
||||
state = state.AddError(baseTime)
|
||||
|
||||
// Multiple successes
|
||||
state = state.AddSuccess(baseTime.Add(10 * time.Second))
|
||||
state = state.AddSuccess(baseTime.Add(20 * time.Second))
|
||||
state = state.AddSuccess(baseTime.Add(30 * time.Second))
|
||||
|
||||
// Should still pass
|
||||
result := state.Check(baseTime.Add(30 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "multiple AddSuccess should keep history empty")
|
||||
|
||||
// Add errors to reach threshold
|
||||
state = state.AddError(baseTime.Add(40 * time.Second))
|
||||
state = state.AddError(baseTime.Add(50 * time.Second))
|
||||
|
||||
// Should fail
|
||||
result = state.Check(baseTime.Add(50 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "should fail after reaching threshold")
|
||||
})
|
||||
|
||||
t.Run("state is immutable - original unchanged after AddError", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(2)
|
||||
originalState := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create new state by adding error
|
||||
newState := originalState.AddError(baseTime)
|
||||
|
||||
// Original should still pass check
|
||||
result := originalState.Check(baseTime)
|
||||
assert.True(t, option.IsSome(result), "original state should be unchanged")
|
||||
|
||||
// New state should reach threshold after another error
|
||||
newState = newState.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
result = newState.Check(baseTime.Add(10 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "new state should fail after reaching threshold")
|
||||
|
||||
// Original should still pass
|
||||
result = originalState.Check(baseTime)
|
||||
assert.True(t, option.IsSome(result), "original state should still be unchanged")
|
||||
})
|
||||
|
||||
t.Run("state is immutable - original unchanged after Empty", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors to original
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
stateWithErrors := state
|
||||
|
||||
// Create new state by calling Empty
|
||||
emptyState := stateWithErrors.Empty()
|
||||
|
||||
// Original with errors should fail check
|
||||
result := stateWithErrors.Check(baseTime.Add(10 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "state with errors should fail after reaching threshold")
|
||||
|
||||
// Empty state should pass
|
||||
result = emptyState.Check(baseTime.Add(10 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "empty state should pass Check")
|
||||
})
|
||||
|
||||
t.Run("exact time window boundary behavior", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add error at baseTime
|
||||
state = state.AddError(baseTime)
|
||||
|
||||
// Add error exactly at time window boundary
|
||||
state = state.AddError(baseTime.Add(1 * time.Minute))
|
||||
|
||||
// The first error should be removed (it's now outside the window)
|
||||
// Only 1 error should remain
|
||||
result := state.Check(baseTime.Add(1 * time.Minute))
|
||||
assert.True(t, option.IsSome(result), "error at exact window boundary should remove older errors")
|
||||
})
|
||||
|
||||
t.Run("multiple errors at same timestamp", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add multiple errors at same time
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime)
|
||||
|
||||
// Should fail check (3 errors at same time)
|
||||
result := state.Check(baseTime)
|
||||
assert.True(t, option.IsNone(result), "multiple errors at same timestamp should count separately")
|
||||
})
|
||||
|
||||
t.Run("errors added out of chronological order are sorted", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(4)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors out of order
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
state = state.AddError(baseTime.Add(5 * time.Second))
|
||||
state = state.AddError(baseTime.Add(50 * time.Second))
|
||||
|
||||
// Add error that should trigger pruning
|
||||
state = state.AddError(baseTime.Add(70 * time.Second))
|
||||
|
||||
// The error at 5s should be removed (> 1 minute from 70s: 70-5=65 > 60)
|
||||
// Should have 3 errors remaining (30s, 50s, 70s)
|
||||
result := state.Check(baseTime.Add(70 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "errors should be sorted and pruned correctly")
|
||||
})
|
||||
|
||||
t.Run("large time window with many failures", func(t *testing.T) {
|
||||
timeWindow := 24 * time.Hour
|
||||
maxFailures := uint(100)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add many failures within the window
|
||||
for i := 0; i < 99; i++ {
|
||||
state = state.AddError(baseTime.Add(time.Duration(i) * time.Minute))
|
||||
}
|
||||
|
||||
// Should still pass (99 < 100)
|
||||
result := state.Check(baseTime.Add(99 * time.Minute))
|
||||
assert.True(t, option.IsSome(result), "should pass with 99 failures when threshold is 100")
|
||||
|
||||
// Add one more to reach threshold
|
||||
state = state.AddError(baseTime.Add(100 * time.Minute))
|
||||
|
||||
// Should fail
|
||||
result = state.Check(baseTime.Add(100 * time.Minute))
|
||||
assert.True(t, option.IsNone(result), "should fail at threshold with large window")
|
||||
})
|
||||
|
||||
t.Run("very short time window", func(t *testing.T) {
|
||||
timeWindow := 100 * time.Millisecond
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors within short window
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(50 * time.Millisecond))
|
||||
state = state.AddError(baseTime.Add(90 * time.Millisecond))
|
||||
|
||||
// Should fail (3 errors within 100ms)
|
||||
result := state.Check(baseTime.Add(90 * time.Millisecond))
|
||||
assert.True(t, option.IsNone(result), "should fail with errors in short time window")
|
||||
|
||||
// Add error after window expires
|
||||
state = state.AddError(baseTime.Add(200 * time.Millisecond))
|
||||
|
||||
// Should pass (old errors removed, only 1 in window)
|
||||
result = state.Check(baseTime.Add(200 * time.Millisecond))
|
||||
assert.True(t, option.IsSome(result), "should pass after short window expires")
|
||||
})
|
||||
|
||||
t.Run("success prevents circuit from opening", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(3)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors close to threshold
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
|
||||
// Add success before reaching threshold
|
||||
state = state.AddSuccess(baseTime.Add(20 * time.Second))
|
||||
|
||||
// Add more errors
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
state = state.AddError(baseTime.Add(40 * time.Second))
|
||||
|
||||
// Should still pass (only 2 errors after success purge)
|
||||
result := state.Check(baseTime.Add(40 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "circuit should stay closed after success purge")
|
||||
})
|
||||
|
||||
t.Run("Empty can be called multiple times", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(2)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add errors
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(10 * time.Second))
|
||||
state = state.AddError(baseTime.Add(20 * time.Second))
|
||||
|
||||
// Reset multiple times
|
||||
state = state.Empty()
|
||||
state = state.Empty()
|
||||
state = state.Empty()
|
||||
|
||||
// Should still pass
|
||||
result := state.Check(baseTime.Add(30 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "state should pass Check after multiple Empty calls")
|
||||
})
|
||||
|
||||
t.Run("gradual failure accumulation within window", func(t *testing.T) {
|
||||
timeWindow := 1 * time.Minute
|
||||
maxFailures := uint(5)
|
||||
state := MakeClosedStateHistory(timeWindow, maxFailures)
|
||||
baseTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Add failures gradually
|
||||
state = state.AddError(baseTime)
|
||||
state = state.AddError(baseTime.Add(15 * time.Second))
|
||||
state = state.AddError(baseTime.Add(30 * time.Second))
|
||||
state = state.AddError(baseTime.Add(45 * time.Second))
|
||||
|
||||
// Should still pass (4 < 5)
|
||||
result := state.Check(baseTime.Add(45 * time.Second))
|
||||
assert.True(t, option.IsSome(result), "should pass before threshold")
|
||||
|
||||
// Add one more within window
|
||||
state = state.AddError(baseTime.Add(55 * time.Second))
|
||||
|
||||
// Should fail (5 >= 5)
|
||||
result = state.Check(baseTime.Add(55 * time.Second))
|
||||
assert.True(t, option.IsNone(result), "should fail at threshold")
|
||||
})
|
||||
}
|
||||
335
v2/circuitbreaker/error.go
Normal file
335
v2/circuitbreaker/error.go
Normal file
@@ -0,0 +1,335 @@
|
||||
// Package circuitbreaker provides error types and utilities for circuit breaker implementations.
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
E "github.com/IBM/fp-go/v2/errors"
|
||||
FH "github.com/IBM/fp-go/v2/http"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
)
|
||||
|
||||
// CircuitBreakerError represents an error that occurs when a circuit breaker is in the open state.
|
||||
//
|
||||
// When a circuit breaker opens due to too many failures, it prevents further operations
|
||||
// from executing until a reset time is reached. This error type communicates that state
|
||||
// and provides information about when the circuit breaker will attempt to close again.
|
||||
//
|
||||
// Fields:
|
||||
// - Name: The name identifying this circuit breaker instance
|
||||
// - ResetAt: The time at which the circuit breaker will transition from open to half-open state
|
||||
//
|
||||
// Thread Safety: This type is immutable and safe for concurrent use.
|
||||
type CircuitBreakerError struct {
|
||||
Name string
|
||||
ResetAt time.Time
|
||||
}
|
||||
|
||||
// Error implements the error interface for CircuitBreakerError.
|
||||
//
|
||||
// Returns a formatted error message indicating that the circuit breaker is open
|
||||
// and when it will attempt to close.
|
||||
//
|
||||
// Returns:
|
||||
// - A string describing the circuit breaker state and reset time
|
||||
//
|
||||
// Thread Safety: This method is safe for concurrent use as it only reads immutable fields.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := &CircuitBreakerError{Name: "API", ResetAt: time.Now().Add(30 * time.Second)}
|
||||
// fmt.Println(err.Error())
|
||||
// // Output: circuit breaker is open [API], will close at 2026-01-09 12:20:47.123 +0100 CET
|
||||
func (e *CircuitBreakerError) Error() string {
|
||||
return fmt.Sprintf("circuit breaker is open [%s], will close at %s", e.Name, e.ResetAt)
|
||||
}
|
||||
|
||||
// MakeCircuitBreakerErrorWithName creates a circuit breaker error constructor with a custom name.
|
||||
//
|
||||
// This function returns a constructor that creates CircuitBreakerError instances with a specific
|
||||
// circuit breaker name. This is useful when you have multiple circuit breakers in your system
|
||||
// and want to identify which one is open in error messages.
|
||||
//
|
||||
// Parameters:
|
||||
// - name: The name to identify this circuit breaker in error messages
|
||||
//
|
||||
// Returns:
|
||||
// - A function that takes a reset time and returns a CircuitBreakerError with the specified name
|
||||
//
|
||||
// Thread Safety: The returned function is safe for concurrent use as it creates new error
|
||||
// instances on each call.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// makeDBError := MakeCircuitBreakerErrorWithName("Database Circuit Breaker")
|
||||
// err := makeDBError(time.Now().Add(30 * time.Second))
|
||||
// fmt.Println(err.Error())
|
||||
// // Output: circuit breaker is open [Database Circuit Breaker], will close at 2026-01-09 12:20:47.123 +0100 CET
|
||||
func MakeCircuitBreakerErrorWithName(name string) func(time.Time) error {
|
||||
return func(resetTime time.Time) error {
|
||||
return &CircuitBreakerError{Name: name, ResetAt: resetTime}
|
||||
}
|
||||
}
|
||||
|
||||
// MakeCircuitBreakerError creates a new CircuitBreakerError with the specified reset time.
|
||||
//
|
||||
// This constructor function creates a circuit breaker error that indicates when the
|
||||
// circuit breaker will transition from the open state to the half-open state, allowing
|
||||
// test requests to determine if the underlying service has recovered.
|
||||
//
|
||||
// Parameters:
|
||||
// - resetTime: The time at which the circuit breaker will attempt to close
|
||||
//
|
||||
// Returns:
|
||||
// - An error representing the circuit breaker open state
|
||||
//
|
||||
// Thread Safety: This function is safe for concurrent use as it creates new error
|
||||
// instances on each call.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resetTime := time.Now().Add(30 * time.Second)
|
||||
// err := MakeCircuitBreakerError(resetTime)
|
||||
// if cbErr, ok := err.(*CircuitBreakerError); ok {
|
||||
// fmt.Printf("Circuit breaker will reset at: %s\n", cbErr.ResetAt)
|
||||
// }
|
||||
var MakeCircuitBreakerError = MakeCircuitBreakerErrorWithName("Generic Circuit Breaker")
|
||||
|
||||
// AnyError converts an error to an Option, wrapping non-nil errors in Some and nil errors in None.
|
||||
//
|
||||
// This variable provides a functional way to handle errors by converting them to Option types.
|
||||
// It's particularly useful in functional programming contexts where you want to treat errors
|
||||
// as optional values rather than using traditional error handling patterns.
|
||||
//
|
||||
// Behavior:
|
||||
// - If the error is non-nil, returns Some(error)
|
||||
// - If the error is nil, returns None
|
||||
//
|
||||
// Thread Safety: This function is pure and safe for concurrent use.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := errors.New("something went wrong")
|
||||
// optErr := AnyError(err) // Some(error)
|
||||
//
|
||||
// var noErr error = nil
|
||||
// optNoErr := AnyError(noErr) // None
|
||||
//
|
||||
// // Using in functional pipelines
|
||||
// result := F.Pipe2(
|
||||
// someOperation(),
|
||||
// AnyError,
|
||||
// O.Map(func(e error) string { return e.Error() }),
|
||||
// )
|
||||
var AnyError = option.FromPredicate(E.IsNonNil)
|
||||
|
||||
// shouldOpenCircuit determines if an error should cause a circuit breaker to open.
|
||||
//
|
||||
// This function checks if an error represents an infrastructure or server problem
|
||||
// that indicates the service is unhealthy and should trigger circuit breaker protection.
|
||||
// It examines both the error type and, for HTTP errors, the status code.
|
||||
//
|
||||
// Errors that should open the circuit include:
|
||||
// - HTTP 5xx server errors (500-599) indicating server-side problems
|
||||
// - Network errors (connection refused, connection reset, timeouts)
|
||||
// - DNS resolution errors
|
||||
// - TLS/certificate errors
|
||||
// - Other infrastructure-related errors
|
||||
//
|
||||
// Errors that should NOT open the circuit include:
|
||||
// - HTTP 4xx client errors (bad request, unauthorized, not found, etc.)
|
||||
// - Application-level validation errors
|
||||
// - Business logic errors
|
||||
//
|
||||
// The function unwraps error chains to find the root cause, making it compatible
|
||||
// with wrapped errors created by fmt.Errorf with %w or errors.Join.
|
||||
//
|
||||
// Parameters:
|
||||
// - err: The error to evaluate (may be nil)
|
||||
//
|
||||
// Returns:
|
||||
// - true if the error should cause the circuit to open, false otherwise
|
||||
//
|
||||
// Thread Safety: This function is pure and safe for concurrent use. It does not
|
||||
// modify any state.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // HTTP 500 error - should open circuit
|
||||
// httpErr := &FH.HttpError{...} // status 500
|
||||
// if shouldOpenCircuit(httpErr) {
|
||||
// // Open circuit breaker
|
||||
// }
|
||||
//
|
||||
// // HTTP 404 error - should NOT open circuit (client error)
|
||||
// notFoundErr := &FH.HttpError{...} // status 404
|
||||
// if !shouldOpenCircuit(notFoundErr) {
|
||||
// // Don't open circuit, this is a client error
|
||||
// }
|
||||
//
|
||||
// // Network timeout - should open circuit
|
||||
// timeoutErr := &net.OpError{Op: "dial", Err: syscall.ETIMEDOUT}
|
||||
// if shouldOpenCircuit(timeoutErr) {
|
||||
// // Open circuit breaker
|
||||
// }
|
||||
func shouldOpenCircuit(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for HTTP errors with server status codes (5xx)
|
||||
var httpErr *FH.HttpError
|
||||
if errors.As(err, &httpErr) {
|
||||
statusCode := httpErr.StatusCode()
|
||||
// Only 5xx errors should open the circuit
|
||||
// 4xx errors are client errors and shouldn't affect circuit state
|
||||
return statusCode >= http.StatusInternalServerError && statusCode < 600
|
||||
}
|
||||
|
||||
// Check for network operation errors
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
// Network timeouts should open the circuit
|
||||
if opErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Check the underlying error
|
||||
if opErr.Err != nil {
|
||||
return isInfrastructureError(opErr.Err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for DNS errors
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for URL errors (often wrap network errors)
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
if urlErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Recursively check the wrapped error
|
||||
return shouldOpenCircuit(urlErr.Err)
|
||||
}
|
||||
|
||||
// Check for specific syscall errors that indicate infrastructure problems
|
||||
return isInfrastructureError(err) || isTLSError(err)
|
||||
}
|
||||
|
||||
// isInfrastructureError checks if an error is a low-level infrastructure error
|
||||
// that should cause the circuit to open.
|
||||
//
|
||||
// This function examines syscall errors to identify network and system-level failures
|
||||
// that indicate the service is unavailable or unreachable.
|
||||
//
|
||||
// Infrastructure errors include:
|
||||
// - ECONNREFUSED: Connection refused (service not listening)
|
||||
// - ECONNRESET: Connection reset by peer (service crashed or network issue)
|
||||
// - ECONNABORTED: Connection aborted (network issue)
|
||||
// - ENETUNREACH: Network unreachable (routing problem)
|
||||
// - EHOSTUNREACH: Host unreachable (host down or network issue)
|
||||
// - EPIPE: Broken pipe (connection closed unexpectedly)
|
||||
// - ETIMEDOUT: Operation timed out (service not responding)
|
||||
//
|
||||
// Parameters:
|
||||
// - err: The error to check
|
||||
//
|
||||
// Returns:
|
||||
// - true if the error is an infrastructure error, false otherwise
|
||||
//
|
||||
// Thread Safety: This function is pure and safe for concurrent use.
|
||||
func isInfrastructureError(err error) bool {
|
||||
|
||||
var syscallErr *syscall.Errno
|
||||
|
||||
if errors.As(err, &syscallErr) {
|
||||
switch *syscallErr {
|
||||
case syscall.ECONNREFUSED,
|
||||
syscall.ECONNRESET,
|
||||
syscall.ECONNABORTED,
|
||||
syscall.ENETUNREACH,
|
||||
syscall.EHOSTUNREACH,
|
||||
syscall.EPIPE,
|
||||
syscall.ETIMEDOUT:
|
||||
return true
|
||||
}
|
||||
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTLSError checks if an error is a TLS/certificate error that should cause the circuit to open.
|
||||
//
|
||||
// TLS errors typically indicate infrastructure or configuration problems that prevent
|
||||
// secure communication with the service. These errors suggest the service is not properly
|
||||
// configured or accessible.
|
||||
//
|
||||
// TLS errors include:
|
||||
// - Certificate verification failures (invalid, expired, or malformed certificates)
|
||||
// - Unknown certificate authority errors (untrusted CA)
|
||||
//
|
||||
// Parameters:
|
||||
// - err: The error to check
|
||||
//
|
||||
// Returns:
|
||||
// - true if the error is a TLS/certificate error, false otherwise
|
||||
//
|
||||
// Thread Safety: This function is pure and safe for concurrent use.
|
||||
func isTLSError(err error) bool {
|
||||
// Certificate verification failed
|
||||
var certErr *x509.CertificateInvalidError
|
||||
if errors.As(err, &certErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Unknown authority
|
||||
var unknownAuthErr *x509.UnknownAuthorityError
|
||||
if errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// InfrastructureError is a predicate that converts errors to Options based on whether
|
||||
// they should trigger circuit breaker opening.
|
||||
//
|
||||
// This variable provides a functional way to filter errors that represent infrastructure
|
||||
// failures (network issues, server errors, timeouts, etc.) from application-level errors
|
||||
// (validation errors, business logic errors, client errors).
|
||||
//
|
||||
// Behavior:
|
||||
// - Returns Some(error) if the error should open the circuit (infrastructure failure)
|
||||
// - Returns None if the error should not open the circuit (application error)
|
||||
//
|
||||
// Thread Safety: This function is pure and safe for concurrent use.
|
||||
//
|
||||
// Use this in circuit breaker configurations to determine which errors should count
|
||||
// toward the failure threshold.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // In a circuit breaker configuration
|
||||
// breaker := MakeCircuitBreaker(
|
||||
// ...,
|
||||
// checkError: InfrastructureError, // Only infrastructure errors open the circuit
|
||||
// ...,
|
||||
// )
|
||||
//
|
||||
// // HTTP 500 error - returns Some(error)
|
||||
// result := InfrastructureError(&FH.HttpError{...}) // Some(error)
|
||||
//
|
||||
// // HTTP 404 error - returns None
|
||||
// result := InfrastructureError(&FH.HttpError{...}) // None
|
||||
var InfrastructureError = option.FromPredicate(shouldOpenCircuit)
|
||||
503
v2/circuitbreaker/error_test.go
Normal file
503
v2/circuitbreaker/error_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
FH "github.com/IBM/fp-go/v2/http"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreakerError tests the CircuitBreakerError type
|
||||
func TestCircuitBreakerError(t *testing.T) {
|
||||
t.Run("Error returns formatted message with reset time", func(t *testing.T) {
|
||||
resetTime := time.Date(2026, 1, 9, 12, 30, 0, 0, time.UTC)
|
||||
err := &CircuitBreakerError{ResetAt: resetTime}
|
||||
|
||||
result := err.Error()
|
||||
|
||||
assert.Contains(t, result, "circuit breaker is open")
|
||||
assert.Contains(t, result, "will close at")
|
||||
assert.Contains(t, result, resetTime.String())
|
||||
})
|
||||
|
||||
t.Run("Error message includes full timestamp", func(t *testing.T) {
|
||||
resetTime := time.Now().Add(30 * time.Second)
|
||||
err := &CircuitBreakerError{ResetAt: resetTime}
|
||||
|
||||
result := err.Error()
|
||||
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, result, "circuit breaker is open")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMakeCircuitBreakerError tests the constructor function
|
||||
func TestMakeCircuitBreakerError(t *testing.T) {
|
||||
t.Run("creates CircuitBreakerError with correct reset time", func(t *testing.T) {
|
||||
resetTime := time.Date(2026, 1, 9, 13, 0, 0, 0, time.UTC)
|
||||
|
||||
err := MakeCircuitBreakerError(resetTime)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
cbErr, ok := err.(*CircuitBreakerError)
|
||||
assert.True(t, ok, "should return *CircuitBreakerError type")
|
||||
assert.Equal(t, resetTime, cbErr.ResetAt)
|
||||
})
|
||||
|
||||
t.Run("returns error interface", func(t *testing.T) {
|
||||
resetTime := time.Now().Add(1 * time.Minute)
|
||||
|
||||
err := MakeCircuitBreakerError(resetTime)
|
||||
|
||||
// Should be assignable to error interface
|
||||
var _ error = err
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("created error can be type asserted", func(t *testing.T) {
|
||||
resetTime := time.Now().Add(45 * time.Second)
|
||||
|
||||
err := MakeCircuitBreakerError(resetTime)
|
||||
|
||||
cbErr, ok := err.(*CircuitBreakerError)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, resetTime, cbErr.ResetAt)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAnyError tests the AnyError function
|
||||
func TestAnyError(t *testing.T) {
|
||||
t.Run("returns Some for non-nil error", func(t *testing.T) {
|
||||
err := errors.New("test error")
|
||||
|
||||
result := AnyError(err)
|
||||
|
||||
assert.True(t, option.IsSome(result), "should return Some for non-nil error")
|
||||
value := option.GetOrElse(func() error { return nil })(result)
|
||||
assert.Equal(t, err, value)
|
||||
})
|
||||
|
||||
t.Run("returns None for nil error", func(t *testing.T) {
|
||||
var err error = nil
|
||||
|
||||
result := AnyError(err)
|
||||
|
||||
assert.True(t, option.IsNone(result), "should return None for nil error")
|
||||
})
|
||||
|
||||
t.Run("works with different error types", func(t *testing.T) {
|
||||
err1 := fmt.Errorf("wrapped: %w", errors.New("inner"))
|
||||
err2 := &CircuitBreakerError{ResetAt: time.Now()}
|
||||
|
||||
result1 := AnyError(err1)
|
||||
result2 := AnyError(err2)
|
||||
|
||||
assert.True(t, option.IsSome(result1))
|
||||
assert.True(t, option.IsSome(result2))
|
||||
})
|
||||
}
|
||||
|
||||
// TestShouldOpenCircuit tests the shouldOpenCircuit function
|
||||
func TestShouldOpenCircuit(t *testing.T) {
|
||||
t.Run("returns false for nil error", func(t *testing.T) {
|
||||
result := shouldOpenCircuit(nil)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("HTTP 5xx errors should open circuit", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expected bool
|
||||
}{
|
||||
{"500 Internal Server Error", 500, true},
|
||||
{"501 Not Implemented", 501, true},
|
||||
{"502 Bad Gateway", 502, true},
|
||||
{"503 Service Unavailable", 503, true},
|
||||
{"504 Gateway Timeout", 504, true},
|
||||
{"599 Custom Server Error", 599, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: tc.statusCode,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
|
||||
result := shouldOpenCircuit(httpErr)
|
||||
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP 4xx errors should NOT open circuit", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expected bool
|
||||
}{
|
||||
{"400 Bad Request", 400, false},
|
||||
{"401 Unauthorized", 401, false},
|
||||
{"403 Forbidden", 403, false},
|
||||
{"404 Not Found", 404, false},
|
||||
{"422 Unprocessable Entity", 422, false},
|
||||
{"429 Too Many Requests", 429, false},
|
||||
{"499 Custom Client Error", 499, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: tc.statusCode,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
|
||||
result := shouldOpenCircuit(httpErr)
|
||||
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP 2xx and 3xx should NOT open circuit", func(t *testing.T) {
|
||||
testCases := []int{200, 201, 204, 301, 302, 304}
|
||||
|
||||
for _, statusCode := range testCases {
|
||||
t.Run(fmt.Sprintf("Status %d", statusCode), func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
|
||||
result := shouldOpenCircuit(httpErr)
|
||||
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("network timeout errors should open circuit", func(t *testing.T) {
|
||||
opErr := &net.OpError{
|
||||
Op: "dial",
|
||||
Err: &timeoutError{},
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(opErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("DNS errors should open circuit", func(t *testing.T) {
|
||||
dnsErr := &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: "example.com",
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(dnsErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("URL timeout errors should open circuit", func(t *testing.T) {
|
||||
urlErr := &url.Error{
|
||||
Op: "Get",
|
||||
URL: "http://example.com",
|
||||
Err: &timeoutError{},
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(urlErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("URL errors with nested network timeout should open circuit", func(t *testing.T) {
|
||||
urlErr := &url.Error{
|
||||
Op: "Get",
|
||||
URL: "http://example.com",
|
||||
Err: &net.OpError{
|
||||
Op: "dial",
|
||||
Err: &timeoutError{},
|
||||
},
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(urlErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("OpError with nil Err should open circuit", func(t *testing.T) {
|
||||
opErr := &net.OpError{
|
||||
Op: "dial",
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(opErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("wrapped HTTP 5xx error should open circuit", func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: 503,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
wrappedErr := fmt.Errorf("service error: %w", httpErr)
|
||||
|
||||
result := shouldOpenCircuit(wrappedErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("wrapped HTTP 4xx error should NOT open circuit", func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: 404,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
wrappedErr := fmt.Errorf("not found: %w", httpErr)
|
||||
|
||||
result := shouldOpenCircuit(wrappedErr)
|
||||
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("generic application error should NOT open circuit", func(t *testing.T) {
|
||||
err := errors.New("validation failed")
|
||||
|
||||
result := shouldOpenCircuit(err)
|
||||
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsInfrastructureError tests infrastructure error detection through shouldOpenCircuit
|
||||
func TestIsInfrastructureError(t *testing.T) {
|
||||
t.Run("network timeout is infrastructure error", func(t *testing.T) {
|
||||
opErr := &net.OpError{Op: "dial", Err: &timeoutError{}}
|
||||
result := shouldOpenCircuit(opErr)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("OpError with nil Err is infrastructure error", func(t *testing.T) {
|
||||
opErr := &net.OpError{Op: "dial", Err: nil}
|
||||
result := shouldOpenCircuit(opErr)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("generic error returns false", func(t *testing.T) {
|
||||
err := errors.New("generic error")
|
||||
result := shouldOpenCircuit(err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("wrapped network timeout is detected", func(t *testing.T) {
|
||||
opErr := &net.OpError{Op: "dial", Err: &timeoutError{}}
|
||||
wrappedErr := fmt.Errorf("connection failed: %w", opErr)
|
||||
result := shouldOpenCircuit(wrappedErr)
|
||||
assert.True(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsTLSError tests the isTLSError function
|
||||
func TestIsTLSError(t *testing.T) {
|
||||
t.Run("certificate invalid error is TLS error", func(t *testing.T) {
|
||||
certErr := &x509.CertificateInvalidError{
|
||||
Reason: x509.Expired,
|
||||
}
|
||||
|
||||
result := isTLSError(certErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("unknown authority error is TLS error", func(t *testing.T) {
|
||||
authErr := &x509.UnknownAuthorityError{}
|
||||
|
||||
result := isTLSError(authErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("generic error is not TLS error", func(t *testing.T) {
|
||||
err := errors.New("generic error")
|
||||
|
||||
result := isTLSError(err)
|
||||
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("wrapped certificate error is detected", func(t *testing.T) {
|
||||
certErr := &x509.CertificateInvalidError{
|
||||
Reason: x509.Expired,
|
||||
}
|
||||
wrappedErr := fmt.Errorf("TLS handshake failed: %w", certErr)
|
||||
|
||||
result := isTLSError(wrappedErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("wrapped unknown authority error is detected", func(t *testing.T) {
|
||||
authErr := &x509.UnknownAuthorityError{}
|
||||
wrappedErr := fmt.Errorf("certificate verification failed: %w", authErr)
|
||||
|
||||
result := isTLSError(wrappedErr)
|
||||
|
||||
assert.True(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestInfrastructureError tests the InfrastructureError variable
|
||||
func TestInfrastructureError(t *testing.T) {
|
||||
t.Run("returns Some for infrastructure errors", func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: 503,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
|
||||
result := InfrastructureError(httpErr)
|
||||
|
||||
assert.True(t, option.IsSome(result))
|
||||
})
|
||||
|
||||
t.Run("returns None for non-infrastructure errors", func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: 404,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
|
||||
result := InfrastructureError(httpErr)
|
||||
|
||||
assert.True(t, option.IsNone(result))
|
||||
})
|
||||
|
||||
t.Run("returns None for nil error", func(t *testing.T) {
|
||||
result := InfrastructureError(nil)
|
||||
|
||||
assert.True(t, option.IsNone(result))
|
||||
})
|
||||
|
||||
t.Run("returns Some for network timeout", func(t *testing.T) {
|
||||
opErr := &net.OpError{
|
||||
Op: "dial",
|
||||
Err: &timeoutError{},
|
||||
}
|
||||
|
||||
result := InfrastructureError(opErr)
|
||||
|
||||
assert.True(t, option.IsSome(result))
|
||||
})
|
||||
}
|
||||
|
||||
// TestComplexErrorScenarios tests complex real-world error scenarios
|
||||
func TestComplexErrorScenarios(t *testing.T) {
|
||||
t.Run("deeply nested URL error with HTTP 5xx", func(t *testing.T) {
|
||||
testURL, _ := url.Parse("http://api.example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: 502,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
urlErr := &url.Error{
|
||||
Op: "Get",
|
||||
URL: "http://api.example.com",
|
||||
Err: httpErr,
|
||||
}
|
||||
wrappedErr := fmt.Errorf("API call failed: %w", urlErr)
|
||||
|
||||
result := shouldOpenCircuit(wrappedErr)
|
||||
|
||||
assert.True(t, result, "should detect HTTP 5xx through multiple layers")
|
||||
})
|
||||
|
||||
t.Run("URL error with timeout nested in OpError", func(t *testing.T) {
|
||||
opErr := &net.OpError{
|
||||
Op: "dial",
|
||||
Err: &timeoutError{},
|
||||
}
|
||||
urlErr := &url.Error{
|
||||
Op: "Post",
|
||||
URL: "http://api.example.com",
|
||||
Err: opErr,
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(urlErr)
|
||||
|
||||
assert.True(t, result, "should detect timeout through URL error")
|
||||
})
|
||||
|
||||
t.Run("multiple wrapped errors with infrastructure error at core", func(t *testing.T) {
|
||||
coreErr := &net.OpError{Op: "dial", Err: &timeoutError{}}
|
||||
layer1 := fmt.Errorf("connection attempt failed: %w", coreErr)
|
||||
layer2 := fmt.Errorf("retry exhausted: %w", layer1)
|
||||
layer3 := fmt.Errorf("service unavailable: %w", layer2)
|
||||
|
||||
result := shouldOpenCircuit(layer3)
|
||||
|
||||
assert.True(t, result, "should unwrap to find infrastructure error")
|
||||
})
|
||||
|
||||
t.Run("OpError with nil Err should open circuit", func(t *testing.T) {
|
||||
opErr := &net.OpError{
|
||||
Op: "dial",
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
result := shouldOpenCircuit(opErr)
|
||||
|
||||
assert.True(t, result, "OpError with nil Err should be treated as infrastructure error")
|
||||
})
|
||||
|
||||
t.Run("mixed error types - HTTP 4xx with network error", func(t *testing.T) {
|
||||
// This tests that we correctly identify the error type
|
||||
testURL, _ := url.Parse("http://example.com")
|
||||
resp := &http.Response{
|
||||
StatusCode: 400,
|
||||
Request: &http.Request{URL: testURL},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
httpErr := FH.StatusCodeError(resp)
|
||||
|
||||
result := shouldOpenCircuit(httpErr)
|
||||
|
||||
assert.False(t, result, "HTTP 4xx should not open circuit even if wrapped")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper type for testing timeout errors
|
||||
type timeoutError struct{}
|
||||
|
||||
func (e *timeoutError) Error() string { return "timeout" }
|
||||
func (e *timeoutError) Timeout() bool { return true }
|
||||
func (e *timeoutError) Temporary() bool { return true }
|
||||
208
v2/circuitbreaker/metrics.go
Normal file
208
v2/circuitbreaker/metrics.go
Normal file
@@ -0,0 +1,208 @@
|
||||
// Package circuitbreaker provides metrics collection for circuit breaker state transitions and events.
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
)
|
||||
|
||||
type (
|
||||
// Metrics defines the interface for collecting circuit breaker metrics and events.
|
||||
// Implementations can use this interface to track circuit breaker behavior for
|
||||
// monitoring, alerting, and debugging purposes.
|
||||
//
|
||||
// All methods accept a time.Time parameter representing when the event occurred,
|
||||
// and return an IO[Void] operation that performs the metric recording when executed.
|
||||
//
|
||||
// Thread Safety: Implementations must be thread-safe as circuit breakers may be
|
||||
// accessed concurrently from multiple goroutines.
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// logger := log.New(os.Stdout, "[CircuitBreaker] ", log.LstdFlags)
|
||||
// metrics := MakeMetricsFromLogger("API-Service", logger)
|
||||
//
|
||||
// // In circuit breaker implementation
|
||||
// io.Run(metrics.Accept(time.Now())) // Record accepted request
|
||||
// io.Run(metrics.Reject(time.Now())) // Record rejected request
|
||||
// io.Run(metrics.Open(time.Now())) // Record circuit opening
|
||||
// io.Run(metrics.Close(time.Now())) // Record circuit closing
|
||||
// io.Run(metrics.Canary(time.Now())) // Record canary request
|
||||
Metrics interface {
|
||||
// Accept records that a request was accepted and allowed through the circuit breaker.
|
||||
// This is called when the circuit is closed or in half-open state (canary request).
|
||||
//
|
||||
// Parameters:
|
||||
// - time.Time: The timestamp when the request was accepted
|
||||
//
|
||||
// Returns:
|
||||
// - IO[Void]: An IO operation that records the acceptance when executed
|
||||
//
|
||||
// Thread Safety: Must be safe to call concurrently.
|
||||
Accept(time.Time) IO[Void]
|
||||
|
||||
// Reject records that a request was rejected because the circuit breaker is open.
|
||||
// This is called when a request is blocked due to the circuit being in open state
|
||||
// and the reset time has not been reached.
|
||||
//
|
||||
// Parameters:
|
||||
// - time.Time: The timestamp when the request was rejected
|
||||
//
|
||||
// Returns:
|
||||
// - IO[Void]: An IO operation that records the rejection when executed
|
||||
//
|
||||
// Thread Safety: Must be safe to call concurrently.
|
||||
Reject(time.Time) IO[Void]
|
||||
|
||||
// Open records that the circuit breaker transitioned to the open state.
|
||||
// This is called when the failure threshold is exceeded and the circuit opens
|
||||
// to prevent further requests from reaching the failing service.
|
||||
//
|
||||
// Parameters:
|
||||
// - time.Time: The timestamp when the circuit opened
|
||||
//
|
||||
// Returns:
|
||||
// - IO[Void]: An IO operation that records the state transition when executed
|
||||
//
|
||||
// Thread Safety: Must be safe to call concurrently.
|
||||
Open(time.Time) IO[Void]
|
||||
|
||||
// Close records that the circuit breaker transitioned to the closed state.
|
||||
// This is called when:
|
||||
// - A canary request succeeds in half-open state
|
||||
// - The circuit is manually reset
|
||||
// - The circuit breaker is initialized
|
||||
//
|
||||
// Parameters:
|
||||
// - time.Time: The timestamp when the circuit closed
|
||||
//
|
||||
// Returns:
|
||||
// - IO[Void]: An IO operation that records the state transition when executed
|
||||
//
|
||||
// Thread Safety: Must be safe to call concurrently.
|
||||
Close(time.Time) IO[Void]
|
||||
|
||||
// Canary records that a canary (test) request is being attempted.
|
||||
// This is called when the circuit is in half-open state and a single test request
|
||||
// is allowed through to check if the service has recovered.
|
||||
//
|
||||
// Parameters:
|
||||
// - time.Time: The timestamp when the canary request was initiated
|
||||
//
|
||||
// Returns:
|
||||
// - IO[Void]: An IO operation that records the canary attempt when executed
|
||||
//
|
||||
// Thread Safety: Must be safe to call concurrently.
|
||||
Canary(time.Time) IO[Void]
|
||||
}
|
||||
|
||||
// loggingMetrics is a simple implementation of the Metrics interface that logs
|
||||
// circuit breaker events using Go's standard log.Logger.
|
||||
//
|
||||
// This implementation is thread-safe as log.Logger is safe for concurrent use.
|
||||
//
|
||||
// Fields:
|
||||
// - name: A human-readable name identifying the circuit breaker instance
|
||||
// - logger: The log.Logger instance used for writing log messages
|
||||
loggingMetrics struct {
|
||||
name string
|
||||
logger *log.Logger
|
||||
}
|
||||
)
|
||||
|
||||
// doLog is a helper method that creates an IO operation for logging a circuit breaker event.
|
||||
// It formats the log message with the event prefix, circuit breaker name, and timestamp.
|
||||
//
|
||||
// Parameters:
|
||||
// - prefix: The event type (e.g., "Accept", "Reject", "Open", "Close", "Canary")
|
||||
// - ct: The timestamp when the event occurred
|
||||
//
|
||||
// Returns:
|
||||
// - IO[Void]: An IO operation that logs the event when executed
|
||||
//
|
||||
// Thread Safety: Safe for concurrent use as log.Logger is thread-safe.
|
||||
//
|
||||
// Log Format: "<prefix>: <name>, <timestamp>"
|
||||
// Example: "Open: API-Service, 2026-01-09 15:30:45.123 +0100 CET"
|
||||
func (m *loggingMetrics) doLog(prefix string, ct time.Time) IO[Void] {
|
||||
return func() Void {
|
||||
m.logger.Printf("%s: %s, %s\n", prefix, m.name, ct)
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
|
||||
// Accept implements the Metrics interface for loggingMetrics.
|
||||
// Logs when a request is accepted through the circuit breaker.
|
||||
//
|
||||
// Thread Safety: Safe for concurrent use.
|
||||
func (m *loggingMetrics) Accept(ct time.Time) IO[Void] {
|
||||
return m.doLog("Accept", ct)
|
||||
}
|
||||
|
||||
// Open implements the Metrics interface for loggingMetrics.
|
||||
// Logs when the circuit breaker transitions to open state.
|
||||
//
|
||||
// Thread Safety: Safe for concurrent use.
|
||||
func (m *loggingMetrics) Open(ct time.Time) IO[Void] {
|
||||
return m.doLog("Open", ct)
|
||||
}
|
||||
|
||||
// Close implements the Metrics interface for loggingMetrics.
|
||||
// Logs when the circuit breaker transitions to closed state.
|
||||
//
|
||||
// Thread Safety: Safe for concurrent use.
|
||||
func (m *loggingMetrics) Close(ct time.Time) IO[Void] {
|
||||
return m.doLog("Close", ct)
|
||||
}
|
||||
|
||||
// Reject implements the Metrics interface for loggingMetrics.
|
||||
// Logs when a request is rejected because the circuit breaker is open.
|
||||
//
|
||||
// Thread Safety: Safe for concurrent use.
|
||||
func (m *loggingMetrics) Reject(ct time.Time) IO[Void] {
|
||||
return m.doLog("Reject", ct)
|
||||
}
|
||||
|
||||
// Canary implements the Metrics interface for loggingMetrics.
|
||||
// Logs when a canary (test) request is attempted in half-open state.
|
||||
//
|
||||
// Thread Safety: Safe for concurrent use.
|
||||
func (m *loggingMetrics) Canary(ct time.Time) IO[Void] {
|
||||
return m.doLog("Canary", ct)
|
||||
}
|
||||
|
||||
// MakeMetricsFromLogger creates a Metrics implementation that logs circuit breaker events
|
||||
// using the provided log.Logger.
|
||||
//
|
||||
// This is a simple metrics implementation suitable for development, debugging, and
|
||||
// basic production monitoring. For more sophisticated metrics collection (e.g., Prometheus,
|
||||
// StatsD), implement the Metrics interface with a custom type.
|
||||
//
|
||||
// Parameters:
|
||||
// - name: A human-readable name identifying the circuit breaker instance.
|
||||
// This name appears in all log messages to distinguish between multiple circuit breakers.
|
||||
// - logger: The log.Logger instance to use for writing log messages.
|
||||
// If nil, this will panic when metrics are recorded.
|
||||
//
|
||||
// Returns:
|
||||
// - Metrics: A thread-safe Metrics implementation that logs events
|
||||
//
|
||||
// Thread Safety: The returned Metrics implementation is safe for concurrent use
|
||||
// as log.Logger is thread-safe.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// logger := log.New(os.Stdout, "[CB] ", log.LstdFlags)
|
||||
// metrics := MakeMetricsFromLogger("UserService", logger)
|
||||
//
|
||||
// // Use with circuit breaker
|
||||
// io.Run(metrics.Open(time.Now()))
|
||||
// // Output: [CB] 2026/01/09 15:30:45 Open: UserService, 2026-01-09 15:30:45.123 +0100 CET
|
||||
//
|
||||
// io.Run(metrics.Reject(time.Now()))
|
||||
// // Output: [CB] 2026/01/09 15:30:46 Reject: UserService, 2026-01-09 15:30:46.456 +0100 CET
|
||||
func MakeMetricsFromLogger(name string, logger *log.Logger) Metrics {
|
||||
return &loggingMetrics{name: name, logger: logger}
|
||||
}
|
||||
506
v2/circuitbreaker/metrics_test.go
Normal file
506
v2/circuitbreaker/metrics_test.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMakeMetricsFromLogger tests the MakeMetricsFromLogger constructor
|
||||
func TestMakeMetricsFromLogger(t *testing.T) {
|
||||
t.Run("creates valid Metrics implementation", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
|
||||
assert.NotNil(t, metrics, "MakeMetricsFromLogger should return non-nil Metrics")
|
||||
})
|
||||
|
||||
t.Run("returns loggingMetrics type", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
|
||||
_, ok := metrics.(*loggingMetrics)
|
||||
assert.True(t, ok, "should return *loggingMetrics type")
|
||||
})
|
||||
|
||||
t.Run("stores name correctly", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
name := "MyCircuitBreaker"
|
||||
|
||||
metrics := MakeMetricsFromLogger(name, logger).(*loggingMetrics)
|
||||
|
||||
assert.Equal(t, name, metrics.name, "name should be stored correctly")
|
||||
})
|
||||
|
||||
t.Run("stores logger correctly", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger).(*loggingMetrics)
|
||||
|
||||
assert.Equal(t, logger, metrics.logger, "logger should be stored correctly")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoggingMetricsAccept tests the Accept method
|
||||
func TestLoggingMetricsAccept(t *testing.T) {
|
||||
t.Run("logs accept event with correct format", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Date(2026, 1, 9, 15, 30, 45, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Accept:", "should contain Accept prefix")
|
||||
assert.Contains(t, output, "TestCircuit", "should contain circuit name")
|
||||
assert.Contains(t, output, timestamp.String(), "should contain timestamp")
|
||||
})
|
||||
|
||||
t.Run("returns IO[Void] that can be executed", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Accept(timestamp)
|
||||
|
||||
assert.NotNil(t, ioOp, "should return non-nil IO operation")
|
||||
result := io.Run(ioOp)
|
||||
assert.NotNil(t, result, "IO operation should execute successfully")
|
||||
})
|
||||
|
||||
t.Run("logs multiple accept events", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
time1 := time.Date(2026, 1, 9, 15, 30, 0, 0, time.UTC)
|
||||
time2 := time.Date(2026, 1, 9, 15, 31, 0, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Accept(time1))
|
||||
io.Run(metrics.Accept(time2))
|
||||
|
||||
output := buf.String()
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
assert.Len(t, lines, 2, "should have 2 log lines")
|
||||
assert.Contains(t, lines[0], time1.String())
|
||||
assert.Contains(t, lines[1], time2.String())
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoggingMetricsReject tests the Reject method
|
||||
func TestLoggingMetricsReject(t *testing.T) {
|
||||
t.Run("logs reject event with correct format", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Date(2026, 1, 9, 15, 30, 45, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Reject(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Reject:", "should contain Reject prefix")
|
||||
assert.Contains(t, output, "TestCircuit", "should contain circuit name")
|
||||
assert.Contains(t, output, timestamp.String(), "should contain timestamp")
|
||||
})
|
||||
|
||||
t.Run("returns IO[Void] that can be executed", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Reject(timestamp)
|
||||
|
||||
assert.NotNil(t, ioOp, "should return non-nil IO operation")
|
||||
result := io.Run(ioOp)
|
||||
assert.NotNil(t, result, "IO operation should execute successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoggingMetricsOpen tests the Open method
|
||||
func TestLoggingMetricsOpen(t *testing.T) {
|
||||
t.Run("logs open event with correct format", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Date(2026, 1, 9, 15, 30, 45, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Open(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Open:", "should contain Open prefix")
|
||||
assert.Contains(t, output, "TestCircuit", "should contain circuit name")
|
||||
assert.Contains(t, output, timestamp.String(), "should contain timestamp")
|
||||
})
|
||||
|
||||
t.Run("returns IO[Void] that can be executed", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Open(timestamp)
|
||||
|
||||
assert.NotNil(t, ioOp, "should return non-nil IO operation")
|
||||
result := io.Run(ioOp)
|
||||
assert.NotNil(t, result, "IO operation should execute successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoggingMetricsClose tests the Close method
|
||||
func TestLoggingMetricsClose(t *testing.T) {
|
||||
t.Run("logs close event with correct format", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Date(2026, 1, 9, 15, 30, 45, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Close(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Close:", "should contain Close prefix")
|
||||
assert.Contains(t, output, "TestCircuit", "should contain circuit name")
|
||||
assert.Contains(t, output, timestamp.String(), "should contain timestamp")
|
||||
})
|
||||
|
||||
t.Run("returns IO[Void] that can be executed", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Close(timestamp)
|
||||
|
||||
assert.NotNil(t, ioOp, "should return non-nil IO operation")
|
||||
result := io.Run(ioOp)
|
||||
assert.NotNil(t, result, "IO operation should execute successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoggingMetricsCanary tests the Canary method
|
||||
func TestLoggingMetricsCanary(t *testing.T) {
|
||||
t.Run("logs canary event with correct format", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Date(2026, 1, 9, 15, 30, 45, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Canary(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Canary:", "should contain Canary prefix")
|
||||
assert.Contains(t, output, "TestCircuit", "should contain circuit name")
|
||||
assert.Contains(t, output, timestamp.String(), "should contain timestamp")
|
||||
})
|
||||
|
||||
t.Run("returns IO[Void] that can be executed", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Canary(timestamp)
|
||||
|
||||
assert.NotNil(t, ioOp, "should return non-nil IO operation")
|
||||
result := io.Run(ioOp)
|
||||
assert.NotNil(t, result, "IO operation should execute successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoggingMetricsDoLog tests the doLog helper method
|
||||
func TestLoggingMetricsDoLog(t *testing.T) {
|
||||
t.Run("formats log message correctly", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := &loggingMetrics{name: "TestCircuit", logger: logger}
|
||||
timestamp := time.Date(2026, 1, 9, 15, 30, 45, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.doLog("CustomEvent", timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "CustomEvent:", "should contain custom prefix")
|
||||
assert.Contains(t, output, "TestCircuit", "should contain circuit name")
|
||||
assert.Contains(t, output, timestamp.String(), "should contain timestamp")
|
||||
})
|
||||
|
||||
t.Run("handles different prefixes", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := &loggingMetrics{name: "TestCircuit", logger: logger}
|
||||
timestamp := time.Now()
|
||||
|
||||
prefixes := []string{"Accept", "Reject", "Open", "Close", "Canary", "Custom"}
|
||||
for _, prefix := range prefixes {
|
||||
buf.Reset()
|
||||
io.Run(metrics.doLog(prefix, timestamp))
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, prefix+":", "should contain prefix: "+prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetricsIntegration tests integration scenarios
|
||||
func TestMetricsIntegration(t *testing.T) {
|
||||
t.Run("logs complete circuit breaker lifecycle", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("APICircuit", logger)
|
||||
baseTime := time.Date(2026, 1, 9, 15, 30, 0, 0, time.UTC)
|
||||
|
||||
// Simulate circuit breaker lifecycle
|
||||
io.Run(metrics.Accept(baseTime)) // Request accepted
|
||||
io.Run(metrics.Accept(baseTime.Add(1 * time.Second))) // Another request
|
||||
io.Run(metrics.Open(baseTime.Add(2 * time.Second))) // Circuit opens
|
||||
io.Run(metrics.Reject(baseTime.Add(3 * time.Second))) // Request rejected
|
||||
io.Run(metrics.Canary(baseTime.Add(30 * time.Second))) // Canary attempt
|
||||
io.Run(metrics.Close(baseTime.Add(31 * time.Second))) // Circuit closes
|
||||
|
||||
output := buf.String()
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
assert.Len(t, lines, 6, "should have 6 log lines")
|
||||
|
||||
assert.Contains(t, lines[0], "Accept:")
|
||||
assert.Contains(t, lines[1], "Accept:")
|
||||
assert.Contains(t, lines[2], "Open:")
|
||||
assert.Contains(t, lines[3], "Reject:")
|
||||
assert.Contains(t, lines[4], "Canary:")
|
||||
assert.Contains(t, lines[5], "Close:")
|
||||
})
|
||||
|
||||
t.Run("distinguishes between multiple circuit breakers", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics1 := MakeMetricsFromLogger("Circuit1", logger)
|
||||
metrics2 := MakeMetricsFromLogger("Circuit2", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
io.Run(metrics1.Accept(timestamp))
|
||||
io.Run(metrics2.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Circuit1", "should contain first circuit name")
|
||||
assert.Contains(t, output, "Circuit2", "should contain second circuit name")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetricsThreadSafety tests concurrent access to metrics
|
||||
func TestMetricsThreadSafety(t *testing.T) {
|
||||
t.Run("handles concurrent metric recording", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("ConcurrentCircuit", logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Launch multiple goroutines recording metrics concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
timestamp := time.Now()
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
output := buf.String()
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
assert.Len(t, lines, numGoroutines, "should have logged all events")
|
||||
})
|
||||
|
||||
t.Run("handles concurrent different event types", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("ConcurrentCircuit", logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numIterations := 20
|
||||
wg.Add(numIterations * 5) // 5 event types
|
||||
|
||||
timestamp := time.Now()
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Run(metrics.Reject(timestamp))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Run(metrics.Open(timestamp))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Run(metrics.Close(timestamp))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Run(metrics.Canary(timestamp))
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
output := buf.String()
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
assert.Len(t, lines, numIterations*5, "should have logged all events")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetricsEdgeCases tests edge cases and special scenarios
|
||||
func TestMetricsEdgeCases(t *testing.T) {
|
||||
t.Run("handles empty circuit breaker name", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.NotEmpty(t, output, "should still log even with empty name")
|
||||
})
|
||||
|
||||
t.Run("handles very long circuit breaker name", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
longName := strings.Repeat("VeryLongCircuitBreakerName", 100)
|
||||
metrics := MakeMetricsFromLogger(longName, logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, longName, "should handle long names")
|
||||
})
|
||||
|
||||
t.Run("handles special characters in name", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
specialName := "Circuit-Breaker_123!@#$%^&*()"
|
||||
metrics := MakeMetricsFromLogger(specialName, logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, specialName, "should handle special characters")
|
||||
})
|
||||
|
||||
t.Run("handles zero time", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
zeroTime := time.Time{}
|
||||
|
||||
io.Run(metrics.Accept(zeroTime))
|
||||
|
||||
output := buf.String()
|
||||
assert.NotEmpty(t, output, "should handle zero time")
|
||||
assert.Contains(t, output, "Accept:")
|
||||
})
|
||||
|
||||
t.Run("handles far future time", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
futureTime := time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
|
||||
|
||||
io.Run(metrics.Accept(futureTime))
|
||||
|
||||
output := buf.String()
|
||||
assert.NotEmpty(t, output, "should handle far future time")
|
||||
assert.Contains(t, output, "9999")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetricsWithCustomLogger tests metrics with different logger configurations
|
||||
func TestMetricsWithCustomLogger(t *testing.T) {
|
||||
t.Run("works with logger with custom prefix", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "[CB] ", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "[CB]", "should include custom prefix")
|
||||
assert.Contains(t, output, "Accept:")
|
||||
})
|
||||
|
||||
t.Run("works with logger with flags", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", log.Ldate|log.Ltime)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
io.Run(metrics.Accept(timestamp))
|
||||
|
||||
output := buf.String()
|
||||
assert.NotEmpty(t, output, "should log with flags")
|
||||
assert.Contains(t, output, "Accept:")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetricsIOOperations tests IO operation behavior
|
||||
func TestMetricsIOOperations(t *testing.T) {
|
||||
t.Run("IO operations are lazy", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
// Create IO operation but don't execute it
|
||||
_ = metrics.Accept(timestamp)
|
||||
|
||||
// Buffer should be empty because IO wasn't executed
|
||||
assert.Empty(t, buf.String(), "IO operation should be lazy")
|
||||
})
|
||||
|
||||
t.Run("IO operations execute when run", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Accept(timestamp)
|
||||
io.Run(ioOp)
|
||||
|
||||
assert.NotEmpty(t, buf.String(), "IO operation should execute when run")
|
||||
})
|
||||
|
||||
t.Run("same IO operation can be executed multiple times", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
metrics := MakeMetricsFromLogger("TestCircuit", logger)
|
||||
timestamp := time.Now()
|
||||
|
||||
ioOp := metrics.Accept(timestamp)
|
||||
io.Run(ioOp)
|
||||
io.Run(ioOp)
|
||||
io.Run(ioOp)
|
||||
|
||||
output := buf.String()
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
assert.Len(t, lines, 3, "should execute multiple times")
|
||||
})
|
||||
}
|
||||
118
v2/circuitbreaker/types.go
Normal file
118
v2/circuitbreaker/types.go
Normal file
@@ -0,0 +1,118 @@
|
||||
// Package circuitbreaker provides a functional implementation of the circuit breaker pattern.
|
||||
// A circuit breaker prevents cascading failures by temporarily blocking requests to a failing service,
|
||||
// allowing it time to recover before retrying.
|
||||
//
|
||||
// # Thread Safety
|
||||
//
|
||||
// All data structures in this package are immutable except for IORef[BreakerState].
|
||||
// The IORef provides thread-safe mutable state through atomic operations.
|
||||
//
|
||||
// Immutable types (safe for concurrent use):
|
||||
// - BreakerState (Either[openState, ClosedState])
|
||||
// - openState
|
||||
// - ClosedState implementations (closedStateWithErrorCount, closedStateWithHistory)
|
||||
// - All function types and readers
|
||||
//
|
||||
// Mutable types (thread-safe through atomic operations):
|
||||
// - IORef[BreakerState] - provides atomic read/write/modify operations
|
||||
//
|
||||
// ClosedState implementations must be thread-safe. The recommended approach is to
|
||||
// return new copies for all operations (Empty, AddError, AddSuccess, Check), which
|
||||
// provides automatic thread safety through immutability.
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/either"
|
||||
"github.com/IBM/fp-go/v2/endomorphism"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/ioref"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/ord"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/predicate"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
"github.com/IBM/fp-go/v2/retry"
|
||||
"github.com/IBM/fp-go/v2/state"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ord is a type alias for ord.Ord, representing a total ordering on type A.
|
||||
// Used for comparing values in a consistent way.
|
||||
Ord[A any] = ord.Ord[A]
|
||||
|
||||
// Option is a type alias for option.Option, representing an optional value.
|
||||
// It can be either Some(value) or None, used for safe handling of nullable values.
|
||||
Option[A any] = option.Option[A]
|
||||
|
||||
// Endomorphism is a type alias for endomorphism.Endomorphism, representing a function from A to A.
|
||||
// Used for transformations that preserve the type.
|
||||
Endomorphism[A any] = endomorphism.Endomorphism[A]
|
||||
|
||||
// IO is a type alias for io.IO, representing a lazy computation that produces a value of type T.
|
||||
// Used for side-effectful operations that are deferred until execution.
|
||||
IO[T any] = io.IO[T]
|
||||
|
||||
// Pair is a type alias for pair.Pair, representing a tuple of two values.
|
||||
// Used for grouping related values together.
|
||||
Pair[L, R any] = pair.Pair[L, R]
|
||||
|
||||
// IORef is a type alias for ioref.IORef, representing a mutable reference to a value of type T.
|
||||
// Used for managing mutable state in a functional way with IO operations.
|
||||
IORef[T any] = ioref.IORef[T]
|
||||
|
||||
// State is a type alias for state.State, representing a stateful computation.
|
||||
// It transforms a state of type T and produces a result of type R.
|
||||
State[T, R any] = state.State[T, R]
|
||||
|
||||
// Either is a type alias for either.Either, representing a value that can be one of two types.
|
||||
// Left[E] represents an error or alternative path, Right[A] represents the success path.
|
||||
Either[E, A any] = either.Either[E, A]
|
||||
|
||||
// Predicate is a type alias for predicate.Predicate, representing a function that tests a value.
|
||||
// Returns true if the value satisfies the predicate condition, false otherwise.
|
||||
Predicate[A any] = predicate.Predicate[A]
|
||||
|
||||
// Reader is a type alias for reader.Reader, representing a computation that depends on an environment R
|
||||
// and produces a value of type A. Used for dependency injection and configuration.
|
||||
Reader[R, A any] = reader.Reader[R, A]
|
||||
|
||||
// openState represents the internal state when the circuit breaker is open.
|
||||
// In the open state, requests are blocked to give the failing service time to recover.
|
||||
// The circuit breaker will transition to a half-open state (canary request) after resetAt.
|
||||
openState struct {
|
||||
openedAt time.Time
|
||||
|
||||
// resetAt is the time when the circuit breaker should attempt a canary request
|
||||
// to test if the service has recovered. Calculated based on the retry policy.
|
||||
resetAt time.Time
|
||||
|
||||
// retryStatus tracks the current retry attempt information, including the number
|
||||
// of retries and the delay between attempts. Used by the retry policy to calculate
|
||||
// exponential backoff or other retry strategies.
|
||||
retryStatus retry.RetryStatus
|
||||
|
||||
// canaryRequest indicates whether the circuit is in half-open state, allowing
|
||||
// a single test request (canary) to check if the service has recovered.
|
||||
// If true, one request is allowed through to test the service.
|
||||
// If the canary succeeds, the circuit closes; if it fails, the circuit remains open
|
||||
// with an extended reset time.
|
||||
canaryRequest bool
|
||||
}
|
||||
|
||||
// BreakerState represents the current state of the circuit breaker.
|
||||
// It is an Either type where:
|
||||
// - Left[openState] represents an open circuit (requests are blocked)
|
||||
// - Right[ClosedState] represents a closed circuit (requests are allowed through)
|
||||
//
|
||||
// State Transitions:
|
||||
// - Closed -> Open: When failure threshold is exceeded in ClosedState
|
||||
// - Open -> Half-Open: When resetAt is reached (canaryRequest = true)
|
||||
// - Half-Open -> Closed: When canary request succeeds
|
||||
// - Half-Open -> Open: When canary request fails (with extended resetAt)
|
||||
BreakerState = Either[openState, ClosedState]
|
||||
|
||||
Void = function.Void
|
||||
)
|
||||
@@ -29,8 +29,8 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
return ChainIOK(io.FromConsumerK(c))
|
||||
func ChainConsumer[A any](c Consumer[A]) Operator[A, Void] {
|
||||
return ChainIOK(io.FromConsumer(c))
|
||||
}
|
||||
|
||||
// ChainFirstConsumer chains a consumer function into a ReaderIO computation, preserving the original value.
|
||||
@@ -61,5 +61,5 @@ func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
//
|
||||
//go:inline
|
||||
func ChainFirstConsumer[A any](c Consumer[A]) Operator[A, A] {
|
||||
return ChainFirstIOK(io.FromConsumerK(c))
|
||||
return ChainFirstIOK(io.FromConsumer(c))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/IBM/fp-go/v2/consumer"
|
||||
"github.com/IBM/fp-go/v2/either"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/lazy"
|
||||
"github.com/IBM/fp-go/v2/predicate"
|
||||
@@ -78,4 +79,6 @@ type (
|
||||
Trampoline[B, L any] = tailrec.Trampoline[B, L]
|
||||
|
||||
Predicate[A any] = predicate.Predicate[A]
|
||||
|
||||
Void = function.Void
|
||||
)
|
||||
|
||||
@@ -402,7 +402,125 @@ result := pipeline(db)(ctx)()
|
||||
|
||||
## Practical Benefits
|
||||
|
||||
### 1. **Improved Testability**
|
||||
### 1. **Performance: Eager Construction, Lazy Execution**
|
||||
|
||||
One of the most important but often overlooked benefits of point-free style is its performance characteristic: **the program structure is constructed eagerly (at definition time), but execution happens lazily (at runtime)**.
|
||||
|
||||
#### Construction Happens Once
|
||||
|
||||
When you define a pipeline using point-free style with `F.Flow`, `F.Pipe`, or function composition, the composition structure is built immediately at definition time:
|
||||
|
||||
```go
|
||||
// Point-free style - composition built ONCE at definition time
|
||||
var processUser = F.Flow3(
|
||||
getDatabase,
|
||||
SequenceReader[DatabaseConfig, Database],
|
||||
applyConfig(dbConfig),
|
||||
)
|
||||
// The pipeline structure is now fixed in memory
|
||||
```
|
||||
|
||||
#### Execution Happens on Demand
|
||||
|
||||
The actual computation only runs when you provide the final parameters and invoke the result:
|
||||
|
||||
```go
|
||||
// Execute multiple times - only execution cost, no re-composition
|
||||
result1 := processUser(ctx1)() // Fast - reuses pre-built pipeline
|
||||
result2 := processUser(ctx2)() // Fast - reuses pre-built pipeline
|
||||
result3 := processUser(ctx3)() // Fast - reuses pre-built pipeline
|
||||
```
|
||||
|
||||
#### Performance Benefit for Repeated Execution
|
||||
|
||||
If a flow is executed multiple times, the point-free style is significantly more efficient because:
|
||||
|
||||
1. **Composition overhead is paid once** - The function composition happens at definition time
|
||||
2. **No re-interpretation** - Each execution doesn't need to rebuild the pipeline
|
||||
3. **Memory efficiency** - The composed function is created once and reused
|
||||
4. **Better for hot paths** - Ideal for high-frequency operations
|
||||
|
||||
#### Comparison: Point-Free vs. Imperative
|
||||
|
||||
```go
|
||||
// Imperative style - reconstruction on EVERY call
|
||||
func processUserImperative(ctx context.Context) Either[error, Database] {
|
||||
// This function body is re-interpreted/executed every time
|
||||
dbComp := getDatabase()(ctx)()
|
||||
if dbReader, err := either.Unwrap(dbComp); err != nil {
|
||||
return Left[Database](err)
|
||||
}
|
||||
db := dbReader(dbConfig)
|
||||
// ... manual composition happens on every invocation
|
||||
return Right[error](db)
|
||||
}
|
||||
|
||||
// Point-free style - composition built ONCE
|
||||
var processUserPointFree = F.Flow3(
|
||||
getDatabase,
|
||||
SequenceReader[DatabaseConfig, Database],
|
||||
applyConfig(dbConfig),
|
||||
)
|
||||
|
||||
// Benchmark scenario: 1000 executions
|
||||
for i := 0; i < 1000; i++ {
|
||||
// Imperative: pays composition cost 1000 times
|
||||
result := processUserImperative(ctx)()
|
||||
|
||||
// Point-free: pays composition cost once, execution cost 1000 times
|
||||
result := processUserPointFree(ctx)()
|
||||
}
|
||||
```
|
||||
|
||||
#### When This Matters Most
|
||||
|
||||
The performance benefit of eager construction is particularly important for:
|
||||
|
||||
- **High-frequency operations** - APIs, event handlers, request processors
|
||||
- **Batch processing** - Same pipeline processes many items
|
||||
- **Long-running services** - Pipelines defined once at startup, executed millions of times
|
||||
- **Hot code paths** - Performance-critical sections that run repeatedly
|
||||
- **Stream processing** - Processing continuous data streams
|
||||
|
||||
#### Example: API Handler
|
||||
|
||||
```go
|
||||
// Define pipeline once at application startup
|
||||
var handleUserRequest = F.Flow4(
|
||||
parseRequest,
|
||||
SequenceReader[Database, UserRequest],
|
||||
applyDatabase(db),
|
||||
Chain(validateAndProcess),
|
||||
)
|
||||
|
||||
// Execute thousands of times per second
|
||||
func apiHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// No composition overhead - just execution
|
||||
result := handleUserRequest(r.Context())()
|
||||
// ... handle result
|
||||
}
|
||||
```
|
||||
|
||||
#### Memory and CPU Efficiency
|
||||
|
||||
```go
|
||||
// Point-free: O(1) composition overhead
|
||||
var pipeline = F.Flow5(step1, step2, step3, step4, step5)
|
||||
// Composed once, stored in memory
|
||||
|
||||
// Execute N times: O(N) execution cost only
|
||||
for i := 0; i < N; i++ {
|
||||
result := pipeline(input[i])
|
||||
}
|
||||
|
||||
// Imperative: O(N) composition + execution cost
|
||||
for i := 0; i < N; i++ {
|
||||
// Composition logic runs every iteration
|
||||
result := step5(step4(step3(step2(step1(input[i])))))
|
||||
}
|
||||
```
|
||||
|
||||
### 2. **Improved Testability**
|
||||
|
||||
Inject test dependencies easily:
|
||||
|
||||
@@ -418,7 +536,7 @@ testQuery := queryWithDB(testDB)
|
||||
// Same computation, different dependencies
|
||||
```
|
||||
|
||||
### 2. **Better Separation of Concerns**
|
||||
### 3. **Better Separation of Concerns**
|
||||
|
||||
Separate configuration from execution:
|
||||
|
||||
@@ -431,7 +549,7 @@ computation := sequenced(cfg)
|
||||
result := computation(ctx)()
|
||||
```
|
||||
|
||||
### 3. **Enhanced Composability**
|
||||
### 4. **Enhanced Composability**
|
||||
|
||||
Build complex pipelines from simple pieces:
|
||||
|
||||
@@ -444,7 +562,7 @@ var processUser = F.Flow4(
|
||||
)
|
||||
```
|
||||
|
||||
### 4. **Reduced Boilerplate**
|
||||
### 5. **Reduced Boilerplate**
|
||||
|
||||
No need to manually thread parameters:
|
||||
|
||||
@@ -651,6 +769,7 @@ var processUser = func(userID string) ReaderIOResult[ProcessedUser] {
|
||||
5. **Reusability** increases as computations can be specialized early
|
||||
6. **Testability** improves through easy dependency injection
|
||||
7. **Separation of concerns** is clearer (configuration vs. execution)
|
||||
8. **Performance benefit**: Eager construction (once) + lazy execution (many times) = efficiency for repeated operations
|
||||
|
||||
## When to Use Sequence
|
||||
|
||||
|
||||
60
v2/context/readerioresult/circuitbreaker.go
Normal file
60
v2/context/readerioresult/circuitbreaker.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package readerioresult
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/circuitbreaker"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/retry"
|
||||
)
|
||||
|
||||
type (
|
||||
ClosedState = circuitbreaker.ClosedState
|
||||
|
||||
Env[T any] = Pair[IORef[circuitbreaker.BreakerState], ReaderIOResult[T]]
|
||||
|
||||
CircuitBreaker[T any] = State[Env[T], ReaderIOResult[T]]
|
||||
)
|
||||
|
||||
func MakeCircuitBreaker[T any](
|
||||
currentTime IO[time.Time],
|
||||
closedState ClosedState,
|
||||
checkError option.Kleisli[error, error],
|
||||
policy retry.RetryPolicy,
|
||||
metrics circuitbreaker.Metrics,
|
||||
) CircuitBreaker[T] {
|
||||
return circuitbreaker.MakeCircuitBreaker[error, T](
|
||||
Left,
|
||||
ChainFirstIOK,
|
||||
ChainFirstLeftIOK,
|
||||
FromIO,
|
||||
Flap,
|
||||
Flatten,
|
||||
|
||||
currentTime,
|
||||
closedState,
|
||||
circuitbreaker.MakeCircuitBreakerError,
|
||||
checkError,
|
||||
policy,
|
||||
metrics,
|
||||
)
|
||||
}
|
||||
|
||||
func MakeSingletonBreaker[T any](
|
||||
currentTime IO[time.Time],
|
||||
closedState ClosedState,
|
||||
checkError option.Kleisli[error, error],
|
||||
policy retry.RetryPolicy,
|
||||
metrics circuitbreaker.Metrics,
|
||||
) Operator[T, T] {
|
||||
return circuitbreaker.MakeSingletonBreaker(
|
||||
MakeCircuitBreaker[T](
|
||||
currentTime,
|
||||
closedState,
|
||||
checkError,
|
||||
policy,
|
||||
metrics,
|
||||
),
|
||||
closedState,
|
||||
)
|
||||
}
|
||||
246
v2/context/readerioresult/circuitbreaker_doc.md
Normal file
246
v2/context/readerioresult/circuitbreaker_doc.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Circuit Breaker Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
The `circuitbreaker.go` file provides a circuit breaker implementation for the `readerioresult` package. A circuit breaker is a design pattern used to detect failures and prevent cascading failures in distributed systems by temporarily blocking operations that are likely to fail.
|
||||
|
||||
## Package
|
||||
|
||||
```go
|
||||
package readerioresult
|
||||
```
|
||||
|
||||
This is part of the `context/readerioresult` package, which provides functional programming abstractions for operations that:
|
||||
- Depend on a `context.Context` (Reader aspect)
|
||||
- Perform side effects (IO aspect)
|
||||
- Can fail with an `error` (Result/Either aspect)
|
||||
|
||||
## Type Definitions
|
||||
|
||||
### ClosedState
|
||||
|
||||
```go
|
||||
type ClosedState = circuitbreaker.ClosedState
|
||||
```
|
||||
|
||||
A type alias for the circuit breaker's closed state. When the circuit is closed, requests are allowed to pass through normally. The closed state tracks success and failure counts to determine when to open the circuit.
|
||||
|
||||
### Env[T any]
|
||||
|
||||
```go
|
||||
type Env[T any] = Pair[IORef[circuitbreaker.BreakerState], ReaderIOResult[T]]
|
||||
```
|
||||
|
||||
The environment type for the circuit breaker state machine. It contains:
|
||||
- `IORef[circuitbreaker.BreakerState]`: A mutable reference to the current breaker state
|
||||
- `ReaderIOResult[T]`: The computation to be protected by the circuit breaker
|
||||
|
||||
### CircuitBreaker[T any]
|
||||
|
||||
```go
|
||||
type CircuitBreaker[T any] = State[Env[T], ReaderIOResult[T]]
|
||||
```
|
||||
|
||||
The main circuit breaker type. It's a state monad that:
|
||||
- Takes an environment containing the breaker state and the protected computation
|
||||
- Returns a new environment and a wrapped computation that respects the circuit breaker logic
|
||||
|
||||
## Functions
|
||||
|
||||
### MakeCircuitBreaker
|
||||
|
||||
```go
|
||||
func MakeCircuitBreaker[T any](
|
||||
currentTime IO[time.Time],
|
||||
closedState ClosedState,
|
||||
checkError option.Kleisli[error, error],
|
||||
policy retry.RetryPolicy,
|
||||
logger io.Kleisli[string, string],
|
||||
) CircuitBreaker[T]
|
||||
```
|
||||
|
||||
Creates a new circuit breaker with the specified configuration.
|
||||
|
||||
#### Parameters
|
||||
|
||||
- **currentTime** `IO[time.Time]`: A function that returns the current time. This can be a virtual timer for testing purposes, allowing you to control time progression in tests.
|
||||
|
||||
- **closedState** `ClosedState`: The initial closed state configuration. This defines:
|
||||
- Maximum number of failures before opening the circuit
|
||||
- Time window for counting failures
|
||||
- Other closed state parameters
|
||||
|
||||
- **checkError** `option.Kleisli[error, error]`: A function that determines whether an error should be counted as a failure. Returns:
|
||||
- `Some(error)`: The error should be counted as a failure
|
||||
- `None`: The error should be ignored (not counted as a failure)
|
||||
|
||||
This allows you to distinguish between transient errors (that should trigger circuit breaking) and permanent errors (that shouldn't).
|
||||
|
||||
- **policy** `retry.RetryPolicy`: The retry policy that determines:
|
||||
- How long to wait before attempting to close the circuit (reset time)
|
||||
- Exponential backoff or other delay strategies
|
||||
- Maximum number of retry attempts
|
||||
|
||||
- **logger** `io.Kleisli[string, string]`: A logging function for circuit breaker events. Receives log messages and performs side effects (like writing to a log file or console).
|
||||
|
||||
#### Returns
|
||||
|
||||
A `CircuitBreaker[T]` that wraps computations with circuit breaker logic.
|
||||
|
||||
#### Circuit Breaker States
|
||||
|
||||
The circuit breaker operates in three states:
|
||||
|
||||
1. **Closed**: Normal operation. Requests pass through. Failures are counted.
|
||||
- If failure threshold is exceeded, transitions to Open state
|
||||
|
||||
2. **Open**: Circuit is broken. Requests fail immediately without executing.
|
||||
- After reset time expires, transitions to Half-Open state
|
||||
|
||||
3. **Half-Open** (Canary): Testing if the service has recovered.
|
||||
- Allows a single test request (canary request)
|
||||
- If canary succeeds, transitions to Closed state
|
||||
- If canary fails, transitions back to Open state with extended reset time
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The function delegates to the generic `circuitbreaker.MakeCircuitBreaker` function, providing the necessary type-specific operations:
|
||||
|
||||
- **Left**: Creates a failed computation from an error
|
||||
- **ChainFirstIOK**: Chains an IO operation that runs for side effects on success
|
||||
- **ChainFirstLeftIOK**: Chains an IO operation that runs for side effects on failure
|
||||
- **FromIO**: Lifts an IO computation into ReaderIOResult
|
||||
- **Flap**: Applies a computation to a function
|
||||
- **Flatten**: Flattens nested ReaderIOResult structures
|
||||
|
||||
These operations allow the generic circuit breaker to work with the `ReaderIOResult` monad.
|
||||
|
||||
## Usage Example
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/circuitbreaker"
|
||||
"github.com/IBM/fp-go/v2/context/readerioresult"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/ioref"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/retry"
|
||||
)
|
||||
|
||||
// Create a circuit breaker configuration
|
||||
func createCircuitBreaker() readerioresult.CircuitBreaker[string] {
|
||||
// Use real time
|
||||
currentTime := func() time.Time { return time.Now() }
|
||||
|
||||
// Configure closed state: open after 5 failures in 10 seconds
|
||||
closedState := circuitbreaker.MakeClosedState(5, 10*time.Second)
|
||||
|
||||
// Check all errors (count all as failures)
|
||||
checkError := func(err error) option.Option[error] {
|
||||
return option.Some(err)
|
||||
}
|
||||
|
||||
// Retry policy: exponential backoff with max 5 retries
|
||||
policy := retry.Monoid.Concat(
|
||||
retry.LimitRetries(5),
|
||||
retry.ExponentialBackoff(100*time.Millisecond),
|
||||
)
|
||||
|
||||
// Simple logger
|
||||
logger := func(msg string) io.IO[string] {
|
||||
return func() string {
|
||||
fmt.Println("Circuit Breaker:", msg)
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
return readerioresult.MakeCircuitBreaker[string](
|
||||
currentTime,
|
||||
closedState,
|
||||
checkError,
|
||||
policy,
|
||||
logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Use the circuit breaker
|
||||
func main() {
|
||||
cb := createCircuitBreaker()
|
||||
|
||||
// Create initial state
|
||||
stateRef := ioref.NewIORef(circuitbreaker.InitialState())
|
||||
|
||||
// Your protected operation
|
||||
operation := func(ctx context.Context) readerioresult.IOResult[string] {
|
||||
return func() readerioresult.Result[string] {
|
||||
// Your actual operation here
|
||||
return result.Of("success")
|
||||
}
|
||||
}
|
||||
|
||||
// Apply circuit breaker
|
||||
env := pair.MakePair(stateRef, operation)
|
||||
result := cb(env)
|
||||
|
||||
// Execute the protected operation
|
||||
ctx := context.Background()
|
||||
protectedOp := pair.Tail(result)
|
||||
outcome := protectedOp(ctx)()
|
||||
}
|
||||
```
|
||||
|
||||
## Testing with Virtual Timer
|
||||
|
||||
For testing, you can provide a virtual timer instead of `time.Now()`:
|
||||
|
||||
```go
|
||||
// Virtual timer for testing
|
||||
type VirtualTimer struct {
|
||||
current time.Time
|
||||
}
|
||||
|
||||
func (vt *VirtualTimer) Now() time.Time {
|
||||
return vt.current
|
||||
}
|
||||
|
||||
func (vt *VirtualTimer) Advance(d time.Duration) {
|
||||
vt.current = vt.current.Add(d)
|
||||
}
|
||||
|
||||
// Use in tests
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
vt := &VirtualTimer{current: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)}
|
||||
|
||||
currentTime := func() time.Time { return vt.Now() }
|
||||
|
||||
cb := readerioresult.MakeCircuitBreaker[string](
|
||||
currentTime,
|
||||
closedState,
|
||||
checkError,
|
||||
policy,
|
||||
logger,
|
||||
)
|
||||
|
||||
// Test circuit breaker behavior
|
||||
// Advance time as needed
|
||||
vt.Advance(5 * time.Second)
|
||||
}
|
||||
```
|
||||
|
||||
## Related Types
|
||||
|
||||
- `circuitbreaker.BreakerState`: The internal state of the circuit breaker (closed or open)
|
||||
- `circuitbreaker.ClosedState`: Configuration for the closed state
|
||||
- `retry.RetryPolicy`: Policy for retry delays and limits
|
||||
- `option.Kleisli[error, error]`: Function type for error checking
|
||||
- `io.Kleisli[string, string]`: Function type for logging
|
||||
|
||||
## See Also
|
||||
|
||||
- `circuitbreaker` package: Generic circuit breaker implementation
|
||||
- `retry` package: Retry policies and strategies
|
||||
- `readerioresult` package: Core ReaderIOResult monad operations
|
||||
974
v2/context/readerioresult/circuitbreaker_test.go
Normal file
974
v2/context/readerioresult/circuitbreaker_test.go
Normal file
@@ -0,0 +1,974 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package readerioresult
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/fp-go/v2/array"
|
||||
"github.com/IBM/fp-go/v2/circuitbreaker"
|
||||
"github.com/IBM/fp-go/v2/ioref"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/result"
|
||||
"github.com/IBM/fp-go/v2/retry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// VirtualTimer provides a controllable time source for testing
|
||||
type VirtualTimer struct {
|
||||
mu sync.Mutex
|
||||
current time.Time
|
||||
}
|
||||
|
||||
// NewVirtualTimer creates a new virtual timer starting at the given time
|
||||
func NewVirtualTimer(start time.Time) *VirtualTimer {
|
||||
return &VirtualTimer{current: start}
|
||||
}
|
||||
|
||||
// Now returns the current virtual time
|
||||
func (vt *VirtualTimer) Now() time.Time {
|
||||
vt.mu.Lock()
|
||||
defer vt.mu.Unlock()
|
||||
return vt.current
|
||||
}
|
||||
|
||||
// Advance moves the virtual time forward by the given duration
|
||||
func (vt *VirtualTimer) Advance(d time.Duration) {
|
||||
vt.mu.Lock()
|
||||
defer vt.mu.Unlock()
|
||||
vt.current = vt.current.Add(d)
|
||||
}
|
||||
|
||||
// Set sets the virtual time to a specific value
|
||||
func (vt *VirtualTimer) Set(t time.Time) {
|
||||
vt.mu.Lock()
|
||||
defer vt.mu.Unlock()
|
||||
vt.current = t
|
||||
}
|
||||
|
||||
// Helper function to create a test logger that collects messages
|
||||
func testMetrics(_ *[]string) circuitbreaker.Metrics {
|
||||
return circuitbreaker.MakeMetricsFromLogger("testMetrics", log.Default())
|
||||
}
|
||||
|
||||
// Helper function to create a simple closed state
|
||||
func testCBClosedState() circuitbreaker.ClosedState {
|
||||
return circuitbreaker.MakeClosedStateCounter(3)
|
||||
}
|
||||
|
||||
// Helper function to create a test retry policy
|
||||
func testCBRetryPolicy() retry.RetryPolicy {
|
||||
return retry.Monoid.Concat(
|
||||
retry.LimitRetries(3),
|
||||
retry.ExponentialBackoff(100*time.Millisecond),
|
||||
)
|
||||
}
|
||||
|
||||
// Helper function that checks all errors
|
||||
func checkAllErrors(err error) option.Option[error] {
|
||||
return option.Some(err)
|
||||
}
|
||||
|
||||
// Helper function that ignores specific errors
|
||||
func ignoreSpecificError(ignoredMsg string) func(error) option.Option[error] {
|
||||
return func(err error) option.Option[error] {
|
||||
if err.Error() == ignoredMsg {
|
||||
return option.None[error]()
|
||||
}
|
||||
return option.Some(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_SuccessfulOperation tests that successful operations
|
||||
// pass through the circuit breaker without issues
|
||||
func TestCircuitBreaker_SuccessfulOperation(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
// Create initial state
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
// Successful operation
|
||||
operation := Of("success")
|
||||
|
||||
// Apply circuit breaker
|
||||
env := pair.MakePair(stateRef, operation)
|
||||
resultEnv := cb(env)
|
||||
|
||||
// Execute
|
||||
ctx := t.Context()
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.Equal(t, result.Of("success"), outcome)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_SingleFailure tests that a single failure is handled
|
||||
// but doesn't open the circuit
|
||||
func TestCircuitBreaker_SingleFailure(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
expError := errors.New("operation failed")
|
||||
|
||||
// Failing operation
|
||||
operation := Left[string](expError)
|
||||
|
||||
env := pair.MakePair(stateRef, operation)
|
||||
resultEnv := cb(env)
|
||||
|
||||
ctx := t.Context()
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.Equal(t, result.Left[string](expError), outcome)
|
||||
|
||||
// Circuit should still be closed after one failure
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(state))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_OpensAfterThreshold tests that the circuit opens
|
||||
// after exceeding the failure threshold
|
||||
func TestCircuitBreaker_OpensAfterThreshold(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(), // Opens after 3 failures
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
expError := errors.New("operation failed")
|
||||
|
||||
// Failing operation
|
||||
operation := Left[string](expError)
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
// Execute 3 failures to open the circuit
|
||||
for range 3 {
|
||||
env := pair.MakePair(stateRef, operation)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
assert.Equal(t, result.Left[string](expError), outcome)
|
||||
}
|
||||
|
||||
// Circuit should now be open
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsOpen(state))
|
||||
|
||||
// Next request should fail immediately with circuit breaker error
|
||||
env := pair.MakePair(stateRef, operation)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.True(t, result.IsLeft(outcome))
|
||||
_, err := result.Unwrap(outcome)
|
||||
var cbErr *circuitbreaker.CircuitBreakerError
|
||||
assert.ErrorAs(t, err, &cbErr)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_HalfOpenAfterResetTime tests that the circuit
|
||||
// transitions to half-open state after the reset time
|
||||
func TestCircuitBreaker_HalfOpenAfterResetTime(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
expError := errors.New("operation failed")
|
||||
|
||||
// Failing operation
|
||||
failingOp := Left[string](expError)
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
// Open the circuit with 3 failures
|
||||
for range 3 {
|
||||
env := pair.MakePair(stateRef, failingOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.Equal(t, result.Left[string](expError), outcome)
|
||||
}
|
||||
|
||||
// Verify circuit is open
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsOpen(state))
|
||||
|
||||
// Advance time past the reset time (exponential backoff starts at 100ms)
|
||||
vt.Advance(200 * time.Millisecond)
|
||||
|
||||
// Now create a successful operation for the canary request
|
||||
successOp := Of("success")
|
||||
|
||||
// Next request should be a canary request
|
||||
env := pair.MakePair(stateRef, successOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
// Canary should succeed
|
||||
assert.Equal(t, result.Of("success"), outcome)
|
||||
|
||||
// Circuit should now be closed again
|
||||
state = ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(state))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_CanaryFailureExtendsOpenTime tests that a failed
|
||||
// canary request extends the open time
|
||||
func TestCircuitBreaker_CanaryFailureExtendsOpenTime(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
expError := errors.New("operation failed")
|
||||
|
||||
// Failing operation
|
||||
failingOp := Left[string](expError)
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
// Open the circuit
|
||||
for range 3 {
|
||||
env := pair.MakePair(stateRef, failingOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
assert.Equal(t, result.Left[string](expError), outcome)
|
||||
}
|
||||
|
||||
// Advance time to trigger canary
|
||||
vt.Advance(200 * time.Millisecond)
|
||||
|
||||
// Canary request fails
|
||||
env := pair.MakePair(stateRef, failingOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.True(t, result.IsLeft(outcome))
|
||||
|
||||
// Circuit should still be open
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsOpen(state))
|
||||
|
||||
// Immediate next request should fail with circuit breaker error
|
||||
env = pair.MakePair(stateRef, failingOp)
|
||||
resultEnv = cb(env)
|
||||
protectedOp = pair.Tail(resultEnv)
|
||||
outcome = protectedOp(ctx)()
|
||||
|
||||
assert.True(t, result.IsLeft(outcome))
|
||||
_, err := result.Unwrap(outcome)
|
||||
var cbErr *circuitbreaker.CircuitBreakerError
|
||||
assert.ErrorAs(t, err, &cbErr)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IgnoredErrorsDoNotCount tests that errors filtered
|
||||
// by checkError don't count toward opening the circuit
|
||||
func TestCircuitBreaker_IgnoredErrorsDoNotCount(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
// Ignore "ignorable error"
|
||||
checkError := ignoreSpecificError("ignorable error")
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkError,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
ctx := t.Context()
|
||||
ignorableError := errors.New("ignorable error")
|
||||
|
||||
// Execute 5 ignorable errors
|
||||
ignorableOp := Left[string](ignorableError)
|
||||
|
||||
for range 5 {
|
||||
env := pair.MakePair(stateRef, ignorableOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
assert.Equal(t, result.Left[string](ignorableError), outcome)
|
||||
}
|
||||
|
||||
// Circuit should still be closed
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(state))
|
||||
|
||||
realError := errors.New("real error")
|
||||
|
||||
// Now send a real error
|
||||
realErrorOp := Left[string](realError)
|
||||
|
||||
env := pair.MakePair(stateRef, realErrorOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.Equal(t, result.Left[string](realError), outcome)
|
||||
|
||||
// Circuit should still be closed (only 1 counted error)
|
||||
state = ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(state))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_MixedSuccessAndFailure tests the circuit behavior
|
||||
// with a mix of successful and failed operations
|
||||
func TestCircuitBreaker_MixedSuccessAndFailure(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
successOp := Of("success")
|
||||
expError := errors.New("failure")
|
||||
|
||||
failOp := Left[string](expError)
|
||||
|
||||
// Pattern: fail, fail, success, fail
|
||||
ops := array.From(failOp, failOp, successOp, failOp)
|
||||
|
||||
for _, op := range ops {
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
_ = protectedOp(ctx)()
|
||||
}
|
||||
|
||||
// Circuit should still be closed (success resets the count)
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(state))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentOperations tests that the circuit breaker
|
||||
// handles concurrent operations correctly
|
||||
func TestCircuitBreaker_ConcurrentOperations(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[int](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]Result[int], 10)
|
||||
|
||||
// Launch 10 concurrent operations
|
||||
for i := range 10 {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
op := Of(idx)
|
||||
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
results[idx] = protectedOp(ctx)()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All operations should succeed
|
||||
for i, res := range results {
|
||||
assert.True(t, result.IsRight(res), "Operation %d should succeed", i)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_DifferentTypes tests that the circuit breaker works
|
||||
// with different result types
|
||||
func TestCircuitBreaker_DifferentTypes(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
// Test with int
|
||||
cbInt := MakeCircuitBreaker[int](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRefInt := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
opInt := Of(42)
|
||||
|
||||
ctx := t.Context()
|
||||
envInt := pair.MakePair(stateRefInt, opInt)
|
||||
resultEnvInt := cbInt(envInt)
|
||||
protectedOpInt := pair.Tail(resultEnvInt)
|
||||
outcomeInt := protectedOpInt(ctx)()
|
||||
|
||||
assert.Equal(t, result.Of(42), outcomeInt)
|
||||
|
||||
// Test with struct
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
cbUser := MakeCircuitBreaker[User](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRefUser := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
opUser := Of(User{ID: 1, Name: "Alice"})
|
||||
|
||||
envUser := pair.MakePair(stateRefUser, opUser)
|
||||
resultEnvUser := cbUser(envUser)
|
||||
protectedOpUser := pair.Tail(resultEnvUser)
|
||||
outcomeUser := protectedOpUser(ctx)()
|
||||
|
||||
require.Equal(t, result.Of(User{ID: 1, Name: "Alice"}), outcomeUser)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_VirtualTimerAdvancement tests that the virtual timer
|
||||
// correctly controls time-based behavior
|
||||
func TestCircuitBreaker_VirtualTimerAdvancement(t *testing.T) {
|
||||
startTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
vt := NewVirtualTimer(startTime)
|
||||
|
||||
// Verify initial time
|
||||
assert.Equal(t, startTime, vt.Now())
|
||||
|
||||
// Advance by 1 hour
|
||||
vt.Advance(1 * time.Hour)
|
||||
assert.Equal(t, startTime.Add(1*time.Hour), vt.Now())
|
||||
|
||||
// Advance by 30 minutes
|
||||
vt.Advance(30 * time.Minute)
|
||||
assert.Equal(t, startTime.Add(90*time.Minute), vt.Now())
|
||||
|
||||
// Set to specific time
|
||||
newTime := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
vt.Set(newTime)
|
||||
assert.Equal(t, newTime, vt.Now())
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_InitialState tests that the circuit starts in closed state
|
||||
func TestCircuitBreaker_InitialState(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
// Check initial state is closed
|
||||
state := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(state), "Circuit should start in closed state")
|
||||
|
||||
// First operation should execute normally
|
||||
op := Of("first operation")
|
||||
|
||||
ctx := t.Context()
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.Equal(t, result.Of("first operation"), outcome)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ErrorMessageFormat tests that circuit breaker errors
|
||||
// have appropriate error messages
|
||||
func TestCircuitBreaker_ErrorMessageFormat(t *testing.T) {
|
||||
vt := NewVirtualTimer(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
expError := errors.New("service unavailable")
|
||||
|
||||
failOp := Left[string](expError)
|
||||
|
||||
// Open the circuit
|
||||
for range 3 {
|
||||
env := pair.MakePair(stateRef, failOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
_ = protectedOp(ctx)()
|
||||
}
|
||||
|
||||
// Next request should fail with circuit breaker error
|
||||
env := pair.MakePair(stateRef, failOp)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
assert.True(t, result.IsLeft[string](outcome))
|
||||
|
||||
// Error message should indicate circuit breaker is open
|
||||
_, err := result.Unwrap(outcome)
|
||||
errMsg := err.Error()
|
||||
assert.Contains(t, errMsg, "circuit", "Error should mention circuit breaker")
|
||||
}
|
||||
|
||||
// RequestSpec defines a virtual request with timing and outcome information
|
||||
type RequestSpec struct {
|
||||
ID int // Unique identifier for the request
|
||||
StartTime time.Duration // Virtual start time relative to test start
|
||||
Duration time.Duration // How long the request takes to execute
|
||||
ShouldFail bool // Whether this request should fail
|
||||
}
|
||||
|
||||
// RequestResult captures the outcome of a request execution
|
||||
type RequestResult struct {
|
||||
ID int
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Success bool
|
||||
Error error
|
||||
CircuitBreakerError bool // True if failed due to circuit breaker being open
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentBatchWithThresholdExceeded tests a complex
|
||||
// concurrent scenario where:
|
||||
// 1. Initial requests succeed
|
||||
// 2. A batch of failures exceeds the threshold, opening the circuit
|
||||
// 3. Subsequent requests fail immediately due to open circuit
|
||||
// 4. After timeout, a canary request succeeds
|
||||
// 5. Following requests succeed again
|
||||
func TestCircuitBreaker_ConcurrentBatchWithThresholdExceeded(t *testing.T) {
|
||||
startTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
vt := NewVirtualTimer(startTime)
|
||||
var logMessages []string
|
||||
|
||||
// Circuit opens after 3 failures, with exponential backoff starting at 100ms
|
||||
cb := MakeCircuitBreaker[string](
|
||||
vt.Now,
|
||||
testCBClosedState(), // Opens after 3 failures
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(), // 100ms initial backoff
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
ctx := t.Context()
|
||||
|
||||
// Define the request sequence
|
||||
// Phase 1: Initial successes (0-100ms)
|
||||
// Phase 2: Failures that exceed threshold (100-200ms) - should open circuit
|
||||
// Phase 3: Requests during open circuit (200-300ms) - should fail immediately
|
||||
// Phase 4: After timeout (400ms+) - canary succeeds, then more successes
|
||||
requests := []RequestSpec{
|
||||
// Phase 1: Initial successful requests
|
||||
{ID: 1, StartTime: 0 * time.Millisecond, Duration: 10 * time.Millisecond, ShouldFail: false},
|
||||
{ID: 2, StartTime: 20 * time.Millisecond, Duration: 10 * time.Millisecond, ShouldFail: false},
|
||||
|
||||
// Phase 2: Sequential failures that exceed threshold (3 failures)
|
||||
{ID: 3, StartTime: 100 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: true},
|
||||
{ID: 4, StartTime: 110 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: true},
|
||||
{ID: 5, StartTime: 120 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: true},
|
||||
{ID: 6, StartTime: 130 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: true},
|
||||
|
||||
// Phase 3: Requests during open circuit - should fail with circuit breaker error
|
||||
{ID: 7, StartTime: 200 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: false},
|
||||
{ID: 8, StartTime: 210 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: false},
|
||||
{ID: 9, StartTime: 220 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: false},
|
||||
|
||||
// Phase 4: After reset timeout (100ms backoff from last failure at ~125ms = ~225ms)
|
||||
// Wait longer to ensure we're past the reset time
|
||||
{ID: 10, StartTime: 400 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: false}, // Canary succeeds
|
||||
{ID: 11, StartTime: 410 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: false},
|
||||
{ID: 12, StartTime: 420 * time.Millisecond, Duration: 5 * time.Millisecond, ShouldFail: false},
|
||||
}
|
||||
|
||||
results := make([]RequestResult, len(requests))
|
||||
|
||||
// Execute requests sequentially but model them as if they were concurrent
|
||||
// by advancing the virtual timer to each request's start time
|
||||
for i, req := range requests {
|
||||
// Set virtual time to request start time
|
||||
vt.Set(startTime.Add(req.StartTime))
|
||||
|
||||
// Create the operation based on spec
|
||||
var op ReaderIOResult[string]
|
||||
if req.ShouldFail {
|
||||
op = Left[string](errors.New("operation failed"))
|
||||
} else {
|
||||
op = Of("success")
|
||||
}
|
||||
|
||||
// Apply circuit breaker
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
|
||||
// Record start time
|
||||
execStartTime := vt.Now()
|
||||
|
||||
// Execute the operation
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
// Advance time by operation duration
|
||||
vt.Advance(req.Duration)
|
||||
execEndTime := vt.Now()
|
||||
|
||||
// Analyze the result
|
||||
isSuccess := result.IsRight(outcome)
|
||||
var err error
|
||||
var isCBError bool
|
||||
|
||||
if !isSuccess {
|
||||
_, err = result.Unwrap(outcome)
|
||||
var cbErr *circuitbreaker.CircuitBreakerError
|
||||
isCBError = errors.As(err, &cbErr)
|
||||
}
|
||||
|
||||
results[i] = RequestResult{
|
||||
ID: req.ID,
|
||||
StartTime: execStartTime,
|
||||
EndTime: execEndTime,
|
||||
Success: isSuccess,
|
||||
Error: err,
|
||||
CircuitBreakerError: isCBError,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify Phase 1: Initial requests should succeed
|
||||
assert.True(t, results[0].Success, "Request 1 should succeed")
|
||||
assert.True(t, results[1].Success, "Request 2 should succeed")
|
||||
|
||||
// Verify Phase 2: Failures should be recorded (first 3 fail with actual error)
|
||||
// The 4th might fail with CB error if circuit opened fast enough
|
||||
assert.False(t, results[2].Success, "Request 3 should fail")
|
||||
assert.False(t, results[3].Success, "Request 4 should fail")
|
||||
assert.False(t, results[4].Success, "Request 5 should fail")
|
||||
|
||||
// At least the first 3 failures should be actual operation failures, not CB errors
|
||||
actualFailures := 0
|
||||
for i := 2; i <= 4; i++ {
|
||||
if !results[i].CircuitBreakerError {
|
||||
actualFailures++
|
||||
}
|
||||
}
|
||||
assert.GreaterOrEqual(t, actualFailures, 3, "At least 3 actual operation failures should occur")
|
||||
|
||||
// Verify Phase 3: Requests during open circuit should fail with circuit breaker error
|
||||
for i := 6; i <= 8; i++ {
|
||||
assert.False(t, results[i].Success, "Request %d should fail during open circuit", results[i].ID)
|
||||
assert.True(t, results[i].CircuitBreakerError, "Request %d should fail with circuit breaker error", results[i].ID)
|
||||
}
|
||||
|
||||
// Verify Phase 4: After timeout, canary and subsequent requests should succeed
|
||||
assert.True(t, results[9].Success, "Request 10 (canary) should succeed")
|
||||
assert.True(t, results[10].Success, "Request 11 should succeed after circuit closes")
|
||||
assert.True(t, results[11].Success, "Request 12 should succeed after circuit closes")
|
||||
|
||||
// Verify final state is closed
|
||||
finalState := ioref.Read(stateRef)()
|
||||
assert.True(t, circuitbreaker.IsClosed(finalState), "Circuit should be closed at the end")
|
||||
|
||||
// Log summary for debugging
|
||||
t.Logf("Test completed with %d requests", len(results))
|
||||
successCount := 0
|
||||
cbErrorCount := 0
|
||||
actualErrorCount := 0
|
||||
|
||||
for _, r := range results {
|
||||
if r.Success {
|
||||
successCount++
|
||||
} else if r.CircuitBreakerError {
|
||||
cbErrorCount++
|
||||
} else {
|
||||
actualErrorCount++
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Summary: %d successes, %d circuit breaker errors, %d actual errors",
|
||||
successCount, cbErrorCount, actualErrorCount)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentHighLoad tests circuit breaker behavior
|
||||
// under high concurrent load with mixed success/failure patterns
|
||||
func TestCircuitBreaker_ConcurrentHighLoad(t *testing.T) {
|
||||
startTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
vt := NewVirtualTimer(startTime)
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[int](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
ctx := t.Context()
|
||||
|
||||
// Create a large batch of 50 requests
|
||||
// Pattern: success, success, fail, fail, fail, fail, success, success, ...
|
||||
// This ensures we have initial successes, then failures to open circuit,
|
||||
// then more requests that hit the open circuit
|
||||
numRequests := 50
|
||||
|
||||
results := make([]bool, numRequests)
|
||||
cbErrors := make([]bool, numRequests)
|
||||
|
||||
// Execute requests with controlled timing
|
||||
for i := range numRequests {
|
||||
// Advance time slightly for each request
|
||||
vt.Advance(10 * time.Millisecond)
|
||||
|
||||
// Pattern: 2 success, 4 failures, repeat
|
||||
// This ensures we exceed the threshold (3 failures) early on
|
||||
shouldFail := (i%6) >= 2 && (i%6) < 6
|
||||
|
||||
var op ReaderIOResult[int]
|
||||
if shouldFail {
|
||||
op = Left[int](errors.New("simulated failure"))
|
||||
} else {
|
||||
op = Of(i)
|
||||
}
|
||||
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
|
||||
results[i] = result.IsRight(outcome)
|
||||
|
||||
if !results[i] {
|
||||
_, err := result.Unwrap(outcome)
|
||||
var cbErr *circuitbreaker.CircuitBreakerError
|
||||
cbErrors[i] = errors.As(err, &cbErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Count outcomes
|
||||
successCount := 0
|
||||
failureCount := 0
|
||||
cbErrorCount := 0
|
||||
|
||||
for i := range numRequests {
|
||||
if results[i] {
|
||||
successCount++
|
||||
} else {
|
||||
failureCount++
|
||||
if cbErrors[i] {
|
||||
cbErrorCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("High load test: %d total requests", numRequests)
|
||||
t.Logf("Results: %d successes, %d failures (%d circuit breaker errors)",
|
||||
successCount, failureCount, cbErrorCount)
|
||||
|
||||
// Verify that circuit breaker activated (some requests failed due to open circuit)
|
||||
assert.Greater(t, cbErrorCount, 0, "Circuit breaker should have opened and blocked some requests")
|
||||
|
||||
// Verify that not all requests failed (some succeeded before circuit opened)
|
||||
assert.Greater(t, successCount, 0, "Some requests should have succeeded")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_TrueConcurrentRequests tests actual concurrent execution
|
||||
// with proper synchronization
|
||||
func TestCircuitBreaker_TrueConcurrentRequests(t *testing.T) {
|
||||
startTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
vt := NewVirtualTimer(startTime)
|
||||
var logMessages []string
|
||||
|
||||
cb := MakeCircuitBreaker[int](
|
||||
vt.Now,
|
||||
testCBClosedState(),
|
||||
checkAllErrors,
|
||||
testCBRetryPolicy(),
|
||||
testMetrics(&logMessages),
|
||||
)
|
||||
|
||||
stateRef := circuitbreaker.MakeClosedIORef(testCBClosedState())()
|
||||
ctx := t.Context()
|
||||
|
||||
// Launch 20 concurrent requests
|
||||
numRequests := 20
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, numRequests)
|
||||
cbErrors := make([]bool, numRequests)
|
||||
|
||||
// First, send some successful requests
|
||||
for i := range 5 {
|
||||
op := Of(i)
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
results[i] = result.IsRight(outcome)
|
||||
}
|
||||
|
||||
// Now send concurrent failures to open the circuit
|
||||
for i := 5; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
op := Left[int](errors.New("concurrent failure"))
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
results[idx] = result.IsRight(outcome)
|
||||
if !results[idx] {
|
||||
_, err := result.Unwrap(outcome)
|
||||
var cbErr *circuitbreaker.CircuitBreakerError
|
||||
cbErrors[idx] = errors.As(err, &cbErr)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Now send more requests that should hit the open circuit
|
||||
for i := 10; i < numRequests; i++ {
|
||||
op := Of(i)
|
||||
env := pair.MakePair(stateRef, op)
|
||||
resultEnv := cb(env)
|
||||
protectedOp := pair.Tail(resultEnv)
|
||||
outcome := protectedOp(ctx)()
|
||||
results[i] = result.IsRight(outcome)
|
||||
if !results[i] {
|
||||
_, err := result.Unwrap(outcome)
|
||||
var cbErr *circuitbreaker.CircuitBreakerError
|
||||
cbErrors[i] = errors.As(err, &cbErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Count outcomes
|
||||
successCount := 0
|
||||
failureCount := 0
|
||||
cbErrorCount := 0
|
||||
|
||||
for i := range numRequests {
|
||||
if results[i] {
|
||||
successCount++
|
||||
} else {
|
||||
failureCount++
|
||||
if cbErrors[i] {
|
||||
cbErrorCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Concurrent test: %d total requests", numRequests)
|
||||
t.Logf("Results: %d successes, %d failures (%d circuit breaker errors)",
|
||||
successCount, failureCount, cbErrorCount)
|
||||
|
||||
// Verify initial successes
|
||||
assert.Equal(t, 5, successCount, "First 5 requests should succeed")
|
||||
|
||||
// Verify that circuit breaker opened and blocked some requests
|
||||
assert.Greater(t, cbErrorCount, 0, "Circuit breaker should have opened and blocked some requests")
|
||||
}
|
||||
@@ -28,7 +28,7 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
//
|
||||
//go:inline
|
||||
func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
return ChainIOK(io.FromConsumerK(c))
|
||||
return ChainIOK(io.FromConsumer(c))
|
||||
}
|
||||
|
||||
// ChainFirstConsumer chains a consumer function into a ReaderIOResult computation, preserving the original value.
|
||||
@@ -59,5 +59,5 @@ func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
//
|
||||
//go:inline
|
||||
func ChainFirstConsumer[A any](c Consumer[A]) Operator[A, A] {
|
||||
return ChainFirstIOK(io.FromConsumerK(c))
|
||||
return ChainFirstIOK(io.FromConsumer(c))
|
||||
}
|
||||
|
||||
@@ -966,6 +966,16 @@ func TapLeft[A, B any](f Kleisli[error, B]) Operator[A, A] {
|
||||
return RIOR.TapLeft[A](WithContextK(f))
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func ChainFirstLeftIOK[A, B any](f io.Kleisli[error, B]) Operator[A, A] {
|
||||
return RIOR.ChainFirstLeftIOK[A, context.Context](f)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func TapLeftIOK[A, B any](f io.Kleisli[error, B]) Operator[A, A] {
|
||||
return RIOR.TapLeftIOK[A, context.Context](f)
|
||||
}
|
||||
|
||||
// Local transforms the context.Context environment before passing it to a ReaderIOResult computation.
|
||||
//
|
||||
// This is the Reader's local operation, which allows you to modify the environment
|
||||
|
||||
@@ -25,10 +25,12 @@ import (
|
||||
"github.com/IBM/fp-go/v2/endomorphism"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/ioeither"
|
||||
"github.com/IBM/fp-go/v2/ioref"
|
||||
"github.com/IBM/fp-go/v2/lazy"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/optics/prism"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/predicate"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
"github.com/IBM/fp-go/v2/readereither"
|
||||
@@ -36,6 +38,7 @@ import (
|
||||
RIOR "github.com/IBM/fp-go/v2/readerioresult"
|
||||
"github.com/IBM/fp-go/v2/readeroption"
|
||||
"github.com/IBM/fp-go/v2/result"
|
||||
"github.com/IBM/fp-go/v2/state"
|
||||
"github.com/IBM/fp-go/v2/tailrec"
|
||||
)
|
||||
|
||||
@@ -143,4 +146,10 @@ type (
|
||||
Trampoline[B, L any] = tailrec.Trampoline[B, L]
|
||||
|
||||
Predicate[A any] = predicate.Predicate[A]
|
||||
|
||||
Pair[A, B any] = pair.Pair[A, B]
|
||||
|
||||
IORef[A any] = ioref.IORef[A]
|
||||
|
||||
State[S, A any] = state.State[S, A]
|
||||
)
|
||||
|
||||
@@ -56,3 +56,77 @@ func AltMonoid[E, A any](zero Lazy[Either[E, A]]) Monoid[E, A] {
|
||||
MonadAlt[E, A],
|
||||
)
|
||||
}
|
||||
|
||||
// takeFirst is a helper function that returns the first Right value, or the second if the first is Left.
|
||||
func takeFirst[E, A any](l, r Either[E, A]) Either[E, A] {
|
||||
if IsRight(l) {
|
||||
return l
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FirstMonoid creates a Monoid for Either[E, A] that returns the first Right value.
|
||||
// This monoid prefers the left operand when it is Right, otherwise returns the right operand.
|
||||
// The empty value is provided as a lazy computation.
|
||||
//
|
||||
// This is equivalent to AltMonoid but implemented more directly.
|
||||
//
|
||||
// Truth table:
|
||||
//
|
||||
// | x | y | concat(x, y) |
|
||||
// | --------- | --------- | ------------ |
|
||||
// | left(e1) | left(e2) | left(e2) |
|
||||
// | right(a) | left(e) | right(a) |
|
||||
// | left(e) | right(b) | right(b) |
|
||||
// | right(a) | right(b) | right(a) |
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import "errors"
|
||||
// zero := func() either.Either[error, int] { return either.Left[int](errors.New("empty")) }
|
||||
// m := either.FirstMonoid[error, int](zero)
|
||||
// m.Concat(either.Right[error](2), either.Right[error](3)) // Right(2) - returns first Right
|
||||
// m.Concat(either.Left[int](errors.New("err")), either.Right[error](3)) // Right(3)
|
||||
// m.Concat(either.Right[error](2), either.Left[int](errors.New("err"))) // Right(2)
|
||||
// m.Empty() // Left(error("empty"))
|
||||
//
|
||||
//go:inline
|
||||
func FirstMonoid[E, A any](zero Lazy[Either[E, A]]) M.Monoid[Either[E, A]] {
|
||||
return M.MakeMonoid(takeFirst[E, A], zero())
|
||||
}
|
||||
|
||||
// takeLast is a helper function that returns the last Right value, or the first if the last is Left.
|
||||
func takeLast[E, A any](l, r Either[E, A]) Either[E, A] {
|
||||
if IsRight(r) {
|
||||
return r
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// LastMonoid creates a Monoid for Either[E, A] that returns the last Right value.
|
||||
// This monoid prefers the right operand when it is Right, otherwise returns the left operand.
|
||||
// The empty value is provided as a lazy computation.
|
||||
//
|
||||
// Truth table:
|
||||
//
|
||||
// | x | y | concat(x, y) |
|
||||
// | --------- | --------- | ------------ |
|
||||
// | left(e1) | left(e2) | left(e1) |
|
||||
// | right(a) | left(e) | right(a) |
|
||||
// | left(e) | right(b) | right(b) |
|
||||
// | right(a) | right(b) | right(b) |
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import "errors"
|
||||
// zero := func() either.Either[error, int] { return either.Left[int](errors.New("empty")) }
|
||||
// m := either.LastMonoid[error, int](zero)
|
||||
// m.Concat(either.Right[error](2), either.Right[error](3)) // Right(3) - returns last Right
|
||||
// m.Concat(either.Left[int](errors.New("err")), either.Right[error](3)) // Right(3)
|
||||
// m.Concat(either.Right[error](2), either.Left[int](errors.New("err"))) // Right(2)
|
||||
// m.Empty() // Left(error("empty"))
|
||||
//
|
||||
//go:inline
|
||||
func LastMonoid[E, A any](zero Lazy[Either[E, A]]) M.Monoid[Either[E, A]] {
|
||||
return M.MakeMonoid(takeLast[E, A], zero())
|
||||
}
|
||||
|
||||
402
v2/either/monoid_test.go
Normal file
402
v2/either/monoid_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
// Copyright (c) 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package either
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestFirstMonoid tests the FirstMonoid implementation
|
||||
func TestFirstMonoid(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[error, int](zero)
|
||||
|
||||
t.Run("both Right values - returns first", func(t *testing.T) {
|
||||
result := m.Concat(Right[error](2), Right[error](3))
|
||||
assert.Equal(t, Right[error](2), result)
|
||||
})
|
||||
|
||||
t.Run("left Right, right Left", func(t *testing.T) {
|
||||
result := m.Concat(Right[error](2), Left[int](errors.New("err")))
|
||||
assert.Equal(t, Right[error](2), result)
|
||||
})
|
||||
|
||||
t.Run("left Left, right Right", func(t *testing.T) {
|
||||
result := m.Concat(Left[int](errors.New("err")), Right[error](3))
|
||||
assert.Equal(t, Right[error](3), result)
|
||||
})
|
||||
|
||||
t.Run("both Left", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
result := m.Concat(Left[int](err1), Left[int](err2))
|
||||
// Should return the second Left
|
||||
assert.True(t, IsLeft(result))
|
||||
_, leftErr := Unwrap(result)
|
||||
assert.Equal(t, err2, leftErr)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := m.Empty()
|
||||
assert.True(t, IsLeft(empty))
|
||||
_, leftErr := Unwrap(empty)
|
||||
assert.Equal(t, "empty", leftErr.Error())
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Right[error](5)
|
||||
result := m.Concat(m.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Right[error](5)
|
||||
result := m.Concat(x, m.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Right[error](1)
|
||||
b := Right[error](2)
|
||||
c := Right[error](3)
|
||||
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Right[error](1), left)
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
// Should return the first Right value encountered
|
||||
result := m.Concat(
|
||||
m.Concat(Left[int](errors.New("err1")), Right[error](1)),
|
||||
m.Concat(Right[error](2), Right[error](3)),
|
||||
)
|
||||
assert.Equal(t, Right[error](1), result)
|
||||
})
|
||||
|
||||
t.Run("with strings", func(t *testing.T) {
|
||||
zeroStr := func() Either[error, string] { return Left[string](errors.New("empty")) }
|
||||
strMonoid := FirstMonoid[error, string](zeroStr)
|
||||
|
||||
result := strMonoid.Concat(Right[error]("first"), Right[error]("second"))
|
||||
assert.Equal(t, Right[error]("first"), result)
|
||||
|
||||
result = strMonoid.Concat(Left[string](errors.New("err")), Right[error]("second"))
|
||||
assert.Equal(t, Right[error]("second"), result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLastMonoid tests the LastMonoid implementation
|
||||
func TestLastMonoid(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[error, int](zero)
|
||||
|
||||
t.Run("both Right values - returns last", func(t *testing.T) {
|
||||
result := m.Concat(Right[error](2), Right[error](3))
|
||||
assert.Equal(t, Right[error](3), result)
|
||||
})
|
||||
|
||||
t.Run("left Right, right Left", func(t *testing.T) {
|
||||
result := m.Concat(Right[error](2), Left[int](errors.New("err")))
|
||||
assert.Equal(t, Right[error](2), result)
|
||||
})
|
||||
|
||||
t.Run("left Left, right Right", func(t *testing.T) {
|
||||
result := m.Concat(Left[int](errors.New("err")), Right[error](3))
|
||||
assert.Equal(t, Right[error](3), result)
|
||||
})
|
||||
|
||||
t.Run("both Left", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
result := m.Concat(Left[int](err1), Left[int](err2))
|
||||
// Should return the first Left
|
||||
assert.True(t, IsLeft(result))
|
||||
_, leftErr := Unwrap(result)
|
||||
assert.Equal(t, err1, leftErr)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := m.Empty()
|
||||
assert.True(t, IsLeft(empty))
|
||||
_, leftErr := Unwrap(empty)
|
||||
assert.Equal(t, "empty", leftErr.Error())
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Right[error](5)
|
||||
result := m.Concat(m.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Right[error](5)
|
||||
result := m.Concat(x, m.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Right[error](1)
|
||||
b := Right[error](2)
|
||||
c := Right[error](3)
|
||||
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Right[error](3), left)
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
// Should return the last Right value encountered
|
||||
result := m.Concat(
|
||||
m.Concat(Right[error](1), Right[error](2)),
|
||||
m.Concat(Right[error](3), Left[int](errors.New("err"))),
|
||||
)
|
||||
assert.Equal(t, Right[error](3), result)
|
||||
})
|
||||
|
||||
t.Run("with strings", func(t *testing.T) {
|
||||
zeroStr := func() Either[error, string] { return Left[string](errors.New("empty")) }
|
||||
strMonoid := LastMonoid[error, string](zeroStr)
|
||||
|
||||
result := strMonoid.Concat(Right[error]("first"), Right[error]("second"))
|
||||
assert.Equal(t, Right[error]("second"), result)
|
||||
|
||||
result = strMonoid.Concat(Right[error]("first"), Left[string](errors.New("err")))
|
||||
assert.Equal(t, Right[error]("first"), result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestFirstMonoidVsAltMonoid verifies FirstMonoid and AltMonoid have the same behavior
|
||||
func TestFirstMonoidVsAltMonoid(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
firstMonoid := FirstMonoid[error, int](zero)
|
||||
altMonoid := AltMonoid[error, int](zero)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left Either[error, int]
|
||||
right Either[error, int]
|
||||
}{
|
||||
{"both Right", Right[error](1), Right[error](2)},
|
||||
{"left Right, right Left", Right[error](1), Left[int](errors.New("err"))},
|
||||
{"left Left, right Right", Left[int](errors.New("err")), Right[error](2)},
|
||||
{"both Left", Left[int](errors.New("err1")), Left[int](errors.New("err2"))},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(tc.left, tc.right)
|
||||
altResult := altMonoid.Concat(tc.left, tc.right)
|
||||
|
||||
// Both should have the same Right/Left status
|
||||
assert.Equal(t, IsRight(firstResult), IsRight(altResult), "FirstMonoid and AltMonoid should have same Right/Left status")
|
||||
|
||||
if IsRight(firstResult) {
|
||||
rightVal1, _ := Unwrap(firstResult)
|
||||
rightVal2, _ := Unwrap(altResult)
|
||||
assert.Equal(t, rightVal1, rightVal2, "FirstMonoid and AltMonoid should have same Right value")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFirstMonoidVsLastMonoid verifies the difference between FirstMonoid and LastMonoid
|
||||
func TestFirstMonoidVsLastMonoid(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
firstMonoid := FirstMonoid[error, int](zero)
|
||||
lastMonoid := LastMonoid[error, int](zero)
|
||||
|
||||
t.Run("both Right - different results", func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(Right[error](1), Right[error](2))
|
||||
lastResult := lastMonoid.Concat(Right[error](1), Right[error](2))
|
||||
|
||||
assert.Equal(t, Right[error](1), firstResult)
|
||||
assert.Equal(t, Right[error](2), lastResult)
|
||||
assert.NotEqual(t, firstResult, lastResult)
|
||||
})
|
||||
|
||||
t.Run("with Left values - different behavior", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
|
||||
// Both Left: FirstMonoid returns second, LastMonoid returns first
|
||||
firstResult := firstMonoid.Concat(Left[int](err1), Left[int](err2))
|
||||
lastResult := lastMonoid.Concat(Left[int](err1), Left[int](err2))
|
||||
|
||||
assert.True(t, IsLeft(firstResult))
|
||||
assert.True(t, IsLeft(lastResult))
|
||||
_, leftErr1 := Unwrap(firstResult)
|
||||
_, leftErr2 := Unwrap(lastResult)
|
||||
assert.Equal(t, err2, leftErr1)
|
||||
assert.Equal(t, err1, leftErr2)
|
||||
})
|
||||
|
||||
t.Run("mixed values - same results", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
left Either[error, int]
|
||||
right Either[error, int]
|
||||
expected Either[error, int]
|
||||
}{
|
||||
{"left Right, right Left", Right[error](1), Left[int](errors.New("err")), Right[error](1)},
|
||||
{"left Left, right Right", Left[int](errors.New("err")), Right[error](2), Right[error](2)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(tc.left, tc.right)
|
||||
lastResult := lastMonoid.Concat(tc.left, tc.right)
|
||||
|
||||
assert.Equal(t, tc.expected, firstResult)
|
||||
assert.Equal(t, tc.expected, lastResult)
|
||||
assert.Equal(t, firstResult, lastResult)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidLaws verifies monoid laws for FirstMonoid and LastMonoid
|
||||
func TestMonoidLaws(t *testing.T) {
|
||||
t.Run("FirstMonoid laws", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[error, int](zero)
|
||||
|
||||
a := Right[error](1)
|
||||
b := Right[error](2)
|
||||
c := Right[error](3)
|
||||
|
||||
// Associativity: (a • b) • c = a • (b • c)
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity: Empty() • a = a
|
||||
leftId := m.Concat(m.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity: a • Empty() = a
|
||||
rightId := m.Concat(a, m.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid laws", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[error, int](zero)
|
||||
|
||||
a := Right[error](1)
|
||||
b := Right[error](2)
|
||||
c := Right[error](3)
|
||||
|
||||
// Associativity: (a • b) • c = a • (b • c)
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity: Empty() • a = a
|
||||
leftId := m.Concat(m.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity: a • Empty() = a
|
||||
rightId := m.Concat(a, m.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("FirstMonoid laws with Left values", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[error, int](zero)
|
||||
|
||||
a := Left[int](errors.New("err1"))
|
||||
b := Left[int](errors.New("err2"))
|
||||
c := Left[int](errors.New("err3"))
|
||||
|
||||
// Associativity with Left values
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid laws with Left values", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[error, int](zero)
|
||||
|
||||
a := Left[int](errors.New("err1"))
|
||||
b := Left[int](errors.New("err2"))
|
||||
c := Left[int](errors.New("err3"))
|
||||
|
||||
// Associativity with Left values
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidEdgeCases tests edge cases for monoid operations
|
||||
func TestMonoidEdgeCases(t *testing.T) {
|
||||
t.Run("FirstMonoid with empty concatenations", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[error, int](zero)
|
||||
|
||||
// Empty with empty
|
||||
result := m.Concat(m.Empty(), m.Empty())
|
||||
assert.True(t, IsLeft(result))
|
||||
})
|
||||
|
||||
t.Run("LastMonoid with empty concatenations", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[error, int](zero)
|
||||
|
||||
// Empty with empty
|
||||
result := m.Concat(m.Empty(), m.Empty())
|
||||
assert.True(t, IsLeft(result))
|
||||
})
|
||||
|
||||
t.Run("FirstMonoid chain of operations", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[error, int](zero)
|
||||
|
||||
// Chain multiple operations
|
||||
result := m.Concat(
|
||||
m.Concat(
|
||||
m.Concat(Left[int](errors.New("err1")), Left[int](errors.New("err2"))),
|
||||
Right[error](1),
|
||||
),
|
||||
m.Concat(Right[error](2), Right[error](3)),
|
||||
)
|
||||
assert.Equal(t, Right[error](1), result)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid chain of operations", func(t *testing.T) {
|
||||
zero := func() Either[error, int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[error, int](zero)
|
||||
|
||||
// Chain multiple operations
|
||||
result := m.Concat(
|
||||
m.Concat(Right[error](1), Right[error](2)),
|
||||
m.Concat(
|
||||
Right[error](3),
|
||||
m.Concat(Right[error](4), Left[int](errors.New("err"))),
|
||||
),
|
||||
)
|
||||
assert.Equal(t, Right[error](4), result)
|
||||
})
|
||||
}
|
||||
@@ -31,3 +31,29 @@ import (
|
||||
// err := errors.New("something went wrong")
|
||||
// same := Identity(err) // returns the same error
|
||||
var Identity = F.Identity[error]
|
||||
|
||||
// IsNonNil checks if an error is non-nil.
|
||||
//
|
||||
// This function provides a predicate for testing whether an error value is not nil.
|
||||
// It's useful in functional programming contexts where you need a function to check
|
||||
// error presence, such as in filter operations or conditional logic.
|
||||
//
|
||||
// Parameters:
|
||||
// - err: The error to check
|
||||
//
|
||||
// Returns:
|
||||
// - true if the error is not nil, false otherwise
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := errors.New("something went wrong")
|
||||
// if IsNonNil(err) {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
// // Using in functional contexts
|
||||
// errors := []error{nil, errors.New("error1"), nil, errors.New("error2")}
|
||||
// nonNilErrors := F.Filter(IsNonNil)(errors) // [error1, error2]
|
||||
func IsNonNil(err error) bool {
|
||||
return err != nil
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func Pipe1[F1 ~func(T0) T1, T0, T1 any](t0 T0, f1 F1) T1 {
|
||||
// The final return value is the result of the last function application
|
||||
//go:inline
|
||||
func Flow1[F1 ~func(T0) T1, T0, T1 any](f1 F1) func(T0) T1 {
|
||||
//go:inline
|
||||
return func(t0 T0) T1 {
|
||||
return Pipe1(t0, f1)
|
||||
}
|
||||
@@ -103,6 +104,7 @@ func Pipe2[F1 ~func(T0) T1, F2 ~func(T1) T2, T0, T1, T2 any](t0 T0, f1 F1, f2 F2
|
||||
// The final return value is the result of the last function application
|
||||
//go:inline
|
||||
func Flow2[F1 ~func(T0) T1, F2 ~func(T1) T2, T0, T1, T2 any](f1 F1, f2 F2) func(T0) T2 {
|
||||
//go:inline
|
||||
return func(t0 T0) T2 {
|
||||
return Pipe2(t0, f1, f2)
|
||||
}
|
||||
@@ -169,6 +171,7 @@ func Pipe3[F1 ~func(T0) T1, F2 ~func(T1) T2, F3 ~func(T2) T3, T0, T1, T2, T3 any
|
||||
// The final return value is the result of the last function application
|
||||
//go:inline
|
||||
func Flow3[F1 ~func(T0) T1, F2 ~func(T1) T2, F3 ~func(T2) T3, T0, T1, T2, T3 any](f1 F1, f2 F2, f3 F3) func(T0) T3 {
|
||||
//go:inline
|
||||
return func(t0 T0) T3 {
|
||||
return Pipe3(t0, f1, f2, f3)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,47 @@
|
||||
|
||||
package function
|
||||
|
||||
// Void represents the unit type, a type with exactly one value.
|
||||
//
|
||||
// In functional programming, Void (also known as Unit) is used to represent
|
||||
// the absence of meaningful information. It's similar to void in other languages,
|
||||
// but as a value rather than the absence of a value.
|
||||
//
|
||||
// Common use cases:
|
||||
// - As a return type for functions that perform side effects but don't return meaningful data
|
||||
// - As a placeholder type parameter when a type is required but no data needs to be passed
|
||||
// - In functional patterns where a value is required but the actual data is irrelevant
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Function that performs an action but returns no meaningful data
|
||||
// func logMessage(msg string) Void {
|
||||
// fmt.Println(msg)
|
||||
// return VOID
|
||||
// }
|
||||
//
|
||||
// // Using Void as a type parameter
|
||||
// type Action = func() Void
|
||||
type (
|
||||
Void = struct{}
|
||||
)
|
||||
|
||||
// VOID is the single inhabitant of the Void type.
|
||||
//
|
||||
// This constant represents the only possible value of type Void. Use it when you need
|
||||
// to return or pass a Void value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// func doSomething() Void {
|
||||
// // perform some action
|
||||
// return VOID
|
||||
// }
|
||||
//
|
||||
// // Ignoring the return value
|
||||
// _ = doSomething()
|
||||
var VOID Void = struct{}{}
|
||||
|
||||
// ToAny converts a value of any type to the any (interface{}) type.
|
||||
//
|
||||
// This function performs an explicit type conversion to the any type, which can be
|
||||
|
||||
@@ -15,15 +15,17 @@
|
||||
|
||||
package io
|
||||
|
||||
import "github.com/IBM/fp-go/v2/function"
|
||||
|
||||
// ChainConsumer converts a Consumer into an IO operator that executes the consumer
|
||||
// as a side effect and returns an empty struct.
|
||||
//
|
||||
// This function bridges the gap between pure consumers (functions that consume values
|
||||
// without returning anything) and the IO monad. It takes a Consumer[A] and returns
|
||||
// an Operator that:
|
||||
// 1. Executes the source IO[A] to get a value
|
||||
// 2. Passes that value to the consumer for side effects
|
||||
// 3. Returns IO[struct{}] to maintain the monadic chain
|
||||
// 1. Executes the source IO[A] to get a value
|
||||
// 2. Passes that value to the consumer for side effects
|
||||
// 3. Returns IO[struct{}] to maintain the monadic chain
|
||||
//
|
||||
// The returned IO[struct{}] allows the operation to be composed with other IO operations
|
||||
// while discarding the consumed value. This is useful for operations like logging,
|
||||
@@ -68,11 +70,11 @@ package io
|
||||
// io.Map(func(struct{}) int { return len(values) }),
|
||||
// )
|
||||
// count := pipeline() // Returns 1, values contains [100]
|
||||
func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
return Chain(FromConsumerK(c))
|
||||
func ChainConsumer[A any](c Consumer[A]) Operator[A, Void] {
|
||||
return Chain(FromConsumer(c))
|
||||
}
|
||||
|
||||
// FromConsumerK converts a Consumer into a Kleisli arrow that wraps the consumer
|
||||
// FromConsumer converts a Consumer into a Kleisli arrow that wraps the consumer
|
||||
// in an IO context.
|
||||
//
|
||||
// This function lifts a Consumer[A] (a function that consumes a value and performs
|
||||
@@ -100,7 +102,7 @@ func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
// }
|
||||
//
|
||||
// // Convert to Kleisli arrow
|
||||
// logKleisli := io.FromConsumerK(logger)
|
||||
// logKleisli := io.FromConsumer(logger)
|
||||
//
|
||||
// // Use with Chain
|
||||
// result := F.Pipe2(
|
||||
@@ -117,11 +119,11 @@ func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
// io.Map(func(struct{}) int { return 1 }),
|
||||
// )
|
||||
// }
|
||||
func FromConsumerK[A any](c Consumer[A]) Kleisli[A, struct{}] {
|
||||
return func(a A) IO[struct{}] {
|
||||
return func() struct{} {
|
||||
func FromConsumer[A any](c Consumer[A]) Kleisli[A, Void] {
|
||||
return func(a A) IO[Void] {
|
||||
return func() Void {
|
||||
c(a)
|
||||
return struct{}{}
|
||||
return function.VOID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
70
v2/io/run.go
Normal file
70
v2/io/run.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package io
|
||||
|
||||
// Run executes an IO computation and returns its result.
|
||||
//
|
||||
// This function is the primary way to execute IO computations. It takes an IO[A]
|
||||
// (a lazy computation) and immediately evaluates it, returning the computed value.
|
||||
//
|
||||
// Run is the bridge between the pure functional world (where computations are
|
||||
// described but not executed) and the imperative world (where side effects occur).
|
||||
// It should typically be called at the edges of your application, such as in main()
|
||||
// or in test code.
|
||||
//
|
||||
// Parameters:
|
||||
// - fa: The IO computation to execute
|
||||
//
|
||||
// Returns:
|
||||
// - The result of executing the IO computation
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Create a lazy computation
|
||||
// greeting := io.Of("Hello, World!")
|
||||
//
|
||||
// // Execute it and get the result
|
||||
// result := io.Run(greeting) // result == "Hello, World!"
|
||||
//
|
||||
// Example with side effects:
|
||||
//
|
||||
// // Create a computation that prints and returns a value
|
||||
// computation := func() string {
|
||||
// fmt.Println("Computing...")
|
||||
// return "Done"
|
||||
// }
|
||||
//
|
||||
// // Nothing is printed yet
|
||||
// io := io.MakeIO(computation)
|
||||
//
|
||||
// // Now the computation runs and "Computing..." is printed
|
||||
// result := io.Run(io) // result == "Done"
|
||||
//
|
||||
// Example with composition:
|
||||
//
|
||||
// result := io.Run(
|
||||
// pipe.Pipe2(
|
||||
// io.Of(5),
|
||||
// io.Map(func(x int) int { return x * 2 }),
|
||||
// io.Map(func(x int) int { return x + 1 }),
|
||||
// ),
|
||||
// ) // result == 11
|
||||
//
|
||||
// Note: Run should be used sparingly in application code. Prefer composing
|
||||
// IO computations and only calling Run at the application boundaries.
|
||||
func Run[A any](fa IO[A]) A {
|
||||
return fa()
|
||||
}
|
||||
228
v2/io/run_test.go
Normal file
228
v2/io/run_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
N "github.com/IBM/fp-go/v2/number"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestRun_BasicValue tests that Run executes a simple IO computation
|
||||
func TestRun_BasicValue(t *testing.T) {
|
||||
io := Of(42)
|
||||
result := Run(io)
|
||||
assert.Equal(t, 42, result)
|
||||
}
|
||||
|
||||
// TestRun_String tests Run with string values
|
||||
func TestRun_String(t *testing.T) {
|
||||
io := Of("Hello, World!")
|
||||
result := Run(io)
|
||||
assert.Equal(t, "Hello, World!", result)
|
||||
}
|
||||
|
||||
// TestRun_WithMap tests Run with a mapped computation
|
||||
func TestRun_WithMap(t *testing.T) {
|
||||
io := F.Pipe1(
|
||||
Of(5),
|
||||
Map(N.Mul(2)),
|
||||
)
|
||||
result := Run(io)
|
||||
assert.Equal(t, 10, result)
|
||||
}
|
||||
|
||||
// TestRun_WithChain tests Run with chained computations
|
||||
func TestRun_WithChain(t *testing.T) {
|
||||
io := F.Pipe1(
|
||||
Of(3),
|
||||
Chain(func(x int) IO[int] {
|
||||
return Of(x * x)
|
||||
}),
|
||||
)
|
||||
result := Run(io)
|
||||
assert.Equal(t, 9, result)
|
||||
}
|
||||
|
||||
// TestRun_ComposedOperations tests Run with multiple composed operations
|
||||
func TestRun_ComposedOperations(t *testing.T) {
|
||||
io := F.Pipe3(
|
||||
Of(5),
|
||||
Map(N.Mul(2)), // 10
|
||||
Map(N.Add(3)), // 13
|
||||
Map(N.Sub(1)), // 12
|
||||
)
|
||||
result := Run(io)
|
||||
assert.Equal(t, 12, result)
|
||||
}
|
||||
|
||||
// TestRun_WithSideEffect tests that Run executes side effects
|
||||
func TestRun_WithSideEffect(t *testing.T) {
|
||||
counter := 0
|
||||
io := func() int {
|
||||
counter++
|
||||
return counter
|
||||
}
|
||||
|
||||
// First execution
|
||||
result1 := Run(io)
|
||||
assert.Equal(t, 1, result1)
|
||||
assert.Equal(t, 1, counter)
|
||||
|
||||
// Second execution (side effect happens again)
|
||||
result2 := Run(io)
|
||||
assert.Equal(t, 2, result2)
|
||||
assert.Equal(t, 2, counter)
|
||||
}
|
||||
|
||||
// TestRun_LazyEvaluation tests that IO is lazy until Run is called
|
||||
func TestRun_LazyEvaluation(t *testing.T) {
|
||||
executed := false
|
||||
io := func() bool {
|
||||
executed = true
|
||||
return true
|
||||
}
|
||||
|
||||
// IO created but not executed
|
||||
assert.False(t, executed, "IO should not execute until Run is called")
|
||||
|
||||
// Now execute
|
||||
result := Run(io)
|
||||
assert.True(t, executed, "IO should execute when Run is called")
|
||||
assert.True(t, result)
|
||||
}
|
||||
|
||||
// TestRun_WithFlatten tests Run with nested IO
|
||||
func TestRun_WithFlatten(t *testing.T) {
|
||||
nested := Of(Of(42))
|
||||
flattened := Flatten(nested)
|
||||
result := Run(flattened)
|
||||
assert.Equal(t, 42, result)
|
||||
}
|
||||
|
||||
// TestRun_WithAp tests Run with applicative operations
|
||||
func TestRun_WithAp(t *testing.T) {
|
||||
double := N.Mul(2)
|
||||
io := F.Pipe1(
|
||||
Of(double),
|
||||
Ap[int](Of(21)),
|
||||
)
|
||||
result := Run(io)
|
||||
assert.Equal(t, 42, result)
|
||||
}
|
||||
|
||||
// TestRun_DifferentTypes tests Run with various types
|
||||
func TestRun_DifferentTypes(t *testing.T) {
|
||||
// Test with bool
|
||||
boolIO := Of(true)
|
||||
assert.True(t, Run(boolIO))
|
||||
|
||||
// Test with float
|
||||
floatIO := Of(3.14)
|
||||
assert.Equal(t, 3.14, Run(floatIO))
|
||||
|
||||
// Test with slice
|
||||
sliceIO := Of([]int{1, 2, 3})
|
||||
assert.Equal(t, []int{1, 2, 3}, Run(sliceIO))
|
||||
|
||||
// Test with struct
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
personIO := Of(Person{Name: "Alice", Age: 30})
|
||||
assert.Equal(t, Person{Name: "Alice", Age: 30}, Run(personIO))
|
||||
}
|
||||
|
||||
// TestRun_WithApFirst tests Run with ApFirst combinator
|
||||
func TestRun_WithApFirst(t *testing.T) {
|
||||
io := F.Pipe1(
|
||||
Of("first"),
|
||||
ApFirst[string](Of("second")),
|
||||
)
|
||||
result := Run(io)
|
||||
assert.Equal(t, "first", result)
|
||||
}
|
||||
|
||||
// TestRun_WithApSecond tests Run with ApSecond combinator
|
||||
func TestRun_WithApSecond(t *testing.T) {
|
||||
io := F.Pipe1(
|
||||
Of("first"),
|
||||
ApSecond[string](Of("second")),
|
||||
)
|
||||
result := Run(io)
|
||||
assert.Equal(t, "second", result)
|
||||
}
|
||||
|
||||
// TestRun_MultipleExecutions tests that Run can be called multiple times
|
||||
func TestRun_MultipleExecutions(t *testing.T) {
|
||||
io := Of(100)
|
||||
|
||||
// Execute multiple times
|
||||
result1 := Run(io)
|
||||
result2 := Run(io)
|
||||
result3 := Run(io)
|
||||
|
||||
assert.Equal(t, 100, result1)
|
||||
assert.Equal(t, 100, result2)
|
||||
assert.Equal(t, 100, result3)
|
||||
}
|
||||
|
||||
// TestRun_WithChainedSideEffects tests Run with multiple side effects
|
||||
func TestRun_WithChainedSideEffects(t *testing.T) {
|
||||
log := []string{}
|
||||
|
||||
io := F.Pipe2(
|
||||
func() string {
|
||||
log = append(log, "step1")
|
||||
return "a"
|
||||
},
|
||||
Chain(func(s string) IO[string] {
|
||||
return func() string {
|
||||
log = append(log, "step2")
|
||||
return s + "b"
|
||||
}
|
||||
}),
|
||||
Chain(func(s string) IO[string] {
|
||||
return func() string {
|
||||
log = append(log, "step3")
|
||||
return s + "c"
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
result := Run(io)
|
||||
assert.Equal(t, "abc", result)
|
||||
assert.Equal(t, []string{"step1", "step2", "step3"}, log)
|
||||
}
|
||||
|
||||
// TestRun_ZeroValue tests Run with zero values
|
||||
func TestRun_ZeroValue(t *testing.T) {
|
||||
// Test with zero int
|
||||
intIO := Of(0)
|
||||
assert.Equal(t, 0, Run(intIO))
|
||||
|
||||
// Test with empty string
|
||||
strIO := Of("")
|
||||
assert.Equal(t, "", Run(strIO))
|
||||
|
||||
// Test with nil slice
|
||||
var nilSlice []int
|
||||
sliceIO := Of(nilSlice)
|
||||
assert.Nil(t, Run(sliceIO))
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"iter"
|
||||
|
||||
"github.com/IBM/fp-go/v2/consumer"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
M "github.com/IBM/fp-go/v2/monoid"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/predicate"
|
||||
@@ -48,4 +49,6 @@ type (
|
||||
// Predicate represents a function that tests a value of type A and returns a boolean.
|
||||
// It's commonly used for filtering and conditional operations.
|
||||
Predicate[A any] = predicate.Predicate[A]
|
||||
|
||||
Void = function.Void
|
||||
)
|
||||
|
||||
@@ -23,9 +23,9 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
// This function bridges the gap between pure consumers (functions that consume values
|
||||
// without returning anything) and the IOEither monad. It takes a Consumer[A] and returns
|
||||
// an Operator that:
|
||||
// 1. If the IOEither is Right, executes the consumer with the value as a side effect
|
||||
// 2. If the IOEither is Left, propagates the error without calling the consumer
|
||||
// 3. Returns IOEither[E, struct{}] to maintain the monadic chain
|
||||
// 1. If the IOEither is Right, executes the consumer with the value as a side effect
|
||||
// 2. If the IOEither is Left, propagates the error without calling the consumer
|
||||
// 3. Returns IOEither[E, struct{}] to maintain the monadic chain
|
||||
//
|
||||
// The consumer is only executed for successful (Right) values. Errors (Left values) are
|
||||
// propagated unchanged. This is useful for operations like logging successful results,
|
||||
@@ -79,12 +79,13 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
// ioeither.Map[error](func(struct{}) int { return len(successfulValues) }),
|
||||
// )
|
||||
// count := pipeline() // Returns Right(1), successfulValues contains [100]
|
||||
//
|
||||
//go:inline
|
||||
func ChainConsumer[E, A any](c Consumer[A]) Operator[E, A, struct{}] {
|
||||
return ChainIOK[E](io.FromConsumerK(c))
|
||||
return ChainIOK[E](io.FromConsumer(c))
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func ChainFirstConsumer[E, A any](c Consumer[A]) Operator[E, A, A] {
|
||||
return ChainFirstIOK[E](io.FromConsumerK(c))
|
||||
return ChainFirstIOK[E](io.FromConsumer(c))
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
|
||||
//go:inline
|
||||
func ChainConsumer[A any](c Consumer[A]) Operator[A, struct{}] {
|
||||
return ChainIOK(io.FromConsumerK(c))
|
||||
return ChainIOK(io.FromConsumer(c))
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func ChainFirstConsumer[A any](c Consumer[A]) Operator[A, A] {
|
||||
return ChainFirstIOK(io.FromConsumerK(c))
|
||||
return ChainFirstIOK(io.FromConsumer(c))
|
||||
}
|
||||
|
||||
201
v2/ioref/doc.go
Normal file
201
v2/ioref/doc.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package ioref provides mutable references in the IO monad.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// IORef represents a mutable reference that can be read and written within IO computations.
|
||||
// It provides thread-safe access to shared mutable state using read-write locks, making it
|
||||
// safe to use across multiple goroutines.
|
||||
//
|
||||
// This package is inspired by Haskell's Data.IORef module and provides a functional approach
|
||||
// to managing mutable state with explicit IO effects.
|
||||
//
|
||||
// # Core Operations
|
||||
//
|
||||
// The package provides four main operations:
|
||||
//
|
||||
// - MakeIORef: Creates a new IORef with an initial value
|
||||
// - Read: Atomically reads the current value from an IORef
|
||||
// - Write: Atomically writes a new value to an IORef
|
||||
// - Modify: Atomically modifies the value using a transformation function
|
||||
// - ModifyWithResult: Atomically modifies the value and returns a computed result
|
||||
//
|
||||
// # Thread Safety
|
||||
//
|
||||
// All operations on IORef are thread-safe:
|
||||
//
|
||||
// - Read operations use read locks, allowing multiple concurrent readers
|
||||
// - Write and Modify operations use write locks, ensuring exclusive access
|
||||
// - The underlying sync.RWMutex ensures proper synchronization
|
||||
//
|
||||
// # Basic Usage
|
||||
//
|
||||
// Creating and using an IORef:
|
||||
//
|
||||
// import (
|
||||
// "github.com/IBM/fp-go/v2/ioref"
|
||||
// )
|
||||
//
|
||||
// // Create a new IORef
|
||||
// ref := ioref.MakeIORef(42)()
|
||||
//
|
||||
// // Read the current value
|
||||
// value := ioref.Read(ref)() // 42
|
||||
//
|
||||
// // Write a new value
|
||||
// ioref.Write(100)(ref)()
|
||||
//
|
||||
// // Read the updated value
|
||||
// newValue := ioref.Read(ref)() // 100
|
||||
//
|
||||
// # Modifying Values
|
||||
//
|
||||
// Use Modify to transform the value in place:
|
||||
//
|
||||
// ref := ioref.MakeIORef(10)()
|
||||
//
|
||||
// // Double the value
|
||||
// ioref.Modify(func(x int) int { return x * 2 })(ref)()
|
||||
//
|
||||
// // Chain multiple modifications
|
||||
// ioref.Modify(func(x int) int { return x + 5 })(ref)()
|
||||
// ioref.Modify(func(x int) int { return x * 3 })(ref)()
|
||||
//
|
||||
// result := ioref.Read(ref)() // (10 * 2 + 5) * 3 = 75
|
||||
//
|
||||
// # Atomic Modify with Result
|
||||
//
|
||||
// Use ModifyWithResult when you need to both transform the value and compute a result
|
||||
// from the old value in a single atomic operation:
|
||||
//
|
||||
// ref := ioref.MakeIORef(42)()
|
||||
//
|
||||
// // Increment and return the old value
|
||||
// oldValue := ioref.ModifyWithResult(func(x int) pair.Pair[int, int] {
|
||||
// return pair.MakePair(x+1, x)
|
||||
// })(ref)()
|
||||
//
|
||||
// // oldValue is 42, ref now contains 43
|
||||
//
|
||||
// This is particularly useful for implementing counters, swapping values, or any operation
|
||||
// where you need to know the previous state.
|
||||
//
|
||||
// # Concurrent Usage
|
||||
//
|
||||
// IORef is safe to use across multiple goroutines:
|
||||
//
|
||||
// ref := ioref.MakeIORef(0)()
|
||||
//
|
||||
// // Multiple goroutines can safely modify the same IORef
|
||||
// var wg sync.WaitGroup
|
||||
// for i := 0; i < 100; i++ {
|
||||
// wg.Add(1)
|
||||
// go func() {
|
||||
// defer wg.Done()
|
||||
// ioref.Modify(func(x int) int { return x + 1 })(ref)()
|
||||
// }()
|
||||
// }
|
||||
// wg.Wait()
|
||||
//
|
||||
// result := ioref.Read(ref)() // 100
|
||||
//
|
||||
// # Comparison with Haskell's IORef
|
||||
//
|
||||
// This implementation provides the following Haskell IORef operations:
|
||||
//
|
||||
// - newIORef → MakeIORef
|
||||
// - readIORef → Read
|
||||
// - writeIORef → Write
|
||||
// - modifyIORef → Modify
|
||||
// - atomicModifyIORef → ModifyWithResult
|
||||
//
|
||||
// The main difference is that Go's implementation uses explicit locking (sync.RWMutex)
|
||||
// rather than relying on the runtime's STM (Software Transactional Memory) as Haskell does.
|
||||
//
|
||||
// # Performance Considerations
|
||||
//
|
||||
// IORef operations are highly optimized:
|
||||
//
|
||||
// - Read operations are very fast (~5ns) and allow concurrent access
|
||||
// - Write and Modify operations are slightly slower (~7-8ns) due to exclusive locking
|
||||
// - ModifyWithResult is marginally slower (~9ns) due to tuple creation
|
||||
// - All operations have zero allocations in the common case
|
||||
//
|
||||
// For high-contention scenarios, consider:
|
||||
//
|
||||
// - Using multiple IORefs to reduce lock contention
|
||||
// - Batching modifications when possible
|
||||
// - Using Read locks for read-heavy workloads
|
||||
//
|
||||
// # Examples
|
||||
//
|
||||
// Counter with atomic increment:
|
||||
//
|
||||
// counter := ioref.MakeIORef(0)()
|
||||
//
|
||||
// increment := func() int {
|
||||
// return ioref.ModifyWithResult(func(x int) pair.Pair[int, int] {
|
||||
// return pair.MakePair(x+1, x+1)
|
||||
// })(counter)()
|
||||
// }
|
||||
//
|
||||
// id1 := increment() // 1
|
||||
// id2 := increment() // 2
|
||||
// id3 := increment() // 3
|
||||
//
|
||||
// Shared configuration:
|
||||
//
|
||||
// type Config struct {
|
||||
// MaxRetries int
|
||||
// Timeout time.Duration
|
||||
// }
|
||||
//
|
||||
// configRef := ioref.MakeIORef(Config{
|
||||
// MaxRetries: 3,
|
||||
// Timeout: 5 * time.Second,
|
||||
// })()
|
||||
//
|
||||
// // Update configuration
|
||||
// ioref.Modify(func(c Config) Config {
|
||||
// c.MaxRetries = 5
|
||||
// return c
|
||||
// })(configRef)()
|
||||
//
|
||||
// // Read configuration
|
||||
// config := ioref.Read(configRef)()
|
||||
//
|
||||
// Stack implementation:
|
||||
//
|
||||
// type Stack []int
|
||||
//
|
||||
// stackRef := ioref.MakeIORef(Stack{})()
|
||||
//
|
||||
// push := func(value int) {
|
||||
// ioref.Modify(func(s Stack) Stack {
|
||||
// return append(s, value)
|
||||
// })(stackRef)()
|
||||
// }
|
||||
//
|
||||
// pop := func() option.Option[int] {
|
||||
// return ioref.ModifyWithResult(func(s Stack) pair.Pair[Stack, option.Option[int]] {
|
||||
// if len(s) == 0 {
|
||||
// return pair.MakePair(s, option.None[int]())
|
||||
// }
|
||||
// return pair.MakePair(s[:len(s)-1], option.Some(s[len(s)-1]))
|
||||
// })(stackRef)()
|
||||
// }
|
||||
package ioref
|
||||
180
v2/ioref/ioref.go
Normal file
180
v2/ioref/ioref.go
Normal file
@@ -0,0 +1,180 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ioref
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
)
|
||||
|
||||
// MakeIORef creates a new IORef containing the given initial value.
|
||||
//
|
||||
// This function returns an IO computation that, when executed, creates a new
|
||||
// mutable reference initialized with the provided value. The reference is
|
||||
// thread-safe and can be safely shared across goroutines.
|
||||
//
|
||||
// Parameters:
|
||||
// - a: The initial value to store in the IORef
|
||||
//
|
||||
// Returns:
|
||||
// - An IO computation that produces a new IORef[A]
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Create a new IORef with initial value 42
|
||||
// refIO := ioref.MakeIORef(42)
|
||||
// ref := refIO() // Execute the IO to get the IORef
|
||||
//
|
||||
// // Create an IORef with a string
|
||||
// strRefIO := ioref.MakeIORef("hello")
|
||||
// strRef := strRefIO()
|
||||
//
|
||||
//go:inline
|
||||
func MakeIORef[A any](a A) IO[IORef[A]] {
|
||||
return func() IORef[A] {
|
||||
return &ioRef[A]{a: a}
|
||||
}
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func Write[A any](a A) io.Kleisli[IORef[A], A] {
|
||||
return func(ref IORef[A]) IO[A] {
|
||||
return func() A {
|
||||
ref.mu.Lock()
|
||||
defer ref.mu.Unlock()
|
||||
|
||||
ref.a = a
|
||||
return a
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read atomically reads the current value from an IORef.
|
||||
//
|
||||
// This function returns an IO computation that reads the value stored in the
|
||||
// IORef. The read operation is thread-safe, using a read lock that allows
|
||||
// multiple concurrent readers but excludes writers.
|
||||
//
|
||||
// Parameters:
|
||||
// - ref: The IORef to read from
|
||||
//
|
||||
// Returns:
|
||||
// - An IO computation that produces the current value of type A
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ref := ioref.MakeIORef(42)()
|
||||
//
|
||||
// // Read the current value
|
||||
// value := ioref.Read(ref)() // 42
|
||||
//
|
||||
// // Use in a pipeline
|
||||
// result := pipe.Pipe2(
|
||||
// ref,
|
||||
// ioref.Read[int],
|
||||
// io.Map(func(x int) int { return x * 2 }),
|
||||
// )()
|
||||
//
|
||||
//go:inline
|
||||
func Read[A any](ref IORef[A]) IO[A] {
|
||||
return func() A {
|
||||
ref.mu.RLock()
|
||||
defer ref.mu.RUnlock()
|
||||
|
||||
return ref.a
|
||||
}
|
||||
}
|
||||
|
||||
// Modify atomically modifies the value in an IORef using the given function.
|
||||
//
|
||||
// This function returns a Kleisli arrow that takes an IORef and produces an IO
|
||||
// computation that applies the transformation function to the current value.
|
||||
// The modification is atomic and thread-safe, using a write lock to ensure
|
||||
// exclusive access during the read-modify-write cycle.
|
||||
//
|
||||
// Parameters:
|
||||
// - f: An endomorphism (function from A to A) that transforms the current value
|
||||
//
|
||||
// Returns:
|
||||
// - A Kleisli arrow from IORef[A] to IO[IORef[A]]
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ref := ioref.MakeIORef(42)()
|
||||
//
|
||||
// // Double the value
|
||||
// ioref.Modify(func(x int) int { return x * 2 })(ref)()
|
||||
//
|
||||
// // Chain multiple modifications
|
||||
// pipe.Pipe2(
|
||||
// ref,
|
||||
// ioref.Modify(func(x int) int { return x + 10 }),
|
||||
// io.Chain(ioref.Modify(func(x int) int { return x * 2 })),
|
||||
// )()
|
||||
//
|
||||
//go:inline
|
||||
func Modify[A any](f Endomorphism[A]) io.Kleisli[IORef[A], A] {
|
||||
return func(ref IORef[A]) IO[A] {
|
||||
return func() A {
|
||||
ref.mu.Lock()
|
||||
defer ref.mu.Unlock()
|
||||
|
||||
ref.a = f(ref.a)
|
||||
return ref.a
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyWithResult atomically modifies the value in an IORef and returns both
|
||||
// the new value and an additional result computed from the old value.
|
||||
//
|
||||
// This function is useful when you need to both transform the stored value and
|
||||
// compute some result based on the old value in a single atomic operation.
|
||||
// It's similar to Haskell's atomicModifyIORef.
|
||||
//
|
||||
// Parameters:
|
||||
// - f: A function that takes the old value and returns a Pair of (new value, result)
|
||||
//
|
||||
// Returns:
|
||||
// - A Kleisli arrow from IORef[A] to IO[B] that produces the result
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ref := ioref.MakeIORef(42)()
|
||||
//
|
||||
// // Increment and return the old value
|
||||
// oldValue := ioref.ModifyWithResult(func(x int) pair.Pair[int, int] {
|
||||
// return pair.MakePair(x+1, x)
|
||||
// })(ref)() // Returns 42, ref now contains 43
|
||||
//
|
||||
// // Swap and return the old value
|
||||
// old := ioref.ModifyWithResult(func(x int) pair.Pair[int, int] {
|
||||
// return pair.MakePair(100, x)
|
||||
// })(ref)() // Returns 43, ref now contains 100
|
||||
//
|
||||
//go:inline
|
||||
func ModifyWithResult[A, B any](f func(A) Pair[A, B]) io.Kleisli[IORef[A], B] {
|
||||
return func(ref IORef[A]) IO[B] {
|
||||
return func() B {
|
||||
ref.mu.Lock()
|
||||
defer ref.mu.Unlock()
|
||||
|
||||
result := f(ref.a)
|
||||
ref.a = pair.Head(result)
|
||||
return pair.Tail(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
74
v2/ioref/types.go
Normal file
74
v2/ioref/types.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package ioref provides mutable references in the IO monad.
|
||||
//
|
||||
// IORef represents a mutable reference that can be read and written within IO computations.
|
||||
// It provides thread-safe access to shared mutable state using read-write locks.
|
||||
//
|
||||
// This is inspired by Haskell's Data.IORef module and provides a functional approach
|
||||
// to managing mutable state with explicit IO effects.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// // Create a new IORef
|
||||
// ref := ioref.MakeIORef(42)()
|
||||
//
|
||||
// // Read the current value
|
||||
// value := ioref.Read(ref)() // 42
|
||||
//
|
||||
// // Write a new value
|
||||
// ioref.Write(100)(ref)()
|
||||
//
|
||||
// // Modify the value
|
||||
// ioref.Modify(func(x int) int { return x * 2 })(ref)()
|
||||
//
|
||||
// // Read the modified value
|
||||
// newValue := ioref.Read(ref)() // 200
|
||||
package ioref
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/IBM/fp-go/v2/endomorphism"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
)
|
||||
|
||||
type (
|
||||
// ioRef is the internal implementation of a mutable reference.
|
||||
// It uses a read-write mutex to ensure thread-safe access.
|
||||
ioRef[A any] struct {
|
||||
mu sync.RWMutex
|
||||
a A
|
||||
}
|
||||
|
||||
// IO represents a synchronous computation that may have side effects.
|
||||
// It's a function that takes no arguments and returns a value of type A.
|
||||
IO[A any] = io.IO[A]
|
||||
|
||||
// IORef represents a mutable reference to a value of type A.
|
||||
// Operations on IORef are thread-safe and performed within the IO monad.
|
||||
//
|
||||
// IORef provides a way to work with mutable state in a functional style,
|
||||
// where mutations are explicit and contained within IO computations.
|
||||
IORef[A any] = *ioRef[A]
|
||||
|
||||
// Endomorphism represents a function from A to A.
|
||||
// It's commonly used with Modify to transform the value in an IORef.
|
||||
Endomorphism[A any] = endomorphism.Endomorphism[A]
|
||||
|
||||
Pair[A, B any] = pair.Pair[A, B]
|
||||
)
|
||||
@@ -52,6 +52,7 @@ import (
|
||||
M "github.com/IBM/fp-go/v2/monoid"
|
||||
"github.com/IBM/fp-go/v2/option"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
)
|
||||
|
||||
// Of creates a sequence containing a single element.
|
||||
@@ -507,7 +508,7 @@ func MonadAp[B, A any](fab Seq[func(A) B], fa Seq[A]) Seq[B] {
|
||||
//
|
||||
//go:inline
|
||||
func Ap[B, A any](fa Seq[A]) Operator[func(A) B, B] {
|
||||
return F.Bind2nd(MonadAp[B, A], fa)
|
||||
return Chain(F.Bind1st(MonadMap[A, B], fa))
|
||||
}
|
||||
|
||||
// From creates a sequence from a variadic list of elements.
|
||||
@@ -708,9 +709,7 @@ func Fold[A any](m M.Monoid[A]) func(Seq[A]) A {
|
||||
//
|
||||
//go:inline
|
||||
func MonadFoldMap[A, B any](fa Seq[A], f func(A) B, m M.Monoid[B]) B {
|
||||
return MonadReduce(fa, func(b B, a A) B {
|
||||
return m.Concat(b, f(a))
|
||||
}, m.Empty())
|
||||
return MonadFold(MonadMap(fa, f), m)
|
||||
}
|
||||
|
||||
// FoldMap returns a function that maps and folds using a monoid.
|
||||
@@ -728,11 +727,10 @@ func MonadFoldMap[A, B any](fa Seq[A], f func(A) B, m M.Monoid[B]) B {
|
||||
//
|
||||
//go:inline
|
||||
func FoldMap[A, B any](m M.Monoid[B]) func(func(A) B) func(Seq[A]) B {
|
||||
return func(f func(A) B) func(Seq[A]) B {
|
||||
return func(as Seq[A]) B {
|
||||
return MonadFoldMap(as, f, m)
|
||||
}
|
||||
}
|
||||
return F.Pipe1(
|
||||
Map[A, B],
|
||||
reader.Map[func(A) B](reader.Map[Seq[A]](Fold(m))),
|
||||
)
|
||||
}
|
||||
|
||||
// MonadFoldMapWithIndex maps each element with its index to a monoid value and combines them.
|
||||
@@ -903,11 +901,51 @@ func Zip[A, B any](fb Seq[B]) func(Seq[A]) Seq2[A, B] {
|
||||
return F.Bind2nd(MonadZip[A, B], fb)
|
||||
}
|
||||
|
||||
// MonadMapToArray maps each element in a sequence using a function and collects the results into an array.
|
||||
// This is a convenience function that combines Map and collection into a single operation.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The type of elements in the input sequence
|
||||
// - B: The type of elements in the output array
|
||||
//
|
||||
// Parameters:
|
||||
// - fa: The input sequence to map
|
||||
// - f: The mapping function to apply to each element
|
||||
//
|
||||
// Returns:
|
||||
// - A slice containing all mapped elements
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// seq := From(1, 2, 3)
|
||||
// result := MonadMapToArray(seq, N.Mul(2))
|
||||
// // returns: []int{2, 4, 6}
|
||||
//
|
||||
//go:inline
|
||||
func MonadMapToArray[A, B any](fa Seq[A], f func(A) B) []B {
|
||||
return G.MonadMapToArray[Seq[A], []B](fa, f)
|
||||
}
|
||||
|
||||
// MapToArray returns a function that maps elements and collects them into an array.
|
||||
// This is the curried version of MonadMapToArray.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The type of elements in the input sequence
|
||||
// - B: The type of elements in the output array
|
||||
//
|
||||
// Parameters:
|
||||
// - f: The mapping function to apply to each element
|
||||
//
|
||||
// Returns:
|
||||
// - A function that takes a sequence and returns a slice of mapped elements
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// double := MapToArray(N.Mul(2))
|
||||
// seq := From(1, 2, 3)
|
||||
// result := double(seq)
|
||||
// // returns: []int{2, 4, 6}
|
||||
//
|
||||
//go:inline
|
||||
func MapToArray[A, B any](f func(A) B) func(Seq[A]) []B {
|
||||
return G.MapToArray[Seq[A], []B](f)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -295,13 +296,13 @@ func TestMakeBy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMakeByZero(t *testing.T) {
|
||||
seq := MakeBy(0, func(i int) int { return i })
|
||||
seq := MakeBy(0, F.Identity)
|
||||
result := toSlice(seq)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestMakeByNegative(t *testing.T) {
|
||||
seq := MakeBy(-5, func(i int) int { return i })
|
||||
seq := MakeBy(-5, F.Identity)
|
||||
result := toSlice(seq)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
@@ -375,17 +376,13 @@ func TestFold(t *testing.T) {
|
||||
|
||||
func TestMonadFoldMap(t *testing.T) {
|
||||
seq := From(1, 2, 3)
|
||||
result := MonadFoldMap(seq, func(x int) string {
|
||||
return fmt.Sprintf("%d", x)
|
||||
}, S.Monoid)
|
||||
result := MonadFoldMap(seq, strconv.Itoa, S.Monoid)
|
||||
assert.Equal(t, "123", result)
|
||||
}
|
||||
|
||||
func TestFoldMap(t *testing.T) {
|
||||
seq := From(1, 2, 3)
|
||||
folder := FoldMap[int](S.Monoid)(func(x int) string {
|
||||
return fmt.Sprintf("%d", x)
|
||||
})
|
||||
folder := FoldMap[int](S.Monoid)(strconv.Itoa)
|
||||
result := folder(seq)
|
||||
assert.Equal(t, "123", result)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
I "iter"
|
||||
|
||||
"github.com/IBM/fp-go/v2/endomorphism"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/iterator/stateless"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/optics/prism"
|
||||
@@ -29,41 +30,155 @@ import (
|
||||
|
||||
type (
|
||||
// Option represents an optional value, either Some(value) or None.
|
||||
// It is used to handle computations that may or may not return a value,
|
||||
// providing a type-safe alternative to nil pointers or sentinel values.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The type of the value that may be present
|
||||
Option[A any] = option.Option[A]
|
||||
|
||||
// Seq is a single-value iterator sequence from Go 1.23+.
|
||||
// It represents a lazy sequence of values that can be iterated using range.
|
||||
// Operations on Seq are lazy and only execute when the sequence is consumed.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - T: The type of elements in the sequence
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// seq := From(1, 2, 3)
|
||||
// for v := range seq {
|
||||
// fmt.Println(v)
|
||||
// }
|
||||
Seq[T any] = I.Seq[T]
|
||||
|
||||
// Seq2 is a key-value iterator sequence from Go 1.23+.
|
||||
// It represents a lazy sequence of key-value pairs that can be iterated using range.
|
||||
// This is useful for working with map-like data structures in a functional way.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - K: The type of keys in the sequence
|
||||
// - V: The type of values in the sequence
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// seq := MonadZip(From(1, 2, 3), From("a", "b", "c"))
|
||||
// for k, v := range seq {
|
||||
// fmt.Printf("%d: %s\n", k, v)
|
||||
// }
|
||||
Seq2[K, V any] = I.Seq2[K, V]
|
||||
|
||||
// Iterator is a stateless iterator type.
|
||||
// It provides a functional interface for iterating over collections
|
||||
// without maintaining internal state.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - T: The type of elements produced by the iterator
|
||||
Iterator[T any] = stateless.Iterator[T]
|
||||
|
||||
// Predicate is a function that tests a value and returns a boolean.
|
||||
// Predicates are commonly used for filtering operations.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - T: The type of value being tested
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// isEven := func(x int) bool { return x%2 == 0 }
|
||||
// filtered := Filter(isEven)(From(1, 2, 3, 4))
|
||||
Predicate[T any] = predicate.Predicate[T]
|
||||
|
||||
// Kleisli represents a function that takes a value and returns a sequence.
|
||||
// This is the monadic bind operation for sequences.
|
||||
// This is the monadic bind operation for sequences, also known as flatMap.
|
||||
// It's used to chain operations that produce sequences.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The input type
|
||||
// - B: The element type of the output sequence
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// duplicate := func(x int) Seq[int] { return From(x, x) }
|
||||
// result := Chain(duplicate)(From(1, 2, 3))
|
||||
// // yields: 1, 1, 2, 2, 3, 3
|
||||
Kleisli[A, B any] = func(A) Seq[B]
|
||||
|
||||
// Kleisli2 represents a function that takes a value and returns a key-value sequence.
|
||||
// This is the monadic bind operation for key-value sequences.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - K: The key type in the output sequence
|
||||
// - A: The input type
|
||||
// - B: The value type in the output sequence
|
||||
Kleisli2[K, A, B any] = func(A) Seq2[K, B]
|
||||
|
||||
// Operator represents a transformation from one sequence to another.
|
||||
// It's a function that takes a Seq[A] and returns a Seq[B].
|
||||
// Operators are the building blocks for composing sequence transformations.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The element type of the input sequence
|
||||
// - B: The element type of the output sequence
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// double := Map(func(x int) int { return x * 2 })
|
||||
// result := double(From(1, 2, 3))
|
||||
// // yields: 2, 4, 6
|
||||
Operator[A, B any] = Kleisli[Seq[A], B]
|
||||
|
||||
// Operator2 represents a transformation from one key-value sequence to another.
|
||||
// It's a function that takes a Seq2[K, A] and returns a Seq2[K, B].
|
||||
//
|
||||
// Type Parameters:
|
||||
// - K: The key type (preserved in the transformation)
|
||||
// - A: The value type of the input sequence
|
||||
// - B: The value type of the output sequence
|
||||
Operator2[K, A, B any] = Kleisli2[K, Seq2[K, A], B]
|
||||
|
||||
// Lens is an optic that focuses on a field within a structure.
|
||||
// It provides a functional way to get and set values in immutable data structures.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - S: The structure type
|
||||
// - A: The field type being focused on
|
||||
Lens[S, A any] = lens.Lens[S, A]
|
||||
|
||||
// Prism is an optic that focuses on a case of a sum type.
|
||||
// It provides a functional way to work with variant types (like Result or Option).
|
||||
//
|
||||
// Type Parameters:
|
||||
// - S: The sum type
|
||||
// - A: The case type being focused on
|
||||
Prism[S, A any] = prism.Prism[S, A]
|
||||
|
||||
// Endomorphism is a function from a type to itself.
|
||||
// It represents transformations that preserve the type.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The type being transformed
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// increment := func(x int) int { return x + 1 }
|
||||
// result := increment(5) // returns 6
|
||||
Endomorphism[A any] = endomorphism.Endomorphism[A]
|
||||
|
||||
// Pair represents a tuple of two values.
|
||||
// It's used to group two related values together.
|
||||
//
|
||||
// Type Parameters:
|
||||
// - A: The type of the first element
|
||||
// - B: The type of the second element
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// p := pair.MakePair(1, "hello")
|
||||
// first := pair.Head(p) // returns 1
|
||||
// second := pair.Tail(p) // returns "hello"
|
||||
Pair[A, B any] = pair.Pair[A, B]
|
||||
|
||||
// Void represents the absence of a value, similar to void in other languages.
|
||||
// It's used in functions that perform side effects but don't return meaningful values.
|
||||
Void = function.Void
|
||||
)
|
||||
|
||||
@@ -15,7 +15,10 @@
|
||||
|
||||
package iter
|
||||
|
||||
import F "github.com/IBM/fp-go/v2/function"
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
)
|
||||
|
||||
// Uniq returns an operator that filters a sequence to contain only unique elements,
|
||||
// where uniqueness is determined by a key extraction function.
|
||||
@@ -92,11 +95,11 @@ import F "github.com/IBM/fp-go/v2/function"
|
||||
func Uniq[A any, K comparable](f func(A) K) Operator[A, A] {
|
||||
return func(s Seq[A]) Seq[A] {
|
||||
return func(yield func(A) bool) {
|
||||
items := make(map[K]struct{})
|
||||
items := make(map[K]Void)
|
||||
for a := range s {
|
||||
k := f(a)
|
||||
if _, ok := items[k]; !ok {
|
||||
items[k] = struct{}{}
|
||||
items[k] = function.VOID
|
||||
if !yield(a) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -83,6 +83,8 @@ func Monoid[A any]() func(S.Semigroup[A]) M.Monoid[Option[A]] {
|
||||
// intMonoid := monoid.MakeMonoid(func(a, b int) int { return a + b }, 0)
|
||||
// optMonoid := AlternativeMonoid(intMonoid)
|
||||
// result := optMonoid.Concat(Some(2), Some(3)) // Some(5)
|
||||
//
|
||||
//go:inline
|
||||
func AlternativeMonoid[A any](m M.Monoid[A]) M.Monoid[Option[A]] {
|
||||
return M.AlternativeMonoid(
|
||||
Of[A],
|
||||
@@ -103,9 +105,81 @@ func AlternativeMonoid[A any](m M.Monoid[A]) M.Monoid[Option[A]] {
|
||||
// optMonoid.Concat(Some(2), Some(3)) // Some(2) - returns first Some
|
||||
// optMonoid.Concat(None[int](), Some(3)) // Some(3)
|
||||
// optMonoid.Empty() // None
|
||||
//
|
||||
//go:inline
|
||||
func AltMonoid[A any]() M.Monoid[Option[A]] {
|
||||
return M.AltMonoid(
|
||||
None[A],
|
||||
MonadAlt[A],
|
||||
)
|
||||
}
|
||||
|
||||
// takeFirst is a helper function that returns the first Some value, or the second if the first is None.
|
||||
func takeFirst[A any](l, r Option[A]) Option[A] {
|
||||
if IsSome(l) {
|
||||
return l
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FirstMonoid creates a Monoid for Option[A] that returns the first Some value.
|
||||
// This monoid prefers the left operand when it is Some, otherwise returns the right operand.
|
||||
// The empty value is None.
|
||||
//
|
||||
// This is equivalent to AltMonoid but implemented more directly.
|
||||
//
|
||||
// Truth table:
|
||||
//
|
||||
// | x | y | concat(x, y) |
|
||||
// | ------- | ------- | ------------ |
|
||||
// | none | none | none |
|
||||
// | some(a) | none | some(a) |
|
||||
// | none | some(b) | some(b) |
|
||||
// | some(a) | some(b) | some(a) |
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// optMonoid := FirstMonoid[int]()
|
||||
// optMonoid.Concat(Some(2), Some(3)) // Some(2) - returns first Some
|
||||
// optMonoid.Concat(None[int](), Some(3)) // Some(3)
|
||||
// optMonoid.Concat(Some(2), None[int]()) // Some(2)
|
||||
// optMonoid.Empty() // None
|
||||
//
|
||||
//go:inline
|
||||
func FirstMonoid[A any]() M.Monoid[Option[A]] {
|
||||
return M.MakeMonoid(takeFirst[A], None[A]())
|
||||
}
|
||||
|
||||
// takeLast is a helper function that returns the last Some value, or the first if the last is None.
|
||||
func takeLast[A any](l, r Option[A]) Option[A] {
|
||||
if IsSome(r) {
|
||||
return r
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// LastMonoid creates a Monoid for Option[A] that returns the last Some value.
|
||||
// This monoid prefers the right operand when it is Some, otherwise returns the left operand.
|
||||
// The empty value is None.
|
||||
//
|
||||
// Truth table:
|
||||
//
|
||||
// | x | y | concat(x, y) |
|
||||
// | ------- | ------- | ------------ |
|
||||
// | none | none | none |
|
||||
// | some(a) | none | some(a) |
|
||||
// | none | some(b) | some(b) |
|
||||
// | some(a) | some(b) | some(b) |
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// optMonoid := LastMonoid[int]()
|
||||
// optMonoid.Concat(Some(2), Some(3)) // Some(3) - returns last Some
|
||||
// optMonoid.Concat(None[int](), Some(3)) // Some(3)
|
||||
// optMonoid.Concat(Some(2), None[int]()) // Some(2)
|
||||
// optMonoid.Empty() // None
|
||||
//
|
||||
//go:inline
|
||||
func LastMonoid[A any]() M.Monoid[Option[A]] {
|
||||
return M.MakeMonoid(takeLast[A], None[A]())
|
||||
}
|
||||
|
||||
445
v2/option/monoid_test.go
Normal file
445
v2/option/monoid_test.go
Normal file
@@ -0,0 +1,445 @@
|
||||
// Copyright (c) 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package option
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
M "github.com/IBM/fp-go/v2/monoid"
|
||||
N "github.com/IBM/fp-go/v2/number"
|
||||
S "github.com/IBM/fp-go/v2/semigroup"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestSemigroupAssociativity tests the associativity law for Semigroup
|
||||
func TestSemigroupAssociativity(t *testing.T) {
|
||||
intSemigroup := S.MakeSemigroup(func(a, b int) int { return a + b })
|
||||
optSemigroup := Semigroup[int]()(intSemigroup)
|
||||
|
||||
a := Some(1)
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
// Test that (a • b) • c = a • (b • c)
|
||||
left := optSemigroup.Concat(optSemigroup.Concat(a, b), c)
|
||||
right := optSemigroup.Concat(a, optSemigroup.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Some(6), left)
|
||||
}
|
||||
|
||||
// TestSemigroupWithNone tests Semigroup behavior with None values
|
||||
func TestSemigroupWithNone(t *testing.T) {
|
||||
intSemigroup := S.MakeSemigroup(func(a, b int) int { return a + b })
|
||||
optSemigroup := Semigroup[int]()(intSemigroup)
|
||||
|
||||
t.Run("None with None", func(t *testing.T) {
|
||||
result := optSemigroup.Concat(None[int](), None[int]())
|
||||
assert.Equal(t, None[int](), result)
|
||||
})
|
||||
|
||||
t.Run("associativity with None", func(t *testing.T) {
|
||||
a := None[int]()
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
left := optSemigroup.Concat(optSemigroup.Concat(a, b), c)
|
||||
right := optSemigroup.Concat(a, optSemigroup.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Some(5), left)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidIdentityLaws tests the identity laws for Monoid
|
||||
func TestMonoidIdentityLaws(t *testing.T) {
|
||||
intSemigroup := S.MakeSemigroup(func(a, b int) int { return a + b })
|
||||
optMonoid := Monoid[int]()(intSemigroup)
|
||||
|
||||
t.Run("left identity with Some", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(optMonoid.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity with Some", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(x, optMonoid.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("left identity with None", func(t *testing.T) {
|
||||
x := None[int]()
|
||||
result := optMonoid.Concat(optMonoid.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity with None", func(t *testing.T) {
|
||||
x := None[int]()
|
||||
result := optMonoid.Concat(x, optMonoid.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAlternativeMonoidIdentityLaws tests identity laws for AlternativeMonoid
|
||||
func TestAlternativeMonoidIdentityLaws(t *testing.T) {
|
||||
intMonoid := N.MonoidSum[int]()
|
||||
optMonoid := AlternativeMonoid(intMonoid)
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(optMonoid.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(x, optMonoid.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("empty is Some(0)", func(t *testing.T) {
|
||||
empty := optMonoid.Empty()
|
||||
assert.Equal(t, Some(0), empty)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAltMonoidIdentityLaws tests identity laws for AltMonoid
|
||||
func TestAltMonoidIdentityLaws(t *testing.T) {
|
||||
optMonoid := AltMonoid[int]()
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(optMonoid.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(x, optMonoid.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Some(1)
|
||||
b := None[int]()
|
||||
c := Some(3)
|
||||
|
||||
left := optMonoid.Concat(optMonoid.Concat(a, b), c)
|
||||
right := optMonoid.Concat(a, optMonoid.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Some(1), left)
|
||||
})
|
||||
}
|
||||
|
||||
// TestFirstMonoid tests the FirstMonoid implementation
|
||||
func TestFirstMonoid(t *testing.T) {
|
||||
optMonoid := FirstMonoid[int]()
|
||||
|
||||
t.Run("both Some values - returns first", func(t *testing.T) {
|
||||
result := optMonoid.Concat(Some(2), Some(3))
|
||||
assert.Equal(t, Some(2), result)
|
||||
})
|
||||
|
||||
t.Run("left Some, right None", func(t *testing.T) {
|
||||
result := optMonoid.Concat(Some(2), None[int]())
|
||||
assert.Equal(t, Some(2), result)
|
||||
})
|
||||
|
||||
t.Run("left None, right Some", func(t *testing.T) {
|
||||
result := optMonoid.Concat(None[int](), Some(3))
|
||||
assert.Equal(t, Some(3), result)
|
||||
})
|
||||
|
||||
t.Run("both None", func(t *testing.T) {
|
||||
result := optMonoid.Concat(None[int](), None[int]())
|
||||
assert.Equal(t, None[int](), result)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := optMonoid.Empty()
|
||||
assert.Equal(t, None[int](), empty)
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(optMonoid.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(x, optMonoid.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Some(1)
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
left := optMonoid.Concat(optMonoid.Concat(a, b), c)
|
||||
right := optMonoid.Concat(a, optMonoid.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Some(1), left)
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
// Should return the first Some value encountered
|
||||
result := optMonoid.Concat(
|
||||
optMonoid.Concat(None[int](), Some(1)),
|
||||
optMonoid.Concat(Some(2), Some(3)),
|
||||
)
|
||||
assert.Equal(t, Some(1), result)
|
||||
})
|
||||
|
||||
t.Run("with strings", func(t *testing.T) {
|
||||
strMonoid := FirstMonoid[string]()
|
||||
|
||||
result := strMonoid.Concat(Some("first"), Some("second"))
|
||||
assert.Equal(t, Some("first"), result)
|
||||
|
||||
result = strMonoid.Concat(None[string](), Some("second"))
|
||||
assert.Equal(t, Some("second"), result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLastMonoid tests the LastMonoid implementation
|
||||
func TestLastMonoid(t *testing.T) {
|
||||
optMonoid := LastMonoid[int]()
|
||||
|
||||
t.Run("both Some values - returns last", func(t *testing.T) {
|
||||
result := optMonoid.Concat(Some(2), Some(3))
|
||||
assert.Equal(t, Some(3), result)
|
||||
})
|
||||
|
||||
t.Run("left Some, right None", func(t *testing.T) {
|
||||
result := optMonoid.Concat(Some(2), None[int]())
|
||||
assert.Equal(t, Some(2), result)
|
||||
})
|
||||
|
||||
t.Run("left None, right Some", func(t *testing.T) {
|
||||
result := optMonoid.Concat(None[int](), Some(3))
|
||||
assert.Equal(t, Some(3), result)
|
||||
})
|
||||
|
||||
t.Run("both None", func(t *testing.T) {
|
||||
result := optMonoid.Concat(None[int](), None[int]())
|
||||
assert.Equal(t, None[int](), result)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := optMonoid.Empty()
|
||||
assert.Equal(t, None[int](), empty)
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(optMonoid.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Some(5)
|
||||
result := optMonoid.Concat(x, optMonoid.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Some(1)
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
left := optMonoid.Concat(optMonoid.Concat(a, b), c)
|
||||
right := optMonoid.Concat(a, optMonoid.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Some(3), left)
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
// Should return the last Some value encountered
|
||||
result := optMonoid.Concat(
|
||||
optMonoid.Concat(Some(1), Some(2)),
|
||||
optMonoid.Concat(Some(3), None[int]()),
|
||||
)
|
||||
assert.Equal(t, Some(3), result)
|
||||
})
|
||||
|
||||
t.Run("with strings", func(t *testing.T) {
|
||||
strMonoid := LastMonoid[string]()
|
||||
|
||||
result := strMonoid.Concat(Some("first"), Some("second"))
|
||||
assert.Equal(t, Some("second"), result)
|
||||
|
||||
result = strMonoid.Concat(Some("first"), None[string]())
|
||||
assert.Equal(t, Some("first"), result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestFirstMonoidVsAltMonoid verifies FirstMonoid and AltMonoid have the same behavior
|
||||
func TestFirstMonoidVsAltMonoid(t *testing.T) {
|
||||
firstMonoid := FirstMonoid[int]()
|
||||
altMonoid := AltMonoid[int]()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left Option[int]
|
||||
right Option[int]
|
||||
}{
|
||||
{"both Some", Some(1), Some(2)},
|
||||
{"left Some, right None", Some(1), None[int]()},
|
||||
{"left None, right Some", None[int](), Some(2)},
|
||||
{"both None", None[int](), None[int]()},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(tc.left, tc.right)
|
||||
altResult := altMonoid.Concat(tc.left, tc.right)
|
||||
assert.Equal(t, firstResult, altResult, "FirstMonoid and AltMonoid should behave the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFirstMonoidVsLastMonoid verifies the difference between FirstMonoid and LastMonoid
|
||||
func TestFirstMonoidVsLastMonoid(t *testing.T) {
|
||||
firstMonoid := FirstMonoid[int]()
|
||||
lastMonoid := LastMonoid[int]()
|
||||
|
||||
t.Run("both Some - different results", func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(Some(1), Some(2))
|
||||
lastResult := lastMonoid.Concat(Some(1), Some(2))
|
||||
|
||||
assert.Equal(t, Some(1), firstResult)
|
||||
assert.Equal(t, Some(2), lastResult)
|
||||
assert.NotEqual(t, firstResult, lastResult)
|
||||
})
|
||||
|
||||
t.Run("with None - same results", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
left Option[int]
|
||||
right Option[int]
|
||||
expected Option[int]
|
||||
}{
|
||||
{"left Some, right None", Some(1), None[int](), Some(1)},
|
||||
{"left None, right Some", None[int](), Some(2), Some(2)},
|
||||
{"both None", None[int](), None[int](), None[int]()},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(tc.left, tc.right)
|
||||
lastResult := lastMonoid.Concat(tc.left, tc.right)
|
||||
|
||||
assert.Equal(t, tc.expected, firstResult)
|
||||
assert.Equal(t, tc.expected, lastResult)
|
||||
assert.Equal(t, firstResult, lastResult)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidComparison compares different monoid implementations
|
||||
func TestMonoidComparison(t *testing.T) {
|
||||
t.Run("Monoid vs AlternativeMonoid with addition", func(t *testing.T) {
|
||||
intSemigroup := S.MakeSemigroup(func(a, b int) int { return a + b })
|
||||
regularMonoid := Monoid[int]()(intSemigroup)
|
||||
|
||||
intMonoid := M.MakeMonoid(func(a, b int) int { return a + b }, 0)
|
||||
altMonoid := AlternativeMonoid(intMonoid)
|
||||
|
||||
// Both should combine Some values the same way
|
||||
assert.Equal(t,
|
||||
regularMonoid.Concat(Some(2), Some(3)),
|
||||
altMonoid.Concat(Some(2), Some(3)),
|
||||
)
|
||||
|
||||
// But empty values differ
|
||||
assert.Equal(t, None[int](), regularMonoid.Empty())
|
||||
assert.Equal(t, Some(0), altMonoid.Empty())
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidLaws verifies monoid laws for all monoid implementations
|
||||
func TestMonoidLaws(t *testing.T) {
|
||||
t.Run("Monoid with addition", func(t *testing.T) {
|
||||
intSemigroup := N.SemigroupSum[int]()
|
||||
optMonoid := Monoid[int]()(intSemigroup)
|
||||
|
||||
a := Some(1)
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
// Associativity: (a • b) • c = a • (b • c)
|
||||
left := optMonoid.Concat(optMonoid.Concat(a, b), c)
|
||||
right := optMonoid.Concat(a, optMonoid.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity: Empty() • a = a
|
||||
leftId := optMonoid.Concat(optMonoid.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity: a • Empty() = a
|
||||
rightId := optMonoid.Concat(a, optMonoid.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("FirstMonoid laws", func(t *testing.T) {
|
||||
optMonoid := FirstMonoid[int]()
|
||||
|
||||
a := Some(1)
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
// Associativity
|
||||
left := optMonoid.Concat(optMonoid.Concat(a, b), c)
|
||||
right := optMonoid.Concat(a, optMonoid.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity
|
||||
leftId := optMonoid.Concat(optMonoid.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity
|
||||
rightId := optMonoid.Concat(a, optMonoid.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid laws", func(t *testing.T) {
|
||||
optMonoid := LastMonoid[int]()
|
||||
|
||||
a := Some(1)
|
||||
b := Some(2)
|
||||
c := Some(3)
|
||||
|
||||
// Associativity
|
||||
left := optMonoid.Concat(optMonoid.Concat(a, b), c)
|
||||
right := optMonoid.Concat(a, optMonoid.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity
|
||||
leftId := optMonoid.Concat(optMonoid.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity
|
||||
rightId := optMonoid.Concat(a, optMonoid.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
}
|
||||
@@ -21,7 +21,21 @@ import (
|
||||
S "github.com/IBM/fp-go/v2/semigroup"
|
||||
)
|
||||
|
||||
// Semigroup implements a two level ordering
|
||||
// Semigroup implements a two-level ordering that combines two Ord instances.
|
||||
// The resulting Ord will first compare using the first ordering, and only if
|
||||
// the values are equal according to the first ordering, it will use the second ordering.
|
||||
//
|
||||
// This is useful for implementing multi-level sorting (e.g., sort by last name, then by first name).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Person struct { LastName, FirstName string }
|
||||
// stringOrd := ord.FromStrictCompare[string]()
|
||||
// byLastName := ord.Contramap(func(p Person) string { return p.LastName })(stringOrd)
|
||||
// byFirstName := ord.Contramap(func(p Person) string { return p.FirstName })(stringOrd)
|
||||
// sg := ord.Semigroup[Person]()
|
||||
// personOrd := sg.Concat(byLastName, byFirstName)
|
||||
// // Now persons are ordered by last name, then by first name
|
||||
func Semigroup[A any]() S.Semigroup[Ord[A]] {
|
||||
return S.MakeSemigroup(func(first, second Ord[A]) Ord[A] {
|
||||
return FromCompare(func(a, b A) int {
|
||||
@@ -34,19 +48,48 @@ func Semigroup[A any]() S.Semigroup[Ord[A]] {
|
||||
})
|
||||
}
|
||||
|
||||
// Monoid implements a two level ordering such that
|
||||
// - its `Concat(ord1, ord2)` operation will order first by `ord1`, and then by `ord2`
|
||||
// - its `Empty` value is an `Ord` that always considers compared elements equal
|
||||
// Monoid implements a two-level ordering with an identity element.
|
||||
//
|
||||
// Properties:
|
||||
// - Concat(ord1, ord2) will order first by ord1, and then by ord2
|
||||
// - Empty() returns an Ord that always considers compared elements equal
|
||||
//
|
||||
// The Empty ordering acts as an identity: Concat(ord, Empty()) == ord
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// m := ord.Monoid[int]()
|
||||
// emptyOrd := m.Empty()
|
||||
// result := emptyOrd.Compare(5, 3) // 0 (always equal)
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// combined := m.Concat(intOrd, emptyOrd) // same as intOrd
|
||||
func Monoid[A any]() M.Monoid[Ord[A]] {
|
||||
return M.MakeMonoid(Semigroup[A]().Concat, FromCompare(F.Constant2[A, A](0)))
|
||||
}
|
||||
|
||||
// MaxSemigroup returns a semigroup where `concat` will return the maximum, based on the provided order.
|
||||
// MaxSemigroup returns a semigroup where Concat will return the maximum value
|
||||
// according to the provided ordering.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// maxSg := ord.MaxSemigroup(intOrd)
|
||||
// result := maxSg.Concat(5, 3) // 5
|
||||
// result := maxSg.Concat(3, 5) // 5
|
||||
func MaxSemigroup[A any](o Ord[A]) S.Semigroup[A] {
|
||||
return S.MakeSemigroup(Max(o))
|
||||
}
|
||||
|
||||
// MaxSemigroup returns a semigroup where `concat` will return the minimum, based on the provided order.
|
||||
// MinSemigroup returns a semigroup where Concat will return the minimum value
|
||||
// according to the provided ordering.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// minSg := ord.MinSemigroup(intOrd)
|
||||
// result := minSg.Concat(5, 3) // 3
|
||||
// result := minSg.Concat(3, 5) // 3
|
||||
func MinSemigroup[A any](o Ord[A]) S.Semigroup[A] {
|
||||
return S.MakeSemigroup(Min(o))
|
||||
}
|
||||
|
||||
206
v2/ord/ord.go
206
v2/ord/ord.go
@@ -17,6 +17,7 @@ package ord
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"time"
|
||||
|
||||
C "github.com/IBM/fp-go/v2/constraints"
|
||||
E "github.com/IBM/fp-go/v2/eq"
|
||||
@@ -80,35 +81,94 @@ func (self ord[T]) Compare(x, y T) int {
|
||||
return self.c(x, y)
|
||||
}
|
||||
|
||||
// ToEq converts an [Ord] to [E.Eq]
|
||||
|
||||
// ToEq converts an [Ord] to [E.Eq].
|
||||
// This allows using an Ord instance where only equality checking is needed.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// intEq := ord.ToEq(intOrd)
|
||||
// result := intEq.Equals(5, 5) // true
|
||||
//
|
||||
//go:inline
|
||||
func ToEq[T any](o Ord[T]) E.Eq[T] {
|
||||
return o
|
||||
}
|
||||
|
||||
// MakeOrd creates an instance of an Ord
|
||||
// MakeOrd creates an instance of an Ord from a compare function and an equals function.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: A comparison function that returns -1 if x < y, 0 if x == y, 1 if x > y
|
||||
// - e: An equality function that returns true if x and y are equal
|
||||
//
|
||||
// The compare and equals functions must be consistent: c(x, y) == 0 iff e(x, y) == true
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.MakeOrd(
|
||||
// func(a, b int) int {
|
||||
// if a < b { return -1 }
|
||||
// if a > b { return 1 }
|
||||
// return 0
|
||||
// },
|
||||
// func(a, b int) bool { return a == b },
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func MakeOrd[T any](c func(x, y T) int, e func(x, y T) bool) Ord[T] {
|
||||
return ord[T]{c: c, e: e}
|
||||
}
|
||||
|
||||
// MakeOrd creates an instance of an Ord from a compare function
|
||||
// FromCompare creates an instance of an Ord from a compare function.
|
||||
// The equals function is automatically derived from the compare function.
|
||||
//
|
||||
// Parameters:
|
||||
// - compare: A comparison function that returns -1 if x < y, 0 if x == y, 1 if x > y
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// stringOrd := ord.FromCompare(func(a, b string) int {
|
||||
// if a < b { return -1 }
|
||||
// if a > b { return 1 }
|
||||
// return 0
|
||||
// })
|
||||
func FromCompare[T any](compare func(T, T) int) Ord[T] {
|
||||
return MakeOrd(compare, func(x, y T) bool {
|
||||
return compare(x, y) == 0
|
||||
})
|
||||
}
|
||||
|
||||
// Reverse creates an inverted ordering
|
||||
// Reverse creates an inverted ordering where the comparison results are reversed.
|
||||
// If the original ordering has x < y, the reversed ordering will have x > y.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// reversedOrd := ord.Reverse(intOrd)
|
||||
// result := reversedOrd.Compare(5, 3) // -1 (reversed from 1)
|
||||
func Reverse[T any](o Ord[T]) Ord[T] {
|
||||
return MakeOrd(func(y, x T) int {
|
||||
return o.Compare(x, y)
|
||||
}, o.Equals)
|
||||
}
|
||||
|
||||
// Contramap creates an ordering under a transformation function
|
||||
// Contramap creates an ordering under a transformation function.
|
||||
// This allows ordering values of type B by first transforming them to type A
|
||||
// and then using the ordering for type A.
|
||||
//
|
||||
// Parameters:
|
||||
// - f: A transformation function from B to A
|
||||
//
|
||||
// Returns a function that takes an Ord[A] and returns an Ord[B]
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Person struct { Name string; Age int }
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// personOrd := ord.Contramap(func(p Person) int {
|
||||
// return p.Age
|
||||
// })(intOrd)
|
||||
// // Now persons are ordered by age
|
||||
func Contramap[A, B any](f func(B) A) func(Ord[A]) Ord[B] {
|
||||
return func(o Ord[A]) Ord[B] {
|
||||
return MakeOrd(func(x, y B) int {
|
||||
@@ -119,7 +179,15 @@ func Contramap[A, B any](f func(B) A) func(Ord[A]) Ord[B] {
|
||||
}
|
||||
}
|
||||
|
||||
// Min takes the minimum of two values. If they are considered equal, the first argument is chosen
|
||||
// Min takes the minimum of two values according to the given ordering.
|
||||
// If the values are considered equal, the first argument is chosen.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// min := ord.Min(intOrd)
|
||||
// result := min(5, 3) // 3
|
||||
// result := min(5, 5) // 5 (first argument)
|
||||
func Min[A any](o Ord[A]) func(A, A) A {
|
||||
return func(a, b A) A {
|
||||
if o.Compare(a, b) < 1 {
|
||||
@@ -129,7 +197,15 @@ func Min[A any](o Ord[A]) func(A, A) A {
|
||||
}
|
||||
}
|
||||
|
||||
// Max takes the maximum of two values. If they are considered equal, the first argument is chosen
|
||||
// Max takes the maximum of two values according to the given ordering.
|
||||
// If the values are considered equal, the first argument is chosen.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// max := ord.Max(intOrd)
|
||||
// result := max(5, 3) // 5
|
||||
// result := max(5, 5) // 5 (first argument)
|
||||
func Max[A any](o Ord[A]) func(A, A) A {
|
||||
return func(a, b A) A {
|
||||
if o.Compare(a, b) >= 0 {
|
||||
@@ -139,7 +215,18 @@ func Max[A any](o Ord[A]) func(A, A) A {
|
||||
}
|
||||
}
|
||||
|
||||
// Clamp clamps a value between a minimum and a maximum
|
||||
// Clamp restricts a value to be within a specified range [low, hi].
|
||||
// If the value is less than low, low is returned.
|
||||
// If the value is greater than hi, hi is returned.
|
||||
// Otherwise, the value itself is returned.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// clamp := ord.Clamp(intOrd)(0, 100)
|
||||
// result := clamp(-10) // 0
|
||||
// result := clamp(50) // 50
|
||||
// result := clamp(150) // 100
|
||||
func Clamp[A any](o Ord[A]) func(A, A) func(A) A {
|
||||
return func(low, hi A) func(A) A {
|
||||
clow := F.Bind2nd(o.Compare, low)
|
||||
@@ -166,14 +253,35 @@ func strictEq[A comparable](a, b A) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
// FromStrictCompare implements the ordering based on the built in native order
|
||||
// FromStrictCompare implements the ordering based on the built-in native order
|
||||
// for types that satisfy the Ordered constraint (integers, floats, strings).
|
||||
//
|
||||
// This is the most common way to create an Ord for built-in types.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// result := intOrd.Compare(5, 3) // 1
|
||||
//
|
||||
// stringOrd := ord.FromStrictCompare[string]()
|
||||
// result := stringOrd.Compare("apple", "banana") // -1
|
||||
//
|
||||
//go:inline
|
||||
func FromStrictCompare[A C.Ordered]() Ord[A] {
|
||||
return MakeOrd(strictCompare[A], strictEq[A])
|
||||
}
|
||||
|
||||
// Lt tests whether one value is strictly less than another
|
||||
// Lt tests whether one value is strictly less than another.
|
||||
// Returns a curried function that first takes the comparison value,
|
||||
// then takes the value to test.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// isLessThan5 := ord.Lt(intOrd)(5)
|
||||
// result := isLessThan5(3) // true
|
||||
// result := isLessThan5(5) // false
|
||||
// result := isLessThan5(7) // false
|
||||
func Lt[A any](o Ord[A]) func(A) func(A) bool {
|
||||
return func(second A) func(A) bool {
|
||||
return func(first A) bool {
|
||||
@@ -182,7 +290,17 @@ func Lt[A any](o Ord[A]) func(A) func(A) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Leq Tests whether one value is less or equal than another
|
||||
// Leq tests whether one value is less than or equal to another.
|
||||
// Returns a curried function that first takes the comparison value,
|
||||
// then takes the value to test.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// isAtMost5 := ord.Leq(intOrd)(5)
|
||||
// result := isAtMost5(3) // true
|
||||
// result := isAtMost5(5) // true
|
||||
// result := isAtMost5(7) // false
|
||||
func Leq[A any](o Ord[A]) func(A) func(A) bool {
|
||||
return func(second A) func(A) bool {
|
||||
return func(first A) bool {
|
||||
@@ -191,9 +309,17 @@ func Leq[A any](o Ord[A]) func(A) func(A) bool {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test whether one value is strictly greater than another
|
||||
*/
|
||||
// Gt tests whether one value is strictly greater than another.
|
||||
// Returns a curried function that first takes the comparison value,
|
||||
// then takes the value to test.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// isGreaterThan5 := ord.Gt(intOrd)(5)
|
||||
// result := isGreaterThan5(3) // false
|
||||
// result := isGreaterThan5(5) // false
|
||||
// result := isGreaterThan5(7) // true
|
||||
func Gt[A any](o Ord[A]) func(A) func(A) bool {
|
||||
return func(second A) func(A) bool {
|
||||
return func(first A) bool {
|
||||
@@ -202,7 +328,17 @@ func Gt[A any](o Ord[A]) func(A) func(A) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Geq tests whether one value is greater or equal than another
|
||||
// Geq tests whether one value is greater than or equal to another.
|
||||
// Returns a curried function that first takes the comparison value,
|
||||
// then takes the value to test.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// isAtLeast5 := ord.Geq(intOrd)(5)
|
||||
// result := isAtLeast5(3) // false
|
||||
// result := isAtLeast5(5) // true
|
||||
// result := isAtLeast5(7) // true
|
||||
func Geq[A any](o Ord[A]) func(A) func(A) bool {
|
||||
return func(second A) func(A) bool {
|
||||
return func(first A) bool {
|
||||
@@ -211,7 +347,21 @@ func Geq[A any](o Ord[A]) func(A) func(A) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Between tests whether a value is between a minimum (inclusive) and a maximum (exclusive)
|
||||
// Between tests whether a value is between a minimum (inclusive) and a maximum (exclusive).
|
||||
// Returns a curried function that first takes the range bounds,
|
||||
// then takes the value to test.
|
||||
//
|
||||
// The range is [lo, hi), meaning lo is included but hi is excluded.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// intOrd := ord.FromStrictCompare[int]()
|
||||
// isBetween3And7 := ord.Between(intOrd)(3, 7)
|
||||
// result := isBetween3And7(2) // false (below range)
|
||||
// result := isBetween3And7(3) // true (at lower bound)
|
||||
// result := isBetween3And7(5) // true (within range)
|
||||
// result := isBetween3And7(7) // false (at upper bound, excluded)
|
||||
// result := isBetween3And7(8) // false (above range)
|
||||
func Between[A any](o Ord[A]) func(A, A) func(A) bool {
|
||||
lt := Lt(o)
|
||||
geq := Geq(o)
|
||||
@@ -220,3 +370,25 @@ func Between[A any](o Ord[A]) func(A, A) func(A) bool {
|
||||
return P.And(lt(hi))(geq(lo))
|
||||
}
|
||||
}
|
||||
|
||||
func compareTime(a, b time.Time) int {
|
||||
if a.Before(b) {
|
||||
return -1
|
||||
} else if a.After(b) {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// OrdTime returns an Ord instance for time.Time values.
|
||||
// Times are ordered chronologically using the Before and After methods.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// timeOrd := ord.OrdTime()
|
||||
// t1 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
// t2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
// result := timeOrd.Compare(t1, t2) // -1 (t1 is before t2)
|
||||
func OrdTime() Ord[time.Time] {
|
||||
return MakeOrd(compareTime, time.Time.Equal)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package ord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -581,3 +582,306 @@ func BenchmarkClamp(b *testing.B) {
|
||||
_ = clamp(i % 150)
|
||||
}
|
||||
}
|
||||
|
||||
// Test OrdTime
|
||||
func TestOrdTime(t *testing.T) {
|
||||
timeOrd := OrdTime()
|
||||
|
||||
t1 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
t2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
t3 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Test Compare
|
||||
assert.Equal(t, -1, timeOrd.Compare(t1, t2), "t1 should be before t2")
|
||||
assert.Equal(t, 1, timeOrd.Compare(t2, t1), "t2 should be after t1")
|
||||
assert.Equal(t, 0, timeOrd.Compare(t1, t3), "t1 should equal t3")
|
||||
|
||||
// Test Equals
|
||||
assert.True(t, timeOrd.Equals(t1, t3), "t1 should equal t3")
|
||||
assert.False(t, timeOrd.Equals(t1, t2), "t1 should not equal t2")
|
||||
}
|
||||
|
||||
func TestOrdTime_WithDifferentTimezones(t *testing.T) {
|
||||
timeOrd := OrdTime()
|
||||
|
||||
// Same instant in different timezones
|
||||
utc := time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
est := utc.In(time.FixedZone("EST", -5*3600))
|
||||
|
||||
// Should be equal (same instant)
|
||||
assert.Equal(t, 0, timeOrd.Compare(utc, est))
|
||||
assert.True(t, timeOrd.Equals(utc, est))
|
||||
}
|
||||
|
||||
func TestOrdTime_WithNanoseconds(t *testing.T) {
|
||||
timeOrd := OrdTime()
|
||||
|
||||
t1 := time.Date(2023, 1, 1, 0, 0, 0, 100, time.UTC)
|
||||
t2 := time.Date(2023, 1, 1, 0, 0, 0, 200, time.UTC)
|
||||
|
||||
assert.Equal(t, -1, timeOrd.Compare(t1, t2))
|
||||
assert.Equal(t, 1, timeOrd.Compare(t2, t1))
|
||||
}
|
||||
|
||||
func TestOrdTime_MinMax(t *testing.T) {
|
||||
timeOrd := OrdTime()
|
||||
|
||||
t1 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
t2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
min := Min(timeOrd)
|
||||
max := Max(timeOrd)
|
||||
|
||||
assert.Equal(t, t1, min(t1, t2))
|
||||
assert.Equal(t, t1, min(t2, t1))
|
||||
|
||||
assert.Equal(t, t2, max(t1, t2))
|
||||
assert.Equal(t, t2, max(t2, t1))
|
||||
}
|
||||
|
||||
// Example tests for documentation
|
||||
func ExampleFromStrictCompare() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
|
||||
result1 := intOrd.Compare(5, 3)
|
||||
result2 := intOrd.Compare(3, 5)
|
||||
result3 := intOrd.Compare(5, 5)
|
||||
|
||||
println(result1) // 1
|
||||
println(result2) // -1
|
||||
println(result3) // 0
|
||||
}
|
||||
|
||||
func ExampleMakeOrd() {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
personOrd := MakeOrd(
|
||||
func(p1, p2 Person) int {
|
||||
if p1.Age < p2.Age {
|
||||
return -1
|
||||
} else if p1.Age > p2.Age {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
},
|
||||
func(p1, p2 Person) bool {
|
||||
return p1.Age == p2.Age
|
||||
},
|
||||
)
|
||||
|
||||
p1 := Person{Name: "Alice", Age: 30}
|
||||
p2 := Person{Name: "Bob", Age: 25}
|
||||
|
||||
result := personOrd.Compare(p1, p2)
|
||||
println(result) // 1 (30 > 25)
|
||||
}
|
||||
|
||||
func ExampleFromCompare() {
|
||||
stringOrd := FromCompare(func(a, b string) int {
|
||||
if a < b {
|
||||
return -1
|
||||
} else if a > b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
result := stringOrd.Compare("apple", "banana")
|
||||
println(result) // -1
|
||||
}
|
||||
|
||||
func ExampleReverse() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
reversedOrd := Reverse(intOrd)
|
||||
|
||||
result1 := intOrd.Compare(5, 3)
|
||||
result2 := reversedOrd.Compare(5, 3)
|
||||
|
||||
println(result1) // 1
|
||||
println(result2) // -1
|
||||
}
|
||||
|
||||
func ExampleContramap() {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
intOrd := FromStrictCompare[int]()
|
||||
|
||||
// Order persons by age
|
||||
personOrd := Contramap(func(p Person) int {
|
||||
return p.Age
|
||||
})(intOrd)
|
||||
|
||||
p1 := Person{Name: "Alice", Age: 30}
|
||||
p2 := Person{Name: "Bob", Age: 25}
|
||||
|
||||
result := personOrd.Compare(p1, p2)
|
||||
println(result) // 1 (30 > 25)
|
||||
}
|
||||
|
||||
func ExampleMin() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
min := Min(intOrd)
|
||||
|
||||
result := min(5, 3)
|
||||
println(result) // 3
|
||||
}
|
||||
|
||||
func ExampleMax() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
max := Max(intOrd)
|
||||
|
||||
result := max(5, 3)
|
||||
println(result) // 5
|
||||
}
|
||||
|
||||
func ExampleClamp() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
clamp := Clamp(intOrd)(0, 100)
|
||||
|
||||
result1 := clamp(-10)
|
||||
result2 := clamp(50)
|
||||
result3 := clamp(150)
|
||||
|
||||
println(result1) // 0
|
||||
println(result2) // 50
|
||||
println(result3) // 100
|
||||
}
|
||||
|
||||
func ExampleLt() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
isLessThan5 := Lt(intOrd)(5)
|
||||
|
||||
result1 := isLessThan5(3)
|
||||
result2 := isLessThan5(5)
|
||||
result3 := isLessThan5(7)
|
||||
|
||||
println(result1) // true
|
||||
println(result2) // false
|
||||
println(result3) // false
|
||||
}
|
||||
|
||||
func ExampleLeq() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
isAtMost5 := Leq(intOrd)(5)
|
||||
|
||||
result1 := isAtMost5(3)
|
||||
result2 := isAtMost5(5)
|
||||
result3 := isAtMost5(7)
|
||||
|
||||
println(result1) // true
|
||||
println(result2) // true
|
||||
println(result3) // false
|
||||
}
|
||||
|
||||
func ExampleGt() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
isGreaterThan5 := Gt(intOrd)(5)
|
||||
|
||||
result1 := isGreaterThan5(3)
|
||||
result2 := isGreaterThan5(5)
|
||||
result3 := isGreaterThan5(7)
|
||||
|
||||
println(result1) // false
|
||||
println(result2) // false
|
||||
println(result3) // true
|
||||
}
|
||||
|
||||
func ExampleGeq() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
isAtLeast5 := Geq(intOrd)(5)
|
||||
|
||||
result1 := isAtLeast5(3)
|
||||
result2 := isAtLeast5(5)
|
||||
result3 := isAtLeast5(7)
|
||||
|
||||
println(result1) // false
|
||||
println(result2) // true
|
||||
println(result3) // true
|
||||
}
|
||||
|
||||
func ExampleBetween() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
isBetween3And7 := Between(intOrd)(3, 7)
|
||||
|
||||
result1 := isBetween3And7(2)
|
||||
result2 := isBetween3And7(3)
|
||||
result3 := isBetween3And7(5)
|
||||
result4 := isBetween3And7(7)
|
||||
result5 := isBetween3And7(8)
|
||||
|
||||
println(result1) // false
|
||||
println(result2) // true
|
||||
println(result3) // true
|
||||
println(result4) // false
|
||||
println(result5) // false
|
||||
}
|
||||
|
||||
func ExampleSemigroup() {
|
||||
type Person struct {
|
||||
LastName string
|
||||
FirstName string
|
||||
}
|
||||
|
||||
stringOrd := FromStrictCompare[string]()
|
||||
|
||||
// Order by last name
|
||||
byLastName := Contramap(func(p Person) string {
|
||||
return p.LastName
|
||||
})(stringOrd)
|
||||
|
||||
// Order by first name
|
||||
byFirstName := Contramap(func(p Person) string {
|
||||
return p.FirstName
|
||||
})(stringOrd)
|
||||
|
||||
// Combine: order by last name, then first name
|
||||
sg := Semigroup[Person]()
|
||||
personOrd := sg.Concat(byLastName, byFirstName)
|
||||
|
||||
p1 := Person{LastName: "Smith", FirstName: "Alice"}
|
||||
p2 := Person{LastName: "Smith", FirstName: "Bob"}
|
||||
|
||||
result := personOrd.Compare(p1, p2)
|
||||
println(result) // -1 (Alice < Bob)
|
||||
}
|
||||
|
||||
func ExampleMonoid() {
|
||||
m := Monoid[int]()
|
||||
|
||||
// Empty ordering considers everything equal
|
||||
emptyOrd := m.Empty()
|
||||
result := emptyOrd.Compare(5, 3)
|
||||
println(result) // 0
|
||||
}
|
||||
|
||||
func ExampleMaxSemigroup() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
maxSg := MaxSemigroup(intOrd)
|
||||
|
||||
result := maxSg.Concat(5, 3)
|
||||
println(result) // 5
|
||||
}
|
||||
|
||||
func ExampleMinSemigroup() {
|
||||
intOrd := FromStrictCompare[int]()
|
||||
minSg := MinSemigroup(intOrd)
|
||||
|
||||
result := minSg.Concat(5, 3)
|
||||
println(result) // 3
|
||||
}
|
||||
|
||||
func ExampleOrdTime() {
|
||||
timeOrd := OrdTime()
|
||||
|
||||
t1 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
t2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
result := timeOrd.Compare(t1, t2)
|
||||
println(result) // -1 (t1 is before t2)
|
||||
}
|
||||
|
||||
230
v2/pair/monoid.go
Normal file
230
v2/pair/monoid.go
Normal file
@@ -0,0 +1,230 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pair
|
||||
|
||||
import (
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
M "github.com/IBM/fp-go/v2/monoid"
|
||||
)
|
||||
|
||||
// ApplicativeMonoid creates a monoid for [Pair] using applicative functor operations on the tail.
|
||||
//
|
||||
// This is an alias for [ApplicativeMonoidTail], which lifts the right (tail) monoid into the
|
||||
// Pair applicative functor. The left monoid provides the semigroup for combining head values
|
||||
// during applicative operations.
|
||||
//
|
||||
// IMPORTANT: The three monoid constructors (ApplicativeMonoid/ApplicativeMonoidTail and
|
||||
// ApplicativeMonoidHead) produce DIFFERENT results:
|
||||
// - ApplicativeMonoidTail: Combines head values in REVERSE order (right-to-left)
|
||||
// - ApplicativeMonoidHead: Combines tail values in REVERSE order (right-to-left)
|
||||
// - The "focused" component (tail for Tail, head for Head) combines in normal order (left-to-right)
|
||||
//
|
||||
// This difference is significant for non-commutative operations like string concatenation.
|
||||
//
|
||||
// Parameters:
|
||||
// - l: A monoid for the head (left) values of type L
|
||||
// - r: A monoid for the tail (right) values of type R
|
||||
//
|
||||
// Returns:
|
||||
// - A Monoid[Pair[L, R]] that combines pairs using applicative operations on the tail
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import (
|
||||
// N "github.com/IBM/fp-go/v2/number"
|
||||
// S "github.com/IBM/fp-go/v2/string"
|
||||
// )
|
||||
//
|
||||
// intAdd := N.MonoidSum[int]()
|
||||
// strConcat := S.Monoid
|
||||
//
|
||||
// pairMonoid := pair.ApplicativeMonoid(intAdd, strConcat)
|
||||
//
|
||||
// p1 := pair.MakePair(10, "foo")
|
||||
// p2 := pair.MakePair(20, "bar")
|
||||
//
|
||||
// result := pairMonoid.Concat(p1, p2)
|
||||
// // result is Pair[int, string]{30, "foobar"}
|
||||
// // Note: head combines normally (10+20), tail combines normally ("foo"+"bar")
|
||||
//
|
||||
// empty := pairMonoid.Empty()
|
||||
// // empty is Pair[int, string]{0, ""}
|
||||
//
|
||||
//go:inline
|
||||
func ApplicativeMonoid[L, R any](l M.Monoid[L], r M.Monoid[R]) M.Monoid[Pair[L, R]] {
|
||||
return ApplicativeMonoidTail(l, r)
|
||||
}
|
||||
|
||||
// ApplicativeMonoidTail creates a monoid for [Pair] by lifting the tail monoid into the applicative functor.
|
||||
//
|
||||
// This function constructs a monoid using the applicative structure of Pair, focusing on
|
||||
// the tail (right) value. The head values are combined using the left monoid's semigroup
|
||||
// operation during applicative application.
|
||||
//
|
||||
// CRITICAL BEHAVIOR: Due to the applicative functor implementation, the HEAD values are
|
||||
// combined in REVERSE order (right-to-left), while TAIL values combine in normal order
|
||||
// (left-to-right). This matters for non-commutative operations:
|
||||
//
|
||||
// strConcat := S.Monoid
|
||||
// pairMonoid := pair.ApplicativeMonoidTail(strConcat, strConcat)
|
||||
// p1 := pair.MakePair("hello", "foo")
|
||||
// p2 := pair.MakePair(" world", "bar")
|
||||
// result := pairMonoid.Concat(p1, p2)
|
||||
// // result is Pair[string, string]{" worldhello", "foobar"}
|
||||
// // ^^^^^^^^^^^^^^ ^^^^^^
|
||||
// // REVERSED! normal
|
||||
//
|
||||
// The resulting monoid satisfies the standard monoid laws:
|
||||
// - Associativity: Concat(Concat(p1, p2), p3) = Concat(p1, Concat(p2, p3))
|
||||
// - Left identity: Concat(Empty(), p) = p
|
||||
// - Right identity: Concat(p, Empty()) = p
|
||||
//
|
||||
// Parameters:
|
||||
// - l: A monoid for the head (left) values of type L
|
||||
// - r: A monoid for the tail (right) values of type R
|
||||
//
|
||||
// Returns:
|
||||
// - A Monoid[Pair[L, R]] that combines pairs component-wise
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import (
|
||||
// N "github.com/IBM/fp-go/v2/number"
|
||||
// M "github.com/IBM/fp-go/v2/monoid"
|
||||
// )
|
||||
//
|
||||
// intAdd := N.MonoidSum[int]()
|
||||
// intMul := N.MonoidProduct[int]()
|
||||
//
|
||||
// pairMonoid := pair.ApplicativeMonoidTail(intAdd, intMul)
|
||||
//
|
||||
// p1 := pair.MakePair(5, 3)
|
||||
// p2 := pair.MakePair(10, 4)
|
||||
//
|
||||
// result := pairMonoid.Concat(p1, p2)
|
||||
// // result is Pair[int, int]{15, 12} (5+10, 3*4)
|
||||
// // Note: Addition is commutative, so order doesn't matter for head
|
||||
//
|
||||
// empty := pairMonoid.Empty()
|
||||
// // empty is Pair[int, int]{0, 1}
|
||||
//
|
||||
// Example with different types:
|
||||
//
|
||||
// import S "github.com/IBM/fp-go/v2/string"
|
||||
//
|
||||
// boolAnd := M.MakeMonoid(func(a, b bool) bool { return a && b }, true)
|
||||
// strConcat := S.Monoid
|
||||
//
|
||||
// pairMonoid := pair.ApplicativeMonoidTail(boolAnd, strConcat)
|
||||
//
|
||||
// p1 := pair.MakePair(true, "hello")
|
||||
// p2 := pair.MakePair(true, " world")
|
||||
//
|
||||
// result := pairMonoid.Concat(p1, p2)
|
||||
// // result is Pair[bool, string]{true, "hello world"}
|
||||
// // Note: Boolean AND is commutative, so order doesn't matter for head
|
||||
//
|
||||
//go:inline
|
||||
func ApplicativeMonoidTail[L, R any](l M.Monoid[L], r M.Monoid[R]) M.Monoid[Pair[L, R]] {
|
||||
return M.ApplicativeMonoid(
|
||||
FromHead[R](l.Empty()),
|
||||
MonadMapTail[L, R, func(R) R],
|
||||
F.Bind1of3(MonadApTail[L, R, R])(l),
|
||||
r)
|
||||
}
|
||||
|
||||
// ApplicativeMonoidHead creates a monoid for [Pair] by lifting the head monoid into the applicative functor.
|
||||
//
|
||||
// This function constructs a monoid using the applicative structure of Pair, focusing on
|
||||
// the head (left) value. The tail values are combined using the right monoid's semigroup
|
||||
// operation during applicative application.
|
||||
//
|
||||
// This is the dual of [ApplicativeMonoidTail], operating on the head instead of the tail.
|
||||
//
|
||||
// CRITICAL BEHAVIOR: Due to the applicative functor implementation, the TAIL values are
|
||||
// combined in REVERSE order (right-to-left), while HEAD values combine in normal order
|
||||
// (left-to-right). This is the opposite of ApplicativeMonoidTail:
|
||||
//
|
||||
// strConcat := S.Monoid
|
||||
// pairMonoid := pair.ApplicativeMonoidHead(strConcat, strConcat)
|
||||
// p1 := pair.MakePair("hello", "foo")
|
||||
// p2 := pair.MakePair(" world", "bar")
|
||||
// result := pairMonoid.Concat(p1, p2)
|
||||
// // result is Pair[string, string]{"hello world", "barfoo"}
|
||||
// // ^^^^^^^^^^^^ ^^^^^^^^
|
||||
// // normal REVERSED!
|
||||
//
|
||||
// The resulting monoid satisfies the standard monoid laws:
|
||||
// - Associativity: Concat(Concat(p1, p2), p3) = Concat(p1, Concat(p2, p3))
|
||||
// - Left identity: Concat(Empty(), p) = p
|
||||
// - Right identity: Concat(p, Empty()) = p
|
||||
//
|
||||
// Parameters:
|
||||
// - l: A monoid for the head (left) values of type L
|
||||
// - r: A monoid for the tail (right) values of type R
|
||||
//
|
||||
// Returns:
|
||||
// - A Monoid[Pair[L, R]] that combines pairs component-wise
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import (
|
||||
// N "github.com/IBM/fp-go/v2/number"
|
||||
// M "github.com/IBM/fp-go/v2/monoid"
|
||||
// )
|
||||
//
|
||||
// intMul := N.MonoidProduct[int]()
|
||||
// intAdd := N.MonoidSum[int]()
|
||||
//
|
||||
// pairMonoid := pair.ApplicativeMonoidHead(intMul, intAdd)
|
||||
//
|
||||
// p1 := pair.MakePair(3, 5)
|
||||
// p2 := pair.MakePair(4, 10)
|
||||
//
|
||||
// result := pairMonoid.Concat(p1, p2)
|
||||
// // result is Pair[int, int]{12, 15} (3*4, 5+10)
|
||||
// // Note: Both operations are commutative, so order doesn't matter
|
||||
//
|
||||
// empty := pairMonoid.Empty()
|
||||
// // empty is Pair[int, int]{1, 0}
|
||||
//
|
||||
// Example comparing Head vs Tail with non-commutative operations:
|
||||
//
|
||||
// import S "github.com/IBM/fp-go/v2/string"
|
||||
//
|
||||
// strConcat := S.Monoid
|
||||
//
|
||||
// // Using ApplicativeMonoidHead - tail values REVERSED
|
||||
// headMonoid := pair.ApplicativeMonoidHead(strConcat, strConcat)
|
||||
// p1 := pair.MakePair("hello", "foo")
|
||||
// p2 := pair.MakePair(" world", "bar")
|
||||
// result := headMonoid.Concat(p1, p2)
|
||||
// // result is Pair[string, string]{"hello world", "barfoo"}
|
||||
//
|
||||
// // Using ApplicativeMonoidTail - head values REVERSED
|
||||
// tailMonoid := pair.ApplicativeMonoidTail(strConcat, strConcat)
|
||||
// result2 := tailMonoid.Concat(p1, p2)
|
||||
// // result2 is Pair[string, string]{" worldhello", "foobar"}
|
||||
// // DIFFERENT result! Head and tail are swapped in their reversal behavior
|
||||
//
|
||||
//go:inline
|
||||
func ApplicativeMonoidHead[L, R any](l M.Monoid[L], r M.Monoid[R]) M.Monoid[Pair[L, R]] {
|
||||
return M.ApplicativeMonoid(
|
||||
FromTail[L](r.Empty()),
|
||||
MonadMapHead[R, L, func(L) L],
|
||||
F.Bind1of3(MonadApHead[R, L, L])(r),
|
||||
l)
|
||||
}
|
||||
497
v2/pair/monoid_test.go
Normal file
497
v2/pair/monoid_test.go
Normal file
@@ -0,0 +1,497 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pair
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
M "github.com/IBM/fp-go/v2/monoid"
|
||||
N "github.com/IBM/fp-go/v2/number"
|
||||
S "github.com/IBM/fp-go/v2/string"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestApplicativeMonoidTail tests the ApplicativeMonoidTail implementation
|
||||
func TestApplicativeMonoidTail(t *testing.T) {
|
||||
t.Run("integer addition and string concatenation", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
p1 := MakePair(5, "hello")
|
||||
p2 := MakePair(3, " world")
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, 8, Head(result))
|
||||
assert.Equal(t, "hello world", Tail(result))
|
||||
})
|
||||
|
||||
t.Run("integer multiplication and addition", func(t *testing.T) {
|
||||
intMul := N.MonoidProduct[int]()
|
||||
intAdd := N.MonoidSum[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intMul, intAdd)
|
||||
|
||||
p1 := MakePair(3, 5)
|
||||
p2 := MakePair(4, 10)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, 12, Head(result)) // 3 * 4
|
||||
assert.Equal(t, 15, Tail(result)) // 5 + 10
|
||||
})
|
||||
|
||||
t.Run("boolean AND and OR", func(t *testing.T) {
|
||||
boolAnd := M.MakeMonoid(func(a, b bool) bool { return a && b }, true)
|
||||
boolOr := M.MakeMonoid(func(a, b bool) bool { return a || b }, false)
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(boolAnd, boolOr)
|
||||
|
||||
p1 := MakePair(true, false)
|
||||
p2 := MakePair(true, true)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, true, Head(result)) // true && true
|
||||
assert.Equal(t, true, Tail(result)) // false || true
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
empty := pairMonoid.Empty()
|
||||
assert.Equal(t, 0, Head(empty))
|
||||
assert.Equal(t, "", Tail(empty))
|
||||
})
|
||||
|
||||
t.Run("left identity law", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
p := MakePair(5, "test")
|
||||
result := pairMonoid.Concat(pairMonoid.Empty(), p)
|
||||
|
||||
assert.Equal(t, p, result)
|
||||
})
|
||||
|
||||
t.Run("right identity law", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
p := MakePair(5, "test")
|
||||
result := pairMonoid.Concat(p, pairMonoid.Empty())
|
||||
|
||||
assert.Equal(t, p, result)
|
||||
})
|
||||
|
||||
t.Run("associativity law", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
p1 := MakePair(1, "a")
|
||||
p2 := MakePair(2, "b")
|
||||
p3 := MakePair(3, "c")
|
||||
|
||||
left := pairMonoid.Concat(pairMonoid.Concat(p1, p2), p3)
|
||||
right := pairMonoid.Concat(p1, pairMonoid.Concat(p2, p3))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, 6, Head(left))
|
||||
assert.Equal(t, "abc", Tail(left))
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, intMul)
|
||||
|
||||
pairs := []Pair[int, int]{
|
||||
MakePair(1, 2),
|
||||
MakePair(3, 4),
|
||||
MakePair(5, 6),
|
||||
}
|
||||
|
||||
result := pairMonoid.Empty()
|
||||
for _, p := range pairs {
|
||||
result = pairMonoid.Concat(result, p)
|
||||
}
|
||||
|
||||
assert.Equal(t, 9, Head(result)) // 0 + 1 + 3 + 5
|
||||
assert.Equal(t, 48, Tail(result)) // 1 * 2 * 4 * 6
|
||||
})
|
||||
}
|
||||
|
||||
// TestApplicativeMonoidHead tests the ApplicativeMonoidHead implementation
|
||||
func TestApplicativeMonoidHead(t *testing.T) {
|
||||
t.Run("integer multiplication and addition", func(t *testing.T) {
|
||||
intMul := N.MonoidProduct[int]()
|
||||
intAdd := N.MonoidSum[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
|
||||
p1 := MakePair(3, 5)
|
||||
p2 := MakePair(4, 10)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, 12, Head(result)) // 3 * 4
|
||||
assert.Equal(t, 15, Tail(result)) // 5 + 10
|
||||
})
|
||||
|
||||
t.Run("string concatenation and boolean OR", func(t *testing.T) {
|
||||
strConcat := S.Monoid
|
||||
boolOr := M.MakeMonoid(func(a, b bool) bool { return a || b }, false)
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(strConcat, boolOr)
|
||||
|
||||
p1 := MakePair("hello", false)
|
||||
p2 := MakePair(" world", true)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, "hello world", Head(result))
|
||||
assert.Equal(t, true, Tail(result))
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
intMul := N.MonoidProduct[int]()
|
||||
intAdd := N.MonoidSum[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
|
||||
empty := pairMonoid.Empty()
|
||||
assert.Equal(t, 1, Head(empty))
|
||||
assert.Equal(t, 0, Tail(empty))
|
||||
})
|
||||
|
||||
t.Run("left identity law", func(t *testing.T) {
|
||||
intMul := N.MonoidProduct[int]()
|
||||
intAdd := N.MonoidSum[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
|
||||
p := MakePair(5, 10)
|
||||
result := pairMonoid.Concat(pairMonoid.Empty(), p)
|
||||
|
||||
assert.Equal(t, p, result)
|
||||
})
|
||||
|
||||
t.Run("right identity law", func(t *testing.T) {
|
||||
intMul := N.MonoidProduct[int]()
|
||||
intAdd := N.MonoidSum[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
|
||||
p := MakePair(5, 10)
|
||||
result := pairMonoid.Concat(p, pairMonoid.Empty())
|
||||
|
||||
assert.Equal(t, p, result)
|
||||
})
|
||||
|
||||
t.Run("associativity law", func(t *testing.T) {
|
||||
intMul := N.MonoidProduct[int]()
|
||||
intAdd := N.MonoidSum[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
|
||||
p1 := MakePair(2, 1)
|
||||
p2 := MakePair(3, 2)
|
||||
p3 := MakePair(4, 3)
|
||||
|
||||
left := pairMonoid.Concat(pairMonoid.Concat(p1, p2), p3)
|
||||
right := pairMonoid.Concat(p1, pairMonoid.Concat(p2, p3))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, 24, Head(left)) // 2 * 3 * 4
|
||||
assert.Equal(t, 6, Tail(left)) // 1 + 2 + 3
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidHead(intAdd, intMul)
|
||||
|
||||
pairs := []Pair[int, int]{
|
||||
MakePair(1, 2),
|
||||
MakePair(3, 4),
|
||||
MakePair(5, 6),
|
||||
}
|
||||
|
||||
result := pairMonoid.Empty()
|
||||
for _, p := range pairs {
|
||||
result = pairMonoid.Concat(result, p)
|
||||
}
|
||||
|
||||
assert.Equal(t, 9, Head(result)) // 0 + 1 + 3 + 5
|
||||
assert.Equal(t, 48, Tail(result)) // 1 * 2 * 4 * 6
|
||||
})
|
||||
}
|
||||
|
||||
// TestApplicativeMonoid tests the ApplicativeMonoid alias
|
||||
func TestApplicativeMonoid(t *testing.T) {
|
||||
t.Run("is alias for ApplicativeMonoidTail", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
monoid1 := ApplicativeMonoid(intAdd, strConcat)
|
||||
monoid2 := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
p1 := MakePair(5, "hello")
|
||||
p2 := MakePair(3, " world")
|
||||
|
||||
result1 := monoid1.Concat(p1, p2)
|
||||
result2 := monoid2.Concat(p1, p2)
|
||||
|
||||
assert.Equal(t, result1, result2)
|
||||
assert.Equal(t, 8, Head(result1))
|
||||
assert.Equal(t, "hello world", Tail(result1))
|
||||
})
|
||||
|
||||
t.Run("empty values are identical", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
monoid1 := ApplicativeMonoid(intAdd, strConcat)
|
||||
monoid2 := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
assert.Equal(t, monoid1.Empty(), monoid2.Empty())
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidHeadVsTail compares ApplicativeMonoidHead and ApplicativeMonoidTail
|
||||
func TestMonoidHeadVsTail(t *testing.T) {
|
||||
t.Run("same result with commutative operations", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
headMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
tailMonoid := ApplicativeMonoidTail(intMul, intAdd)
|
||||
|
||||
p1 := MakePair(2, 3)
|
||||
p2 := MakePair(4, 5)
|
||||
|
||||
resultHead := headMonoid.Concat(p1, p2)
|
||||
resultTail := tailMonoid.Concat(p1, p2)
|
||||
|
||||
// Both should give same result since operations are commutative
|
||||
assert.Equal(t, 8, Head(resultHead)) // 2 * 4
|
||||
assert.Equal(t, 8, Tail(resultHead)) // 3 + 5
|
||||
assert.Equal(t, 8, Head(resultTail)) // 2 * 4
|
||||
assert.Equal(t, 8, Tail(resultTail)) // 3 + 5
|
||||
})
|
||||
|
||||
t.Run("different empty values", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
headMonoid := ApplicativeMonoidHead(intMul, intAdd)
|
||||
tailMonoid := ApplicativeMonoidTail(intAdd, intMul)
|
||||
|
||||
emptyHead := headMonoid.Empty()
|
||||
emptyTail := tailMonoid.Empty()
|
||||
|
||||
assert.Equal(t, 1, Head(emptyHead)) // intMul empty
|
||||
assert.Equal(t, 0, Tail(emptyHead)) // intAdd empty
|
||||
assert.Equal(t, 0, Head(emptyTail)) // intAdd empty
|
||||
assert.Equal(t, 1, Tail(emptyTail)) // intMul empty
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidLaws verifies monoid laws for all implementations
|
||||
func TestMonoidLaws(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
monoid M.Monoid[Pair[int, int]]
|
||||
p1, p2, p3 Pair[int, int]
|
||||
}{
|
||||
{
|
||||
name: "ApplicativeMonoidTail",
|
||||
monoid: ApplicativeMonoidTail(N.MonoidSum[int](), N.MonoidProduct[int]()),
|
||||
p1: MakePair(1, 2),
|
||||
p2: MakePair(3, 4),
|
||||
p3: MakePair(5, 6),
|
||||
},
|
||||
{
|
||||
name: "ApplicativeMonoidHead",
|
||||
monoid: ApplicativeMonoidHead(N.MonoidProduct[int](), N.MonoidSum[int]()),
|
||||
p1: MakePair(2, 1),
|
||||
p2: MakePair(3, 2),
|
||||
p3: MakePair(4, 3),
|
||||
},
|
||||
{
|
||||
name: "ApplicativeMonoid",
|
||||
monoid: ApplicativeMonoid(N.MonoidSum[int](), N.MonoidSum[int]()),
|
||||
p1: MakePair(1, 2),
|
||||
p2: MakePair(3, 4),
|
||||
p3: MakePair(5, 6),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
left := tc.monoid.Concat(tc.monoid.Concat(tc.p1, tc.p2), tc.p3)
|
||||
right := tc.monoid.Concat(tc.p1, tc.monoid.Concat(tc.p2, tc.p3))
|
||||
assert.Equal(t, left, right)
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
result := tc.monoid.Concat(tc.monoid.Empty(), tc.p1)
|
||||
assert.Equal(t, tc.p1, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
result := tc.monoid.Concat(tc.p1, tc.monoid.Empty())
|
||||
assert.Equal(t, tc.p1, result)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMonoidEdgeCases tests edge cases for monoid operations
|
||||
func TestMonoidEdgeCases(t *testing.T) {
|
||||
t.Run("concatenating empty with empty", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, strConcat)
|
||||
|
||||
result := pairMonoid.Concat(pairMonoid.Empty(), pairMonoid.Empty())
|
||||
assert.Equal(t, pairMonoid.Empty(), result)
|
||||
})
|
||||
|
||||
t.Run("chain of operations", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, intMul)
|
||||
|
||||
result := pairMonoid.Concat(
|
||||
pairMonoid.Concat(
|
||||
pairMonoid.Concat(MakePair(1, 2), MakePair(2, 3)),
|
||||
MakePair(3, 4),
|
||||
),
|
||||
MakePair(4, 5),
|
||||
)
|
||||
|
||||
assert.Equal(t, 10, Head(result)) // 1 + 2 + 3 + 4
|
||||
assert.Equal(t, 120, Tail(result)) // 2 * 3 * 4 * 5
|
||||
})
|
||||
|
||||
t.Run("zero values", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, intMul)
|
||||
|
||||
p1 := MakePair(0, 0)
|
||||
p2 := MakePair(5, 10)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, 5, Head(result))
|
||||
assert.Equal(t, 0, Tail(result)) // 0 * 10 = 0
|
||||
})
|
||||
|
||||
t.Run("negative values", func(t *testing.T) {
|
||||
intAdd := N.MonoidSum[int]()
|
||||
intMul := N.MonoidProduct[int]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(intAdd, intMul)
|
||||
|
||||
p1 := MakePair(-5, -2)
|
||||
p2 := MakePair(3, 4)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, -2, Head(result)) // -5 + 3
|
||||
assert.Equal(t, -8, Tail(result)) // -2 * 4
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidWithDifferentTypes tests monoids with various type combinations
|
||||
func TestMonoidWithDifferentTypes(t *testing.T) {
|
||||
t.Run("string and boolean", func(t *testing.T) {
|
||||
strConcat := S.Monoid
|
||||
boolAnd := M.MakeMonoid(func(a, b bool) bool { return a && b }, true)
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(strConcat, boolAnd)
|
||||
|
||||
p1 := MakePair("hello", true)
|
||||
p2 := MakePair(" world", true)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
// Note: The order depends on the applicative implementation
|
||||
assert.Equal(t, " worldhello", Head(result))
|
||||
assert.Equal(t, true, Tail(result))
|
||||
})
|
||||
|
||||
t.Run("boolean and string", func(t *testing.T) {
|
||||
boolOr := M.MakeMonoid(func(a, b bool) bool { return a || b }, false)
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(boolOr, strConcat)
|
||||
|
||||
p1 := MakePair(false, "foo")
|
||||
p2 := MakePair(true, "bar")
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, true, Head(result))
|
||||
assert.Equal(t, "foobar", Tail(result))
|
||||
})
|
||||
|
||||
t.Run("float64 addition and multiplication", func(t *testing.T) {
|
||||
floatAdd := N.MonoidSum[float64]()
|
||||
floatMul := N.MonoidProduct[float64]()
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(floatAdd, floatMul)
|
||||
|
||||
p1 := MakePair(1.5, 2.0)
|
||||
p2 := MakePair(2.5, 3.0)
|
||||
|
||||
result := pairMonoid.Concat(p1, p2)
|
||||
assert.Equal(t, 4.0, Head(result))
|
||||
assert.Equal(t, 6.0, Tail(result))
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidCommutativity tests behavior with non-commutative operations
|
||||
func TestMonoidCommutativity(t *testing.T) {
|
||||
t.Run("string concatenation is not commutative", func(t *testing.T) {
|
||||
strConcat := S.Monoid
|
||||
|
||||
pairMonoid := ApplicativeMonoidTail(strConcat, strConcat)
|
||||
|
||||
p1 := MakePair("hello", "foo")
|
||||
p2 := MakePair(" world", "bar")
|
||||
|
||||
result1 := pairMonoid.Concat(p1, p2)
|
||||
result2 := pairMonoid.Concat(p2, p1)
|
||||
|
||||
// The applicative implementation reverses the order for head values
|
||||
assert.Equal(t, " worldhello", Head(result1))
|
||||
assert.Equal(t, "foobar", Tail(result1))
|
||||
assert.Equal(t, "hello world", Head(result2))
|
||||
assert.Equal(t, "barfoo", Tail(result2))
|
||||
assert.NotEqual(t, result1, result2)
|
||||
})
|
||||
}
|
||||
@@ -370,5 +370,3 @@ func TestTailRec_ComplexState(t *testing.T) {
|
||||
assert.Equal(t, 60, result)
|
||||
})
|
||||
}
|
||||
|
||||
// Made with Bob
|
||||
|
||||
@@ -4,10 +4,10 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
|
||||
//go:inline
|
||||
func ChainConsumer[R, A any](c Consumer[A]) Operator[R, A, struct{}] {
|
||||
return ChainIOK[R](io.FromConsumerK(c))
|
||||
return ChainIOK[R](io.FromConsumer(c))
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func ChainFirstConsumer[R, A any](c Consumer[A]) Operator[R, A, A] {
|
||||
return ChainFirstIOK[R](io.FromConsumerK(c))
|
||||
return ChainFirstIOK[R](io.FromConsumer(c))
|
||||
}
|
||||
|
||||
@@ -650,5 +650,3 @@ func TestRetrying_StackSafety(t *testing.T) {
|
||||
assert.Equal(t, maxAttempts, finalResult)
|
||||
assert.Equal(t, maxAttempts, attempts, "Should handle many retries without stack overflow")
|
||||
}
|
||||
|
||||
// Made with Bob
|
||||
|
||||
@@ -22,7 +22,7 @@ import "github.com/IBM/fp-go/v2/io"
|
||||
//
|
||||
//go:inline
|
||||
func ChainConsumer[R, E, A any](c Consumer[A]) Operator[R, E, A, struct{}] {
|
||||
return ChainIOK[R, E](io.FromConsumerK(c))
|
||||
return ChainIOK[R, E](io.FromConsumer(c))
|
||||
}
|
||||
|
||||
// ChainFirstConsumer chains a consumer into a ReaderIOEither computation while preserving
|
||||
@@ -45,5 +45,5 @@ func ChainConsumer[R, E, A any](c Consumer[A]) Operator[R, E, A, struct{}] {
|
||||
//
|
||||
//go:inline
|
||||
func ChainFirstConsumer[R, E, A any](c Consumer[A]) Operator[R, E, A, A] {
|
||||
return ChainFirstIOK[R, E](io.FromConsumerK(c))
|
||||
return ChainFirstIOK[R, E](io.FromConsumer(c))
|
||||
}
|
||||
|
||||
@@ -957,7 +957,7 @@ func MonadTapLeft[A, R, EA, EB, B any](ma ReaderIOEither[R, EA, A], f Kleisli[R,
|
||||
// - An Operator that performs the side effect but always returns the original error if input was Left
|
||||
//
|
||||
//go:inline
|
||||
func ChainFirstLeft[A, R, EA, EB, B any](f Kleisli[R, EB, EA, B]) Operator[R, EA, A, A] {
|
||||
func ChainFirstLeft[A, R, EB, EA, B any](f Kleisli[R, EB, EA, B]) Operator[R, EA, A, A] {
|
||||
return eithert.ChainFirstLeft(
|
||||
readerio.Chain[R, Either[EA, A], Either[EA, A]],
|
||||
readerio.Map[R, Either[EB, B], Either[EA, A]],
|
||||
@@ -966,11 +966,23 @@ func ChainFirstLeft[A, R, EA, EB, B any](f Kleisli[R, EB, EA, B]) Operator[R, EA
|
||||
)
|
||||
}
|
||||
|
||||
func ChainFirstLeftIOK[A, R, EA, B any](f io.Kleisli[EA, B]) Operator[R, EA, A, A] {
|
||||
return ChainFirstLeft[A](function.Flow2(
|
||||
f,
|
||||
FromIO[R, EA],
|
||||
))
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func TapLeft[A, R, EA, EB, B any](f Kleisli[R, EB, EA, B]) Operator[R, EA, A, A] {
|
||||
func TapLeft[A, R, EB, EA, B any](f Kleisli[R, EB, EA, B]) Operator[R, EA, A, A] {
|
||||
return ChainFirstLeft[A](f)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func TapLeftIOK[A, R, EA, B any](f io.Kleisli[EA, B]) Operator[R, EA, A, A] {
|
||||
return ChainFirstLeftIOK[A, R](f)
|
||||
}
|
||||
|
||||
// Delay creates an operation that passes in the value after some delay
|
||||
//
|
||||
//go:inline
|
||||
|
||||
@@ -796,11 +796,21 @@ func ChainFirstLeft[A, R, B any](f Kleisli[R, error, B]) Operator[R, A, A] {
|
||||
return RIOE.ChainFirstLeft[A](f)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func ChainFirstLeftIOK[A, R, B any](f io.Kleisli[error, B]) Operator[R, A, A] {
|
||||
return RIOE.ChainFirstLeftIOK[A, R](f)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func TapLeft[A, R, B any](f Kleisli[R, error, B]) Operator[R, A, A] {
|
||||
return RIOE.TapLeft[A](f)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func TapLeftIOK[A, R, B any](f io.Kleisli[error, B]) Operator[R, A, A] {
|
||||
return RIOE.TapLeftIOK[A, R](f)
|
||||
}
|
||||
|
||||
// Delay creates an operation that passes in the value after some delay
|
||||
//
|
||||
//go:inline
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package option
|
||||
// Package reflect provides functional programming utilities for working with Go's reflect.Value type.
|
||||
// It offers higher-order functions like Map, Reduce, and ReduceWithIndex that operate on
|
||||
// reflective values representing slices or arrays.
|
||||
//
|
||||
// These utilities are particularly useful when working with dynamic types or when implementing
|
||||
// generic algorithms that need to operate on collections discovered at runtime.
|
||||
package reflect
|
||||
|
||||
import (
|
||||
R "reflect"
|
||||
@@ -22,21 +28,84 @@ import (
|
||||
G "github.com/IBM/fp-go/v2/reflect/generic"
|
||||
)
|
||||
|
||||
// ReduceWithIndex applies a reducer function to each element of a reflect.Value (representing a slice or array),
|
||||
// accumulating a result value. The reducer function receives the current index, the accumulated value,
|
||||
// and the current element as a reflect.Value.
|
||||
//
|
||||
// This is a curried function that first takes the reducer function and initial value,
|
||||
// then returns a function that accepts the reflect.Value to reduce.
|
||||
//
|
||||
// Parameters:
|
||||
// - f: A reducer function that takes (index int, accumulator A, element reflect.Value) and returns the new accumulator
|
||||
// - initial: The initial value for the accumulation
|
||||
//
|
||||
// Returns:
|
||||
// - A function that takes a reflect.Value and returns the final accumulated value
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Sum all integers in a reflected slice with their indices
|
||||
// sumWithIndex := ReduceWithIndex(func(i int, acc int, v reflect.Value) int {
|
||||
// return acc + i + int(v.Int())
|
||||
// }, 0)
|
||||
// result := sumWithIndex(reflect.ValueOf([]int{10, 20, 30}))
|
||||
// // result = 0 + (0+10) + (1+20) + (2+30) = 63
|
||||
func ReduceWithIndex[A any](f func(int, A, R.Value) A, initial A) func(R.Value) A {
|
||||
return func(val R.Value) A {
|
||||
count := val.Len()
|
||||
current := initial
|
||||
for i := 0; i < count; i++ {
|
||||
for i := range count {
|
||||
current = f(i, current, val.Index(i))
|
||||
}
|
||||
return current
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce applies a reducer function to each element of a reflect.Value (representing a slice or array),
|
||||
// accumulating a result value. Unlike ReduceWithIndex, the reducer function does not receive the index.
|
||||
//
|
||||
// This is a curried function that first takes the reducer function and initial value,
|
||||
// then returns a function that accepts the reflect.Value to reduce.
|
||||
//
|
||||
// Parameters:
|
||||
// - f: A reducer function that takes (accumulator A, element reflect.Value) and returns the new accumulator
|
||||
// - initial: The initial value for the accumulation
|
||||
//
|
||||
// Returns:
|
||||
// - A function that takes a reflect.Value and returns the final accumulated value
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Sum all integers in a reflected slice
|
||||
// sum := Reduce(func(acc int, v reflect.Value) int {
|
||||
// return acc + int(v.Int())
|
||||
// }, 0)
|
||||
// result := sum(reflect.ValueOf([]int{10, 20, 30}))
|
||||
// // result = 60
|
||||
func Reduce[A any](f func(A, R.Value) A, initial A) func(R.Value) A {
|
||||
return ReduceWithIndex(F.Ignore1of3[int](f), initial)
|
||||
}
|
||||
|
||||
// Map transforms each element of a reflect.Value (representing a slice or array) using the provided
|
||||
// function, returning a new slice containing the transformed values.
|
||||
//
|
||||
// This is a curried function that first takes the transformation function,
|
||||
// then returns a function that accepts the reflect.Value to map over.
|
||||
//
|
||||
// Parameters:
|
||||
// - f: A transformation function that takes a reflect.Value and returns a value of type A
|
||||
//
|
||||
// Returns:
|
||||
// - A function that takes a reflect.Value and returns a slice of transformed values
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Extract integers from a reflected slice and double them
|
||||
// doubleInts := Map(func(v reflect.Value) int {
|
||||
// return int(v.Int()) * 2
|
||||
// })
|
||||
// result := doubleInts(reflect.ValueOf([]int{1, 2, 3}))
|
||||
// // result = []int{2, 4, 6}
|
||||
func Map[A any](f func(R.Value) A) func(R.Value) []A {
|
||||
return G.Map[[]A](f)
|
||||
}
|
||||
|
||||
371
v2/reflect/reflect_test.go
Normal file
371
v2/reflect/reflect_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package reflect
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestReduceWithIndex_IntSum tests reducing integers with index awareness
|
||||
func TestReduceWithIndex_IntSum(t *testing.T) {
|
||||
input := []int{10, 20, 30}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
// Sum values plus their indices: (0+10) + (1+20) + (2+30) = 63
|
||||
reducer := ReduceWithIndex(func(i int, acc int, v reflect.Value) int {
|
||||
return acc + i + int(v.Int())
|
||||
}, 0)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 63, result)
|
||||
}
|
||||
|
||||
// TestReduceWithIndex_StringConcat tests concatenating strings with indices
|
||||
func TestReduceWithIndex_StringConcat(t *testing.T) {
|
||||
input := []string{"a", "b", "c"}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
// Concatenate with indices: "0:a,1:b,2:c"
|
||||
reducer := ReduceWithIndex(func(i int, acc string, v reflect.Value) string {
|
||||
if acc == "" {
|
||||
return string(rune('0'+i)) + ":" + v.String()
|
||||
}
|
||||
return acc + "," + string(rune('0'+i)) + ":" + v.String()
|
||||
}, "")
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, "0:a,1:b,2:c", result)
|
||||
}
|
||||
|
||||
// TestReduceWithIndex_EmptySlice tests reducing an empty slice
|
||||
func TestReduceWithIndex_EmptySlice(t *testing.T) {
|
||||
input := []int{}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := ReduceWithIndex(func(i int, acc int, v reflect.Value) int {
|
||||
return acc + int(v.Int())
|
||||
}, 42)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 42, result, "Should return initial value for empty slice")
|
||||
}
|
||||
|
||||
// TestReduceWithIndex_SingleElement tests reducing a single-element slice
|
||||
func TestReduceWithIndex_SingleElement(t *testing.T) {
|
||||
input := []int{100}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := ReduceWithIndex(func(i int, acc int, v reflect.Value) int {
|
||||
return acc + i + int(v.Int())
|
||||
}, 0)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 100, result, "Should process single element correctly")
|
||||
}
|
||||
|
||||
// TestReduceWithIndex_BuildStruct tests building a complex structure
|
||||
func TestReduceWithIndex_BuildStruct(t *testing.T) {
|
||||
type Result struct {
|
||||
Sum int
|
||||
Count int
|
||||
}
|
||||
|
||||
input := []int{5, 10, 15}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := ReduceWithIndex(func(i int, acc Result, v reflect.Value) Result {
|
||||
return Result{
|
||||
Sum: acc.Sum + int(v.Int()),
|
||||
Count: acc.Count + 1,
|
||||
}
|
||||
}, Result{Sum: 0, Count: 0})
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 30, result.Sum)
|
||||
assert.Equal(t, 3, result.Count)
|
||||
}
|
||||
|
||||
// TestReduce_IntSum tests basic integer summation
|
||||
func TestReduce_IntSum(t *testing.T) {
|
||||
input := []int{10, 20, 30}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := Reduce(func(acc int, v reflect.Value) int {
|
||||
return acc + int(v.Int())
|
||||
}, 0)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 60, result)
|
||||
}
|
||||
|
||||
// TestReduce_IntProduct tests integer multiplication
|
||||
func TestReduce_IntProduct(t *testing.T) {
|
||||
input := []int{2, 3, 4}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := Reduce(func(acc int, v reflect.Value) int {
|
||||
return acc * int(v.Int())
|
||||
}, 1)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 24, result)
|
||||
}
|
||||
|
||||
// TestReduce_StringConcat tests string concatenation
|
||||
func TestReduce_StringConcat(t *testing.T) {
|
||||
input := []string{"Hello", " ", "World"}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := Reduce(func(acc string, v reflect.Value) string {
|
||||
return acc + v.String()
|
||||
}, "")
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, "Hello World", result)
|
||||
}
|
||||
|
||||
// TestReduce_EmptySlice tests reducing an empty slice
|
||||
func TestReduce_EmptySlice(t *testing.T) {
|
||||
input := []int{}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := Reduce(func(acc int, v reflect.Value) int {
|
||||
return acc + int(v.Int())
|
||||
}, 100)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 100, result, "Should return initial value for empty slice")
|
||||
}
|
||||
|
||||
// TestReduce_FindMax tests finding maximum value
|
||||
func TestReduce_FindMax(t *testing.T) {
|
||||
input := []int{3, 7, 2, 9, 1, 5}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
reducer := Reduce(func(acc int, v reflect.Value) int {
|
||||
val := int(v.Int())
|
||||
if val > acc {
|
||||
return val
|
||||
}
|
||||
return acc
|
||||
}, input[0])
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 9, result)
|
||||
}
|
||||
|
||||
// TestReduce_CountElements tests counting elements matching a condition
|
||||
func TestReduce_CountElements(t *testing.T) {
|
||||
input := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
// Count even numbers
|
||||
reducer := Reduce(func(acc int, v reflect.Value) int {
|
||||
if int(v.Int())%2 == 0 {
|
||||
return acc + 1
|
||||
}
|
||||
return acc
|
||||
}, 0)
|
||||
|
||||
result := reducer(reflectVal)
|
||||
assert.Equal(t, 5, result, "Should count 5 even numbers")
|
||||
}
|
||||
|
||||
// TestMap_IntToString tests mapping integers to strings
|
||||
func TestMap_IntToString(t *testing.T) {
|
||||
input := []int{1, 2, 3}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) string {
|
||||
return "num:" + string(rune('0'+int(v.Int())))
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
expected := []string{"num:1", "num:2", "num:3"}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestMap_DoubleInts tests doubling integer values
|
||||
func TestMap_DoubleInts(t *testing.T) {
|
||||
input := []int{1, 2, 3, 4, 5}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) int {
|
||||
return int(v.Int()) * 2
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
expected := []int{2, 4, 6, 8, 10}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestMap_ExtractField tests extracting a field from structs
|
||||
func TestMap_ExtractField(t *testing.T) {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
input := []Person{
|
||||
{Name: "Alice", Age: 30},
|
||||
{Name: "Bob", Age: 25},
|
||||
{Name: "Charlie", Age: 35},
|
||||
}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) string {
|
||||
return v.FieldByName("Name").String()
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
expected := []string{"Alice", "Bob", "Charlie"}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestMap_EmptySlice tests mapping an empty slice
|
||||
func TestMap_EmptySlice(t *testing.T) {
|
||||
input := []int{}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) int {
|
||||
return int(v.Int()) * 2
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
assert.Empty(t, result, "Should return empty slice")
|
||||
assert.NotNil(t, result, "Should not return nil")
|
||||
}
|
||||
|
||||
// TestMap_SingleElement tests mapping a single-element slice
|
||||
func TestMap_SingleElement(t *testing.T) {
|
||||
input := []int{42}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) int {
|
||||
return int(v.Int()) * 2
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
expected := []int{84}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestMap_BoolToInt tests mapping booleans to integers
|
||||
func TestMap_BoolToInt(t *testing.T) {
|
||||
input := []bool{true, false, true, true, false}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) int {
|
||||
if v.Bool() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
expected := []int{1, 0, 1, 1, 0}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestMap_ComplexTransformation tests a complex transformation
|
||||
func TestMap_ComplexTransformation(t *testing.T) {
|
||||
input := []int{1, 2, 3, 4, 5}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
type Result struct {
|
||||
Original int
|
||||
Squared int
|
||||
IsEven bool
|
||||
}
|
||||
|
||||
mapper := Map(func(v reflect.Value) Result {
|
||||
val := int(v.Int())
|
||||
return Result{
|
||||
Original: val,
|
||||
Squared: val * val,
|
||||
IsEven: val%2 == 0,
|
||||
}
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
assert.Len(t, result, 5)
|
||||
assert.Equal(t, 1, result[0].Original)
|
||||
assert.Equal(t, 1, result[0].Squared)
|
||||
assert.False(t, result[0].IsEven)
|
||||
assert.Equal(t, 4, result[3].Original)
|
||||
assert.Equal(t, 16, result[3].Squared)
|
||||
assert.True(t, result[3].IsEven)
|
||||
}
|
||||
|
||||
// TestMap_StringLength tests mapping strings to their lengths
|
||||
func TestMap_StringLength(t *testing.T) {
|
||||
input := []string{"a", "ab", "abc", "abcd"}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
mapper := Map(func(v reflect.Value) int {
|
||||
return len(v.String())
|
||||
})
|
||||
|
||||
result := mapper(reflectVal)
|
||||
expected := []int{1, 2, 3, 4}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestIntegration_MapThenReduce tests combining Map and Reduce operations
|
||||
func TestIntegration_MapThenReduce(t *testing.T) {
|
||||
input := []int{1, 2, 3, 4, 5}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
// First map: square each number
|
||||
mapper := Map(func(v reflect.Value) int {
|
||||
val := int(v.Int())
|
||||
return val * val
|
||||
})
|
||||
squared := mapper(reflectVal)
|
||||
|
||||
// Then reduce: sum all squared values
|
||||
squaredReflect := reflect.ValueOf(squared)
|
||||
reducer := Reduce(func(acc int, v reflect.Value) int {
|
||||
return acc + int(v.Int())
|
||||
}, 0)
|
||||
result := reducer(squaredReflect)
|
||||
|
||||
// 1^2 + 2^2 + 3^2 + 4^2 + 5^2 = 1 + 4 + 9 + 16 + 25 = 55
|
||||
assert.Equal(t, 55, result)
|
||||
}
|
||||
|
||||
// TestIntegration_ReduceWithIndexToMap tests using ReduceWithIndex to build a map
|
||||
func TestIntegration_ReduceWithIndexToMap(t *testing.T) {
|
||||
input := []string{"apple", "banana", "cherry"}
|
||||
reflectVal := reflect.ValueOf(input)
|
||||
|
||||
// Build a map with index as key
|
||||
reducer := ReduceWithIndex(func(i int, acc map[int]string, v reflect.Value) map[int]string {
|
||||
acc[i] = v.String()
|
||||
return acc
|
||||
}, make(map[int]string))
|
||||
|
||||
result := reducer(reflectVal)
|
||||
expected := map[int]string{
|
||||
0: "apple",
|
||||
1: "banana",
|
||||
2: "cherry",
|
||||
}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
@@ -52,3 +52,61 @@ func AlternativeMonoid[A any](m M.Monoid[A]) Monoid[A] {
|
||||
func AltMonoid[A any](zero Lazy[Result[A]]) Monoid[A] {
|
||||
return either.AltMonoid(zero)
|
||||
}
|
||||
|
||||
// FirstMonoid creates a Monoid for Result[A] that returns the first Ok (Right) value.
|
||||
// This monoid prefers the left operand when it is Ok, otherwise returns the right operand.
|
||||
// The empty value is provided as a lazy computation.
|
||||
//
|
||||
// This is equivalent to AltMonoid but implemented more directly.
|
||||
//
|
||||
// Truth table:
|
||||
//
|
||||
// | x | y | concat(x, y) |
|
||||
// | --------- | --------- | ------------ |
|
||||
// | err(e1) | err(e2) | err(e2) |
|
||||
// | ok(a) | err(e) | ok(a) |
|
||||
// | err(e) | ok(b) | ok(b) |
|
||||
// | ok(a) | ok(b) | ok(a) |
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import "errors"
|
||||
// zero := func() result.Result[int] { return result.Error[int](errors.New("empty")) }
|
||||
// m := result.FirstMonoid[int](zero)
|
||||
// m.Concat(result.Of(2), result.Of(3)) // Ok(2) - returns first Ok
|
||||
// m.Concat(result.Error[int](errors.New("err")), result.Of(3)) // Ok(3)
|
||||
// m.Concat(result.Of(2), result.Error[int](errors.New("err"))) // Ok(2)
|
||||
// m.Empty() // Error(error("empty"))
|
||||
//
|
||||
//go:inline
|
||||
func FirstMonoid[A any](zero Lazy[Result[A]]) M.Monoid[Result[A]] {
|
||||
return either.FirstMonoid(zero)
|
||||
}
|
||||
|
||||
// LastMonoid creates a Monoid for Result[A] that returns the last Ok (Right) value.
|
||||
// This monoid prefers the right operand when it is Ok, otherwise returns the left operand.
|
||||
// The empty value is provided as a lazy computation.
|
||||
//
|
||||
// Truth table:
|
||||
//
|
||||
// | x | y | concat(x, y) |
|
||||
// | --------- | --------- | ------------ |
|
||||
// | err(e1) | err(e2) | err(e1) |
|
||||
// | ok(a) | err(e) | ok(a) |
|
||||
// | err(e) | ok(b) | ok(b) |
|
||||
// | ok(a) | ok(b) | ok(b) |
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import "errors"
|
||||
// zero := func() result.Result[int] { return result.Error[int](errors.New("empty")) }
|
||||
// m := result.LastMonoid[int](zero)
|
||||
// m.Concat(result.Of(2), result.Of(3)) // Ok(3) - returns last Ok
|
||||
// m.Concat(result.Error[int](errors.New("err")), result.Of(3)) // Ok(3)
|
||||
// m.Concat(result.Of(2), result.Error[int](errors.New("err"))) // Ok(2)
|
||||
// m.Empty() // Error(error("empty"))
|
||||
//
|
||||
//go:inline
|
||||
func LastMonoid[A any](zero Lazy[Result[A]]) M.Monoid[Result[A]] {
|
||||
return either.LastMonoid(zero)
|
||||
}
|
||||
|
||||
498
v2/result/monoid_test.go
Normal file
498
v2/result/monoid_test.go
Normal file
@@ -0,0 +1,498 @@
|
||||
// Copyright (c) 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package result
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestFirstMonoid tests the FirstMonoid implementation
|
||||
func TestFirstMonoid(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[int](zero)
|
||||
|
||||
t.Run("both Right values - returns first", func(t *testing.T) {
|
||||
result := m.Concat(Right(2), Right(3))
|
||||
assert.Equal(t, Right(2), result)
|
||||
})
|
||||
|
||||
t.Run("left Right, right Left", func(t *testing.T) {
|
||||
result := m.Concat(Right(2), Left[int](errors.New("err")))
|
||||
assert.Equal(t, Right(2), result)
|
||||
})
|
||||
|
||||
t.Run("left Left, right Right", func(t *testing.T) {
|
||||
result := m.Concat(Left[int](errors.New("err")), Right(3))
|
||||
assert.Equal(t, Right(3), result)
|
||||
})
|
||||
|
||||
t.Run("both Left", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
result := m.Concat(Left[int](err1), Left[int](err2))
|
||||
// Should return the second Left
|
||||
assert.True(t, IsLeft(result))
|
||||
_, leftErr := Unwrap(result)
|
||||
assert.Equal(t, err2, leftErr)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := m.Empty()
|
||||
assert.True(t, IsLeft(empty))
|
||||
_, leftErr := Unwrap(empty)
|
||||
assert.Equal(t, "empty", leftErr.Error())
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Right(5)
|
||||
result := m.Concat(m.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Right(5)
|
||||
result := m.Concat(x, m.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Right(1)
|
||||
b := Right(2)
|
||||
c := Right(3)
|
||||
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Right(1), left)
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
// Should return the first Right value encountered
|
||||
result := m.Concat(
|
||||
m.Concat(Left[int](errors.New("err1")), Right(1)),
|
||||
m.Concat(Right(2), Right(3)),
|
||||
)
|
||||
assert.Equal(t, Right(1), result)
|
||||
})
|
||||
|
||||
t.Run("with strings", func(t *testing.T) {
|
||||
zeroStr := func() Result[string] { return Left[string](errors.New("empty")) }
|
||||
strMonoid := FirstMonoid[string](zeroStr)
|
||||
|
||||
result := strMonoid.Concat(Right("first"), Right("second"))
|
||||
assert.Equal(t, Right("first"), result)
|
||||
|
||||
result = strMonoid.Concat(Left[string](errors.New("err")), Right("second"))
|
||||
assert.Equal(t, Right("second"), result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLastMonoid tests the LastMonoid implementation
|
||||
func TestLastMonoid(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[int](zero)
|
||||
|
||||
t.Run("both Right values - returns last", func(t *testing.T) {
|
||||
result := m.Concat(Right(2), Right(3))
|
||||
assert.Equal(t, Right(3), result)
|
||||
})
|
||||
|
||||
t.Run("left Right, right Left", func(t *testing.T) {
|
||||
result := m.Concat(Right(2), Left[int](errors.New("err")))
|
||||
assert.Equal(t, Right(2), result)
|
||||
})
|
||||
|
||||
t.Run("left Left, right Right", func(t *testing.T) {
|
||||
result := m.Concat(Left[int](errors.New("err")), Right(3))
|
||||
assert.Equal(t, Right(3), result)
|
||||
})
|
||||
|
||||
t.Run("both Left", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
result := m.Concat(Left[int](err1), Left[int](err2))
|
||||
// Should return the first Left
|
||||
assert.True(t, IsLeft(result))
|
||||
_, leftErr := Unwrap(result)
|
||||
assert.Equal(t, err1, leftErr)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := m.Empty()
|
||||
assert.True(t, IsLeft(empty))
|
||||
_, leftErr := Unwrap(empty)
|
||||
assert.Equal(t, "empty", leftErr.Error())
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Right(5)
|
||||
result := m.Concat(m.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Right(5)
|
||||
result := m.Concat(x, m.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Right(1)
|
||||
b := Right(2)
|
||||
c := Right(3)
|
||||
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Right(3), left)
|
||||
})
|
||||
|
||||
t.Run("multiple concatenations", func(t *testing.T) {
|
||||
// Should return the last Right value encountered
|
||||
result := m.Concat(
|
||||
m.Concat(Right(1), Right(2)),
|
||||
m.Concat(Right(3), Left[int](errors.New("err"))),
|
||||
)
|
||||
assert.Equal(t, Right(3), result)
|
||||
})
|
||||
|
||||
t.Run("with strings", func(t *testing.T) {
|
||||
zeroStr := func() Result[string] { return Left[string](errors.New("empty")) }
|
||||
strMonoid := LastMonoid[string](zeroStr)
|
||||
|
||||
result := strMonoid.Concat(Right("first"), Right("second"))
|
||||
assert.Equal(t, Right("second"), result)
|
||||
|
||||
result = strMonoid.Concat(Right("first"), Left[string](errors.New("err")))
|
||||
assert.Equal(t, Right("first"), result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAltMonoid tests the AltMonoid implementation
|
||||
func TestAltMonoid(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := AltMonoid[int](zero)
|
||||
|
||||
t.Run("both Right values - returns first", func(t *testing.T) {
|
||||
result := m.Concat(Right(2), Right(3))
|
||||
assert.Equal(t, Right(2), result)
|
||||
})
|
||||
|
||||
t.Run("left Right, right Left", func(t *testing.T) {
|
||||
result := m.Concat(Right(2), Left[int](errors.New("err")))
|
||||
assert.Equal(t, Right(2), result)
|
||||
})
|
||||
|
||||
t.Run("left Left, right Right", func(t *testing.T) {
|
||||
result := m.Concat(Left[int](errors.New("err")), Right(3))
|
||||
assert.Equal(t, Right(3), result)
|
||||
})
|
||||
|
||||
t.Run("both Left", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
result := m.Concat(Left[int](err1), Left[int](err2))
|
||||
// Should return the second Left
|
||||
assert.True(t, IsLeft(result))
|
||||
_, leftErr := Unwrap(result)
|
||||
assert.Equal(t, err2, leftErr)
|
||||
})
|
||||
|
||||
t.Run("empty value", func(t *testing.T) {
|
||||
empty := m.Empty()
|
||||
assert.True(t, IsLeft(empty))
|
||||
_, leftErr := Unwrap(empty)
|
||||
assert.Equal(t, "empty", leftErr.Error())
|
||||
})
|
||||
|
||||
t.Run("left identity", func(t *testing.T) {
|
||||
x := Right(5)
|
||||
result := m.Concat(m.Empty(), x)
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("right identity", func(t *testing.T) {
|
||||
x := Right(5)
|
||||
result := m.Concat(x, m.Empty())
|
||||
assert.Equal(t, x, result)
|
||||
})
|
||||
|
||||
t.Run("associativity", func(t *testing.T) {
|
||||
a := Right(1)
|
||||
b := Left[int](errors.New("err"))
|
||||
c := Right(3)
|
||||
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
|
||||
assert.Equal(t, left, right)
|
||||
assert.Equal(t, Right(1), left)
|
||||
})
|
||||
}
|
||||
|
||||
// TestFirstMonoidVsAltMonoid verifies FirstMonoid and AltMonoid have the same behavior
|
||||
func TestFirstMonoidVsAltMonoid(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
firstMonoid := FirstMonoid[int](zero)
|
||||
altMonoid := AltMonoid[int](zero)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left Result[int]
|
||||
right Result[int]
|
||||
}{
|
||||
{"both Right", Right(1), Right(2)},
|
||||
{"left Right, right Left", Right(1), Left[int](errors.New("err"))},
|
||||
{"left Left, right Right", Left[int](errors.New("err")), Right(2)},
|
||||
{"both Left", Left[int](errors.New("err1")), Left[int](errors.New("err2"))},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(tc.left, tc.right)
|
||||
altResult := altMonoid.Concat(tc.left, tc.right)
|
||||
|
||||
// Both should have the same Right/Left status
|
||||
assert.Equal(t, IsRight(firstResult), IsRight(altResult), "FirstMonoid and AltMonoid should have same Right/Left status")
|
||||
|
||||
if IsRight(firstResult) {
|
||||
rightVal1, _ := Unwrap(firstResult)
|
||||
rightVal2, _ := Unwrap(altResult)
|
||||
assert.Equal(t, rightVal1, rightVal2, "FirstMonoid and AltMonoid should have same Right value")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFirstMonoidVsLastMonoid verifies the difference between FirstMonoid and LastMonoid
|
||||
func TestFirstMonoidVsLastMonoid(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
firstMonoid := FirstMonoid[int](zero)
|
||||
lastMonoid := LastMonoid[int](zero)
|
||||
|
||||
t.Run("both Right - different results", func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(Right(1), Right(2))
|
||||
lastResult := lastMonoid.Concat(Right(1), Right(2))
|
||||
|
||||
assert.Equal(t, Right(1), firstResult)
|
||||
assert.Equal(t, Right(2), lastResult)
|
||||
assert.NotEqual(t, firstResult, lastResult)
|
||||
})
|
||||
|
||||
t.Run("with Left values - different behavior", func(t *testing.T) {
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
|
||||
// Both Left: FirstMonoid returns second, LastMonoid returns first
|
||||
firstResult := firstMonoid.Concat(Left[int](err1), Left[int](err2))
|
||||
lastResult := lastMonoid.Concat(Left[int](err1), Left[int](err2))
|
||||
|
||||
assert.True(t, IsLeft(firstResult))
|
||||
assert.True(t, IsLeft(lastResult))
|
||||
_, leftErr1 := Unwrap(firstResult)
|
||||
_, leftErr2 := Unwrap(lastResult)
|
||||
assert.Equal(t, err2, leftErr1)
|
||||
assert.Equal(t, err1, leftErr2)
|
||||
})
|
||||
|
||||
t.Run("mixed values - same results", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
left Result[int]
|
||||
right Result[int]
|
||||
expected Result[int]
|
||||
}{
|
||||
{"left Right, right Left", Right(1), Left[int](errors.New("err")), Right(1)},
|
||||
{"left Left, right Right", Left[int](errors.New("err")), Right(2), Right(2)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
firstResult := firstMonoid.Concat(tc.left, tc.right)
|
||||
lastResult := lastMonoid.Concat(tc.left, tc.right)
|
||||
|
||||
assert.Equal(t, tc.expected, firstResult)
|
||||
assert.Equal(t, tc.expected, lastResult)
|
||||
assert.Equal(t, firstResult, lastResult)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidLaws verifies monoid laws for all monoid implementations
|
||||
func TestMonoidLaws(t *testing.T) {
|
||||
t.Run("FirstMonoid laws", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[int](zero)
|
||||
|
||||
a := Right(1)
|
||||
b := Right(2)
|
||||
c := Right(3)
|
||||
|
||||
// Associativity: (a • b) • c = a • (b • c)
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity: Empty() • a = a
|
||||
leftId := m.Concat(m.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity: a • Empty() = a
|
||||
rightId := m.Concat(a, m.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid laws", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[int](zero)
|
||||
|
||||
a := Right(1)
|
||||
b := Right(2)
|
||||
c := Right(3)
|
||||
|
||||
// Associativity: (a • b) • c = a • (b • c)
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity: Empty() • a = a
|
||||
leftId := m.Concat(m.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity: a • Empty() = a
|
||||
rightId := m.Concat(a, m.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("AltMonoid laws", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := AltMonoid[int](zero)
|
||||
|
||||
a := Right(1)
|
||||
b := Right(2)
|
||||
c := Right(3)
|
||||
|
||||
// Associativity: (a • b) • c = a • (b • c)
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
|
||||
// Left identity: Empty() • a = a
|
||||
leftId := m.Concat(m.Empty(), a)
|
||||
assert.Equal(t, a, leftId)
|
||||
|
||||
// Right identity: a • Empty() = a
|
||||
rightId := m.Concat(a, m.Empty())
|
||||
assert.Equal(t, a, rightId)
|
||||
})
|
||||
|
||||
t.Run("FirstMonoid laws with Left values", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[int](zero)
|
||||
|
||||
a := Left[int](errors.New("err1"))
|
||||
b := Left[int](errors.New("err2"))
|
||||
c := Left[int](errors.New("err3"))
|
||||
|
||||
// Associativity with Left values
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid laws with Left values", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[int](zero)
|
||||
|
||||
a := Left[int](errors.New("err1"))
|
||||
b := Left[int](errors.New("err2"))
|
||||
c := Left[int](errors.New("err3"))
|
||||
|
||||
// Associativity with Left values
|
||||
left := m.Concat(m.Concat(a, b), c)
|
||||
right := m.Concat(a, m.Concat(b, c))
|
||||
assert.Equal(t, left, right)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMonoidEdgeCases tests edge cases for monoid operations
|
||||
func TestMonoidEdgeCases(t *testing.T) {
|
||||
t.Run("FirstMonoid with empty concatenations", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[int](zero)
|
||||
|
||||
// Empty with empty
|
||||
result := m.Concat(m.Empty(), m.Empty())
|
||||
assert.True(t, IsLeft(result))
|
||||
})
|
||||
|
||||
t.Run("LastMonoid with empty concatenations", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[int](zero)
|
||||
|
||||
// Empty with empty
|
||||
result := m.Concat(m.Empty(), m.Empty())
|
||||
assert.True(t, IsLeft(result))
|
||||
})
|
||||
|
||||
t.Run("FirstMonoid chain of operations", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := FirstMonoid[int](zero)
|
||||
|
||||
// Chain multiple operations
|
||||
result := m.Concat(
|
||||
m.Concat(
|
||||
m.Concat(Left[int](errors.New("err1")), Left[int](errors.New("err2"))),
|
||||
Right(1),
|
||||
),
|
||||
m.Concat(Right(2), Right(3)),
|
||||
)
|
||||
assert.Equal(t, Right(1), result)
|
||||
})
|
||||
|
||||
t.Run("LastMonoid chain of operations", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := LastMonoid[int](zero)
|
||||
|
||||
// Chain multiple operations
|
||||
result := m.Concat(
|
||||
m.Concat(Right(1), Right(2)),
|
||||
m.Concat(
|
||||
Right(3),
|
||||
m.Concat(Right(4), Left[int](errors.New("err"))),
|
||||
),
|
||||
)
|
||||
assert.Equal(t, Right(4), result)
|
||||
})
|
||||
|
||||
t.Run("AltMonoid chain of operations", func(t *testing.T) {
|
||||
zero := func() Result[int] { return Left[int](errors.New("empty")) }
|
||||
m := AltMonoid[int](zero)
|
||||
|
||||
// Chain multiple operations - should return first Right
|
||||
result := m.Concat(
|
||||
m.Concat(Left[int](errors.New("err1")), Right(1)),
|
||||
m.Concat(Right(2), Right(3)),
|
||||
)
|
||||
assert.Equal(t, Right(1), result)
|
||||
})
|
||||
}
|
||||
349
v2/state/bind.go
349
v2/state/bind.go
@@ -1,3 +1,18 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package state
|
||||
|
||||
import (
|
||||
@@ -7,6 +22,39 @@ import (
|
||||
F "github.com/IBM/fp-go/v2/internal/functor"
|
||||
)
|
||||
|
||||
// Do initializes a do-notation computation with an empty value.
|
||||
// This is the entry point for building complex stateful computations using
|
||||
// the do-notation pattern, which allows for imperative-style sequencing of
|
||||
// monadic operations.
|
||||
//
|
||||
// The do-notation pattern is useful for building pipelines where you need to:
|
||||
// - Bind intermediate results to names
|
||||
// - Sequence multiple stateful operations
|
||||
// - Build up complex state transformations step by step
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type MyState struct {
|
||||
// x int
|
||||
// y int
|
||||
// }
|
||||
//
|
||||
// type Result struct {
|
||||
// sum int
|
||||
// product int
|
||||
// }
|
||||
//
|
||||
// computation := function.Pipe3(
|
||||
// Do[MyState](Result{}),
|
||||
// Bind(func(r Result) func(int) Result {
|
||||
// return func(x int) Result { r.sum = x; return r }
|
||||
// }, Gets(func(s MyState) int { return s.x })),
|
||||
// Bind(func(r Result) func(int) Result {
|
||||
// return func(y int) Result { r.product = r.sum * y; return r }
|
||||
// }, Gets(func(s MyState) int { return s.y })),
|
||||
// Map[MyState](func(r Result) int { return r.product }),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func Do[ST, A any](
|
||||
empty A,
|
||||
@@ -14,6 +62,40 @@ func Do[ST, A any](
|
||||
return Of[ST](empty)
|
||||
}
|
||||
|
||||
// Bind sequences a stateful computation and binds its result to a field in an
|
||||
// accumulator structure. This is a key building block for do-notation, allowing
|
||||
// you to extract values from State computations and incorporate them into a
|
||||
// growing result structure.
|
||||
//
|
||||
// The setter function takes the computed value T and returns a function that
|
||||
// updates the accumulator from S1 to S2 by setting the field to T.
|
||||
//
|
||||
// Parameters:
|
||||
// - setter: A function that takes a value T and returns a function to update
|
||||
// the accumulator structure from S1 to S2
|
||||
// - f: A Kleisli arrow that takes the current accumulator S1 and produces a
|
||||
// State computation yielding T
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// value int
|
||||
// doubled int
|
||||
// }
|
||||
//
|
||||
// // Bind the result of a computation to the 'doubled' field
|
||||
// computation := Bind(
|
||||
// func(d int) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.doubled = d
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// func(acc Accumulator) State[MyState, int] {
|
||||
// return Of[MyState](acc.value * 2)
|
||||
// },
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func Bind[ST, S1, S2, T any](
|
||||
setter func(T) func(S1) S2,
|
||||
@@ -27,6 +109,39 @@ func Bind[ST, S1, S2, T any](
|
||||
)
|
||||
}
|
||||
|
||||
// Let computes a pure value from the current accumulator and binds it to a field.
|
||||
// Unlike Bind, this doesn't execute a stateful computation - it simply applies a
|
||||
// pure function to the accumulator and stores the result.
|
||||
//
|
||||
// This is useful in do-notation when you need to compute derived values without
|
||||
// performing stateful operations.
|
||||
//
|
||||
// Parameters:
|
||||
// - key: A function that takes the computed value T and returns a function to
|
||||
// update the accumulator from S1 to S2
|
||||
// - f: A pure function that extracts or computes T from the current accumulator S1
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// x int
|
||||
// y int
|
||||
// sum int
|
||||
// }
|
||||
//
|
||||
// // Compute sum from x and y without state operations
|
||||
// computation := Let(
|
||||
// func(s int) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.sum = s
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// func(acc Accumulator) int {
|
||||
// return acc.x + acc.y
|
||||
// },
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func Let[ST, S1, S2, T any](
|
||||
key func(T) func(S1) S2,
|
||||
@@ -39,6 +154,36 @@ func Let[ST, S1, S2, T any](
|
||||
)
|
||||
}
|
||||
|
||||
// LetTo binds a constant value to a field in the accumulator.
|
||||
// This is a specialized version of Let where the value is already known
|
||||
// and doesn't need to be computed from the accumulator.
|
||||
//
|
||||
// This is useful for initializing fields with constant values or for
|
||||
// setting default values in do-notation pipelines.
|
||||
//
|
||||
// Parameters:
|
||||
// - key: A function that takes the constant value T and returns a function to
|
||||
// update the accumulator from S1 to S2
|
||||
// - b: The constant value to bind
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// status string
|
||||
// value int
|
||||
// }
|
||||
//
|
||||
// // Set a constant status
|
||||
// computation := LetTo(
|
||||
// func(s string) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.status = s
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// "initialized",
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func LetTo[ST, S1, S2, T any](
|
||||
key func(T) func(S1) S2,
|
||||
@@ -51,6 +196,37 @@ func LetTo[ST, S1, S2, T any](
|
||||
)
|
||||
}
|
||||
|
||||
// BindTo creates an initial accumulator structure from a value.
|
||||
// This is typically the first operation in a do-notation pipeline,
|
||||
// converting a simple value into a structure that can accumulate
|
||||
// additional fields.
|
||||
//
|
||||
// Parameters:
|
||||
// - setter: A function that takes a value T and creates an accumulator structure S1
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// initial int
|
||||
// doubled int
|
||||
// }
|
||||
//
|
||||
// // Start a pipeline by binding the initial value
|
||||
// computation := function.Pipe2(
|
||||
// Of[MyState](42),
|
||||
// BindTo(func(x int) Accumulator {
|
||||
// return Accumulator{initial: x}
|
||||
// }),
|
||||
// Bind(func(d int) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.doubled = d
|
||||
// return acc
|
||||
// }
|
||||
// }, func(acc Accumulator) State[MyState, int] {
|
||||
// return Of[MyState](acc.initial * 2)
|
||||
// }),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func BindTo[ST, S1, T any](
|
||||
setter func(T) S1,
|
||||
@@ -61,6 +237,39 @@ func BindTo[ST, S1, T any](
|
||||
)
|
||||
}
|
||||
|
||||
// ApS applies a State computation in an applicative style and binds the result
|
||||
// to a field in the accumulator. Unlike Bind, which uses monadic sequencing,
|
||||
// ApS uses applicative composition, which can be more efficient when the
|
||||
// computation doesn't depend on the accumulator value.
|
||||
//
|
||||
// This is useful when you have independent State computations that can be
|
||||
// composed without depending on each other's results.
|
||||
//
|
||||
// Parameters:
|
||||
// - setter: A function that takes the computed value T and returns a function to
|
||||
// update the accumulator from S1 to S2
|
||||
// - fa: A State computation that produces a value of type T
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// counter int
|
||||
// timestamp int64
|
||||
// }
|
||||
//
|
||||
// // Apply an independent state computation
|
||||
// getTimestamp := Gets(func(s MyState) int64 { return s.timestamp })
|
||||
//
|
||||
// computation := ApS(
|
||||
// func(ts int64) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.timestamp = ts
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// getTimestamp,
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func ApS[ST, S1, S2, T any](
|
||||
setter func(T) func(S1) S2,
|
||||
@@ -74,6 +283,41 @@ func ApS[ST, S1, S2, T any](
|
||||
)
|
||||
}
|
||||
|
||||
// ApSL is a lens-based version of ApS that uses a lens to focus on a specific
|
||||
// field in the accumulator structure. This provides a more convenient and
|
||||
// type-safe way to update nested fields.
|
||||
//
|
||||
// A lens provides both a getter and setter for a field, making it easier to
|
||||
// work with complex data structures without manually writing setter functions.
|
||||
//
|
||||
// Parameters:
|
||||
// - lens: A lens focusing on field T within structure S
|
||||
// - fa: A State computation that produces a value of type T
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type MyState struct {
|
||||
// counter int
|
||||
// }
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// value int
|
||||
// doubled int
|
||||
// }
|
||||
//
|
||||
// // Create a lens for the 'doubled' field
|
||||
// doubledLens := MakeLens(
|
||||
// func(acc Accumulator) int { return acc.doubled },
|
||||
// func(d int) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.doubled = d
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// )
|
||||
//
|
||||
// computation := ApSL(doubledLens, Of[MyState](42))
|
||||
//
|
||||
//go:inline
|
||||
func ApSL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
@@ -82,6 +326,49 @@ func ApSL[ST, S, T any](
|
||||
return ApS(lens.Set, fa)
|
||||
}
|
||||
|
||||
// BindL is a lens-based version of Bind that focuses on a specific field,
|
||||
// extracts its value, applies a stateful computation, and updates the field
|
||||
// with the result. This is particularly useful for updating nested fields
|
||||
// based on their current values.
|
||||
//
|
||||
// The computation receives the current value of the focused field and produces
|
||||
// a new value through a State computation.
|
||||
//
|
||||
// Parameters:
|
||||
// - lens: A lens focusing on field T within structure S
|
||||
// - f: A Kleisli arrow that takes the current field value and produces a
|
||||
// State computation yielding the new value
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type MyState struct {
|
||||
// multiplier int
|
||||
// }
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// value int
|
||||
// }
|
||||
//
|
||||
// valueLens := MakeLens(
|
||||
// func(acc Accumulator) int { return acc.value },
|
||||
// func(v int) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.value = v
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// )
|
||||
//
|
||||
// // Double the value using state
|
||||
// computation := BindL(
|
||||
// valueLens,
|
||||
// func(v int) State[MyState, int] {
|
||||
// return Gets(func(s MyState) int {
|
||||
// return v * s.multiplier
|
||||
// })
|
||||
// },
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func BindL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
@@ -90,6 +377,40 @@ func BindL[ST, S, T any](
|
||||
return Bind(lens.Set, function.Flow2(lens.Get, f))
|
||||
}
|
||||
|
||||
// LetL is a lens-based version of Let that focuses on a specific field,
|
||||
// extracts its value, applies a pure transformation, and updates the field
|
||||
// with the result. This is useful for pure transformations of nested fields.
|
||||
//
|
||||
// Unlike BindL, this doesn't perform stateful computations - it only applies
|
||||
// a pure function to the field value.
|
||||
//
|
||||
// Parameters:
|
||||
// - lens: A lens focusing on field T within structure S
|
||||
// - f: An endomorphism (pure function from T to T) that transforms the field value
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// counter int
|
||||
// message string
|
||||
// }
|
||||
//
|
||||
// counterLens := MakeLens(
|
||||
// func(acc Accumulator) int { return acc.counter },
|
||||
// func(c int) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.counter = c
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// )
|
||||
//
|
||||
// // Increment the counter
|
||||
// computation := LetL(
|
||||
// counterLens,
|
||||
// func(c int) int { return c + 1 },
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func LetL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
@@ -98,6 +419,34 @@ func LetL[ST, S, T any](
|
||||
return Let[ST](lens.Set, function.Flow2(lens.Get, f))
|
||||
}
|
||||
|
||||
// LetToL is a lens-based version of LetTo that sets a specific field to a
|
||||
// constant value. This provides a convenient way to update nested fields
|
||||
// with known values.
|
||||
//
|
||||
// Parameters:
|
||||
// - lens: A lens focusing on field T within structure S
|
||||
// - b: The constant value to set
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Accumulator struct {
|
||||
// status string
|
||||
// value int
|
||||
// }
|
||||
//
|
||||
// statusLens := MakeLens(
|
||||
// func(acc Accumulator) string { return acc.status },
|
||||
// func(s string) func(Accumulator) Accumulator {
|
||||
// return func(acc Accumulator) Accumulator {
|
||||
// acc.status = s
|
||||
// return acc
|
||||
// }
|
||||
// },
|
||||
// )
|
||||
//
|
||||
// // Set status to "completed"
|
||||
// computation := LetToL(statusLens, "completed")
|
||||
//
|
||||
//go:inline
|
||||
func LetToL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
|
||||
568
v2/state/bind_test.go
Normal file
568
v2/state/bind_test.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test types for bind operations
|
||||
type BindTestState struct {
|
||||
Counter int
|
||||
Multiplier int
|
||||
}
|
||||
|
||||
type Accumulator struct {
|
||||
Value int
|
||||
Doubled int
|
||||
Status string
|
||||
}
|
||||
|
||||
// TestDo verifies that Do initializes a computation with an empty value
|
||||
func TestDo(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 2}
|
||||
emptyAcc := Accumulator{Value: 0, Doubled: 0, Status: ""}
|
||||
|
||||
computation := Do[BindTestState](emptyAcc)
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, emptyAcc, pair.Tail(result), "value should be empty accumulator")
|
||||
}
|
||||
|
||||
// TestBind verifies that Bind sequences a computation and binds the result
|
||||
func TestBind(t *testing.T) {
|
||||
initial := BindTestState{Counter: 10, Multiplier: 3}
|
||||
|
||||
// Start with an accumulator
|
||||
startAcc := Accumulator{Value: 5, Doubled: 0, Status: ""}
|
||||
|
||||
// Bind a computation that reads from state and updates the accumulator
|
||||
computation := Bind(
|
||||
func(doubled int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Doubled = doubled
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int {
|
||||
return acc.Value * s.Multiplier
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
result := computation(Of[BindTestState](startAcc))(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 5, pair.Tail(result).Value, "value should be preserved")
|
||||
assert.Equal(t, 15, pair.Tail(result).Doubled, "doubled should be 5 * 3 = 15")
|
||||
}
|
||||
|
||||
// TestBindChaining verifies chaining multiple Bind operations
|
||||
func TestBindChaining(t *testing.T) {
|
||||
initial := BindTestState{Counter: 10, Multiplier: 2}
|
||||
|
||||
computation := F.Pipe3(
|
||||
Do[BindTestState](Accumulator{}),
|
||||
Bind(
|
||||
func(v int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Value = v
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int { return s.Counter })
|
||||
},
|
||||
),
|
||||
Bind(
|
||||
func(d int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Doubled = d
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int {
|
||||
return acc.Value * s.Multiplier
|
||||
})
|
||||
},
|
||||
),
|
||||
Bind(
|
||||
func(status string) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Status = status
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, string] {
|
||||
return Of[BindTestState]("completed")
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 10, pair.Tail(result).Value, "value should be 10")
|
||||
assert.Equal(t, 20, pair.Tail(result).Doubled, "doubled should be 20")
|
||||
assert.Equal(t, "completed", pair.Tail(result).Status, "status should be completed")
|
||||
}
|
||||
|
||||
// TestLet verifies that Let computes a pure value and binds it
|
||||
func TestLet(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 2}
|
||||
startAcc := Accumulator{Value: 10, Doubled: 20, Status: ""}
|
||||
|
||||
// Compute sum from existing fields
|
||||
computation := Let[BindTestState](
|
||||
func(sum int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Status = "sum"
|
||||
acc.Value = sum
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) int {
|
||||
return acc.Value + acc.Doubled
|
||||
},
|
||||
)
|
||||
|
||||
result := computation(Of[BindTestState](startAcc))(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 30, pair.Tail(result).Value, "value should be sum: 10 + 20 = 30")
|
||||
assert.Equal(t, "sum", pair.Tail(result).Status, "status should be set")
|
||||
}
|
||||
|
||||
// TestLetTo verifies that LetTo binds a constant value
|
||||
func TestLetTo(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 2}
|
||||
startAcc := Accumulator{Value: 10, Doubled: 0, Status: ""}
|
||||
|
||||
computation := LetTo[BindTestState](
|
||||
func(status string) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Status = status
|
||||
return acc
|
||||
}
|
||||
},
|
||||
"initialized",
|
||||
)
|
||||
|
||||
result := computation(Of[BindTestState](startAcc))(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, "initialized", pair.Tail(result).Status, "status should be initialized")
|
||||
assert.Equal(t, 10, pair.Tail(result).Value, "value should be preserved")
|
||||
}
|
||||
|
||||
// TestBindTo verifies that BindTo creates an initial accumulator
|
||||
func TestBindTo(t *testing.T) {
|
||||
initial := BindTestState{Counter: 42, Multiplier: 2}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[BindTestState](100),
|
||||
BindTo[BindTestState](func(x int) Accumulator {
|
||||
return Accumulator{Value: x, Doubled: 0, Status: "created"}
|
||||
}),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 100, pair.Tail(result).Value, "value should be 100")
|
||||
assert.Equal(t, 0, pair.Tail(result).Doubled, "doubled should be 0")
|
||||
assert.Equal(t, "created", pair.Tail(result).Status, "status should be created")
|
||||
}
|
||||
|
||||
// TestBindToWithPipeline verifies BindTo in a complete pipeline
|
||||
func TestBindToWithPipeline(t *testing.T) {
|
||||
initial := BindTestState{Counter: 10, Multiplier: 3}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[BindTestState](5),
|
||||
BindTo[BindTestState](func(x int) Accumulator {
|
||||
return Accumulator{Value: x}
|
||||
}),
|
||||
Bind(
|
||||
func(d int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Doubled = d
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int {
|
||||
return acc.Value * s.Multiplier
|
||||
})
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 5, pair.Tail(result).Value, "value should be 5")
|
||||
assert.Equal(t, 15, pair.Tail(result).Doubled, "doubled should be 15")
|
||||
}
|
||||
|
||||
// TestApS verifies applicative-style binding
|
||||
func TestApS(t *testing.T) {
|
||||
initial := BindTestState{Counter: 7, Multiplier: 2}
|
||||
startAcc := Accumulator{Value: 10, Doubled: 0, Status: ""}
|
||||
|
||||
// Independent computation that doesn't depend on accumulator
|
||||
getCounter := Gets(func(s BindTestState) int { return s.Counter })
|
||||
|
||||
computation := ApS(
|
||||
func(c int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Doubled = c
|
||||
return acc
|
||||
}
|
||||
},
|
||||
getCounter,
|
||||
)
|
||||
|
||||
result := computation(Of[BindTestState](startAcc))(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 10, pair.Tail(result).Value, "value should be preserved")
|
||||
assert.Equal(t, 7, pair.Tail(result).Doubled, "doubled should be counter value")
|
||||
}
|
||||
|
||||
// TestApSL verifies lens-based applicative binding
|
||||
func TestApSL(t *testing.T) {
|
||||
initial := BindTestState{Counter: 42, Multiplier: 2}
|
||||
|
||||
// Create a lens for the Doubled field
|
||||
doubledLens := lens.MakeLens(
|
||||
func(acc Accumulator) int { return acc.Doubled },
|
||||
func(acc Accumulator, d int) Accumulator {
|
||||
acc.Doubled = d
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
getCounter := Gets(func(s BindTestState) int { return s.Counter })
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[BindTestState](Accumulator{Value: 10}),
|
||||
ApSL(doubledLens, getCounter),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 10, pair.Tail(result).Value, "value should be preserved")
|
||||
assert.Equal(t, 42, pair.Tail(result).Doubled, "doubled should be set to counter")
|
||||
}
|
||||
|
||||
// TestBindL verifies lens-based monadic binding
|
||||
func TestBindL(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 3}
|
||||
|
||||
// Create a lens for the Value field
|
||||
valueLens := lens.MakeLens(
|
||||
func(acc Accumulator) int { return acc.Value },
|
||||
func(acc Accumulator, v int) Accumulator {
|
||||
acc.Value = v
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
// Multiply the value by the state's multiplier
|
||||
computation := F.Pipe1(
|
||||
Of[BindTestState](Accumulator{Value: 10}),
|
||||
BindL(
|
||||
valueLens,
|
||||
func(v int) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int {
|
||||
return v * s.Multiplier
|
||||
})
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 30, pair.Tail(result).Value, "value should be 10 * 3 = 30")
|
||||
}
|
||||
|
||||
// TestLetL verifies lens-based pure transformation
|
||||
func TestLetL(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 2}
|
||||
|
||||
// Create a lens for the Value field
|
||||
valueLens := lens.MakeLens(
|
||||
func(acc Accumulator) int { return acc.Value },
|
||||
func(acc Accumulator, v int) Accumulator {
|
||||
acc.Value = v
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
// Double the value using a pure function
|
||||
computation := F.Pipe1(
|
||||
Of[BindTestState](Accumulator{Value: 21}),
|
||||
LetL[BindTestState](valueLens, func(v int) int { return v * 2 }),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 42, pair.Tail(result).Value, "value should be doubled to 42")
|
||||
}
|
||||
|
||||
// TestLetToL verifies lens-based constant binding
|
||||
func TestLetToL(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 2}
|
||||
|
||||
// Create a lens for the Status field
|
||||
statusLens := lens.MakeLens(
|
||||
func(acc Accumulator) string { return acc.Status },
|
||||
func(acc Accumulator, s string) Accumulator {
|
||||
acc.Status = s
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[BindTestState](Accumulator{Value: 10}),
|
||||
LetToL[BindTestState](statusLens, "completed"),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 10, pair.Tail(result).Value, "value should be preserved")
|
||||
assert.Equal(t, "completed", pair.Tail(result).Status, "status should be set")
|
||||
}
|
||||
|
||||
// TestComplexDoNotation verifies a complex do-notation pipeline
|
||||
func TestComplexDoNotation(t *testing.T) {
|
||||
initial := BindTestState{Counter: 10, Multiplier: 2}
|
||||
|
||||
// Create lenses
|
||||
valueLens := lens.MakeLens(
|
||||
func(acc Accumulator) int { return acc.Value },
|
||||
func(acc Accumulator, v int) Accumulator {
|
||||
acc.Value = v
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
doubledLens := lens.MakeLens(
|
||||
func(acc Accumulator) int { return acc.Doubled },
|
||||
func(acc Accumulator, d int) Accumulator {
|
||||
acc.Doubled = d
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
statusLens := lens.MakeLens(
|
||||
func(acc Accumulator) string { return acc.Status },
|
||||
func(acc Accumulator, s string) Accumulator {
|
||||
acc.Status = s
|
||||
return acc
|
||||
},
|
||||
)
|
||||
|
||||
computation := F.Pipe5(
|
||||
Do[BindTestState](Accumulator{}),
|
||||
// Get counter from state and bind to Value
|
||||
Bind(
|
||||
valueLens.Set,
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int { return s.Counter })
|
||||
},
|
||||
),
|
||||
// Compute doubled value using state
|
||||
BindL(
|
||||
doubledLens,
|
||||
func(d int) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int {
|
||||
return d * s.Multiplier
|
||||
})
|
||||
},
|
||||
),
|
||||
// Add a pure computation
|
||||
LetL[BindTestState](valueLens, func(v int) int { return v + 5 }),
|
||||
// Set a constant
|
||||
LetToL[BindTestState](statusLens, "processed"),
|
||||
// Extract final result
|
||||
Map[BindTestState](func(acc Accumulator) int {
|
||||
return acc.Value + acc.Doubled
|
||||
}),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
// Value: 10 (counter) + 5 = 15
|
||||
// Doubled: 0 * 2 = 0 (doubled starts at 0)
|
||||
// Sum: 15 + 0 = 15
|
||||
assert.Equal(t, 15, pair.Tail(result), "final result should be 15")
|
||||
}
|
||||
|
||||
// TestDoNotationWithModify verifies do-notation with state modifications
|
||||
func TestDoNotationWithModify(t *testing.T) {
|
||||
initial := BindTestState{Counter: 0, Multiplier: 2}
|
||||
|
||||
computation := F.Pipe3(
|
||||
Do[BindTestState](Accumulator{}),
|
||||
// Increment counter in state
|
||||
Bind(
|
||||
func(v int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Value = v
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return MonadChain(
|
||||
Modify(func(s BindTestState) BindTestState {
|
||||
s.Counter++
|
||||
return s
|
||||
}),
|
||||
func(_ Void) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int { return s.Counter })
|
||||
},
|
||||
)
|
||||
},
|
||||
),
|
||||
// Increment again
|
||||
Bind(
|
||||
func(d int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Doubled = d
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return MonadChain(
|
||||
Modify(func(s BindTestState) BindTestState {
|
||||
s.Counter++
|
||||
return s
|
||||
}),
|
||||
func(_ Void) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int { return s.Counter })
|
||||
},
|
||||
)
|
||||
},
|
||||
),
|
||||
Map[BindTestState](func(acc Accumulator) Accumulator { return acc }),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 2, pair.Head(result).Counter, "counter should be incremented twice")
|
||||
assert.Equal(t, 1, pair.Tail(result).Value, "value should be 1")
|
||||
assert.Equal(t, 2, pair.Tail(result).Doubled, "doubled should be 2")
|
||||
}
|
||||
|
||||
// TestLetWithComplexComputation verifies Let with complex pure computations
|
||||
func TestLetWithComplexComputation(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 2}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[BindTestState](Accumulator{Value: 10, Doubled: 20}),
|
||||
Let[BindTestState](
|
||||
func(sum int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Value = sum
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) int {
|
||||
// Complex computation from accumulator
|
||||
return (acc.Value + acc.Doubled) * 2
|
||||
},
|
||||
),
|
||||
Let[BindTestState](
|
||||
func(status string) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Status = status
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) string {
|
||||
if acc.Value > 50 {
|
||||
return "high"
|
||||
}
|
||||
return "low"
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
// (10 + 20) * 2 = 60
|
||||
assert.Equal(t, 60, pair.Tail(result).Value, "value should be 60")
|
||||
assert.Equal(t, "high", pair.Tail(result).Status, "status should be high")
|
||||
}
|
||||
|
||||
// TestMixedBindAndLet verifies mixing Bind and Let operations
|
||||
func TestMixedBindAndLet(t *testing.T) {
|
||||
initial := BindTestState{Counter: 5, Multiplier: 3}
|
||||
|
||||
computation := F.Pipe4(
|
||||
Do[BindTestState](Accumulator{}),
|
||||
// Bind: stateful operation
|
||||
Bind(
|
||||
func(v int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Value = v
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) State[BindTestState, int] {
|
||||
return Gets(func(s BindTestState) int { return s.Counter })
|
||||
},
|
||||
),
|
||||
// Let: pure computation
|
||||
Let[BindTestState](
|
||||
func(d int) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Doubled = d
|
||||
return acc
|
||||
}
|
||||
},
|
||||
func(acc Accumulator) int {
|
||||
return acc.Value * 2
|
||||
},
|
||||
),
|
||||
// LetTo: constant
|
||||
LetTo[BindTestState](
|
||||
func(s string) func(Accumulator) Accumulator {
|
||||
return func(acc Accumulator) Accumulator {
|
||||
acc.Status = s
|
||||
return acc
|
||||
}
|
||||
},
|
||||
"done",
|
||||
),
|
||||
Map[BindTestState](func(acc Accumulator) Accumulator { return acc }),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 5, pair.Tail(result).Value, "value should be 5")
|
||||
assert.Equal(t, 10, pair.Tail(result).Doubled, "doubled should be 10")
|
||||
assert.Equal(t, "done", pair.Tail(result).Status, "status should be done")
|
||||
}
|
||||
@@ -20,52 +20,137 @@ import (
|
||||
"github.com/IBM/fp-go/v2/internal/chain"
|
||||
"github.com/IBM/fp-go/v2/internal/functor"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
)
|
||||
|
||||
var (
|
||||
undefined any = struct{}{}
|
||||
)
|
||||
|
||||
// Get returns a State computation that retrieves the current state and returns it as the value.
|
||||
// The state is unchanged by this operation.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// getState := Get[Counter]()
|
||||
// result := getState(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: Counter{count: 5}}
|
||||
//
|
||||
//go:inline
|
||||
func Get[S any]() State[S, S] {
|
||||
return pair.Of[S]
|
||||
}
|
||||
|
||||
// Gets applies a function to the current state and returns the result as the value.
|
||||
// The state itself remains unchanged. This is useful for extracting or computing
|
||||
// values from the state without modifying it.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// getDouble := Gets(func(c Counter) int { return c.count * 2 })
|
||||
// result := getDouble(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 10}
|
||||
//
|
||||
//go:line
|
||||
func Gets[FCT ~func(S) A, A, S any](f FCT) State[S, A] {
|
||||
return func(s S) Pair[S, A] {
|
||||
return pair.MakePair(s, f(s))
|
||||
}
|
||||
}
|
||||
|
||||
// Put returns a State computation that replaces the current state with a new state.
|
||||
// The returned value is Void, indicating this operation is performed for its side effect.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// setState := Put[Counter]()
|
||||
// result := setState(Counter{count: 10})
|
||||
// // result = Pair{head: Counter{count: 10}, tail: Void}
|
||||
//
|
||||
//go:inline
|
||||
func Put[S any]() State[S, any] {
|
||||
return function.Bind2nd(pair.MakePair[S, any], undefined)
|
||||
func Put[S any]() State[S, Void] {
|
||||
return Of[S](function.VOID)
|
||||
}
|
||||
|
||||
func Modify[FCT ~func(S) S, S any](f FCT) State[S, any] {
|
||||
// Modify applies a transformation function to the current state, producing a new state.
|
||||
// The returned value is Void, indicating this operation is performed for its side effect.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// increment := Modify(func(c Counter) Counter { return Counter{count: c.count + 1} })
|
||||
// result := increment(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 6}, tail: Void}
|
||||
func Modify[FCT ~func(S) S, S any](f FCT) State[S, Void] {
|
||||
return function.Flow2(
|
||||
f,
|
||||
function.Bind2nd(pair.MakePair[S, any], undefined),
|
||||
Put[S](),
|
||||
)
|
||||
}
|
||||
|
||||
// Of creates a State computation that returns the given value without modifying the state.
|
||||
// This is the Pointed interface implementation for State, lifting a pure value into
|
||||
// the State context.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := Of[Counter](42)
|
||||
// result := computation(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 42}
|
||||
//
|
||||
//go:inline
|
||||
func Of[S, A any](a A) State[S, A] {
|
||||
return function.Bind2nd(pair.MakePair[S, A], a)
|
||||
return pair.FromTail[S](a)
|
||||
}
|
||||
|
||||
// MonadMap transforms the value produced by a State computation using the given function,
|
||||
// while preserving the state. This is the Functor interface implementation for State.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := Of[Counter](10)
|
||||
// doubled := MonadMap(computation, func(x int) int { return x * 2 })
|
||||
// result := doubled(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 20}
|
||||
//
|
||||
//go:inline
|
||||
func MonadMap[S any, FCT ~func(A) B, A, B any](fa State[S, A], f FCT) State[S, B] {
|
||||
return func(s S) Pair[S, B] {
|
||||
p2 := fa(s)
|
||||
return pair.MakePair(pair.Head(p2), f(pair.Tail(p2)))
|
||||
}
|
||||
return reader.MonadMap(fa, pair.Map[S](f))
|
||||
}
|
||||
|
||||
// Map returns a function that transforms the value of a State computation.
|
||||
// This is the curried version of MonadMap, useful for composition and pipelines.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// double := Map[Counter](func(x int) int { return x * 2 })
|
||||
// computation := function.Pipe1(Of[Counter](10), double)
|
||||
// result := computation(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 20}
|
||||
//
|
||||
//go:inline
|
||||
func Map[S any, FCT ~func(A) B, A, B any](f FCT) Operator[S, A, B] {
|
||||
return function.Bind2nd(MonadMap[S, FCT, A, B], f)
|
||||
return reader.Map[S](pair.Map[S](f))
|
||||
}
|
||||
|
||||
// MonadChain sequences two State computations, where the second computation depends
|
||||
// on the value produced by the first. The state is threaded through both computations.
|
||||
// This is the Monad interface implementation for State.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := Of[Counter](5)
|
||||
// chained := MonadChain(computation, func(x int) State[Counter, int] {
|
||||
// return func(s Counter) Pair[Counter, int] {
|
||||
// newState := Counter{count: s.count + x}
|
||||
// return pair.MakePair(newState, x * 2)
|
||||
// }
|
||||
// })
|
||||
// result := chained(Counter{count: 10})
|
||||
// // result = Pair{head: Counter{count: 15}, tail: 10}
|
||||
func MonadChain[S any, FCT ~func(A) State[S, B], A, B any](fa State[S, A], f FCT) State[S, B] {
|
||||
return func(s S) Pair[S, B] {
|
||||
a := fa(s)
|
||||
@@ -73,11 +158,38 @@ func MonadChain[S any, FCT ~func(A) State[S, B], A, B any](fa State[S, A], f FCT
|
||||
}
|
||||
}
|
||||
|
||||
// Chain returns a function that sequences State computations.
|
||||
// This is the curried version of MonadChain, useful for composition and pipelines.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// addToCounter := func(x int) State[Counter, int] {
|
||||
// return func(s Counter) Pair[Counter, int] {
|
||||
// newState := Counter{count: s.count + x}
|
||||
// return pair.MakePair(newState, newState.count)
|
||||
// }
|
||||
// }
|
||||
// computation := function.Pipe1(Of[Counter](5), Chain(addToCounter))
|
||||
// result := computation(Counter{count: 10})
|
||||
// // result = Pair{head: Counter{count: 15}, tail: 15}
|
||||
//
|
||||
//go:inline
|
||||
func Chain[S any, FCT ~func(A) State[S, B], A, B any](f FCT) Operator[S, A, B] {
|
||||
return function.Bind2nd(MonadChain[S, FCT, A, B], f)
|
||||
}
|
||||
|
||||
// MonadAp applies a State computation containing a function to a State computation
|
||||
// containing a value. Both computations are executed sequentially, threading the state
|
||||
// through both. This is the Applicative interface implementation for State.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// fab := Of[Counter](func(x int) int { return x * 2 })
|
||||
// fa := Of[Counter](21)
|
||||
// result := MonadAp(fab, fa)(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 42}
|
||||
func MonadAp[B, S, A any](fab State[S, func(A) B], fa State[S, A]) State[S, B] {
|
||||
return func(s S) Pair[S, B] {
|
||||
f := fab(s)
|
||||
@@ -87,11 +199,37 @@ func MonadAp[B, S, A any](fab State[S, func(A) B], fa State[S, A]) State[S, B] {
|
||||
}
|
||||
}
|
||||
|
||||
// Ap returns a function that applies a State computation containing a function
|
||||
// to a State computation containing a value. This is the curried version of MonadAp.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := function.Pipe1(
|
||||
// Of[Counter](func(x int) int { return x * 2 }),
|
||||
// Ap[int](Of[Counter](21)),
|
||||
// )
|
||||
// result := computation(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 42}
|
||||
//
|
||||
//go:inline
|
||||
func Ap[B, S, A any](ga State[S, A]) Operator[S, func(A) B, B] {
|
||||
return function.Bind2nd(MonadAp[B, S, A], ga)
|
||||
}
|
||||
|
||||
// MonadChainFirst sequences two State computations but returns the value from the first
|
||||
// computation while still threading the state through both. This is useful when you want
|
||||
// to perform a stateful side effect but keep the original value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := Of[Counter](42)
|
||||
// increment := func(x int) State[Counter, Void] {
|
||||
// return Modify(func(c Counter) Counter { return Counter{count: c.count + 1} })
|
||||
// }
|
||||
// result := MonadChainFirst(computation, increment)(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 6}, tail: 42}
|
||||
func MonadChainFirst[S any, FCT ~func(A) State[S, B], A, B any](ma State[S, A], f FCT) State[S, A] {
|
||||
return chain.MonadChainFirst(
|
||||
MonadChain[S, func(A) State[S, A], A, A],
|
||||
@@ -101,6 +239,18 @@ func MonadChainFirst[S any, FCT ~func(A) State[S, B], A, B any](ma State[S, A],
|
||||
)
|
||||
}
|
||||
|
||||
// ChainFirst returns a function that sequences State computations but keeps the first value.
|
||||
// This is the curried version of MonadChainFirst, useful for composition and pipelines.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// increment := func(x int) State[Counter, Void] {
|
||||
// return Modify(func(c Counter) Counter { return Counter{count: c.count + 1} })
|
||||
// }
|
||||
// computation := function.Pipe1(Of[Counter](42), ChainFirst(increment))
|
||||
// result := computation(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 6}, tail: 42}
|
||||
func ChainFirst[S any, FCT ~func(A) State[S, B], A, B any](f FCT) Operator[S, A, A] {
|
||||
return chain.ChainFirst(
|
||||
Chain[S, func(A) State[S, A], A, A],
|
||||
@@ -109,23 +259,64 @@ func ChainFirst[S any, FCT ~func(A) State[S, B], A, B any](f FCT) Operator[S, A,
|
||||
)
|
||||
}
|
||||
|
||||
// Flatten removes one level of nesting from a State computation that produces another
|
||||
// State computation. This is equivalent to MonadChain with the identity function.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// nested := Of[Counter](Of[Counter](42))
|
||||
// flattened := Flatten(nested)
|
||||
// result := flattened(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 42}
|
||||
//
|
||||
//go:inline
|
||||
func Flatten[S, A any](mma State[S, State[S, A]]) State[S, A] {
|
||||
return MonadChain(mma, function.Identity[State[S, A]])
|
||||
}
|
||||
|
||||
// Execute runs a State computation with the given initial state and returns only
|
||||
// the final state, discarding the computed value. This is useful when you only
|
||||
// care about the state transformations.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := Modify(func(c Counter) Counter { return Counter{count: c.count + 1} })
|
||||
// finalState := Execute[Void, Counter](Counter{count: 5})(computation)
|
||||
// // finalState = Counter{count: 6}
|
||||
func Execute[A, S any](s S) func(State[S, A]) S {
|
||||
return func(fa State[S, A]) S {
|
||||
return pair.Head(fa(s))
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate runs a State computation with the given initial state and returns only
|
||||
// the computed value, discarding the final state. This is useful when you only
|
||||
// care about the result of the computation.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// computation := Of[Counter](42)
|
||||
// value := Evaluate[int, Counter](Counter{count: 5})(computation)
|
||||
// // value = 42
|
||||
func Evaluate[A, S any](s S) func(State[S, A]) A {
|
||||
return func(fa State[S, A]) A {
|
||||
return pair.Tail(fa(s))
|
||||
}
|
||||
}
|
||||
|
||||
// MonadFlap applies a fixed value to a State computation containing a function.
|
||||
// This is the reverse of MonadAp, where the value is known but the function is
|
||||
// in the State context.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// fab := Of[Counter](func(x int) int { return x * 2 })
|
||||
// result := MonadFlap(fab, 21)(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 42}
|
||||
func MonadFlap[FAB ~func(A) B, S, A, B any](fab State[S, FAB], a A) State[S, B] {
|
||||
return functor.MonadFlap(
|
||||
MonadMap[S, func(FAB) B],
|
||||
@@ -133,6 +324,19 @@ func MonadFlap[FAB ~func(A) B, S, A, B any](fab State[S, FAB], a A) State[S, B]
|
||||
a)
|
||||
}
|
||||
|
||||
// Flap returns a function that applies a fixed value to a State computation containing
|
||||
// a function. This is the curried version of MonadFlap.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Counter struct { count int }
|
||||
// applyTwentyOne := Flap[Counter, int, int](21)
|
||||
// computation := function.Pipe1(
|
||||
// Of[Counter](func(x int) int { return x * 2 }),
|
||||
// applyTwentyOne,
|
||||
// )
|
||||
// result := computation(Counter{count: 5})
|
||||
// // result = Pair{head: Counter{count: 5}, tail: 42}
|
||||
func Flap[S, A, B any](a A) Operator[S, func(A) B, B] {
|
||||
return functor.Flap(
|
||||
Map[S, func(func(A) B) B],
|
||||
|
||||
502
v2/state/state_test.go
Normal file
502
v2/state/state_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
N "github.com/IBM/fp-go/v2/number"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestState struct {
|
||||
Counter int
|
||||
Message string
|
||||
}
|
||||
|
||||
// TestGet verifies that Get returns the current state as both state and value
|
||||
func TestGet(t *testing.T) {
|
||||
initial := TestState{Counter: 42, Message: "test"}
|
||||
computation := Get[TestState]()
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, initial, pair.Tail(result), "value should equal state")
|
||||
}
|
||||
|
||||
// TestGets verifies that Gets applies a function to state and returns the result
|
||||
func TestGets(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
// Extract and double the counter
|
||||
computation := Gets(func(s TestState) int {
|
||||
return s.Counter * 2
|
||||
})
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 10, pair.Tail(result), "value should be doubled counter")
|
||||
}
|
||||
|
||||
// TestPut verifies that Put replaces the state
|
||||
func TestPut(t *testing.T) {
|
||||
newState := TestState{Counter: 10, Message: "new"}
|
||||
|
||||
computation := Put[TestState]()
|
||||
|
||||
result := computation(newState)
|
||||
|
||||
assert.Equal(t, newState, pair.Head(result), "state should be replaced")
|
||||
assert.Equal(t, F.VOID, pair.Tail(result), "value should be Void")
|
||||
}
|
||||
|
||||
// TestModify verifies that Modify transforms the state
|
||||
func TestModify(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
increment := Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter + 1, Message: s.Message}
|
||||
})
|
||||
|
||||
result := increment(initial)
|
||||
|
||||
assert.Equal(t, 6, pair.Head(result).Counter, "counter should be incremented")
|
||||
assert.Equal(t, "test", pair.Head(result).Message, "message should be unchanged")
|
||||
assert.Equal(t, F.VOID, pair.Tail(result), "value should be Void")
|
||||
}
|
||||
|
||||
// TestOf verifies that Of creates a computation with a value and unchanged state
|
||||
func TestOf(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
computation := Of[TestState](42)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be 42")
|
||||
}
|
||||
|
||||
// TestMonadMap verifies that MonadMap transforms the value
|
||||
func TestMonadMap(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
computation := Of[TestState](10)
|
||||
|
||||
mapped := MonadMap(computation, N.Mul(2))
|
||||
result := mapped(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 20, pair.Tail(result), "value should be doubled")
|
||||
}
|
||||
|
||||
// TestMap verifies the curried version of MonadMap
|
||||
func TestMap(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
double := Map[TestState](N.Mul(2))
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](21),
|
||||
double,
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be doubled")
|
||||
}
|
||||
|
||||
// TestMonadChain verifies that MonadChain sequences computations
|
||||
func TestMonadChain(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
computation := Of[TestState](5)
|
||||
chained := MonadChain(computation, func(x int) State[TestState, string] {
|
||||
return func(s TestState) Pair[TestState, string] {
|
||||
newState := TestState{Counter: s.Counter + x, Message: fmt.Sprintf("value: %d", x)}
|
||||
return pair.MakePair(newState, fmt.Sprintf("result: %d", x*2))
|
||||
}
|
||||
})
|
||||
|
||||
result := chained(initial)
|
||||
|
||||
assert.Equal(t, "result: 10", pair.Tail(result), "value should be transformed")
|
||||
assert.Equal(t, 5, pair.Head(result).Counter, "counter should be updated")
|
||||
assert.Equal(t, "value: 5", pair.Head(result).Message, "message should be set")
|
||||
}
|
||||
|
||||
// TestChain verifies the curried version of MonadChain
|
||||
func TestChain(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
addToCounter := func(x int) State[TestState, int] {
|
||||
return func(s TestState) Pair[TestState, int] {
|
||||
newState := TestState{Counter: s.Counter + x, Message: s.Message}
|
||||
return pair.MakePair(newState, s.Counter+x)
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](5),
|
||||
Chain(addToCounter),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 5, pair.Tail(result), "value should be 5")
|
||||
assert.Equal(t, 5, pair.Head(result).Counter, "counter should be 5")
|
||||
}
|
||||
|
||||
// TestMonadAp verifies applicative application
|
||||
func TestMonadAp(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
fab := Of[TestState](N.Mul(3))
|
||||
fa := Of[TestState](7)
|
||||
|
||||
result := MonadAp(fab, fa)(initial)
|
||||
|
||||
assert.Equal(t, 21, pair.Tail(result), "value should be 21")
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
}
|
||||
|
||||
// TestAp verifies the curried version of MonadAp
|
||||
func TestAp(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](N.Mul(4)),
|
||||
Ap[int](Of[TestState](10)),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 40, pair.Tail(result), "value should be 40")
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
}
|
||||
|
||||
// TestMonadChainFirst verifies that ChainFirst keeps the first value
|
||||
func TestMonadChainFirst(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
computation := Of[TestState](42)
|
||||
increment := func(x int) State[TestState, Void] {
|
||||
return Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter + 1, Message: s.Message}
|
||||
})
|
||||
}
|
||||
|
||||
result := MonadChainFirst(computation, increment)(initial)
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be preserved")
|
||||
assert.Equal(t, 1, pair.Head(result).Counter, "counter should be incremented")
|
||||
}
|
||||
|
||||
// TestChainFirst verifies the curried version of MonadChainFirst
|
||||
func TestChainFirst(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
increment := func(x int) State[TestState, Void] {
|
||||
return Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter + 1, Message: s.Message}
|
||||
})
|
||||
}
|
||||
|
||||
computation := F.Pipe1(Of[TestState](42), ChainFirst(increment))
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be preserved")
|
||||
assert.Equal(t, 6, pair.Head(result).Counter, "counter should be incremented")
|
||||
}
|
||||
|
||||
// TestFlatten verifies that Flatten removes one level of nesting
|
||||
func TestFlatten(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
nested := Of[TestState](Of[TestState](42))
|
||||
flattened := Flatten(nested)
|
||||
|
||||
result := flattened(initial)
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be 42")
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
}
|
||||
|
||||
// TestExecute verifies that Execute returns only the final state
|
||||
func TestExecute(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "old"}
|
||||
|
||||
computation := Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter + 1, Message: "new"}
|
||||
})
|
||||
|
||||
finalState := Execute[Void, TestState](initial)(computation)
|
||||
|
||||
assert.Equal(t, 6, finalState.Counter, "counter should be incremented")
|
||||
assert.Equal(t, "new", finalState.Message, "message should be updated")
|
||||
}
|
||||
|
||||
// TestEvaluate verifies that Evaluate returns only the value
|
||||
func TestEvaluate(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
computation := Of[TestState](42)
|
||||
|
||||
value := Evaluate[int, TestState](initial)(computation)
|
||||
|
||||
assert.Equal(t, 42, value, "value should be 42")
|
||||
}
|
||||
|
||||
// TestMonadFlap verifies that MonadFlap applies a value to a function in State
|
||||
func TestMonadFlap(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
fab := Of[TestState](func(x int) int { return x * 2 })
|
||||
result := MonadFlap(fab, 21)(initial)
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be 42")
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
}
|
||||
|
||||
// TestFlap verifies the curried version of MonadFlap
|
||||
func TestFlap(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
applyTwentyOne := Flap[TestState, int, int](21)
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](func(x int) int { return x * 2 }),
|
||||
applyTwentyOne,
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be 42")
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
}
|
||||
|
||||
// TestChainedOperations verifies complex chained operations
|
||||
func TestChainedOperations(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
incrementCounter := func(x int) State[TestState, int] {
|
||||
return func(s TestState) Pair[TestState, int] {
|
||||
newState := TestState{Counter: s.Counter + x, Message: s.Message}
|
||||
return pair.MakePair(newState, newState.Counter)
|
||||
}
|
||||
}
|
||||
|
||||
setMessage := func(count int) State[TestState, string] {
|
||||
return func(s TestState) Pair[TestState, string] {
|
||||
msg := fmt.Sprintf("Count is %d", count)
|
||||
newState := TestState{Counter: s.Counter, Message: msg}
|
||||
return pair.MakePair(newState, msg)
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](5),
|
||||
Chain(incrementCounter),
|
||||
Chain(setMessage),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, "Count is 5", pair.Tail(result), "value should be message")
|
||||
assert.Equal(t, 5, pair.Head(result).Counter, "counter should be 5")
|
||||
assert.Equal(t, "Count is 5", pair.Head(result).Message, "message should be set")
|
||||
}
|
||||
|
||||
// TestMapPreservesState verifies that Map operations don't modify state
|
||||
func TestMapPreservesState(t *testing.T) {
|
||||
initial := TestState{Counter: 42, Message: "important"}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](10),
|
||||
Map[TestState](N.Mul(2)),
|
||||
Map[TestState](N.Add(5)),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
// Value should be transformed: 10 * 2 + 5 = 25
|
||||
assert.Equal(t, 25, pair.Tail(result), "value should be 25")
|
||||
// State should be unchanged
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
}
|
||||
|
||||
// TestChainModifiesState verifies that Chain operations can modify state
|
||||
func TestChainModifiesState(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
addOne := func(x int) State[TestState, int] {
|
||||
return func(s TestState) Pair[TestState, int] {
|
||||
newState := TestState{Counter: s.Counter + 1, Message: s.Message}
|
||||
return pair.MakePair(newState, x+1)
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](0),
|
||||
Chain(addOne),
|
||||
Chain(addOne),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 2, pair.Tail(result), "value should be 2")
|
||||
assert.Equal(t, 2, pair.Head(result).Counter, "counter should be 2")
|
||||
}
|
||||
|
||||
// TestApplicativeComposition verifies applicative composition
|
||||
func TestApplicativeComposition(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
add := func(x int) func(int) int {
|
||||
return func(y int) int {
|
||||
return x + y
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](add(10)),
|
||||
Ap[int](Of[TestState](32)),
|
||||
)
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result), "value should be 42")
|
||||
}
|
||||
|
||||
// TestStatefulComputation verifies a computation that reads and modifies state
|
||||
func TestStatefulComputation(t *testing.T) {
|
||||
initial := TestState{Counter: 10, Message: "start"}
|
||||
|
||||
// A computation that reads and modifies state
|
||||
computation := func(s TestState) Pair[TestState, int] {
|
||||
newState := TestState{
|
||||
Counter: s.Counter * 2,
|
||||
Message: fmt.Sprintf("%s -> doubled", s.Message),
|
||||
}
|
||||
return pair.MakePair(newState, newState.Counter)
|
||||
}
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 20, pair.Tail(result), "value should be 20")
|
||||
assert.Equal(t, 20, pair.Head(result).Counter, "counter should be 20")
|
||||
assert.Equal(t, "start -> doubled", pair.Head(result).Message, "message should be updated")
|
||||
}
|
||||
|
||||
// TestGetAndModify verifies combining Get and Modify
|
||||
func TestGetAndModify(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
step1 := Chain(func(s TestState) State[TestState, Void] {
|
||||
return Modify(func(_ TestState) TestState {
|
||||
return TestState{Counter: s.Counter * 2, Message: s.Message + "!"}
|
||||
})
|
||||
})
|
||||
|
||||
step2 := Chain(func(_ Void) State[TestState, TestState] {
|
||||
return Get[TestState]()
|
||||
})
|
||||
|
||||
computation := step2(step1(Get[TestState]()))
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 10, pair.Tail(result).Counter, "counter should be doubled")
|
||||
assert.Equal(t, "test!", pair.Tail(result).Message, "message should have exclamation")
|
||||
}
|
||||
|
||||
// TestGetsWithComplexExtraction verifies Gets with complex state extraction
|
||||
func TestGetsWithComplexExtraction(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "hello"}
|
||||
|
||||
computation := Gets(func(s TestState) string {
|
||||
return fmt.Sprintf("%s: %d", s.Message, s.Counter)
|
||||
})
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, initial, pair.Head(result), "state should be unchanged")
|
||||
assert.Equal(t, "hello: 5", pair.Tail(result), "value should be formatted string")
|
||||
}
|
||||
|
||||
// TestMultipleModifications verifies multiple state modifications
|
||||
func TestMultipleModifications(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
increment := Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter + 1, Message: s.Message}
|
||||
})
|
||||
|
||||
setMessage := Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter, Message: "done"}
|
||||
})
|
||||
|
||||
step1 := ChainFirst(func(_ Void) State[TestState, Void] { return increment })
|
||||
step2 := ChainFirst(func(_ Void) State[TestState, Void] { return setMessage })
|
||||
|
||||
computation := step2(step1(increment))
|
||||
|
||||
result := computation(initial)
|
||||
|
||||
assert.Equal(t, 2, pair.Head(result).Counter, "counter should be 2")
|
||||
assert.Equal(t, "done", pair.Head(result).Message, "message should be 'done'")
|
||||
}
|
||||
|
||||
// TestExecuteWithComplexState verifies Execute with complex state transformations
|
||||
func TestExecuteWithComplexState(t *testing.T) {
|
||||
initial := TestState{Counter: 1, Message: "start"}
|
||||
|
||||
step1 := Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter * 2, Message: s.Message}
|
||||
})
|
||||
|
||||
step2 := ChainFirst(func(_ Void) State[TestState, Void] {
|
||||
return Modify(func(s TestState) TestState {
|
||||
return TestState{Counter: s.Counter + 10, Message: "end"}
|
||||
})
|
||||
})
|
||||
|
||||
computation := step2(step1)
|
||||
|
||||
finalState := Execute[Void, TestState](initial)(computation)
|
||||
|
||||
assert.Equal(t, 12, finalState.Counter, "counter should be (1*2)+10 = 12")
|
||||
assert.Equal(t, "end", finalState.Message, "message should be 'end'")
|
||||
}
|
||||
|
||||
// TestEvaluateWithChain verifies Evaluate with chained computations
|
||||
func TestEvaluateWithChain(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](10),
|
||||
Map[TestState](N.Mul(2)),
|
||||
Chain(func(x int) State[TestState, string] {
|
||||
return Of[TestState](fmt.Sprintf("result: %d", x))
|
||||
}),
|
||||
)
|
||||
|
||||
value := Evaluate[string, TestState](initial)(computation)
|
||||
|
||||
assert.Equal(t, "result: 20", value, "value should be 'result: 20'")
|
||||
}
|
||||
@@ -17,6 +17,7 @@ package state
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/endomorphism"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
@@ -61,4 +62,6 @@ type (
|
||||
// It transforms a State[S, A] into a State[S, B], making it useful for
|
||||
// building pipelines of stateful transformations while maintaining the state type S.
|
||||
Operator[S, A, B any] = Kleisli[S, State[S, A], B]
|
||||
|
||||
Void = function.Void
|
||||
)
|
||||
|
||||
134
v2/stateio/TEST_COVERAGE.md
Normal file
134
v2/stateio/TEST_COVERAGE.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# StateIO Test Coverage Summary
|
||||
|
||||
## Overview
|
||||
Comprehensive test suite for the `stateio` package with **90.7% code coverage**.
|
||||
|
||||
## Test Files Created
|
||||
|
||||
### 1. state_test.go
|
||||
Tests for core StateIO operations:
|
||||
- **Of**: Creating successful computations
|
||||
- **MonadMap / Map**: Transforming values with functors
|
||||
- **MonadChain / Chain**: Sequencing dependent computations (monadic bind)
|
||||
- **MonadAp / Ap**: Applicative operations
|
||||
- **FromIO / FromIOK**: Lifting IO computations into StateIO
|
||||
- **Stateful operations**: Testing state threading through computations
|
||||
- **Composition**: Testing chained operations and state preservation
|
||||
|
||||
### 2. bind_test.go
|
||||
Tests for do-notation and binding operations:
|
||||
- **Do**: Starting do-notation chains
|
||||
- **Bind**: Binding computation results to state fields
|
||||
- **Let**: Computing derived values
|
||||
- **LetTo**: Setting constant values
|
||||
- **BindTo**: Wrapping values in constructors
|
||||
- **ApS**: Applicative sequencing
|
||||
- **Lens-based operations**: ApSL, BindL, LetL, LetToL for nested structures
|
||||
- **Complex do-notation**: Multi-step stateful computations
|
||||
|
||||
### 3. monad_test.go
|
||||
Tests for monadic laws and algebraic properties:
|
||||
|
||||
#### Monad Laws
|
||||
- **Left Identity**: `Of(a) >>= f ≡ f(a)`
|
||||
- **Right Identity**: `m >>= Of ≡ m`
|
||||
- **Associativity**: `(m >>= f) >>= g ≡ m >>= (x => f(x) >>= g)`
|
||||
|
||||
#### Functor Laws
|
||||
- **Identity**: `Map(id) ≡ id`
|
||||
- **Composition**: `Map(f . g) ≡ Map(f) . Map(g)`
|
||||
|
||||
#### Applicative Laws
|
||||
- **Identity**: `Ap(Of(id), v) ≡ v`
|
||||
- **Homomorphism**: `Ap(Of(f), Of(x)) ≡ Of(f(x))`
|
||||
- **Interchange**: `Ap(u, Of(y)) ≡ Ap(Of(f => f(y)), u)`
|
||||
|
||||
#### Type Class Implementations
|
||||
- **Pointed**: Tests the Pointed interface implementation
|
||||
- **Functor**: Tests the Functor interface implementation
|
||||
- **Applicative**: Tests the Applicative interface implementation
|
||||
- **Monad**: Tests the Monad interface implementation
|
||||
|
||||
#### Equality Operations
|
||||
- **Eq**: Testing equality predicates for StateIO values
|
||||
- **FromStrictEquals**: Testing strict equality construction
|
||||
|
||||
### 4. resource_test.go
|
||||
Tests for resource management:
|
||||
- **WithResource**: Resource acquisition and release patterns
|
||||
- **Resource chaining**: Using resources in chained computations
|
||||
- **State tracking**: Verifying state changes during resource lifecycle
|
||||
|
||||
## Test Statistics
|
||||
|
||||
- **Total Tests**: 43
|
||||
- **All Tests Passing**: ✅
|
||||
- **Code Coverage**: 90.7%
|
||||
- **Test Execution Time**: ~3 seconds
|
||||
|
||||
## Functions Tested
|
||||
|
||||
### Core Operations (state.go)
|
||||
- ✅ Of
|
||||
- ✅ MonadMap
|
||||
- ✅ Map
|
||||
- ✅ MonadChain
|
||||
- ✅ Chain
|
||||
- ✅ MonadAp
|
||||
- ✅ Ap
|
||||
- ✅ FromIO
|
||||
- ✅ FromIOK
|
||||
|
||||
### Do-Notation (bind.go)
|
||||
- ✅ Do
|
||||
- ✅ Bind
|
||||
- ✅ Let
|
||||
- ✅ LetTo
|
||||
- ✅ BindTo
|
||||
- ✅ ApS
|
||||
- ✅ ApSL
|
||||
- ✅ BindL
|
||||
- ✅ LetL
|
||||
- ✅ LetToL
|
||||
|
||||
### Type Classes (monad.go)
|
||||
- ✅ Pointed
|
||||
- ✅ Functor
|
||||
- ✅ Applicative
|
||||
- ✅ Monad
|
||||
|
||||
### Equality (eq.go)
|
||||
- ✅ Eq
|
||||
- ✅ FromStrictEquals
|
||||
|
||||
### Resource Management (resource.go)
|
||||
- ✅ WithResource
|
||||
- ✅ uncurryState (internal, tested via WithResource)
|
||||
|
||||
## Monadic Laws Verification
|
||||
|
||||
All three fundamental monad laws have been verified:
|
||||
|
||||
1. **Left Identity Law**: Verified that wrapping a value and immediately binding it is equivalent to just applying the function
|
||||
2. **Right Identity Law**: Verified that binding with the unit function returns the original computation
|
||||
3. **Associativity Law**: Verified that the order of binding operations doesn't matter
|
||||
|
||||
Additionally, functor and applicative laws have been verified to ensure the type class hierarchy is correctly implemented.
|
||||
|
||||
## Documentation Review
|
||||
|
||||
The package documentation in `doc.go` has been reviewed and is comprehensive, including:
|
||||
- Clear explanation of the StateIO monad transformer
|
||||
- Fantasy Land specification compliance
|
||||
- Core operations documentation
|
||||
- Example usage patterns
|
||||
- Monad laws statement
|
||||
|
||||
## Notes
|
||||
|
||||
- The StateIO monad correctly threads state through all operations
|
||||
- All monadic laws are satisfied
|
||||
- Resource management works correctly with proper cleanup
|
||||
- Lens-based operations enable working with nested state structures
|
||||
- The implementation follows functional programming best practices
|
||||
- Test coverage is excellent at 90.7%, with the remaining 9.3% likely being edge cases or internal helper functions
|
||||
270
v2/stateio/bind.go
Normal file
270
v2/stateio/bind.go
Normal file
@@ -0,0 +1,270 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
A "github.com/IBM/fp-go/v2/internal/apply"
|
||||
C "github.com/IBM/fp-go/v2/internal/chain"
|
||||
F "github.com/IBM/fp-go/v2/internal/functor"
|
||||
)
|
||||
|
||||
// Do starts a do-notation chain for building computations in a fluent style.
|
||||
// This is typically used with Bind, Let, and other combinators to compose
|
||||
// stateful computations with side effects.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Result struct {
|
||||
// name string
|
||||
// age int
|
||||
// }
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{}),
|
||||
// Bind(...),
|
||||
// Let(...),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func Do[ST, A any](
|
||||
empty A,
|
||||
) StateIO[ST, A] {
|
||||
return Of[ST](empty)
|
||||
}
|
||||
|
||||
// Bind executes a computation and binds its result to a field in the accumulator state.
|
||||
// This is used in do-notation to sequence dependent computations.
|
||||
//
|
||||
// The setter function takes the computed value and returns a function that updates
|
||||
// the accumulator state. The computation function (f) receives the current accumulator
|
||||
// state and returns a StateIO computation.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{}),
|
||||
// Bind(
|
||||
// func(name string) func(Result) Result {
|
||||
// return func(r Result) Result { return Result{name: name, age: r.age} }
|
||||
// },
|
||||
// func(r Result) StateIO[AppState, string] {
|
||||
// return Of[AppState]("John")
|
||||
// },
|
||||
// ),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func Bind[ST, S1, S2, T any](
|
||||
setter func(T) func(S1) S2,
|
||||
f Kleisli[ST, S1, T],
|
||||
) Operator[ST, S1, S2] {
|
||||
return C.Bind(
|
||||
Chain[ST, S1, S2],
|
||||
Map[ST, T, S2],
|
||||
setter,
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// Let computes a derived value and binds it to a field in the accumulator state.
|
||||
// Unlike Bind, this does not execute a monadic computation, just a pure function.
|
||||
//
|
||||
// The key function takes the computed value and returns a function that updates
|
||||
// the accumulator state. The computation function (f) receives the current accumulator
|
||||
// state and returns a pure value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{age: 25}),
|
||||
// Let(
|
||||
// func(isAdult bool) func(Result) Result {
|
||||
// return func(r Result) Result { return Result{age: r.age, isAdult: isAdult} }
|
||||
// },
|
||||
// func(r Result) bool { return r.age >= 18 },
|
||||
// ),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func Let[ST, S1, S2, T any](
|
||||
key func(T) func(S1) S2,
|
||||
f func(S1) T,
|
||||
) Operator[ST, S1, S2] {
|
||||
return F.Let(
|
||||
Map[ST, S1, S2],
|
||||
key,
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// LetTo binds a constant value to a field in the accumulator state.
|
||||
// This is useful for setting fixed values in the accumulator during do-notation.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{}),
|
||||
// LetTo(
|
||||
// func(status string) func(Result) Result {
|
||||
// return func(r Result) Result { return Result{status: status} }
|
||||
// },
|
||||
// "active",
|
||||
// ),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func LetTo[ST, S1, S2, T any](
|
||||
key func(T) func(S1) S2,
|
||||
b T,
|
||||
) Operator[ST, S1, S2] {
|
||||
return F.LetTo(
|
||||
Map[ST, S1, S2],
|
||||
key,
|
||||
b,
|
||||
)
|
||||
}
|
||||
|
||||
// BindTo wraps a value in a simple constructor, typically used to start a do-notation chain
|
||||
// after getting an initial value. This transforms a StateIO[S, T] into StateIO[S, S1]
|
||||
// by applying a constructor function.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := function.Pipe2(
|
||||
// Of[AppState](42),
|
||||
// BindTo[AppState](func(x int) Result { return Result{value: x} }),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func BindTo[ST, S1, T any](
|
||||
setter func(T) S1,
|
||||
) Operator[ST, T, S1] {
|
||||
return C.BindTo(
|
||||
Map[ST, T, S1],
|
||||
setter,
|
||||
)
|
||||
}
|
||||
|
||||
// ApS applies a computation in sequence and binds the result to a field.
|
||||
// This is the applicative version of Bind, useful for parallel-style composition
|
||||
// where computations don't depend on each other's results.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{}),
|
||||
// ApS(
|
||||
// func(count int) func(Result) Result {
|
||||
// return func(r Result) Result { return Result{count: count} }
|
||||
// },
|
||||
// Of[AppState](42),
|
||||
// ),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func ApS[ST, S1, S2, T any](
|
||||
setter func(T) func(S1) S2,
|
||||
fa StateIO[ST, T],
|
||||
) Operator[ST, S1, S2] {
|
||||
return A.ApS(
|
||||
Ap[S2, ST, T],
|
||||
Map[ST, S1, func(T) S2],
|
||||
setter,
|
||||
fa,
|
||||
)
|
||||
}
|
||||
|
||||
// ApSL is a lens-based variant of ApS for working with nested structures.
|
||||
// It uses a lens to focus on a specific field in the accumulator state,
|
||||
// making it easier to update nested fields without manual destructuring.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// nameLens := lens.Prop[Result, string]("name")
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{}),
|
||||
// ApSL(nameLens, Of[AppState]("John")),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func ApSL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
fa StateIO[ST, T],
|
||||
) Endomorphism[StateIO[ST, S]] {
|
||||
return ApS(lens.Set, fa)
|
||||
}
|
||||
|
||||
// BindL is a lens-based variant of Bind for working with nested structures.
|
||||
// It uses a lens to focus on a specific field in the accumulator state,
|
||||
// allowing you to update that field based on a computation that depends on its current value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// counterLens := lens.Prop[Result, int]("counter")
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{counter: 0}),
|
||||
// BindL(counterLens, func(n int) StateIO[AppState, int] {
|
||||
// return Of[AppState](n + 1)
|
||||
// }),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func BindL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
f Kleisli[ST, T, T],
|
||||
) Endomorphism[StateIO[ST, S]] {
|
||||
return Bind(lens.Set, function.Flow2(lens.Get, f))
|
||||
}
|
||||
|
||||
// LetL is a lens-based variant of Let for working with nested structures.
|
||||
// It uses a lens to focus on a specific field in the accumulator state,
|
||||
// allowing you to update that field using a pure function.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// counterLens := lens.Prop[Result, int]("counter")
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{counter: 5}),
|
||||
// LetL(counterLens, func(n int) int { return n * 2 }),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func LetL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
f Endomorphism[T],
|
||||
) Endomorphism[StateIO[ST, S]] {
|
||||
return Let[ST](lens.Set, function.Flow2(lens.Get, f))
|
||||
}
|
||||
|
||||
// LetToL is a lens-based variant of LetTo for working with nested structures.
|
||||
// It uses a lens to focus on a specific field in the accumulator state,
|
||||
// allowing you to set that field to a constant value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// statusLens := lens.Prop[Result, string]("status")
|
||||
// result := function.Pipe2(
|
||||
// Do[AppState](Result{}),
|
||||
// LetToL(statusLens, "active"),
|
||||
// )
|
||||
//
|
||||
//go:inline
|
||||
func LetToL[ST, S, T any](
|
||||
lens Lens[S, T],
|
||||
b T,
|
||||
) Endomorphism[StateIO[ST, S]] {
|
||||
return LetTo[ST](lens.Set, b)
|
||||
}
|
||||
410
v2/stateio/bind_test.go
Normal file
410
v2/stateio/bind_test.go
Normal file
@@ -0,0 +1,410 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/optics/lens"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type BindTestState struct {
|
||||
Name string
|
||||
Age int
|
||||
Email string
|
||||
}
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
computation := Do[BindTestState](BindTestState{Name: "initial"})
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, BindTestState{Name: "initial"}, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestBind(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
|
||||
getName := func(s BindTestState) StateIO[BindTestState, string] {
|
||||
return Of[BindTestState]("John")
|
||||
}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[BindTestState](BindTestState{}),
|
||||
Bind(
|
||||
func(name string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Name = name
|
||||
return s
|
||||
}
|
||||
},
|
||||
getName,
|
||||
),
|
||||
Map[BindTestState](func(s BindTestState) string { return s.Name }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "John", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestBindMultiple(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
|
||||
getName := func(s BindTestState) StateIO[BindTestState, string] {
|
||||
return Of[BindTestState]("Jane")
|
||||
}
|
||||
|
||||
getAge := func(s BindTestState) StateIO[BindTestState, int] {
|
||||
return Of[BindTestState](30)
|
||||
}
|
||||
|
||||
computation := F.Pipe3(
|
||||
Do[BindTestState](BindTestState{}),
|
||||
Bind(
|
||||
func(name string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Name = name
|
||||
return s
|
||||
}
|
||||
},
|
||||
getName,
|
||||
),
|
||||
Bind(
|
||||
func(age int) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Age = age
|
||||
return s
|
||||
}
|
||||
},
|
||||
getAge,
|
||||
),
|
||||
Map[BindTestState](func(s BindTestState) BindTestState { return s }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
finalState := pair.Tail(result)
|
||||
|
||||
assert.Equal(t, "Jane", finalState.Name)
|
||||
assert.Equal(t, 30, finalState.Age)
|
||||
}
|
||||
|
||||
func TestLet(t *testing.T) {
|
||||
initial := BindTestState{Age: 25}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[BindTestState](BindTestState{Age: 25}),
|
||||
Let[BindTestState](
|
||||
func(isAdult bool) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
if isAdult {
|
||||
s.Email = "adult@example.com"
|
||||
} else {
|
||||
s.Email = "minor@example.com"
|
||||
}
|
||||
return s
|
||||
}
|
||||
},
|
||||
func(s BindTestState) bool { return s.Age >= 18 },
|
||||
),
|
||||
Map[BindTestState](func(s BindTestState) string { return s.Email }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "adult@example.com", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestLetTo(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[BindTestState](BindTestState{}),
|
||||
LetTo[BindTestState](
|
||||
func(email string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Email = email
|
||||
return s
|
||||
}
|
||||
},
|
||||
"constant@example.com",
|
||||
),
|
||||
Map[BindTestState](func(s BindTestState) string { return s.Email }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "constant@example.com", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestBindTo(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[BindTestState]("Alice"),
|
||||
BindTo[BindTestState](func(name string) BindTestState {
|
||||
return BindTestState{Name: name}
|
||||
}),
|
||||
Map[BindTestState](func(s BindTestState) string { return s.Name }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "Alice", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestApS(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
|
||||
computation := F.Pipe3(
|
||||
Do[BindTestState](BindTestState{}),
|
||||
ApS(
|
||||
func(name string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Name = name
|
||||
return s
|
||||
}
|
||||
},
|
||||
Of[BindTestState]("Bob"),
|
||||
),
|
||||
ApS(
|
||||
func(age int) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Age = age
|
||||
return s
|
||||
}
|
||||
},
|
||||
Of[BindTestState](40),
|
||||
),
|
||||
Map[BindTestState](func(s BindTestState) BindTestState { return s }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
finalState := pair.Tail(result)
|
||||
|
||||
assert.Equal(t, "Bob", finalState.Name)
|
||||
assert.Equal(t, 40, finalState.Age)
|
||||
}
|
||||
|
||||
func TestComplexDoNotation(t *testing.T) {
|
||||
initial := BindTestState{}
|
||||
|
||||
fetchName := func(s BindTestState) StateIO[BindTestState, string] {
|
||||
return Of[BindTestState]("Charlie")
|
||||
}
|
||||
|
||||
fetchAge := func(s BindTestState) StateIO[BindTestState, int] {
|
||||
return Of[BindTestState](35)
|
||||
}
|
||||
|
||||
computation := F.Pipe4(
|
||||
Do[BindTestState](BindTestState{}),
|
||||
Bind(
|
||||
func(name string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Name = name
|
||||
return s
|
||||
}
|
||||
},
|
||||
fetchName,
|
||||
),
|
||||
Bind(
|
||||
func(age int) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Age = age
|
||||
return s
|
||||
}
|
||||
},
|
||||
fetchAge,
|
||||
),
|
||||
Let[BindTestState](
|
||||
func(email string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Email = email
|
||||
return s
|
||||
}
|
||||
},
|
||||
func(s BindTestState) string {
|
||||
return fmt.Sprintf("%s@example.com", s.Name)
|
||||
},
|
||||
),
|
||||
Map[BindTestState](func(s BindTestState) BindTestState { return s }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
finalState := pair.Tail(result)
|
||||
|
||||
assert.Equal(t, "Charlie", finalState.Name)
|
||||
assert.Equal(t, 35, finalState.Age)
|
||||
assert.Equal(t, "Charlie@example.com", finalState.Email)
|
||||
}
|
||||
|
||||
// Lens-based tests
|
||||
type NestedState struct {
|
||||
User BindTestState
|
||||
ID int
|
||||
}
|
||||
|
||||
var userLens = lens.MakeLensCurried(
|
||||
func(s NestedState) BindTestState { return s.User },
|
||||
func(user BindTestState) func(NestedState) NestedState {
|
||||
return func(s NestedState) NestedState {
|
||||
s.User = user
|
||||
return s
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
var nameLens = lens.MakeLensCurried(
|
||||
func(s BindTestState) string { return s.Name },
|
||||
func(name string) func(BindTestState) BindTestState {
|
||||
return func(s BindTestState) BindTestState {
|
||||
s.Name = name
|
||||
return s
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
func TestApSL(t *testing.T) {
|
||||
initial := NestedState{User: BindTestState{}, ID: 1}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[NestedState](NestedState{User: BindTestState{}, ID: 1}),
|
||||
ApSL(
|
||||
userLens,
|
||||
Of[NestedState](BindTestState{Name: "David", Age: 28}),
|
||||
),
|
||||
Map[NestedState](func(s NestedState) string { return s.User.Name }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "David", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestBindL(t *testing.T) {
|
||||
initial := NestedState{User: BindTestState{Name: "Eve"}, ID: 2}
|
||||
|
||||
updateUser := func(user BindTestState) StateIO[NestedState, BindTestState] {
|
||||
return Of[NestedState](BindTestState{
|
||||
Name: user.Name + " Updated",
|
||||
Age: user.Age + 1,
|
||||
Email: user.Email,
|
||||
})
|
||||
}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[NestedState](NestedState{User: BindTestState{Name: "Eve", Age: 20}, ID: 2}),
|
||||
BindL(userLens, updateUser),
|
||||
Map[NestedState](func(s NestedState) string { return s.User.Name }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "Eve Updated", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestLetL(t *testing.T) {
|
||||
initial := NestedState{User: BindTestState{Name: "Frank"}, ID: 3}
|
||||
|
||||
uppercase := func(name string) string {
|
||||
return fmt.Sprintf("%s (UPPERCASE)", name)
|
||||
}
|
||||
|
||||
composedLens := F.Pipe1(userLens, lens.Compose[NestedState](nameLens))
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[NestedState](NestedState{User: BindTestState{Name: "Frank"}, ID: 3}),
|
||||
LetL[NestedState](
|
||||
composedLens,
|
||||
uppercase,
|
||||
),
|
||||
Map[NestedState](func(s NestedState) string { return s.User.Name }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "Frank (UPPERCASE)", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestLetToL(t *testing.T) {
|
||||
initial := NestedState{User: BindTestState{}, ID: 4}
|
||||
|
||||
composedLens := F.Pipe1(userLens, lens.Compose[NestedState](nameLens))
|
||||
|
||||
computation := F.Pipe2(
|
||||
Do[NestedState](NestedState{User: BindTestState{}, ID: 4}),
|
||||
LetToL[NestedState](
|
||||
composedLens,
|
||||
"Grace",
|
||||
),
|
||||
Map[NestedState](func(s NestedState) string { return s.User.Name }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "Grace", pair.Tail(result))
|
||||
}
|
||||
|
||||
func TestDoNotationWithStatefulOperations(t *testing.T) {
|
||||
type Counter struct {
|
||||
Value int
|
||||
}
|
||||
|
||||
initial := Counter{Value: 0}
|
||||
|
||||
increment := func(c Counter) StateIO[Counter, int] {
|
||||
return func(s Counter) IO[Pair[Counter, int]] {
|
||||
return func() Pair[Counter, int] {
|
||||
newState := Counter{Value: s.Value + 1}
|
||||
return pair.MakePair(newState, newState.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe3(
|
||||
Do[Counter](Counter{Value: 0}),
|
||||
Bind(
|
||||
func(v int) func(Counter) Counter {
|
||||
return func(c Counter) Counter {
|
||||
return Counter{Value: v}
|
||||
}
|
||||
},
|
||||
increment,
|
||||
),
|
||||
Bind(
|
||||
func(v int) func(Counter) Counter {
|
||||
return func(c Counter) Counter {
|
||||
return Counter{Value: v}
|
||||
}
|
||||
},
|
||||
increment,
|
||||
),
|
||||
Map[Counter](func(c Counter) int { return c.Value }),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
// After two increments starting from 0, we should have 2
|
||||
assert.Equal(t, 2, pair.Tail(result))
|
||||
assert.Equal(t, 2, pair.Head(result).Value)
|
||||
}
|
||||
128
v2/stateio/doc.go
Normal file
128
v2/stateio/doc.go
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package stateio provides a functional programming abstraction that combines
|
||||
// stateful computations with side effects.
|
||||
//
|
||||
// # Fantasy Land Specification
|
||||
//
|
||||
// This is a monad transformer combining:
|
||||
// - State monad: https://github.com/fantasyland/fantasy-land
|
||||
// - IO monad: https://github.com/fantasyland/fantasy-land
|
||||
//
|
||||
// Implemented Fantasy Land algebras:
|
||||
// - Functor: https://github.com/fantasyland/fantasy-land#functor
|
||||
// - Apply: https://github.com/fantasyland/fantasy-land#apply
|
||||
// - Applicative: https://github.com/fantasyland/fantasy-land#applicative
|
||||
// - Chain: https://github.com/fantasyland/fantasy-land#chain
|
||||
// - Monad: https://github.com/fantasyland/fantasy-land#monad
|
||||
//
|
||||
// # StateIO
|
||||
//
|
||||
// StateIO[S, A] represents a computation that:
|
||||
// - Manages state of type S (State monad)
|
||||
// - Performs side effects (IO monad)
|
||||
// - Produces a value of type A
|
||||
//
|
||||
// The type is defined as:
|
||||
//
|
||||
// StateIO[S, A] = Reader[S, IO[Pair[S, A]]]
|
||||
//
|
||||
// This is particularly useful for:
|
||||
// - Stateful computations with side effects
|
||||
// - Managing application state while performing IO operations
|
||||
// - Composing operations that need both state management and effectful computation
|
||||
//
|
||||
// # Core Operations
|
||||
//
|
||||
// Construction:
|
||||
// - Of: Create a computation with a pure value
|
||||
// - FromIO: Lift an IO computation into StateIO
|
||||
//
|
||||
// Transformation:
|
||||
// - Map: Transform the value within the computation
|
||||
// - Chain: Sequence dependent computations (monadic bind)
|
||||
//
|
||||
// Combination:
|
||||
// - Ap: Apply a function in a context to a value in a context
|
||||
//
|
||||
// Kleisli Arrows:
|
||||
// - FromIOK: Lift an IO-returning function to a Kleisli arrow
|
||||
//
|
||||
// Do Notation (Monadic Composition):
|
||||
// - Do: Start a do-notation chain
|
||||
// - Bind: Bind a value from a computation
|
||||
// - BindTo: Bind a value to a simple constructor
|
||||
// - Let: Compute a derived value
|
||||
// - LetTo: Set a constant value
|
||||
// - ApS: Apply in sequence (for applicative composition)
|
||||
// - BindL/ApSL/LetL/LetToL: Lens-based variants for working with nested structures
|
||||
//
|
||||
// # Example Usage
|
||||
//
|
||||
// type AppState struct {
|
||||
// RequestCount int
|
||||
// LastError error
|
||||
// }
|
||||
//
|
||||
// // A computation that manages state and performs IO
|
||||
// func incrementCounter(data string) StateIO[AppState, string] {
|
||||
// return func(state AppState) IO[Pair[AppState, string]] {
|
||||
// return func() Pair[AppState, string] {
|
||||
// // Update state.RequestCount
|
||||
// // Perform IO operations
|
||||
// newState := AppState{RequestCount: state.RequestCount + 1}
|
||||
// result := "processed: " + data
|
||||
// return pair.MakePair(newState, result)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Compose operations using do-notation
|
||||
// type Result struct {
|
||||
// result string
|
||||
// count int
|
||||
// }
|
||||
//
|
||||
// computation := function.Pipe3(
|
||||
// Do[AppState](Result{}),
|
||||
// Bind(
|
||||
// func(result string) func(Result) Result {
|
||||
// return func(r Result) Result { return Result{result: result, count: r.count} }
|
||||
// },
|
||||
// func(r Result) StateIO[AppState, string] {
|
||||
// return incrementCounter("data")
|
||||
// },
|
||||
// ),
|
||||
// Map[AppState](func(r Result) string { return r.result }),
|
||||
// )
|
||||
//
|
||||
// // Execute with initial state
|
||||
// initialState := AppState{RequestCount: 0}
|
||||
// outcome := computation(initialState)() // Returns Pair[AppState, string]
|
||||
//
|
||||
// # Monad Laws
|
||||
//
|
||||
// StateIO satisfies the monad laws:
|
||||
// - Left Identity: Of(a) >>= f ≡ f(a)
|
||||
// - Right Identity: m >>= Of ≡ m
|
||||
// - Associativity: (m >>= f) >>= g ≡ m >>= (x => f(x) >>= g)
|
||||
//
|
||||
// Where >>= represents the Chain operation (monadic bind).
|
||||
//
|
||||
// These laws ensure that StateIO computations compose predictably and that
|
||||
// the order of composition doesn't affect the final result (beyond the order
|
||||
// of effects and state updates).
|
||||
package stateio
|
||||
64
v2/stateio/eq.go
Normal file
64
v2/stateio/eq.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/eq"
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
)
|
||||
|
||||
// Eq constructs an equality checker for StateIO values.
|
||||
// It takes an equality checker for IO[Pair[S, A]] and returns a function that,
|
||||
// given an initial state S, produces an equality checker for StateIO[S, A].
|
||||
//
|
||||
// Two StateIO values are considered equal if, when executed with the same initial state,
|
||||
// they produce equal IO[Pair[S, A]] results.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// eqIO := io.FromStrictEquals[Pair[AppState, int]]()
|
||||
// eqStateIO := Eq[AppState, int](eqIO)
|
||||
// initialState := AppState{}
|
||||
// areEqual := eqStateIO(initialState).Equals(stateIO1, stateIO2)
|
||||
func Eq[
|
||||
S, A any](eqr eq.Eq[IO[Pair[S, A]]]) func(S) eq.Eq[StateIO[S, A]] {
|
||||
return func(s S) eq.Eq[StateIO[S, A]] {
|
||||
return eq.FromEquals(func(l, r StateIO[S, A]) bool {
|
||||
return eqr.Equals(l(s), r(s))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FromStrictEquals constructs an equality checker for StateIO values where both
|
||||
// the state S and value A are comparable types.
|
||||
//
|
||||
// This is a convenience function that uses Go's built-in equality (==) for comparison.
|
||||
// It returns a function that, given an initial state, produces an equality checker
|
||||
// for StateIO[S, A].
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// eqStateIO := FromStrictEquals[AppState, int]()
|
||||
// initialState := AppState{}
|
||||
// areEqual := eqStateIO(initialState).Equals(stateIO1, stateIO2)
|
||||
func FromStrictEquals[
|
||||
S, A comparable]() func(S) eq.Eq[StateIO[S, A]] {
|
||||
return function.Pipe1(
|
||||
io.FromStrictEquals[Pair[S, A]](),
|
||||
Eq[S, A],
|
||||
)
|
||||
}
|
||||
152
v2/stateio/monad.go
Normal file
152
v2/stateio/monad.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/internal/applicative"
|
||||
"github.com/IBM/fp-go/v2/internal/functor"
|
||||
"github.com/IBM/fp-go/v2/internal/monad"
|
||||
"github.com/IBM/fp-go/v2/internal/pointed"
|
||||
)
|
||||
|
||||
// StateIOPointed implements the Pointed typeclass for StateIO.
|
||||
// It provides the 'Of' operation to lift pure values into the StateIO context.
|
||||
type StateIOPointed[
|
||||
S, A any,
|
||||
] struct{}
|
||||
|
||||
// StateIOFunctor implements the Functor typeclass for StateIO.
|
||||
// It provides the 'Map' operation to transform values within the StateIO context.
|
||||
type StateIOFunctor[
|
||||
S, A, B any,
|
||||
] struct{}
|
||||
|
||||
// StateIOApplicative implements the Applicative typeclass for StateIO.
|
||||
// It provides 'Of', 'Map', and 'Ap' operations for applicative composition.
|
||||
type StateIOApplicative[
|
||||
S, A, B any,
|
||||
] struct{}
|
||||
|
||||
// StateIOMonad implements the Monad typeclass for StateIO.
|
||||
// It provides 'Of', 'Map', 'Chain', and 'Ap' operations for monadic composition.
|
||||
type StateIOMonad[
|
||||
S, A, B any,
|
||||
] struct{}
|
||||
|
||||
// Of lifts a pure value into the StateIO context.
|
||||
func (o *StateIOPointed[S, A]) Of(a A) StateIO[S, A] {
|
||||
return Of[S](a)
|
||||
}
|
||||
|
||||
// Of lifts a pure value into the StateIO context.
|
||||
func (o *StateIOMonad[S, A, B]) Of(a A) StateIO[S, A] {
|
||||
return Of[S](a)
|
||||
}
|
||||
|
||||
// Of lifts a pure value into the StateIO context.
|
||||
func (o *StateIOApplicative[S, A, B]) Of(a A) StateIO[S, A] {
|
||||
return Of[S](a)
|
||||
}
|
||||
|
||||
// Map transforms the value within a StateIO using the provided function.
|
||||
func (o *StateIOMonad[S, A, B]) Map(f func(A) B) Operator[S, A, B] {
|
||||
return Map[S](f)
|
||||
}
|
||||
|
||||
// Map transforms the value within a StateIO using the provided function.
|
||||
func (o *StateIOApplicative[S, A, B]) Map(f func(A) B) Operator[S, A, B] {
|
||||
return Map[S](f)
|
||||
}
|
||||
|
||||
// Map transforms the value within a StateIO using the provided function.
|
||||
func (o *StateIOFunctor[S, A, B]) Map(f func(A) B) Operator[S, A, B] {
|
||||
return Map[S](f)
|
||||
}
|
||||
|
||||
// Chain sequences two StateIO computations, threading state through both.
|
||||
func (o *StateIOMonad[S, A, B]) Chain(f Kleisli[S, A, B]) Operator[S, A, B] {
|
||||
return Chain(f)
|
||||
}
|
||||
|
||||
// Ap applies a function wrapped in StateIO to a value wrapped in StateIO.
|
||||
func (o *StateIOMonad[S, A, B]) Ap(fa StateIO[S, A]) Operator[S, func(A) B, B] {
|
||||
return Ap[B](fa)
|
||||
}
|
||||
|
||||
// Ap applies a function wrapped in StateIO to a value wrapped in StateIO.
|
||||
func (o *StateIOApplicative[S, A, B]) Ap(fa StateIO[S, A]) Operator[S, func(A) B, B] {
|
||||
return Ap[B](fa)
|
||||
}
|
||||
|
||||
// Pointed returns a Pointed instance for StateIO.
|
||||
// The Pointed typeclass provides the 'Of' operation to lift pure values
|
||||
// into the StateIO context.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// p := Pointed[AppState, int]()
|
||||
// result := p.Of(42)
|
||||
func Pointed[
|
||||
S, A any,
|
||||
]() pointed.Pointed[A, StateIO[S, A]] {
|
||||
return &StateIOPointed[S, A]{}
|
||||
}
|
||||
|
||||
// Functor returns a Functor instance for StateIO.
|
||||
// The Functor typeclass provides the 'Map' operation to transform values
|
||||
// within the StateIO context.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// f := Functor[AppState, int, string]()
|
||||
// result := f.Map(strconv.Itoa)(Of[AppState](42))
|
||||
func Functor[
|
||||
S, A, B any,
|
||||
]() functor.Functor[A, B, StateIO[S, A], StateIO[S, B]] {
|
||||
return &StateIOFunctor[S, A, B]{}
|
||||
}
|
||||
|
||||
// Applicative returns an Applicative instance for StateIO.
|
||||
// The Applicative typeclass provides 'Of', 'Map', and 'Ap' operations
|
||||
// for applicative-style composition of StateIO computations.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// app := Applicative[AppState, int, string]()
|
||||
// fab := Of[AppState](func(x int) string { return strconv.Itoa(x) })
|
||||
// fa := Of[AppState](42)
|
||||
// result := app.Ap(fa)(fab)
|
||||
func Applicative[
|
||||
S, A, B any,
|
||||
]() applicative.Applicative[A, B, StateIO[S, A], StateIO[S, B], StateIO[S, func(A) B]] {
|
||||
return &StateIOApplicative[S, A, B]{}
|
||||
}
|
||||
|
||||
// Monad returns a Monad instance for StateIO.
|
||||
// The Monad typeclass provides 'Of', 'Map', 'Chain', and 'Ap' operations
|
||||
// for monadic composition of StateIO computations.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// m := Monad[AppState, int, string]()
|
||||
// result := m.Chain(func(x int) StateIO[AppState, string] {
|
||||
// return Of[AppState](strconv.Itoa(x))
|
||||
// })(Of[AppState](42))
|
||||
func Monad[
|
||||
S, A, B any,
|
||||
]() monad.Monad[A, B, StateIO[S, A], StateIO[S, B], StateIO[S, func(A) B]] {
|
||||
return &StateIOMonad[S, A, B]{}
|
||||
}
|
||||
304
v2/stateio/monad_test.go
Normal file
304
v2/stateio/monad_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
EQ "github.com/IBM/fp-go/v2/eq"
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MonadTestState struct {
|
||||
Value int
|
||||
}
|
||||
|
||||
// Test Left Identity Law: Of(a) >>= f ≡ f(a)
|
||||
func TestMonadLeftIdentity(t *testing.T) {
|
||||
initial := MonadTestState{Value: 0}
|
||||
a := 42
|
||||
|
||||
f := func(x int) StateIO[MonadTestState, string] {
|
||||
return func(s MonadTestState) IO[Pair[MonadTestState, string]] {
|
||||
return func() Pair[MonadTestState, string] {
|
||||
newState := MonadTestState{Value: s.Value + x}
|
||||
return pair.MakePair(newState, fmt.Sprintf("%d", x))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Left side: Of(a) >>= f
|
||||
left := MonadChain(Of[MonadTestState](a), f)
|
||||
leftResult := left(initial)()
|
||||
|
||||
// Right side: f(a)
|
||||
right := f(a)
|
||||
rightResult := right(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(rightResult), pair.Tail(leftResult))
|
||||
assert.Equal(t, pair.Head(rightResult).Value, pair.Head(leftResult).Value)
|
||||
}
|
||||
|
||||
// Test Right Identity Law: m >>= Of ≡ m
|
||||
func TestMonadRightIdentity(t *testing.T) {
|
||||
initial := MonadTestState{Value: 10}
|
||||
|
||||
m := func(s MonadTestState) IO[Pair[MonadTestState, int]] {
|
||||
return func() Pair[MonadTestState, int] {
|
||||
newState := MonadTestState{Value: s.Value * 2}
|
||||
return pair.MakePair(newState, newState.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Left side: m >>= Of
|
||||
left := MonadChain(m, func(x int) StateIO[MonadTestState, int] {
|
||||
return Of[MonadTestState](x)
|
||||
})
|
||||
leftResult := left(initial)()
|
||||
|
||||
// Right side: m
|
||||
rightResult := m(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(rightResult), pair.Tail(leftResult))
|
||||
assert.Equal(t, pair.Head(rightResult).Value, pair.Head(leftResult).Value)
|
||||
}
|
||||
|
||||
// Test Associativity Law: (m >>= f) >>= g ≡ m >>= (x => f(x) >>= g)
|
||||
func TestMonadAssociativity(t *testing.T) {
|
||||
initial := MonadTestState{Value: 5}
|
||||
|
||||
m := Of[MonadTestState](10)
|
||||
|
||||
f := func(x int) StateIO[MonadTestState, int] {
|
||||
return func(s MonadTestState) IO[Pair[MonadTestState, int]] {
|
||||
return func() Pair[MonadTestState, int] {
|
||||
newState := MonadTestState{Value: s.Value + x}
|
||||
return pair.MakePair(newState, x*2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g := func(y int) StateIO[MonadTestState, string] {
|
||||
return func(s MonadTestState) IO[Pair[MonadTestState, string]] {
|
||||
return func() Pair[MonadTestState, string] {
|
||||
newState := MonadTestState{Value: s.Value + y}
|
||||
return pair.MakePair(newState, fmt.Sprintf("%d", y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Left side: (m >>= f) >>= g
|
||||
left := MonadChain(MonadChain(m, f), g)
|
||||
leftResult := left(initial)()
|
||||
|
||||
// Right side: m >>= (x => f(x) >>= g)
|
||||
right := MonadChain(m, func(x int) StateIO[MonadTestState, string] {
|
||||
return MonadChain(f(x), g)
|
||||
})
|
||||
rightResult := right(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(rightResult), pair.Tail(leftResult))
|
||||
assert.Equal(t, pair.Head(rightResult).Value, pair.Head(leftResult).Value)
|
||||
}
|
||||
|
||||
// Test Functor Identity Law: Map(id) ≡ id
|
||||
func TestFunctorIdentity(t *testing.T) {
|
||||
initial := MonadTestState{Value: 7}
|
||||
|
||||
m := Of[MonadTestState](42)
|
||||
|
||||
// Map with identity function
|
||||
mapped := MonadMap(m, F.Identity[int])
|
||||
mappedResult := mapped(initial)()
|
||||
|
||||
// Original computation
|
||||
originalResult := m(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(originalResult), pair.Tail(mappedResult))
|
||||
assert.Equal(t, pair.Head(originalResult).Value, pair.Head(mappedResult).Value)
|
||||
}
|
||||
|
||||
// Test Functor Composition Law: Map(f . g) ≡ Map(f) . Map(g)
|
||||
func TestFunctorComposition(t *testing.T) {
|
||||
initial := MonadTestState{Value: 3}
|
||||
|
||||
m := Of[MonadTestState](10)
|
||||
|
||||
f := func(x int) int { return x * 2 }
|
||||
g := func(x int) int { return x + 5 }
|
||||
|
||||
// Left side: Map(f . g)
|
||||
left := MonadMap(m, F.Flow2(g, f))
|
||||
leftResult := left(initial)()
|
||||
|
||||
// Right side: Map(f) . Map(g)
|
||||
right := F.Pipe1(m, F.Flow2(Map[MonadTestState](g), Map[MonadTestState](f)))
|
||||
rightResult := right(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(rightResult), pair.Tail(leftResult))
|
||||
assert.Equal(t, pair.Head(rightResult).Value, pair.Head(leftResult).Value)
|
||||
}
|
||||
|
||||
// Test Applicative Identity Law: Ap(Of(id), v) ≡ v
|
||||
func TestApplicativeIdentity(t *testing.T) {
|
||||
initial := MonadTestState{Value: 1}
|
||||
|
||||
v := Of[MonadTestState](42)
|
||||
|
||||
// Ap(Of(id), v)
|
||||
applied := MonadAp(Of[MonadTestState](F.Identity[int]), v)
|
||||
appliedResult := applied(initial)()
|
||||
|
||||
// v
|
||||
originalResult := v(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(originalResult), pair.Tail(appliedResult))
|
||||
assert.Equal(t, pair.Head(originalResult).Value, pair.Head(appliedResult).Value)
|
||||
}
|
||||
|
||||
// Test Applicative Homomorphism Law: Ap(Of(f), Of(x)) ≡ Of(f(x))
|
||||
func TestApplicativeHomomorphism(t *testing.T) {
|
||||
initial := MonadTestState{Value: 2}
|
||||
|
||||
f := func(x int) int { return x * 3 }
|
||||
x := 7
|
||||
|
||||
// Left side: Ap(Of(f), Of(x))
|
||||
left := MonadAp(Of[MonadTestState](f), Of[MonadTestState](x))
|
||||
leftResult := left(initial)()
|
||||
|
||||
// Right side: Of(f(x))
|
||||
right := Of[MonadTestState](f(x))
|
||||
rightResult := right(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(rightResult), pair.Tail(leftResult))
|
||||
assert.Equal(t, pair.Head(rightResult).Value, pair.Head(leftResult).Value)
|
||||
}
|
||||
|
||||
// Test Applicative Interchange Law: Ap(u, Of(y)) ≡ Ap(Of(f => f(y)), u)
|
||||
func TestApplicativeInterchange(t *testing.T) {
|
||||
initial := MonadTestState{Value: 4}
|
||||
|
||||
u := Of[MonadTestState](func(x int) int { return x + 10 })
|
||||
y := 5
|
||||
|
||||
// Left side: Ap(u, Of(y))
|
||||
left := MonadAp(u, Of[MonadTestState](y))
|
||||
leftResult := left(initial)()
|
||||
|
||||
// Right side: Ap(Of(f => f(y)), u)
|
||||
right := MonadAp(
|
||||
Of[MonadTestState](func(f func(int) int) int { return f(y) }),
|
||||
u,
|
||||
)
|
||||
rightResult := right(initial)()
|
||||
|
||||
assert.Equal(t, pair.Tail(rightResult), pair.Tail(leftResult))
|
||||
assert.Equal(t, pair.Head(rightResult).Value, pair.Head(leftResult).Value)
|
||||
}
|
||||
|
||||
// Test that StateIO implements Pointed interface
|
||||
func TestPointed(t *testing.T) {
|
||||
pointed := Pointed[MonadTestState, int]()
|
||||
computation := pointed.Of(42)
|
||||
|
||||
initial := MonadTestState{Value: 0}
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
// Test that StateIO implements Functor interface
|
||||
func TestFunctor(t *testing.T) {
|
||||
functor := Functor[MonadTestState, int, string]()
|
||||
|
||||
computation := Of[MonadTestState](42)
|
||||
mapped := functor.Map(func(x int) string { return fmt.Sprintf("%d", x) })(computation)
|
||||
|
||||
initial := MonadTestState{Value: 0}
|
||||
result := mapped(initial)()
|
||||
|
||||
assert.Equal(t, "42", pair.Tail(result))
|
||||
}
|
||||
|
||||
// Test that StateIO implements Applicative interface
|
||||
func TestApplicative(t *testing.T) {
|
||||
applicative := Applicative[MonadTestState, int, string]()
|
||||
|
||||
fab := Of[MonadTestState](func(x int) string { return fmt.Sprintf("%d", x) })
|
||||
fa := Of[MonadTestState](42)
|
||||
result := applicative.Ap(fa)(fab)
|
||||
|
||||
initial := MonadTestState{Value: 0}
|
||||
output := result(initial)()
|
||||
|
||||
assert.Equal(t, "42", pair.Tail(output))
|
||||
}
|
||||
|
||||
// Test that StateIO implements Monad interface
|
||||
func TestMonad(t *testing.T) {
|
||||
monad := Monad[MonadTestState, int, string]()
|
||||
|
||||
computation := monad.Of(42)
|
||||
chained := monad.Chain(func(x int) StateIO[MonadTestState, string] {
|
||||
return Of[MonadTestState](fmt.Sprintf("%d", x))
|
||||
})(computation)
|
||||
|
||||
initial := MonadTestState{Value: 0}
|
||||
result := chained(initial)()
|
||||
|
||||
assert.Equal(t, "42", pair.Tail(result))
|
||||
}
|
||||
|
||||
// Test Eq functionality
|
||||
func TestEq(t *testing.T) {
|
||||
initial := MonadTestState{Value: 0}
|
||||
|
||||
comp1 := Of[MonadTestState](42)
|
||||
comp2 := Of[MonadTestState](42)
|
||||
comp3 := Of[MonadTestState](43)
|
||||
|
||||
// Create equality predicate for IO[Pair[MonadTestState, int]]
|
||||
eqIO := EQ.FromEquals(func(l, r IO[Pair[MonadTestState, int]]) bool {
|
||||
lResult := l()
|
||||
rResult := r()
|
||||
return pair.Tail(lResult) == pair.Tail(rResult) &&
|
||||
pair.Head(lResult).Value == pair.Head(rResult).Value
|
||||
})
|
||||
|
||||
eq := Eq(eqIO)(initial)
|
||||
|
||||
assert.True(t, eq.Equals(comp1, comp2))
|
||||
assert.False(t, eq.Equals(comp1, comp3))
|
||||
}
|
||||
|
||||
// Test FromStrictEquals
|
||||
func TestFromStrictEquals(t *testing.T) {
|
||||
initial := MonadTestState{Value: 0}
|
||||
|
||||
comp1 := Of[MonadTestState](42)
|
||||
comp2 := Of[MonadTestState](42)
|
||||
comp3 := Of[MonadTestState](43)
|
||||
|
||||
eq := FromStrictEquals[MonadTestState, int]()(initial)
|
||||
|
||||
assert.True(t, eq.Equals(comp1, comp2))
|
||||
assert.False(t, eq.Equals(comp1, comp3))
|
||||
}
|
||||
82
v2/stateio/resource.go
Normal file
82
v2/stateio/resource.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2023 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
)
|
||||
|
||||
// uncurryState transforms a curried function into an uncurried function that operates on pairs.
|
||||
// This is an internal helper function used by WithResource to adapt StateIO computations
|
||||
// to work with the IO resource management functions.
|
||||
//
|
||||
// It converts: func(A) io.Kleisli[S, B] -> io.Kleisli[Pair[S, A], B]
|
||||
func uncurryState[S, A, B any](f func(A) io.Kleisli[S, B]) io.Kleisli[Pair[S, A], B] {
|
||||
return func(r Pair[S, A]) IO[B] {
|
||||
return f(pair.Tail(r))(pair.Head(r))
|
||||
}
|
||||
}
|
||||
|
||||
// WithResource provides safe resource management for StateIO computations.
|
||||
// It ensures that resources are properly acquired and released, even if errors occur.
|
||||
//
|
||||
// The function takes:
|
||||
// - onCreate: A StateIO computation that creates/acquires the resource
|
||||
// - onRelease: A Kleisli arrow that releases the resource (receives the resource, returns any value)
|
||||
//
|
||||
// It returns a Kleisli arrow that takes a resource-using computation and ensures proper cleanup.
|
||||
//
|
||||
// The pattern follows the bracket pattern (acquire-use-release):
|
||||
// 1. Acquire the resource using onCreate
|
||||
// 2. Use the resource with the provided computation
|
||||
// 3. Release the resource using onRelease (guaranteed to run)
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Create a file resource
|
||||
// openFile := func(s AppState) IO[Pair[AppState, *os.File]] {
|
||||
// return io.Of(pair.MakePair(s, file))
|
||||
// }
|
||||
//
|
||||
// // Release the file resource
|
||||
// closeFile := func(f *os.File) StateIO[AppState, error] {
|
||||
// return FromIO[AppState](io.Of(f.Close()))
|
||||
// }
|
||||
//
|
||||
// // Use the resource safely
|
||||
// withFile := WithResource[string, AppState, *os.File, error](
|
||||
// openFile,
|
||||
// closeFile,
|
||||
// )
|
||||
//
|
||||
// // Apply to a computation that uses the file
|
||||
// result := withFile(func(f *os.File) StateIO[AppState, string] {
|
||||
// // Use file f here
|
||||
// return Of[AppState]("data")
|
||||
// })
|
||||
func WithResource[A, S, RES, ANY any](
|
||||
onCreate StateIO[S, RES],
|
||||
onRelease Kleisli[S, RES, ANY],
|
||||
) Kleisli[S, Kleisli[S, RES, A], A] {
|
||||
release := uncurryState(onRelease)
|
||||
return func(f Kleisli[S, RES, A]) StateIO[S, A] {
|
||||
use := uncurryState(f)
|
||||
return func(s S) IO[Pair[S, A]] {
|
||||
return io.WithResource[Pair[S, RES], Pair[S, A]](onCreate(s), release)(use)
|
||||
}
|
||||
}
|
||||
}
|
||||
125
v2/stateio/resource_test.go
Normal file
125
v2/stateio/resource_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type ResourceState struct {
|
||||
ResourceCreated bool
|
||||
Value int
|
||||
}
|
||||
|
||||
func TestWithResource(t *testing.T) {
|
||||
initial := ResourceState{ResourceCreated: false, Value: 0}
|
||||
|
||||
// Create resource
|
||||
onCreate := func(s ResourceState) IO[Pair[ResourceState, string]] {
|
||||
return func() Pair[ResourceState, string] {
|
||||
newState := ResourceState{ResourceCreated: true, Value: s.Value}
|
||||
return pair.MakePair(newState, "resource-handle")
|
||||
}
|
||||
}
|
||||
|
||||
// Release resource (cleanup function)
|
||||
onRelease := func(res string) StateIO[ResourceState, int] {
|
||||
return func(s ResourceState) IO[Pair[ResourceState, int]] {
|
||||
return func() Pair[ResourceState, int] {
|
||||
// Release doesn't modify state in this test, just returns
|
||||
return pair.MakePair(s, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use resource
|
||||
useResource := func(res string) StateIO[ResourceState, int] {
|
||||
return func(s ResourceState) IO[Pair[ResourceState, int]] {
|
||||
return func() Pair[ResourceState, int] {
|
||||
// Verify we received the resource handle
|
||||
assert.Equal(t, "resource-handle", res)
|
||||
newState := ResourceState{ResourceCreated: s.ResourceCreated, Value: 42}
|
||||
return pair.MakePair(newState, 42)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the resource management computation
|
||||
withRes := WithResource[int](onCreate, onRelease)
|
||||
computation := withRes(useResource)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
// Verify the resource was created and used
|
||||
assert.Equal(t, 42, pair.Tail(result))
|
||||
finalState := pair.Head(result)
|
||||
assert.True(t, finalState.ResourceCreated)
|
||||
assert.Equal(t, 42, finalState.Value)
|
||||
}
|
||||
|
||||
func TestWithResourceChained(t *testing.T) {
|
||||
initial := ResourceState{ResourceCreated: false, Value: 0}
|
||||
|
||||
// Create resource
|
||||
onCreate := func(s ResourceState) IO[Pair[ResourceState, int]] {
|
||||
return func() Pair[ResourceState, int] {
|
||||
newState := ResourceState{ResourceCreated: true, Value: s.Value}
|
||||
return pair.MakePair(newState, 100)
|
||||
}
|
||||
}
|
||||
|
||||
// Release resource
|
||||
onRelease := func(res int) StateIO[ResourceState, int] {
|
||||
return func(s ResourceState) IO[Pair[ResourceState, int]] {
|
||||
return func() Pair[ResourceState, int] {
|
||||
return pair.MakePair(s, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use resource with chaining
|
||||
useResource := func(res int) StateIO[ResourceState, int] {
|
||||
return MonadChain(
|
||||
Of[ResourceState](res),
|
||||
func(r int) StateIO[ResourceState, int] {
|
||||
return func(s ResourceState) IO[Pair[ResourceState, int]] {
|
||||
return func() Pair[ResourceState, int] {
|
||||
newState := ResourceState{
|
||||
ResourceCreated: s.ResourceCreated,
|
||||
Value: r * 2,
|
||||
}
|
||||
return pair.MakePair(newState, r*2)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Create the resource management computation
|
||||
withRes := WithResource[int](onCreate, onRelease)
|
||||
computation := withRes(useResource)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
// Verify the resource was created and used
|
||||
assert.Equal(t, 200, pair.Tail(result))
|
||||
finalState := pair.Head(result)
|
||||
assert.True(t, finalState.ResourceCreated)
|
||||
assert.Equal(t, 200, finalState.Value)
|
||||
}
|
||||
164
v2/stateio/state.go
Normal file
164
v2/stateio/state.go
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/function"
|
||||
"github.com/IBM/fp-go/v2/internal/statet"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
)
|
||||
|
||||
// Of creates a StateIO that wraps a pure value.
|
||||
// The value is wrapped and the state is passed through unchanged.
|
||||
//
|
||||
// This is the Pointed/Applicative 'of' operation that lifts a pure value
|
||||
// into the StateIO context.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := Of[AppState](42)
|
||||
// // Returns a computation containing 42 that passes state through unchanged
|
||||
func Of[S, A any](a A) StateIO[S, A] {
|
||||
return statet.Of[StateIO[S, A]](io.Of[Pair[S, A]], a)
|
||||
}
|
||||
|
||||
// MonadMap transforms the value of a StateIO using the provided function.
|
||||
// The state is threaded through the computation unchanged.
|
||||
// This is the functor map operation.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := MonadMap(
|
||||
// Of[AppState](21),
|
||||
// func(x int) int { return x * 2 },
|
||||
// ) // Result contains 42
|
||||
func MonadMap[S, A, B any](fa StateIO[S, A], f func(A) B) StateIO[S, B] {
|
||||
return statet.MonadMap[StateIO[S, A], StateIO[S, B]](
|
||||
io.MonadMap[Pair[S, A], Pair[S, B]],
|
||||
fa,
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// Map is the curried version of [MonadMap].
|
||||
// Returns a function that transforms a StateIO.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// double := Map[AppState](func(x int) int { return x * 2 })
|
||||
// result := function.Pipe1(Of[AppState](21), double)
|
||||
func Map[S, A, B any](f func(A) B) Operator[S, A, B] {
|
||||
return statet.Map[StateIO[S, A], StateIO[S, B]](
|
||||
io.Map[Pair[S, A], Pair[S, B]],
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// MonadChain sequences two computations, passing the result of the first to a function
|
||||
// that produces the second computation. This is the monadic bind operation.
|
||||
// The state is threaded through both computations sequentially.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := MonadChain(
|
||||
// Of[AppState](5),
|
||||
// func(x int) StateIO[AppState, string] {
|
||||
// return Of[AppState](fmt.Sprintf("value: %d", x))
|
||||
// },
|
||||
// )
|
||||
func MonadChain[S, A, B any](fa StateIO[S, A], f Kleisli[S, A, B]) StateIO[S, B] {
|
||||
return statet.MonadChain(
|
||||
io.MonadChain[Pair[S, A], Pair[S, B]],
|
||||
fa,
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// Chain is the curried version of [MonadChain].
|
||||
// Returns a function that sequences computations.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// stringify := Chain(func(x int) StateIO[AppState, string] {
|
||||
// return Of[AppState](fmt.Sprintf("%d", x))
|
||||
// })
|
||||
// result := function.Pipe1(Of[AppState](42), stringify)
|
||||
func Chain[S, A, B any](f Kleisli[S, A, B]) Operator[S, A, B] {
|
||||
return statet.Chain[StateIO[S, A]](
|
||||
io.Chain[Pair[S, A], Pair[S, B]],
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// MonadAp applies a function wrapped in a StateIO to a value wrapped in a StateIO.
|
||||
// The state is threaded through both computations sequentially.
|
||||
// This is the applicative apply operation.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// fab := Of[AppState](func(x int) int { return x * 2 })
|
||||
// fa := Of[AppState](21)
|
||||
// result := MonadAp(fab, fa) // Result contains 42
|
||||
func MonadAp[B, S, A any](fab StateIO[S, func(A) B], fa StateIO[S, A]) StateIO[S, B] {
|
||||
return statet.MonadAp[StateIO[S, A], StateIO[S, B]](
|
||||
io.MonadMap[Pair[S, A], Pair[S, B]],
|
||||
io.MonadChain[Pair[S, func(A) B], Pair[S, B]],
|
||||
fab,
|
||||
fa,
|
||||
)
|
||||
}
|
||||
|
||||
// Ap is the curried version of [MonadAp].
|
||||
// Returns a function that applies a wrapped function to the given wrapped value.
|
||||
func Ap[B, S, A any](fa StateIO[S, A]) Operator[S, func(A) B, B] {
|
||||
return statet.Ap[StateIO[S, A], StateIO[S, B], StateIO[S, func(A) B]](
|
||||
io.Map[Pair[S, A], Pair[S, B]],
|
||||
io.Chain[Pair[S, func(A) B], Pair[S, B]],
|
||||
fa,
|
||||
)
|
||||
}
|
||||
|
||||
// FromIO lifts an IO computation into StateIO.
|
||||
// The IO computation is executed and its result is wrapped in StateIO.
|
||||
// The state is passed through unchanged.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ioAction := io.Of(42)
|
||||
// stateIOAction := FromIO[AppState](ioAction)
|
||||
func FromIO[S, A any](fa IO[A]) StateIO[S, A] {
|
||||
return statet.FromF[StateIO[S, A]](
|
||||
io.MonadMap[A],
|
||||
fa,
|
||||
)
|
||||
}
|
||||
|
||||
// Combinators
|
||||
|
||||
// FromIOK lifts an IO-returning function into a Kleisli arrow for StateIO.
|
||||
// This is useful for composing functions that return IO actions with StateIO computations.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// readFile := func(path string) IO[string] { ... }
|
||||
// kleisli := FromIOK[AppState](readFile)
|
||||
// // kleisli can now be used with Chain
|
||||
func FromIOK[S, A, B any](f func(A) IO[B]) Kleisli[S, A, B] {
|
||||
return function.Flow2(
|
||||
f,
|
||||
FromIO[S, B],
|
||||
)
|
||||
}
|
||||
282
v2/stateio/state_test.go
Normal file
282
v2/stateio/state_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
F "github.com/IBM/fp-go/v2/function"
|
||||
N "github.com/IBM/fp-go/v2/number"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestState struct {
|
||||
Counter int
|
||||
Message string
|
||||
}
|
||||
|
||||
func TestOf(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
computation := Of[TestState](42)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestMonadMap(t *testing.T) {
|
||||
initial := TestState{Counter: 5, Message: "test"}
|
||||
computation := Of[TestState](10)
|
||||
|
||||
mapped := MonadMap(computation, N.Mul(2))
|
||||
result := mapped(initial)()
|
||||
|
||||
assert.Equal(t, 20, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
double := Map[TestState](N.Mul(2))
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](21),
|
||||
double,
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestMonadChain(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
computation := Of[TestState](5)
|
||||
chained := MonadChain(computation, func(x int) StateIO[TestState, string] {
|
||||
return func(s TestState) IO[Pair[TestState, string]] {
|
||||
return func() Pair[TestState, string] {
|
||||
newState := TestState{Counter: s.Counter + x, Message: fmt.Sprintf("value: %d", x)}
|
||||
return pair.MakePair(newState, fmt.Sprintf("result: %d", x*2))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
result := chained(initial)()
|
||||
|
||||
assert.Equal(t, "result: 10", pair.Tail(result))
|
||||
assert.Equal(t, 5, pair.Head(result).Counter)
|
||||
assert.Equal(t, "value: 5", pair.Head(result).Message)
|
||||
}
|
||||
|
||||
func TestChain(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
addToCounter := func(x int) StateIO[TestState, int] {
|
||||
return func(s TestState) IO[Pair[TestState, int]] {
|
||||
return func() Pair[TestState, int] {
|
||||
newState := TestState{Counter: s.Counter + x, Message: s.Message}
|
||||
return pair.MakePair(newState, s.Counter+x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](5),
|
||||
Chain(addToCounter),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 5, pair.Tail(result))
|
||||
assert.Equal(t, 5, pair.Head(result).Counter)
|
||||
}
|
||||
|
||||
func TestMonadAp(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
fab := Of[TestState](N.Mul(3))
|
||||
fa := Of[TestState](7)
|
||||
|
||||
result := MonadAp(fab, fa)(initial)()
|
||||
|
||||
assert.Equal(t, 21, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestAp(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](N.Mul(4)),
|
||||
Ap[int](Of[TestState](10)),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 40, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestFromIO(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
ioComputation := func() int { return 42 }
|
||||
computation := FromIO[TestState](ioComputation)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestFromIOK(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
ioFunc := func(x int) IO[string] {
|
||||
return func() string {
|
||||
return fmt.Sprintf("value: %d", x*2)
|
||||
}
|
||||
}
|
||||
|
||||
kleisli := FromIOK[TestState](ioFunc)
|
||||
computation := kleisli(21)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "value: 42", pair.Tail(result))
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestChainedOperations(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
incrementCounter := func(x int) StateIO[TestState, int] {
|
||||
return func(s TestState) IO[Pair[TestState, int]] {
|
||||
return func() Pair[TestState, int] {
|
||||
newState := TestState{Counter: s.Counter + x, Message: s.Message}
|
||||
return pair.MakePair(newState, newState.Counter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setMessage := func(count int) StateIO[TestState, string] {
|
||||
return func(s TestState) IO[Pair[TestState, string]] {
|
||||
return func() Pair[TestState, string] {
|
||||
msg := fmt.Sprintf("Count is %d", count)
|
||||
newState := TestState{Counter: s.Counter, Message: msg}
|
||||
return pair.MakePair(newState, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](5),
|
||||
Chain(incrementCounter),
|
||||
Chain(setMessage),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, "Count is 5", pair.Tail(result))
|
||||
assert.Equal(t, 5, pair.Head(result).Counter)
|
||||
assert.Equal(t, "Count is 5", pair.Head(result).Message)
|
||||
}
|
||||
|
||||
func TestStatefulComputation(t *testing.T) {
|
||||
initial := TestState{Counter: 10, Message: "start"}
|
||||
|
||||
// A computation that reads and modifies state
|
||||
computation := func(s TestState) IO[Pair[TestState, int]] {
|
||||
return func() Pair[TestState, int] {
|
||||
newState := TestState{
|
||||
Counter: s.Counter * 2,
|
||||
Message: fmt.Sprintf("%s -> doubled", s.Message),
|
||||
}
|
||||
return pair.MakePair(newState, newState.Counter)
|
||||
}
|
||||
}
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 20, pair.Tail(result))
|
||||
assert.Equal(t, 20, pair.Head(result).Counter)
|
||||
assert.Equal(t, "start -> doubled", pair.Head(result).Message)
|
||||
}
|
||||
|
||||
func TestMapPreservesState(t *testing.T) {
|
||||
initial := TestState{Counter: 42, Message: "important"}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](10),
|
||||
Map[TestState](N.Mul(2)),
|
||||
Map[TestState](N.Add(5)),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
// Value should be transformed: 10 * 2 + 5 = 25
|
||||
assert.Equal(t, 25, pair.Tail(result))
|
||||
// State should be unchanged
|
||||
assert.Equal(t, initial, pair.Head(result))
|
||||
}
|
||||
|
||||
func TestChainModifiesState(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
addOne := func(x int) StateIO[TestState, int] {
|
||||
return func(s TestState) IO[Pair[TestState, int]] {
|
||||
return func() Pair[TestState, int] {
|
||||
newState := TestState{Counter: s.Counter + 1, Message: s.Message}
|
||||
return pair.MakePair(newState, x+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe2(
|
||||
Of[TestState](0),
|
||||
Chain(addOne),
|
||||
Chain(addOne),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 2, pair.Tail(result))
|
||||
assert.Equal(t, 2, pair.Head(result).Counter)
|
||||
}
|
||||
|
||||
func TestApplicativeComposition(t *testing.T) {
|
||||
initial := TestState{Counter: 0, Message: ""}
|
||||
|
||||
add := func(x int) func(int) int {
|
||||
return func(y int) int {
|
||||
return x + y
|
||||
}
|
||||
}
|
||||
|
||||
computation := F.Pipe1(
|
||||
Of[TestState](add(10)),
|
||||
Ap[int](Of[TestState](32)),
|
||||
)
|
||||
|
||||
result := computation(initial)()
|
||||
|
||||
assert.Equal(t, 42, pair.Tail(result))
|
||||
}
|
||||
91
v2/stateio/type.go
Normal file
91
v2/stateio/type.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2024 - 2025 IBM Corp.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stateio
|
||||
|
||||
import (
|
||||
"github.com/IBM/fp-go/v2/either"
|
||||
"github.com/IBM/fp-go/v2/endomorphism"
|
||||
"github.com/IBM/fp-go/v2/io"
|
||||
"github.com/IBM/fp-go/v2/optics/iso/lens"
|
||||
"github.com/IBM/fp-go/v2/pair"
|
||||
"github.com/IBM/fp-go/v2/predicate"
|
||||
"github.com/IBM/fp-go/v2/reader"
|
||||
"github.com/IBM/fp-go/v2/state"
|
||||
)
|
||||
|
||||
type (
|
||||
// Endomorphism represents a function from A to A.
|
||||
Endomorphism[A any] = endomorphism.Endomorphism[A]
|
||||
|
||||
// Lens is an optic that focuses on a field of type A within a structure of type S.
|
||||
Lens[S, A any] = lens.Lens[S, A]
|
||||
|
||||
// State represents a stateful computation that takes an initial state S and returns
|
||||
// a pair of the new state S and a value A.
|
||||
State[S, A any] = state.State[S, A]
|
||||
|
||||
// Pair represents a tuple of two values.
|
||||
Pair[L, R any] = pair.Pair[L, R]
|
||||
|
||||
// Reader represents a computation that depends on an environment/context of type R
|
||||
// and produces a value of type A.
|
||||
Reader[R, A any] = reader.Reader[R, A]
|
||||
|
||||
// Either represents a value that can be either a Left (error) or Right (success).
|
||||
Either[E, A any] = either.Either[E, A]
|
||||
|
||||
// IO represents a computation that performs side effects and produces a value of type A.
|
||||
IO[A any] = io.IO[A]
|
||||
|
||||
// StateIO represents a stateful computation that performs side effects.
|
||||
// It combines the State monad with the IO monad, allowing computations that:
|
||||
// - Manage state of type S
|
||||
// - Perform side effects (IO)
|
||||
// - Produce a value of type A
|
||||
//
|
||||
// The computation takes an initial state S and returns an IO action that produces
|
||||
// a Pair containing the new state S and the result value A.
|
||||
//
|
||||
// Type definition: StateIO[S, A] = Reader[S, IO[Pair[S, A]]]
|
||||
//
|
||||
// This is useful for:
|
||||
// - Stateful computations with side effects
|
||||
// - Managing application state while performing IO operations
|
||||
// - Composing operations that need both state management and IO
|
||||
StateIO[S, A any] = Reader[S, IO[Pair[S, A]]]
|
||||
|
||||
// Kleisli represents a Kleisli arrow for StateIO.
|
||||
// It's a function from A to StateIO[S, B], enabling composition of
|
||||
// stateful, effectful computations.
|
||||
//
|
||||
// Kleisli arrows are used for:
|
||||
// - Chaining dependent computations
|
||||
// - Building pipelines of stateful operations
|
||||
// - Monadic composition with Chain/Bind operations
|
||||
Kleisli[S, A, B any] = Reader[A, StateIO[S, B]]
|
||||
|
||||
// Operator represents a transformation from one StateIO to another.
|
||||
// It's a function that takes StateIO[S, A] and returns StateIO[S, B].
|
||||
//
|
||||
// Operators are used for:
|
||||
// - Transforming computations (Map, Chain, etc.)
|
||||
// - Building reusable computation transformers
|
||||
// - Composing higher-order operations
|
||||
Operator[S, A, B any] = Reader[StateIO[S, A], StateIO[S, B]]
|
||||
|
||||
// Predicate represents a function that tests a value of type A.
|
||||
Predicate[A any] = predicate.Predicate[A]
|
||||
)
|
||||
Reference in New Issue
Block a user