Nice write‑up. A couple of notes from doing roughly the same dance on Cortex‑M0 and M3 boards for sensor fusion.
1. You can, in fact, get rid of every FP instruction on M0.
The trick is to pre‑bake the scale and zero_point into a single fixed‑point multiplier per layer (the dyadic form you mentioned). The formula is
ini
Copy
Edit
y = ((Wx + b) M) >> s
Where M fits in an int32 and s is the power‑of‑two shift. You compute M and s once on the host, write them as const tables, and your inner loop is literally a MAC followed by a multiply‑accumulate‑shift. No fpsoft library, no division.
2. CMSIS‑NN already gives you the fast int8 kernels.
The docs are painful but you can steal just four files: arm_fully_connected_q7.c, arm_nnsupportfunctions.c, and their headers. On M0 this compiled to ~3 kB for me. Feed those kernels fixed‑point activations and you only pay for the ops you call.
3. Workflow that kept me sane
Prototype in PyTorch. Tiny net, ReLU, MSE, Adam, done.
torch.quantization.quantize_qat for quantization‑aware training. Export to ONNX, then run a one‑page Python script that dumps .h files with weight, bias, M, s.
Hand‑roll the inference loop in C. It is about 40 lines per layer, easy to unit‑test on the host with the same vectors you trained on.
By starting with a known‑good fp32 model you always have a checksum: the int8 path must match fp32 within tolerance or you know exactly where to look.
Awesome, thanks! This is exactly the kind of experienced take I was hoping my blog post would summon =D
Re: computing M and s, does torch.quantization.quantize_qat do this or do you do it yourself from the (presumably f32) activation scaling that torch finds?
I don't have much experience with this kind of numerical computing, so I have no intuition about how much the "quantization" of selecting M and s might impact the overall performance of the network. I.e., whether
- M and s should be trained as part of QAT (e.g., the "Learned Step Size Quantization" paper)
- it's fine to just deterministically compute M and s from the f32 activation scaling.
Also: Thanks for the tips re: CMSIS-NN, glad to know it's possible to use in a non-framework way. Any chance your example is open source somewhere?
My suggestion would be that, since you want a tiny integer-only NN tailored for a specific computer, are only occasionally training one for a specific task, and you have a simulator to generate unlimited data, you simply do random search or an evolutionary method like CMA-ES.
They are easy to understand and code up by hand in a few lines (which is one reason you won't find any libraries for them - they are the 'leftpad' or 'isEven' of NNs, the effort it would take to install and understand and use a library often exceeds what it would take to just write it yourself), will handle any NN topology or numeric type you can invent, and will train very fast in this scenario.
Experienced practitioner here, the second half pf the post describes doing everything exactly the way I have done it (only differences are I picked C++ and Eigen instead of rust and nalgebrafor inference, and i used torch’s ndarray and backprop tools instead of jax’s- with the analagous “just print out a C++ code from python” approach to weight serialization). You picked up on the key insight which is that the size of the code needed to just directly implement the inference equations is much smaller than the size of the configuration file of any possible framework that was flexible enough to meet your requirements of (rust, no inference time allocation, no inference time floating point, trained from scratch, ultra small parameter count, …)
I gotta say, I'm always interested in new ways to make stuff lighter especially for small devices - you think these clever tricks actually hold up for real-world use or just look cool on paper?
Could you point to an example that you like more? One of the author’s goals is to:
> solicit “why don’t you just …” emails from experienced practitioners who can point me to the library/tutorial I’ve been missing =D (see the alternatives-considered down the page for what I struck out on)
The last time I did anything like this, the easiest workflow I found was to use your favorite high-level runtime for training and just implement a serializer converting the model into source code for your target embedded system. Hand-code the inference loop. This is exactly the strategy TFA landed on.
One advantage of having it implemented in code is that you can observe and think about the instructions being generated. TFA didn't talk at all about something pretty important for small/fast neural networks -- the normal "cleanup" code (padding, alignment, length alignment, data-dependent horizontal sums, etc) can dwarf the actual mul->add execution times. You might want to, e.g., ensure your dimensions are all multiples of 8. You definitely want to store weights as column-major instead of row-major if the network is written as vec @ mat instead of mat @ vec (and vice versa for the latter).
When you're baking weights and biases into code like that, use an affine representation -- explicitly pad the input with the number one, along with however many extra zeroes you need for any other length padding requirements make sense for your problem (usually zero for embedded, but this is a similar workflow to low-resource networks on traditional computers, where you probably want vectorization).
Floats are a tiny bit hard to avoid for dot products. For similar precision, you require nearly twice the bit count in a fixed-point representation just to make the multiplies work, plus some extra bits proportional to the log2 of the dimension. E.g., if you trained on f16 inputs then you'll have roughly comparable precision with i32 fixed-point weights, and that's assuming you go through the effort to scale and shift everything into an appropriate numerical regime. Twice the instruction count (or thereabouts) on twice the register width makes fixed-point 2-4x slower for similar precision than a hardware float, supposing those wide instructions exist for your microcontroller, and soft floats are closer to 10x slower for multiply-accumulate. If you're emulating wide integer instructions, just use soft floats. If you don't care about a 4x slowdown, just use soft floats.
Training can be a little finicky for small networks. At a minimum, you probably want to create train/test/validate sets and have many training runs. There are other techniques if you want to go down a rabbit hole.
Other ML architectures can be much more performant here. Gradient-boosted trees are already SOTA on many of these problems, and oblivious trees map extremely well to normal microcontroller instruction sets. By skipping the multiplies, your fixed-point precision is on par with floats of similar bit-width, making quantization a breeze.
This is such a confused blogpost I swear this had to be a teenager just banging their head against the wall.
Wanting to natively train a quantized neural network is stupid unless you are training directly on your microcontroller. I was constantly waiting for the author to explain their special circumstances and it turns out they don't have any. They just have a standard TinyML [0] use case that's been done to death with fixed point quanitization aware training, which unlike what the author of the blog post said, doesn't rely on terabytes of data.
QAT is done on a conventionally trained model with much less data than the full training process. Doing QAT early has no benefits. The big downside of QAT isn't that you need a lot of data, it's that you need the same data distribution as the original training data and nobody has access to that, because only the weights are published.
> since our input data
comes from multiple sensors and the the output pose
has six components (three spatial positions and three spatial rotations)
Typo: two "the"
For robotics/inverse pose applications, don't people usually use a 3x3 matrix (three rotations, three spatial) for coordinate representation? Otherwise you get weird gimbal lock issues (I think).
For my application I need just the translations and Euler angles. The range of poses is mechanically constrained so I don't have to worry about gimbal lock. But yeah, my limited understanding matches yours that other parameterizations are more useful in general contexts.
Hey there op. I don't know what your sensors are measuring (distance to a point maybe? Or angle from a Valve lighthouse for inside-out tracking?)
But here's my "why didn't you just"
Since you have a forward simulation function (pose to measurements), why didn't you use an iterative solver to reverse it? Coordinate descent is easy to code and if you have a constrained range of poses you can probably just use multiple starting points to avoid getting stuck with a local minimum. Then use the last solution as a starting point for the next one to save iterations.
Sure it's not closed-form like an NN and it can still have pathological cases, but the code is a little more transparent
That's a reasonable idea, but unfortunately wouldn't work in my case since the simulation relies on a lot of scientific libraries in Python and I need the inversion to happen on the microcontroller.
When you say "coordinate descent" do you mean gradient descent? I.e., updating a potential pose using the gradient of a loss term (e.g., (predicted sensor reading - actual sensor reading)**2)?
I bet that would work, but a tricky part would be calculating gradients. I'm not sure if the Python libraries I'm using support that. My understanding is that automatic differentiation through libraries might be easier in a language like Julia where dual numbers flow through everything via the multiple dispatch mechanism.
All of this is absurdly complicated. Exactly what I would expect from a new student who doesn't know what they're doing and has no one to teach them how do you engineering in a systematic manner. I don't mean this as an insult. I teach this stuff and have seen it hundreds of times.
You should look for "post training static quantization" also called . There are countless ways to quantize. This will quantize both the weights and the activations after training.
You're doing this on hard mode for no reason. This is typical and something I often need to break people out of. Optimizing for performance by doing custom things in Jax when you're a beginner is a terrible path to take.
Performance is not your problem. You're training a trivial network that would have run on a CPU 20 years ago.
There's no clear direction here, just trying complicated stuff in no logical order with no learning or dependencies between steps. You need to treat these problems as scientific experiments. What do I do to learn more about my domain, what do I change depending on the answer I get, etc. Not, now it's time to try something else random like jax.
Worse. You need to learn the key lesson in this space. Credit assignment for problems is extremely hard. If something isn't working why isn't it? Because of a bug? A hopeless problem? Using a crappy optimizer? Etc. That's why you should start in a framework that works and escape it later if you want.
Here's a simple plan to do this:
First forget about quantization. Use pytorch. Implement your trivial network in 5 lines. Train it with Adam. Make sure it works. Make sure your problem is solveable with the data that you have and the network you've chosen and your activation functions and the loss and the optimizer (use Adam, forget about this doing stuff by hand for now).
> Unless I had an expert guide who was absolutely sure it’d be straightforward (email me!), I’d avoid high-level frameworks like TensorFlow and PyTorch and instead implement the quantization-aware training myself.
This is exactly backwards. Unless you have an expert never implement anything yourself. If you don't have one, rely on what already exists. Because you can logically narrow down the options for what works and what's wrong. If you do it yourself you're always lost.
Once you have that working start backing off. Slowly change the working network into what you need. Step by step. At every step write down why you think your change is good and what you would do if it isn't. Then look at the results.
Forget about microflow-rs or whatever. Train with pytorch, export to onnx, generate c code for your onnx for inference.
I kind of see your point, but only in the context of working on time-sensitive task which others rely upon. But if it is hobby/educational project, what is wrong doing things by yourself? And resort to decomposing existing solution if you can't figure out why yours is not working?
There's nothing better for understanding something rather than trying to do that "something" from scratch yourself.
I think the point is that OP is learning things about a wide variety of topics that aren't really relevant to their stated goal, i.e. solving the sensor/state inference problem.
Which, as you say, can be valuable! There's nothing wrong with that. But the more complexity you add the less likely you are to actually solve the problem (all else being equal, some problems are just inherently complex).
Nice write‑up. A couple of notes from doing roughly the same dance on Cortex‑M0 and M3 boards for sensor fusion.
1. You can, in fact, get rid of every FP instruction on M0. The trick is to pre‑bake the scale and zero_point into a single fixed‑point multiplier per layer (the dyadic form you mentioned). The formula is
ini Copy Edit y = ((Wx + b) M) >> s Where M fits in an int32 and s is the power‑of‑two shift. You compute M and s once on the host, write them as const tables, and your inner loop is literally a MAC followed by a multiply‑accumulate‑shift. No fpsoft library, no division.
2. CMSIS‑NN already gives you the fast int8 kernels. The docs are painful but you can steal just four files: arm_fully_connected_q7.c, arm_nnsupportfunctions.c, and their headers. On M0 this compiled to ~3 kB for me. Feed those kernels fixed‑point activations and you only pay for the ops you call.
3. Workflow that kept me sane
Prototype in PyTorch. Tiny net, ReLU, MSE, Adam, done.
torch.quantization.quantize_qat for quantization‑aware training. Export to ONNX, then run a one‑page Python script that dumps .h files with weight, bias, M, s.
Hand‑roll the inference loop in C. It is about 40 lines per layer, easy to unit‑test on the host with the same vectors you trained on.
By starting with a known‑good fp32 model you always have a checksum: the int8 path must match fp32 within tolerance or you know exactly where to look.
Awesome, thanks! This is exactly the kind of experienced take I was hoping my blog post would summon =D
Re: computing M and s, does torch.quantization.quantize_qat do this or do you do it yourself from the (presumably f32) activation scaling that torch finds?
I don't have much experience with this kind of numerical computing, so I have no intuition about how much the "quantization" of selecting M and s might impact the overall performance of the network. I.e., whether
- M and s should be trained as part of QAT (e.g., the "Learned Step Size Quantization" paper)
- it's fine to just deterministically compute M and s from the f32 activation scaling.
Also: Thanks for the tips re: CMSIS-NN, glad to know it's possible to use in a non-framework way. Any chance your example is open source somewhere?
My suggestion would be that, since you want a tiny integer-only NN tailored for a specific computer, are only occasionally training one for a specific task, and you have a simulator to generate unlimited data, you simply do random search or an evolutionary method like CMA-ES.
They are easy to understand and code up by hand in a few lines (which is one reason you won't find any libraries for them - they are the 'leftpad' or 'isEven' of NNs, the effort it would take to install and understand and use a library often exceeds what it would take to just write it yourself), will handle any NN topology or numeric type you can invent, and will train very fast in this scenario.
Experienced practitioner here, the second half pf the post describes doing everything exactly the way I have done it (only differences are I picked C++ and Eigen instead of rust and nalgebrafor inference, and i used torch’s ndarray and backprop tools instead of jax’s- with the analagous “just print out a C++ code from python” approach to weight serialization). You picked up on the key insight which is that the size of the code needed to just directly implement the inference equations is much smaller than the size of the configuration file of any possible framework that was flexible enough to meet your requirements of (rust, no inference time allocation, no inference time floating point, trained from scratch, ultra small parameter count, …)
I wonder how well BitNet (ternary weights) would work for this. It seems like a promising way forward for constrained hardware.
https://arxiv.org/abs/2310.11453
https://github.com/cpldcpu/BitNetMCU/blob/main/docs/document...
If people like "give me the simplest possible working coding example" of neural networks, I highly recommend this one:
"A Neural Network in 11 lines of Python (Part 1)": https://iamtrask.github.io/2015/07/12/basic-python-network/
Out of curiosity, did you consider bayesian state estimators?
For example, an unscented kalman filter: https://www.mathworks.com/help/control/ug/nonlinear-state-es...
Great article. For a moment, I thought this would be about a gen AI that would turn any input into a “kawai” version of it
Anyway, excellent insights and detail
I gotta say, I'm always interested in new ways to make stuff lighter especially for small devices - you think these clever tricks actually hold up for real-world use or just look cool on paper?
What benefit does jax.nn provide over rolling one's own? There are countless examples on the web of small neural networks, written from scratch.
Could you point to an example that you like more? One of the author’s goals is to:
> solicit “why don’t you just …” emails from experienced practitioners who can point me to the library/tutorial I’ve been missing =D (see the alternatives-considered down the page for what I struck out on)
The last time I did anything like this, the easiest workflow I found was to use your favorite high-level runtime for training and just implement a serializer converting the model into source code for your target embedded system. Hand-code the inference loop. This is exactly the strategy TFA landed on.
One advantage of having it implemented in code is that you can observe and think about the instructions being generated. TFA didn't talk at all about something pretty important for small/fast neural networks -- the normal "cleanup" code (padding, alignment, length alignment, data-dependent horizontal sums, etc) can dwarf the actual mul->add execution times. You might want to, e.g., ensure your dimensions are all multiples of 8. You definitely want to store weights as column-major instead of row-major if the network is written as vec @ mat instead of mat @ vec (and vice versa for the latter).
When you're baking weights and biases into code like that, use an affine representation -- explicitly pad the input with the number one, along with however many extra zeroes you need for any other length padding requirements make sense for your problem (usually zero for embedded, but this is a similar workflow to low-resource networks on traditional computers, where you probably want vectorization).
Floats are a tiny bit hard to avoid for dot products. For similar precision, you require nearly twice the bit count in a fixed-point representation just to make the multiplies work, plus some extra bits proportional to the log2 of the dimension. E.g., if you trained on f16 inputs then you'll have roughly comparable precision with i32 fixed-point weights, and that's assuming you go through the effort to scale and shift everything into an appropriate numerical regime. Twice the instruction count (or thereabouts) on twice the register width makes fixed-point 2-4x slower for similar precision than a hardware float, supposing those wide instructions exist for your microcontroller, and soft floats are closer to 10x slower for multiply-accumulate. If you're emulating wide integer instructions, just use soft floats. If you don't care about a 4x slowdown, just use soft floats.
Training can be a little finicky for small networks. At a minimum, you probably want to create train/test/validate sets and have many training runs. There are other techniques if you want to go down a rabbit hole.
Other ML architectures can be much more performant here. Gradient-boosted trees are already SOTA on many of these problems, and oblivious trees map extremely well to normal microcontroller instruction sets. By skipping the multiplies, your fixed-point precision is on par with floats of similar bit-width, making quantization a breeze.
This is such a confused blogpost I swear this had to be a teenager just banging their head against the wall.
Wanting to natively train a quantized neural network is stupid unless you are training directly on your microcontroller. I was constantly waiting for the author to explain their special circumstances and it turns out they don't have any. They just have a standard TinyML [0] use case that's been done to death with fixed point quanitization aware training, which unlike what the author of the blog post said, doesn't rely on terabytes of data.
QAT is done on a conventionally trained model with much less data than the full training process. Doing QAT early has no benefits. The big downside of QAT isn't that you need a lot of data, it's that you need the same data distribution as the original training data and nobody has access to that, because only the weights are published.
[0] https://medium.com/@thommaskevin/tinyml-quantization-aware-t...
> since our input data comes from multiple sensors and the the output pose has six components (three spatial positions and three spatial rotations)
Typo: two "the"
For robotics/inverse pose applications, don't people usually use a 3x3 matrix (three rotations, three spatial) for coordinate representation? Otherwise you get weird gimbal lock issues (I think).
For my application I need just the translations and Euler angles. The range of poses is mechanically constrained so I don't have to worry about gimbal lock. But yeah, my limited understanding matches yours that other parameterizations are more useful in general contexts.
This post and interactive explanations have been on my backlog to read and internalize: https://thenumb.at/Exponential-Rotations/
(Also: Thanks for pointing out the typo, I just deployed a fix.)
Hey there op. I don't know what your sensors are measuring (distance to a point maybe? Or angle from a Valve lighthouse for inside-out tracking?)
But here's my "why didn't you just"
Since you have a forward simulation function (pose to measurements), why didn't you use an iterative solver to reverse it? Coordinate descent is easy to code and if you have a constrained range of poses you can probably just use multiple starting points to avoid getting stuck with a local minimum. Then use the last solution as a starting point for the next one to save iterations.
Sure it's not closed-form like an NN and it can still have pathological cases, but the code is a little more transparent
That's a reasonable idea, but unfortunately wouldn't work in my case since the simulation relies on a lot of scientific libraries in Python and I need the inversion to happen on the microcontroller.
When you say "coordinate descent" do you mean gradient descent? I.e., updating a potential pose using the gradient of a loss term (e.g., (predicted sensor reading - actual sensor reading)**2)?
I bet that would work, but a tricky part would be calculating gradients. I'm not sure if the Python libraries I'm using support that. My understanding is that automatic differentiation through libraries might be easier in a language like Julia where dual numbers flow through everything via the multiple dispatch mechanism.
Ah makes sense.
No, coordinate descent is a stupider gradient-optional method: https://en.wikipedia.org/wiki/Coordinate_descent
It's slow and sub-optimal, but the code is very easy to follow and you don't have to wonder whether your gradient is correct.
All of this is absurdly complicated. Exactly what I would expect from a new student who doesn't know what they're doing and has no one to teach them how do you engineering in a systematic manner. I don't mean this as an insult. I teach this stuff and have seen it hundreds of times.
You should look for "post training static quantization" also called . There are countless ways to quantize. This will quantize both the weights and the activations after training.
You're doing this on hard mode for no reason. This is typical and something I often need to break people out of. Optimizing for performance by doing custom things in Jax when you're a beginner is a terrible path to take.
Performance is not your problem. You're training a trivial network that would have run on a CPU 20 years ago.
There's no clear direction here, just trying complicated stuff in no logical order with no learning or dependencies between steps. You need to treat these problems as scientific experiments. What do I do to learn more about my domain, what do I change depending on the answer I get, etc. Not, now it's time to try something else random like jax.
Worse. You need to learn the key lesson in this space. Credit assignment for problems is extremely hard. If something isn't working why isn't it? Because of a bug? A hopeless problem? Using a crappy optimizer? Etc. That's why you should start in a framework that works and escape it later if you want.
Here's a simple plan to do this:
First forget about quantization. Use pytorch. Implement your trivial network in 5 lines. Train it with Adam. Make sure it works. Make sure your problem is solveable with the data that you have and the network you've chosen and your activation functions and the loss and the optimizer (use Adam, forget about this doing stuff by hand for now).
> Unless I had an expert guide who was absolutely sure it’d be straightforward (email me!), I’d avoid high-level frameworks like TensorFlow and PyTorch and instead implement the quantization-aware training myself.
This is exactly backwards. Unless you have an expert never implement anything yourself. If you don't have one, rely on what already exists. Because you can logically narrow down the options for what works and what's wrong. If you do it yourself you're always lost.
Once you have that working start backing off. Slowly change the working network into what you need. Step by step. At every step write down why you think your change is good and what you would do if it isn't. Then look at the results.
Forget about microflow-rs or whatever. Train with pytorch, export to onnx, generate c code for your onnx for inference.
Read the pytorch guide on PTSQ and use it.
I kind of see your point, but only in the context of working on time-sensitive task which others rely upon. But if it is hobby/educational project, what is wrong doing things by yourself? And resort to decomposing existing solution if you can't figure out why yours is not working?
There's nothing better for understanding something rather than trying to do that "something" from scratch yourself.
I think the point is that OP is learning things about a wide variety of topics that aren't really relevant to their stated goal, i.e. solving the sensor/state inference problem.
Which, as you say, can be valuable! There's nothing wrong with that. But the more complexity you add the less likely you are to actually solve the problem (all else being equal, some problems are just inherently complex).
Despite the tone this is excellent advice! I had similar impressions reading the article and was wondering if I missed something.
Targeting ONNX and using something like https://github.com/kraiskil/onnx2c as parent mentioned is good advice.
Well said. Thanks.