shalinib commited on
Commit
3e2a209
·
1 Parent(s): b305121

ggml : refactor llamafile_sgemm PPC code (llama/14673)

Browse files

Remove un-necessary templates from class definition and packing functions
Reduce deeply nested conditionals, if-else switching in mnapck function
Replace repetitive code with inline functions in Packing functions

2 ~ 7% improvement in Q8 Model
15 ~ 50% improvement in Q4 Model

Signed-off-by: Shalini Salomi Bodapati <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
ggml/src/ggml-cpu/llamafile/sgemm.cpp CHANGED
@@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC {
1541
  } else if constexpr(RM == 8 && RN == 4) {
1542
  KERNEL_8x4(ii,jj);
1543
  } else {
1544
- static_assert(false, "RN/RM values not supported");
1545
  }
1546
  }
1547
 
@@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC {
1573
  const int nth;
1574
  };
1575
 
1576
- template <typename TA, typename TB, typename TC>
1577
  class tinyBLAS_Q0_PPC {
1578
  public:
1579
  tinyBLAS_Q0_PPC(int64_t k,
1580
  const TA *A, int64_t lda,
1581
- const TB *B, int64_t ldb,
1582
- TC *C, int64_t ldc,
1583
  int ith, int nth)
1584
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585
  }
@@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC {
1590
 
1591
  private:
1592
 
1593
- template<int RM, int RN>
1594
- inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1595
  for (int I = 0; I < RM; I++) {
1596
  for (int J = 0; J < RN; J++) {
1597
  *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
@@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC {
1611
  fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1612
  }
1613
  }
1614
-
1615
- template<typename VA, typename VB, int size>
1616
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1617
- int64_t i, j;
1618
- TA *aoffset = NULL;
1619
- VA *vecOffset = NULL;
1620
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1621
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1622
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1623
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1624
- VB t1, t2, t3, t4, t5, t6, t7, t8;
1625
  const vector signed char lowMask = vec_splats((signed char)0xF);
1626
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1627
  const vector signed char v8 = vec_splats((signed char)0x8);
1628
- aoffset = const_cast<TA*>(a);
1629
- vecOffset = vec;
 
 
 
 
 
 
 
 
 
 
 
 
1630
  vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1631
  vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1632
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1633
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1634
- vector signed int vsum = {0};
1635
- vector signed int vsum2 = {0};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1636
 
 
 
 
 
 
 
 
 
 
 
 
1637
  j = (rows >> 3);
1638
  if (j > 0) {
1639
  do {
@@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC {
1646
  aoffset7 = aoffset6 + lda;
1647
  aoffset8 = aoffset7 + lda;
1648
  aoffset += 8 * lda;
1649
-
1650
  i = (cols >> 2);
1651
  if (i > 0) {
1652
  do {
1653
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1654
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1655
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1656
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1657
- c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1658
- c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1659
- c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1660
- c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1661
-
1662
- c1[0] = vec_and(c1[1], lowMask);
1663
- c1[1] = vec_sr(c1[1], v4);
1664
- c1[0] = vec_sub(c1[0], v8);
1665
- c1[1] = vec_sub(c1[1], v8);
1666
- vsum = vec_sum4s(c1[0], vsum);
1667
- vsum2 = vec_sum4s(c1[1], vsum2);
1668
- vsum = vec_add(vsum, vsum2);
1669
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1670
- vsum = vec_splats(0);
1671
- vsum2 = vec_splats(0);
1672
-
1673
- c2[0] = vec_and(c2[1], lowMask);
1674
- c2[1] = vec_sr(c2[1], v4);
1675
- c2[0] = vec_sub(c2[0], v8);
1676
- c2[1] = vec_sub(c2[1], v8);
1677
- vsum = vec_sum4s(c2[0], vsum);
1678
- vsum2 = vec_sum4s(c2[1], vsum2);
1679
- vsum = vec_add(vsum, vsum2);
1680
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1681
- vsum = vec_splats(0);
1682
- vsum2 = vec_splats(0);
1683
-
1684
- c3[0] = vec_and(c3[1], lowMask);
1685
- c3[1] = vec_sr(c3[1], v4);
1686
- c3[0] = vec_sub(c3[0], v8);
1687
- c3[1] = vec_sub(c3[1], v8);
1688
- vsum = vec_sum4s(c3[0], vsum);
1689
- vsum2 = vec_sum4s(c3[1], vsum2);
1690
- vsum = vec_add(vsum, vsum2);
1691
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1692
- vsum = vec_splats(0);
1693
- vsum2 = vec_splats(0);
1694
-
1695
- c4[0] = vec_and(c4[1], lowMask);
1696
- c4[1] = vec_sr(c4[1], v4);
1697
- c4[0] = vec_sub(c4[0], v8);
1698
- c4[1] = vec_sub(c4[1], v8);
1699
- vsum = vec_sum4s(c4[0], vsum);
1700
- vsum2 = vec_sum4s(c4[1], vsum2);
1701
- vsum = vec_add(vsum, vsum2);
1702
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1703
- vsum = vec_splats(0);
1704
- vsum2 = vec_splats(0);
1705
-
1706
- c5[0] = vec_and(c5[1], lowMask);
1707
- c5[1] = vec_sr(c5[1], v4);
1708
- c5[0] = vec_sub(c5[0], v8);
1709
- c5[1] = vec_sub(c5[1], v8);
1710
- vsum = vec_sum4s(c5[0], vsum);
1711
- vsum2 = vec_sum4s(c5[1], vsum2);
1712
- vsum = vec_add(vsum, vsum2);
1713
- comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1714
- vsum = vec_splats(0);
1715
- vsum2 = vec_splats(0);
1716
-
1717
- c6[0] = vec_and(c6[1], lowMask);
1718
- c6[1] = vec_sr(c6[1], v4);
1719
- c6[0] = vec_sub(c6[0], v8);
1720
- c6[1] = vec_sub(c6[1], v8);
1721
- vsum = vec_sum4s(c6[0], vsum);
1722
- vsum2 = vec_sum4s(c6[1], vsum2);
1723
- vsum = vec_add(vsum, vsum2);
1724
- comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1725
- vsum = vec_splats(0);
1726
- vsum2 = vec_splats(0);
1727
-
1728
- c7[0] = vec_and(c7[1], lowMask);
1729
- c7[1] = vec_sr(c7[1], v4);
1730
- c7[0] = vec_sub(c7[0], v8);
1731
- c7[1] = vec_sub(c7[1], v8);
1732
- vsum = vec_sum4s(c7[0], vsum);
1733
- vsum2 = vec_sum4s(c7[1], vsum2);
1734
- vsum = vec_add(vsum, vsum2);
1735
- comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1736
- vsum = vec_splats(0);
1737
- vsum2 = vec_splats(0);
1738
-
1739
- c8[0] = vec_and(c8[1], lowMask);
1740
- c8[1] = vec_sr(c8[1], v4);
1741
- c8[0] = vec_sub(c8[0], v8);
1742
- c8[1] = vec_sub(c8[1], v8);
1743
- vsum = vec_sum4s(c8[0], vsum);
1744
- vsum2 = vec_sum4s(c8[1], vsum2);
1745
- vsum = vec_add(vsum, vsum2);
1746
- comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1747
- vsum = vec_splats(0);
1748
- vsum2 = vec_splats(0);
1749
-
1750
- t1 = vec_perm(c1[0], c2[0], swiz1);
1751
- t2 = vec_perm(c1[0], c2[0], swiz2);
1752
- t3 = vec_perm(c3[0], c4[0], swiz1);
1753
- t4 = vec_perm(c3[0], c4[0], swiz2);
1754
- t5 = vec_perm(t1, t3, swiz3);
1755
- t6 = vec_perm(t1, t3, swiz4);
1756
- t7 = vec_perm(t2, t4, swiz3);
1757
- t8 = vec_perm(t2, t4, swiz4);
1758
- vec_xst(t5, 0, vecOffset);
1759
- vec_xst(t6, 0, vecOffset+16);
1760
- vec_xst(t7, 0, vecOffset+32);
1761
- vec_xst(t8, 0, vecOffset+48);
1762
-
1763
- t1 = vec_perm(c1[1], c2[1], swiz1);
1764
- t2 = vec_perm(c1[1], c2[1], swiz2);
1765
- t3 = vec_perm(c3[1], c4[1], swiz1);
1766
- t4 = vec_perm(c3[1], c4[1], swiz2);
1767
- t5 = vec_perm(t1, t3, swiz3);
1768
- t6 = vec_perm(t1, t3, swiz4);
1769
- t7 = vec_perm(t2, t4, swiz3);
1770
- t8 = vec_perm(t2, t4, swiz4);
1771
- vec_xst(t5, 0, vecOffset+64);
1772
- vec_xst(t6, 0, vecOffset+80);
1773
- vec_xst(t7, 0, vecOffset+96);
1774
- vec_xst(t8, 0, vecOffset+112);
1775
-
1776
- t1 = vec_perm(c5[0], c6[0], swiz1);
1777
- t2 = vec_perm(c5[0], c6[0], swiz2);
1778
- t3 = vec_perm(c7[0], c8[0], swiz1);
1779
- t4 = vec_perm(c7[0], c8[0], swiz2);
1780
- t5 = vec_perm(t1, t3, swiz3);
1781
- t6 = vec_perm(t1, t3, swiz4);
1782
- t7 = vec_perm(t2, t4, swiz3);
1783
- t8 = vec_perm(t2, t4, swiz4);
1784
- vec_xst(t5, 0, vecOffset+128);
1785
- vec_xst(t6, 0, vecOffset+144);
1786
- vec_xst(t7, 0, vecOffset+160);
1787
- vec_xst(t8, 0, vecOffset+176);
1788
-
1789
- t1 = vec_perm(c5[1], c6[1], swiz1);
1790
- t2 = vec_perm(c5[1], c6[1], swiz2);
1791
- t3 = vec_perm(c7[1], c8[1], swiz1);
1792
- t4 = vec_perm(c7[1], c8[1], swiz2);
1793
- t5 = vec_perm(t1, t3, swiz3);
1794
- t6 = vec_perm(t1, t3, swiz4);
1795
- t7 = vec_perm(t2, t4, swiz3);
1796
- t8 = vec_perm(t2, t4, swiz4);
1797
- vec_xst(t5, 0, vecOffset+192);
1798
- vec_xst(t6, 0, vecOffset+208);
1799
- vec_xst(t7, 0, vecOffset+224);
1800
- vec_xst(t8, 0, vecOffset+240);
1801
-
1802
  aoffset1 += lda;
1803
  aoffset2 += lda;
1804
  aoffset3 += lda;
@@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC {
1821
  aoffset3 = aoffset2 + lda;
1822
  aoffset4 = aoffset3 + lda;
1823
  aoffset += 4 * lda;
1824
-
1825
  i = (cols >> 2);
1826
  if (i > 0) {
1827
  do {
1828
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1829
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1830
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1831
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1832
-
1833
- c1[0] = vec_and(c1[1], lowMask);
1834
- c1[1] = vec_sr(c1[1], v4);
1835
- c1[0] = vec_sub(c1[0], v8);
1836
- c1[1] = vec_sub(c1[1], v8);
1837
- vsum = vec_sum4s(c1[0], vsum);
1838
- vsum2 = vec_sum4s(c1[1], vsum2);
1839
- vsum = vec_add(vsum, vsum2);
1840
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1841
- vsum = vec_splats(0);
1842
- vsum2 = vec_splats(0);
1843
-
1844
- c2[0] = vec_and(c2[1], lowMask);
1845
- c2[1] = vec_sr(c2[1], v4);
1846
- c2[0] = vec_sub(c2[0], v8);
1847
- c2[1] = vec_sub(c2[1], v8);
1848
- vsum = vec_sum4s(c2[0], vsum);
1849
- vsum2 = vec_sum4s(c2[1], vsum2);
1850
- vsum = vec_add(vsum, vsum2);
1851
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1852
- vsum = vec_splats(0);
1853
- vsum2 = vec_splats(0);
1854
-
1855
- c3[0] = vec_and(c3[1], lowMask);
1856
- c3[1] = vec_sr(c3[1], v4);
1857
- c3[0] = vec_sub(c3[0], v8);
1858
- c3[1] = vec_sub(c3[1], v8);
1859
- vsum = vec_sum4s(c3[0], vsum);
1860
- vsum2 = vec_sum4s(c3[1], vsum2);
1861
- vsum = vec_add(vsum, vsum2);
1862
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1863
- vsum = vec_splats(0);
1864
- vsum2 = vec_splats(0);
1865
-
1866
- c4[0] = vec_and(c4[1], lowMask);
1867
- c4[1] = vec_sr(c4[1], v4);
1868
- c4[0] = vec_sub(c4[0], v8);
1869
- c4[1] = vec_sub(c4[1], v8);
1870
- vsum = vec_sum4s(c4[0], vsum);
1871
- vsum2 = vec_sum4s(c4[1], vsum2);
1872
- vsum = vec_add(vsum, vsum2);
1873
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1874
- vsum = vec_splats(0);
1875
- vsum2 = vec_splats( 0);
1876
-
1877
- t1 = vec_perm(c1[0], c2[0], swiz1);
1878
- t2 = vec_perm(c1[0], c2[0], swiz2);
1879
- t3 = vec_perm(c3[0], c4[0], swiz1);
1880
- t4 = vec_perm(c3[0], c4[0], swiz2);
1881
- t5 = vec_perm(t1, t3, swiz3);
1882
- t6 = vec_perm(t1, t3, swiz4);
1883
- t7 = vec_perm(t2, t4, swiz3);
1884
- t8 = vec_perm(t2, t4, swiz4);
1885
- vec_xst(t5, 0, vecOffset);
1886
- vec_xst(t6, 0, vecOffset+16);
1887
- vec_xst(t7, 0, vecOffset+32);
1888
- vec_xst(t8, 0, vecOffset+48);
1889
-
1890
- t1 = vec_perm(c1[1], c2[1], swiz1);
1891
- t2 = vec_perm(c1[1], c2[1], swiz2);
1892
- t3 = vec_perm(c3[1], c4[1], swiz1);
1893
- t4 = vec_perm(c3[1], c4[1], swiz2);
1894
- t5 = vec_perm(t1, t3, swiz3);
1895
- t6 = vec_perm(t1, t3, swiz4);
1896
- t7 = vec_perm(t2, t4, swiz3);
1897
- t8 = vec_perm(t2, t4, swiz4);
1898
- vec_xst(t5, 0, vecOffset+64);
1899
- vec_xst(t6, 0, vecOffset+80);
1900
- vec_xst(t7, 0, vecOffset+96);
1901
- vec_xst(t8, 0, vecOffset+112);
1902
-
1903
  aoffset1 += lda;
1904
  aoffset2 += lda;
1905
  aoffset3 += lda;
@@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC {
1918
  if (i > 0) {
1919
  do {
1920
  switch(rows) {
1921
- case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1922
- case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1923
- case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1924
  break;
1925
  }
1926
- c1[0] = vec_and(c1[1], lowMask);
1927
- c1[1] = vec_sr(c1[1], v4);
1928
- c1[0] = vec_sub(c1[0], v8);
1929
- c1[1] = vec_sub(c1[1], v8);
1930
- vsum = vec_sum4s(c1[0], vsum);
1931
- vsum2 = vec_sum4s(c1[1], vsum2);
1932
- vsum = vec_add(vsum, vsum2);
1933
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1934
- vsum = vec_splats(0);
1935
- vsum2 = vec_splats(0);
1936
-
1937
- c2[0] = vec_and(c2[1], lowMask);
1938
- c2[1] = vec_sr(c2[1], v4);
1939
- c2[0] = vec_sub(c2[0], v8);
1940
- c2[1] = vec_sub(c2[1], v8);
1941
- vsum = vec_sum4s(c2[0], vsum);
1942
- vsum2 = vec_sum4s(c2[1], vsum2);
1943
- vsum = vec_add(vsum, vsum2);
1944
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1945
- vsum = vec_splats(0);
1946
- vsum2 = vec_splats(0);
1947
-
1948
- c3[0] = vec_and(c3[1], lowMask);
1949
- c3[1] = vec_sr(c3[1], v4);
1950
- c3[0] = vec_sub(c3[0], v8);
1951
- c3[1] = vec_sub(c3[1], v8);
1952
- vsum = vec_sum4s(c3[0], vsum);
1953
- vsum2 = vec_sum4s(c3[1], vsum2);
1954
- vsum = vec_add(vsum, vsum2);
1955
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1956
- vsum = vec_splats(0);
1957
- vsum2 = vec_splats(0);
1958
-
1959
- c4[0] = vec_and(c4[1], lowMask);
1960
- c4[1] = vec_sr(c4[1], v4);
1961
- c4[0] = vec_sub(c4[0], v8);
1962
- c4[1] = vec_sub(c4[1], v8);
1963
- vsum = vec_sum4s(c4[0], vsum);
1964
- vsum2 = vec_sum4s(c4[1], vsum2);
1965
- vsum = vec_add(vsum, vsum2);
1966
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1967
- vsum = vec_splats(0);
1968
- vsum2 = vec_splats(0);
1969
-
1970
- t1 = vec_perm(c1[0], c2[0], swiz1);
1971
- t2 = vec_perm(c1[0], c2[0], swiz2);
1972
- t3 = vec_perm(c3[0], c4[0], swiz1);
1973
- t4 = vec_perm(c3[0], c4[0], swiz2);
1974
- t5 = vec_perm(t1, t3, swiz3);
1975
- t6 = vec_perm(t1, t3, swiz4);
1976
- t7 = vec_perm(t2, t4, swiz3);
1977
- t8 = vec_perm(t2, t4, swiz4);
1978
- vec_xst(t5, 0, vecOffset);
1979
- vec_xst(t6, 0, vecOffset+16);
1980
- vec_xst(t7, 0, vecOffset+32);
1981
- vec_xst(t8, 0, vecOffset+48);
1982
-
1983
- t1 = vec_perm(c1[1], c2[1], swiz1);
1984
- t2 = vec_perm(c1[1], c2[1], swiz2);
1985
- t3 = vec_perm(c3[1], c4[1], swiz1);
1986
- t4 = vec_perm(c3[1], c4[1], swiz2);
1987
- t5 = vec_perm(t1, t3, swiz3);
1988
- t6 = vec_perm(t1, t3, swiz4);
1989
- t7 = vec_perm(t2, t4, swiz3);
1990
- t8 = vec_perm(t2, t4, swiz4);
1991
- vec_xst(t5, 0, vecOffset+64);
1992
- vec_xst(t6, 0, vecOffset+80);
1993
- vec_xst(t7, 0, vecOffset+96);
1994
- vec_xst(t8, 0, vecOffset+112);
1995
  aoffset1 += lda;
1996
  aoffset2 += lda;
1997
  aoffset3 += lda;
@@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC {
2001
  }
2002
  }
2003
  }
2004
-
2005
  template<typename VA, typename VB>
2006
- void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2007
  int64_t i, j;
2008
- TB *aoffset = NULL;
2009
  VA *vecOffset = NULL;
2010
- TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2011
- TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2012
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2013
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
2014
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
2015
- VB t1, t2, t3, t4, t5, t6, t7, t8;
2016
- vector unsigned char xor_vector;
2017
- uint8_t flip_vec = 0x80;
2018
- xor_vector = vec_splats(flip_vec);
2019
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2020
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2021
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
2022
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
2023
-
2024
- aoffset = const_cast<TB*>(a);
2025
  vecOffset = vec;
2026
  j = (rows >> 3);
2027
  if (j > 0) {
2028
  do {
2029
- aoffset1 = aoffset;
2030
- aoffset2 = aoffset1 + lda;
2031
- aoffset3 = aoffset2 + lda;
2032
- aoffset4 = aoffset3 + lda;
2033
- aoffset5 = aoffset4 + lda;
2034
- aoffset6 = aoffset5 + lda;
2035
- aoffset7 = aoffset6 + lda;
2036
- aoffset8 = aoffset7 + lda;
2037
  aoffset += 8 * lda;
2038
 
2039
  i = (cols >> 3);
2040
  if (i > 0) {
2041
  do {
2042
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2043
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2044
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2045
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2046
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
2047
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
2048
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
2049
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
2050
-
2051
- __builtin_vsx_disassemble_pair(c1, &C1);
2052
- __builtin_vsx_disassemble_pair(c2, &C2);
2053
- __builtin_vsx_disassemble_pair(c3, &C3);
2054
- __builtin_vsx_disassemble_pair(c4, &C4);
2055
- __builtin_vsx_disassemble_pair(c5, &C5);
2056
- __builtin_vsx_disassemble_pair(c6, &C6);
2057
- __builtin_vsx_disassemble_pair(c7, &C7);
2058
- __builtin_vsx_disassemble_pair(c8, &C8);
2059
-
2060
- t1 = vec_perm(c1[0], c2[0], swiz1);
2061
- t2 = vec_perm(c1[0], c2[0], swiz2);
2062
- t3 = vec_perm(c3[0], c4[0], swiz1);
2063
- t4 = vec_perm(c3[0], c4[0], swiz2);
2064
- t5 = vec_perm(t1, t3, swiz3);
2065
- t6 = vec_perm(t1, t3, swiz4);
2066
- t7 = vec_perm(t2, t4, swiz3);
2067
- t8 = vec_perm(t2, t4, swiz4);
2068
- if (flip == true) {
2069
- t5 = vec_xor(t5, xor_vector);
2070
- t6 = vec_xor(t6, xor_vector);
2071
- t7 = vec_xor(t7, xor_vector);
2072
- t8 = vec_xor(t8, xor_vector);
2073
- }
2074
- vec_xst(t5, 0, vecOffset);
2075
- vec_xst(t6, 0, vecOffset+16);
2076
- vec_xst(t7, 0, vecOffset+32);
2077
- vec_xst(t8, 0, vecOffset+48);
2078
-
2079
- t1 = vec_perm(c1[1], c2[1], swiz1);
2080
- t2 = vec_perm(c1[1], c2[1], swiz2);
2081
- t3 = vec_perm(c3[1], c4[1], swiz1);
2082
- t4 = vec_perm(c3[1], c4[1], swiz2);
2083
- t5 = vec_perm(t1, t3, swiz3);
2084
- t6 = vec_perm(t1, t3, swiz4);
2085
- t7 = vec_perm(t2, t4, swiz3);
2086
- t8 = vec_perm(t2, t4, swiz4);
2087
- if (flip == true) {
2088
- t5 = vec_xor(t5, xor_vector);
2089
- t6 = vec_xor(t6, xor_vector);
2090
- t7 = vec_xor(t7, xor_vector);
2091
- t8 = vec_xor(t8, xor_vector);
2092
- }
2093
- vec_xst(t5, 0, vecOffset+64);
2094
- vec_xst(t6, 0, vecOffset+80);
2095
- vec_xst(t7, 0, vecOffset+96);
2096
- vec_xst(t8, 0, vecOffset+112);
2097
-
2098
- t1 = vec_perm(c5[0], c6[0], swiz1);
2099
- t2 = vec_perm(c5[0], c6[0], swiz2);
2100
- t3 = vec_perm(c7[0], c8[0], swiz1);
2101
- t4 = vec_perm(c7[0], c8[0], swiz2);
2102
- t5 = vec_perm(t1, t3, swiz3);
2103
- t6 = vec_perm(t1, t3, swiz4);
2104
- t7 = vec_perm(t2, t4, swiz3);
2105
- t8 = vec_perm(t2, t4, swiz4);
2106
- if (flip == true) {
2107
- t5 = vec_xor(t5, xor_vector);
2108
- t6 = vec_xor(t6, xor_vector);
2109
- t7 = vec_xor(t7, xor_vector);
2110
- t8 = vec_xor(t8, xor_vector);
2111
- }
2112
- vec_xst(t5, 0, vecOffset+128);
2113
- vec_xst(t6, 0, vecOffset+144);
2114
- vec_xst(t7, 0, vecOffset+160);
2115
- vec_xst(t8, 0, vecOffset+176);
2116
-
2117
- t1 = vec_perm(c5[1], c6[1], swiz1);
2118
- t2 = vec_perm(c5[1], c6[1], swiz2);
2119
- t3 = vec_perm(c7[1], c8[1], swiz1);
2120
- t4 = vec_perm(c7[1], c8[1], swiz2);
2121
- t5 = vec_perm(t1, t3, swiz3);
2122
- t6 = vec_perm(t1, t3, swiz4);
2123
- t7 = vec_perm(t2, t4, swiz3);
2124
- t8 = vec_perm(t2, t4, swiz4);
2125
- if (flip == true) {
2126
- t5 = vec_xor(t5, xor_vector);
2127
- t6 = vec_xor(t6, xor_vector);
2128
- t7 = vec_xor(t7, xor_vector);
2129
- t8 = vec_xor(t8, xor_vector);
2130
  }
2131
- vec_xst(t5, 0, vecOffset+192);
2132
- vec_xst(t6, 0, vecOffset+208);
2133
- vec_xst(t7, 0, vecOffset+224);
2134
- vec_xst(t8, 0, vecOffset+240);
2135
-
2136
- aoffset1 += lda;
2137
- aoffset2 += lda;
2138
- aoffset3 += lda;
2139
- aoffset4 += lda;
2140
- aoffset5 += lda;
2141
- aoffset6 += lda;
2142
- aoffset7 += lda;
2143
- aoffset8 += lda;
2144
  vecOffset += 256;
2145
  i--;
2146
  } while(i > 0);
@@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC {
2150
  }
2151
 
2152
  if (rows & 4) {
2153
- aoffset1 = aoffset;
2154
- aoffset2 = aoffset1 + lda;
2155
- aoffset3 = aoffset2 + lda;
2156
- aoffset4 = aoffset3 + lda;
2157
- aoffset += 4 * lda;
2158
-
2159
  i = (cols >> 3);
2160
  if (i > 0) {
2161
  do {
2162
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2163
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2164
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2165
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2166
-
2167
- __builtin_vsx_disassemble_pair(c1, &C1);
2168
- __builtin_vsx_disassemble_pair(c2, &C2);
2169
- __builtin_vsx_disassemble_pair(c3, &C3);
2170
- __builtin_vsx_disassemble_pair(c4, &C4);
2171
-
2172
- t1 = vec_perm(c1[0], c2[0], swiz1);
2173
- t2 = vec_perm(c1[0], c2[0], swiz2);
2174
- t3 = vec_perm(c3[0], c4[0], swiz1);
2175
- t4 = vec_perm(c3[0], c4[0], swiz2);
2176
- t5 = vec_perm(t1, t3, swiz3);
2177
- t6 = vec_perm(t1, t3, swiz4);
2178
- t7 = vec_perm(t2, t4, swiz3);
2179
- t8 = vec_perm(t2, t4, swiz4);
2180
- if (flip == true) {
2181
- t5 = vec_xor(t5, xor_vector);
2182
- t6 = vec_xor(t6, xor_vector);
2183
- t7 = vec_xor(t7, xor_vector);
2184
- t8 = vec_xor(t8, xor_vector);
2185
  }
2186
- vec_xst(t5, 0, vecOffset);
2187
- vec_xst(t6, 0, vecOffset+16);
2188
- vec_xst(t7, 0, vecOffset+32);
2189
- vec_xst(t8, 0, vecOffset+48);
2190
-
2191
- t1 = vec_perm(c1[1], c2[1], swiz1);
2192
- t2 = vec_perm(c1[1], c2[1], swiz2);
2193
- t3 = vec_perm(c3[1], c4[1], swiz1);
2194
- t4 = vec_perm(c3[1], c4[1], swiz2);
2195
- t5 = vec_perm(t1, t3, swiz3);
2196
- t6 = vec_perm(t1, t3, swiz4);
2197
- t7 = vec_perm(t2, t4, swiz3);
2198
- t8 = vec_perm(t2, t4, swiz4);
2199
- if (flip == true) {
2200
- t5 = vec_xor(t5, xor_vector);
2201
- t6 = vec_xor(t6, xor_vector);
2202
- t7 = vec_xor(t7, xor_vector);
2203
- t8 = vec_xor(t8, xor_vector);
2204
  }
2205
- vec_xst(t5, 0, vecOffset+64);
2206
- vec_xst(t6, 0, vecOffset+80);
2207
- vec_xst(t7, 0, vecOffset+96);
2208
- vec_xst(t8, 0, vecOffset+112);
2209
-
2210
- aoffset1 += lda;
2211
- aoffset2 += lda;
2212
- aoffset3 += lda;
2213
- aoffset4 += lda;
2214
  vecOffset += 128;
2215
  i--;
2216
  } while(i > 0);
2217
  }
2218
  }
 
2219
  if (rows & 3) {
2220
- aoffset1 = aoffset;
2221
- aoffset2 = aoffset1 + lda;
2222
- aoffset3 = aoffset2 + lda;
2223
  i = (cols >> 3);
2224
  if (i > 0) {
2225
  do {
2226
  switch(rows) {
2227
- case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2228
- __builtin_vsx_disassemble_pair(c3, &C3);
2229
- case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2230
- __builtin_vsx_disassemble_pair(c2, &C2);
2231
- case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2232
- __builtin_vsx_disassemble_pair(c1, &C1);
 
 
 
2233
  break;
2234
  }
2235
- t1 = vec_perm(c1[0], c2[0], swiz1);
2236
- t2 = vec_perm(c1[0], c2[0], swiz2);
2237
- t3 = vec_perm(c3[0], c4[0], swiz1);
2238
- t4 = vec_perm(c3[0], c4[0], swiz2);
2239
- t5 = vec_perm(t1, t3, swiz3);
2240
- t6 = vec_perm(t1, t3, swiz4);
2241
- t7 = vec_perm(t2, t4, swiz3);
2242
- t8 = vec_perm(t2, t4, swiz4);
2243
- if (flip == true) {
2244
- t5 = vec_xor(t5, xor_vector);
2245
- t6 = vec_xor(t6, xor_vector);
2246
- t7 = vec_xor(t7, xor_vector);
2247
- t8 = vec_xor(t8, xor_vector);
2248
- }
2249
- vec_xst(t5, 0, vecOffset);
2250
- vec_xst(t6, 0, vecOffset+16);
2251
- vec_xst(t7, 0, vecOffset+32);
2252
- vec_xst(t8, 0, vecOffset+48);
2253
-
2254
- t1 = vec_perm(c1[1], c2[1], swiz1);
2255
- t2 = vec_perm(c1[1], c2[1], swiz2);
2256
- t3 = vec_perm(c3[1], c4[1], swiz1);
2257
- t4 = vec_perm(c3[1], c4[1], swiz2);
2258
- t5 = vec_perm(t1, t3, swiz3);
2259
- t6 = vec_perm(t1, t3, swiz4);
2260
- t7 = vec_perm(t2, t4, swiz3);
2261
- t8 = vec_perm(t2, t4, swiz4);
2262
- if (flip == true) {
2263
- t5 = vec_xor(t5, xor_vector);
2264
- t6 = vec_xor(t6, xor_vector);
2265
- t7 = vec_xor(t7, xor_vector);
2266
- t8 = vec_xor(t8, xor_vector);
2267
- }
2268
- vec_xst(t5, 0, vecOffset+64);
2269
- vec_xst(t6, 0, vecOffset+80);
2270
- vec_xst(t7, 0, vecOffset+96);
2271
- vec_xst(t8, 0, vecOffset+112);
2272
-
2273
- aoffset1 += lda;
2274
- aoffset2 += lda;
2275
- aoffset3 += lda;
2276
  vecOffset += 128;
2277
  i--;
2278
  } while(i > 0);
@@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC {
2281
  }
2282
 
2283
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2284
- int64_t mc, nc, mp, np;
2285
- int m_rem = MIN(m - m0, 8);
2286
- int n_rem = MIN(n - n0, 8);
2287
- // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
2288
- // issues. After resolving them, below code will be enabled.
2289
- /*if (m_rem >= 16 && n_rem >= 8) {
2290
- mc = 16;
2291
- nc = 8;
2292
- gemm<16,8>(m0, m, n0, n);
2293
- } else if(m_rem >= 8 && n_rem >= 16) {
2294
- mc = 8;
2295
- nc = 16;
2296
- gemm<8,16>(m0, m, n0, n);
2297
- }*/
2298
  if (m_rem >= 8 && n_rem >= 8) {
2299
- mc = 8;
2300
- nc = 8;
2301
- gemm<8,8>(m0, m, n0, n);
2302
  } else if (m_rem >= 4 && n_rem >= 8) {
2303
  mc = 4;
2304
  nc = 8;
2305
- gemm<4,8>(m0, m, n0, n);
2306
  } else if (m_rem >= 8 && n_rem >= 4) {
2307
  mc = 8;
2308
  nc = 4;
2309
- gemm<8,4>(m0, m, n0, n);
2310
  } else if (m_rem >= 4 && n_rem >= 4) {
2311
  mc = 4;
2312
  nc = 4;
2313
- gemm_small<4, 4>(m0, m, n0, n);
2314
- } else if ((m_rem < 4) && (n_rem > 4)) {
2315
- nc = 4;
2316
- switch(m_rem) {
2317
- case 1:
2318
- mc = 1;
2319
- gemm_small<1, 4>(m0, m, n0, n);
2320
- break;
2321
- case 2:
2322
- mc = 2;
2323
- gemm_small<2, 4>(m0, m, n0, n);
2324
- break;
2325
- case 3:
2326
- mc = 3;
2327
- gemm_small<3, 4>(m0, m, n0, n);
2328
- break;
2329
- default:
2330
- return;
2331
- }
2332
- } else if ((m_rem > 4) && (n_rem < 4)) {
2333
- mc = 4;
2334
- switch(n_rem) {
2335
- case 1:
2336
- nc = 1;
2337
- gemm_small<4, 1>(m0, m, n0, n);
2338
- break;
2339
- case 2:
2340
- nc = 2;
2341
- gemm_small<4, 2>(m0, m, n0, n);
2342
- break;
2343
- case 3:
2344
- nc = 3;
2345
- gemm_small<4, 3>(m0, m, n0, n);
2346
- break;
2347
- default:
2348
- return;
2349
- }
2350
  } else {
2351
- switch((m_rem << 4) | n_rem) {
2352
- case 0x43:
2353
- mc = 4;
2354
- nc = 3;
2355
- gemm_small<4, 3>(m0, m, n0, n);
2356
- break;
2357
- case 0x42:
2358
- mc = 4;
2359
- nc = 2;
2360
- gemm_small<4, 2>(m0, m, n0, n);
2361
- break;
2362
- case 0x41:
2363
- mc = 4;
2364
- nc = 1;
2365
- gemm_small<4, 1>(m0, m, n0, n);
2366
- break;
2367
- case 0x34:
2368
- mc = 3;
2369
- nc = 4;
2370
- gemm_small<3, 4>(m0, m, n0, n);
2371
- break;
2372
- case 0x33:
2373
- mc = 3;
2374
- nc = 3;
2375
- gemm_small<3, 3>(m0, m, n0, n);
2376
- break;
2377
- case 0x32:
2378
- mc = 3;
2379
- nc = 2;
2380
- gemm_small<3, 2>(m0, m, n0, n);
2381
- break;
2382
- case 0x31:
2383
- mc = 3;
2384
- nc = 1;
2385
- gemm_small<3, 1>(m0, m, n0, n);
2386
- break;
2387
- case 0x24:
2388
- mc = 2;
2389
- nc = 4;
2390
- gemm_small<2, 4>(m0, m, n0, n);
2391
- break;
2392
- case 0x23:
2393
- mc = 2;
2394
- nc = 3;
2395
- gemm_small<2, 3>(m0, m, n0, n);
2396
- break;
2397
- case 0x22:
2398
- mc = 2;
2399
- nc = 2;
2400
- gemm_small<2, 2>(m0, m, n0, n);
2401
- break;
2402
- case 0x21:
2403
- mc = 2;
2404
- nc = 1;
2405
- gemm_small<2, 1>(m0, m, n0, n);
2406
- break;
2407
- case 0x14:
2408
- mc = 1;
2409
- nc = 4;
2410
- gemm_small<1, 4>(m0, m, n0, n);
2411
- break;
2412
- case 0x13:
2413
- mc = 1;
2414
- nc = 3;
2415
- gemm_small<1, 3>(m0, m, n0, n);
2416
- break;
2417
- case 0x12:
2418
- mc = 1;
2419
- nc = 2;
2420
- gemm_small<1, 2>(m0, m, n0, n);
2421
- break;
2422
- case 0x11:
2423
- mc = 1;
2424
- nc = 1;
2425
- gemm_small<1, 1>(m0, m, n0, n);
2426
- break;
2427
- default:
2428
- return;
2429
- }
2430
  }
2431
- mp = m0 + (m - m0) / mc * mc;
2432
- np = n0 + (n - n0) / nc * nc;
 
2433
  mnpack(mp, m, n0, np);
2434
  mnpack(m0, m, np, n);
2435
  }
2436
 
 
2437
  void KERNEL_4x8(int64_t ii, int64_t jj) {
2438
  vec_t vec_A[8], vec_B[16] = {0};
2439
  acc_t acc_0, acc_1;
@@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC {
2445
  __builtin_mma_xxsetaccz(&acc_0);
2446
  __builtin_mma_xxsetaccz(&acc_1);
2447
  if (std::is_same_v<TA, block_q4_0>) {
2448
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2449
  } else {
2450
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2451
  }
2452
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2453
  for(int x = 0; x < 8; x++) {
@@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC {
2475
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
2476
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2477
  }
2478
- save_res<4, 4>(ii, jj, 0, fin_res);
2479
- save_res<4, 4>(ii, jj+4, 4, fin_res);
2480
  }
2481
 
2482
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC {
2490
  __builtin_mma_xxsetaccz(&acc_0);
2491
  __builtin_mma_xxsetaccz(&acc_1);
2492
  if (std::is_same_v<TA, block_q4_0>) {
2493
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2494
  } else {
2495
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2496
  }
2497
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2498
  for(int x = 0; x < 8; x++) {
@@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC {
2519
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2520
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2521
  }
2522
- save_res<4, 4>(ii, jj, 0, fin_res);
2523
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2524
  }
2525
 
2526
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC {
2536
  __builtin_mma_xxsetaccz(&acc_2);
2537
  __builtin_mma_xxsetaccz(&acc_3);
2538
  if (std::is_same_v<TA, block_q4_0>) {
2539
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2540
  } else {
2541
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2542
  }
2543
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2544
  for(int x = 0; x < 8; x++) {
@@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC {
2570
  compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2571
  compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2572
  }
2573
- save_res<4, 4>(ii, jj, 0, fin_res);
2574
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2575
- save_res<4, 4>(ii, jj+4, 8, fin_res);
2576
- save_res<4, 4>(ii+4, jj+4, 12, fin_res);
2577
  }
2578
 
2579
- template<int RM, int RN>
2580
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2581
  int64_t ytiles = (m - m0) / RM;
2582
  int64_t xtiles = (n - n0) / RN;
2583
  int64_t tiles = xtiles * ytiles;
@@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC {
2606
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2607
  __builtin_mma_xxsetaccz(&acc_0);
2608
  if (isAblock_q4) {
2609
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2610
  } else {
2611
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2612
  }
2613
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2614
  for(int x = 0; x < 8; x+=4) {
@@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC {
2641
  fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2642
  }
2643
  }
2644
- save_res<RM, RN>(ii, jj, 0, fin_res);
2645
  }
2646
  }
2647
 
@@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC {
2654
  } else if constexpr(RM == 8 && RN == 8) {
2655
  KERNEL_8x8(ii,jj);
2656
  } else {
2657
- static_assert(false, "RN/RM values not supported");
2658
  }
2659
  }
2660
 
@@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC {
2676
  }
2677
 
2678
  const TA *const A;
2679
- const TB *const B;
2680
- TC *C;
2681
- TA *At;
2682
- TB *Bt;
2683
  const int64_t k;
2684
  const int64_t lda;
2685
  const int64_t ldb;
@@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC {
2688
  const int nth;
2689
  };
2690
 
2691
- template <typename TA, typename TB, typename TC>
2692
  class tinyBLAS_PPC {
2693
  public:
2694
  tinyBLAS_PPC(int64_t k,
2695
- const TA *A, int64_t lda,
2696
- const TB *B, int64_t ldb,
2697
- TC *C, int64_t ldc,
2698
  int ith, int nth)
2699
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2700
  }
@@ -2707,247 +2184,139 @@ class tinyBLAS_PPC {
2707
 
2708
  void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2709
 
2710
- template<typename VA>
2711
- void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2712
  int64_t i, j;
2713
- TA *aoffset = NULL, *boffset = NULL;
2714
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2715
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2716
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2717
- VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2718
- VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2719
- VA t1, t2, t3, t4, t5, t6, t7, t8;
2720
- aoffset = const_cast<TA*>(a);
2721
  boffset = vec;
2722
  j = (rows >> 3);
2723
  if (j > 0) {
2724
 
2725
  do {
2726
- aoffset1 = aoffset;
2727
- aoffset2 = aoffset1 + lda;
2728
- aoffset3 = aoffset2 + lda;
2729
- aoffset4 = aoffset3 + lda;
2730
- aoffset5 = aoffset4 + lda;
2731
- aoffset6 = aoffset5 + lda;
2732
- aoffset7 = aoffset6 + lda;
2733
- aoffset8 = aoffset7 + lda;
2734
  aoffset += 8 * lda;
2735
  i = (cols >> 3);
2736
  if (i > 0) {
2737
  do {
2738
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2739
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2740
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2741
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2742
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
2743
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
2744
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
2745
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
2746
- __builtin_vsx_disassemble_pair(c1, &C1);
2747
- __builtin_vsx_disassemble_pair(c2, &C2);
2748
- __builtin_vsx_disassemble_pair(c3, &C3);
2749
- __builtin_vsx_disassemble_pair(c4, &C4);
2750
- __builtin_vsx_disassemble_pair(c5, &C5);
2751
- __builtin_vsx_disassemble_pair(c6, &C6);
2752
- __builtin_vsx_disassemble_pair(c7, &C7);
2753
- __builtin_vsx_disassemble_pair(c8, &C8);
2754
-
2755
- t1 = vec_mergeh(c1[0], c2[0]);
2756
- t2 = vec_mergeh(c3[0], c4[0]);
2757
- t3 = vec_mergeh(c5[0], c6[0]);
2758
- t4 = vec_mergeh(c7[0], c8[0]);
2759
- t5 = vec_xxpermdi(t1, t2, 0);
2760
- t6 = vec_xxpermdi(t3, t4, 0);
2761
- t7 = vec_xxpermdi(t1, t2, 3);
2762
- t8 = vec_xxpermdi(t3, t4, 3);
2763
- vec_xst(t5, 0, boffset);
2764
- vec_xst(t6, 0, boffset+4);
2765
- vec_xst(t7, 0, boffset+8);
2766
- vec_xst(t8, 0, boffset+12);
2767
-
2768
- t1 = vec_mergel(c1[0], c2[0]);
2769
- t2 = vec_mergel(c3[0], c4[0]);
2770
- t3 = vec_mergel(c5[0], c6[0]);
2771
- t4 = vec_mergel(c7[0], c8[0]);
2772
- t5 = vec_xxpermdi(t1, t2, 0);
2773
- t6 = vec_xxpermdi(t3, t4, 0);
2774
- t7 = vec_xxpermdi(t1, t2, 3);
2775
- t8 = vec_xxpermdi(t3, t4, 3);
2776
- vec_xst(t5, 0, boffset+16);
2777
- vec_xst(t6, 0, boffset+20);
2778
- vec_xst(t7, 0, boffset+24);
2779
- vec_xst(t8, 0, boffset+28);
2780
-
2781
- t1 = vec_mergeh(c1[1], c2[1]);
2782
- t2 = vec_mergeh(c3[1], c4[1]);
2783
- t3 = vec_mergeh(c5[1], c6[1]);
2784
- t4 = vec_mergeh(c7[1], c8[1]);
2785
- t5 = vec_xxpermdi(t1, t2, 0);
2786
- t6 = vec_xxpermdi(t3, t4, 0);
2787
- t7 = vec_xxpermdi(t1, t2, 3);
2788
- t8 = vec_xxpermdi(t3, t4, 3);
2789
- vec_xst(t5, 0, boffset+32);
2790
- vec_xst(t6, 0, boffset+36);
2791
- vec_xst(t7, 0, boffset+40);
2792
- vec_xst(t8, 0, boffset+44);
2793
-
2794
- t1 = vec_mergel(c1[1], c2[1]);
2795
- t2 = vec_mergel(c3[1], c4[1]);
2796
- t3 = vec_mergel(c5[1], c6[1]);
2797
- t4 = vec_mergel(c7[1], c8[1]);
2798
- t5 = vec_xxpermdi(t1, t2, 0);
2799
- t6 = vec_xxpermdi(t3, t4, 0);
2800
- t7 = vec_xxpermdi(t1, t2, 3);
2801
- t8 = vec_xxpermdi(t3, t4, 3);
2802
- vec_xst(t5, 0, boffset+48);
2803
- vec_xst(t6, 0, boffset+52);
2804
- vec_xst(t7, 0, boffset+56);
2805
- vec_xst(t8, 0, boffset+60);
2806
-
2807
- aoffset1 += 8*lda;
2808
- aoffset2 += 8*lda;
2809
- aoffset3 += 8*lda;
2810
- aoffset4 += 8*lda;
2811
  boffset += 64;
2812
  i--;
2813
  } while(i > 0);
2814
  }
2815
  if (cols & 4) {
2816
- c1[0] = vec_xl(0, aoffset1);
2817
- c2[0] = vec_xl(0, aoffset2);
2818
- c3[0] = vec_xl(0, aoffset3);
2819
- c4[0] = vec_xl(0, aoffset4);
2820
- c5[0] = vec_xl(0, aoffset5);
2821
- c6[0] = vec_xl(0, aoffset6);
2822
- c7[0] = vec_xl(0, aoffset7);
2823
- c8[0] = vec_xl(0, aoffset8);
2824
-
2825
- t1 = vec_mergeh(c1[0], c2[0]);
2826
- t2 = vec_mergeh(c3[0], c4[0]);
2827
- t3 = vec_mergeh(c5[0], c6[0]);
2828
- t4 = vec_mergeh(c7[0], c8[0]);
2829
- t5 = vec_xxpermdi(t1, t2, 0);
2830
- t6 = vec_xxpermdi(t3, t4, 0);
2831
- t7 = vec_xxpermdi(t1, t2, 3);
2832
- t8 = vec_xxpermdi(t3, t4, 3);
2833
- vec_xst(t5, 0, boffset);
2834
- vec_xst(t6, 0, boffset+4);
2835
- vec_xst(t7, 0, boffset+8);
2836
- vec_xst(t8, 0, boffset+12);
2837
-
2838
- t1 = vec_mergel(c1[0], c2[0]);
2839
- t2 = vec_mergel(c3[0], c4[0]);
2840
- t3 = vec_mergel(c5[0], c6[0]);
2841
- t4 = vec_mergel(c7[0], c8[0]);
2842
- t5 = vec_xxpermdi(t1, t2, 0);
2843
- t6 = vec_xxpermdi(t3, t4, 0);
2844
- t7 = vec_xxpermdi(t1, t2, 3);
2845
- t8 = vec_xxpermdi(t3, t4, 3);
2846
- vec_xst(t5, 0, boffset+16);
2847
- vec_xst(t6, 0, boffset+20);
2848
- vec_xst(t7, 0, boffset+24);
2849
- vec_xst(t8, 0, boffset+28);
2850
  }
2851
  j--;
2852
  } while(j > 0);
2853
  }
2854
 
2855
  if (rows & 4) {
2856
- aoffset1 = aoffset;
2857
- aoffset2 = aoffset1 + lda;
2858
- aoffset3 = aoffset2 + lda;
2859
- aoffset4 = aoffset3 + lda;
2860
  aoffset += 4 * lda;
2861
  i = (cols >> 3);
2862
  if (i > 0) {
2863
  do {
2864
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2865
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2866
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2867
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2868
- __builtin_vsx_disassemble_pair(c1, &C1);
2869
- __builtin_vsx_disassemble_pair(c2, &C2);
2870
- __builtin_vsx_disassemble_pair(c3, &C3);
2871
- __builtin_vsx_disassemble_pair(c4, &C4);
2872
-
2873
- t1 = vec_mergeh(c1[0], c2[0]);
2874
- t2 = vec_mergeh(c3[0], c4[0]);
2875
- t3 = vec_mergel(c1[0], c2[0]);
2876
- t4 = vec_mergel(c3[0], c4[0]);
2877
- t5 = vec_xxpermdi(t1, t2, 0);
2878
- t6 = vec_xxpermdi(t1, t2, 3);
2879
- t7 = vec_xxpermdi(t3, t4, 0);
2880
- t8 = vec_xxpermdi(t3, t4, 3);
2881
- vec_xst(t5, 0, boffset);
2882
- vec_xst(t6, 0, boffset+4);
2883
- vec_xst(t7, 0, boffset+8);
2884
- vec_xst(t8, 0, boffset+12);
2885
-
2886
- t1 = vec_mergeh(c1[1], c2[1]);
2887
- t2 = vec_mergeh(c3[1], c4[1]);
2888
- t3 = vec_mergel(c1[1], c2[1]);
2889
- t4 = vec_mergel(c3[1], c4[1]);
2890
- t5 = vec_xxpermdi(t1, t2, 0);
2891
- t6 = vec_xxpermdi(t1, t2, 3);
2892
- t7 = vec_xxpermdi(t3, t4, 0);
2893
- t8 = vec_xxpermdi(t3, t4, 3);
2894
- vec_xst(t5, 0, boffset+16);
2895
- vec_xst(t6, 0, boffset+20);
2896
- vec_xst(t7, 0, boffset+24);
2897
- vec_xst(t8, 0, boffset+28);
2898
-
2899
- aoffset1 += 8*lda;
2900
- aoffset2 += 8*lda;
2901
- aoffset3 += 8*lda;
2902
- aoffset4 += 8*lda;
2903
  boffset += 32;
2904
  i--;
2905
  } while(i > 0);
2906
  }
2907
 
2908
  if (cols & 4) {
2909
- c1[0] = vec_xl(0, aoffset1);
2910
- c2[0] = vec_xl(0, aoffset2);
2911
- c3[0] = vec_xl(0, aoffset3);
2912
- c4[0] = vec_xl(0, aoffset4);
2913
-
2914
- t1 = vec_mergeh(c1[0], c2[0]);
2915
- t2 = vec_mergeh(c3[0], c4[0]);
2916
- t3 = vec_xxpermdi(t1, t2, 0);
2917
- t4 = vec_xxpermdi(t1, t2, 3);
2918
- vec_xst(t3, 0, boffset);
2919
- vec_xst(t4, 0, boffset+4);
2920
-
2921
- t1 = vec_mergel(c1[0], c2[0]);
2922
- t2 = vec_mergel(c3[0], c4[0]);
2923
- t3 = vec_xxpermdi(t1, t2, 0);
2924
- t4 = vec_xxpermdi(t1, t2, 3);
2925
- vec_xst(t3, 0, boffset+8);
2926
- vec_xst(t4, 0, boffset+12);
2927
  }
2928
  }
2929
  if (rows & 3) {
2930
- aoffset1 = aoffset;
2931
- aoffset2 = aoffset1 + lda;
2932
- aoffset3 = aoffset2 + lda;
2933
  if (cols & 4) {
2934
- c1[0] = vec_xl(0, aoffset1);
2935
- c2[0] = vec_xl(0, aoffset2);
2936
- c3[0] = vec_xl(0, aoffset3);
2937
-
2938
- t1 = vec_mergeh(c1[0], c2[0]);
2939
- t2 = vec_mergeh(c3[0], c4[0]);
2940
- t3 = vec_xxpermdi(t1, t2, 0);
2941
- t4 = vec_xxpermdi(t1, t2, 3);
2942
- vec_xst(t3, 0, boffset);
2943
- vec_xst(t4, 0, boffset+4);
2944
-
2945
- t1 = vec_mergel(c1[0], c2[0]);
2946
- t2 = vec_mergel(c3[0], c4[0]);
2947
- t3 = vec_xxpermdi(t1, t2, 0);
2948
- t4 = vec_xxpermdi(t1, t2, 3);
2949
- vec_xst(t3, 0, boffset+8);
2950
- vec_xst(t4, 0, boffset+12);
2951
  }
2952
  }
2953
  }
@@ -2957,8 +2326,8 @@ class tinyBLAS_PPC {
2957
  acc_t acc_0;
2958
  __builtin_mma_xxsetaccz(&acc_0);
2959
  for (int l = 0; l < k; l+=4) {
2960
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2961
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2962
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2963
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2964
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -2973,8 +2342,8 @@ class tinyBLAS_PPC {
2973
  __builtin_mma_xxsetaccz(&acc_0);
2974
  __builtin_mma_xxsetaccz(&acc_1);
2975
  for (int64_t l = 0; l < k; l+=4) {
2976
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2977
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
2978
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2979
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2980
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2994,8 +2363,8 @@ class tinyBLAS_PPC {
2994
  __builtin_mma_xxsetaccz(&acc_0);
2995
  __builtin_mma_xxsetaccz(&acc_1);
2996
  for (int64_t l = 0; l < k; l+=4) {
2997
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2998
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2999
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3000
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3001
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -3017,8 +2386,8 @@ class tinyBLAS_PPC {
3017
  __builtin_mma_xxsetaccz(&acc_2);
3018
  __builtin_mma_xxsetaccz(&acc_3);
3019
  for (int l = 0; l < k; l+=8) {
3020
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
3021
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
3022
  for(int x = 0; x < 16; x+=2) {
3023
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3024
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -3033,155 +2402,37 @@ class tinyBLAS_PPC {
3033
  }
3034
 
3035
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3036
- int64_t mc, nc, mp, np;
3037
- int m_rem = MIN(m - m0, 16);
3038
- int n_rem = MIN(n - n0, 16);
3039
- if (m_rem >= 16 && n_rem >= 8) {
3040
- mc = 8;
3041
- nc = 8;
3042
- gemm<8,8>(m0, m, n0, n);
3043
- } else if(m_rem >= 8 && n_rem >= 16) {
3044
- mc = 8;
3045
- nc = 8;
3046
- gemm<8,8>(m0, m, n0, n);
3047
- } else if (m_rem >= 8 && n_rem >= 8) {
3048
- mc = 8;
3049
- nc = 8;
3050
- gemm<8,8>(m0, m, n0, n);
3051
  } else if (m_rem >= 4 && n_rem >= 8) {
3052
- mc = 4;
3053
- nc = 8;
3054
- gemm<4,8>(m0, m, n0, n);
3055
  } else if (m_rem >= 8 && n_rem >= 4) {
3056
- mc = 8;
3057
- nc = 4;
3058
- gemm<8,4>(m0, m, n0, n);
3059
  } else if (m_rem >= 4 && n_rem >= 4) {
3060
- mc = 4;
3061
- nc = 4;
3062
- gemm<4,4>(m0, m, n0, n);
3063
- } else if ((m_rem < 4) && (n_rem > 4)) {
3064
- nc = 4;
3065
- switch(m_rem) {
3066
- case 1:
3067
- mc = 1;
3068
- gemm_small(m0, m, n0, n, mc, nc);
3069
- break;
3070
- case 2:
3071
- mc = 2;
3072
- gemm_small(m0, m, n0, n, mc, nc);
3073
- break;
3074
- case 3:
3075
- mc = 3;
3076
- gemm_small(m0, m, n0, n, mc, nc);
3077
- break;
3078
- default:
3079
- return;
3080
- }
3081
- } else if ((m_rem > 4) && (n_rem < 4)) {
3082
- mc = 4;
3083
- switch(n_rem) {
3084
- case 1:
3085
- nc = 1;
3086
- gemm_small(m0, m, n0, n, mc, nc);
3087
- break;
3088
- case 2:
3089
- nc = 2;
3090
- gemm_small(m0, m, n0, n, mc, nc);
3091
- break;
3092
- case 3:
3093
- nc = 3;
3094
- gemm_small(m0, m, n0, n, mc, nc);
3095
- break;
3096
- default:
3097
- return;
3098
- }
3099
  } else {
3100
- switch((m_rem << 4) | n_rem) {
3101
- case 0x43:
3102
- mc = 4;
3103
- nc = 3;
3104
- gemm_small(m0, m, n0, n, mc, nc);
3105
- break;
3106
- case 0x42:
3107
- mc = 4;
3108
- nc = 2;
3109
- gemm_small(m0, m, n0, n, mc, nc);
3110
- break;
3111
- case 0x41:
3112
- mc = 4;
3113
- nc = 1;
3114
- gemm_small(m0, m, n0, n, mc, nc);
3115
- break;
3116
- case 0x34:
3117
- mc = 3;
3118
- nc = 4;
3119
- gemm_small(m0, m, n0, n, mc, nc);
3120
- break;
3121
- case 0x33:
3122
- mc = 3;
3123
- nc = 3;
3124
- gemm_small(m0, m, n0, n, mc, nc);
3125
- break;
3126
- case 0x32:
3127
- mc = 3;
3128
- nc = 2;
3129
- gemm_small(m0, m, n0, n, mc, nc);
3130
- break;
3131
- case 0x31:
3132
- mc = 3;
3133
- nc = 1;
3134
- gemm_small(m0, m, n0, n, mc, nc);
3135
- break;
3136
- case 0x24:
3137
- mc = 2;
3138
- nc = 4;
3139
- gemm_small(m0, m, n0, n, mc, nc);
3140
- break;
3141
- case 0x23:
3142
- mc = 2;
3143
- nc = 3;
3144
- gemm_small(m0, m, n0, n, mc, nc);
3145
- break;
3146
- case 0x22:
3147
- mc = 2;
3148
- nc = 2;
3149
- gemm_small(m0, m, n0, n, mc, nc);
3150
- break;
3151
- case 0x21:
3152
- mc = 2;
3153
- nc = 1;
3154
- gemm_small(m0, m, n0, n, mc, nc);
3155
- break;
3156
- case 0x14:
3157
- mc = 1;
3158
- nc = 4;
3159
- gemm_small(m0, m, n0, n, mc, nc);
3160
- break;
3161
- case 0x13:
3162
- mc = 1;
3163
- nc = 3;
3164
- gemm_small(m0, m, n0, n, mc, nc);
3165
- break;
3166
- case 0x12:
3167
- mc = 1;
3168
- nc = 2;
3169
- gemm_small(m0, m, n0, n, mc, nc);
3170
- break;
3171
- case 0x11:
3172
- mc = 1;
3173
- nc = 1;
3174
- gemm_small(m0, m, n0, n, mc, nc);
3175
- break;
3176
- default:
3177
- return;
3178
- }
3179
  }
3180
- mp = m0 + (m - m0) / mc * mc;
3181
- np = n0 + (n - n0) / nc * nc;
3182
  mnpack(mp, m, n0, np);
3183
  mnpack(m0, m, np, n);
3184
- }
3185
 
3186
  void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3187
  int64_t ytiles = (m - m0) / RM;
@@ -3206,22 +2457,22 @@ class tinyBLAS_PPC {
3206
  * matrix elements.
3207
  */
3208
  if (RM == 1) {
3209
- TA* a = const_cast<TA*>(A+(ii)*lda+l);
3210
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3211
  vec_A[0] = (vec_t)vec_xl(0,a);
3212
- vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3213
- vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3214
- vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
3215
  } else if (RN == 1) {
3216
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3217
- TB* b = const_cast<TB*>(B+(jj)*ldb+l);
3218
  vec_B[0] = (vec_t)vec_xl(0,b);
3219
- vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3220
- vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3221
- vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
3222
  } else {
3223
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3224
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3225
  }
3226
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3227
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -3231,7 +2482,7 @@ class tinyBLAS_PPC {
3231
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
3232
  for (int I = 0; I < RM; I++) {
3233
  for (int J = 0; J < RN; J++) {
3234
- *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
3235
  }
3236
  }
3237
  }
@@ -3263,11 +2514,9 @@ class tinyBLAS_PPC {
3263
  }
3264
  }
3265
 
3266
- const TA *const A;
3267
- const TB *const B;
3268
- TC *C;
3269
- TA *At;
3270
- TB *Bt;
3271
  const int64_t k;
3272
  const int64_t lda;
3273
  const int64_t ldb;
@@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3366
  #elif defined(__MMA__)
3367
  if (k % 8)
3368
  return false;
3369
- tinyBLAS_PPC<float, float, float> tb{
3370
  k, (const float *)A, lda,
3371
  (const float *)B, ldb,
3372
  (float *)C, ldc,
@@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3493
  return false;
3494
  if (m < 8 && m != 4)
3495
  return false;
3496
- tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
3497
  k, (const block_q8_0 *)A, lda,
3498
  (const block_q8_0 *)B, ldb,
3499
  (float *)C, ldc,
@@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3530
  return false;
3531
  if (m < 8 && m != 4)
3532
  return false;
3533
- tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
3534
  k, (const block_q4_0 *)A, lda,
3535
  (const block_q8_0 *)B, ldb,
3536
  (float *)C, ldc,
 
1541
  } else if constexpr(RM == 8 && RN == 4) {
1542
  KERNEL_8x4(ii,jj);
1543
  } else {
1544
+ assert(false && "RN/RM values not supported");
1545
  }
1546
  }
1547
 
 
1573
  const int nth;
1574
  };
1575
 
1576
+ template <typename TA>
1577
  class tinyBLAS_Q0_PPC {
1578
  public:
1579
  tinyBLAS_Q0_PPC(int64_t k,
1580
  const TA *A, int64_t lda,
1581
+ const block_q8_0 *B, int64_t ldb,
1582
+ float *C, int64_t ldc,
1583
  int ith, int nth)
1584
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585
  }
 
1590
 
1591
  private:
1592
 
1593
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
 
1594
  for (int I = 0; I < RM; I++) {
1595
  for (int J = 0; J < RN; J++) {
1596
  *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
 
1610
  fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611
  }
1612
  }
1613
+ /* This function processes quantized data from block_q4_0 elements.
1614
+ * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615
+ * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616
+ * Also compute the rowsum which is required to compensate the above conversion. */
1617
+ inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
 
 
 
 
 
 
1618
  const vector signed char lowMask = vec_splats((signed char)0xF);
1619
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1620
  const vector signed char v8 = vec_splats((signed char)0x8);
1621
+ vector signed int vsum = {0};
1622
+ vector signed int vsum2 = {0};
1623
+ c[0] = vec_and(c[1], lowMask);
1624
+ c[1] = vec_sr(c[1], v4);
1625
+ c[0] = vec_sub(c[0], v8);
1626
+ c[1] = vec_sub(c[1], v8);
1627
+ vsum = vec_sum4s(c[0], vsum);
1628
+ vsum2 = vec_sum4s(c[1], vsum2);
1629
+ vsum = vec_add(vsum, vsum2);
1630
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1631
+ }
1632
+
1633
+ template <typename V1, typename V2>
1634
+ inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
1635
  vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1636
  vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1637
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1638
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1639
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
1640
+ vector unsigned char xor_vector;
1641
+ uint8_t flip_vec = 0x80;
1642
+ xor_vector = vec_splats(flip_vec);
1643
+ t1 = vec_perm(s1, s2, swiz1);
1644
+ t2 = vec_perm(s1, s2, swiz2);
1645
+ t3 = vec_perm(s3, s4, swiz1);
1646
+ t4 = vec_perm(s3, s4, swiz2);
1647
+ t5 = vec_perm(t1, t3, swiz3);
1648
+ t6 = vec_perm(t1, t3, swiz4);
1649
+ t7 = vec_perm(t2, t4, swiz3);
1650
+ t8 = vec_perm(t2, t4, swiz4);
1651
+ if (flip == true) {
1652
+ t5 = vec_xor(t5, xor_vector);
1653
+ t6 = vec_xor(t6, xor_vector);
1654
+ t7 = vec_xor(t7, xor_vector);
1655
+ t8 = vec_xor(t8, xor_vector);
1656
+ }
1657
+ vec_xst(t5, 0, vecOffset);
1658
+ vec_xst(t6, 0, vecOffset+16);
1659
+ vec_xst(t7, 0, vecOffset+32);
1660
+ vec_xst(t8, 0, vecOffset+48);
1661
+ }
1662
 
1663
+ template<int size>
1664
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665
+ int64_t i, j;
1666
+ TA *aoffset = NULL;
1667
+ int8_t *vecOffset = NULL;
1668
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1669
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1670
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1671
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1672
+ aoffset = const_cast<TA*>(a);
1673
+ vecOffset = vec;
1674
  j = (rows >> 3);
1675
  if (j > 0) {
1676
  do {
 
1683
  aoffset7 = aoffset6 + lda;
1684
  aoffset8 = aoffset7 + lda;
1685
  aoffset += 8 * lda;
 
1686
  i = (cols >> 2);
1687
  if (i > 0) {
1688
  do {
1689
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697
+
1698
+ process_q4_elements(c1, &comparray[0]);
1699
+ process_q4_elements(c2, &comparray[1]);
1700
+ process_q4_elements(c3, &comparray[2]);
1701
+ process_q4_elements(c4, &comparray[3]);
1702
+ process_q4_elements(c5, &comparray[4]);
1703
+ process_q4_elements(c6, &comparray[5]);
1704
+ process_q4_elements(c7, &comparray[6]);
1705
+ process_q4_elements(c8, &comparray[7]);
1706
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1707
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1708
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1709
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1710
  aoffset1 += lda;
1711
  aoffset2 += lda;
1712
  aoffset3 += lda;
 
1729
  aoffset3 = aoffset2 + lda;
1730
  aoffset4 = aoffset3 + lda;
1731
  aoffset += 4 * lda;
 
1732
  i = (cols >> 2);
1733
  if (i > 0) {
1734
  do {
1735
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1736
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1737
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1738
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1739
+
1740
+ process_q4_elements(c1, &comparray[0]);
1741
+ process_q4_elements(c2, &comparray[1]);
1742
+ process_q4_elements(c3, &comparray[2]);
1743
+ process_q4_elements(c4, &comparray[3]);
1744
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1745
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1746
  aoffset1 += lda;
1747
  aoffset2 += lda;
1748
  aoffset3 += lda;
 
1761
  if (i > 0) {
1762
  do {
1763
  switch(rows) {
1764
+ case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1765
+ case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1766
+ case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1767
  break;
1768
  }
1769
+ process_q4_elements(c1, &comparray[0]);
1770
+ process_q4_elements(c2, &comparray[1]);
1771
+ process_q4_elements(c3, &comparray[2]);
1772
+ process_q4_elements(c4, &comparray[3]);
1773
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1774
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1775
  aoffset1 += lda;
1776
  aoffset2 += lda;
1777
  aoffset3 += lda;
 
1781
  }
1782
  }
1783
  }
 
1784
  template<typename VA, typename VB>
1785
+ void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1786
  int64_t i, j;
1787
+ block_q8_0 *aoffset = NULL;
1788
  VA *vecOffset = NULL;
1789
+ block_q8_0* aoffsets[8];
1790
+ __vector_pair arr[8];
1791
+ VB c[8][2] = {0};
1792
+ VB c1[8] = {0}; VB c2[8] = {0};
1793
+ aoffset = const_cast<block_q8_0*>(a);
 
 
 
 
 
 
 
 
 
 
1794
  vecOffset = vec;
1795
  j = (rows >> 3);
1796
  if (j > 0) {
1797
  do {
1798
+ aoffsets[0] = aoffset;
1799
+ for (int it = 1; it < 8; it++)
1800
+ aoffsets[it] = aoffsets[it-1] + lda;
 
 
 
 
 
1801
  aoffset += 8 * lda;
1802
 
1803
  i = (cols >> 3);
1804
  if (i > 0) {
1805
  do {
1806
+ for (int it = 0; it < 8; it++) {
1807
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1808
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1809
+ c1[it] = c[it][0];
1810
+ c2[it] = c[it][1];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1811
  }
1812
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1813
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1814
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1815
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1816
+ for (int it = 0; it < 8; it++)
1817
+ aoffsets[it] += lda;
 
 
 
 
 
 
 
1818
  vecOffset += 256;
1819
  i--;
1820
  } while(i > 0);
 
1824
  }
1825
 
1826
  if (rows & 4) {
1827
+ aoffsets[0] = aoffset;
1828
+ for (int it = 1; it < 4; it++ )
1829
+ aoffsets[it] = aoffsets[it-1] + lda;
1830
+ aoffset += 4 * lda;
 
 
1831
  i = (cols >> 3);
1832
  if (i > 0) {
1833
  do {
1834
+ for (int it = 0; it < 4; it++) {
1835
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1836
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1837
+ c1[it] = c[it][0];
1838
+ c2[it] = c[it][1];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1839
  }
1840
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1841
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1842
+ for (int it = 0; it < 4; it++) {
1843
+ aoffsets[it] += lda;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1844
  }
 
 
 
 
 
 
 
 
 
1845
  vecOffset += 128;
1846
  i--;
1847
  } while(i > 0);
1848
  }
1849
  }
1850
+
1851
  if (rows & 3) {
1852
+ aoffsets[0] = aoffset;
1853
+ for (int it = 1; it < 3; it++ )
1854
+ aoffsets[it] = aoffsets[it-1] + lda;
1855
  i = (cols >> 3);
1856
  if (i > 0) {
1857
  do {
1858
  switch(rows) {
1859
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1860
+ __builtin_vsx_disassemble_pair(c[2], &arr[2]);
1861
+ c1[2] = c[2][0]; c2[2] = c[2][1];
1862
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1863
+ __builtin_vsx_disassemble_pair(c[1], &arr[1]);
1864
+ c1[1] = c[1][0]; c2[1] = c[1][1];
1865
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1866
+ __builtin_vsx_disassemble_pair(c[0], &arr[0]);
1867
+ c1[0] = c[0][0]; c2[0] = c[0][1];
1868
  break;
1869
  }
1870
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1871
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1872
+ for (int it = 0; it < 3; it++)
1873
+ aoffsets[it] += lda;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1874
  vecOffset += 128;
1875
  i--;
1876
  } while(i > 0);
 
1879
  }
1880
 
1881
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1882
+ int m_rem = MIN(m - m0, 16);
1883
+ int n_rem = MIN(n - n0, 16);
1884
+
1885
+ int mc = 0, nc = 0;
1886
+
 
 
 
 
 
 
 
 
 
1887
  if (m_rem >= 8 && n_rem >= 8) {
1888
+ mc = 8;
1889
+ nc = 8;
1890
+ gemm<8, 8>(m0, m, n0, n);
1891
  } else if (m_rem >= 4 && n_rem >= 8) {
1892
  mc = 4;
1893
  nc = 8;
1894
+ gemm<4, 8>(m0, m, n0, n);
1895
  } else if (m_rem >= 8 && n_rem >= 4) {
1896
  mc = 8;
1897
  nc = 4;
1898
+ gemm<8, 4>(m0, m, n0, n);
1899
  } else if (m_rem >= 4 && n_rem >= 4) {
1900
  mc = 4;
1901
  nc = 4;
1902
+ gemm_small(m0, m, n0, n, mc, nc);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1903
  } else {
1904
+ mc = (m_rem >= 4) ? 4 : m_rem;
1905
+ nc = (n_rem >= 4) ? 4 : n_rem;
1906
+ if (mc == 0 || nc == 0)
1907
+ return;
1908
+ gemm_small(m0, m, n0, n, mc, nc);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1909
  }
1910
+
1911
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
1912
+ int64_t np = n0 + ((n - n0) / nc) * nc;
1913
  mnpack(mp, m, n0, np);
1914
  mnpack(m0, m, np, n);
1915
  }
1916
 
1917
+
1918
  void KERNEL_4x8(int64_t ii, int64_t jj) {
1919
  vec_t vec_A[8], vec_B[16] = {0};
1920
  acc_t acc_0, acc_1;
 
1926
  __builtin_mma_xxsetaccz(&acc_0);
1927
  __builtin_mma_xxsetaccz(&acc_1);
1928
  if (std::is_same_v<TA, block_q4_0>) {
1929
+ packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1930
  } else {
1931
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1932
  }
1933
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1934
  for(int x = 0; x < 8; x++) {
 
1956
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1957
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
1958
  }
1959
+ save_res(ii, jj, 0, fin_res);
1960
+ save_res(ii, jj+4, 4, fin_res);
1961
  }
1962
 
1963
  void KERNEL_8x4(int64_t ii, int64_t jj) {
 
1971
  __builtin_mma_xxsetaccz(&acc_0);
1972
  __builtin_mma_xxsetaccz(&acc_1);
1973
  if (std::is_same_v<TA, block_q4_0>) {
1974
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1975
  } else {
1976
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1977
  }
1978
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1979
  for(int x = 0; x < 8; x++) {
 
2000
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2001
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2002
  }
2003
+ save_res(ii, jj, 0, fin_res);
2004
+ save_res(ii+4, jj, 4, fin_res);
2005
  }
2006
 
2007
  void KERNEL_8x8(int64_t ii, int64_t jj) {
 
2017
  __builtin_mma_xxsetaccz(&acc_2);
2018
  __builtin_mma_xxsetaccz(&acc_3);
2019
  if (std::is_same_v<TA, block_q4_0>) {
2020
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2021
  } else {
2022
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2023
  }
2024
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2025
  for(int x = 0; x < 8; x++) {
 
2051
  compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2052
  compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2053
  }
2054
+ save_res(ii, jj, 0, fin_res);
2055
+ save_res(ii+4, jj, 4, fin_res);
2056
+ save_res(ii, jj+4, 8, fin_res);
2057
+ save_res(ii+4, jj+4, 12, fin_res);
2058
  }
2059
 
2060
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
 
2061
  int64_t ytiles = (m - m0) / RM;
2062
  int64_t xtiles = (n - n0) / RN;
2063
  int64_t tiles = xtiles * ytiles;
 
2086
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2087
  __builtin_mma_xxsetaccz(&acc_0);
2088
  if (isAblock_q4) {
2089
+ packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2090
  } else {
2091
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2092
  }
2093
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2094
  for(int x = 0; x < 8; x+=4) {
 
2121
  fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2122
  }
2123
  }
2124
+ save_res(ii, jj, 0, fin_res, RM, RN);
2125
  }
2126
  }
2127
 
 
2134
  } else if constexpr(RM == 8 && RN == 8) {
2135
  KERNEL_8x8(ii,jj);
2136
  } else {
2137
+ assert(false && "RN/RM values not supported");
2138
  }
2139
  }
2140
 
 
2156
  }
2157
 
2158
  const TA *const A;
2159
+ const block_q8_0 *const B;
2160
+ float *C;
 
 
2161
  const int64_t k;
2162
  const int64_t lda;
2163
  const int64_t ldb;
 
2166
  const int nth;
2167
  };
2168
 
 
2169
  class tinyBLAS_PPC {
2170
  public:
2171
  tinyBLAS_PPC(int64_t k,
2172
+ const float *A, int64_t lda,
2173
+ const float *B, int64_t ldb,
2174
+ float *C, int64_t ldc,
2175
  int ith, int nth)
2176
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2177
  }
 
2184
 
2185
  void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2186
 
2187
+ inline void vector_permute_store_4(vector float *src, float *vecOffset) {
2188
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2189
+ t1 = vec_mergeh(src[0], src[1]);
2190
+ t2 = vec_mergeh(src[2], src[3]);
2191
+ t3 = vec_mergel(src[0], src[1]);
2192
+ t4 = vec_mergel(src[2], src[3]);
2193
+
2194
+ t5 = vec_xxpermdi(t1, t2, 0);
2195
+ t6 = vec_xxpermdi(t1, t2, 3);
2196
+ t7 = vec_xxpermdi(t3, t4, 0);
2197
+ t8 = vec_xxpermdi(t3, t4, 3);
2198
+
2199
+ vec_xst(t5, 0, vecOffset);
2200
+ vec_xst(t6, 0, vecOffset + 4);
2201
+ vec_xst(t7, 0, vecOffset + 8);
2202
+ vec_xst(t8, 0, vecOffset + 12);
2203
+ }
2204
+
2205
+ inline void vector_permute_store_8(vector float *src, float *vecOffset) {
2206
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2207
+ t1 = vec_mergeh(src[0], src[1]);
2208
+ t2 = vec_mergeh(src[2], src[3]);
2209
+ t3 = vec_mergeh(src[4], src[5]);
2210
+ t4 = vec_mergeh(src[6], src[7]);
2211
+
2212
+ t5 = vec_xxpermdi(t1, t2, 0);
2213
+ t6 = vec_xxpermdi(t3, t4, 0);
2214
+ t7 = vec_xxpermdi(t1, t2, 3);
2215
+ t8 = vec_xxpermdi(t3, t4, 3);
2216
+
2217
+ vec_xst(t5, 0, vecOffset);
2218
+ vec_xst(t6, 0, vecOffset + 4);
2219
+ vec_xst(t7, 0, vecOffset + 8);
2220
+ vec_xst(t8, 0, vecOffset + 12);
2221
+
2222
+ t1 = vec_mergel(src[0], src[1]);
2223
+ t2 = vec_mergel(src[2], src[3]);
2224
+ t3 = vec_mergel(src[4], src[5]);
2225
+ t4 = vec_mergel(src[6], src[7]);
2226
+
2227
+ t5 = vec_xxpermdi(t1, t2, 0);
2228
+ t6 = vec_xxpermdi(t3, t4, 0);
2229
+ t7 = vec_xxpermdi(t1, t2, 3);
2230
+ t8 = vec_xxpermdi(t3, t4, 3);
2231
+
2232
+ vec_xst(t5, 0, vecOffset + 16);
2233
+ vec_xst(t6, 0, vecOffset + 20);
2234
+ vec_xst(t7, 0, vecOffset + 24);
2235
+ vec_xst(t8, 0, vecOffset + 28);
2236
+ }
2237
+
2238
+ void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
2239
  int64_t i, j;
2240
+ float * aoffsets[8];
2241
+ float *aoffset = NULL, *boffset = NULL;
2242
+ __vector_pair arr[8];
2243
+ vector float c[8][2] = {0};
2244
+ vector float c1[8] = {0};
2245
+ vector float c2[8] = {0};
2246
+ aoffset = const_cast<float*>(a);
 
2247
  boffset = vec;
2248
  j = (rows >> 3);
2249
  if (j > 0) {
2250
 
2251
  do {
2252
+ aoffsets[0] = aoffset;
2253
+ for (int it = 1; it< 8; it++)
2254
+ aoffsets[it] = aoffsets[it-1] + lda;
 
 
 
 
 
2255
  aoffset += 8 * lda;
2256
  i = (cols >> 3);
2257
  if (i > 0) {
2258
  do {
2259
+ for (int it = 0; it< 8; it++) {
2260
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2261
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2262
+ c1[it] = c[it][0];
2263
+ c2[it] = c[it][1];
2264
+ }
2265
+
2266
+ vector_permute_store_8(c1, boffset);
2267
+ vector_permute_store_8(c2, boffset+32);
2268
+ for (int it = 0; it < 4; it++)
2269
+ aoffsets[it] = aoffsets[it] + 8*lda;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2270
  boffset += 64;
2271
  i--;
2272
  } while(i > 0);
2273
  }
2274
  if (cols & 4) {
2275
+ for (int it = 0; it < 8 ; it++)
2276
+ c1[it] = vec_xl(0, aoffsets[it]);
2277
+ vector_permute_store_8(c1, boffset);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2278
  }
2279
  j--;
2280
  } while(j > 0);
2281
  }
2282
 
2283
  if (rows & 4) {
2284
+ aoffsets[0] = aoffset;
2285
+ for (int it = 1; it < 4; it++)
2286
+ aoffsets[it] = aoffsets[it-1] + lda;
 
2287
  aoffset += 4 * lda;
2288
  i = (cols >> 3);
2289
  if (i > 0) {
2290
  do {
2291
+ for (int it = 0; it < 4; it++) {
2292
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2293
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2294
+ c1[it] = c[it][0];
2295
+ c2[it] = c[it][1];
2296
+ }
2297
+ vector_permute_store_4(c1, boffset);
2298
+ vector_permute_store_4(c2, boffset+16);
2299
+ for (int it = 0; it < 4; it++)
2300
+ aoffsets[it] += 8*lda;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2301
  boffset += 32;
2302
  i--;
2303
  } while(i > 0);
2304
  }
2305
 
2306
  if (cols & 4) {
2307
+ for (int it = 0; it < 4; it++)
2308
+ c1[it] = vec_xl(0, aoffsets[it]);
2309
+ vector_permute_store_4(c1, boffset);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2310
  }
2311
  }
2312
  if (rows & 3) {
2313
+ aoffsets[0] = aoffset;
2314
+ for (int it = 1; it < 3; it++)
2315
+ aoffsets[it] = aoffsets[it-1] + lda;
2316
  if (cols & 4) {
2317
+ for (int it = 0; it < 3; it++)
2318
+ c1[it] = vec_xl(0, aoffsets[it]);
2319
+ vector_permute_store_4(c1, boffset);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2320
  }
2321
  }
2322
  }
 
2326
  acc_t acc_0;
2327
  __builtin_mma_xxsetaccz(&acc_0);
2328
  for (int l = 0; l < k; l+=4) {
2329
+ packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2330
+ packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2331
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2332
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2333
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
 
2342
  __builtin_mma_xxsetaccz(&acc_0);
2343
  __builtin_mma_xxsetaccz(&acc_1);
2344
  for (int64_t l = 0; l < k; l+=4) {
2345
+ packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2346
+ packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
2347
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2348
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2349
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
 
2363
  __builtin_mma_xxsetaccz(&acc_0);
2364
  __builtin_mma_xxsetaccz(&acc_1);
2365
  for (int64_t l = 0; l < k; l+=4) {
2366
+ packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
2367
+ packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2368
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2369
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2370
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
 
2386
  __builtin_mma_xxsetaccz(&acc_2);
2387
  __builtin_mma_xxsetaccz(&acc_3);
2388
  for (int l = 0; l < k; l+=8) {
2389
+ packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
2390
+ packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
2391
  for(int x = 0; x < 16; x+=2) {
2392
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2393
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
 
2402
  }
2403
 
2404
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2405
+ int m_rem = MIN(m - m0, 8);
2406
+ int n_rem = MIN(n - n0, 8);
2407
+ int mc = 0, nc = 0;
2408
+ if (m_rem >= 8 && n_rem >= 8) {
2409
+ mc = 8;
2410
+ nc = 8;
2411
+ gemm<8, 8>(m0, m, n0, n);
 
 
 
 
 
 
 
 
2412
  } else if (m_rem >= 4 && n_rem >= 8) {
2413
+ mc = 4;
2414
+ nc = 8;
2415
+ gemm<4, 8>(m0, m, n0, n);
2416
  } else if (m_rem >= 8 && n_rem >= 4) {
2417
+ mc = 8;
2418
+ nc = 4;
2419
+ gemm<8, 4>(m0, m, n0, n);
2420
  } else if (m_rem >= 4 && n_rem >= 4) {
2421
+ mc = 4;
2422
+ nc = 4;
2423
+ gemm<4, 4>(m0, m, n0, n);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2424
  } else {
2425
+ mc = (m_rem >= 4) ? 4 : m_rem;
2426
+ nc = (n_rem >= 4) ? 4 : n_rem;
2427
+ if (mc == 0 || nc == 0)
2428
+ return;
2429
+ gemm_small(m0, m, n0, n, mc, nc);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2430
  }
2431
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
2432
+ int64_t np = n0 + ((n - n0) / nc) * nc;
2433
  mnpack(mp, m, n0, np);
2434
  mnpack(m0, m, np, n);
2435
+ }
2436
 
2437
  void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2438
  int64_t ytiles = (m - m0) / RM;
 
2457
  * matrix elements.
2458
  */
2459
  if (RM == 1) {
2460
+ float* a = const_cast<float*>(A+(ii)*lda+l);
2461
+ packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2462
  vec_A[0] = (vec_t)vec_xl(0,a);
2463
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
2464
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
2465
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
2466
  } else if (RN == 1) {
2467
+ packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2468
+ float* b = const_cast<float*>(B+(jj)*ldb+l);
2469
  vec_B[0] = (vec_t)vec_xl(0,b);
2470
+ vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
2471
+ vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
2472
+ vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
2473
  } else {
2474
+ packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2475
+ packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2476
  }
2477
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2478
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
 
2482
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
2483
  for (int I = 0; I < RM; I++) {
2484
  for (int J = 0; J < RN; J++) {
2485
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
2486
  }
2487
  }
2488
  }
 
2514
  }
2515
  }
2516
 
2517
+ const float *const A;
2518
+ const float *const B;
2519
+ float *C;
 
 
2520
  const int64_t k;
2521
  const int64_t lda;
2522
  const int64_t ldb;
 
2615
  #elif defined(__MMA__)
2616
  if (k % 8)
2617
  return false;
2618
+ tinyBLAS_PPC tb{
2619
  k, (const float *)A, lda,
2620
  (const float *)B, ldb,
2621
  (float *)C, ldc,
 
2742
  return false;
2743
  if (m < 8 && m != 4)
2744
  return false;
2745
+ tinyBLAS_Q0_PPC<block_q8_0> tb{
2746
  k, (const block_q8_0 *)A, lda,
2747
  (const block_q8_0 *)B, ldb,
2748
  (float *)C, ldc,
 
2779
  return false;
2780
  if (m < 8 && m != 4)
2781
  return false;
2782
+ tinyBLAS_Q0_PPC<block_q4_0> tb{
2783
  k, (const block_q4_0 *)A, lda,
2784
  (const block_q8_0 *)B, ldb,
2785
  (float *)C, ldc,