Files
mysql-scanner/mysql-scanner/main.go
2026-02-26 22:47:32 +00:00

580 lines
13 KiB
Go

package main
import (
"database/sql"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"regexp"
"strings"
_ "github.com/go-sql-driver/mysql"
)
// Config holds scanner configuration
type Config struct {
Host string
Port int
User string
Password string
Database string
Patterns []string
JSON bool
Output string
}
// Finding represents a detected issue
type Finding struct {
Table string `json:"table"`
Column string `json:"column"`
RowID string `json:"row_id"`
RiskLevel string `json:"risk_level"`
Pattern string `json:"pattern"`
Match string `json:"match"`
Snippet string `json:"snippet"`
LineNumbers []int `json:"line_numbers,omitempty"`
}
// ScanResult holds complete scan results
type ScanResult struct {
Database string `json:"database"`
Tables int `json:"tables_scanned"`
Rows int `json:"rows_scanned"`
Findings []Finding `json:"findings"`
HighRisk int `json:"high_risk_count"`
MedRisk int `json:"medium_risk_count"`
LowRisk int `json:"low_risk_count"`
}
// Pattern defines a detection pattern
type Pattern struct {
Name string
Risk string
Regex *regexp.Regexp
Sensitive bool
}
var patterns []Pattern
func init() {
// Initialize detection patterns
patterns = []Pattern{
// PHP dangerous functions - HIGH RISK
{
Name: "eval_function",
Risk: "high",
Regex: regexp.MustCompile(`\beval\s*\(`),
},
{
Name: "assert_function",
Risk: "high",
Regex: regexp.MustCompile(`\bassert\s*\(`),
},
{
Name: "create_function",
Risk: "high",
Regex: regexp.MustCompile(`\bcreate_function\s*\(`),
},
{
Name: "exec_function",
Risk: "high",
Regex: regexp.MustCompile(`\b(?:exec|system|passthru|shell_exec)\s*\(`),
},
// Base64 with decode - HIGH RISK
{
Name: "base64_decode",
Risk: "high",
Regex: regexp.MustCompile(`\bbase64_decode\s*\(`),
},
{
Name: "gzinflate",
Risk: "high",
Regex: regexp.MustCompile(`\bgzinflate\s*\(`),
},
{
Name: "str_rot13",
Risk: "medium",
Regex: regexp.MustCompile(`\bstr_rot13\s*\(`),
},
// Long base64 strings - MEDIUM RISK
{
Name: "long_base64",
Risk: "medium",
Regex: regexp.MustCompile(`[A-Za-z0-9+/]{200,}={0,2}`),
},
// HTML/JS injection - MEDIUM RISK
{
Name: "script_tag",
Risk: "medium",
Regex: regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
},
{
Name: "iframe_tag",
Risk: "medium",
Regex: regexp.MustCompile(`(?i)<iframe[^>]*(?:hidden|width=0|height=0)[^>]*>`),
},
{
Name: "javascript_protocol",
Risk: "medium",
Regex: regexp.MustCompile(`(?i)javascript:\s*[^\s"'>]+`),
},
{
Name: "event_handler",
Risk: "medium",
Regex: regexp.MustCompile(`(?i)\b(?:onerror|onload|onclick|onmouseover)\s*=`),
},
// Web shell patterns - HIGH RISK
{
Name: "superglobal_eval",
Risk: "high",
Regex: regexp.MustCompile(`(?:\$_(?:GET|POST|REQUEST|COOKIE))\[.*\]\s*\(\s*\$`),
},
{
Name: "error_reporting_suppress",
Risk: "medium",
Regex: regexp.MustCompile(`error_reporting\s*\(\s*0\s*\)`),
},
// Variable function calls - MEDIUM RISK
{
Name: "variable_function",
Risk: "medium",
Regex: regexp.MustCompile(`\$[a-zA-Z_]\w*\s*\(`),
},
// Common malware comment signatures - LOW RISK
{
Name: "malware_comment",
Risk: "low",
Regex: regexp.MustCompile(`(?i)/\*.*(?:shell|backdoor|webshell|hack|c99|r57).*\*/`),
},
// Obfuscation indicators - MEDIUM RISK
{
Name: "hex_encode",
Risk: "medium",
Regex: regexp.MustCompile(`\\x[0-9a-fA-F]{2}`),
},
}
}
func main() {
config := parseFlags()
if config.User == "" || config.Database == "" {
log.Fatal("Error: --user and --db are required")
}
// Connect to database
db, err := connectDB(config)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
fmt.Printf("Connected to %s@%s:%d/%s\n", config.User, config.Host, config.Port, config.Database)
// Scan database
result, err := scanDatabase(db, config)
if err != nil {
log.Fatalf("Scan failed: %v", err)
}
// Output results
if config.JSON {
outputJSON(result, config.Output)
} else {
outputHuman(result)
}
}
func parseFlags() Config {
host := flag.String("host", getEnv("DB_HOST", "localhost"), "Database host")
port := flag.Int("port", getEnvInt("DB_PORT", 3306), "Database port")
user := flag.String("user", getEnv("DB_USER", ""), "Database user")
password := flag.String("password", getEnv("DB_PASSWORD", ""), "Database password")
db := flag.String("db", getEnv("DB_NAME", ""), "Database name")
patterns := flag.String("patterns", "", "Comma-separated patterns to scan (default: all)")
jsonOut := flag.Bool("json", false, "Output JSON instead of human-readable")
output := flag.String("output", "", "Output file path (default: stdout)")
flag.Parse()
var patternList []string
if *patterns != "" {
patternList = strings.Split(*patterns, ",")
for i := range patternList {
patternList[i] = strings.TrimSpace(patternList[i])
}
}
return Config{
Host: *host,
Port: *port,
User: *user,
Password: *password,
Database: *db,
Patterns: patternList,
JSON: *jsonOut,
Output: *output,
}
}
func connectDB(config Config) (*sql.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
config.User,
config.Password,
config.Host,
config.Port,
config.Database,
)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return db, nil
}
func scanDatabase(db *sql.DB, config Config) (*ScanResult, error) {
result := &ScanResult{
Database: config.Database,
Findings: []Finding{},
}
// Get all tables
tables, err := getTables(db, config.Database)
if err != nil {
return nil, fmt.Errorf("failed to get tables: %w", err)
}
fmt.Printf("Found %d tables\n", len(tables))
for _, table := range tables {
fmt.Printf("Scanning table: %s\n", table)
columns, err := getTextColumns(db, config.Database, table)
if err != nil {
log.Printf("Warning: failed to get columns for %s: %v", table, err)
continue
}
if len(columns) == 0 {
fmt.Printf(" No text columns, skipping\n")
continue
}
fmt.Printf(" Scanning %d text columns: %v\n", len(columns), columns)
rows, err := scanTable(db, table, columns, config.Patterns)
if err != nil {
log.Printf("Warning: failed to scan table %s: %v", table, err)
continue
}
result.Findings = append(result.Findings, rows...)
result.Tables++
result.Rows += len(rows)
}
// Count risk levels
for _, f := range result.Findings {
switch f.RiskLevel {
case "high":
result.HighRisk++
case "medium":
result.MedRisk++
case "low":
result.LowRisk++
}
}
return result, nil
}
func getTables(db *sql.DB, dbName string) ([]string, error) {
query := `SELECT table_name FROM information_schema.tables
WHERE table_schema = ? AND table_type = 'BASE TABLE'`
rows, err := db.Query(query, dbName)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, nil
}
func getTextColumns(db *sql.DB, dbName, table string) ([]string, error) {
query := `SELECT column_name FROM information_schema.columns
WHERE table_schema = ? AND table_name = ?
AND data_type IN ('text', 'longtext', 'mediumtext', 'varchar', 'char')
ORDER BY ordinal_position`
rows, err := db.Query(query, dbName, table)
if err != nil {
return nil, err
}
defer rows.Close()
var columns []string
for rows.Next() {
var col string
if err := rows.Scan(&col); err != nil {
return nil, err
}
columns = append(columns, col)
}
return columns, nil
}
func scanTable(db *sql.DB, table string, columns, activePatterns []string) ([]Finding, error) {
var findings []Finding
// Build SELECT query with ID column if exists
query, idColumn := buildSelectQuery(table, columns)
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
// Get column names for scanning
colNames, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
// Scan row into values
values := make([]interface{}, len(colNames))
valuePtrs := make([]interface{}, len(colNames))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
// Get row identifier
rowID := ""
if idColumn != "" {
for i, name := range colNames {
if name == idColumn && values[i] != nil {
rowID = fmt.Sprintf("%v", values[i])
break
}
}
}
// Scan each text column
for i, name := range colNames {
if values[i] == nil {
continue
}
// Skip ID column
if name == idColumn {
continue
}
content := fmt.Sprintf("%v", values[i])
if len(content) == 0 {
continue
}
// Run patterns
for _, pattern := range patterns {
// Filter patterns if specified
if len(activePatterns) > 0 && !contains(activePatterns, pattern.Name) {
continue
}
if pattern.Regex.MatchString(content) {
matches := pattern.Regex.FindAllString(content, 10)
for _, match := range matches {
finding := Finding{
Table: table,
Column: name,
RowID: rowID,
RiskLevel: pattern.Risk,
Pattern: pattern.Name,
Match: match,
Snippet: truncate(content, 200),
}
findings = append(findings, finding)
}
}
}
}
}
return findings, nil
}
func buildSelectQuery(table string, columns []string) (string, string) {
// Try to find an ID column for row identification
idColumns := []string{"id", "ID", "post_id", "comment_id", "user_id", "option_id", "term_id"}
var idColumn string
// For now, we'll select all columns and figure out ID during scan
selectCols := "*"
if len(columns) > 0 {
// Include ID column if we can guess it
for _, idCol := range idColumns {
if contains(columns, idCol) {
idColumn = idCol
break
}
}
}
query := fmt.Sprintf("SELECT %s FROM %s", selectCols, table)
return query, idColumn
}
func outputHuman(result *ScanResult) {
fmt.Printf("\n=== SCAN RESULTS ===\n")
fmt.Printf("Database: %s\n", result.Database)
fmt.Printf("Tables scanned: %d\n", result.Tables)
fmt.Printf("Findings: %d total (High: %d, Medium: %d, Low: %d)\n\n",
len(result.Findings), result.HighRisk, result.MedRisk, result.LowRisk)
if len(result.Findings) == 0 {
fmt.Println("✓ No suspicious patterns detected")
return
}
// Group by risk level
high := []Finding{}
medium := []Finding{}
low := []Finding{}
for _, f := range result.Findings {
switch f.RiskLevel {
case "high":
high = append(high, f)
case "medium":
medium = append(medium, f)
case "low":
low = append(low, f)
}
}
// Print high risk first
if len(high) > 0 {
fmt.Printf("🔴 HIGH RISK (%d)\n", len(high))
printFindings(high, 5)
}
if len(medium) > 0 {
fmt.Printf("\n🟡 MEDIUM RISK (%d)\n", len(medium))
printFindings(medium, 3)
}
if len(low) > 0 {
fmt.Printf("\n🟢 LOW RISK (%d)\n", len(low))
printFindings(low, 2)
}
}
func printFindings(findings []Finding, maxDisplay int) {
display := min(maxDisplay, len(findings))
for i := 0; i < display; i++ {
f := findings[i]
fmt.Printf("\n [%s] %s.%s", f.RiskLevel, f.Table, f.Column)
if f.RowID != "" {
fmt.Printf(" (ID: %s)", f.RowID)
}
fmt.Printf("\n Pattern: %s\n", f.Pattern)
fmt.Printf(" Match: %s\n", truncate(f.Match, 100))
fmt.Printf(" Snippet: %s\n", f.Snippet)
}
if len(findings) > maxDisplay {
fmt.Printf("\n ... and %d more\n", len(findings)-maxDisplay)
}
}
func outputJSON(result *ScanResult, outputPath string) {
data, err := json.MarshalIndent(result, "", " ")
if err != nil {
log.Fatalf("Failed to marshal JSON: %v", err)
}
if outputPath != "" {
if err := os.WriteFile(outputPath, data, 0644); err != nil {
log.Fatalf("Failed to write output file: %v", err)
}
fmt.Printf("Results written to %s\n", outputPath)
} else {
fmt.Println(string(data))
}
}
// Utility functions
func getEnv(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if value := os.Getenv(key); value != "" {
var i int
fmt.Sscanf(value, "%d", &i)
return i
}
return fallback
}
func contains(slice []string, item string) bool {
for _, s := range slice {
if strings.EqualFold(s, item) {
return true
}
}
return false
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func min(a, b int) int {
if a < b {
return a
}
return b
}