As a side product of some work I'm doing on automated circuit solving I've hacked into the Juce expression class the ability to output expressions as SSE intrinsics calls, so the class can now be used to automatically convert infix scalar expressions to prefix SSE intrinsics.
http://www.cytomic.com/files/dsp/juce_ExpressionSse.zip
I've also added a function called "balance" which tries to balance the expression tree as much as possible for parallel execution and also tries to remove divisions and replace them with multiplications. This is not a replacement for computer algegra simplification of complex expserssion, it just handles a few basic cases and does the last step of getting regular c code type expressions into SSE.
There are probably errors but it's generating useful expressions for me right now so I thought I would share it with everyone. Here are some examples of what it does:
input = a1 + a2 + a3 + a4 output = add_pd (add_pd (a1, a2), add_pd (a3, a4)) input = a1 + a2 + a3 - a4 output = add_pd (add_pd (a1, a2), sub_pd (a3, a4)) input = a1 - a2 - a3 + a4 output = sub_pd (add_pd (a1, a4), add_pd (a2, a3)) input = -a1 - a2 - a3 - a4 output = neg_pd (add_pd (add_pd (a1, a2), add_pd (a3, a4))) input = a1 - (-a2 + a3 - (-a4)) output = sub_pd (add_pd (a1, a2), add_pd (a3, a4)) input = a1 + a2 + a3 + a4 + a5 + a6 + a7 output = add_pd (add_pd (add_pd (a1, a2), add_pd (a3, a4)), add_pd (add_pd (a5, a6), a7)) input = m1*m2*m3*m4 output = mul_pd (mul_pd (m1, m2), mul_pd (m3, m4)) input = m1*m2*m3/m4 output = mul_pd (mul_pd (m1, m2), div_pd (m3, m4)) input = m1/m2/m3*m4 output = div_pd (mul_pd (m1, m4), mul_pd (m2, m3)) input = m1/m2/m3/m4 output = div_pd (m1, mul_pd (mul_pd (m2, m3), m4)) input = a1/a2/(a3/a4) output = div_pd (mul_pd (a1, a4), mul_pd (a2, a3))
And the code that generated the above:
StringArray exs; exs.add ("a1 + a2 + a3 + a4"); exs.add ("a1 + a2 + a3 - a4"); exs.add ("a1 - a2 - a3 + a4"); exs.add ("-a1 - a2 - a3 - a4"); exs.add ("a1 - (-a2 + a3 - (-a4))"); exs.add ("a1 + a2 + a3 + a4 + a5 + a6 + a7"); exs.add ("m1*m2*m3*m4"); exs.add ("m1*m2*m3/m4"); exs.add ("m1/m2/m3*m4"); exs.add ("m1/m2/m3/m4"); exs.add ("a1/a2/(a3/a4)"); ScopedPointer <Expression::SsePdFormat> format (new Expression::SsePdFormat); for (int i=0; i<exs.size (); i++) { std::cout << "input = " << exs[i] << "\n"; std::cout << "output = " << Expression (exs[i]).basicSimplify ().balance ().toString (format) << "\n\n"; }
I've added two calls to the expression class: basicSimplify and balance, basic simplify removes some basic identities like:
a*1=a, a/1=a, a*-1=-a, -1/a=-a, a*0=0, 0/a=0, a+0=a, a-0=a, a-(-b)=a+b, a+(-b)=a-b, -a-(-b)=b-a
And balance tries to balance out the left and right branches of an expression to make the evaluation as parrallel as possible, which only works when the resultant string is copy and pasted into a c file and then compiled with sse intrisics.
With a bit more work this could auto convert code into sse code. I've done a very basic converter that handles very limited cases as an example to get people up and running. If you wan't to handle more complicated code as the input you'll need to code it yourself and parse things properly instead of just basic hacking based around a single equal and semicolon sign on the same line like I've done:
String basicCodeToSse (const String& code) { String out; ScopedPointer <Expression::SsePdFormat> format (new Expression::SsePdFormat); StringArray lines; lines.addLines (code); for (int i=0; i<lines.size (); i++) { String line = lines[i]; if (line.contains ("=")) { String lhs = line.upToLastOccurrenceOf ("=", false, false); String rhs = line.substring (lhs.length ()+1).trim (); String toconvert = rhs.upToFirstOccurrenceOf (";", false, false); rhs = rhs.substring (toconvert.length ()); lhs.trimEnd (); juce_wchar c = lhs[lhs.length ()-1]; // check for += -= *= /= if (c == '+' || c == '-' || c == '*' || c == '/') { String variable = lhs.trim (); lhs = lhs.dropLastCharacters (1); toconvert = variable + "("+toconvert+")"; } String converted = Expression (toconvert).basicSimplify ().balance ().toString (format); if (lhs.contains ("double")) { lhs = lhs.replace ("double", "vec2d"); } lines.set (i, lhs + "= " + converted + rhs); } out << lines[i] + "\n"; } return out; }
And here is an example of calling this:
String code = " // can only handle end of line comments and single line expressions\n" " // can only handle a single variable before any asignments\n" " // no error checking at all\n" "\n" " const double a1 = a2*a3 - a4;\n" " double a5 = a1*a3 + a4*a5 + a6*a7 - a8;\n" " a5 *= a3 + a5;\n" " a5 -= a2 + a3;\n"; std::cout << basicCodeToSse (code);
And the output:
// can only handle end of line comments and single line expressions // can only handle a single variable before any asignments // no error checking at all const vec2d a1 = sub_pd (mul_pd (a2, a3), a4); vec2d a5 = add_pd (add_pd (mul_pd (a1, a3), mul_pd (a4, a5)), sub_pd (mul_pd (a6, a7), a8)); a5 = mul_pd (a5, add_pd (a3, a5)); a5 = sub_pd (a5, add_pd (a2, a3));
I use macros to help make things a tiny bit more readable in sse code by dropping all the _mm_ from in front of things, I use these by default, but you can change it to anything:
typedef __m128d vec2d;
typedef __m128 vec4f;
#define set_pd(x0,x1) _mm_set_pd(x1, x0)
#define set1_pd(x0) _mm_set1_pd(x0)
#define add_pd(x0,x1) _mm_add_pd(x0,x1)
#define sub_pd(x0,x1) _mm_sub_pd(x0,x1)
#define mul_pd(x0,x1) _mm_mul_pd(x0,x1)
#define div_pd(x0,x1) _mm_div_pd(x0,x1)
#define max_pd(x0,x1) _mm_max_pd(x0,x1)
#define min_pd(x0,x1) _mm_min_pd(x0,x1)
#define clip_pd(x0,lo,hi) _mm_min_pd(_mm_max_pd(x0,lo),hi)
#define cvtps_pd(x0) _mm_cvtps_pd(x0)
#define andnot_pd(x0,x1) _mm_andnot_pd(x0,x1)
#define cmplt_pd(x0,x1) _mm_cmplt_pd(x0,x1)
#define cmple_pd(x0,x1) _mm_cmple_pd(x0,x1)
#define cmpgt_pd(x0,x1) _mm_cmpgt_pd(x0,x1)
#define cmpge_pd(x0,x1) _mm_cmpge_pd(x0,x1)
#define or_pd(x0,x1) _mm_or_pd(x0,x1)
#define xor_pd(x0,x1) _mm_xor_pd(x0,x1)
#define and_pd(x0,x1) _mm_and_pd(x0,x1)
#define andnot_pd(x0,x1) _mm_andnot_pd(x0,x1)
#define setzero_pd() _mm_setzero_pd()
#define load_pd(x0) _mm_load_pd(x0)
#define loadu_pd(x0) _mm_loadu_pd(x0)
#define store_pd(x0,x1) _mm_store_pd(x0,x1)
#define storeu_pd(x0,x1) _mm_storeu_pd(x0,x1)
#define cvtlps_pd(x0) _mm_cvtps_pd(x0)
#define cvthps_pd(x0) _mm_cvtps_pd(_mm_movehl_ps(x0,x0))
#define set_ps(x0,x1,x2,x3) _mm_set_ps(x3, x2, x1, x0)
#define set1_ps(x0) _mm_set1_ps(x0)
#define add_ps(x0,x1) _mm_add_ps(x0,x1)
#define sub_ps(x0,x1) _mm_sub_ps(x0,x1)
#define mul_ps(x0,x1) _mm_mul_ps(x0,x1)
#define div_ps(x0,x1) _mm_div_ps(x0,x1)
#define rcp_ps(x0) _mm_rcp_ps(x0)
#define max_ps(x0,x1) _mm_max_ps(x0,x1)
#define min_ps(x0,x1) _mm_min_ps(x0,x1)
#define clip_ps(x0,lo,hi) _mm_min_ps(_mm_max_ps(x0,lo),hi)
#define and_ps(x0,x1) _mm_and_ps(x0,x1)
#define andnot_ps(x0,x1) _mm_andnot_ps(x0,x1)
#define or_ps(x0,x1) _mm_or_ps(x0,x1)
#define xor_ps(x0,x1) _mm_xor_ps(x0,x1)
#define movelh_ps(x0,x1) _mm_movelh_ps(x0,x1)
#define movehl_ps(x0,x1) _mm_movehl_ps(x0,x1)
#define cvtepi32_ps(x0) _mm_cvtepi32_ps(x0)
#define cvtpd_ps(x0) _mm_cvtpd_ps(x0)
#define cmplt_ps(x0,x1) _mm_cmplt_ps(x0,x1)
#define cmple_ps(x0,x1) _mm_cmple_ps(x0,x1)
#define cmpgt_ps(x0,x1) _mm_cmpgt_ps(x0,x1)
#define cmpge_ps(x0,x1) _mm_cmpge_ps(x0,x1)
#define setzero_ps() _mm_setzero_ps()
#define load_ps(x0) _mm_load_ps(x0)
#define loadu_ps(x0) _mm_loadu_ps(x0)
#define store_ps(x0,x1) _mm_store_ps(x0,x1)
#define storeu_ps(x0,x1) _mm_storeu_ps(x0,x1)
#define cvtpd2_ps(x0,x1) _mm_movelh_ps(_mm_cvtpd_ps(x0),_mm_cvtpd_ps(x1))
#define cvtps_epi32(x0) _mm_cvtps_epi32(x0)
const vec2d zero_pd = set1_pd ( 0.0);
const vec2d negzero_pd = set1_pd (-0.0);
const vec2d half_pd = set1_pd ( 0.5);
const vec2d neghalf_pd = set1_pd (-0.5);
const vec2d one_pd = set1_pd ( 1.0);
const vec2d negone_pd = set1_pd (-1.0);
const vec2d two_pd = set1_pd ( 2.0);
const vec2d negtwo_pd = set1_pd (-2.0);
const vec2d pi_pd = set1_pd (3.141592653589793);
const vec2d halfpi_pd = set1_pd (0.5*3.141592653589793);
const vec2d twopi_pd = set1_pd (2.0*3.141592653589793);
const vec4f zero_ps = set1_ps ( 0.0f);
const vec4f negzero_ps = set1_ps (-0.0f);
const vec4f half_ps = set1_ps ( 0.5f);
const vec4f one_ps = set1_ps ( 1.0f);
const vec4f negone_ps = set1_ps (-1.0f);
const vec4f two_ps = set1_ps ( 2.0f);
const vec4f negtwo_ps = set1_ps (-2.0f);
const vec4f pi_ps = set1_ps (3.141592653589793f);
const vec4f halfpi_ps = set1_ps (0.5f*3.141592653589793f);
const vec4f twopi_ps = set1_ps (2.0f*3.141592653589793f);
#define abs_ps(x0) andnot_ps(negzero_ps, x0)
#define neg_ps(x0) xor_ps(negzero_ps, x0)
#define abs_pd(x0) andnot_pd(negzero_pd, x0)
#define neg_pd(x0) xor_pd(negzero_pd, x0)
__inline vec2d sqr_pd (vec2d x0)
{
return mul_pd (x0, x0);
}
__inline vec4f sqr_ps (vec4f x0)
{
return mul_ps (x0, x0);
}