Files
Data-Structure/Matrix/matrix.h
2025-07-20 21:53:15 +08:00

169 lines
4.2 KiB
C++

#pragma once
#include <iostream>
template <class T>
class matrix {
private:
int rows;//行
int cols;//列
T* element;//注意是T* 而不是T** 底层线性
public:
matrix() = delete;
matrix(int rows = 0, int cols = 0, T* x = nullptr);
~matrix() { delete[] element; }
int rowss() const { return rows; }
int columns() const { return cols; }
int times(int i, int j, const matrix<T>& left, const matrix<T>& right) const;//If unused, write it in private
T& operator()(int i, int j) const;
matrix<T>& operator=(const matrix<T>& x);
//matrix<T> operator+() const;
matrix<T> operator+(const matrix<T>& x) const;
//matrix<T> operator-() const;
matrix<T> operator-(const matrix<T>& x) const;
matrix<T> operator*(const matrix<T>& x) const;
matrix<T>& operator+=(const matrix<T>& x);
matrix<T>& transpose();
template <class T>
friend std::ostream& operator<<(std::ostream& os, const matrix<T>& x);
bool operator==(const matrix<T>& other) const {
if (rows != other.rows || cols != other.cols)
return false;
for (int i = 0; i < rows * cols; ++i)
if (element[i] != other.element[i])
return false;
return true;
}
bool operator!=(const matrix<T>& other) const {
return !(*this == other);
}
};
template <class T>
int matrix<T>::times(int i, int j, const matrix<T>& left, const matrix<T>& right) const {
int tmp = 0;
for (int x = 1; x <= left.cols; x++) {
tmp += left(i, x) * right(x, j);
}
return tmp;
}
template <class T>
matrix<T>::matrix(int rows, int cols, T* x) : rows(rows), cols(cols){
if (rows < 0 || cols < 0) {
throw std::out_of_range("Matrix subscript out of range");
}
int sum = rows * cols;
element = new T[sum];
if (x) {
for (int i = 0; i < sum; i++) element[i] = x[i];
}
else {
for (int i = 0; i < sum; i++) element[i] = T();
}
}
template <class T>
T& matrix<T>::operator()(int i, int j) const {
if (i < 1 || i > rows || j < 1 || j > cols) {
throw std::out_of_range("Matrix subscript out of range");
}
return element[(i - 1) * cols + j - 1];//important, using cols rather than rows
}
template <class T>
matrix<T>& matrix<T>::operator=(const matrix<T>& x) {
if (*this != x) {
delete[] element;
int sum = x.rows * x.cols;
element = new T[sum];
for (int i = 0; i < sum; i++) {
element[i] = x.element[i];
}
}
return *this;
}
template <class T>
matrix<T> matrix<T>::operator+(const matrix<T>& x) const{
if (rows != x.rows || cols != x.cols) {
throw std::invalid_argument("matrix can't add: size mismatch");
}
//std::unique_ptr<matrix<T>> tmp = std::make_unique<matrix<T>>(rows, cols);
matrix<T> tmp(rows, cols);
for (int i = 0; i < rows * cols; i++) {
tmp.element[i] = element[i] + x.element[i];
}
return tmp;
}
template <class T>
matrix<T> matrix<T>::operator-(const matrix<T>& x) const {
if (rows != x.rows || cols != x.cols) {
throw std::invalid_argument("matrix can't minus: size mismatch");
}
//std::unique_ptr<matrix<T>> tmp = std::make_unique<matrix<T>>(rows, cols);
matrix<T> tmp(rows, cols);
for (int i = 0; i < rows * cols; i++) {
tmp.element[i] = element[i] - x.element[i];
}
return tmp;
}
template <class T>
matrix<T> matrix<T>::operator*(const matrix<T>& x) const {
if (cols != x.rows) {
throw std::invalid_argument("matrix can't times: size mismatch");
}
T* t = new T[rows * x.cols];
matrix<T> tmp(rows, x.cols, t);
//* is difficult : 左行右列
for (int i = 1; i <= rows; i++) {
for (int j = 1; j <= x.cols; j++) {
tmp(i, j) = times(i, j, *this, x);
}
}
return tmp;
}
template <class T>
matrix<T>& matrix<T>::operator+=(const matrix<T>& x) {
if (rows != x.rows || cols != x.cols) {
throw std::invalid_argument("matrix += error: size mismatch");
}
for (int i = 0; i < rows * cols; i++) {
element[i] += x.element[i];
}
return *this;
}
template <class T>
std::ostream& operator<<(std::ostream& os, const matrix<T>& x) {
for (int i = 0; i < x.rows; i++) {
for (int j = 0; j < x.cols; j++) {
os << x.element[i * x.cols + j] << " ";
}
os << std::endl;
}
return os;
}
template <class T>
matrix<T>& matrix<T>::transpose() {
T* telement = new T[rows * cols];
for (int i = 1; i <= rows; i++) {
for (int j = 1; j <= cols; j++) {
telement[(j - 1) * rows + i - 1] = (*this)(i, j);
}
}
delete[] element;
std::swap(rows, cols);
element = telement;
return *this;
}