Safe representation of restricted values

Table of contents

  1. A new type for probabilities
    1. Creating a Probability
    2. Testing the implementation
  2. Interacting with probabilities
    1. Non-mutable references
    2. Take the inner value
    3. Trait implementations
    4. Chaining operations
  3. Is it worth it?

In my previous article, we explored how to build a complete Lambda function step-by-step by collaborating with the compiler. We ended up with a function that could handle all edge cases, often by checking the values of Options and Results, with one notable exception.

When we checked if a probability was within 0.0 and 1.0, we used an if expression instead.

24 // Return an error if the probability is not between 0 and 1
25 if !(0.0..=1.0).contains(&probability) {
26 return Ok(user_error("'probability' must be between 0 and 1"));
27 }

This is perfectly fine, but we could abstract those checks away by creating a specialized Probability struct that can only hold values between 0.0 and 1.0. This will add a lot of complexity to our codebase, but introduces an extra layer of safety: if you’re manipulating a Probability, you know it only contains valid probability values.

A new type for probabilities

Let’s define that new type.

1#[derive(Clone, Debug, PartialEq, PartialOrd)]
2struct Probability(f64);

We only need to hold a single f64, so we can define it as a tuple struct with a single value. Under the hood, Rust will store it as an f64 with no overhead.

Creating a Probability

The inner value is private, so we cannot access it directly outside of the module where we define it. That means that users cannot accidentally pass invalid probability values.

Instead, we need to add a method that will create a new Probability from an f64. As a rule of thumb, if there is a standard Trait that matches what we need to do, we should use it. In this case, we can use the TryFrom trait:

1pub trait TryFrom<T> {
2 type Error;
3 fn try_from(value: T) -> Result<Self, Self::Error>;
4}

Compared to the From trait, TryFrom returns a Result that may contain an error. As we could get f64 values outside of the valid range, we could fail to create the Probability. We can then use the Result to return an error instead.

4impl TryFrom<f64> for Probability {
5 type Error = ProbabilityError;
6
7 fn try_from(value: f64) -> Result<Self, Self::Error> {
8 // Check if the value is between 0.0 and 1.0
9 if (0.0..=1.0).contains(&value) {
10 Ok(Probability(value))
11 // Return an error otherwise
12 } else {
13 Err(ProbabilityError::OutOfBounds)
14 }
15 }
16}

Testing the implementation

Since this is the only way to create a Probability, we can be pretty sure that the value is valid – as long as the try_from function is implemented correctly. Let’s add a few simple tests to make sure it works.

I’m purposefully omitting some tests here, but you might want to add checks for edge cases, like 0.0 and 1.0.

18/// Test if the value is within 0.0 and 1.0
19#[test]
20fn test_in_bound() {
21 let p = Probability::try_from(0.5).unwrap();
22 assert_eq!(p, Probability(0.5));
23}
24
25/// Test if the value is under 0.0
26#[test]
27fn test_under_bounds() {
28 match Probability::try_from(-0.1).unwrap_err() {
29 ProbabilityError::OutOfBounds => (),
30 _ => assert!(false),
31 }
32}
33
34/// Test if the value is over 1.0
35#[test]
36fn test_over_bounds() {
37 match Probability::try_from(1.1).unwrap_err() {
38 ProbabilityError::OutOfBounds => (),
39 _ => assert!(false),
40 }
41}

I’m a big fan of randomized testing. While you lose a bit on repeatability, you gain the ability to test a wide variety of cases and make sure your implementation doesn’t break. With the rand crate, we can generate numbers based on a range, making it quite easy to generate valid values or values under 0.0.

It’s a bit tricker for values over 1.0 because we cannot define a range that would exclude its lower value. I’ll walk around that by setting a minimum value slightly over 1.0, but that means I will never test cases between 1.0 and that value.

43/// Test if the value is within 0.0 and 1.0
44#[test]
45fn test_randomized_in_bound() {
46 let mut rng = thread_rng();
47
48 for _ in 0..2048 {
49 let inner: f64 = rng.gen_range(0.0..=1.0);
50 let p = Probability::try_from(inner).unwrap();
51 assert_eq!(p, Probability(inner));
52 }
53}
54
55/// Test if the value is under 0.0
56#[test]
57fn test_randomized_under_bounds() {
58 let mut rng = thread_rng();
59
60 for _ in 0..2048 {
61 let inner: f64 = rng.gen_range(f64::MIN..0.0);
62 match Probability::try_from(inner).unwrap_err() {
63 ProbabilityError::OutOfBounds => (),
64 _ => assert!(false),
65 }
66 }
67}
68
69/// Test if the value is over 1.0
70#[test]
71fn test_randomized_over_bounds() {
72 let mut rng = thread_rng();
73
74 for _ in 0..2048 {
75 let inner: f64 = rng.gen_range(1.1..f64::MAX);
76 match Probability::try_from(inner).unwrap_err() {
77 ProbabilityError::OutOfBounds => (),
78 _ => assert!(false),
79 }
80 }
81}

That’s a lot of tests compared to what we had before! But all these tests are quite useful: they give us higher confidence that our implementation is correct. Remember: this is the critical part where we validate the input!

Interacting with probabilities

At the moment, we’ve defined a way to create a Probability that contains a valid value, but we don’t have a way to interact with it. Since the inner f64 is private, we can’t just refer to it directly. However, there’s a catch here: we should be able to read the value, but never modify it – at least not in safe Rust.

There are three things we can do about this:

  1. Allow a non-mutable reference to the inner value.
  2. Allow taking ownership of the inner value.
  3. Implement the same traits as f64.

This last option is a bit tricky because f64 implements many traits. Furthermore, because we have a limited range of valid values, we’d need to always return a Result: while 0.6 is a valid probability, 0.6 + 0.6 is not, so we cannot just add two probabilities without checking the result.

Non-mutable references

Let’s go back to the rule of thumb I mentioned before: if there is a standard Trait that matches what we need to do, we should use it. In this case, multiple traits could be interesting, but they all have one nice thing in common: they have an immutable and mutable version. For example, we have the AsRef trait and its AsMut counterpart; or Borrow and BorrowMut.

Now comes the tricky question: how do we know which one to use? The documentation for the AsRef trait provides a detailed explanation on the topic:

AsRef has the same signature as Borrow, but Borrow is different in a few aspects:

  • Unlike AsRef, Borrow has a blanket impl for any T, and can be used to accept either a reference or a value.
  • Borrow also requires that Hash, Eq and Ord for borrowed value are equivalent to those of the owned value. For this reason, if you want to borrow only a single field of a struct you can implement AsRef, but not Borrow.

In this case, a Probability will behave the same way for those three traits:

  1. It should be equal to another Probability if the inner values are equal.
  2. It should be less than another Probability if the inner values are less than the other.
  3. It should hash to the same value as the inner value.

However, there is a small detail that would cause an issue with Borrow: Borrow has a blanket impl for any T. If we have a probability p and we try to do &p == &0.1, we’ll get an error message about comparing Probability and f64. You can take a look at this Rust playground to see the error message.

Let’s implement AsRef instead:

83impl convert::AsRef<f64> for Probability {
84 fn as_ref(&self) -> &f64 {
85 &self.0
86 }
87}

Take the inner value

There are situations where we no longer need the Probability and just want to take ownership of the inner value. We could use the Into trait to do exactly that.

In the documentation, you might see that it advises against implementing Into and recommends using the From trait instead. This is because anything that implements From will also have an opposite blanket implementation. Formally speaking, From<T> for U implies Into<U> for T.

89impl From<Probability> for f64 {
90 fn from(value: Probability) -> Self {
91 value.0
92 }
93}

This looks pretty similar to the AsRef implementation, except that we take ownership of the value.

Trait implementations

Now, all we have left to do is implement a variety of traits so users can interact with Probability directly. At the moment, users can only create Probability, and then either take the inner value or use the as_ref method. That means that adding two probabilities together looks like this:

    let p1 = Probability::new(0.3).unwrap();
    let p2 = Probability::new(0.4).unwrap();

    let p_sum = Probability::new(p1.as_ref() + p2.as_ref()).unwrap();

This is pretty lengthy, but we can’t do better with what we have at the moment. To make it easier, we should implement a bunch of traits that will allow us to perform these manipulations, like p1 + p2, directly.

For this, there are two categories of manipulations: those that may fail and those that won’t. For example, 0.6 + 0.7 should fail, as the sum is above our threshold of 1.0. On the other side, multiplying two numbers between 0.0 and 1.0 will always result in a value in that range.

If we look at the std::ops module, there are quite a few traits that could be interesting here. Depending on what you are building, you might decide only some of the traits are relevant. First, I’ll focus on maths operations: Add, Div, Mul, Rem, and Sub. Since Probability cannot contain negative values, we don’t need to implement Neg. We could also implement the equivalent operations with assignments (AddAssign and others), but these don’t support returning a Result.

This is a significant number of traits to implement, but we don’t need to do all of that by hand. Instead, we can use a macro to generate them. In a nutshell, a macro is something that will generate Rust code based on inputs. Here’s one that will work with all the maths traits mentioned earlier:

95macro_rules! impl_op {
96 // Pattern with a single identifier
97 ($trait:ident) => {
98 // Generating the trait implementation
99 impl std::ops::$trait for Probability {
100 type Output = Result<Self, ProbabilityError>;
101
102 // Implement the trait's method
103 paste::paste! {
104 fn [<$trait:snake>](self, rhs: Self) -> Self::Output {
105 Probability::try_from(self.0.[<$trait:snake>](rhs.0))
106 }
107 }
108 }
109 };
110}

A few things are going on here, so let’s break that down.

First is macro_rules!: this is a macro to generate macro. With it, we define a declarative macro that will take a trait as an input, and implement it for Probability.

The second one is paste::paste!. This is a macro to concatenate and transform identifiers. We use it to transform the trait name into its method identifier: Add needs an add() method, Div needs a div(), etc. We can use paste to transform Add into a snake case version.

Finally, you might notice that we’re not using any operator in the macro but rely on the method name instead. If we didn’t do that, we’d have to specify what operator corresponds to what macro, while we already know what’s the method name.

We can then use the macro for four of our traits in just three lines:

112impl_op!(Add);
113impl_op!(Div);
114impl_op!(Rem);
115impl_op!(Sub);

We are only missing Mul now. We could use the same macro to implement it, but multiplying two probabilities will always result in values between 0.0 and 1.0. And here we have another choice to make: should we keep it consistent with the other operations, or should we make it return the value directly?

I’ve opted to return the Probability directly, as users would just unwrap it anyway and multiplication is a common operation on probabilities.

117impl std::ops::Mul for Probability {
118 type Output = Probability;
119
120 fn mul(self, rhs: Self) -> Self::Output {
121 // We can create the probability without using try_from because the
122 // product of values between [0, 1] is always between [0, 1].
123 Probability(self.0 * rhs.0)
124 }
125}

Here, I’m using Probability(self.0 * rhs.0) directly instead of unwrapping from try_from. This is because we can prove mathematically that the product of two probabilities is always between 0.0 and 1.0 – and therefore I’m fairly confident that the result will always be a valid probability.

If I was less confident, I could use try_from and expect to unwrap the result. This would have a small performance penalty but will catch any errors in my logic.

Chaining operations

Let’s say we want to add three probabilities together. With the implementation we have right now, if we do p1 + p2 + p3, we’ll get the following error:

error[E0369]: cannot add `Probability` to `Result<Probability, ProbabilityError>`
   --> probability/src/lib.rs:189:29
    |
189 |         assert_eq!((p1 + p2 + p3).unwrap(), expected);
    |                     ------- ^ -- Probability
    |                     |
    |                     Result<Probability, ProbabilityError>

That makes sense given what we have at the moment: when we add two probabilities together, we get a Result back, and we haven’t implemented adding a Result to a Probability. As the operation will always result in the same error type, we could create an operation that’ll add if possible, otherwise, return the error.

Since this is in the macro we defined earlier, we’ll just add a new implementation in it:

95macro_rules! impl_op {
96 ($trait:ident) => {
97 impl std::ops::$trait for Probability {
98 type Output = Result<Self, ProbabilityError>;
99
100 paste::paste! {
101 fn [<$trait:snake>](self, rhs: Self) -> Self::Output {
102 Probability::try_from(self.0.[<$trait:snake>](rhs.0))
103 }
104 }
105 }
106
107 impl std::ops::$trait<Probability> for Result<Probability, ProbabilityError> {
108 type Output = Result<Probability, ProbabilityError>;
109
110 paste::paste! {
111 fn [<$trait:snake>](self, rhs: Probability) -> Self::Output {
112 match self {
113 // The Result contains a value - we can add it to the probability
114 Ok(value) => Probability::try_from(value.0.[<$trait:snake>](rhs.0)),
115 // The Result already contains an error
116 Err(error) => Err(error),
117 }
118 }
119 }
120 }
121 };
122}

Is it worth it?

Doing all of this was quite a bit of work, with many considerations along the way. Is it worth it?

It depends on what you’ll do with it. If this was just for a single program like in my previous article, it wouldn’t be worth it. If you’re only going to check something once, you don’t need to create a custom type and do all the extra work of implementing various traits to perform maths operations.

However, if you’re building a library that will be used by hundreds or thousands of developers, it’s probably a good thing. All that extra effort from one developer will pay off by ensuring others will have the tools to safely manipulate values between 0.0 and 1.0.

You can find examples of such wrapping structs in the std::num part of the standard library. You have structs for values that cannot be zero (for memory optimisation reasons), and for wrapping numbers on overflow.